查看原文
其他

线性回归的正则化 ——岭回归与LASSO回归

爬虫俱乐部 Stata and Python数据分析 2022-03-15

本文作者:王   歌

文字编辑:孙晓玲

技术总编:张   邯

导读

在《基于广义线性模型的机器学习算法——线性回归》中我们介绍了如何使用线性回归算法来拟合学习器,但有时使用线性回归可能会产生过拟合的现象,此时我们通常有两种途径解决:一是对特征进行选择,减少特征数量,二是使用正则化的方法,这样可以保留所有的特征,而在正则化时我们通常会采用岭回归或LASSO回归,今天我们就来介绍一下这两种正则化方法。

 1 算法原理
1.1正则化

在介绍岭回归与LASSO回归之前,首先我们考虑这样一个问题。在线性回归中我们求解使得损失函数即均方误差最小的w和b作为回归中参数的估计值,求解的过程在于对损失函数中的w求偏导,此时我们会得到:

由此可以得到w的值并且是唯一解。但当X的列数多于行数、不满秩时,此时可以解出多个w,它们都能使损失函数最小,这是我们该如何选择最优解呢?

在机器学习中通常会选择正则化的方法来解决这一问题,即在损失函数中加入正则项(惩罚约束)来防止模型过拟合。其防止过拟合的原理,从本质上来说属于数学上的特征缩减(shrinkage),由于过拟合过分追求“小偏差”使模型过于复杂,导致拟合的数据分布与真实分布偏差很小但方差很大,此时拟合出的曲线斜率的绝对值大,也就是函数的偏导数大,因而要避免偏导数过大就要减小参数,即通过设置惩罚算子α,使得影响较小的特征的系数衰减到0,只保留重要特征从而减少模型复杂度进而达到避免过拟合的目的。这里的模型复杂度并非是幂次上的数值大,而是模型空间中可选模型的数量多。所以上面问题中加入αI,此时满秩可求逆。

正则化的优化目标函数一般为:

其中α≥0,对于线性回归模型,当时的回归称为LASSO回归,称为L1正则化(或L1范数);当时的回归称为岭回归,称为L2正则化(或L2范数)。

1.2岭回归

岭回归的损失函数完整表达为:

岭回归的求解与一般的线性回归求解方法是类似的。加入了惩罚项后,在求解参数时仍然要求偏导,此时若采用最小二乘法,则

而若采用梯度下降法求解,则

其中,该式为w的一个迭代更新公式,β为步长(学习率)。岭回归保留了所有的变量,缩小了参数,所以也就没有信息的损失,但相应的模型的变量过多解释性会降低。

1.3 LASSO回归

LASSO回归的损失函数完整表达为

其中n为样本量。由于L1范数为绝对值的形式,导致LASSO的损失函数不是连续可导的,所以最小二乘法,梯度下降法,牛顿法,拟牛顿法都不能用。通常我们可以使用坐标下降法、最小角回归法或近端梯度下降法等方法,这几种方法的具体推导过程就不再介绍了,大家感兴趣可以进一步了解。当模型中变量较多需要压缩时可以选择LASSO回归的方法。

 2 类参数介绍

在sklearn中进行岭回归和LASSO回归的类分别是Ridge()和Lasso(),首先我们介绍一下这两个类中用到的参数。

2.1 Ridge()

·alpha:即上面讲到的正则化惩罚算子α,接受浮点型,默认为1.0。取值越大对共线性的鲁棒性越强;

·copy_X:默认为True,表示是否复制X数组,为True时将复制X数组,否则将覆盖原数组X;

·fit_intercept:默认为True,表示是否计算此模型的截距;

·max_iter:最大迭代次数,接受整型,默认为None;

·normalize:默认为False,表示是否先进行归一化,当fit_intercept设置为False时,将忽略此参数;

·solver:选择计算时的求解方法,默认为‘auto’,主要有几种选择:(1)auto根据数据类型自动选择求解器;(2)svd使用X的奇异值分解来计算Ridge系数;(3)cholesky使用标准的scipy.linalg.solve函数来获得闭合形式的解;(4)sparse_cg使用在scipy.sparse.linalg.cg中找到的共轭梯度求解器;(5)lsqr使用正则化最小二乘常数scipy.sparse.linalg.lsqr;(6)sag使用随机平均梯度下降;

·tol:接受浮点型,用于选择解的精度;

·random_state:默认为None,用于设置种子,接受整型。

2.2 LASSO()

这个类中有部分参数用法是与Ridge()类相同的,即alpha、fit_intercept 、normalize、copy_X、max_iter、random_state。除此之外,还有以下参数:

·precompute:默认=False ,表示是否使用预计算的Gram矩阵(特征间未减去均值的协方差阵)来加速计算;

·warm_start : 选择为 True 时,重复使用上一次学习作为初始化,否则直接清除上次方案;

·positive : 选择为 True 时,强制使系数为正;

·tol:优化容忍度,接受浮点型;

·selection :默认为'cyclic',若设置为'random',表示每次循环随机更新参数,按照默认设置则会依次更新。

3 算法实例

这里我们使用的数据仍然是上次介绍过的sklearn库中自带的波士顿房价数据。首先我们分别使用岭回归和LASSO回归进行拟合,并输出回归系数和截距,程序如下:
from sklearn.datasets import load_bostonfrom sklearn.model_selection import train_test_splitfrom sklearn.linear_model import Ridge, Lassoboston_sample = load_boston()x_train, x_test, y_train, y_test = train_test_split( boston_sample.data, boston_sample.target, test_size=0.25,random_state=123)# 岭回归ridge = Ridge()ridge.fit(x_train, y_train)print('岭回归系数:', ridge.coef_)print('岭回归截距:', ridge.intercept_)y_ridge_pre = ridge.predict(x_test)# LASSOlasso = Lasso()lasso.fit(x_train, y_train)print('LASSO回归系数:', lasso.coef_)print('LASSO回归系数:', lasso.intercept_)y_lasso_pre = lasso.predict(x_test)

输出结果如下:

而后我们分别输出两个模型回归的拟合优度和均方误差,程序如下:

from sklearn.metrics import mean_squared_error,r2_scoreprint('岭回归的r2为:', r2_score(y_test, y_ridge_pre))print('岭回归的均方误差为:', mean_squared_error(y_test,y_ridge_pre))print('LASSO回归的r2为:', r2_score(y_test, y_lasso_pre))print('LASSO回归的均方误差为:', mean_squared_error(y_test,y_lasso_pre))

结果如下:

我们将线性回归、岭回归和LASSO的预测结果与真实值画在一张图上对比一下,程序如下:

import matplotlib.pyplot as pltplt.plot(y_test, label='True')plt.plot(y, label='linear')plt.plot(y_ridge_pre, label='Ridge')plt.plot(y_lasso_pre, label='LASSO')plt.legend()plt.show()

结果如下图:

可以看到在这个例子中用这三种回归方法得到的结果差别不太大,当然,这里由于篇幅所限,我们在岭回归和LASSO回归中使用了默认的α值,其实还可以通过调参对模型进行优化。同时,我们也要注意,岭回归和LASSO都是使用特征缩减的思想,因此得到的结果是有偏的,所以在选择方法时要注意自己的需求。以上就是我们对线性回归进一步的延伸,大家快去练习一下吧!







对我们的推文累计打赏超过1000元,我们即可给您开具发票,发票类别为“咨询费”。用心做事,不负您的支持!
往期推文推荐
Pandas中节约空间的小tip—categorical类型
Ftools命令组之flevelsof命令介绍
疫情下的家庭关系|《请回答1988》影评爬取
教你把Python当美图秀秀用(二)
自己动手进行线性回归计算
personage与年龄
原来这才是查看盲评结果的正确方式
教你把Python当美图秀秀用(一)
用数据透视表剖析泰坦尼克号乘客数据
读入文本文档,intext来帮忙
matchit——解锁文本相似度的钥匙
基于广义线性模型的机器学习算法——线性回归
听说你会魔法?
dummieslab——从分类变量到虚拟变量的“一步之遥”
线上Python课程都面向哪些方向?
子类与父类
用requests库爬取淘宝数据

关于我们



微信公众号“Stata and Python数据分析”分享实用的stata、python等软件的数据处理知识,欢迎转载、打赏。我们是由李春涛教授领导下的研究生及本科生组成的大数据处理和分析团队。

此外,欢迎大家踊跃投稿,介绍一些关于stata和python的数据处理和分析技巧。
投稿邮箱:statatraining@163.com
投稿要求:
1)必须原创,禁止抄袭;
2)必须准确,详细,有例子,有截图;
注意事项:
1)所有投稿都会经过本公众号运营团队成员的审核,审核通过才可录用,一经录用,会在该推文里为作者署名,并有赏金分成。
2)邮件请注明投稿,邮件名称为“投稿+推文名称”。
3)应广大读者要求,现开通有偿问答服务,如果大家遇到有关数据处理、分析等问题,可以在公众号中提出,只需支付少量赏金,我们会在后期的推文里给予解答。

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存