在创建好了数据生成器之后,接下来即可进行手动线性回归建模实验。
# 科学计算模块
import numpy as np
import pandas as pd
# 绘图模块
import matplotlib as mpl
import matplotlib.pyplot as plt
# 自定义模块
from ML_basic_function import *
接下来,我们尝试进行线性回归模型的手动建模实验。建模过程将遵照机器学习的一般建模流程,并且借助NumPy所提供的相关工具来进行实现。通过本次实验,我们将进一步深化对机器学习建模流程的理解,并且也将进一步熟悉对编程基础工具的掌握。
首先,是准备数据集。我们利用数据生成器创建一个扰动项不太大的数据集:
# 设置随机数种子
np.random.seed(24)
# 扰动项取值为0.01
features, labels = arrayGenReg(delta=0.01)
features
array([[ 1.32921217, -0.77003345, 1. ],
[-0.31628036, -0.99081039, 1. ],
[-1.07081626, -1.43871328, 1. ],
...,
[ 1.5507578 , -0.35986144, 1. ],
[-1.36267161, -0.61353562, 1. ],
[-1.44029131, 0.50439425, 1. ]])
其中,features也被称为特征矩阵、labels也被称为标签数组
接下来,选取模型对上述回归类问题数据进行建模。此处我们选取带有截距项的多元线性回归方程进行建模,基本模型为:
令$\hat w = [w_1,w_2,b]^T$,$\hat x = [x_1,x_2, 1]^T$,则上式可写为
注,此处如果要构建一个不带截距项的模型,则可另X为原始特征矩阵带入进行建模。
对于线性回归来说,我们可以参照SSE、MSE或者RMSE的计算过程构造损失函数。由于目前模型参数还只是隐式的值(在代码中并不显示),我们可以简单尝试,通过人工设置一组$\hat w$来计算SSE。
np.random.seed(24)
w = np.random.randn(3).reshape(-1, 1)
w
array([[ 1.32921217],
[-0.77003345],
[-0.31628036]])
此时模型输出的预测结果为:
y_hat = features.dot(w)
y_hat[:10]
array([[ 2.04347616],
[ 0.02627308],
[-0.63176501],
[ 0.20623364],
[-2.64718921],
[-0.86880796],
[ 0.88171608],
[-1.61055557],
[ 0.80113619],
[-0.49279524]])
据此,根据公式$SSE= ||y - X\hat w||_2^2 = (y - \hat y)^T(y - \hat y)$,SSE计算结果为:
(labels - y_hat).T.dot(labels - y_hat)
array([[2093.52940481]])
# 计算MSE
(labels - y_hat).T.dot(labels - y_hat) / len(labels)
array([[2.0935294]])
labels[:10]
array([[ 4.43811826],
[ 1.375912 ],
[ 0.30286597],
[ 1.81970897],
[-2.47783626],
[ 0.47374318],
[ 2.83085905],
[-0.83695165],
[ 2.84344069],
[ 0.8176895 ]])
能够看出,在当前参数取值下,$w$值随机生成的话,模型输出结果和当前真实结果相距甚远。
不过,为了后续快速计算SSE,我们可以将上述SSE计算过程封装为一个函数,令其在输入特征矩阵、标签数组和真实参数情况下即可输出SSE计算结果:
def SSELoss(X, w, y):
"""
SSE计算函数
:param X:输入数据的特征矩阵
:param w:线性方程参数
:param y:输入数据的标签数组
:return SSE:返回对应数据集预测结果和真实结果的误差平方和
"""
y_hat = X.dot(w)
SSE = (y - y_hat).T.dot(y - y_hat)
return SSE
# 简单测试函数性能
SSELoss(features, w, labels)
array([[2093.52940481]])
实验结束后,需要将上述SSELoss函数写入ML_basic_function.py中。
接下来,我们需要在SSELoss中找到一组最佳的参数取值,另模型预测结果和真实结果尽可能接近。此处我们利用Lesson 2中介绍的最小二乘法来进行求解,最小二乘法求解模型参数公式为:
值得注意的是,最小二乘法在进行求解过程中,需要特征矩阵的交叉乘积可逆,也就是$X^TX$必须存在逆矩阵。我们可以通过计算其行列式来判断该条件是否满足:
np.linalg.det(features.T.dot(features))
967456500.1798325
行列式不为0,因此$X^TX$逆矩阵存在,可以通过最小二乘法求解。具体求解方法分为两种,其一是使用NumPy中线性代数基本方法,根据上述公式进行求解,同时也可以直接使用lstsq函数进行求解。
w = np.linalg.inv(features.T.dot(features)).dot(features.T).dot(labels)
w
array([[ 1.99961892],
[-0.99985281],
[ 0.99970541]])
即可算出模型最优参数w。所谓模型最优参数,指的是参数取得任何其他数值,模型评估结果都不如该组参数时计算结果更好。首先,我们也可以计算此时模型SSE指标:
SSELoss(features, w, labels)
array([[0.09300731]])
明显小于此前所采用的随机w取值时SSE结果。此外,我们还可以计算模型MSE
SSELoss(features, w, labels) / len(y)
array([[9.30073138e-05]])
当然,由于数据集本身是依据$y=2x_1+x_2-1$规律构建的,因此从模型参数也能够看出模型预测效果较好。
模型评估指标中SSE、MSE、RMSE三者反应的是一个事实,我们根据SSE构建损失函数只是因为SSE计算函数能够非常方便的进行最小值推导,SSE取得最小值时MSE、RMSE也取得最小值。
当然,我们也可以利用lstsq函数进行最小二乘法结果求解。二者结果一致。
np.linalg.lstsq(features, labels, rcond=-1)
(array([[ 1.99961892],
[-0.99985281],
[ 0.99970541]]),
array([0.09300731]),
3,
array([32.70582436, 31.3166949 , 30.3678959 ]))
最终w参数取值为:
np.linalg.lstsq(features, labels, rcond=-1)[0]
array([[ 1.99961892],
[-0.99985281],
[ 0.99970541]])
至此,我们即完成了整个线性回归的机器学习建模流程。
尽管上述建模过程能够发现,面对白噪声不是很大、并且线性相关性非常明显的数据集,模型整体表现较好,但在实际应用中,大多数数据集可能都不具备明显的线性相关性,并且存在一定的白噪声(数据误差)。此时多元线性回归模型效果会受到极大影响。
plt.plot(features[:, 0], labels, 'o')
[<matplotlib.lines.Line2D at 0x1c069fca748>]
例如,此处创建一个满足$y=x^3+1$基本规律,并且白噪声很小的数据集进行建模测试
# 设置随机数种子
np.random.seed(24)
# 扰动项取值为0.01
features, labels = arrayGenReg(w=[2,1], deg=3, delta=0.01)
features
array([[ 1.32921217, 1. ],
[-0.77003345, 1. ],
[-0.31628036, 1. ],
...,
[ 0.84682091, 1. ],
[ 1.73889649, 1. ],
[ 1.93991673, 1. ]])
plt.plot(features[:, 0], labels, 'o')
[<matplotlib.lines.Line2D at 0x1c06a08f808>]
进行最小二乘法模型参数求解,直接使用numpy封装的函数
np.linalg.lstsq(features, labels, rcond=-1)
(array([[5.91925985],
[0.96963333]]),
array([28466.58711077]),
2,
array([32.13187164, 31.53261742]))
w = np.linalg.lstsq(features, labels, rcond=-1)[0]
w
array([[5.91925985],
[0.96963333]])
SSELoss(features, w, labels)
array([[28466.58711077]])
y_hat = features.dot(w)
y_hat
array([[ 8.83758557e+00],
[-3.58839477e+00],
[-9.02512307e-01],
[-4.89523081e+00],
[-5.36880634e+00],
[-7.54648442e+00],
[ 4.31056333e+00],
[ 2.72008802e+00],
[-8.65747595e+00],
[ 2.26929679e+00],
[ 4.98765532e+00],
[ 1.21527295e+01],
[ 6.66122896e+00],
[ 1.58530262e+00],
[-1.87850922e+00],
[ 6.00235693e+00],
[ 9.57283160e+00],
[ 7.23065606e+00],
[ 1.94963550e+00],
[ 4.01816093e+00],
[-6.94403640e+00],
[ 4.30135465e+00],
[ 9.21430297e+00],
[ 5.94778537e-01],
[ 1.68981997e+00],
[ 8.11774654e+00],
[ 9.57556764e-01],
[ 1.06049793e+01],
[ 3.06796824e+00],
[ 7.11102898e+00],
[-1.31332760e+00],
[ 4.04657115e+00],
[ 1.09529557e+01],
[-6.87908708e+00],
[ 9.42815917e+00],
[-1.13977976e+01],
[ 2.01195383e-01],
[ 4.70778176e+00],
[-2.50223789e+00],
[ 2.69048103e+00],
[ 8.45218962e+00],
[ 2.68642457e+00],
[-1.06930163e+01],
[ 5.72816118e+00],
[ 7.06972854e+00],
[ 1.66868570e+00],
[ 8.40277804e-01],
[ 1.24689568e+00],
[-8.67137808e+00],
[-1.35285096e+00],
[ 1.10381328e+01],
[ 7.25191993e+00],
[ 5.08827548e+00],
[-1.61109631e+00],
[-1.00113464e+00],
[ 4.53382427e+00],
[ 1.61358344e+00],
[ 1.18726633e+00],
[-2.22063113e+00],
[ 3.92439695e+00],
[-3.24459412e+00],
[-4.33236995e-01],
[ 6.04314480e+00],
[-1.01724967e+01],
[ 3.45989605e+00],
[-5.35909332e+00],
[-1.42834635e+01],
[-6.28100682e+00],
[-5.87748151e+00],
[ 6.48226530e+00],
[ 6.79570714e+00],
[ 1.41468491e+01],
[-1.50760009e+00],
[-1.00608345e+00],
[ 9.69753807e-02],
[ 1.02114191e+01],
[ 5.02260972e+00],
[ 1.14402872e+00],
[-4.07638724e+00],
[ 1.26981564e+01],
[-8.67993598e+00],
[-4.21595911e-01],
[-1.54710011e-03],
[ 4.86734526e+00],
[-6.77211998e+00],
[ 8.98303875e+00],
[ 2.08874359e-01],
[-4.79145910e+00],
[-3.14098948e+00],
[ 6.80527990e-01],
[-2.46121858e+00],
[ 6.12593677e+00],
[-4.77394443e+00],
[ 3.47560787e+00],
[ 1.26556995e+01],
[ 2.78341517e+00],
[ 6.88452436e+00],
[ 6.13321737e+00],
[-1.00874274e+01],
[ 8.25472740e+00],
[ 2.08920978e-01],
[ 5.95191250e+00],
[ 7.50765741e+00],
[ 3.73502394e+00],
[ 2.18414660e+00],
[-6.62325357e+00],
[ 5.77553430e+00],
[ 1.52309945e+00],
[ 3.78495734e+00],
[ 6.83337756e+00],
[-5.65260030e+00],
[ 6.84661196e+00],
[ 9.19321700e+00],
[-1.29206945e+00],
[-9.42404419e+00],
[ 1.34400030e+01],
[ 4.93251288e+00],
[ 3.87807361e+00],
[ 1.01506008e+00],
[-2.68849937e+00],
[-2.90873944e+00],
[ 1.75564891e+00],
[ 1.37783358e+00],
[-1.12484544e+00],
[-4.35672419e+00],
[ 1.85798634e+00],
[-5.86779704e+00],
[-8.72663440e+00],
[-7.54432962e+00],
[-8.89327885e+00],
[ 4.34549552e+00],
[-3.71324294e+00],
[ 5.31871039e+00],
[ 3.01085614e+00],
[ 1.10546961e+01],
[ 5.07292701e+00],
[ 5.35433412e+00],
[-2.51543377e+00],
[-9.79005996e+00],
[ 4.53722456e+00],
[ 7.67942726e+00],
[-1.64420720e+00],
[-2.30134628e+00],
[-1.55651823e+00],
[-3.38607010e+00],
[ 7.74897608e+00],
[-6.30554995e-02],
[ 5.31595714e+00],
[ 2.80202686e+00],
[-4.38996808e+00],
[ 1.21060852e+01],
[-3.70423941e-01],
[-3.79687976e+00],
[ 9.29415152e+00],
[ 4.80596309e+00],
[-1.19010242e+00],
[-5.07652463e+00],
[-7.85638661e+00],
[-4.93045128e+00],
[ 3.37973834e+00],
[-4.44496725e+00],
[ 6.90510608e+00],
[-1.66980604e+00],
[ 5.86013963e+00],
[-4.11768380e+00],
[ 1.75221858e+00],
[ 1.44600003e+00],
[ 9.17417088e-03],
[ 8.58842068e-01],
[-2.01572094e+00],
[-3.55803047e-02],
[ 1.18765279e+01],
[ 2.30374067e+00],
[ 9.18640736e+00],
[ 9.37541005e+00],
[ 4.30575741e+00],
[ 8.23993827e+00],
[ 3.83274090e+00],
[ 1.16147093e+00],
[-5.07683714e-01],
[-2.96918860e+00],
[-4.08541312e+00],
[ 1.24371552e+01],
[ 1.02922846e+00],
[ 1.32762876e+00],
[ 6.56897321e+00],
[-1.60192113e+01],
[ 3.05149689e+00],
[ 2.54113572e+00],
[ 1.59537139e+00],
[-2.44991257e+00],
[-1.42992822e+00],
[ 7.42331904e+00],
[ 9.41802438e+00],
[ 5.07156716e+00],
[ 2.06588285e+00],
[ 1.60554600e+00],
[-2.92430176e-01],
[-5.72791855e+00],
[ 2.81145393e+00],
[-3.77957572e+00],
[-9.53804514e+00],
[ 1.51218151e-01],
[-1.81759164e+00],
[-8.88384293e-01],
[-8.54380243e-01],
[-6.72951509e+00],
[-2.51163555e+00],
[-2.91754804e+00],
[-8.93748935e+00],
[ 2.86270030e+00],
[ 5.74400157e+00],
[-4.41873443e-01],
[-5.54889044e-01],
[ 1.99793416e-01],
[ 1.97784209e+00],
[-6.16342550e+00],
[ 2.90812161e-01],
[-2.16357252e+00],
[-5.61839332e+00],
[-5.66191666e-01],
[-4.90362591e+00],
[-1.75976015e+00],
[ 5.29519267e+00],
[ 1.07215531e+01],
[-2.32226151e+00],
[-6.72454624e+00],
[ 7.62441250e+00],
[-5.50603630e+00],
[ 2.65106968e+00],
[ 1.08471484e+00],
[ 1.87249937e+00],
[-1.00375660e+00],
[-7.03246426e+00],
[ 4.90215632e+00],
[-5.79103180e-01],
[-5.02994764e+00],
[ 1.71601251e+00],
[-3.74550809e+00],
[ 1.40718161e+01],
[ 4.00976617e+00],
[-4.61217194e+00],
[-1.15033854e-01],
[-3.15070254e+00],
[-4.66109819e+00],
[-4.51988890e+00],
[-1.07192244e+00],
[ 2.44151262e+00],
[ 5.33652817e-01],
[-1.34355703e+00],
[ 2.83839840e+00],
[ 6.78527979e+00],
[-2.27927855e+00],
[ 8.84306084e+00],
[ 9.54310590e+00],
[-4.85454642e+00],
[ 1.51025739e+00],
[-3.56917895e-01],
[ 3.92621392e+00],
[ 1.02052928e+01],
[ 6.45630477e+00],
[-3.54854025e+00],
[ 4.35876593e+00],
[ 1.84321305e+00],
[ 1.21246377e+01],
[ 1.68363558e+00],
[-1.00409288e+00],
[ 2.04127332e+00],
[ 8.29518244e-01],
[-5.77456391e+00],
[-2.33652988e+00],
[-3.68188422e+00],
[-5.19887716e+00],
[-5.88141128e+00],
[ 2.04772771e+00],
[ 1.50779810e+00],
[ 8.95100441e+00],
[-4.43914079e+00],
[-2.51122413e+00],
[-1.93724685e+00],
[ 1.06747209e+01],
[-3.33816172e-01],
[ 9.09857876e+00],
[-1.56402310e+01],
[-4.41404230e+00],
[ 1.42071898e+01],
[-2.44376901e+00],
[ 6.51277509e+00],
[-4.01975942e+00],
[-7.10767270e+00],
[ 3.81326494e+00],
[ 2.14267860e+00],
[ 1.75381993e+00],
[ 1.39570572e+01],
[-3.34252306e+00],
[ 3.78543807e-01],
[ 8.22884392e+00],
[ 8.94757704e+00],
[-2.68980725e+00],
[ 4.96460991e+00],
[ 2.34695220e+00],
[-5.80542019e+00],
[ 6.30542949e+00],
[-2.53703972e-01],
[-7.26190802e-01],
[-4.68368135e-01],
[ 5.19514492e+00],
[-9.83301478e+00],
[ 2.16244601e+00],
[-2.72011038e+00],
[-7.03899325e+00],
[-2.37437646e+00],
[-2.19750947e-01],
[ 7.93035479e-01],
[-5.48783227e+00],
[ 2.03326486e+00],
[-7.00562187e-01],
[-7.36853599e+00],
[ 4.53859726e-01],
[ 2.36555941e+00],
[-5.45885136e+00],
[ 8.74192673e+00],
[ 1.09239547e+01],
[-1.35543493e+01],
[ 2.51237655e-01],
[-6.27202597e+00],
[-7.82886756e+00],
[ 6.01869065e+00],
[-6.61841696e-01],
[-1.52695889e+00],
[ 3.74130190e+00],
[-1.07766971e+00],
[-8.08450599e+00],
[-1.92159900e+00],
[ 8.28223444e+00],
[ 7.57348991e-01],
[ 4.87212469e+00],
[-2.33102974e-01],
[-8.85136128e+00],
[-7.43257672e+00],
[ 4.36955794e+00],
[ 3.76697399e+00],
[ 1.18111677e+01],
[ 7.13456678e+00],
[ 3.80951631e-01],
[ 7.89735705e+00],
[-1.24124424e+01],
[-3.73110533e+00],
[-2.78855453e+00],
[-1.28164914e+01],
[ 1.19782507e+00],
[-5.85213353e+00],
[ 2.76743381e-01],
[ 6.03010152e+00],
[-4.92606824e+00],
[-2.09304613e+00],
[ 4.98737450e+00],
[-6.52983609e+00],
[ 8.50127372e+00],
[ 6.32092764e-01],
[-8.99262029e+00],
[-1.04544064e+00],
[ 7.91557989e+00],
[-5.10288545e+00],
[ 6.00199124e+00],
[-3.16918282e+00],
[-5.05021103e+00],
[ 1.44358240e+00],
[ 7.51037238e+00],
[-4.62539865e+00],
[-9.15683421e+00],
[ 3.59363997e+00],
[ 5.02032578e+00],
[ 2.85642311e+00],
[ 5.37684546e+00],
[ 1.12122832e-01],
[-3.05482584e+00],
[-1.19749974e+00],
[ 1.00503733e+01],
[-5.86675495e-01],
[ 1.99730695e+00],
[ 4.29368066e+00],
[ 6.92385865e+00],
[-4.19493607e+00],
[-1.19984361e+00],
[-3.88486837e+00],
[ 1.02205575e+01],
[ 3.64358765e-01],
[-1.93760783e+00],
[ 7.59803677e+00],
[-7.98762295e-01],
[ 3.24236253e+00],
[ 1.06397102e+01],
[ 8.48083223e+00],
[ 1.11949283e+01],
[-3.10434164e+00],
[-2.21737016e+00],
[ 2.16953035e+00],
[ 3.51538430e+00],
[-7.20976979e+00],
[ 3.04470232e-01],
[ 1.01365360e+01],
[ 1.73494317e+00],
[-3.12716389e+00],
[-1.30761278e+00],
[-7.08944587e+00],
[ 7.77688768e+00],
[ 5.63764843e+00],
[-9.42298711e+00],
[ 5.34381877e+00],
[ 6.03743768e+00],
[ 8.04306815e+00],
[-9.12580173e+00],
[ 9.14047426e+00],
[ 8.32115594e+00],
[ 1.40686847e+00],
[-6.92908580e-01],
[-4.17650402e+00],
[-1.06097427e+01],
[ 6.38254779e+00],
[-1.57135909e+01],
[ 2.37847416e-01],
[ 5.13682120e+00],
[ 1.18176059e+00],
[ 5.57261866e+00],
[ 3.64628032e-01],
[-3.97003495e+00],
[-2.94845530e+00],
[-6.06125326e+00],
[-8.15744414e-01],
[-4.48451083e+00],
[ 7.06577491e+00],
[-6.33714189e+00],
[ 3.44506334e-02],
[ 3.23376069e+00],
[ 2.49416167e+00],
[ 7.00223140e+00],
[-8.28359419e+00],
[-5.99726292e+00],
[-5.00351039e+00],
[ 6.44663174e+00],
[ 4.77035282e-01],
[-1.19319850e+00],
[ 5.79403562e+00],
[ 3.02391308e+00],
[-8.55830274e+00],
[ 5.74759564e-01],
[ 2.40524432e+00],
[ 1.02151134e+01],
[ 4.19426385e+00],
[ 2.91576867e+00],
[ 2.53294049e+00],
[-5.28196156e+00],
[-8.47150946e+00],
[-5.05970472e+00],
[ 9.14129438e+00],
[ 3.93624236e-01],
[-5.32255542e-02],
[-1.52691627e+01],
[-1.16617862e+00],
[-1.74393776e+00],
[-6.39748132e+00],
[-4.18193080e+00],
[ 4.97351924e+00],
[-8.40941116e-01],
[-7.55135277e+00],
[-1.91529498e+00],
[-2.84313459e-01],
[-2.87533601e+00],
[ 1.46239808e+00],
[ 7.52767224e+00],
[ 3.65547697e+00],
[-1.10260149e+01],
[-6.36396420e+00],
[-1.78428200e+00],
[ 3.81590207e+00],
[ 7.08055946e+00],
[ 2.67574632e+00],
[-1.53430241e+00],
[ 7.00686336e+00],
[ 1.13174957e+01],
[-6.52259004e+00],
[-1.82537691e+00],
[ 1.83537626e+00],
[ 1.26272939e+00],
[ 2.71405129e+00],
[ 7.09457962e+00],
[ 1.07659229e+00],
[-4.32430234e+00],
[ 7.12993138e+00],
[ 2.43425014e+00],
[ 5.44854383e+00],
[-1.08375343e+01],
[-3.13064630e+00],
[ 1.53288107e+01],
[-7.29643136e+00],
[-9.69000277e+00],
[ 9.36228686e-01],
[ 6.79112641e+00],
[-6.87628698e+00],
[-6.64632539e-01],
[ 1.99670570e+00],
[ 1.01477393e+00],
[ 5.30093797e+00],
[ 1.26788009e+01],
[ 2.95834594e+00],
[ 6.76251395e-01],
[ 1.09998634e+01],
[ 8.66295496e+00],
[-1.67015000e+00],
[ 1.59985463e+00],
[ 3.76713563e+00],
[ 9.43996118e+00],
[ 6.36844165e+00],
[ 2.21187178e-01],
[ 4.05248486e+00],
[-4.45053916e+00],
[-6.00836991e-01],
[ 6.62727384e+00],
[-3.55840715e-02],
[ 3.65421836e+00],
[ 7.42106110e+00],
[-2.34652930e+00],
[ 1.12752804e+01],
[-1.07737300e+01],
[-3.95511809e+00],
[ 7.42417097e+00],
[ 7.11743022e+00],
[-3.24137314e+00],
[-9.11895388e+00],
[ 4.94005931e+00],
[ 2.64404294e+00],
[-4.14424348e+00],
[ 2.25363002e+00],
[ 5.87718018e+00],
[-6.13403519e+00],
[-7.77828932e+00],
[ 2.94835043e+00],
[ 5.10001455e+00],
[ 1.62070353e+01],
[-1.59139503e+00],
[ 1.00031266e+01],
[ 3.41131545e+00],
[ 3.36611250e+00],
[ 4.63699041e+00],
[ 6.29486189e+00],
[-1.45442183e+00],
[ 1.88411227e+00],
[ 1.72337794e+00],
[-8.96764593e+00],
[ 1.05507421e+01],
[ 2.90230176e+00],
[-6.09435734e+00],
[ 3.91008555e+00],
[-8.05992833e+00],
[-3.49026675e+00],
[ 3.63509138e-01],
[ 6.58472400e+00],
[ 6.64283199e+00],
[ 1.90977113e+00],
[ 2.33252487e+00],
[-3.23102753e-01],
[ 7.80064568e+00],
[ 3.29674817e+00],
[ 2.90709521e+00],
[ 5.21314748e-01],
[-6.21316291e-01],
[ 6.41298199e+00],
[ 3.15085273e+00],
[-6.74226153e-03],
[-7.31426492e+00],
[ 3.34382173e+00],
[ 6.32762483e+00],
[-5.77048958e+00],
[-1.24795757e+01],
[ 2.86454967e+00],
[ 7.05605649e-01],
[ 1.53642691e+00],
[ 4.13226268e+00],
[ 8.61470126e+00],
[-4.73546591e+00],
[-2.61780622e+00],
[-7.59664141e+00],
[ 3.27055944e+00],
[-3.39901981e-01],
[-2.43457779e+00],
[ 1.44743669e+00],
[-1.04620814e+01],
[ 2.92615010e+00],
[ 3.76399122e+00],
[ 7.35590240e+00],
[ 1.05943629e+01],
[ 4.72591399e+00],
[-3.26902804e+00],
[-2.01207307e+00],
[ 8.80477316e+00],
[-1.10927280e+01],
[ 4.34004970e+00],
[ 1.80573727e+00],
[ 2.44186949e-01],
[-8.53157912e+00],
[-1.16743408e+01],
[ 9.68644740e+00],
[ 1.40241457e+01],
[ 3.60384263e+00],
[ 5.67655564e+00],
[-5.13046491e+00],
[-1.88495233e+00],
[ 1.28387057e+00],
[ 8.79086720e-01],
[-2.27150130e+00],
[-5.15486809e+00],
[-5.60552360e+00],
[ 7.29416678e+00],
[ 5.51951668e+00],
[ 6.54425132e+00],
[ 6.61936901e+00],
[ 8.70231833e+00],
[-2.83846717e+00],
[ 7.17836261e+00],
[-3.29528596e+00],
[ 1.34795344e+00],
[ 8.23426169e+00],
[ 4.68026854e+00],
[-6.95321231e-02],
[ 7.48012135e+00],
[ 5.66938673e+00],
[-6.08116384e+00],
[ 5.11449277e-02],
[-1.05142300e+01],
[-7.30032720e+00],
[-1.00682906e+00],
[-4.40845938e+00],
[ 7.04496880e+00],
[ 7.34269391e+00],
[-2.37617243e+00],
[ 7.88416377e-01],
[ 9.95485682e+00],
[ 9.57928439e+00],
[ 6.18982798e-01],
[ 5.25750689e+00],
[ 2.59672911e+00],
[ 1.31148932e+01],
[ 1.01344852e+01],
[ 1.16132475e+00],
[ 3.58054674e+00],
[ 1.57291870e+00],
[ 1.79108037e+00],
[-3.45403549e+00],
[-5.54004739e+00],
[-1.19000115e+00],
[ 4.18498187e+00],
[-7.28403730e+00],
[-4.53492852e+00],
[-9.27591892e-01],
[-7.08785174e+00],
[ 7.69841472e+00],
[ 4.79928490e+00],
[ 5.37084949e+00],
[-3.70749948e+00],
[-5.39278024e+00],
[-3.06773289e+00],
[ 3.06673148e+00],
[-3.79874804e+00],
[ 5.55605504e+00],
[ 2.27935443e+00],
[ 5.91729848e+00],
[ 1.25424142e-01],
[ 5.55948682e+00],
[ 6.05015053e+00],
[ 8.23766742e+00],
[-5.18341368e+00],
[-7.78004082e+00],
[-3.77206923e+00],
[ 8.43849400e+00],
[-3.42218246e-01],
[ 7.63918340e+00],
[ 2.68225664e+00],
[ 5.88768660e+00],
[ 8.00940423e+00],
[-3.05163967e+00],
[-3.06302995e-01],
[ 5.13960415e+00],
[ 4.72403262e+00],
[-6.12296457e+00],
[ 6.74076544e+00],
[-9.28946988e+00],
[ 4.58978415e-01],
[ 2.17685463e-01],
[ 6.78672271e+00],
[ 6.51741330e+00],
[ 2.60401886e+00],
[ 8.39785741e+00],
[ 2.52359890e+00],
[-6.39753922e+00],
[-3.14275018e+00],
[ 6.60712859e+00],
[ 4.23328533e+00],
[ 4.14377274e+00],
[ 6.94386244e+00],
[ 2.77992741e+00],
[-6.81942856e+00],
[ 7.36501543e+00],
[ 1.36049435e+01],
[ 8.23390601e+00],
[-3.91066490e+00],
[-3.34233662e+00],
[-6.95976347e+00],
[ 4.29514208e+00],
[ 2.68840196e+00],
[-4.06318264e+00],
[ 4.24837207e+00],
[-2.36996555e+00],
[ 3.05940467e+00],
[ 6.11522913e+00],
[-7.00302840e+00],
[ 1.63027432e+01],
[-7.80228600e+00],
[-8.68671272e-01],
[ 6.52517125e+00],
[-4.55357487e+00],
[-3.37834225e+00],
[ 6.86519476e+00],
[ 3.05838316e+00],
[-3.18749494e+00],
[-2.71038284e+00],
[ 8.43879759e+00],
[ 1.26580781e+01],
[ 1.01629243e+01],
[-9.14190408e+00],
[-1.78997719e+00],
[-1.14437745e+01],
[-4.46041084e+00],
[-1.07451005e+00],
[ 4.76564492e+00],
[ 1.04751922e+01],
[-5.32395967e+00],
[-3.37011613e+00],
[ 2.02597858e-01],
[ 9.68359079e+00],
[-6.96087787e+00],
[-4.38241475e+00],
[ 2.59024684e+00],
[ 1.36946094e+01],
[-4.40708083e+00],
[ 5.68561498e+00],
[-3.69718829e+00],
[ 6.86463201e+00],
[-3.04276684e+00],
[ 6.93392012e-01],
[-1.79630393e+00],
[ 5.51796153e+00],
[ 1.57026602e-01],
[ 9.61468366e+00],
[ 3.51627248e+00],
[-4.61681280e+00],
[ 5.55920826e+00],
[-7.92017924e+00],
[-3.71644003e+00],
[ 1.77687697e-01],
[-1.09282007e+01],
[ 3.57330305e+00],
[-4.63217839e+00],
[-7.72004636e-02],
[-6.43847919e+00],
[ 8.76613884e+00],
[-4.18969202e+00],
[-5.97273377e+00],
[ 6.55821705e+00],
[-3.10111304e+00],
[-7.99684806e+00],
[-1.35950712e+00],
[ 6.87730741e+00],
[-5.05824930e-01],
[-3.26801593e+00],
[-1.21906768e+00],
[-2.92881446e+00],
[-7.17904474e+00],
[ 9.69777024e+00],
[-7.72623431e-01],
[ 6.36232484e+00],
[ 5.36387966e+00],
[-5.31242699e+00],
[ 9.63469589e-01],
[-7.89102704e-01],
[-1.84175026e+00],
[ 2.13067022e+00],
[-4.65782501e+00],
[-1.08937746e+01],
[ 5.54064451e+00],
[-9.17945922e-01],
[-2.10631708e+00],
[-6.30854048e+00],
[-2.89485244e+00],
[ 7.67127178e+00],
[-5.89359537e+00],
[ 6.49035820e+00],
[ 5.03227493e+00],
[ 1.51211666e+01],
[ 1.08700967e+00],
[ 3.51295260e+00],
[-1.05076352e+01],
[ 2.96952266e+00],
[ 6.25816306e+00],
[ 7.96471769e+00],
[ 9.04174485e+00],
[ 1.54585404e+00],
[ 4.41159763e+00],
[-6.56939175e+00],
[-5.73631150e+00],
[ 4.26333370e+00],
[ 1.71426498e+00],
[ 3.60652755e+00],
[ 4.54297940e+00],
[-2.70206402e+00],
[-1.19895871e+01],
[ 6.33580204e+00],
[-5.52166414e+00],
[ 2.53451520e+00],
[ 5.83348799e+00],
[-4.71613791e+00],
[-3.77772431e+00],
[ 1.45249281e+00],
[-6.81019973e-01],
[ 1.40399528e+01],
[ 9.77910610e+00],
[-5.63522374e+00],
[ 9.12427323e+00],
[ 5.66869541e-01],
[ 4.36878765e+00],
[-1.27102401e+00],
[ 5.60880509e+00],
[ 1.86207457e+01],
[-9.06016431e-01],
[-1.32972814e+01],
[ 2.96857455e+00],
[-2.72705550e+00],
[ 4.20000866e+00],
[ 3.37801566e+00],
[-3.70343102e-01],
[ 2.11411325e+00],
[ 1.40783349e+01],
[ 1.39710577e+00],
[-6.75476083e+00],
[ 3.95204839e+00],
[-1.55216782e+00],
[ 3.96793436e-01],
[-4.70607758e+00],
[ 8.80130607e+00],
[ 4.17132557e+00],
[-5.55083289e+00],
[-9.30471725e+00],
[ 3.08269744e+00],
[-1.53509348e+00],
[ 9.16584744e-01],
[-5.57384093e-01],
[-6.52226143e+00],
[ 4.88912630e+00],
[-5.17770088e-01],
[-5.86951085e-01],
[ 6.61337615e+00],
[ 5.86931747e+00],
[-6.09389200e+00],
[ 7.51763547e+00],
[ 1.39251802e+01],
[-3.81392402e+00],
[ 5.13016749e+00],
[ 1.28912126e+01],
[-1.04542587e+01],
[-2.97406874e+00],
[ 4.75625986e-03],
[ 1.88708435e+00],
[-1.52744258e+00],
[ 6.59945780e+00],
[-1.82024869e+00],
[ 7.64087923e+00],
[-1.16174243e+01],
[ 2.12108632e+00],
[ 1.50359731e+00],
[ 8.07711662e+00],
[-2.08422791e+00],
[ 7.42229794e+00],
[ 5.08637568e+00],
[ 5.42397145e+00],
[-3.50642042e+00],
[ 1.06281894e+01],
[-4.93897801e+00],
[ 2.47614058e+01],
[-7.38817603e+00],
[-7.38939762e+00],
[ 3.82736051e+00],
[ 7.61436215e+00],
[ 7.00040163e+00],
[ 2.94023044e+00],
[-2.54267629e-01],
[-4.42696489e+00],
[ 7.57222560e+00],
[ 8.13916680e+00],
[ 1.00072001e+00],
[ 3.55644609e-01],
[ 5.66744985e-01],
[ 2.23060848e+00],
[ 8.12906695e+00],
[-1.62078500e+00],
[-6.78766372e-01],
[-4.42784129e+00],
[ 7.40400077e+00],
[-1.23810437e-01],
[ 2.07875798e+00],
[ 1.47671911e+01],
[ 2.75597590e+00],
[ 4.73053466e+00],
[ 8.78038743e+00],
[ 4.86179521e-01],
[ 6.64412565e+00],
[ 9.70077796e+00],
[-1.02165591e+00],
[ 2.78495142e+00],
[ 3.67552815e+00],
[ 9.52630670e+00],
[ 2.25545442e+00],
[ 1.26982512e+01],
[ 3.90765230e+00],
[ 1.38238680e+01],
[-3.40174647e+00],
[ 3.32093728e+00],
[-3.38411311e+00],
[ 5.24682755e+00],
[ 4.61048041e-01],
[ 1.24965375e+01],
[ 1.96823873e+00],
[ 5.46546171e+00],
[-2.19492306e+00],
[-3.46464546e+00],
[ 1.13821169e+00],
[ 5.32229105e+00],
[ 3.41865106e+00],
[-1.77677274e+00],
[ 1.83889826e+00],
[-9.20030098e+00],
[-2.83993504e+00],
[ 3.42985591e+00],
[ 3.41469553e+00],
[ 6.47521959e+00],
[-1.16462115e+01],
[-4.77143723e+00],
[-4.25007104e+00],
[-2.37134525e+00],
[-2.72884747e+00],
[ 6.00640342e+00],
[ 1.28446067e+00],
[ 1.23547181e+00],
[ 6.69383884e+00],
[ 3.18055125e+00],
[ 7.08002475e+00],
[ 2.53581476e+00],
[-4.80575509e+00],
[ 5.37187541e+00],
[-2.96781016e+00],
[-6.40627443e+00],
[ 5.09670278e+00],
[ 8.55327137e+00],
[-1.65951180e+01],
[ 9.75830445e+00],
[-8.68610848e+00],
[-4.93733888e+00],
[ 1.04000538e+00],
[ 1.64159645e+00],
[-2.88926310e+00],
[-2.32919762e+00],
[-2.63562433e+00],
[-5.68279220e-01],
[ 7.78675183e-01],
[ 8.37641937e+00],
[-3.41330543e+00],
[ 7.25245284e-01],
[-5.83235275e-01],
[-4.59085651e+00],
[ 1.28484485e+00],
[ 1.12486876e+01],
[-4.86298271e+00],
[ 4.89120189e+00],
[ 5.45115655e-01],
[-1.90947576e+00],
[-7.73627902e+00],
[-2.27943920e+00],
[-9.18621392e+00],
[-1.53010365e+00],
[ 1.47136927e+01],
[-4.40800552e+00],
[-1.57169672e+00],
[ 8.37167459e+00],
[ 4.77947274e+00],
[-8.25886108e+00],
[-1.49518675e+01],
[-4.28216486e+00],
[ 2.18318115e+00],
[ 5.98218635e+00],
[ 1.12626135e+01],
[ 1.24525045e+01]])
plt.plot(features[:, 0], labels, 'o')
plt.plot(features[:, 0], y_hat, 'r-')
[<matplotlib.lines.Line2D at 0x1c0697ad108>]
[<matplotlib.lines.Line2D at 0x1c0694b4408>]
从模型结果能够看出,模型和数据集分布规律相差较大
此外,我们稍微增加模型白噪声,测试线性回归模型效果
# 设置随机数种子
np.random.seed(24)
# 扰动项取值为2
features, labels = arrayGenReg(w=[2,1], delta=2)
features
array([[ 1.32921217, 1. ],
[-0.77003345, 1. ],
[-0.31628036, 1. ],
...,
[ 0.84682091, 1. ],
[ 1.73889649, 1. ],
[ 1.93991673, 1. ]])
plt.plot(features[:, 0], labels, 'o')
[<matplotlib.lines.Line2D at 0x1c0696b0ac8>]
np.linalg.lstsq(features, labels, rcond=-1)
(array([[1.91605821],
[0.90602215]]),
array([3767.12804359]),
2,
array([32.13187164, 31.53261742]))
w = np.linalg.lstsq(features, labels, rcond=-1)[0]
w
array([[1.91605821],
[0.90602215]])
SSELoss(features, w, labels)
array([[3767.12804359]])
X = np.linspace(-5, 5, 1000)
y = w[0] * X + w[1]
plt.plot(features[:, 0], labels, 'o')
plt.plot(X, y, 'r-')
[<matplotlib.lines.Line2D at 0x1c0694efc08>]
[<matplotlib.lines.Line2D at 0x1c069535ec8>]
能够发现,模型误差较大。
并且,除此之外,线性回归模型还面临这一个重大问题就是,如果特征矩阵的交叉乘积不可逆,则最小二乘法求解过程就不成立了。
当然,此时也代表着数据集存在着较为严重的多重共线性,换而言之就是数据集的特征矩阵可能可以相互线性表出。这个时候矩阵方程$X^TX\hat w=X^Ty$并不一定存在唯一解。解决该问题的方法有很多种,从数学角度出发,我们可以从以下三个方面入手:
通过数学过程可以证明,此时如果我们在原有损失函数基础上添加一个关于参数$w$的1-范数($||\hat w||_1$)或者2-范数($||\hat w||_2$)的某个计算结果,则可令最小二乘法条件得到满足。此时最小二乘法计算结果由无偏估计变为有偏估计。例如,当我们在损失函数中加上$\lambda ||\hat w||_2^2$(其中$\lambda$为参数)时,模型损失函数为:
经过数学转化,上述矩阵表达式对$\hat w$求导后令其为零,则可解出:
其中$I$为单位矩阵。此时由于$(X^TX+\lambda I)$肯定是可逆矩阵,因此可以顺利求解出$\hat w$:
该过程也被称为岭回归。而类似的,如果是通过添加了$\hat w$的1-范数的某个表达式,从而构造损失函数如下:
则该过程被称为Lasso。而更进一步,如果构建的损失函数同时包含$\hat w$的1-范数和2-范数,形如如下形式:
则构建的是弹性网模型(Elastic-Net),其中$\lambda、\alpha$都是参数,n是样本个数。不难发现,岭回归和Lasso其实都是弹性网的一种特殊形式。更多关于线性模型的相关方法,我们将在后续逐渐介绍。
1-范数也被称为L1范数,将参数的1-范数添加入损失函数的做法,也被称为损失函数的L1正则化,L2正则化也类似。在大多数情况下,添加正则化项也可称为添加惩罚函数$p(w)$,核心作用是缓解模型过拟合倾向。
对于线性回归模型来说,除了SSE以外,我们还可使用决定系数(R-square,也被称为拟合优度检验)作为其模型评估指标。决定系数的计算需要使用之前介绍的组间误差平方和和离差平方和的概念。在回归分析中,SSR表示聚类中类似的组间平方和概念,表意为Sum of squares of the regression,由预测数据与标签均值之间差值的平方和计算的出:
而SST(Total sum of squares)则是实际值和均值之间的差值的平方和计算得到:SST可直接计算,是数据集本身的特征,类似于方差
并且,$SST$可由$SSR+SSE$计算得出。而决定系数,则由$SSR$和$SST$共同决定:
很明显,决定系数是一个鉴于[0,1]之间的值,并且约趋近于1,模型拟合效果越好。我们可以通过如下过程,进行决定系数的计算:
sst = np.power(labels - labels.mean(), 2).sum()
sse = SSELoss(features, w, labels)
r = 1-(sse/sst)
r
array([[0.99998114]])