knn.py 2.0 KB
Newer Older
F
feilong 已提交
1
# -*- coding: UTF-8 -*-
F
fix bug  
feilong 已提交
2
# 作者:qq_44193969
F
feilong 已提交
3 4 5 6 7 8 9
# 标题:无监督聚类
# 描述:KNN 聚类,近朱者赤,近墨者黑

import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier

F
feilong 已提交
10

F
feilong 已提交
11 12 13 14 15 16 17 18 19 20 21 22 23 24
def generate_data(class1_num, class2_num):
    np.random.seed(2021)
    data_size_1 = class1_num
    x1_1 = np.random.normal(loc=2, scale=1.0, size=data_size_1)
    x2_1 = np.random.normal(loc=3, scale=1.0, size=data_size_1)
    y_1 = [0 for _ in range(data_size_1)]

    data_size_2 = class2_num
    x1_2 = np.random.normal(loc=6, scale=2.0, size=data_size_2)
    x2_2 = np.random.normal(loc=8, scale=2.0, size=data_size_2)
    y_2 = [1 for _ in range(data_size_2)]

    x1 = np.concatenate((x1_1, x1_2), axis=0)
    x2 = np.concatenate((x2_1, x2_2), axis=0)
F
feilong 已提交
25
    x = np.hstack((x1.reshape(-1, 1), x2.reshape(-1, 1)))
F
feilong 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39
    y = np.concatenate((y_1, y_2), axis=0)

    data_size_all = data_size_1+data_size_2
    shuffled_index = np.random.permutation(data_size_all)
    x = x[shuffled_index]
    y = y[shuffled_index]

    split_index = int(data_size_all*0.7)
    x_train = x[:split_index]
    y_train = y[:split_index]
    x_test = x[split_index:]
    y_test = y[split_index:]
    return x_train, y_train, x_test, y_test

F
feilong 已提交
40

F
feilong 已提交
41
def show_data(x_train, y_train, x_test, y_test):
F
feilong 已提交
42
    plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, marker='.')
F
feilong 已提交
43
    plt.show()
F
feilong 已提交
44
    plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test, marker='.')
F
feilong 已提交
45 46 47
    plt.show()


F
feilong 已提交
48
def train_and_predict(is_show=False):
F
feilong 已提交
49 50 51 52 53 54 55
    x_train, y_train, x_test, y_test = generate_data(300, 500)
    if is_show:
        show_data(x_train, y_train, x_test, y_test)
    neigh = KNeighborsClassifier(n_neighbors=2)
    neigh.fit(x_train, y_train)
    acc_count = 0
    for idx, x_test_ in enumerate(x_test):
F
feilong 已提交
56
        res = neigh.predict(x_test_.reshape(1, -1))
F
feilong 已提交
57 58 59 60 61
        if res[0] == y_test[idx]:
            acc_count += 1
    acc = acc_count / len(x_test)
    print('准确率为: {}'.format(acc))

F
feilong 已提交
62

F
feilong 已提交
63 64
if __name__ == '__main__':
    train_and_predict()