Package PyML :: Package classifiers :: Module baseClassifiers
[frames] | no frames]

Source Code for Module PyML.classifiers.baseClassifiers

  1  import numpy 
  2  import time 
  3  import copy 
  4   
  5  from PyML.evaluators import assess 
  6  from PyML.utils import misc 
  7  from PyML.base.pymlObject import PyMLobject 
  8   
  9  """base class for for PyML classifiers""" 
 10   
 11  __docformat__ = "restructuredtext en" 
 12   
 13  containersRequiringProjection = ['VectorDataSet', 'PyVectorDataSet'] 
 14   
 15   
16 -class Classifier (PyMLobject) :
17 18 """base class for PyML classifiers, specifying the classifier api""" 19 20 type = 'classifier' 21 deepcopy = False 22 23 # the type of Results object returned by testing a classifier: 24 resultsObject = assess.ClassificationResults 25 26 test = assess.test 27 cv = assess.cv 28 stratifiedCV = assess.stratifiedCV 29 loo = assess.loo 30 trainTest = assess.trainTest 31 nCV = assess.nCV 32
33 - def __init__(self, arg = None, **args) :
34 35 PyMLobject.__init__(self, arg, **args) 36 if type(arg) == type('') : 37 self.load(arg) 38 self.log = misc.Container()
39
40 - def logger(self) :
41 42 pass
43
44 - def __repr__(self) :
45 46 return '<' + self.__class__.__name__ + ' instance>\n'
47
48 - def project(self, data) :
49 """ 50 project a test dataset to the training data features. 51 """ 52 53 if data.__class__.__name__ not in containersRequiringProjection : 54 return 55 if misc.listEqual(self.featureID, data.featureID) : 56 return 57 58 if len(misc.intersect(self.featureID, data.featureID)) != len(self.featureID) : 59 raise ValueError, 'missing features in test data' 60 61 featuresToEliminate = [i for i in range(data.numFeatures) 62 if data.featureID[i] not in self.featureDict] 63 data.eliminateFeatures(featuresToEliminate)
64
65 - def save(self, fileHandle) :
66 67 raise NotImplementedError, 'your classifier does not implement this function'
68
69 - def train(self, data, **args) :
70 71 # store the current cpu time: 72 self._clock = time.clock() 73 74 if not data.labels.numericLabels : 75 # check if there is a class that is not represented in the training data: 76 if min(data.labels.classSize) == 0 : 77 raise ValueError, 'there is a class with no data' 78 79 # store just as much about the labels as is needed: 80 self.labels = misc.Container() 81 self.labels.addAttributes(data.labels, ['numClasses', 'classLabels']) 82 if data.__class__.__name__ in containersRequiringProjection : 83 self.featureID = data.featureID[:] 84 self.featureDict = data.featureDict.copy() 85 86 data.train(**args) 87 # if there is some testing done on the data, it requires the training data: 88 if data.testingFunc is not None : 89 self.trainingData = data
90
91 - def trainFinalize(self) :
92 93 self.log.trainingTime = self.getTrainingTime()
94
95 - def getTrainingTime(self) :
96 97 return time.clock() - self._clock
98
99 - def classify(self, data, i) :
100 101 raise NotImplementedError
102
103 - def twoClassClassify(self, data, i) :
104 105 val = self.decisionFunc(data, i) 106 if val > 0 : 107 return (1, val) 108 else: 109 return (0, val)
110
111 -class IteratorClassifier (Classifier) :
112
113 - def __iter__(self) :
114 115 self._classifierIdx = -1 116 return self
117
118 - def getClassifier(self) :
119 120 if self._classifierIdx < 0 : 121 return None 122 return self.classifiers[self._classifierIdx]
123
124 - def next(self) :
125 126 self._classifierIdx += 1 127 if self._classifierIdx == len(self.classifiers) : 128 raise StopIteration 129 func = getattr(self.classifiers[self._classifierIdx], self._method) 130 131 return func(self._data, **self._args)
132
133 - def test(self, data, **args) :
134 135 self._method = 'test' 136 self._data = data 137 self._args = args 138 return iter(self)
139
140 - def cv(self, data, **args) :
141 142 self._method = 'cv' 143 self._data = data 144 self._args = args 145 return iter(self)
146
147 - def stratifiedCV(self, data, **args) :
148 149 self._method = 'stratifiedCV' 150 self._data = data 151 self._args = args 152 return iter(self)
153
154 - def loo(self, data, **args) :
155 156 self._method = 'loo' 157 self._data = data 158 self._args = args 159 return iter(self)
160