"F" "N" "A" "D" "X" "Y" "Z" "M"
"F" "N" "A" "D" "X" "Y" "Z" "M"
multilayerperceptron;
import java.util.ArrayList;
import java.util.Collection;
import weka.classifiers.Evaluation;
import weka.classifiers.neural.common.NeuralModel;
import weka.classifiers.neural.common.SimpleNeuron;
import weka.classifiers.neural.common.WekaAlgorithmAncestor;
import weka.classifiers.neural.common.learning.LearningKernelFactory;
import weka.classifiers.neural.common.learning.LearningRateKernel;
import weka.classifiers.neural.common.training.TrainerFactory;
import weka.classifiers.neural.common.transfer.TransferFunction;
import weka.classifiers.neural.common.transfer.TransferFunctionFactory;
import
weka.classifiers.neural.multilayerperceptron.algorithm.BackPropagationAlgorith
m;
import weka.core.Instances;
import weka.core.Option;
import weka.core.SelectedTag;
import weka.core.Utils;
/**
* <p>Title: Weka Neural Implementation</p>
* <p>Description: ...</p>
* <p>Copyright: Copyright (c) 2003</p>
* <p>Company: N/A</p>
* @author Jason Brownlee
* @version 1.0
*/
// param flags
public final static String [] EXTRA_PARAMETERS =
{
"F", // transfer function
"N", // training mode
"A", // momentum
"D", // weight decay
"X", // hidden layer 1 num nodes
"Y", // hidden layer 2 num nodes
"Z", // hidden layer 3 num nodes
"M" // learning rate function
};
public final static String [] EXTRA_PARAMETER_NOTES =
{
"<transfer function>", // transfer function
"<training mode>", // training mode
"<momentum value>", // momentum
"<weight decay value>", // weight decay
"<total first layer nodes>", // hidden layer 1 num nodes
"<total second layer nodes>", // hidden layer 2 num nodes
"<total third layer nodes>", // hidden layer 3 num nodes
"<learning function>" // learning rate function
};
// momentum
protected double momentum = 0.0;
// weight decay
protected double weightDecay = 0.0;
// topology
protected int hiddenLayer1 = 0;
protected int hiddenLayer2 = 0;
protected int hiddenLayer3 = 0;
public BackPropagation()
{
// set good initial values
transferFunction = TransferFunctionFactory.TRANSFER_SIGMOID;
trainingMode = TrainerFactory.TRAINER_BATCH;
trainingIterations = 500;
biasInput = SimpleNeuron.DEFAULT_BIAS_VALUE;
learningRate = 0.1;
learningRateFunction = LearningKernelFactory.LEARNING_FUNCTION_STATIC;
randomNumberSeed = 0;
momentum = 0.2;
weightDecay = 0.0;
hiddenLayer1 = 0;
hiddenLayer2 = 0;
hiddenLayer3 = 0;
}
return algorithm;
list.add("-"+EXTRA_PARAMETERS[PARAM_TRANSFER_FUNCTION]);
list.add(Integer.toString(transferFunction));
list.add("-"+EXTRA_PARAMETERS[PARAM_TRAINING_MODE]);
list.add(Integer.toString(trainingMode));
list.add("-"+EXTRA_PARAMETERS[PARAM_MOMENTUM]);
list.add(Double.toString(momentum));
list.add("-"+EXTRA_PARAMETERS[PARAM_WEIGHT_DECAY]);
list.add(Double.toString(weightDecay));
list.add("-"+EXTRA_PARAMETERS[PARAM_HIDDEN_1]);
list.add(Integer.toString(hiddenLayer1));
list.add("-"+EXTRA_PARAMETERS[PARAM_HIDDEN_2]);
list.add(Integer.toString(hiddenLayer2));
list.add("-"+EXTRA_PARAMETERS[PARAM_HIDDEN_3]);
list.add(Integer.toString(hiddenLayer3));
list.add("-"+EXTRA_PARAMETERS[PARAM_LEARNING_RATE_FUNCTION]);
list.add(Integer.toString(learningRateFunction));
return list;
}
return list;
}
return buffer.toString();
}
switch(i)
{
case PARAM_TRANSFER_FUNCTION:
{
transferFunction =
Integer.parseInt(data);
break;
}
case PARAM_TRAINING_MODE:
{
trainingMode = Integer.parseInt(data);
break;
}
case PARAM_MOMENTUM:
{
momentum = Double.parseDouble(data);
break;
}
case PARAM_WEIGHT_DECAY:
{
weightDecay = Double.parseDouble(data);
break;
}
case PARAM_HIDDEN_1:
{
hiddenLayer1 = Integer.parseInt(data);
break;
}
case PARAM_HIDDEN_2:
{
hiddenLayer2 = Integer.parseInt(data);
break;
}
case PARAM_HIDDEN_3:
{
hiddenLayer3 = Integer.parseInt(data);
break;
}
case PARAM_LEARNING_RATE_FUNCTION:
{
learningRateFunction =
Integer.parseInt(data);
break;
}
default:
{
throw new Exception("Invalid option
offset: " + i);
}
}
}
}
/**
* Entry point into the algorithm for direct usage
* @param args
*/
public static void main(String [] args)
{
try
{
System.out.println(Evaluation.evaluateModel(new
BackPropagation(), args));
}
catch (Exception e)
{
System.out.println(e.getMessage());
}
}
}