Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add continual learning #minor #131

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions hiclass/HierarchicalClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
bert: bool = False,
classifier_abbreviation: str = "",
tmp_dir: str = None,
warm_start: bool = False,
):
"""
Initialize a local hierarchical classifier.
Expand Down Expand Up @@ -99,6 +100,9 @@ def __init__(
tmp_dir : str, default=None
Temporary directory to persist local classifiers that are trained. If the job needs to be restarted,
it will skip the pre-trained local classifier found in the temporary directory.
warm_start : bool, default=False
When set to true, the hierarchical classifier reuses the solution of the previous call to fit, that is,
new classes can be added.
"""
self.local_classifier = local_classifier
self.verbose = verbose
Expand All @@ -108,6 +112,7 @@ def __init__(
self.bert = bert
self.classifier_abbreviation = classifier_abbreviation
self.tmp_dir = tmp_dir
self.warm_start = warm_start

def fit(self, X, y, sample_weight=None):
"""
Expand Down Expand Up @@ -155,6 +160,8 @@ def _pre_fit(self, X, y, sample_weight):
else:
self.sample_weight_ = None

self.warm_start_ = self.warm_start

self.y_ = make_leveled(self.y_)

# Create and configure logger
Expand All @@ -164,7 +171,7 @@ def _pre_fit(self, X, y, sample_weight):
# which would generate the prediction a->b->c
self._disambiguate()

# Create DAG from self.y_ and store to self.hierarchy_
# Create or update DAG from self.y_ and store to self.hierarchy_
self._create_digraph()

# If user passes edge_list, then export
Expand Down Expand Up @@ -229,7 +236,7 @@ def _create_digraph(self):
self._create_digraph_2d()

if self.y_.ndim > 2:
# Unsuported dimension
# Unsupported dimension
self.logger_.error(f"y with {self.y_.ndim} dimensions detected")
raise ValueError(
f"Creating graph from y with {self.y_.ndim} dimensions is not supported"
Expand All @@ -250,7 +257,10 @@ def _create_digraph_1d(self):
def _create_digraph_2d(self):
if self.y_.ndim == 2:
# Create max_levels variable
self.max_levels_ = self.y_.shape[1]
if self.warm_start_:
self.max_levels_ = max(self.max_levels_, self.y_.shape[1])
else:
self.max_levels_ = self.y_.shape[1]
rows, columns = self.y_.shape
self.logger_.info(f"Creating digraph from {rows} 2D labels")
for row in range(rows):
Expand Down Expand Up @@ -296,7 +306,10 @@ def _add_artificial_root(self):
self.logger_.info(f"Detected {len(roots)} roots")

# Add artificial root as predecessor to root(s) detected
self.root_ = "hiclass::root"
if self.warm_start_:
roots.remove(self.root_)
else:
self.root_ = "hiclass::root"
for old_root in roots:
self.hierarchy_.add_edge(self.root_, old_root)

Expand Down
7 changes: 7 additions & 0 deletions hiclass/LocalClassifierPerLevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
n_jobs: int = 1,
bert: bool = False,
tmp_dir: str = None,
warm_start: bool = False,
):
"""
Initialize a local classifier per level.
Expand All @@ -79,6 +80,9 @@ def __init__(
tmp_dir : str, default=None
Temporary directory to persist local classifiers that are trained. If the job needs to be restarted,
it will skip the pre-trained local classifier found in the temporary directory.
warm_start : bool, default=False
When set to true, the hierarchical classifier reuses the solution of the previous call to fit, that is,
new classes can be added.
"""
super().__init__(
local_classifier=local_classifier,
Expand All @@ -89,6 +93,7 @@ def __init__(
classifier_abbreviation="LCPL",
bert=bert,
tmp_dir=tmp_dir,
warm_start=warm_start,
)

def fit(self, X, y, sample_weight=None):
Expand All @@ -115,6 +120,8 @@ def fit(self, X, y, sample_weight=None):
# Execute common methods necessary before fitting
super()._pre_fit(X, y, sample_weight)

# TODO: add partial_fit here if warm_start=True

# Fit local classifiers in DAG
super().fit(X, y)

Expand Down
7 changes: 7 additions & 0 deletions hiclass/LocalClassifierPerNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
n_jobs: int = 1,
bert: bool = False,
tmp_dir: str = None,
warm_start: bool = False,
):
"""
Initialize a local classifier per node.
Expand Down Expand Up @@ -85,6 +86,9 @@ def __init__(
tmp_dir : str, default=None
Temporary directory to persist local classifiers that are trained. If the job needs to be restarted,
it will skip the pre-trained local classifier found in the temporary directory.
warm_start : bool, default=False
When set to true, the hierarchical classifier reuses the solution of the previous call to fit, that is,
new classes can be added.
"""
super().__init__(
local_classifier=local_classifier,
Expand All @@ -95,6 +99,7 @@ def __init__(
classifier_abbreviation="LCPN",
bert=bert,
tmp_dir=tmp_dir,
warm_start=warm_start,
)
self.binary_policy = binary_policy

Expand Down Expand Up @@ -125,6 +130,8 @@ def fit(self, X, y, sample_weight=None):
# Initialize policy
self._initialize_binary_policy()

# TODO: add partial_fit here if warm_start=True

# Fit local classifiers in DAG
super().fit(X, y)

Expand Down
40 changes: 39 additions & 1 deletion hiclass/LocalClassifierPerParentNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
n_jobs: int = 1,
bert: bool = False,
tmp_dir: str = None,
warm_start: bool = False,
):
"""
Initialize a local classifier per parent node.
Expand All @@ -72,6 +73,9 @@ def __init__(
tmp_dir : str, default=None
Temporary directory to persist local classifiers that are trained. If the job needs to be restarted,
it will skip the pre-trained local classifier found in the temporary directory.
warm_start : bool, default=False
When set to true, the hierarchical classifier reuses the solution of the previous call to fit, that is,
new classes can be added.
"""
super().__init__(
local_classifier=local_classifier,
Expand All @@ -82,6 +86,7 @@ def __init__(
classifier_abbreviation="LCPPN",
bert=bert,
tmp_dir=tmp_dir,
warm_start=warm_start,
)

def fit(self, X, y, sample_weight=None):
Expand Down Expand Up @@ -165,6 +170,38 @@ def predict(self, X):

return y

def partial_fit(self, X, y, sample_weight=None):
"""
Add new parent nodes for the local classifier per parent node.

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The training input samples. Internally, its dtype will be converted
to ``dtype=np.float32``. If a sparse matrix is provided, it will be
converted into a sparse ``csc_matrix``.
y : array-like of shape (n_samples, n_levels)
The target values, i.e., hierarchical class labels for classification.
sample_weight : array-like of shape (n_samples,), default=None
Array of weights that are assigned to individual samples.
If not provided, then each sample is given unit weight.

Returns
-------
self : object
Fitted estimator.
"""
self.warm_start_ = True

# Execute common methods necessary before fitting
super()._pre_fit(X, y, sample_weight)

# Fit local classifiers in DAG
super().fit(X, y)

# Return the classifier
return self

def _predict_remaining_levels(self, X, y):
for level in range(1, y.shape[1]):
predecessors = set(y[:, level - 1])
Expand All @@ -183,7 +220,8 @@ def _initialize_local_classifiers(self):
local_classifiers = {}
nodes = self._get_parents()
for node in nodes:
local_classifiers[node] = {"classifier": deepcopy(self.local_classifier_)}
if "classifier" not in self.hierarchy_.nodes[node]:
local_classifiers[node] = {"classifier": deepcopy(self.local_classifier_)}
nx.set_node_attributes(self.hierarchy_, local_classifiers)

def _get_parents(self):
Expand Down
56 changes: 56 additions & 0 deletions tests/test_HierarchicalClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,21 @@ def test_create_digraph_1d(graph_1d):
assert list(ground_truth.edges) == list(graph_1d.hierarchy_.edges)


def test_update_digraph_1d(graph_1d):
ground_truth = nx.DiGraph()
ground_truth.add_nodes_from(np.array(["a", "b", "c", "d", "e", "f"]))
graph_1d._create_digraph()
attributes = {}
attributes["a"] = {"trained_classifier": "yes"}
nx.set_node_attributes(graph_1d.hierarchy_, attributes)
graph_1d.y_ = np.array(["a", "b", "c", "d", "e", "f"])
graph_1d._create_digraph_1d()
assert nx.is_isomorphic(ground_truth, graph_1d.hierarchy_)
assert list(ground_truth.nodes) == list(graph_1d.hierarchy_.nodes)
assert list(ground_truth.edges) == list(graph_1d.hierarchy_.edges)
assert graph_1d.hierarchy_.nodes["a"]["trained_classifier"] == "yes"


@pytest.fixture
def graph_1d_disguised_as_2d():
classifier = HierarchicalClassifier()
Expand All @@ -82,6 +97,8 @@ def digraph_2d():
classifier.logger_ = logging.getLogger("HC")
classifier.edge_list = tempfile.TemporaryFile()
classifier.separator_ = "::HiClass::Separator::"
classifier.warm_start_ = True
classifier.max_levels_ = 3
return classifier


Expand All @@ -93,6 +110,31 @@ def test_create_digraph_2d(digraph_2d):
assert list(ground_truth.edges) == list(digraph_2d.hierarchy_.edges)


def test_update_digraph_2d(digraph_2d):
ground_truth = nx.DiGraph(
[
("a", "b"),
("b", "c"),
("d", "e"),
("e", "f"),
("g", "h"),
("h", "i"),
("i", "j"),
]
)
digraph_2d._create_digraph()
attributes = {}
attributes["b"] = {"trained_classifier": "yes"}
nx.set_node_attributes(digraph_2d.hierarchy_, attributes)
digraph_2d.y_ = np.array([["g", "h", "i", "j"]])
digraph_2d._create_digraph_2d()
assert nx.is_isomorphic(ground_truth, digraph_2d.hierarchy_)
assert list(ground_truth.nodes) == list(digraph_2d.hierarchy_.nodes)
assert list(ground_truth.edges) == list(digraph_2d.hierarchy_.edges)
assert digraph_2d.hierarchy_.nodes["b"]["trained_classifier"] == "yes"
assert digraph_2d.max_levels_ == 4


@pytest.fixture
def digraph_3d():
classifier = HierarchicalClassifier()
Expand Down Expand Up @@ -137,6 +179,7 @@ def digraph_one_root():
classifier = HierarchicalClassifier()
classifier.logger_ = logging.getLogger("HC")
classifier.hierarchy_ = nx.DiGraph([("a", "b"), ("b", "c"), ("c", "d")])
classifier.warm_start_ = False
return classifier


Expand All @@ -155,6 +198,7 @@ def digraph_multiple_roots():
classifier.X_ = np.array([[1, 2], [3, 4], [5, 6]])
classifier.y_ = np.array([["a", "b"], ["c", "d"], ["e", "f"]])
classifier.sample_weight_ = None
classifier.warm_start_ = False
return classifier


Expand All @@ -165,6 +209,17 @@ def test_add_artificial_root_multiple_roots(digraph_multiple_roots):
assert "hiclass::root" == digraph_multiple_roots.root_


def test_add_artificial_new_nodes(digraph_multiple_roots):
digraph_multiple_roots._add_artificial_root()
digraph_multiple_roots.hierarchy_.add_node("g")
digraph_multiple_roots.hierarchy_.add_node("h")
digraph_multiple_roots.warm_start_ = True
digraph_multiple_roots._add_artificial_root()
successors = list(digraph_multiple_roots.hierarchy_.successors("hiclass::root"))
assert ["a", "c", "e", "g", "h"] == successors
assert "hiclass::root" == digraph_multiple_roots.root_


def test_initialize_local_classifiers_2(digraph_multiple_roots):
digraph_multiple_roots.local_classifier = None
digraph_multiple_roots._initialize_local_classifiers()
Expand Down Expand Up @@ -224,6 +279,7 @@ def test_fit_digraph():
def test_pre_fit_bert():
classifier = HierarchicalClassifier()
classifier.logger_ = logging.getLogger("HC")
classifier.warm_start_ = False
classifier.bert = True
x = [[0, 1], [2, 3]]
y = [["a", "b"], ["c", "d"]]
Expand Down
Loading