In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
In [2]:
pip list
Package                            Version            
---------------------------------- -------------------
alabaster                          0.7.12             
anaconda-client                    1.7.2              
anaconda-navigator                 1.9.12             
anaconda-project                   0.8.3              
argh                               0.26.2             
asn1crypto                         1.3.0              
astroid                            2.3.3              
astropy                            4.0                
atomicwrites                       1.3.0              
attrs                              19.3.0             
autopep8                           1.4.4              
Babel                              2.8.0              
backcall                           0.1.0              
backports.functools-lru-cache      1.6.1              
backports.shutil-get-terminal-size 1.0.0              
backports.tempfile                 1.0                
backports.weakref                  1.0.post1          
bcrypt                             3.1.7              
beautifulsoup4                     4.8.2              
bitarray                           1.2.1              
bkcharts                           0.2                
bleach                             3.1.0              
bokeh                              1.4.0              
boto                               2.49.0             
Bottleneck                         1.3.2              
certifi                            2019.11.28         
cffi                               1.14.0             
chardet                            3.0.4              
chart-studio                       1.1.0              
Click                              7.0                
cloudpickle                        1.3.0              
clyent                             1.2.2              
colorama                           0.4.3              
colorlover                         0.3.0              
comtypes                           1.1.7              
conda                              4.8.2              
conda-build                        3.18.11            
conda-package-handling             1.6.0              
conda-verify                       3.4.2              
contextlib2                        0.6.0.post1        
cryptography                       2.8                
cycler                             0.10.0             
Cython                             0.29.15            
cytoolz                            0.10.1             
dask                               2.11.0             
decorator                          4.4.1              
defusedxml                         0.6.0              
diff-match-patch                   20181111           
distributed                        2.11.0             
docutils                           0.16               
entrypoints                        0.3                
et-xmlfile                         1.0.1              
fastcache                          1.1.0              
filelock                           3.0.12             
flake8                             3.7.9              
Flask                              1.1.1              
fsspec                             0.6.2              
future                             0.18.2             
gevent                             1.4.0              
glob2                              0.7                
graphviz                           0.19               
greenlet                           0.4.15             
h5py                               2.10.0             
HeapDict                           1.0.1              
html5lib                           1.0.1              
hypothesis                         5.5.4              
idna                               2.8                
imageio                            2.6.1              
imagesize                          1.2.0              
importlib-metadata                 1.5.0              
intervaltree                       3.0.2              
ipykernel                          5.1.4              
ipython                            7.12.0             
ipython-genutils                   0.2.0              
ipywidgets                         7.5.1              
isort                              4.3.21             
itsdangerous                       1.1.0              
jdcal                              1.4.1              
jedi                               0.14.1             
Jinja2                             2.11.1             
joblib                             0.14.1             
json5                              0.9.1              
jsonschema                         3.2.0              
jupyter                            1.0.0              
jupyter-client                     5.3.4              
jupyter-console                    6.1.0              
jupyter-core                       4.6.1              
jupyterlab                         1.2.6              
jupyterlab-server                  1.0.6              
keyring                            21.1.0             
kiwisolver                         1.1.0              
lazy-object-proxy                  1.4.3              
libarchive-c                       2.8                
llvmlite                           0.31.0             
locket                             0.2.0              
lxml                               4.5.0              
MarkupSafe                         1.1.1              
matplotlib                         3.1.3              
mccabe                             0.6.1              
menuinst                           1.4.16             
mistune                            0.8.4              
mkl-fft                            1.0.15             
mkl-random                         1.1.0              
mkl-service                        2.3.0              
mlxtend                            0.19.0             
mock                               4.0.1              
more-itertools                     8.2.0              
mpmath                             1.1.0              
msgpack                            0.6.1              
multipledispatch                   0.6.0              
navigator-updater                  0.2.1              
nbconvert                          5.6.1              
nbformat                           5.0.4              
nbmerge                            0.0.4              
networkx                           2.4                
nltk                               3.4.5              
nose                               1.3.7              
notebook                           6.0.3              
numba                              0.48.0             
numexpr                            2.7.1              
numpy                              1.18.1             
numpydoc                           0.9.2              
olefile                            0.46               
openpyxl                           3.0.3              
packaging                          20.1               
pandas                             1.0.1              
pandasql                           0.7.3              
pandocfilters                      1.4.2              
paramiko                           2.7.1              
parso                              0.5.2              
partd                              1.1.0              
path                               13.1.0             
pathlib2                           2.3.5              
pathtools                          0.1.2              
patsy                              0.5.1              
pep8                               1.7.1              
pexpect                            4.8.0              
pickleshare                        0.7.5              
Pillow                             7.0.0              
pip                                20.0.2             
pkginfo                            1.5.0.1            
plotly                             5.5.0              
pluggy                             0.13.1             
ply                                3.11               
prometheus-client                  0.7.1              
prompt-toolkit                     3.0.3              
psutil                             5.6.7              
py                                 1.8.1              
pycodestyle                        2.5.0              
pycosat                            0.6.3              
pycparser                          2.19               
pycrypto                           2.6.1              
pycurl                             7.43.0.5           
pydocstyle                         4.0.1              
pyflakes                           2.1.1              
Pygments                           2.5.2              
pylint                             2.4.4              
PyNaCl                             1.3.0              
pyodbc                             4.0.0-unsupported  
pyOpenSSL                          19.1.0             
pyparsing                          2.4.6              
pyreadline                         2.1                
pyrsistent                         0.15.7             
PySocks                            1.7.1              
pytest                             5.3.5              
pytest-arraydiff                   0.3                
pytest-astropy                     0.8.0              
pytest-astropy-header              0.1.2              
pytest-doctestplus                 0.5.0              
pytest-openfiles                   0.4.0              
pytest-remotedata                  0.3.2              
python-dateutil                    2.8.1              
python-jsonrpc-server              0.3.4              
python-language-server             0.31.7             
pytz                               2019.3             
PyWavelets                         1.1.1              
pywin32                            227                
pywin32-ctypes                     0.2.0              
pywinpty                           0.5.7              
PyYAML                             5.3                
pyzmq                              18.1.1             
QDarkStyle                         2.8                
QtAwesome                          0.6.1              
qtconsole                          4.6.0              
QtPy                               1.9.0              
requests                           2.22.0             
retrying                           1.3.3              
rope                               0.16.0             
Rtree                              0.9.3              
ruamel-yaml                        0.15.87            
scikit-image                       0.16.2             
scikit-learn                       0.22.1             
scipy                              1.4.1              
seaborn                            0.10.0             
Send2Trash                         1.5.0              
setuptools                         45.2.0.post20200210
simplegeneric                      0.8.1              
singledispatch                     3.4.0.3            
six                                1.14.0             
snowballstemmer                    2.0.0              
sortedcollections                  1.1.2              
sortedcontainers                   2.1.0              
soupsieve                          1.9.5              
Sphinx                             2.4.0              
sphinxcontrib-applehelp            1.0.1              
sphinxcontrib-devhelp              1.0.1              
sphinxcontrib-htmlhelp             1.0.2              
sphinxcontrib-jsmath               1.0.1              
sphinxcontrib-qthelp               1.0.2              
sphinxcontrib-serializinghtml      1.1.3              
sphinxcontrib-websupport           1.2.0              
spyder                             4.0.1              
spyder-kernels                     1.8.1              
SQLAlchemy                         1.3.13             
statsmodels                        0.11.0             
sympy                              1.5.1              
tables                             3.6.1              
tblib                              1.6.0              
tenacity                           8.0.1              
terminado                          0.8.3              
testpath                           0.4.4              
toolz                              0.10.0             
tornado                            6.0.3              
tqdm                               4.42.1             
traitlets                          4.3.3              
ujson                              1.35               
unicodecsv                         0.14.1             
urllib3                            1.25.8             
watchdog                           0.10.2             
wcwidth                            0.1.8              
webencodings                       0.5.1              
Werkzeug                           1.0.0              
wheel                              0.34.2             
widgetsnbextension                 3.5.1              
win-inet-pton                      1.1.0              
win-unicode-console                0.5                
wincertstore                       0.2                
wrapt                              1.11.2             
xlrd                               1.2.0              
XlsxWriter                         1.2.7              
xlwings                            0.17.1             
xlwt                               1.3.0              
xmltodict                          0.12.0             
yapf                               0.28.0             
zict                               1.0.0              
zipp                               2.2.0              
Note: you may need to restart the kernel to use updated packages.

辅助函数1:切分数据集函数

In [3]:
def binSplitDataSet(dataSet, feature, value):
    """
    函数说明:根据特征切分数据集合
    Parameters:
        dataSet - DataFrame,数据集
        feature - int,带切分的特征
        value   - int,该特征的值
    Returns:
        mat0 - 切分的数据集合0
        mat1 - 切分的数据集合1
    """
    mat0 = dataSet.loc[dataSet.iloc[:,feature] > value,:]
    mat1 = dataSet.loc[dataSet.iloc[:,feature] <= value,:]
    return mat0, mat1
In [4]:
data = pd.read_csv('split_data.csv', header=None)
data
Out[4]:
0 1 2
0 0 1 1
1 0 2 0
2 1 0 1
3 1 1 1
4 1 2 1
5 2 0 0
6 2 1 1
7 2 2 0
In [5]:
mat0, mat1 = binSplitDataSet(data, 0, 1) # 以第0列切分,切分值为1。
In [6]:
mat0 # 大于1的划分到mat0
Out[6]:
0 1 2
5 2 0 0
6 2 1 1
7 2 2 0
In [7]:
mat1 # <=1的划分到mat1
Out[7]:
0 1 2
0 0 1 1
1 0 2 0
2 1 0 1
3 1 1 1
4 1 2 1

辅助函数2:计算总误差函数

In [8]:
def errType(dataSet):
    var = dataSet.iloc[:,-1].var() * dataSet.shape[0] # .var计算均方误差,然后乘以样本数算出总误差。
    return var
In [9]:
errType(data)
Out[9]:
2.142857142857143

辅助函数3:叶子节点存放均值

In [10]:
def leafType(dataSet):
    leaf = dataSet.iloc[:,-1].mean()
    return leaf
In [11]:
ex00 = pd.read_csv('ex00.csv', header=None)
ex00
Out[11]:
0 1
0 0.036098 0.155096
1 0.993349 1.077553
2 0.530897 0.893462
3 0.712386 0.564858
4 0.343554 -0.371700
... ... ...
195 0.552381 1.369630
196 0.683886 0.999985
197 0.210334 -0.006899
198 0.604529 1.212685
199 0.250744 0.046297

200 rows × 2 columns

In [12]:
leafType(ex00)
Out[12]:
0.5717430049999996

辅助函数4:找到数据的最佳二元切分方式函数

In [13]:
def chooseBestSplit(dataSet, leafType = leafType, errType = errType, ops = (1,4)):
    """
    函数说明:找到数据的最佳二元切分方式函数
    Parameters:
        dataSet  - DataFrame,数据集合
        leafType - function,生成叶结点函数
        regErr   - function,误差估计函数
        ops      - tuple,用户定义的参数构成的元组
       Returns:
       tuple,
        bestIndex - 最佳切分特征
        bestValue - 最佳特征值
    """
    import types
    tolS = ops[0] # tolS允许的误差下降值
    tolN = ops[1] # tolN切分的最少样本数
    
    #如果当前所有值相等,则退出。(根据set的特性)
    if len(set(dataSet.iloc[:,-1].values)) == 1:
        return None, leafType(dataSet)
    #统计数据集合的行m和列n
    m, n = np.shape(dataSet)
    #默认最后一个特征为最佳切分特征,计算其误差估计
    S = errType(dataSet)
   
    bestS = float('inf') # 分别为最佳误差
    bestIndex = 0        # 最佳特征切分的索引值
    bestValue = 0        # 最佳特征值

    #遍历所有特征列
    for featIndex in range(n - 1):
        #遍历所有特征值
        for splitVal in set(dataSet.iloc[:,featIndex].values):
            #根据特征和特征值切分数据集
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            #如果数据少于tolN,则退出
            if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue
            #计算误差估计
            newS = errType(mat0) + errType(mat1)
            #如果误差估计更小,则更新特征索引值和特征值
            if newS < bestS: 
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #如果误差减少不大则退出
    if (S - bestS) < tolS: 
        return None, leafType(dataSet)
    #根据最佳的切分特征和特征值切分数据集合
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    #如果切分出的数据集很小则退出
    if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
        return None, leafType(dataSet)    
    #返回最佳切分特征和特征值
    return bestIndex, bestValue
In [14]:
chooseBestSplit(ex00)
Out[14]:
(0, 0.48813)

主函数:创建树

找到最佳的带切分特征:

如果该节点不可再分,将该节点保存围殴叶子节点。
执行切分。
在右子树递归调用createTree()函数。
在左子树递归调用createTree()函数。

In [15]:
def createTree(dataSet, leafType = leafType, errType = errType, ops = (1, 4)):
    """
    函数说明:树构建函数
    Parameters:
        dataSet  - DataFrame,数据集合
        leafType - function,建立叶结点的函数
        errType  - function,误差计算函数
        ops      - tuple,包含树构建所有其他参数的元组
    Returns:
        retTree  - dict,构建的回归树
    """
    
    # 最佳切分特征和特征值
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    # 如果没有特征,则返回特征值
    if feat == None: return val
    
    # 空字典存放最后的回归树
    retTree = {}
    retTree['spInd'] = feat # 列索引
    retTree['spVal'] = val  # 特征值
    
    # 分成左数据集和右数据集
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    
    # 创建左子树和右子树
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree 
In [16]:
createTree(ex00) 
# 因为只有一列特征,所以只分一次就结束了。
Out[16]:
{'spInd': 0,
 'spVal': 0.48813,
 'left': 1.018096767241379,
 'right': -0.04465028571428573}

分析一维度的一个例子

In [17]:
ex0 = pd.read_table('ex0.txt', header=None)
ex0
Out[17]:
0 1 2
0 1.0 0.409175 1.883180
1 1.0 0.182603 0.063908
2 1.0 0.663687 3.042257
3 1.0 0.517395 2.305004
4 1.0 0.013643 -0.067698
... ... ... ...
195 1.0 0.321632 0.896855
196 1.0 0.845148 4.220850
197 1.0 0.012003 -0.217283
198 1.0 0.018883 -0.300577
199 1.0 0.071476 0.006014

200 rows × 3 columns

In [18]:
ex0.iloc[:, 0].value_counts()
# 第一列全是1.0
Out[18]:
1.0    200
Name: 0, dtype: int64
In [19]:
ex0.shape
Out[19]:
(200, 3)
In [20]:
# 不要第一列的数据,只要后面的两列。
# 要注意这个案例是一维度的,看似后面的两列像一组点坐标,其实第二列是特征,第三列是标签(并不是第二个维度!)
plt.scatter(ex0.iloc[:,1].values, ex0.iloc[:,2].values)
Out[20]:
<matplotlib.collections.PathCollection at 0x1820bd53f08>
In [21]:
tree = createTree(ex0)
tree
Out[21]:
{'spInd': 1,
 'spVal': 0.39435,
 'left': {'spInd': 1,
  'spVal': 0.582002,
  'left': {'spInd': 1,
   'spVal': 0.797583,
   'left': 3.9871632000000004,
   'right': 2.9836209534883724},
  'right': 1.9800350714285717},
 'right': {'spInd': 1,
  'spVal': 0.19783399999999998,
  'left': 1.0289583666666666,
  'right': -0.02383815555555556}}

image.png

上图中共有5个叶子节点,如何使用这些节点呢?
可以先画出原始数据的图像:

In [22]:
# 可视化
plt.scatter(ex0.iloc[:,1].values, ex0.iloc[:,2].values)
Out[22]:
<matplotlib.collections.PathCollection at 0x1820be0a848>

然后划出一条回归线,如何画出回归线?
答:先随机生成 0 - 1 之间的100个数据作为x点坐标,然后带入到上面的回归树中得到对应的y值。——得到x和y后即可绘出一条直线。
下面我们用sklearn库中的回归树对上述的案例绘制出树和回归线。

使用sklearn训练模型

In [23]:
from sklearn.tree import DecisionTreeRegressor
In [24]:
x = (ex0.iloc[:,1].values).reshape(-1, 1)
y = (ex0.iloc[:,-1].values).reshape(-1, 1)
In [25]:
model1 = DecisionTreeRegressor(max_depth=3)
model1.fit(x, y)
Out[25]:
DecisionTreeRegressor(ccp_alpha=0.0, criterion='mse', max_depth=3,
                      max_features=None, max_leaf_nodes=None,
                      min_impurity_decrease=0.0, min_impurity_split=None,
                      min_samples_leaf=1, min_samples_split=2,
                      min_weight_fraction_leaf=0.0, presort='deprecated',
                      random_state=None, splitter='best')
In [26]:
# 画出决策树
import graphviz # http://www.graphviz.org/
from sklearn import tree

dot_data = tree.export_graphviz(model1, 
                                out_file = None, 
                                feature_names = ['feature_names'],
                                class_names = ['no','yes'],
                                filled = True,
                                rounded = True,
                                special_characters = True)
graph = graphviz.Source(dot_data)
graph.render('cart')
Out[26]:
'cart.pdf'
In [27]:
graph
Out[27]:
Tree 0 feature_names ≤ 0.397 mse = 2.1 samples = 200 value = 2.004 1 feature_names ≤ 0.203 mse = 0.3 samples = 75 value = 0.397 0->1 True 8 feature_names ≤ 0.596 mse = 0.702 samples = 125 value = 2.968 0->8 False 2 feature_names ≤ 0.15 mse = 0.028 samples = 45 value = -0.024 1->2 5 feature_names ≤ 0.213 mse = 0.045 samples = 30 value = 1.029 1->5 3 mse = 0.031 samples = 34 value = -0.055 2->3 4 mse = 0.004 samples = 11 value = 0.072 2->4 6 mse = 0.013 samples = 2 value = 1.375 5->6 7 mse = 0.038 samples = 28 value = 1.004 5->7 9 feature_names ≤ 0.487 mse = 0.035 samples = 42 value = 1.98 8->9 12 feature_names ≤ 0.807 mse = 0.296 samples = 83 value = 3.467 8->12 10 mse = 0.018 samples = 16 value = 1.881 9->10 11 mse = 0.036 samples = 26 value = 2.041 9->11 13 mse = 0.036 samples = 43 value = 2.984 12->13 14 mse = 0.054 samples = 40 value = 3.987 12->14
In [28]:
# 生成 0 - 1 之间100个数据,用于后续可视化画线。
X_test = np.arange(0, 1, 0.01)[:, np.newaxis] # 人工生成的数据点,作为x点坐标
y_test = model1.predict(X_test)               # 输入测试数据,输出预测结果。得到y点坐标
In [29]:
X_test.tolist()
Out[29]:
[[0.0],
 [0.01],
 [0.02],
 [0.03],
 [0.04],
 [0.05],
 [0.06],
 [0.07],
 [0.08],
 [0.09],
 [0.1],
 [0.11],
 [0.12],
 [0.13],
 [0.14],
 [0.15],
 [0.16],
 [0.17],
 [0.18],
 [0.19],
 [0.2],
 [0.21],
 [0.22],
 [0.23],
 [0.24],
 [0.25],
 [0.26],
 [0.27],
 [0.28],
 [0.29],
 [0.3],
 [0.31],
 [0.32],
 [0.33],
 [0.34],
 [0.35000000000000003],
 [0.36],
 [0.37],
 [0.38],
 [0.39],
 [0.4],
 [0.41000000000000003],
 [0.42],
 [0.43],
 [0.44],
 [0.45],
 [0.46],
 [0.47000000000000003],
 [0.48],
 [0.49],
 [0.5],
 [0.51],
 [0.52],
 [0.53],
 [0.54],
 [0.55],
 [0.56],
 [0.5700000000000001],
 [0.58],
 [0.59],
 [0.6],
 [0.61],
 [0.62],
 [0.63],
 [0.64],
 [0.65],
 [0.66],
 [0.67],
 [0.68],
 [0.6900000000000001],
 [0.7000000000000001],
 [0.71],
 [0.72],
 [0.73],
 [0.74],
 [0.75],
 [0.76],
 [0.77],
 [0.78],
 [0.79],
 [0.8],
 [0.81],
 [0.8200000000000001],
 [0.8300000000000001],
 [0.84],
 [0.85],
 [0.86],
 [0.87],
 [0.88],
 [0.89],
 [0.9],
 [0.91],
 [0.92],
 [0.93],
 [0.9400000000000001],
 [0.9500000000000001],
 [0.96],
 [0.97],
 [0.98],
 [0.99]]
In [30]:
y_test.tolist()
Out[30]:
[-0.054810500000000005,
 -0.054810500000000005,
 -0.054810500000000005,
 -0.054810500000000005,
 -0.054810500000000005,
 -0.054810500000000005,
 -0.054810500000000005,
 -0.054810500000000005,
 -0.054810500000000005,
 -0.054810500000000005,
 -0.054810500000000005,
 -0.054810500000000005,
 -0.054810500000000005,
 -0.054810500000000005,
 -0.054810500000000005,
 0.07189454545454545,
 0.07189454545454545,
 0.07189454545454545,
 0.07189454545454545,
 0.07189454545454545,
 0.07189454545454545,
 1.3753635,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.0042151428571429,
 1.88108975,
 1.88108975,
 1.88108975,
 1.88108975,
 1.88108975,
 1.88108975,
 1.88108975,
 1.88108975,
 1.88108975,
 2.0409245,
 2.0409245,
 2.0409245,
 2.0409245,
 2.0409245,
 2.0409245,
 2.0409245,
 2.0409245,
 2.0409245,
 2.0409245,
 2.0409245,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 2.9836209534883724,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999,
 3.987163199999999]
In [31]:
y_test_df = pd.DataFrame(y_test)
y_test_df
y_test_df[0].value_counts()
# 共计8种结果对应上图八个叶子节点。
Out[31]:
 2.983621    21
 3.987163    19
 1.004215    18
-0.054811    15
 2.040925    11
 1.881090     9
 0.071895     6
 1.375363     1
Name: 0, dtype: int64
In [32]:
# 可视化
plt.figure()
plt.scatter(x, y, s=20, edgecolors="black", c="darkorange", label="data")
plt.plot(X_test, y_test, color="black", label="max_depth=3", linewidth=2)
plt.legend()
plt.show()
In [ ]: