|
Neuroph | |||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.Objectjava.util.Observable
org.neuroph.core.learning.LearningRule
org.neuroph.core.learning.IterativeLearning
org.neuroph.core.learning.SupervisedLearning
public abstract class SupervisedLearning
Base class for all supervised learning algorithms. It extends IterativeLearning, and provides general supervised learning principles.
Field Summary | |
---|---|
protected double |
maxError
Max allowed network error (condition to stop learning) |
protected double |
previousEpochError
Total network error in previous epoch |
protected double |
totalNetworkError
Total network error TODO: this field should be transient in future |
Fields inherited from class org.neuroph.core.learning.IterativeLearning |
---|
currentIteration, iterationsLimited, learningRate, maxIterations |
Fields inherited from class org.neuroph.core.learning.LearningRule |
---|
neuralNetwork |
Constructor Summary | |
---|---|
SupervisedLearning()
Creates new supervised learning rule |
|
SupervisedLearning(NeuralNetwork network)
Creates new supervised learning rule and sets the neural network to train |
Method Summary | |
---|---|
void |
doLearningEpoch(TrainingSet trainingSet)
This method implements basic logic for one learning epoch for the supervised learning algorithms. |
protected boolean |
errorChangeStalled()
Returns true if absolute error change is sufficently small (<=minErrorChange) for minErrorChangeStopIterations number of iterations |
double |
getMaxError()
Returns learning error tolerance - the value of total network error to stop learning. |
double |
getMinErrorChange()
Returns min error change stopping criteria |
int |
getMinErrorChangeIterationsCount()
Returns number of iterations count for for min error change stopping criteria |
int |
getMinErrorChangeIterationsLimit()
Returns number of iterations for min error change stopping criteria |
protected java.util.Vector<java.lang.Double> |
getPatternError(java.util.Vector<java.lang.Double> output,
java.util.Vector<java.lang.Double> desiredOutput)
Calculates the network error for the current pattern - diference between desired and actual output |
double |
getPreviousEpochError()
Returns total network error in previous learning epoch |
java.lang.Double |
getTotalNetworkError()
Returns total network error in current learning epoch |
protected boolean |
hasReachedStopCondition()
Returns true if stop condition has been reached, false otherwise. |
void |
learn(TrainingSet trainingSet,
double maxError)
Trains network for the specified training set and number of iterations |
void |
learn(TrainingSet trainingSet,
double maxError,
int maxIterations)
Trains network for the specified training set and number of iterations |
protected void |
learnPattern(SupervisedTrainingElement trainingElement)
Trains network with the pattern from the specified training element |
protected void |
reset()
Reset the iteration counter |
void |
setMaxError(double maxError)
Sets allowed network error, which indicates when to stopLearning training |
void |
setMinErrorChange(double minErrorChange)
Sets min error change stopping criteria |
void |
setMinErrorChangeIterationsLimit(int minErrorChangeIterationsLimit)
Sets number of iterations for min error change stopping criteria |
protected abstract void |
updateNetworkWeights(java.util.Vector<java.lang.Double> patternError)
This method should implement the weights update procedure |
protected abstract void |
updateTotalNetworkError(java.util.Vector<java.lang.Double> patternError)
Subclasses update total network error for each training pattern with this method. |
Methods inherited from class org.neuroph.core.learning.IterativeLearning |
---|
doOneLearningIteration, getCurrentIteration, getLearningRate, isPausedLearning, learn, learn, pause, resume, setLearningRate, setMaxIterations |
Methods inherited from class org.neuroph.core.learning.LearningRule |
---|
getNeuralNetwork, getTrainingSet, isStopped, notifyChange, run, setNeuralNetwork, setTrainingSet, stopLearning |
Methods inherited from class java.util.Observable |
---|
addObserver, clearChanged, countObservers, deleteObserver, deleteObservers, hasChanged, notifyObservers, notifyObservers, setChanged |
Methods inherited from class java.lang.Object |
---|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
Field Detail |
---|
protected double totalNetworkError
protected transient double previousEpochError
protected double maxError
Constructor Detail |
---|
public SupervisedLearning()
public SupervisedLearning(NeuralNetwork network)
network
- network to trainMethod Detail |
---|
public void learn(TrainingSet trainingSet, double maxError)
trainingSet
- training set to learnmaxError
- maximum numberof iterations to learnpublic void learn(TrainingSet trainingSet, double maxError, int maxIterations)
trainingSet
- training set to learnmaxIterations
- maximum numberof iterations to learnprotected void reset()
IterativeLearning
reset
in class IterativeLearning
public void doLearningEpoch(TrainingSet trainingSet)
doLearningEpoch
in class IterativeLearning
trainingSet
- training set for training networkprotected boolean hasReachedStopCondition()
protected boolean errorChangeStalled()
protected void learnPattern(SupervisedTrainingElement trainingElement)
trainingElement
- supervised training element which contains input and desired
outputprotected java.util.Vector<java.lang.Double> getPatternError(java.util.Vector<java.lang.Double> output, java.util.Vector<java.lang.Double> desiredOutput)
output
- actual network outputdesiredOutput
- desired network output
public void setMaxError(double maxError)
maxError
- network errorpublic double getMaxError()
public java.lang.Double getTotalNetworkError()
public double getPreviousEpochError()
public double getMinErrorChange()
public void setMinErrorChange(double minErrorChange)
minErrorChange
- value for min error change stopping criteriapublic int getMinErrorChangeIterationsLimit()
public void setMinErrorChangeIterationsLimit(int minErrorChangeIterationsLimit)
minErrorChangeIterationsLimit
- number of iterations for min error change stopping criteriapublic int getMinErrorChangeIterationsCount()
protected abstract void updateTotalNetworkError(java.util.Vector<java.lang.Double> patternError)
patternError
- pattern error vectorprotected abstract void updateNetworkWeights(java.util.Vector<java.lang.Double> patternError)
patternError
- pattern error vector
|
Neuroph | |||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |