关于KNN算法识别MNIST数据集

来源:12-1 什么是决策树

慕九州9175731

2018-10-15

老师您好,我手写KNN算法识别MNIST数据集,但结果不尽如人意仅有30%左右的准确率,请问是什么原因导致的呢?以下是源代码

import numpy as np
from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from math import sqrt
from collections import Counter

def KNNClassifier(X_train,y_train,k,x):
    distances = [sqrt(np.sum((x_train - x)**2)) for x_train in X_train]
    nearest = np.argsort(distances)

    topK_y = [y_train[i] for i in nearest[:k]]
    votes = Counter(topK_y)

    return votes.most_common(1)[0][0]

def PredictMatrix(X_train,y_train,k,X_predict):
    """给定待预测数据集X_predict,返回表示X_predict的结果向量"""

    y_predict = [KNNClassifier(X_train,y_train,k,x) for x in X_predict]
    return np.array(y_predict)

def Accuracy(predict_matrix,label_matrix):
    return sum(predict_matrix==label_matrix)/label_matrix.shape[0]

#获取MNIST数据集
mnist = fetch_mldata("MNIST original")

#随机获取MNIST中的5000个样本作为样本点

sample = np.array(np.random.randint(low=0,high=70000, size=5000))

data = mnist.data[sample]
target = mnist.target[sample]

#训练测试集划分
X_train, X_test, y_train, y_test = np.array(train_test_split(data, target, train_size=0.9))
        
prediction_matrix = PredictMatrix(X_train,y_train,5,X_test)

score = Accuracy(prediction_matrix,y_test)

print('knn score: %f' % score)
写回答

1回答

liuyubobobo

2018-10-15

MNIST一共有70000个数据点。其中官方数据集的设计中,60000个数据点是训练数据,10000个数据点是测试数据。你只从中随机取样了5000个点进行训练,连10%都不到。这是无法保证你的取样数据是在0-9这十个数字中均匀取样的。在最坏的情况下,很有可能你取样的5000个点,只囊括1个个数字的样本!当然这是最坏情况,实际,根据你的结果,很可能你的取样,只充分囊括了3个数字对应的信息样本。


我在服务器级别的机器上,使用课程中第四章我们自己写的kNN,跑了一遍MNIST数据,用全部60000个点做训练,10000个点做测试,最终的结果是96.93%:)(k采用和sklearn中的kNN一样的默认值:5)

然而,对于服务器级别的机器,跑了近一个小时。在普通的家用机上,很可能要跑一天。

相关notebook我已经放到了github上,可以参考这里:https://gihub.com/liuyubobobo/Play-with-Machine-Learning-Algorithms/blob/master/04-kNN/Optional-03-kNN-for-MNIST/Optional-03-kNN-for-MNIST.ipynb


实际上,对于kNN来说,他的问题就在于:实际预测效率低下!这也是其实,在实际应用中,你很少看到使用kNN的原因。当然,如果有兴趣,可以找一台闲置的电脑跑跑看。如果写的代码和课程代码一致,应该会得到和我一样的准确率:)


这个课程的所有代码,都可以通过官方github获得,传送门:https://github.com/liuyubobobo/Play-with-Machine-Learning-Algorithms


加油!:)

1
2
慕九州9175731
谢谢老师,我自己也猜测应该是随机取样造成的偏差,但没有想到具体原因是什么,谢谢老师为我解答这个疑惑,感谢~
2018-10-15
共2条回复

Python3入门机器学习 经典算法与应用  

Python3+sklearn,兼顾原理、算法底层实现和框架使用。

5839 学习 · 2437 问题

查看课程