import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
pip list
def binSplitDataSet(dataSet, feature, value):
"""
函数说明:根据特征切分数据集合
Parameters:
dataSet - DataFrame,数据集
feature - int,带切分的特征
value - int,该特征的值
Returns:
mat0 - 切分的数据集合0
mat1 - 切分的数据集合1
"""
mat0 = dataSet.loc[dataSet.iloc[:,feature] > value,:]
mat1 = dataSet.loc[dataSet.iloc[:,feature] <= value,:]
return mat0, mat1
data = pd.read_csv('split_data.csv', header=None)
data
mat0, mat1 = binSplitDataSet(data, 0, 1) # 以第0列切分,切分值为1。
mat0 # 大于1的划分到mat0
mat1 # <=1的划分到mat1
def errType(dataSet):
var = dataSet.iloc[:,-1].var() * dataSet.shape[0] # .var计算均方误差,然后乘以样本数算出总误差。
return var
errType(data)
def leafType(dataSet):
leaf = dataSet.iloc[:,-1].mean()
return leaf
ex00 = pd.read_csv('ex00.csv', header=None)
ex00
leafType(ex00)
def chooseBestSplit(dataSet, leafType = leafType, errType = errType, ops = (1,4)):
"""
函数说明:找到数据的最佳二元切分方式函数
Parameters:
dataSet - DataFrame,数据集合
leafType - function,生成叶结点函数
regErr - function,误差估计函数
ops - tuple,用户定义的参数构成的元组
Returns:
tuple,
bestIndex - 最佳切分特征
bestValue - 最佳特征值
"""
import types
tolS = ops[0] # tolS允许的误差下降值
tolN = ops[1] # tolN切分的最少样本数
#如果当前所有值相等,则退出。(根据set的特性)
if len(set(dataSet.iloc[:,-1].values)) == 1:
return None, leafType(dataSet)
#统计数据集合的行m和列n
m, n = np.shape(dataSet)
#默认最后一个特征为最佳切分特征,计算其误差估计
S = errType(dataSet)
bestS = float('inf') # 分别为最佳误差
bestIndex = 0 # 最佳特征切分的索引值
bestValue = 0 # 最佳特征值
#遍历所有特征列
for featIndex in range(n - 1):
#遍历所有特征值
for splitVal in set(dataSet.iloc[:,featIndex].values):
#根据特征和特征值切分数据集
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
#如果数据少于tolN,则退出
if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue
#计算误差估计
newS = errType(mat0) + errType(mat1)
#如果误差估计更小,则更新特征索引值和特征值
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
#如果误差减少不大则退出
if (S - bestS) < tolS:
return None, leafType(dataSet)
#根据最佳的切分特征和特征值切分数据集合
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
#如果切分出的数据集很小则退出
if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
return None, leafType(dataSet)
#返回最佳切分特征和特征值
return bestIndex, bestValue
chooseBestSplit(ex00)
找到最佳的带切分特征:
如果该节点不可再分,将该节点保存围殴叶子节点。
执行切分。
在右子树递归调用createTree()函数。
在左子树递归调用createTree()函数。
def createTree(dataSet, leafType = leafType, errType = errType, ops = (1, 4)):
"""
函数说明:树构建函数
Parameters:
dataSet - DataFrame,数据集合
leafType - function,建立叶结点的函数
errType - function,误差计算函数
ops - tuple,包含树构建所有其他参数的元组
Returns:
retTree - dict,构建的回归树
"""
# 最佳切分特征和特征值
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
# 如果没有特征,则返回特征值
if feat == None: return val
# 空字典存放最后的回归树
retTree = {}
retTree['spInd'] = feat # 列索引
retTree['spVal'] = val # 特征值
# 分成左数据集和右数据集
lSet, rSet = binSplitDataSet(dataSet, feat, val)
# 创建左子树和右子树
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
createTree(ex00)
# 因为只有一列特征,所以只分一次就结束了。
ex0 = pd.read_table('ex0.txt', header=None)
ex0
ex0.iloc[:, 0].value_counts()
# 第一列全是1.0
ex0.shape
# 不要第一列的数据,只要后面的两列。
# 要注意这个案例是一维度的,看似后面的两列像一组点坐标,其实第二列是特征,第三列是标签(并不是第二个维度!)
plt.scatter(ex0.iloc[:,1].values, ex0.iloc[:,2].values)
tree = createTree(ex0)
tree
上图中共有5个叶子节点,如何使用这些节点呢?
可以先画出原始数据的图像:
# 可视化
plt.scatter(ex0.iloc[:,1].values, ex0.iloc[:,2].values)
然后划出一条回归线,如何画出回归线?
答:先随机生成 0 - 1 之间的100个数据作为x点坐标,然后带入到上面的回归树中得到对应的y值。——得到x和y后即可绘出一条直线。
下面我们用sklearn库中的回归树对上述的案例绘制出树和回归线。
from sklearn.tree import DecisionTreeRegressor
x = (ex0.iloc[:,1].values).reshape(-1, 1)
y = (ex0.iloc[:,-1].values).reshape(-1, 1)
model1 = DecisionTreeRegressor(max_depth=3)
model1.fit(x, y)
# 画出决策树
import graphviz # http://www.graphviz.org/
from sklearn import tree
dot_data = tree.export_graphviz(model1,
out_file = None,
feature_names = ['feature_names'],
class_names = ['no','yes'],
filled = True,
rounded = True,
special_characters = True)
graph = graphviz.Source(dot_data)
graph.render('cart')
graph
# 生成 0 - 1 之间100个数据,用于后续可视化画线。
X_test = np.arange(0, 1, 0.01)[:, np.newaxis] # 人工生成的数据点,作为x点坐标
y_test = model1.predict(X_test) # 输入测试数据,输出预测结果。得到y点坐标
X_test.tolist()
y_test.tolist()
y_test_df = pd.DataFrame(y_test)
y_test_df
y_test_df[0].value_counts()
# 共计8种结果对应上图八个叶子节点。
# 可视化
plt.figure()
plt.scatter(x, y, s=20, edgecolors="black", c="darkorange", label="data")
plt.plot(X_test, y_test, color="black", label="max_depth=3", linewidth=2)
plt.legend()
plt.show()