-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathplot_binary_policies.py
78 lines (56 loc) · 2.36 KB
/
plot_binary_policies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# -*- coding: utf-8 -*-
"""
===========================
Binary Training Policies
===========================
The siblings policy is used by default on the local classifier per node, but the remaining ones can be selected with the parameter :literal:`binary_policy`, for example:
.. tabs::
.. code-tab:: python
:caption: Exclusive
rf = RandomForestClassifier()
classifier = LocalClassifierPerNode(local_classifier=rf, binary_policy="exclusive")
.. code-tab:: python
:caption: Less exclusive
rf = RandomForestClassifier()
classifier = LocalClassifierPerNode(local_classifier=rf, binary_policy="less_exclusive")
.. code-tab:: python
:caption: Less inclusive
rf = RandomForestClassifier()
classifier = LocalClassifierPerNode(local_classifier=rf, binary_policy="less_inclusive")
.. code-tab:: python
:caption: Inclusive
rf = RandomForestClassifier()
classifier = LocalClassifierPerNode(local_classifier=rf, binary_policy="inclusive")
.. code-tab:: python
:caption: Siblings
rf = RandomForestClassifier()
classifier = LocalClassifierPerNode(local_classifier=rf, binary_policy="siblings")
.. code-tab:: python
:caption: Exclusive siblings
rf = RandomForestClassifier()
classifier = LocalClassifierPerNode(local_classifier=rf, binary_policy="exclusive_siblings")
In the code below, the inclusive policy is selected.
However, the code can be easily updated by replacing lines 20-21 with the examples shown in the tabs above.
.. seealso::
Mathematical definition on the different policies is given at :ref:`Training Policies`.
"""
from sklearn.ensemble import RandomForestClassifier
from hiclass import LocalClassifierPerNode
# Define data
X_train = [[1], [2], [3], [4]]
X_test = [[4], [3], [2], [1]]
Y_train = [
["Animal", "Mammal", "Sheep"],
["Animal", "Mammal", "Cow"],
["Animal", "Reptile", "Snake"],
["Animal", "Reptile", "Lizard"],
]
# Use random forest classifiers for every node
# And inclusive policy to select training examples for binary classifiers.
rf = RandomForestClassifier()
classifier = LocalClassifierPerNode(local_classifier=rf, binary_policy="inclusive")
# Train local classifier per node
classifier.fit(X_train, Y_train)
# Predict
predictions = classifier.predict(X_test)
print(predictions)