关于KNN算法识别MNIST数据集
来源:12-1 什么是决策树
慕九州9175731
2018-10-15
老师您好,我手写KNN算法识别MNIST数据集,但结果不尽如人意仅有30%左右的准确率,请问是什么原因导致的呢?以下是源代码
import numpy as np
from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from math import sqrt
from collections import Counter
def KNNClassifier(X_train,y_train,k,x):
distances = [sqrt(np.sum((x_train - x)**2)) for x_train in X_train]
nearest = np.argsort(distances)
topK_y = [y_train[i] for i in nearest[:k]]
votes = Counter(topK_y)
return votes.most_common(1)[0][0]
def PredictMatrix(X_train,y_train,k,X_predict):
"""给定待预测数据集X_predict,返回表示X_predict的结果向量"""
y_predict = [KNNClassifier(X_train,y_train,k,x) for x in X_predict]
return np.array(y_predict)
def Accuracy(predict_matrix,label_matrix):
return sum(predict_matrix==label_matrix)/label_matrix.shape[0]
#获取MNIST数据集
mnist = fetch_mldata("MNIST original")
#随机获取MNIST中的5000个样本作为样本点
sample = np.array(np.random.randint(low=0,high=70000, size=5000))
data = mnist.data[sample]
target = mnist.target[sample]
#训练测试集划分
X_train, X_test, y_train, y_test = np.array(train_test_split(data, target, train_size=0.9))
prediction_matrix = PredictMatrix(X_train,y_train,5,X_test)
score = Accuracy(prediction_matrix,y_test)
print('knn score: %f' % score)
1回答
-
MNIST一共有70000个数据点。其中官方数据集的设计中,60000个数据点是训练数据,10000个数据点是测试数据。你只从中随机取样了5000个点进行训练,连10%都不到。这是无法保证你的取样数据是在0-9这十个数字中均匀取样的。在最坏的情况下,很有可能你取样的5000个点,只囊括1个个数字的样本!当然这是最坏情况,实际,根据你的结果,很可能你的取样,只充分囊括了3个数字对应的信息样本。
我在服务器级别的机器上,使用课程中第四章我们自己写的kNN,跑了一遍MNIST数据,用全部60000个点做训练,10000个点做测试,最终的结果是96.93%:)(k采用和sklearn中的kNN一样的默认值:5)
然而,对于服务器级别的机器,跑了近一个小时。在普通的家用机上,很可能要跑一天。
相关notebook我已经放到了github上,可以参考这里:https://gihub.com/liuyubobobo/Play-with-Machine-Learning-Algorithms/blob/master/04-kNN/Optional-03-kNN-for-MNIST/Optional-03-kNN-for-MNIST.ipynb
实际上,对于kNN来说,他的问题就在于:实际预测效率低下!这也是其实,在实际应用中,你很少看到使用kNN的原因。当然,如果有兴趣,可以找一台闲置的电脑跑跑看。如果写的代码和课程代码一致,应该会得到和我一样的准确率:)
这个课程的所有代码,都可以通过官方github获得,传送门:https://github.com/liuyubobobo/Play-with-Machine-Learning-Algorithms
加油!:)
122018-10-15
相似问题