python

超轻量级php框架startmvc

使用python实现knn算法

更新时间:2020-05-14 21:12:01 作者:startmvc
本文实例为大家分享了python实现knn算法的具体代码,供大家参考,具体内容如下knn算法描述

本文实例为大家分享了python实现knn算法的具体代码,供大家参考,具体内容如下

knn算法描述

对需要分类的点依次执行以下操作: 1.计算已知类别数据集中每个点与该点之间的距离 2.按照距离递增顺序排序 3.选取与该点距离最近的k个点 4.确定前k个点所在类别出现的频率 5.返回前k个点出现频率最高的类别作为该点的预测分类

knn算法实现

数据处理


#从文件中读取数据,返回的数据和分类均为二维数组
def loadDataSet(filename):
 dataSet = []
 labels = []
 fr = open(filename)
 for line in fr.readlines():
 lineArr = line.strip().split(",")
 dataSet.append([float(lineArr[0]),float(lineArr[1])])
 labels.append([float(lineArr[2])])
 return dataSet , labels


knn算法


#计算两个向量之间的欧氏距离
def calDist(X1 , X2):
 sum = 0
 for x1 , x2 in zip(X1 , X2):
 sum += (x1 - x2) ** 2
 return sum ** 0.5

def knn(data , dataSet , labels , k):
 n = shape(dataSet)[0]
 for i in range(n):
 dist = calDist(data , dataSet[i])
 #只记录两点之间的距离和已知点的类别
 labels[i].append(dist)
 #按照距离递增排序
 labels.sort(key=lambda x:x[1])
 count = {}
 #统计每个类别出现的频率
 for i in range(k):
 key = labels[i][0]
 if count.has_key(key):
 count[key] += 1
 else : count[key] = 1
 #按频率递减排序
 sortCount = sorted(count.items(),key=lambda item:item[1],reverse=True)
 return sortCount[0][0]#返回频率最高的key,即label


结果测试

已知类别数据(来源于西瓜书+虚构)

0.697,0.460,1 0.774,0.376,1 0.720,0.330,1 0.634,0.264,1 0.608,0.318,1 0.556,0.215,1 0.403,0.237,1 0.481,0.149,1 0.437,0.211,1 0.525,0.186,1 0.666,0.091,0 0.639,0.161,0 0.657,0.198,0 0.593,0.042,0 0.719,0.103,0 0.671,0.196,0 0.703,0.121,0 0.614,0.116,0

绘图方法


def drawPoints(data , dataSet, labels):
 xcord1 = [];
 ycord1 = [];
 xcord2 = [];
 ycord2 = [];
 for i in range(shape(dataSet)[0]):
 if labels[i][0] == 0:
 xcord1.append(dataSet[i][0])
 ycord1.append(dataSet[i][1])
 if labels[i][0] == 1:
 xcord2.append(dataSet[i][0])
 ycord2.append(dataSet[i][1])
 fig = plt.figure()
 ax = fig.add_subplot(111)
 ax.scatter(xcord1, ycord1, s=30, c='blue', marker='s',label=0)
 ax.scatter(xcord2, ycord2, s=30, c='green',label=1)
 ax.scatter(data[0], data[1], s=30, c='red',label="testdata")
 plt.legend(loc='upper right')
 plt.show()


测试代码


dataSet , labels = loadDataSet('dataSet.txt')
data = [0.6767,0.2122]
drawPoints(data , dataSet, labels)
newlabels = knn(data, dataSet , labels , 5)
print newlabels

运行结果

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

python knn