In [1]:
import numpy as np
import matplotlib.pyplot as plt  
import pandas as pd
In [2]:
pip list
Package                            Version            
---------------------------------- -------------------
alabaster                          0.7.12             
anaconda-client                    1.7.2              
anaconda-navigator                 1.9.12             
anaconda-project                   0.8.3              
argh                               0.26.2             
asn1crypto                         1.3.0              
astroid                            2.3.3              
astropy                            4.0                
atomicwrites                       1.3.0              
attrs                              19.3.0             
autopep8                           1.4.4              
Babel                              2.8.0              
backcall                           0.1.0              
backports.functools-lru-cache      1.6.1              
backports.shutil-get-terminal-size 1.0.0              
backports.tempfile                 1.0                
backports.weakref                  1.0.post1          
bcrypt                             3.1.7              
beautifulsoup4                     4.8.2              
bitarray                           1.2.1              
bkcharts                           0.2                
bleach                             3.1.0              
bokeh                              1.4.0              
boto                               2.49.0             
Bottleneck                         1.3.2              
certifi                            2019.11.28         
cffi                               1.14.0             
chardet                            3.0.4              
chart-studio                       1.1.0              
Click                              7.0                
cloudpickle                        1.3.0              
clyent                             1.2.2              
colorama                           0.4.3              
colorlover                         0.3.0              
comtypes                           1.1.7              
conda                              4.8.2              
conda-build                        3.18.11            
conda-package-handling             1.6.0              
conda-verify                       3.4.2              
contextlib2                        0.6.0.post1        
cryptography                       2.8                
cycler                             0.10.0             
Cython                             0.29.15            
cytoolz                            0.10.1             
dask                               2.11.0             
decorator                          4.4.1              
defusedxml                         0.6.0              
diff-match-patch                   20181111           
distributed                        2.11.0             
docutils                           0.16               
entrypoints                        0.3                
et-xmlfile                         1.0.1              
fastcache                          1.1.0              
filelock                           3.0.12             
flake8                             3.7.9              
Flask                              1.1.1              
fsspec                             0.6.2              
future                             0.18.2             
gevent                             1.4.0              
glob2                              0.7                
graphviz                           0.19               
greenlet                           0.4.15             
h5py                               2.10.0             
HeapDict                           1.0.1              
html5lib                           1.0.1              
hypothesis                         5.5.4              
idna                               2.8                
imageio                            2.6.1              
imagesize                          1.2.0              
importlib-metadata                 1.5.0              
intervaltree                       3.0.2              
ipykernel                          5.1.4              
ipython                            7.12.0             
ipython-genutils                   0.2.0              
ipywidgets                         7.5.1              
isort                              4.3.21             
itsdangerous                       1.1.0              
jdcal                              1.4.1              
jedi                               0.14.1             
Jinja2                             2.11.1             
joblib                             0.14.1             
json5                              0.9.1              
jsonschema                         3.2.0              
jupyter                            1.0.0              
jupyter-client                     5.3.4              
jupyter-console                    6.1.0              
jupyter-core                       4.6.1              
jupyterlab                         1.2.6              
jupyterlab-server                  1.0.6              
keyring                            21.1.0             
kiwisolver                         1.1.0              
lazy-object-proxy                  1.4.3              
libarchive-c                       2.8                
llvmlite                           0.31.0             
locket                             0.2.0              
lxml                               4.5.0              
MarkupSafe                         1.1.1              
matplotlib                         3.1.3              
mccabe                             0.6.1              
menuinst                           1.4.16             
mistune                            0.8.4              
mkl-fft                            1.0.15             
mkl-random                         1.1.0              
mkl-service                        2.3.0              
mlxtend                            0.19.0             
mock                               4.0.1              
more-itertools                     8.2.0              
mpmath                             1.1.0              
msgpack                            0.6.1              
multipledispatch                   0.6.0              
navigator-updater                  0.2.1              
nbconvert                          5.6.1              
nbformat                           5.0.4              
nbmerge                            0.0.4              
networkx                           2.4                
nltk                               3.4.5              
nose                               1.3.7              
notebook                           6.0.3              
numba                              0.48.0             
numexpr                            2.7.1              
numpy                              1.18.1             
numpydoc                           0.9.2              
olefile                            0.46               
openpyxl                           3.0.3              
packaging                          20.1               
pandas                             1.0.1              
pandasql                           0.7.3              
pandocfilters                      1.4.2              
paramiko                           2.7.1              
parso                              0.5.2              
partd                              1.1.0              
path                               13.1.0             
pathlib2                           2.3.5              
pathtools                          0.1.2              
patsy                              0.5.1              
pep8                               1.7.1              
pexpect                            4.8.0              
pickleshare                        0.7.5              
Pillow                             7.0.0              
pip                                20.0.2             
pkginfo                            1.5.0.1            
plotly                             5.5.0              
pluggy                             0.13.1             
ply                                3.11               
prometheus-client                  0.7.1              
prompt-toolkit                     3.0.3              
psutil                             5.6.7              
py                                 1.8.1              
pycodestyle                        2.5.0              
pycosat                            0.6.3              
pycparser                          2.19               
pycrypto                           2.6.1              
pycurl                             7.43.0.5           
pydocstyle                         4.0.1              
pyflakes                           2.1.1              
Pygments                           2.5.2              
pylint                             2.4.4              
PyNaCl                             1.3.0              
pyodbc                             4.0.0-unsupported  
pyOpenSSL                          19.1.0             
pyparsing                          2.4.6              
pyreadline                         2.1                
pyrsistent                         0.15.7             
PySocks                            1.7.1              
pytest                             5.3.5              
pytest-arraydiff                   0.3                
pytest-astropy                     0.8.0              
pytest-astropy-header              0.1.2              
pytest-doctestplus                 0.5.0              
pytest-openfiles                   0.4.0              
pytest-remotedata                  0.3.2              
python-dateutil                    2.8.1              
python-jsonrpc-server              0.3.4              
python-language-server             0.31.7             
pytz                               2019.3             
PyWavelets                         1.1.1              
pywin32                            227                
pywin32-ctypes                     0.2.0              
pywinpty                           0.5.7              
PyYAML                             5.3                
pyzmq                              18.1.1             
QDarkStyle                         2.8                
QtAwesome                          0.6.1              
qtconsole                          4.6.0              
QtPy                               1.9.0              
requests                           2.22.0             
retrying                           1.3.3              
rope                               0.16.0             
Rtree                              0.9.3              
ruamel-yaml                        0.15.87            
scikit-image                       0.16.2             
scikit-learn                       0.22.1             
scipy                              1.4.1              
seaborn                            0.10.0             
Send2Trash                         1.5.0              
setuptools                         45.2.0.post20200210
simplegeneric                      0.8.1              
singledispatch                     3.4.0.3            
six                                1.14.0             
snowballstemmer                    2.0.0              
sortedcollections                  1.1.2              
sortedcontainers                   2.1.0              
soupsieve                          1.9.5              
Sphinx                             2.4.0              
sphinxcontrib-applehelp            1.0.1              
sphinxcontrib-devhelp              1.0.1              
sphinxcontrib-htmlhelp             1.0.2              
sphinxcontrib-jsmath               1.0.1              
sphinxcontrib-qthelp               1.0.2              
sphinxcontrib-serializinghtml      1.1.3              
sphinxcontrib-websupport           1.2.0              
spyder                             4.0.1              
spyder-kernels                     1.8.1              
SQLAlchemy                         1.3.13             
statsmodels                        0.11.0             
sympy                              1.5.1              
tables                             3.6.1              
tblib                              1.6.0              
tenacity                           8.0.1              
terminado                          0.8.3              
testpath                           0.4.4              
toolz                              0.10.0             
tornado                            6.0.3              
tqdm                               4.42.1             
traitlets                          4.3.3              
ujson                              1.35               
unicodecsv                         0.14.1             
urllib3                            1.25.8             
watchdog                           0.10.2             
wcwidth                            0.1.8              
webencodings                       0.5.1              
Werkzeug                           1.0.0              
wheel                              0.34.2             
widgetsnbextension                 3.5.1              
win-inet-pton                      1.1.0              
win-unicode-console                0.5                
wincertstore                       0.2                
wrapt                              1.11.2             
xlrd                               1.2.0              
XlsxWriter                         1.2.7              
xlwings                            0.17.1             
xlwt                               1.3.0              
xmltodict                          0.12.0             
yapf                               0.28.0             
zict                               1.0.0              
zipp                               2.2.0              
Note: you may need to restart the kernel to use updated packages.

导入数据

In [3]:
data = pd.read_csv('kmeans.csv', header=None)
data
Out[3]:
0 1
0 1.658985 4.285136
1 -3.453687 3.424321
2 4.838138 -1.151539
3 -5.379713 -3.362104
4 0.972564 2.924086
... ... ...
75 -2.793241 -2.149706
76 2.884105 3.043438
77 -2.967647 2.848696
78 4.479332 -1.764772
79 -4.905566 -2.911070

80 rows × 2 columns

In [4]:
# 载入数据
#data = np.genfromtxt("data.txt", delimiter=" ")

plt.scatter(data.iloc[:,0],data.iloc[:,1])  # 第一列作为x轴坐标,第二列作为y轴坐标
plt.show()
In [5]:
data.shape
Out[5]:
(80, 2)

辅助函数1:计算距离

In [6]:
def euclDistance(vector1, vector2):  
    """
    函数说明:计算距离
    Parameters:
        vector1  - numpy.ndarray,质心坐标
        vector2  - numpy.ndarray,样本点坐标
    Returns:
        np.sqrt(sum((vector2 - vector1)**2)) - float,距离
    """
    # 欧氏距离
    return np.sqrt(sum((vector2 - vector1)**2))

辅助函数2:初始化质心

In [7]:
def initCentroids(data, k):  
    """
    函数说明:初始化质心,随机从样本点中抽取一个样本。
    Parameters:
        data  - DataFrame,数据集合
        k     - int,聚类个数参数
    Returns:
        centroids,  - numpy.ndarray,簇的中心点坐标,每个数据表示一组点坐标,即图上五角星的坐标。
    """
    
    numSamples, dim = data.values.shape
    # k个质心,列数跟样本的列数一样
    centroids = np.zeros((k, dim))  
    
    # 随机选出k个质心
    for i in range(k):  
        # 随机选取一个样本的索引
        index = int(np.random.uniform(0, numSamples))  
        # 根据index索引从data中拿到数据作为初始化的质心
        centroids[i, :] = data.values[index, :]   
    return centroids  

辅助函数3:画图

In [8]:
def showCluster(data, k, centroids, clusterData):  
    """
    函数说明:显示结果
    Parameters:
        data        - DataFrame,数据集合
        k           - int,聚类个数参数
        centroids,  - numpy.ndarray,簇的中心点坐标,每个数据表示一组点坐标,即图上五角星的坐标。
        clusterData - numpy.ndarray,第一列保存该样本属于哪个簇,第二列保存该样本跟它所属簇的误差。
    Returns:
    """
    
    numSamples, dim = data.shape  
    if dim != 2:  
        print("dimension of your data is not 2!")  
        return 1  
  
    # 用不同颜色形状来表示各个类别
    mark = ['or', 'ob', 'og', 'ok', '^r', '+r', 'sr', 'dr', '<r', 'pr']  
    if k > len(mark):  
        print("Your k is too large!")  
        return 1  
  
    # 画样本点  
    for i in range(numSamples):  
        markIndex = int(clusterData[i, 0])  
        plt.plot(data.values[i, 0], data.values[i, 1], mark[markIndex])  
  
    # 用不同颜色形状来表示各个类别
    mark = ['*r', '*b', '*g', '*k', '^b', '+b', 'sb', 'db', '<b', 'pb']  
    # 画质心点 
    for i in range(k):  
        plt.plot(centroids[i, 0], centroids[i, 1], mark[i], markersize = 20)  
  
    plt.show()

主函数:训练模型

In [9]:
def kmeans(data, k):  
    """
    函数说明:传入数据集和k的值
    Parameters:
        data  - DataFrame,数据集合
        k     - int,聚类个数参数
    Returns:
        centroids,  - numpy.ndarray,簇的中心点坐标,每个数据表示一组点坐标,即图上五角星的坐标。
        clusterData - numpy.ndarray,样本的属性,第一列保存该样本属于哪个簇,第二列保存该样本跟它所属簇的误差。
    """
    
    numSamples = data.shape[0] # 样本个数
    clusterData = np.array(np.zeros((numSamples, 2)))  # clusterData第一列保存该样本属于哪个簇,第二列保存该样本跟它所属簇的误差。
    clusterChanged = True   # 质心是否要改变?当质心不改变时为False即终止循环。
  
    # 初始化质心  
    centroids = initCentroids(data, k)  
  
    while clusterChanged:  
        clusterChanged = False  
        # 循环每一个样本 
        for i in range(numSamples):  
            # 最小距离
            minDist  = 100000.0  
            # 定义样本所属的簇
            minIndex = 0  
            
            # 循环计算每一个质心与该样本的距离
            for j in range(k):  
                # 循环每一个质心和样本,计算距离
                distance = euclDistance(centroids[j, :], data.values[i, :])  
                # 如果计算的距离小于最小距离,则更新最小距离
                if distance < minDist:  
                    minDist  = distance 
                    # 更新最小距离,保存在clusterData第2列
                    clusterData[i, 1] = minDist
                    # 更新样本所属的簇
                    minIndex = j  
              
            # 如果样本的所属的簇发生了变化
            if clusterData[i, 0] != minIndex:  
                # 质心要重新计算
                clusterChanged = True
                # 更新样本的簇,clusterData第1列保存样本所属的簇。
                clusterData[i, 0] = minIndex
  
        # 更新质心
        for j in range(k):  
            # 获取第j个簇所有的样本所在的索引
            cluster_index = np.nonzero(clusterData[:, 0] == j)
            # 第j个簇所有的样本点
            pointsInCluster = data.values[cluster_index]  
            # 计算质心
            centroids[j, :] = np.mean(pointsInCluster, axis = 0) 
 
    return centroids, clusterData  

启动函数

In [10]:
# # 设置k值
# k = 2

# centroids, clusterData = kmeans(data, k)  # 下图五角星是聚类重心点

# # centroids中出现任意空值就报错。
# if np.isnan(centroids).any():
#     print('错误!centroids含有空值,不合法!')
# else:
#     print('聚类成功!')   

# # 显示结果
# showCluster(data, k, centroids, clusterData)  
In [11]:
# 设置k值
k = 4

min_loss = 10000
min_loss_centroids = np.array([])
min_loss_clusterData = np.array([])

for i in range(50):
    # centroids 簇的中心点 
    # cluster Data样本的属性,第一列保存该样本属于哪个簇,第二列保存该样本跟它所属簇的误差
    centroids, clusterData = kmeans(data, k)  
    
    ########################################## 代价函数 ################################################
    loss = sum(clusterData[:,1])/data.shape[0] # 损失函数
    if loss < min_loss:
        min_loss = loss
        min_loss_centroids = centroids
        min_loss_clusterData = clusterData
        
        print('循环:loss',min_loss)
    ####################################################################################################
print('最终:loss',min_loss)    
centroids = min_loss_centroids
clusterData = min_loss_clusterData

# 显示结果
showCluster(data, k, centroids, clusterData)
循环:loss 1.1696794346725077
循环:loss 1.1675654672086737
D:\ProgramData\Anaconda3\lib\site-packages\numpy\core\fromnumeric.py:3335: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
D:\ProgramData\Anaconda3\lib\site-packages\numpy\core\_methods.py:154: RuntimeWarning: invalid value encountered in true_divide
  ret, rcount, out=ret, casting='unsafe', subok=False)
最终:loss 1.1675654672086737

做预测

In [12]:
# 做预测
x_test = [0,1]
np.tile(x_test,(k,1))
Out[12]:
array([[0, 1],
       [0, 1],
       [0, 1],
       [0, 1]])
In [13]:
# 误差
np.tile(x_test,(k,1))-centroids
Out[13]:
array([[-2.6265299 , -2.10868015],
       [ 3.53973889,  3.89384326],
       [-2.65077367,  3.79019029],
       [ 2.46154315, -1.78737555]])
In [14]:
# 误差平方
(np.tile(x_test,(k,1))-centroids)**2
Out[14]:
array([[ 6.89865932,  4.44653198],
       [12.52975144, 15.16201536],
       [ 7.02660103, 14.3655424 ],
       [ 6.05919468,  3.19471136]])
In [15]:
# 误差平方和
((np.tile(x_test,(k,1))-centroids)**2).sum(axis=1)
Out[15]:
array([11.34519129, 27.6917668 , 21.39214343,  9.25390604])
In [16]:
# 最小值所在的索引号
np.argmin(((np.tile(x_test,(k,1))-centroids)**2).sum(axis=1))
Out[16]:
3
In [17]:
def predict(datas):
    return np.array([np.argmin(((np.tile(data,(k,1))-centroids)**2).sum(axis=1)) for data in datas])

画出簇的作用区域

In [18]:
# 获取数据值所在的范围
# 生成网格矩阵,通过生成网络点坐标,代入预测函数得到聚类结果以绘制出作用区域。
x_min, x_max = data.values[:, 0].min() - 1, data.values[:, 0].max() + 1
y_min, y_max = data.values[:, 1].min() - 1, data.values[:, 1].max() + 1

xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
                     np.arange(y_min, y_max, 0.02))

z = predict(np.c_[xx.ravel(), yy.ravel()])# ravel与flatten类似,多维数据转一维。flatten不会改变原始数据,ravel会改变原始数据
z = z.reshape(xx.shape)
# 等高线图
cs = plt.contourf(xx, yy, z)
# 显示结果
showCluster(data, k, centroids, clusterData)  
In [ ]: