热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

关于KNN的python3实现

关于KNN,有幸看到这篇文章,写的很好,这里就不在赘述。直接贴上代码了,有小的改动。(原来是python2版本的,这里改为python3的,主要就是print)环境:win732bit+

  关于KNN,有幸看到这篇文章,写的很好,这里就不在赘述。直接贴上代码了,有小的改动。(原来是python2版本的,这里改为python3的,主要就是print)

  环境:win7 32bit + spyder + anaconda3.5

  一、初阶

# -*- coding: utf-8 -*-
"""
Created on Sun Nov 6 16:09:00 2016

@author: Administrator
"""

#Input:
#newInput:待测的数据点(1xM)
#dataSet:已知的数据(NxM)
#labels:已知数据的标签(1xM)
#k:选取的最邻近数据点的个数
#
#Output:
#待测数据点的分类标签
#

from numpy import *

# creat a dataset which contain 4 samples with 2 class
def createDataSet():
# creat a matrix: each row as a sample
group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels


#classify using KNN
def KNNClassify(newInput, dataSet, labels, k):
numSamples = dataSet.shape[0] # row number
# step1:calculate Euclidean distance
# tile(A, reps):Constract an array by repeating A reps times
diff = tile(newInput, (numSamples, 1)) - dataSet
squreDiff = diff**2
squreDist = sum(squreDiff, axis=1) # sum if performed by row
distance = squreDist ** 0.5

#step2:sort the distance
# argsort() returns the indices that would sort an array in a ascending order
sortedDistIndices = argsort(distance)

classCount = {}
for i in range(k):
# choose the min k distance
voteLabel = labels[sortedDistIndices[i]]

#step4:count the times labels occur
# when the key voteLabel is not in dictionary classCount,
# get() will return 0
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
#step5:the max vote class will return
maxCount = 0
for k, v in classCount.items():
if v > maxCount:
maxCount = v
maxIndex = k

return maxIndex


# test

dataSet, labels = createDataSet()

testX = array([1.2, 1.0])
k = 3
outputLabel = KNNClassify(testX, dataSet, labels, 3)

print("Your input is:", testX, "and classified to class: ", outputLabel)


testX = array([0.1, 0.3])
k = 3
outputLabel = KNNClassify(testX, dataSet, labels, 3)

print("Your input is:", testX, "and classified to class: ", outputLabel)

  运行结果:

 

  二、进阶

  用到的手写识别数据库资料在这里下载。关于资料的介绍在上面的博文也已经介绍的很清楚了。

# -*- coding: utf-8 -*-
"""
Created on Sun Nov 6 16:09:00 2016

@author: Administrator
"""

#Input:
#newInput:待测的数据点(1xM)
#dataSet:已知的数据(NxM)
#labels:已知数据的标签(1xM)
#k:选取的最邻近数据点的个数
#
#Output:
#待测数据点的分类标签
#

from numpy import *



#classify using KNN
def KNNClassify(newInput, dataSet, labels, k):
numSamples = dataSet.shape[0] # row number
# step1:calculate Euclidean distance
# tile(A, reps):Constract an array by repeating A reps times
diff = tile(newInput, (numSamples, 1)) - dataSet
squreDiff = diff**2
squreDist = sum(squreDiff, axis=1) # sum if performed by row
distance = squreDist ** 0.5

#step2:sort the distance
# argsort() returns the indices that would sort an array in a ascending order
sortedDistIndices = argsort(distance)

classCount = {}
for i in range(k):
# choose the min k distance
voteLabel = labels[sortedDistIndices[i]]

#step4:count the times labels occur
# when the key voteLabel is not in dictionary classCount,
# get() will return 0
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
#step5:the max vote class will return
maxCount = 0
for k, v in classCount.items():
if v > maxCount:
maxCount = v
maxIndex = k

return maxIndex



# convert image to vector
def img2vector(filename):
rows = 32
cols = 32
imgVector = zeros((1, rows * cols))
fileIn = open(filename)
for row in range(rows):
lineStr = fileIn.readline()
for col in range(cols):
imgVector[0, row * 32 + col] = int(lineStr[col])

return imgVector


# load dataSet
def loadDataSet():
## step 1: Getting training set
print("---Getting training set...")
dataSetDir = 'F:\\Techonolgoy\\算法学习\\KNN\\进阶\\'
trainingFileList = os.listdir(dataSetDir + 'trainingDigits') # load the training set
numSamples = len(trainingFileList)

train_x = zeros((numSamples, 1024))
train_y = []
for i in range(numSamples):
filename = trainingFileList[i]

# get train_x
train_x[i, :] = img2vector(dataSetDir + 'trainingDigits/%s' % filename)

# get label from file name such as "1_18.txt"
label = int(filename.split('_')[0]) # return 1
train_y.append(label)

## step 2: Getting testing set
print("---Getting testing set...")
testingFileList = os.listdir(dataSetDir + 'testDigits') # load the testing set
numSamples = len(testingFileList)
test_x = zeros((numSamples, 1024))
test_y = []
for i in range(numSamples):
filename = testingFileList[i]

# get train_x
test_x[i, :] = img2vector(dataSetDir + 'testDigits/%s' % filename)

# get label from file name such as "1_18.txt"
label = int(filename.split('_')[0]) # return 1
test_y.append(label)

return train_x, train_y, test_x, test_y

# test hand writing class
def testHandWritingClass():
## step 1: load data
print("step 1: load data...")
train_x, train_y, test_x, test_y = loadDataSet()

## step 2: training...
print("step 2: training...")
pass

## step 3: testing
print("step 3: testing...")
numTestSamples = test_x.shape[0]
matchCount = 0
for i in range(numTestSamples):
predict = KNNClassify(test_x[i], train_x, train_y, 3)
if predict == test_y[i]:
matchCount += 1
accuracy = float(matchCount) / numTestSamples

## step 4: show the result
print("step 4: show the result...")
print('The classify accuracy is: %.2f%%' % (accuracy * 100))



testHandWritingClass()

  运行结果:

 


推荐阅读
  • 在 Flutter 开发过程中,开发者经常会遇到 Widget 构造函数中的可选参数 Key。对于初学者来说,理解 Key 的作用和使用场景可能是一个挑战。本文将详细探讨 Key 的概念及其应用场景,并通过实例帮助你更好地掌握这一重要工具。 ... [详细]
  • 毕业设计:基于机器学习与深度学习的垃圾邮件(短信)分类算法实现
    本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ... [详细]
  • 本文详细介绍了Java中org.neo4j.helpers.collection.Iterators.single()方法的功能、使用场景及代码示例,帮助开发者更好地理解和应用该方法。 ... [详细]
  • 题目描述:给定n个半开区间[a, b),要求使用两个互不重叠的记录器,求最多可以记录多少个区间。解决方案采用贪心算法,通过排序和遍历实现最优解。 ... [详细]
  • 使用 Azure Service Principal 和 Microsoft Graph API 获取 AAD 用户列表
    本文介绍了一段通用代码示例,该代码不仅能够操作 Azure Active Directory (AAD),还可以通过 Azure Service Principal 的授权访问和管理 Azure 订阅资源。Azure 的架构可以分为两个层级:AAD 和 Subscription。 ... [详细]
  • 深入解析Spring Cloud Ribbon负载均衡机制
    本文详细介绍了Spring Cloud中的Ribbon组件如何实现服务调用的负载均衡。通过分析其工作原理、源码结构及配置方式,帮助读者理解Ribbon在分布式系统中的重要作用。 ... [详细]
  • 前言--页数多了以后需要指定到某一页(只做了功能,样式没有细调)html ... [详细]
  • 本文深入探讨了 Java 中的 Serializable 接口,解释了其实现机制、用途及注意事项,帮助开发者更好地理解和使用序列化功能。 ... [详细]
  • UNP 第9章:主机名与地址转换
    本章探讨了用于在主机名和数值地址之间进行转换的函数,如gethostbyname和gethostbyaddr。此外,还介绍了getservbyname和getservbyport函数,用于在服务器名和端口号之间进行转换。 ... [详细]
  • 本文详细介绍了如何构建一个高效的UI管理系统,集中处理UI页面的打开、关闭、层级管理和页面跳转等问题。通过UIManager统一管理外部切换逻辑,实现功能逻辑分散化和代码复用,支持多人协作开发。 ... [详细]
  • 本文探讨了 Objective-C 中的一些重要语法特性,包括 goto 语句、块(block)的使用、访问修饰符以及属性管理等。通过实例代码和详细解释,帮助开发者更好地理解和应用这些特性。 ... [详细]
  • 本文详细介绍了 Apache Jena 库中的 Txn.executeWrite 方法,通过多个实际代码示例展示了其在不同场景下的应用,帮助开发者更好地理解和使用该方法。 ... [详细]
  • 本文详细介绍了Java中的访问器(getter)和修改器(setter),探讨了它们在保护数据完整性、增强代码可维护性方面的重要作用。通过具体示例,展示了如何正确使用这些方法来控制类属性的访问和更新。 ... [详细]
  • 题目Link题目学习link1题目学习link2题目学习link3%%%受益匪浅!-----&# ... [详细]
  • 本文详细介绍了 React 中的两个重要 Hook 函数:useState 和 useEffect。通过具体示例,解释了如何使用它们来管理组件状态和处理副作用。 ... [详细]
author-avatar
书友62423539
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有