1 import numpy
2
3 from PyML.classifiers.ext import knn
4 from PyML.evaluators import assess
5 from PyML.classifiers.baseClassifiers import Classifier
6 import time
7 from PyML.utils import arrayWrap
8
9 -class KNN (Classifier) :
10 """
11 a K-Nearest-Neighbors classifier
12
13 :Keywords:
14 - `k` - the number of nearest neighbors on which to base the classification
15 [default : 3]
16
17 if the training data is a C++ dataset (e.g. SparseDataSet) classification
18 is much faster since everything is done in C++; if a python container is
19 used then it's a slower pure python implementation.
20 """
21
22 attributes = {'k' : 3}
23
24 - def __init__(self, arg = None, **args) :
27
29
30 rep = '<' + self.__class__.__name__ + ' instance>\n'
31 rep += 'number of nearest neighbors: ' + str(self.k)
32
33 return rep
34
35
36 - def train(self, data, **args) :
45
46
48
49 '''For each class the sum of the distances to the k nearest neighbors
50 is computed. The distance is computed using the given kernel'''
51
52 x = data.X[i]
53 if self.projectionRequired :
54 xproj = self.project(x)
55 else :
56 xproj = x
57 numClasses = self.labels.numClasses
58 s = []
59 for c in range(numClasses) :
60 sim = [self.data.dotProduct(xproj, self.data.X[i])
61 for i in self.data.labels.classes[c]]
62
63 s.append(numpy.sum(numpy.sort(sim)[-self.k:]))
64
65 if numClasses > 2 :
66 return numpy.argmax(s), max(s) - numpy.sort(s)[-2]
67 elif numClasses == 2 :
68 return numpy.argmax(s), s[0] - s[1]
69 else :
70 raise ValueError, 'wrong number of classes'
71
72 - def test(self, data, **args) :
78
79 - def testC(self, data, **args) :
80
81 testStart = time.clock()
82 if data.testingFunc is not None :
83 data.test(self.trainingData, **args)
84
85 cdecisionFunc = arrayWrap.doubleVector([])
86 cY = self.knnc.test(data.castToBase(), cdecisionFunc)
87
88 res = self.resultsObject(data, self, **args)
89 for i in range(len(data)) :
90 res.appendPrediction((cY[i], cdecisionFunc[i]), data, i)
91
92 res.log = self.log
93 try :
94 computeStats = args['stats']
95 except :
96 computeStats = False
97 if computeStats and data.labels.L is not None :
98 res.computeStats()
99
100 res.log.testingTime = time.clock() - testStart
101
102 return res
103
107