Lesson 3.3 线性回归手动实现与模型局限¶

  在创建好了数据生成器之后,接下来即可进行手动线性回归建模实验。

In [2]:
# 科学计算模块
import numpy as np
import pandas as pd

# 绘图模块
import matplotlib as mpl
import matplotlib.pyplot as plt

# 自定义模块
from ML_basic_function import *

一、线性回归的手动实现¶

  接下来,我们尝试进行线性回归模型的手动建模实验。建模过程将遵照机器学习的一般建模流程,并且借助NumPy所提供的相关工具来进行实现。通过本次实验,我们将进一步深化对机器学习建模流程的理解,并且也将进一步熟悉对编程基础工具的掌握。

1.根据机器学习建模流程构建线性回归模型¶

  • Step 1.数据准备

  首先,是准备数据集。我们利用数据生成器创建一个扰动项不太大的数据集:

In [3]:
# 设置随机数种子
np.random.seed(24)   

# 扰动项取值为0.01
features, labels = arrayGenReg(delta=0.01)
In [4]:
features
Out[4]:
array([[ 1.32921217, -0.77003345,  1.        ],
       [-0.31628036, -0.99081039,  1.        ],
       [-1.07081626, -1.43871328,  1.        ],
       ...,
       [ 1.5507578 , -0.35986144,  1.        ],
       [-1.36267161, -0.61353562,  1.        ],
       [-1.44029131,  0.50439425,  1.        ]])

其中,features也被称为特征矩阵、labels也被称为标签数组

  • Step 2.模型选取

  接下来,选取模型对上述回归类问题数据进行建模。此处我们选取带有截距项的多元线性回归方程进行建模,基本模型为:

$$ f(x) = w_1x_1+w_2x_2+b $$

令$\hat w = [w_1,w_2,b]^T$,$\hat x = [x_1,x_2, 1]^T$,则上式可写为

$ f(x) = \hat w^T\hat x $

注,此处如果要构建一个不带截距项的模型,则可另X为原始特征矩阵带入进行建模。

  • Step 3.构造损失函数

  对于线性回归来说,我们可以参照SSE、MSE或者RMSE的计算过程构造损失函数。由于目前模型参数还只是隐式的值(在代码中并不显示),我们可以简单尝试,通过人工设置一组$\hat w$来计算SSE。

  • 另$\hat w$为一组随机值,计算SSE
In [5]:
np.random.seed(24)
w = np.random.randn(3).reshape(-1, 1)
w
Out[5]:
array([[ 1.32921217],
       [-0.77003345],
       [-0.31628036]])

  此时模型输出的预测结果为:

In [6]:
y_hat = features.dot(w)
y_hat[:10]
Out[6]:
array([[ 2.04347616],
       [ 0.02627308],
       [-0.63176501],
       [ 0.20623364],
       [-2.64718921],
       [-0.86880796],
       [ 0.88171608],
       [-1.61055557],
       [ 0.80113619],
       [-0.49279524]])

  据此,根据公式$SSE= ||y - X\hat w||_2^2 = (y - \hat y)^T(y - \hat y)$,SSE计算结果为:

In [7]:
(labels - y_hat).T.dot(labels - y_hat)
Out[7]:
array([[2093.52940481]])
In [8]:
# 计算MSE
(labels - y_hat).T.dot(labels - y_hat) / len(labels)
Out[8]:
array([[2.0935294]])
In [9]:
labels[:10]
Out[9]:
array([[ 4.43811826],
       [ 1.375912  ],
       [ 0.30286597],
       [ 1.81970897],
       [-2.47783626],
       [ 0.47374318],
       [ 2.83085905],
       [-0.83695165],
       [ 2.84344069],
       [ 0.8176895 ]])

能够看出,在当前参数取值下,$w$值随机生成的话,模型输出结果和当前真实结果相距甚远。

不过,为了后续快速计算SSE,我们可以将上述SSE计算过程封装为一个函数,令其在输入特征矩阵、标签数组和真实参数情况下即可输出SSE计算结果:

In [10]:
def SSELoss(X, w, y):
    """
    SSE计算函数
    
    :param X:输入数据的特征矩阵
    :param w:线性方程参数
    :param y:输入数据的标签数组
    :return SSE:返回对应数据集预测结果和真实结果的误差平方和 
    """
    y_hat = X.dot(w)
    SSE = (y - y_hat).T.dot(y - y_hat)
    return SSE
In [11]:
# 简单测试函数性能
SSELoss(features, w, labels)
Out[11]:
array([[2093.52940481]])

实验结束后,需要将上述SSELoss函数写入ML_basic_function.py中。

  • Step 4.利用最小二乘法求解损失函数

  接下来,我们需要在SSELoss中找到一组最佳的参数取值,另模型预测结果和真实结果尽可能接近。此处我们利用Lesson 2中介绍的最小二乘法来进行求解,最小二乘法求解模型参数公式为:

$$\hat w = (X^TX)^{-1}X^Ty$$

  值得注意的是,最小二乘法在进行求解过程中,需要特征矩阵的交叉乘积可逆,也就是$X^TX$必须存在逆矩阵。我们可以通过计算其行列式来判断该条件是否满足:

In [12]:
np.linalg.det(features.T.dot(features))
Out[12]:
967456500.1798325

  行列式不为0,因此$X^TX$逆矩阵存在,可以通过最小二乘法求解。具体求解方法分为两种,其一是使用NumPy中线性代数基本方法,根据上述公式进行求解,同时也可以直接使用lstsq函数进行求解。

  • 基础方法求解
In [13]:
w = np.linalg.inv(features.T.dot(features)).dot(features.T).dot(labels)
w
Out[13]:
array([[ 1.99961892],
       [-0.99985281],
       [ 0.99970541]])

  即可算出模型最优参数w。所谓模型最优参数,指的是参数取得任何其他数值,模型评估结果都不如该组参数时计算结果更好。首先,我们也可以计算此时模型SSE指标:

In [14]:
SSELoss(features, w, labels)
Out[14]:
array([[0.09300731]])

明显小于此前所采用的随机w取值时SSE结果。此外,我们还可以计算模型MSE

In [59]:
SSELoss(features, w, labels) / len(y)
Out[59]:
array([[9.30073138e-05]])

当然,由于数据集本身是依据$y=2x_1+x_2-1$规律构建的,因此从模型参数也能够看出模型预测效果较好。

模型评估指标中SSE、MSE、RMSE三者反应的是一个事实,我们根据SSE构建损失函数只是因为SSE计算函数能够非常方便的进行最小值推导,SSE取得最小值时MSE、RMSE也取得最小值。

  • 使用numpy的lstsq函数求解

  当然,我们也可以利用lstsq函数进行最小二乘法结果求解。二者结果一致。

In [285]:
np.linalg.lstsq(features, labels, rcond=-1)
Out[285]:
(array([[ 1.99961892],
        [-0.99985281],
        [ 0.99970541]]),
 array([0.09300731]),
 3,
 array([32.70582436, 31.3166949 , 30.3678959 ]))

最终w参数取值为:

In [286]:
np.linalg.lstsq(features, labels, rcond=-1)[0]
Out[286]:
array([[ 1.99961892],
       [-0.99985281],
       [ 0.99970541]])

  至此,我们即完成了整个线性回归的机器学习建模流程。

二、线性回归模型局限¶

  尽管上述建模过程能够发现,面对白噪声不是很大、并且线性相关性非常明显的数据集,模型整体表现较好,但在实际应用中,大多数数据集可能都不具备明显的线性相关性,并且存在一定的白噪声(数据误差)。此时多元线性回归模型效果会受到极大影响。

In [64]:
plt.plot(features[:, 0], labels, 'o')
Out[64]:
[<matplotlib.lines.Line2D at 0x1c069fca748>]
  • 非线性相关规律

  例如,此处创建一个满足$y=x^3+1$基本规律,并且白噪声很小的数据集进行建模测试

In [65]:
# 设置随机数种子
np.random.seed(24)   

# 扰动项取值为0.01
features, labels = arrayGenReg(w=[2,1], deg=3, delta=0.01)
In [66]:
features
Out[66]:
array([[ 1.32921217,  1.        ],
       [-0.77003345,  1.        ],
       [-0.31628036,  1.        ],
       ...,
       [ 0.84682091,  1.        ],
       [ 1.73889649,  1.        ],
       [ 1.93991673,  1.        ]])
In [38]:
plt.plot(features[:, 0], labels, 'o')
Out[38]:
[<matplotlib.lines.Line2D at 0x1c06a08f808>]

  进行最小二乘法模型参数求解,直接使用numpy封装的函数

In [67]:
np.linalg.lstsq(features, labels, rcond=-1)
Out[67]:
(array([[5.91925985],
        [0.96963333]]),
 array([28466.58711077]),
 2,
 array([32.13187164, 31.53261742]))
In [68]:
w = np.linalg.lstsq(features, labels, rcond=-1)[0]
In [69]:
w
Out[69]:
array([[5.91925985],
       [0.96963333]])
In [70]:
SSELoss(features, w, labels)
Out[70]:
array([[28466.58711077]])
In [73]:
y_hat = features.dot(w)
In [74]:
y_hat
Out[74]:
array([[ 8.83758557e+00],
       [-3.58839477e+00],
       [-9.02512307e-01],
       [-4.89523081e+00],
       [-5.36880634e+00],
       [-7.54648442e+00],
       [ 4.31056333e+00],
       [ 2.72008802e+00],
       [-8.65747595e+00],
       [ 2.26929679e+00],
       [ 4.98765532e+00],
       [ 1.21527295e+01],
       [ 6.66122896e+00],
       [ 1.58530262e+00],
       [-1.87850922e+00],
       [ 6.00235693e+00],
       [ 9.57283160e+00],
       [ 7.23065606e+00],
       [ 1.94963550e+00],
       [ 4.01816093e+00],
       [-6.94403640e+00],
       [ 4.30135465e+00],
       [ 9.21430297e+00],
       [ 5.94778537e-01],
       [ 1.68981997e+00],
       [ 8.11774654e+00],
       [ 9.57556764e-01],
       [ 1.06049793e+01],
       [ 3.06796824e+00],
       [ 7.11102898e+00],
       [-1.31332760e+00],
       [ 4.04657115e+00],
       [ 1.09529557e+01],
       [-6.87908708e+00],
       [ 9.42815917e+00],
       [-1.13977976e+01],
       [ 2.01195383e-01],
       [ 4.70778176e+00],
       [-2.50223789e+00],
       [ 2.69048103e+00],
       [ 8.45218962e+00],
       [ 2.68642457e+00],
       [-1.06930163e+01],
       [ 5.72816118e+00],
       [ 7.06972854e+00],
       [ 1.66868570e+00],
       [ 8.40277804e-01],
       [ 1.24689568e+00],
       [-8.67137808e+00],
       [-1.35285096e+00],
       [ 1.10381328e+01],
       [ 7.25191993e+00],
       [ 5.08827548e+00],
       [-1.61109631e+00],
       [-1.00113464e+00],
       [ 4.53382427e+00],
       [ 1.61358344e+00],
       [ 1.18726633e+00],
       [-2.22063113e+00],
       [ 3.92439695e+00],
       [-3.24459412e+00],
       [-4.33236995e-01],
       [ 6.04314480e+00],
       [-1.01724967e+01],
       [ 3.45989605e+00],
       [-5.35909332e+00],
       [-1.42834635e+01],
       [-6.28100682e+00],
       [-5.87748151e+00],
       [ 6.48226530e+00],
       [ 6.79570714e+00],
       [ 1.41468491e+01],
       [-1.50760009e+00],
       [-1.00608345e+00],
       [ 9.69753807e-02],
       [ 1.02114191e+01],
       [ 5.02260972e+00],
       [ 1.14402872e+00],
       [-4.07638724e+00],
       [ 1.26981564e+01],
       [-8.67993598e+00],
       [-4.21595911e-01],
       [-1.54710011e-03],
       [ 4.86734526e+00],
       [-6.77211998e+00],
       [ 8.98303875e+00],
       [ 2.08874359e-01],
       [-4.79145910e+00],
       [-3.14098948e+00],
       [ 6.80527990e-01],
       [-2.46121858e+00],
       [ 6.12593677e+00],
       [-4.77394443e+00],
       [ 3.47560787e+00],
       [ 1.26556995e+01],
       [ 2.78341517e+00],
       [ 6.88452436e+00],
       [ 6.13321737e+00],
       [-1.00874274e+01],
       [ 8.25472740e+00],
       [ 2.08920978e-01],
       [ 5.95191250e+00],
       [ 7.50765741e+00],
       [ 3.73502394e+00],
       [ 2.18414660e+00],
       [-6.62325357e+00],
       [ 5.77553430e+00],
       [ 1.52309945e+00],
       [ 3.78495734e+00],
       [ 6.83337756e+00],
       [-5.65260030e+00],
       [ 6.84661196e+00],
       [ 9.19321700e+00],
       [-1.29206945e+00],
       [-9.42404419e+00],
       [ 1.34400030e+01],
       [ 4.93251288e+00],
       [ 3.87807361e+00],
       [ 1.01506008e+00],
       [-2.68849937e+00],
       [-2.90873944e+00],
       [ 1.75564891e+00],
       [ 1.37783358e+00],
       [-1.12484544e+00],
       [-4.35672419e+00],
       [ 1.85798634e+00],
       [-5.86779704e+00],
       [-8.72663440e+00],
       [-7.54432962e+00],
       [-8.89327885e+00],
       [ 4.34549552e+00],
       [-3.71324294e+00],
       [ 5.31871039e+00],
       [ 3.01085614e+00],
       [ 1.10546961e+01],
       [ 5.07292701e+00],
       [ 5.35433412e+00],
       [-2.51543377e+00],
       [-9.79005996e+00],
       [ 4.53722456e+00],
       [ 7.67942726e+00],
       [-1.64420720e+00],
       [-2.30134628e+00],
       [-1.55651823e+00],
       [-3.38607010e+00],
       [ 7.74897608e+00],
       [-6.30554995e-02],
       [ 5.31595714e+00],
       [ 2.80202686e+00],
       [-4.38996808e+00],
       [ 1.21060852e+01],
       [-3.70423941e-01],
       [-3.79687976e+00],
       [ 9.29415152e+00],
       [ 4.80596309e+00],
       [-1.19010242e+00],
       [-5.07652463e+00],
       [-7.85638661e+00],
       [-4.93045128e+00],
       [ 3.37973834e+00],
       [-4.44496725e+00],
       [ 6.90510608e+00],
       [-1.66980604e+00],
       [ 5.86013963e+00],
       [-4.11768380e+00],
       [ 1.75221858e+00],
       [ 1.44600003e+00],
       [ 9.17417088e-03],
       [ 8.58842068e-01],
       [-2.01572094e+00],
       [-3.55803047e-02],
       [ 1.18765279e+01],
       [ 2.30374067e+00],
       [ 9.18640736e+00],
       [ 9.37541005e+00],
       [ 4.30575741e+00],
       [ 8.23993827e+00],
       [ 3.83274090e+00],
       [ 1.16147093e+00],
       [-5.07683714e-01],
       [-2.96918860e+00],
       [-4.08541312e+00],
       [ 1.24371552e+01],
       [ 1.02922846e+00],
       [ 1.32762876e+00],
       [ 6.56897321e+00],
       [-1.60192113e+01],
       [ 3.05149689e+00],
       [ 2.54113572e+00],
       [ 1.59537139e+00],
       [-2.44991257e+00],
       [-1.42992822e+00],
       [ 7.42331904e+00],
       [ 9.41802438e+00],
       [ 5.07156716e+00],
       [ 2.06588285e+00],
       [ 1.60554600e+00],
       [-2.92430176e-01],
       [-5.72791855e+00],
       [ 2.81145393e+00],
       [-3.77957572e+00],
       [-9.53804514e+00],
       [ 1.51218151e-01],
       [-1.81759164e+00],
       [-8.88384293e-01],
       [-8.54380243e-01],
       [-6.72951509e+00],
       [-2.51163555e+00],
       [-2.91754804e+00],
       [-8.93748935e+00],
       [ 2.86270030e+00],
       [ 5.74400157e+00],
       [-4.41873443e-01],
       [-5.54889044e-01],
       [ 1.99793416e-01],
       [ 1.97784209e+00],
       [-6.16342550e+00],
       [ 2.90812161e-01],
       [-2.16357252e+00],
       [-5.61839332e+00],
       [-5.66191666e-01],
       [-4.90362591e+00],
       [-1.75976015e+00],
       [ 5.29519267e+00],
       [ 1.07215531e+01],
       [-2.32226151e+00],
       [-6.72454624e+00],
       [ 7.62441250e+00],
       [-5.50603630e+00],
       [ 2.65106968e+00],
       [ 1.08471484e+00],
       [ 1.87249937e+00],
       [-1.00375660e+00],
       [-7.03246426e+00],
       [ 4.90215632e+00],
       [-5.79103180e-01],
       [-5.02994764e+00],
       [ 1.71601251e+00],
       [-3.74550809e+00],
       [ 1.40718161e+01],
       [ 4.00976617e+00],
       [-4.61217194e+00],
       [-1.15033854e-01],
       [-3.15070254e+00],
       [-4.66109819e+00],
       [-4.51988890e+00],
       [-1.07192244e+00],
       [ 2.44151262e+00],
       [ 5.33652817e-01],
       [-1.34355703e+00],
       [ 2.83839840e+00],
       [ 6.78527979e+00],
       [-2.27927855e+00],
       [ 8.84306084e+00],
       [ 9.54310590e+00],
       [-4.85454642e+00],
       [ 1.51025739e+00],
       [-3.56917895e-01],
       [ 3.92621392e+00],
       [ 1.02052928e+01],
       [ 6.45630477e+00],
       [-3.54854025e+00],
       [ 4.35876593e+00],
       [ 1.84321305e+00],
       [ 1.21246377e+01],
       [ 1.68363558e+00],
       [-1.00409288e+00],
       [ 2.04127332e+00],
       [ 8.29518244e-01],
       [-5.77456391e+00],
       [-2.33652988e+00],
       [-3.68188422e+00],
       [-5.19887716e+00],
       [-5.88141128e+00],
       [ 2.04772771e+00],
       [ 1.50779810e+00],
       [ 8.95100441e+00],
       [-4.43914079e+00],
       [-2.51122413e+00],
       [-1.93724685e+00],
       [ 1.06747209e+01],
       [-3.33816172e-01],
       [ 9.09857876e+00],
       [-1.56402310e+01],
       [-4.41404230e+00],
       [ 1.42071898e+01],
       [-2.44376901e+00],
       [ 6.51277509e+00],
       [-4.01975942e+00],
       [-7.10767270e+00],
       [ 3.81326494e+00],
       [ 2.14267860e+00],
       [ 1.75381993e+00],
       [ 1.39570572e+01],
       [-3.34252306e+00],
       [ 3.78543807e-01],
       [ 8.22884392e+00],
       [ 8.94757704e+00],
       [-2.68980725e+00],
       [ 4.96460991e+00],
       [ 2.34695220e+00],
       [-5.80542019e+00],
       [ 6.30542949e+00],
       [-2.53703972e-01],
       [-7.26190802e-01],
       [-4.68368135e-01],
       [ 5.19514492e+00],
       [-9.83301478e+00],
       [ 2.16244601e+00],
       [-2.72011038e+00],
       [-7.03899325e+00],
       [-2.37437646e+00],
       [-2.19750947e-01],
       [ 7.93035479e-01],
       [-5.48783227e+00],
       [ 2.03326486e+00],
       [-7.00562187e-01],
       [-7.36853599e+00],
       [ 4.53859726e-01],
       [ 2.36555941e+00],
       [-5.45885136e+00],
       [ 8.74192673e+00],
       [ 1.09239547e+01],
       [-1.35543493e+01],
       [ 2.51237655e-01],
       [-6.27202597e+00],
       [-7.82886756e+00],
       [ 6.01869065e+00],
       [-6.61841696e-01],
       [-1.52695889e+00],
       [ 3.74130190e+00],
       [-1.07766971e+00],
       [-8.08450599e+00],
       [-1.92159900e+00],
       [ 8.28223444e+00],
       [ 7.57348991e-01],
       [ 4.87212469e+00],
       [-2.33102974e-01],
       [-8.85136128e+00],
       [-7.43257672e+00],
       [ 4.36955794e+00],
       [ 3.76697399e+00],
       [ 1.18111677e+01],
       [ 7.13456678e+00],
       [ 3.80951631e-01],
       [ 7.89735705e+00],
       [-1.24124424e+01],
       [-3.73110533e+00],
       [-2.78855453e+00],
       [-1.28164914e+01],
       [ 1.19782507e+00],
       [-5.85213353e+00],
       [ 2.76743381e-01],
       [ 6.03010152e+00],
       [-4.92606824e+00],
       [-2.09304613e+00],
       [ 4.98737450e+00],
       [-6.52983609e+00],
       [ 8.50127372e+00],
       [ 6.32092764e-01],
       [-8.99262029e+00],
       [-1.04544064e+00],
       [ 7.91557989e+00],
       [-5.10288545e+00],
       [ 6.00199124e+00],
       [-3.16918282e+00],
       [-5.05021103e+00],
       [ 1.44358240e+00],
       [ 7.51037238e+00],
       [-4.62539865e+00],
       [-9.15683421e+00],
       [ 3.59363997e+00],
       [ 5.02032578e+00],
       [ 2.85642311e+00],
       [ 5.37684546e+00],
       [ 1.12122832e-01],
       [-3.05482584e+00],
       [-1.19749974e+00],
       [ 1.00503733e+01],
       [-5.86675495e-01],
       [ 1.99730695e+00],
       [ 4.29368066e+00],
       [ 6.92385865e+00],
       [-4.19493607e+00],
       [-1.19984361e+00],
       [-3.88486837e+00],
       [ 1.02205575e+01],
       [ 3.64358765e-01],
       [-1.93760783e+00],
       [ 7.59803677e+00],
       [-7.98762295e-01],
       [ 3.24236253e+00],
       [ 1.06397102e+01],
       [ 8.48083223e+00],
       [ 1.11949283e+01],
       [-3.10434164e+00],
       [-2.21737016e+00],
       [ 2.16953035e+00],
       [ 3.51538430e+00],
       [-7.20976979e+00],
       [ 3.04470232e-01],
       [ 1.01365360e+01],
       [ 1.73494317e+00],
       [-3.12716389e+00],
       [-1.30761278e+00],
       [-7.08944587e+00],
       [ 7.77688768e+00],
       [ 5.63764843e+00],
       [-9.42298711e+00],
       [ 5.34381877e+00],
       [ 6.03743768e+00],
       [ 8.04306815e+00],
       [-9.12580173e+00],
       [ 9.14047426e+00],
       [ 8.32115594e+00],
       [ 1.40686847e+00],
       [-6.92908580e-01],
       [-4.17650402e+00],
       [-1.06097427e+01],
       [ 6.38254779e+00],
       [-1.57135909e+01],
       [ 2.37847416e-01],
       [ 5.13682120e+00],
       [ 1.18176059e+00],
       [ 5.57261866e+00],
       [ 3.64628032e-01],
       [-3.97003495e+00],
       [-2.94845530e+00],
       [-6.06125326e+00],
       [-8.15744414e-01],
       [-4.48451083e+00],
       [ 7.06577491e+00],
       [-6.33714189e+00],
       [ 3.44506334e-02],
       [ 3.23376069e+00],
       [ 2.49416167e+00],
       [ 7.00223140e+00],
       [-8.28359419e+00],
       [-5.99726292e+00],
       [-5.00351039e+00],
       [ 6.44663174e+00],
       [ 4.77035282e-01],
       [-1.19319850e+00],
       [ 5.79403562e+00],
       [ 3.02391308e+00],
       [-8.55830274e+00],
       [ 5.74759564e-01],
       [ 2.40524432e+00],
       [ 1.02151134e+01],
       [ 4.19426385e+00],
       [ 2.91576867e+00],
       [ 2.53294049e+00],
       [-5.28196156e+00],
       [-8.47150946e+00],
       [-5.05970472e+00],
       [ 9.14129438e+00],
       [ 3.93624236e-01],
       [-5.32255542e-02],
       [-1.52691627e+01],
       [-1.16617862e+00],
       [-1.74393776e+00],
       [-6.39748132e+00],
       [-4.18193080e+00],
       [ 4.97351924e+00],
       [-8.40941116e-01],
       [-7.55135277e+00],
       [-1.91529498e+00],
       [-2.84313459e-01],
       [-2.87533601e+00],
       [ 1.46239808e+00],
       [ 7.52767224e+00],
       [ 3.65547697e+00],
       [-1.10260149e+01],
       [-6.36396420e+00],
       [-1.78428200e+00],
       [ 3.81590207e+00],
       [ 7.08055946e+00],
       [ 2.67574632e+00],
       [-1.53430241e+00],
       [ 7.00686336e+00],
       [ 1.13174957e+01],
       [-6.52259004e+00],
       [-1.82537691e+00],
       [ 1.83537626e+00],
       [ 1.26272939e+00],
       [ 2.71405129e+00],
       [ 7.09457962e+00],
       [ 1.07659229e+00],
       [-4.32430234e+00],
       [ 7.12993138e+00],
       [ 2.43425014e+00],
       [ 5.44854383e+00],
       [-1.08375343e+01],
       [-3.13064630e+00],
       [ 1.53288107e+01],
       [-7.29643136e+00],
       [-9.69000277e+00],
       [ 9.36228686e-01],
       [ 6.79112641e+00],
       [-6.87628698e+00],
       [-6.64632539e-01],
       [ 1.99670570e+00],
       [ 1.01477393e+00],
       [ 5.30093797e+00],
       [ 1.26788009e+01],
       [ 2.95834594e+00],
       [ 6.76251395e-01],
       [ 1.09998634e+01],
       [ 8.66295496e+00],
       [-1.67015000e+00],
       [ 1.59985463e+00],
       [ 3.76713563e+00],
       [ 9.43996118e+00],
       [ 6.36844165e+00],
       [ 2.21187178e-01],
       [ 4.05248486e+00],
       [-4.45053916e+00],
       [-6.00836991e-01],
       [ 6.62727384e+00],
       [-3.55840715e-02],
       [ 3.65421836e+00],
       [ 7.42106110e+00],
       [-2.34652930e+00],
       [ 1.12752804e+01],
       [-1.07737300e+01],
       [-3.95511809e+00],
       [ 7.42417097e+00],
       [ 7.11743022e+00],
       [-3.24137314e+00],
       [-9.11895388e+00],
       [ 4.94005931e+00],
       [ 2.64404294e+00],
       [-4.14424348e+00],
       [ 2.25363002e+00],
       [ 5.87718018e+00],
       [-6.13403519e+00],
       [-7.77828932e+00],
       [ 2.94835043e+00],
       [ 5.10001455e+00],
       [ 1.62070353e+01],
       [-1.59139503e+00],
       [ 1.00031266e+01],
       [ 3.41131545e+00],
       [ 3.36611250e+00],
       [ 4.63699041e+00],
       [ 6.29486189e+00],
       [-1.45442183e+00],
       [ 1.88411227e+00],
       [ 1.72337794e+00],
       [-8.96764593e+00],
       [ 1.05507421e+01],
       [ 2.90230176e+00],
       [-6.09435734e+00],
       [ 3.91008555e+00],
       [-8.05992833e+00],
       [-3.49026675e+00],
       [ 3.63509138e-01],
       [ 6.58472400e+00],
       [ 6.64283199e+00],
       [ 1.90977113e+00],
       [ 2.33252487e+00],
       [-3.23102753e-01],
       [ 7.80064568e+00],
       [ 3.29674817e+00],
       [ 2.90709521e+00],
       [ 5.21314748e-01],
       [-6.21316291e-01],
       [ 6.41298199e+00],
       [ 3.15085273e+00],
       [-6.74226153e-03],
       [-7.31426492e+00],
       [ 3.34382173e+00],
       [ 6.32762483e+00],
       [-5.77048958e+00],
       [-1.24795757e+01],
       [ 2.86454967e+00],
       [ 7.05605649e-01],
       [ 1.53642691e+00],
       [ 4.13226268e+00],
       [ 8.61470126e+00],
       [-4.73546591e+00],
       [-2.61780622e+00],
       [-7.59664141e+00],
       [ 3.27055944e+00],
       [-3.39901981e-01],
       [-2.43457779e+00],
       [ 1.44743669e+00],
       [-1.04620814e+01],
       [ 2.92615010e+00],
       [ 3.76399122e+00],
       [ 7.35590240e+00],
       [ 1.05943629e+01],
       [ 4.72591399e+00],
       [-3.26902804e+00],
       [-2.01207307e+00],
       [ 8.80477316e+00],
       [-1.10927280e+01],
       [ 4.34004970e+00],
       [ 1.80573727e+00],
       [ 2.44186949e-01],
       [-8.53157912e+00],
       [-1.16743408e+01],
       [ 9.68644740e+00],
       [ 1.40241457e+01],
       [ 3.60384263e+00],
       [ 5.67655564e+00],
       [-5.13046491e+00],
       [-1.88495233e+00],
       [ 1.28387057e+00],
       [ 8.79086720e-01],
       [-2.27150130e+00],
       [-5.15486809e+00],
       [-5.60552360e+00],
       [ 7.29416678e+00],
       [ 5.51951668e+00],
       [ 6.54425132e+00],
       [ 6.61936901e+00],
       [ 8.70231833e+00],
       [-2.83846717e+00],
       [ 7.17836261e+00],
       [-3.29528596e+00],
       [ 1.34795344e+00],
       [ 8.23426169e+00],
       [ 4.68026854e+00],
       [-6.95321231e-02],
       [ 7.48012135e+00],
       [ 5.66938673e+00],
       [-6.08116384e+00],
       [ 5.11449277e-02],
       [-1.05142300e+01],
       [-7.30032720e+00],
       [-1.00682906e+00],
       [-4.40845938e+00],
       [ 7.04496880e+00],
       [ 7.34269391e+00],
       [-2.37617243e+00],
       [ 7.88416377e-01],
       [ 9.95485682e+00],
       [ 9.57928439e+00],
       [ 6.18982798e-01],
       [ 5.25750689e+00],
       [ 2.59672911e+00],
       [ 1.31148932e+01],
       [ 1.01344852e+01],
       [ 1.16132475e+00],
       [ 3.58054674e+00],
       [ 1.57291870e+00],
       [ 1.79108037e+00],
       [-3.45403549e+00],
       [-5.54004739e+00],
       [-1.19000115e+00],
       [ 4.18498187e+00],
       [-7.28403730e+00],
       [-4.53492852e+00],
       [-9.27591892e-01],
       [-7.08785174e+00],
       [ 7.69841472e+00],
       [ 4.79928490e+00],
       [ 5.37084949e+00],
       [-3.70749948e+00],
       [-5.39278024e+00],
       [-3.06773289e+00],
       [ 3.06673148e+00],
       [-3.79874804e+00],
       [ 5.55605504e+00],
       [ 2.27935443e+00],
       [ 5.91729848e+00],
       [ 1.25424142e-01],
       [ 5.55948682e+00],
       [ 6.05015053e+00],
       [ 8.23766742e+00],
       [-5.18341368e+00],
       [-7.78004082e+00],
       [-3.77206923e+00],
       [ 8.43849400e+00],
       [-3.42218246e-01],
       [ 7.63918340e+00],
       [ 2.68225664e+00],
       [ 5.88768660e+00],
       [ 8.00940423e+00],
       [-3.05163967e+00],
       [-3.06302995e-01],
       [ 5.13960415e+00],
       [ 4.72403262e+00],
       [-6.12296457e+00],
       [ 6.74076544e+00],
       [-9.28946988e+00],
       [ 4.58978415e-01],
       [ 2.17685463e-01],
       [ 6.78672271e+00],
       [ 6.51741330e+00],
       [ 2.60401886e+00],
       [ 8.39785741e+00],
       [ 2.52359890e+00],
       [-6.39753922e+00],
       [-3.14275018e+00],
       [ 6.60712859e+00],
       [ 4.23328533e+00],
       [ 4.14377274e+00],
       [ 6.94386244e+00],
       [ 2.77992741e+00],
       [-6.81942856e+00],
       [ 7.36501543e+00],
       [ 1.36049435e+01],
       [ 8.23390601e+00],
       [-3.91066490e+00],
       [-3.34233662e+00],
       [-6.95976347e+00],
       [ 4.29514208e+00],
       [ 2.68840196e+00],
       [-4.06318264e+00],
       [ 4.24837207e+00],
       [-2.36996555e+00],
       [ 3.05940467e+00],
       [ 6.11522913e+00],
       [-7.00302840e+00],
       [ 1.63027432e+01],
       [-7.80228600e+00],
       [-8.68671272e-01],
       [ 6.52517125e+00],
       [-4.55357487e+00],
       [-3.37834225e+00],
       [ 6.86519476e+00],
       [ 3.05838316e+00],
       [-3.18749494e+00],
       [-2.71038284e+00],
       [ 8.43879759e+00],
       [ 1.26580781e+01],
       [ 1.01629243e+01],
       [-9.14190408e+00],
       [-1.78997719e+00],
       [-1.14437745e+01],
       [-4.46041084e+00],
       [-1.07451005e+00],
       [ 4.76564492e+00],
       [ 1.04751922e+01],
       [-5.32395967e+00],
       [-3.37011613e+00],
       [ 2.02597858e-01],
       [ 9.68359079e+00],
       [-6.96087787e+00],
       [-4.38241475e+00],
       [ 2.59024684e+00],
       [ 1.36946094e+01],
       [-4.40708083e+00],
       [ 5.68561498e+00],
       [-3.69718829e+00],
       [ 6.86463201e+00],
       [-3.04276684e+00],
       [ 6.93392012e-01],
       [-1.79630393e+00],
       [ 5.51796153e+00],
       [ 1.57026602e-01],
       [ 9.61468366e+00],
       [ 3.51627248e+00],
       [-4.61681280e+00],
       [ 5.55920826e+00],
       [-7.92017924e+00],
       [-3.71644003e+00],
       [ 1.77687697e-01],
       [-1.09282007e+01],
       [ 3.57330305e+00],
       [-4.63217839e+00],
       [-7.72004636e-02],
       [-6.43847919e+00],
       [ 8.76613884e+00],
       [-4.18969202e+00],
       [-5.97273377e+00],
       [ 6.55821705e+00],
       [-3.10111304e+00],
       [-7.99684806e+00],
       [-1.35950712e+00],
       [ 6.87730741e+00],
       [-5.05824930e-01],
       [-3.26801593e+00],
       [-1.21906768e+00],
       [-2.92881446e+00],
       [-7.17904474e+00],
       [ 9.69777024e+00],
       [-7.72623431e-01],
       [ 6.36232484e+00],
       [ 5.36387966e+00],
       [-5.31242699e+00],
       [ 9.63469589e-01],
       [-7.89102704e-01],
       [-1.84175026e+00],
       [ 2.13067022e+00],
       [-4.65782501e+00],
       [-1.08937746e+01],
       [ 5.54064451e+00],
       [-9.17945922e-01],
       [-2.10631708e+00],
       [-6.30854048e+00],
       [-2.89485244e+00],
       [ 7.67127178e+00],
       [-5.89359537e+00],
       [ 6.49035820e+00],
       [ 5.03227493e+00],
       [ 1.51211666e+01],
       [ 1.08700967e+00],
       [ 3.51295260e+00],
       [-1.05076352e+01],
       [ 2.96952266e+00],
       [ 6.25816306e+00],
       [ 7.96471769e+00],
       [ 9.04174485e+00],
       [ 1.54585404e+00],
       [ 4.41159763e+00],
       [-6.56939175e+00],
       [-5.73631150e+00],
       [ 4.26333370e+00],
       [ 1.71426498e+00],
       [ 3.60652755e+00],
       [ 4.54297940e+00],
       [-2.70206402e+00],
       [-1.19895871e+01],
       [ 6.33580204e+00],
       [-5.52166414e+00],
       [ 2.53451520e+00],
       [ 5.83348799e+00],
       [-4.71613791e+00],
       [-3.77772431e+00],
       [ 1.45249281e+00],
       [-6.81019973e-01],
       [ 1.40399528e+01],
       [ 9.77910610e+00],
       [-5.63522374e+00],
       [ 9.12427323e+00],
       [ 5.66869541e-01],
       [ 4.36878765e+00],
       [-1.27102401e+00],
       [ 5.60880509e+00],
       [ 1.86207457e+01],
       [-9.06016431e-01],
       [-1.32972814e+01],
       [ 2.96857455e+00],
       [-2.72705550e+00],
       [ 4.20000866e+00],
       [ 3.37801566e+00],
       [-3.70343102e-01],
       [ 2.11411325e+00],
       [ 1.40783349e+01],
       [ 1.39710577e+00],
       [-6.75476083e+00],
       [ 3.95204839e+00],
       [-1.55216782e+00],
       [ 3.96793436e-01],
       [-4.70607758e+00],
       [ 8.80130607e+00],
       [ 4.17132557e+00],
       [-5.55083289e+00],
       [-9.30471725e+00],
       [ 3.08269744e+00],
       [-1.53509348e+00],
       [ 9.16584744e-01],
       [-5.57384093e-01],
       [-6.52226143e+00],
       [ 4.88912630e+00],
       [-5.17770088e-01],
       [-5.86951085e-01],
       [ 6.61337615e+00],
       [ 5.86931747e+00],
       [-6.09389200e+00],
       [ 7.51763547e+00],
       [ 1.39251802e+01],
       [-3.81392402e+00],
       [ 5.13016749e+00],
       [ 1.28912126e+01],
       [-1.04542587e+01],
       [-2.97406874e+00],
       [ 4.75625986e-03],
       [ 1.88708435e+00],
       [-1.52744258e+00],
       [ 6.59945780e+00],
       [-1.82024869e+00],
       [ 7.64087923e+00],
       [-1.16174243e+01],
       [ 2.12108632e+00],
       [ 1.50359731e+00],
       [ 8.07711662e+00],
       [-2.08422791e+00],
       [ 7.42229794e+00],
       [ 5.08637568e+00],
       [ 5.42397145e+00],
       [-3.50642042e+00],
       [ 1.06281894e+01],
       [-4.93897801e+00],
       [ 2.47614058e+01],
       [-7.38817603e+00],
       [-7.38939762e+00],
       [ 3.82736051e+00],
       [ 7.61436215e+00],
       [ 7.00040163e+00],
       [ 2.94023044e+00],
       [-2.54267629e-01],
       [-4.42696489e+00],
       [ 7.57222560e+00],
       [ 8.13916680e+00],
       [ 1.00072001e+00],
       [ 3.55644609e-01],
       [ 5.66744985e-01],
       [ 2.23060848e+00],
       [ 8.12906695e+00],
       [-1.62078500e+00],
       [-6.78766372e-01],
       [-4.42784129e+00],
       [ 7.40400077e+00],
       [-1.23810437e-01],
       [ 2.07875798e+00],
       [ 1.47671911e+01],
       [ 2.75597590e+00],
       [ 4.73053466e+00],
       [ 8.78038743e+00],
       [ 4.86179521e-01],
       [ 6.64412565e+00],
       [ 9.70077796e+00],
       [-1.02165591e+00],
       [ 2.78495142e+00],
       [ 3.67552815e+00],
       [ 9.52630670e+00],
       [ 2.25545442e+00],
       [ 1.26982512e+01],
       [ 3.90765230e+00],
       [ 1.38238680e+01],
       [-3.40174647e+00],
       [ 3.32093728e+00],
       [-3.38411311e+00],
       [ 5.24682755e+00],
       [ 4.61048041e-01],
       [ 1.24965375e+01],
       [ 1.96823873e+00],
       [ 5.46546171e+00],
       [-2.19492306e+00],
       [-3.46464546e+00],
       [ 1.13821169e+00],
       [ 5.32229105e+00],
       [ 3.41865106e+00],
       [-1.77677274e+00],
       [ 1.83889826e+00],
       [-9.20030098e+00],
       [-2.83993504e+00],
       [ 3.42985591e+00],
       [ 3.41469553e+00],
       [ 6.47521959e+00],
       [-1.16462115e+01],
       [-4.77143723e+00],
       [-4.25007104e+00],
       [-2.37134525e+00],
       [-2.72884747e+00],
       [ 6.00640342e+00],
       [ 1.28446067e+00],
       [ 1.23547181e+00],
       [ 6.69383884e+00],
       [ 3.18055125e+00],
       [ 7.08002475e+00],
       [ 2.53581476e+00],
       [-4.80575509e+00],
       [ 5.37187541e+00],
       [-2.96781016e+00],
       [-6.40627443e+00],
       [ 5.09670278e+00],
       [ 8.55327137e+00],
       [-1.65951180e+01],
       [ 9.75830445e+00],
       [-8.68610848e+00],
       [-4.93733888e+00],
       [ 1.04000538e+00],
       [ 1.64159645e+00],
       [-2.88926310e+00],
       [-2.32919762e+00],
       [-2.63562433e+00],
       [-5.68279220e-01],
       [ 7.78675183e-01],
       [ 8.37641937e+00],
       [-3.41330543e+00],
       [ 7.25245284e-01],
       [-5.83235275e-01],
       [-4.59085651e+00],
       [ 1.28484485e+00],
       [ 1.12486876e+01],
       [-4.86298271e+00],
       [ 4.89120189e+00],
       [ 5.45115655e-01],
       [-1.90947576e+00],
       [-7.73627902e+00],
       [-2.27943920e+00],
       [-9.18621392e+00],
       [-1.53010365e+00],
       [ 1.47136927e+01],
       [-4.40800552e+00],
       [-1.57169672e+00],
       [ 8.37167459e+00],
       [ 4.77947274e+00],
       [-8.25886108e+00],
       [-1.49518675e+01],
       [-4.28216486e+00],
       [ 2.18318115e+00],
       [ 5.98218635e+00],
       [ 1.12626135e+01],
       [ 1.24525045e+01]])
In [75]:
plt.plot(features[:, 0], labels, 'o')
plt.plot(features[:, 0], y_hat, 'r-')
Out[75]:
[<matplotlib.lines.Line2D at 0x1c0697ad108>]
Out[75]:
[<matplotlib.lines.Line2D at 0x1c0694b4408>]

  从模型结果能够看出,模型和数据集分布规律相差较大

  • 噪声增加

  此外,我们稍微增加模型白噪声,测试线性回归模型效果

In [45]:
# 设置随机数种子
np.random.seed(24)   

# 扰动项取值为2
features, labels = arrayGenReg(w=[2,1], delta=2)
In [46]:
features
Out[46]:
array([[ 1.32921217,  1.        ],
       [-0.77003345,  1.        ],
       [-0.31628036,  1.        ],
       ...,
       [ 0.84682091,  1.        ],
       [ 1.73889649,  1.        ],
       [ 1.93991673,  1.        ]])
In [47]:
plt.plot(features[:, 0], labels, 'o')
Out[47]:
[<matplotlib.lines.Line2D at 0x1c0696b0ac8>]
In [48]:
np.linalg.lstsq(features, labels, rcond=-1)
Out[48]:
(array([[1.91605821],
        [0.90602215]]),
 array([3767.12804359]),
 2,
 array([32.13187164, 31.53261742]))
In [49]:
w = np.linalg.lstsq(features, labels, rcond=-1)[0]
In [50]:
w
Out[50]:
array([[1.91605821],
       [0.90602215]])
In [51]:
SSELoss(features, w, labels)
Out[51]:
array([[3767.12804359]])
In [52]:
X = np.linspace(-5, 5, 1000)
y = w[0] * X + w[1]
In [53]:
plt.plot(features[:, 0], labels, 'o')
plt.plot(X, y, 'r-')
Out[53]:
[<matplotlib.lines.Line2D at 0x1c0694efc08>]
Out[53]:
[<matplotlib.lines.Line2D at 0x1c069535ec8>]

能够发现,模型误差较大。

  • 最小二乘法条件限制

  并且,除此之外,线性回归模型还面临这一个重大问题就是,如果特征矩阵的交叉乘积不可逆,则最小二乘法求解过程就不成立了。

$$\hat w = (X^TX)^{-1}X^Ty$$

当然,此时也代表着数据集存在着较为严重的多重共线性,换而言之就是数据集的特征矩阵可能可以相互线性表出。这个时候矩阵方程$X^TX\hat w=X^Ty$并不一定存在唯一解。解决该问题的方法有很多种,从数学角度出发,我们可以从以下三个方面入手:

  • 其一,对数据进行降维处理:
      首先,可考虑进一步对数据集进行SVD分解或PCA主成分分析,在SVD或PCA执行的过程中会对数据集进行正交变换,最终所得数据集各列将不存在任何相关性。当然此举会对数据集的结构进行改变,且各列特征变得不可解释。
  • 其二,修改求解损失函数的方法:
      我们可以试图求解原方程的广义逆矩阵,对于某些矩阵方程来说,通过求解广义逆矩阵,也可以得到近似最优解;此外,我们还可以通过使用其他最优化求解方法,如梯度下降算法等来进行求解;
  • 其三,修改损失函数:
      其实可以修改原损失函数,令其满足最小二乘法求解条件即可。如果$XTX$不可逆,那么我们可以通过试图在损失函数中加入一个正则化项,从而令损失函数可解。根据Lesson 2中的公式推导,目前根据SSE所构建的损失函数如下:
$$SSELoss(\hat w) = ||y - X\hat w||_2^2 = (y - X\hat w)^T(y - X\hat w)$$

通过数学过程可以证明,此时如果我们在原有损失函数基础上添加一个关于参数$w$的1-范数($||\hat w||_1$)或者2-范数($||\hat w||_2$)的某个计算结果,则可令最小二乘法条件得到满足。此时最小二乘法计算结果由无偏估计变为有偏估计。例如,当我们在损失函数中加上$\lambda ||\hat w||_2^2$(其中$\lambda$为参数)时,模型损失函数为:

$$Loss(\hat w) = ||y - X\hat w||_2^2 +\lambda ||\hat w||_2^2$$

经过数学转化,上述矩阵表达式对$\hat w$求导后令其为零,则可解出:

$$(X^TX+\lambda I) \hat w = X^Ty$$

其中$I$为单位矩阵。此时由于$(X^TX+\lambda I)$肯定是可逆矩阵,因此可以顺利求解出$\hat w$:

$$\hat w = (X^TX+\lambda I)^{-1}X^Ty$$

该过程也被称为岭回归。而类似的,如果是通过添加了$\hat w$的1-范数的某个表达式,从而构造损失函数如下:

$$Loss(\hat w) = ||y - X\hat w||_2^2 +\lambda ||\hat w||_1$$

则该过程被称为Lasso。而更进一步,如果构建的损失函数同时包含$\hat w$的1-范数和2-范数,形如如下形式:

$$Loss(\hat w) = \frac{1}{2n}||y - X\hat w||_2^2 + \lambda \alpha ||\hat w||_1 +\frac{\lambda(1-\alpha)}{2} ||\hat w||_2 ^ 2$$

则构建的是弹性网模型(Elastic-Net),其中$\lambda、\alpha$都是参数,n是样本个数。不难发现,岭回归和Lasso其实都是弹性网的一种特殊形式。更多关于线性模型的相关方法,我们将在后续逐渐介绍。

1-范数也被称为L1范数,将参数的1-范数添加入损失函数的做法,也被称为损失函数的L1正则化,L2正则化也类似。在大多数情况下,添加正则化项也可称为添加惩罚函数$p(w)$,核心作用是缓解模型过拟合倾向。

三、线性回归的决定系数¶

  对于线性回归模型来说,除了SSE以外,我们还可使用决定系数(R-square,也被称为拟合优度检验)作为其模型评估指标。决定系数的计算需要使用之前介绍的组间误差平方和和离差平方和的概念。在回归分析中,SSR表示聚类中类似的组间平方和概念,表意为Sum of squares of the regression,由预测数据与标签均值之间差值的平方和计算的出:

$$SSR =\sum^{n}_{i=1}(\bar{y_i}-\hat{y_i})^2$$

而SST(Total sum of squares)则是实际值和均值之间的差值的平方和计算得到:SST可直接计算,是数据集本身的特征,类似于方差

$$SST =\sum^{n}_{i=1}(\bar{y_i}-y_i)^2$$

并且,$SST$可由$SSR+SSE$计算得出。而决定系数,则由$SSR$和$SST$共同决定:

$$R-square=\frac{SSR}{SST}=\frac{SST-SSE}{SSE}=1-\frac{SSE}{SST}$$

很明显,决定系数是一个鉴于[0,1]之间的值,并且约趋近于1,模型拟合效果越好。我们可以通过如下过程,进行决定系数的计算:

In [19]:
sst = np.power(labels - labels.mean(), 2).sum()
sse = SSELoss(features, w, labels)
In [20]:
r = 1-(sse/sst)
In [21]:
r
Out[21]:
array([[0.99998114]])