1
2 import numpy
3 from numpy import arange
4 import matplotlib
5 from matplotlib import pylab
6 pylab.rcParams['contour.negative_linestyle'] = 'solid'
7 from PyML.containers.vectorDatasets import VectorDataSet
8 from PyML.containers.labels import Labels
9
10 """
11 demo2d: display decision boundaries and contours of the decision function
12 of a classifier on two dimensional data.
13 it usually works with the latest version of matplotlib.
14
15 USAGE::
16
17 first you need to generate some data; you need to call
18 demo2d.getData()
19 data is generated by clicking '1' or '2' at positions on the figure
20 where you want your data points to be.
21 click 'q' when you're done.
22 demo2d.decisionSurface(classifier) then plots the decision boundary and
23 contours of the decision function of the given classifier on the data
24 that was generated.
25 demo2d.decisionSurface can be called several times using different classifiers.
26 """
27
28 X = []
29 Y = []
30
31 plotStr = ['or', '+b']
32 xmin = -1
33 xmax = 1
34 ymin = -1
35 ymax = 1
36
38
39 global data
40 global X
41 global Y
42 if event.key == 'q' :
43 if len(X) == 0 : return
44 data = VectorDataSet(X)
45 data.attachLabels(Labels(Y))
46 X = []
47 Y = []
48 print 'done creating data. close this window and use the decisionSurface function'
49 pylab.disconnect(binding_id)
50 if event.key =='1' or event.key == '2' :
51 if event.inaxes is not None:
52 print 'data coords', event.xdata, event.ydata
53 X.append([event.xdata, event.ydata])
54 Y.append(event.key)
55 pylab.plot([event.xdata], [event.ydata],
56 plotStr[int(event.key) - 1])
57 pylab.draw()
58
60
61 pylab.subplot(111)
62 pylab.plot([xmin,xmin,xmax,xmax], [ymin,ymax,ymin,ymax], '.k')
63 pylab.title("press the numbers 1 or 2 to generate data points and 'q' to quit")
64 global binding_id
65 binding_id = pylab.connect('key_press_event', pick)
66 pylab.show()
67
69
70 global data
71
72 data = data_
73
75
76 for c in range(data.labels.numClasses) :
77 x1 = []
78 x2 = []
79 for p in data.labels.classes[c] :
80 x = data.getPattern(p)
81 x1.append(x[0])
82 x2.append(x[1])
83 pylab.plot(x1, x2, plotStr[c], markersize=markersize)
84
85
87
88 global data
89 classifier.train(data)
90
91 numContours = 3
92 if 'numContours' in args :
93 numContours = args['numContours']
94 title = None
95 if 'title' in args :
96 title = args['title']
97 markersize=5
98 fontsize = 'medium'
99 if 'markersize' in args :
100 markersize = args['markersize']
101 if 'fontsize' in args :
102 fontsize = args['fontsize']
103 contourFontsize = 10
104 if 'contourFontsize' in args :
105 contourFontsize = args['contourFontsize']
106 showColorbar = False
107 if 'showColorbar' in args :
108 showColorbar = args['showColorbar']
109 show = True
110 if fileName is not None :
111 show = False
112 if 'show' in args :
113 show = args['show']
114
115
116 delta = 0.01
117 if 'delta' in args :
118 delta = args['delta']
119
120 x = arange(xmin, xmax, delta)
121 y = arange(ymin, ymax, delta)
122
123 Z = numpy.zeros((len(x), len(y)), numpy.float_)
124 gridX = numpy.zeros((len(x) *len(y), 2), numpy.float_)
125 n = 0
126 for i in range(len(x)) :
127 for j in range(len(y)) :
128 gridX[n][0] = x[i]
129 gridX[n][1] = y[j]
130 n += 1
131
132 gridData = VectorDataSet(gridX)
133 gridData.attachKernel(data.kernel)
134 results = classifier.test(gridData)
135
136 n = 0
137 for i in range(len(x)) :
138 for j in range(len(y)) :
139 Z[i][j] = results.decisionFunc[n]
140 n += 1
141
142
143 im = pylab.imshow(numpy.transpose(Z),
144 interpolation='bilinear', origin='lower',
145 cmap=pylab.cm.gray, extent=(xmin,xmax,ymin,ymax) )
146
147 if numContours == 1 :
148 C = pylab.contour(numpy.transpose(Z),
149 [0],
150 origin='lower',
151 linewidths=(3),
152 colors = 'black',
153 extent=(xmin,xmax,ymin,ymax))
154 elif numContours == 3 :
155 C = pylab.contour(numpy.transpose(Z),
156 [-1,0,1],
157 origin='lower',
158 linewidths=(1,3,1),
159 colors = 'black',
160 extent=(xmin,xmax,ymin,ymax))
161 else :
162 C = pylab.contour(numpy.transpose(Z),
163 numContours,
164 origin='lower',
165 linewidths=2,
166 extent=(xmin,xmax,ymin,ymax))
167
168 pylab.clabel(C,
169 inline=1,
170 fmt='%1.1f',
171 fontsize=contourFontsize)
172
173
174 scatter(data, markersize)
175 xticklabels = pylab.getp(pylab.gca(), 'xticklabels')
176 yticklabels = pylab.getp(pylab.gca(), 'yticklabels')
177 pylab.setp(xticklabels, fontsize=fontsize)
178 pylab.setp(yticklabels, fontsize=fontsize)
179
180 if title is not None :
181 pylab.title(title, fontsize=fontsize)
182 if showColorbar :
183 pylab.colorbar(im)
184
185
186 pylab.hot()
187 if fileName is not None :
188 pylab.savefig(fileName)
189 if show :
190 pylab.show()
191