查看原文
其他

“物以类聚”、“近朱者赤”——机器学习初探之KNN

爬虫俱乐部 Stata and Python数据分析 2022-03-15

本文作者:王   歌

文字编辑:孙晓玲

技术总编:张   邯

导读

前面的推文向大家介绍了Python中两个重要的库——NumpyPandas的一些基本用法(点此复习:NumPy数组基本介绍Pandas基本数据类型介绍),这些库都是我们利用Python进行数据分析的好帮手。在此基础上,我们将结合另一重要的库——sklearn对机器学习的一些常用算法进行介绍,希望能通过易懂的语言和简单的小例子使读者对机器学习的算法有一个初步的了解,不再产生艰深畏难的情绪。在正式介绍今天的算法之前,作为这一系列的开始,小编首先和大家一起对知识结构进行简单的梳理,并了解一些基本概念,做到纲举目张、心中有数。



1.机器学习引入


机器学习,顾名思义就是让机器从已有的数据中进行学习从而更准确地预测。从数据中学得模型的过程称为“学习”或“训练”,用来训练的数据称为“训练集”,学得的模型有时也称为“学习器”,训练集中每个训练样本的结果信息称为“标签”或“标记”,利用模型进行预测的过程称为“测试”,被预测的样本为“测试样本”,学得的模型对新样本的预测能力称为“泛化能力”,一个模型泛化能力越强,在新样本上的适用性越好。机器学习的过程通常可以分为以下几个步骤:(1)收集数据;(2)处理数据,包括数据的清洗、调整数据的格式等;(3)训练算法;(4)测试算法;(5)使用算法。

由于机器学习的任务是对测试样本给出类似于训练集标签的结果信息,因此根据训练集是否有标签,将有标签的称为有监督学习,没有标签的称为无监督学习,介于这两者之间的是半监督学习,它是指不依赖外界交互,自动利用未标记样本来提升学习能力,在这一学习过程中同时使用了未标记数据和标记数据。而有监督学习中又可根据标签值是否连续分为回归和分类两类问题。我们今天要学习的KNN算法就是属于有监督学习的算法,它既可以用于分类,也可以用于回归。



2.KNN算法的基本理论


KNN(k-Nearest Neighbor)也叫k-近邻算法,其分类的基本思想是给定一个训练数据集,对于新的输入数据,在训练数据集中找到与它最邻近的K个样本,如果某个类别在这K个样本中出现次数最多,就把它分到这个类中。这就类似于我们所说的“物以类聚”、“近朱者赤”的道理。一般情况下我们只选择前K个最邻近的样本,K通常是不大于20的整数。根据算法的基本思想,我们可以将伪代码总结如下:

(1)对于训练集中的所有样本,计算每个样本与当前输入数据的距离;

(2)将计算出的距离按照升序排列,并取出前K个距离最小的样本;

(3)统计这K个样本的标签值,并找出出现频率最高的标签;

(4)输入数据的标签值即为该频率最高的标签值。

在计算距离时,通常使用欧氏距离,当然也可以使用闵氏距离等进行计算。而对于K值的选取,如果选择较小,就相当于用一个较小范围的训练样本进行预测,只有与输入实例较近的训练实例才会对预测起作用,容易发生过拟合。如果K值选择较大,就相当于用较大的范围中的训练样本进行预测,容易发生欠拟合。

从上面的流程中我们可以看到,KNN算法没有进行数据的训练,直接使用未知的数据与已知的数据进行比较得到结果,不具有显式的学习过程,属于一种懒惰学习的模式,因此算法的训练复杂度为0。但由于要将训练数据存储下来以便进行计算比较,所以KNN算法会有计算复杂度高、空间复杂度高、存储空间需求大的缺点。不过由于准确性高,对异常值和噪声有较高的容忍度,并且对输入的数据没有假定,因此这一算法也有很好的应用效果。



3.算法实例


我们可以通过Python程序构造距离函数来实现这一算法,大家有兴趣可以尝试一下。这里我们主要用Python中的sklearn库来实现。sklearn(scikit-learn)库是Python中用来实现数据分析、机器学习的有效工具。下面我们就来看一看如何在KNN算法中应用。这里使用的数据来自sklearn官方文档提供的鸢尾花(Iris)数据。该数据共有150个样本,分为3类——iris-setosa(山鸢尾)、iris-versicolour(杂色鸢尾)和iris-virginica(维吉尼亚鸢尾),其对应的标签值分别为0、1、2,每类各50个样本,每个样本包含4个属性:Sepal Length(花萼长度)、Sepal Width(花萼宽度)、Petal Length(花瓣长度)、Petal Width(花瓣宽度),属性值都为正浮点数。

首先我们导入数据集并将前10个样本的数据输出,程序如下:

import numpy as npfrom sklearn import datasets #导入sklearn的内置数据集iris_sample = datasets.load_iris()print(iris_sample.data[:10, :], iris_sample.target[:10]) 

输出的结果如下:

数据集中的data数组是存储四个属性的二维数组,target是对应的分类标签,由于这个数据是整齐排列的,因此前50个均为山鸢尾,标签值均为0,51-100标签为1,101-150标签为2。然后我们将这个数据集随机分出75%的数据作为训练数据,剩余25%作为测试数据,程序如下:

from sklearn.model_selection import train_test_splitfrom sklearn.neighbors import KNeighborsClassifierx_train, x_test, y_train, y_test = train_test_split(iris_sample.data, iris_sample.target, test_size=0.25, random_state=123)knclf = KNeighborsClassifier(n_neighbors=5)knclf.fit(x_train, y_train) #拟合学习器y_test_pre = knclf.predict(x_test)score = knclf.score(x_test, y_test)print('测试集预测结果为:', y_test_pre)print('测试集正确结果为:', y_test)print('测试集准确度为:', score)

这里我们首先使用sklearn的train_test_split ()将数据集进行划分,该函数第一个参数是被划分的属性集,第二个参数为被划分的标签数据集,参数test_size 确定了测试集所占比例,如给此参数传入整数则为样本数量,random_state设置了随机数的种子。KNeighborsClassifier()定义一个分类器对象knclf,其基本的语法结构如下:

KNeighborsClassifier(n_neighbors=5,weights=’uniform’,algorithm=’auto’,leaf_size=30,p=2,metric=’minkowski’,metric_params=None,n_jobs=None,**kwargs=object)

其中n_neighbors即为我们所说的k值,默认为5,参数p和metric说明我们使用的是欧式距离,metric_params是一些特殊的metric选项需要的参数,weights是在进行分类时给最近邻加权,默认的uniform是等权加权,还有distance选项是按照距离的倒数进行加权,也可以自己设置其他加权方法,algorithm是分类时采取的算法,默认的auto选项会在学习时自动选择最合适的算法,n_jobs是并行计算的线程数量。然后使用fit()方法传入训练集拟合模型,并使用predict()方法来预测测试集的标签。最后利用score()方法得到预测的准确率。程序运行结果如下:

可以看到除了第一个样本预测错误以外,其余样本均预测正确,正确率为97.37%。

以上就是我们对于KNN算法基本知识的介绍,这里我们只做了关于使用KNN进行分类的算法演示,类似的也可以用它进行回归。大家也可以使用生活中的真实数据来操作一下,有任何问题和心得都欢迎联系小编,小编愿意和您共同进步!





对我们的推文累计打赏超过1000元,我们即可给您开具发票,发票类别为“咨询费”。用心做事,不负您的支持!
往期推文推荐
SFI:Stata与Python的数据交互手册(二)

从流调数据中寻找感染真相

熟悉又陌生的reshape

NBA球员薪资分析——基于随机森林算法(二)

NBA球员薪资分析——基于随机森林算法(一)

高亮输出之唐诗作者

湖北省各市疫情数据爬取

古代诗人总去的这些地方你一定要知道!

DataFrame数组常用方法(二)

ftools命令——畅游大数据时代的加速器

卫健委的“糊涂账”

Pandas中数据的排序与切片

DataFrame数组常用方法

巧用局部宏扩展函数dir

过了14天潜伏期真的没事了?

Pandas基本数据类型介绍

NumPy数组基本介绍

“个性化”sortobs命令,教你实现排序自由

关于我们



微信公众号“Stata and Python数据分析”分享实用的stata、python等软件的数据处理知识,欢迎转载、打赏。我们是由李春涛教授领导下的研究生及本科生组成的大数据处理和分析团队。

此外,欢迎大家踊跃投稿,介绍一些关于stata和python的数据处理和分析技巧。
投稿邮箱:statatraining@163.com
投稿要求:
1)必须原创,禁止抄袭;
2)必须准确,详细,有例子,有截图;
注意事项:
1)所有投稿都会经过本公众号运营团队成员的审核,审核通过才可录用,一经录用,会在该推文里为作者署名,并有赏金分成。
2)邮件请注明投稿,邮件名称为“投稿+推文名称”。
3)应广大读者要求,现开通有偿问答服务,如果大家遇到有关数据处理、分析等问题,可以在公众号中提出,只需支付少量赏金,我们会在后期的推文里给予解答。

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存