查看原文
其他

Sklearn入门之线性判别分析

爬虫俱乐部 Stata and Python数据分析 2022-03-15

本文作者:杨长青

本文编辑:胡   婧

技术总编:张学人

爬虫俱乐部郑重推出中国高校上市公司高管排行榜,受到社会各界的广泛关注。我们所采用的数据全部来源于百度词条、新浪财经、国泰安等,客观公正!如有需要,可联系爬虫俱乐部!

爬虫俱乐部将于2019年1月19日至25日在武汉举行两期Stata编程技术定制培训,此次采取初级班和高级班分批次培训模式,采用理论与案例相结合的方式,旨在帮助大家熟悉Stata核心的爬虫技术,以及Stata与其他软件交互的高端技术。目前还有少量名额大家抓紧时间报名啦!详细培训大纲及报名方式,请见往期推文《2019寒假Stata编程技术定制培训班》。报名表下载请点击文末阅读原文呦~

上一篇推文《Sklearn入门之多元线性回归》介绍了最基础的回归模型,本文将主要介绍最经典的判别方法——线性判别分析(LDA)以及如何运用python中的Sklearn库来实现线性判别分析

LDA算法的思想很简单,我们以二维情况为例,如下图所示,LDA实质就是讲将训练集样本投影到一条直线上,使同类之间的投影点尽可能接近,异类之间的投影点尽可能的远离。对新样本分类预测时,将其投影到同一条直线上,再根据投影点的位置,确定新样本的类别。LDA可以推广到多分类问题上,即可以将样本投影到N-1维的平面上(N样本类别)。具体的算法过程及推导这里就不详细阐述,如有需要可以参阅周志华著的机器学习

LDA也可以从贝叶斯决策理论的角度来阐述,当每一类数据的观测值满足同先验,满足高斯分布且协方差相等的假设时,可以将测试集类别频数看作先验概率,拟合每一类数据高斯分布参数,最终得到后验概率,将样本分到后仰概率大的一类,可以证明在满足假设条件时,LDA算法能达到最优。从贝叶斯的角度,这里就没有二分类和多分类的差别。具体理论推导可以参考王星等译的统计学习导论——基于R应用

下面,我们通过Sklearn自带示例的数据集Iris来实现线性判别分析。该数据集包含150个数据,分为3类(分别是Setosa,Versicolour,Virginica三种鸢尾花),每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于三个种类中的哪一类。

为了方便可视化,让大家直视LDA划分的效果,我们只选取花瓣长度和花瓣宽度这两个变特征对鸢尾花进行分类。

首先,我们画出一张x轴为花瓣长度,y轴为花瓣宽度,不同鸢尾花种类对应不同颜色的散点图。程序如下所示:

import matplotlib.pyplot as plt
from sklearn import datasets #载入示例数据集
iris = datasets.load_iris() #载入示例数据集
X = iris.data[:, :2] #选前两个特征
y = iris.target#花的类别
target_names = iris.target_names#获得标签名
colors = ['navy', 'turquoise', 'darkorange'] #设置颜色
for color, i, target_name in zip(colors, [0, 1, 2], target_names):     plt.scatter(X[y == i, 0], X[y == i, 1], color=color,label=target_name) plt.legend(loc='best', shadow=False, scatterpoints=1) #设置图例
plt.title("花萼长度和宽度对花蕊类别的影响") #设置标题
plt.xlabel("length") #设置x轴备注
plt.ylabel("width") #设置y轴备注
plt.show()

散点图如下所示:

可以直观的看到setosa的花萼长度较长,但花萼较宽。而virginica长度比较长。但是由于选取特征较少,有一部分还是分隔不是很明显。下面我们用Sklearn库中的LinearDiscriminantAnalysis 模块通过这两个特征对花卉进行LDA分类。

爬虫俱乐部是您身边的科研助手,能够为您在数据处理实证研究中提供帮助。承蒙30000+粉丝的支持与厚爱,我们在腾讯课堂推出了网络视频课程,专注于数据整理、网络爬虫、循环命令编制和结果输出…李老师及团队精彩地讲解,深入浅出,注重案例与实战,让您更加快速高效地掌握Stata技巧及数据处理的精髓,而且可以无限次重复观看,在原有课程基础上已上传了三节全新的内容!百分百好评,简单易学,一个月让您从入门到精通。绝对物超所值!观看学习网址:

https://ke.qq.com/course/286526?tuin=1b60b462

敬请关注!

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap#色彩设置
from sklearn import datasets #载入示例数据集
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis iris = datasets.load_iris()  #载入示例数据集
X = iris.data[:, :2]  #选前两个特征
y = iris.target #花的类别
h=0.02 #绘制范围图的步数
#定义LDA模型并训练
clf = LinearDiscriminantAnalysis() clf.fit(X, y) #LDA训练
# 绘制颜色地图
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF']) #三种分类颜色的clour code
cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF']) #三种点的颜色
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 #绘制决策边界。 为此,我们将为每个分配颜色
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),                     np.arange(y_min, y_max, h)) #将两个一维数组变成,间隔0.02的矩阵
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) #将xx,yy变成一维数组
Z = Z.reshape(xx.shape)#_是按行连接两个矩阵,就是把两矩阵左右相加,要求行数相等
plt.figure() plt.pcolormesh(xx, yy, Z, cmap=cmap_light) #plt.pcolormesh()会根据y_predict的结果自动在cmap里选择颜色
# 绘制样本散点图
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold,            edgecolor='k', s=20) plt.xlim(xx.min(), xx.max()) plt.ylim(yy.min(), yy.max()) plt.title("LDA分类效果图") plt.xlabel("length") #设置x轴备注
plt.ylabel("width") #设置y轴备注
plt.show()

最终分类效果图如下:

可以比对原始散点图,分类效果还是比较好的。下面我们利用鸢尾花卉的全部特征,划分训练集和测试集,利用训练集数据的四个变量进行模型训练,使用测试集来验证模型的准确度。程序如下:

from sklearn import datasets
from sklearn import model_selection
from sklearn.discriminant_analysis  import LinearDiscriminantAnalysis iris = datasets.load_iris() #载入示例数据集
X = iris.data #选取全部特征
y = iris.target#花的类别
x_train,x_test,y_train,y_test=model_selection.train_test_split(X, y,test_size=0.25,random_state=10010) #切分0.25的测试集,随机数种子为10010
clf = LinearDiscriminantAnalysis() clf.fit(x_train, y_train) #LDA训练
y_predict=clf.predict(x_test) clf.score(x_train,y_train) #训练集准确度
clf.score(x_test,y_test) #测试集准确度

最终输出结果如下:

对于训练集,分类准确度达到了97.32%,测试集准确度达到了100%。效果比两个特征时好了很多,但是达到100%很大一部分是由于数据集太小,一共只有150条观测,切分出来测试集仅仅只有38条。

以上便是运用Sklearn进行线性判别分析的全部内容,如果各位读者在日常使用中遇到此类问题,都可通过留言或者发邮件与我们联系,我们会竭诚为您解答。

有问题,不要怕!访问 

http://www.wuhanstring.com/uploads/5_aboutus/爬虫俱乐部-用户问题登记表.docx (复制到浏览器中)下载爬虫俱乐部用户问题登记表并按要求填写后发送至邮箱statatraining@163.com,我们会及时为您解答哟~

爬虫俱乐部的github主站正式上线了!我们的网站地址是:https://stata-club.github.io,粉丝们可以通过该网站访问过去的推文哟~

爬虫俱乐部隆重推出数据定制及处理业务,您有任何网页数据获取及处理方面的难题,请发邮件至我们邮箱statatraining@163.com,届时会有俱乐部高级会员为您排忧解难!

对爬虫俱乐部的推文累计打赏超过1000元我们即可给您开具发票,发票类别为“咨询费”。用心做事,只为做您更贴心的小爬虫!

往期推文推荐

关于我们

微信公众号“爬虫俱乐部”分享实用的stata命令,欢迎转载、打赏。爬虫俱乐部是由李春涛教授领导下的研究生及本科生组成的大数据分析和数据挖掘团队。


此外,欢迎大家踊跃投稿,介绍一些关于stata的数据处理和分析技巧。

投稿邮箱:statatraining@163.com

投稿要求:
1)必须原创,禁止抄袭;
2)必须准确,详细,有例子,有截图;
注意事项:
1)所有投稿都会经过本公众号运营团队成员的审核,审核通过才可录用,一经录用,会在该推文里为作者署名,并有赏金分成。
2)邮件请注明投稿,邮件名称为“投稿+推文名称”。
3)应广大读者要求,现开通有偿问答服务,如果大家遇到关于stata分析数据的问题,可以在公众号中提出,只需支付少量赏金,我们会在后期的推文里给予解答。


您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存