![PyTorch深度学习应用实战](https://wfqqreader-1252317822.image.myqcloud.com/cover/410/52842410/b_52842410.jpg)
2-2 万般皆自“回归”起
要探究神经网络优化的过程,要先了解简单线性回归求解,线性回归方程式如下:
y=wx+b
已知样本(x, y),要求解方程式中的参数权重(w)、偏差(b)。
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P21_1735.jpg?sign=1739499701-61WT1RlQ9FtEQhbzle3h2C2nbq5k9mI8-0-e5cb6f9044f8d6c2f25fc8bb7a062d66)
图2.2 简单线性回归
一般求解方法有两种:
(1)最小平方法(Ordinary Least Square, OLS);
(2)最大似然估计法(Maximum Likelihood Estimation, MLE)。
以最小平方法为例,首先定义目标函数(Object Function)或称损失函数(Loss Function)为均方误差(MSE),即预测值与实际值差距的平方和,MSE当然越小越好,所以它是一个最小化的问题,我们可以利用偏微分推导出公式,过程如下。
(1)
其中ε:误差,即实际值(y)与预测值之差;
n:样本个数。
(2)MSE=SSE/n,n为常数,不影响求解,可忽略。
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P21_25720.jpg?sign=1739499701-cMShgxR0FkYCoovgYs8jgon0ItQCcgHn-0-cc29ed3faca670e2f22f783f9457a33b)
(3)分别对w及b偏微分,并且令一阶导数=0,可以得到两个联立方程式,进而求得w及b。
(4)先对b偏微分,又因
f′(x)=g(x)g(x)=g′(x)g(x)+g(x)g′(x)=2g(x)g′(x)
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25730.jpg?sign=1739499701-aGTCOOG8KyUzinmSfv0IsHLYuBB39gmy-0-ff6d0750ac65719ad06f2dd13b27426c)
→两边同除以2
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25731.jpg?sign=1739499701-bp46hHgcS9Xe56ZvR6TxxeTfPDk2ECPl-0-3f2815a39d02db1c7ce5e5e4f95f3972)
→分解
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25732.jpg?sign=1739499701-kXWAfxTymz8ssAGl9qgozuN42fCjPK8f-0-c27f2bec66f89d07a494be1f7986f067)
→除以n,为x、y的平均数
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25733.jpg?sign=1739499701-hBlpNarcCxR398KtenwPUYM8GMQgb8ou-0-e0fc7708383231bb6bee73db01e91cf4)
→移项
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25734.jpg?sign=1739499701-cCOOioNqLxQxZ8hBlviQEHX6mNRXDT3T-0-8c767efa130c431a24caed883d77465a)
(5)对w偏微分:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25735.jpg?sign=1739499701-31fdhqmr5DSVyBO0P4zNWTbmf5u9vtMH-0-6752d1330e6d90d870d2e2c37982a666)
→两边同除以-2
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25737.jpg?sign=1739499701-K36VIdGOhqO1KMVQxrW3lZvsos7kPXBC-0-59ab911e194df9080b6a3d3b263dd103)
→分解
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25738.jpg?sign=1739499701-xGhs8qwxBFsdhJ6GUAC7f0HAuGtA3UiD-0-8f8ae71dced72bdde3362c40addc8853)
→代入步骤(4)的计算结果
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25739.jpg?sign=1739499701-yfLskKSuXRyirRAOq3VqijzW3l5QFqfC-0-660b99665bd0e99ea7c8a2ef8f4cbf2a)
→化简
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25740.jpg?sign=1739499701-TiGzjZswtXKiUI0jYWh4ZPeaB9zHHUEQ-0-8dd8a20365302b02fb95cda75b9a16d6)
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25741.jpg?sign=1739499701-JTBK71QDuKm5B4knOAUf16MdJ3VSr7lp-0-840ed7c613dd9621a37e1ab3da0ed3a5)
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25742.jpg?sign=1739499701-QrtqlSfH0cN8WGg03JllcGnZdQPEiFW3-0-a4deacc3fea42758a35097eba8a66380)
结论:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25743.jpg?sign=1739499701-JCbBbkn7IjcQxgkV9NlGcKY4wdbRtYJp-0-c6836643bacbbdc067c5aa48154b886b)
范例1.现有一个世界人口统计数据集,以年度(year)为x,人口数为y,按上述公式计算回归系数w、b。
下列程序代码请参考【02_01_线性回归.ipynb】。
(1)使用Pandas相关函数计算,程序如下:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P23_1947.jpg?sign=1739499701-wUe2ykHYiThIAVVDUy4hYlVHd8dcTcg2-0-00f5d5e9d4470ff3f5a46bbcc1445509)
执行结果:
w=0.061159358661557375,b=-116.35631056117687
(2)改用NumPy的现成函数polyfit验算:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P23_1954.jpg?sign=1739499701-KOxEXoxbGmM40nzwf8rKWpYwqrOD9nhy-0-8772ff589eb974fe45dad8306b36c605)
执行结果:答案相差不大。
w=0.061159358661554586, b=-116.35631056117121
(3)上面公式,x只限一个,若以矩阵计算则更具通用性,多元回归亦可适用,即模型可以有多个特征(x),为简化模型,将b视为w的一环:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P23_25752.jpg?sign=1739499701-M2jlsh6mjkcZ1rdoRdvrQTWptYvAvAN8-0-36f6266d829baad49801dd2d5fecbdd9)
一样对SSE偏微分,一阶导数=0有最小值,公式推导如下:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P23_25754.jpg?sign=1739499701-RXPq1NoOXuJz1jR0npM3aZjeZgU1fc52-0-614dcfe217f17a15c66458aea2a80e4d)
→移项、整理
(xx′)w=xy
→移项
w=(xx′)−1xy
(4)使用NumPy相关函数计算,程序如下:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P23_1969.jpg?sign=1739499701-juX4MYpkQsMMRDN46XF9w6k37l2oib4X-0-724c661af37cce96d4029615461cee11)
执行结果与上一段相同。
范例2.再以Scikit-Learn的房价数据集为例,求解线性回归,该数据集有多个特征(x)。
(1)以矩阵计算的方式,完全不变。
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P24_1991.jpg?sign=1739499701-taleUPBQbTuDgWvbmgIpQQ1ntZoQ7pW0-0-50506e935ff2750719279999300c85a2)
执行结果如下:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P24_1998.jpg?sign=1739499701-hqAO2ErsoXhoshMNnnehTRYioHTQ083B-0-4b90b9216ac38976c2a2f03d54a7f28a)
(2)以Scikit-Learn的线性回归类别验证答案。
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P24_2001.jpg?sign=1739499701-DTfi21U4Dr9ZVaxBvBU7RkzmDWF7MQH9-0-4715051ecd2d2a4359280f47ffb8fe42)
执行结果与采用矩阵计算的结果完全相同。
(3)PyTorch自v1.9起提供线性代数函数库[1],可直接调用,程序改写如下:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P24_2008.jpg?sign=1739499701-AttIZdS5GlqCnbTw10eFT0aO2sPKVaFV-0-90b1a22a1bcc4d2a8edca2bd0d4110f6)
执行结果与NumPy计算完全相同。