Menu

[572a65]: / src / junit / classification / TestKNN.java  Maximize  Restore  History

Download this file

142 lines (122 with data), 4.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
/**
* %SVN.HEADER%
*/
package junit.classification;
import java.io.File;
import java.io.IOException;
import java.util.Map;
import java.util.Random;
import net.sf.javaml.classification.Classifier;
import net.sf.javaml.classification.KNearestNeighbors;
import net.sf.javaml.classification.evaluation.CrossValidation;
import net.sf.javaml.classification.evaluation.PerformanceMeasure;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.DefaultDataset;
import net.sf.javaml.core.Instance;
import net.sf.javaml.core.exception.TrainingRequiredException;
import net.sf.javaml.tools.InstanceTools;
import net.sf.javaml.tools.data.FileHandler;
import org.junit.Assert;
import org.junit.Test;
import be.abeel.util.TimeInterval;
/**
*
* @author Thomas Abeel
*
*/
public class TestKNN {
@Test (expected=TrainingRequiredException.class)
public void testNoTraining(){
KNearestNeighbors knn=new KNearestNeighbors(1);
knn.classDistribution(InstanceTools.randomInstance(5));
}
@Test
public void testSingleClass2() {
try {
Dataset data = new DefaultDataset();
for (int i = 0; i < 10; i++) {
Instance is = InstanceTools.randomInstance(5);
is.setClassValue("class");
data.add(is);
}
KNearestNeighbors knn = new KNearestNeighbors(1);
knn.buildClassifier(data);
Map<Object, Double> distr = knn.classDistribution(InstanceTools.randomInstance(5));
System.out.println(distr);
Assert.assertFalse(Double.isNaN(distr.get("class")));
} catch (IllegalArgumentException e) {
e.printStackTrace();
Assert.fail();
}
}
@Test
public void testSingleClass(){
Dataset data=new DefaultDataset();
for(int i=0;i<10;i++){
Instance is=InstanceTools.randomInstance(5);
is.setClassValue("class");
}
KNearestNeighbors knn=new KNearestNeighbors(1);
knn.buildClassifier(data);
Map<Object,Double>distr=knn.classDistribution(InstanceTools.randomInstance(5));
System.out.println(distr);
}
@Test
public void testKNNIris() {
try {
Dataset data = FileHandler.loadDataset(new File("devtools/data/iris.data"), 4, ",");
System.out.println("Loader: " + data.classes());
Classifier knn = new KNearestNeighbors(5);
CrossValidation cv = new CrossValidation(knn);
Map<Object, PerformanceMeasure> p = cv.crossValidation(data, 5, new Random(10));
System.out.println(p);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
/**
* Test KNN usage on the sparse sample.
*
*/
@Test
public void testSparseKNN() {
try {
/* Load a data set */
Dataset data = FileHandler.loadSparseDataset(new File("devtools/data/smallsparse.tsv"), 0, "\t", ":");
/*
* Contruct a KNN classifier that uses 5 neighbors to make a
* decision.
*/
Classifier knn = new KNearestNeighbors(5);
knn.buildClassifier(data);
System.out.println("Building complete!");
/*
* Load a data set, this can be a different one, but we will use the
* same one.
*/
Dataset dataForClassification = FileHandler.loadSparseDataset(new File("devtools/data/smallsparse.tsv"), 0,
"\t", ":");
/* Counters for correct and wrong predictions. */
int correct = 0, wrong = 0;
int count = 0;
/* Classify all instances and check with the correct class values */
for (int i = 0; i < 15; i++) {
Instance inst = dataForClassification.instance(i);
long time = System.currentTimeMillis();
System.out.print("Processing instance: " + ++count + "\t");
Object predictedClassValue = knn.classify(inst);
Object realClassValue = inst.classValue();
if (predictedClassValue.equals(realClassValue))
correct++;
else
wrong++;
System.out.println(new TimeInterval(System.currentTimeMillis() - time));
}
System.out.println("Correct predictions " + correct);
System.out.println("Wrong predictions " + wrong);
} catch (IOException e) {
Assert.assertTrue(false);
}
}
}
Want the latest updates on software, tech news, and AI?
Get latest updates about software, tech news, and AI from SourceForge directly in your inbox once a month.