KNN图像分类(基于python3.6)
1.数据来源
CIFAR-10是一个常用的图像分类数据集。数据集包含60000张32*32像素的小图片,每张图片都有一个类别标注(总共有10类),分成了50000张的训练集和10000张的测试集。
python中提取CIFAR-10的代码如下:
a python2 routine which will open such a file and return a dictionary:
def unpickle(file):
import cPickle
with open(file, 'rb') as fo:
dict = cPickle.load(fo)
return dict
And a python3 version:
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
2.K最近邻分类器(K Nearset Neighbor Classifier )
import os
import numpy as np
import pickle
def load_CIFAR_batch(filename):
with open(filename, 'rb') as fo:
d= pickle.load(fo, encoding='bytes')
X=d[b'data']
Y=d[b'labels']
X=X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
Y=np.array(Y)
return X, Y
def load_CIFAR10(ROOT):
xs=[]
ys=[]
for b in range(1,6):
f=os.path.join(ROOT, "data_batch_%d" % (b, ))
X, Y=load_CIFAR_batch(f)
xs.append(X)
ys.append(Y)
X_train=np.concatenate(xs)
Y_train=np.concatenate(ys)
del X, Y
X_test, Y_test=load_CIFAR_batch(os.path.join(ROOT, "test_batch"))
return X_train, Y_train, X_test, Y_test
X_train, Y_train, X_test, Y_test = load_CIFAR10('F:\python/cifar-10-batches-py/')
Xtr_rows = X_train.reshape(X_train.shape[0], 32 * 32 * 3)
Xte_rows = X_test.reshape(X_test.shape[0], 32 * 32 * 3)
class NearestNeighbor:
def __init__(self):
pass
def train(self, X, y):
self.Xtr = X
self.ytr = y
def predict(self, X,k):
num_test = X.shape[0]
Ypred = np.zeros(num_test, dtype = self.ytr.dtype)
for i in range(num_test):
jishu={}
distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)
for j in range(k):
min_index = np.argmin(distances)
max_index=np.argmax(distances)
if self.ytr[min_index] in jishu.keys():
jishu[self.ytr[min_index]]+=1
else:
jishu[self.ytr[min_index]]=1
distances[min_index]=distances[max_index]
Ypred[i]=max(jishu.items(), key=lambda x: x[1])[0]
return Ypred
Xval_rows = Xtr_rows[:1000, :]
Yval = Y_train[:1000]
Xtr_rows = Xtr_rows[1000:, :]
Ytr = Y_train[1000:]
nn = NearestNeighbor()
nn.train(Xtr_rows,Ytr)
for k in [3,5,7,10,20]:
Yte_predict = nn.predict(Xval_rows,k)
print ('k=%d'%(k),'accuracy: %f' % ( np.mean(Yte_predict == Yval) ))
print("end")