From 8b782e973ab2828bec0f226344cf6054039fce87 Mon Sep 17 00:00:00 2001 From: Jonathan Haas Date: Mon, 7 Nov 2022 10:13:53 +0100 Subject: [PATCH 01/21] Update HierarchicalClassifier.py Remove disambiguate function --- hiclass/HierarchicalClassifier.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/hiclass/HierarchicalClassifier.py b/hiclass/HierarchicalClassifier.py index 2bb5580b..cf7e05de 100644 --- a/hiclass/HierarchicalClassifier.py +++ b/hiclass/HierarchicalClassifier.py @@ -188,18 +188,6 @@ def _create_logger(self): # Add ch to logger self.logger_.addHandler(ch) - def _disambiguate(self): - self.separator_ = "::HiClass::Separator::" - if self.y_.ndim == 2: - new_y = [] - for i in range(self.y_.shape[0]): - row = [str(self.y_[i, 0])] - for j in range(1, self.y_.shape[1]): - parent = str(row[-1]) - child = str(self.y_[i, j]) - row.append(parent + self.separator_ + child) - new_y.append(np.asarray(row, dtype=np.str_)) - self.y_ = np.array(new_y) def _create_digraph(self): # Create DiGraph From eee80ceb0584f01f3b0dd0142a141f7ba536cda9 Mon Sep 17 00:00:00 2001 From: Jonathan Haas Date: Thu, 10 Nov 2022 10:12:57 +0100 Subject: [PATCH 02/21] Update .gitignore for vscode users --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 4f3bf593..f34b738e 100644 --- a/.gitignore +++ b/.gitignore @@ -334,5 +334,6 @@ Thumbs.db # Ignore all local history of files .history .ionide +.vscode # End of https://www.toptal.com/developers/gitignore/api/python,pycharm,pycharm+all,visualstudiocode From 8778539c7e9a2ad5bd2598d66be55f3047a73b6c Mon Sep 17 00:00:00 2001 From: Jonathan Haas Date: Thu, 10 Nov 2022 10:17:15 +0100 Subject: [PATCH 03/21] Remove node separator WIP --- hiclass/HierarchicalClassifier.py | 55 +++++++++++++++--------- hiclass/LocalClassifierPerLevel.py | 1 - hiclass/LocalClassifierPerNode.py | 6 +-- tests/test_BinaryPolicy.py | 61 +++++++++++++++++++++++++++ tests/test_HierarchicalClassifier.py | 63 ++++++++++++---------------- tests/test_LocalClassifierPerNode.py | 28 ++++++++----- 6 files changed, 142 insertions(+), 72 deletions(-) diff --git a/hiclass/HierarchicalClassifier.py b/hiclass/HierarchicalClassifier.py index cf7e05de..08068b7e 100644 --- a/hiclass/HierarchicalClassifier.py +++ b/hiclass/HierarchicalClassifier.py @@ -15,6 +15,9 @@ _has_ray = False else: _has_ray = True + + +ARTIFICIAL_ROOT = "hiclass::root" def make_leveled(y): @@ -146,7 +149,7 @@ def _pre_fit(self, X, y, sample_weight): # Avoids creating more columns in prediction if edges are a->b and b->c, # which would generate the prediction a->b->c - self._disambiguate() + # self._disambiguate() # Create DAG from self.y_ and store to self.hierarchy_ self._create_digraph() @@ -188,7 +191,6 @@ def _create_logger(self): # Add ch to logger self.logger_.addHandler(ch) - def _create_digraph(self): # Create DiGraph self.hierarchy_ = nx.DiGraph() @@ -221,20 +223,25 @@ def _create_digraph_1d(self): def _create_digraph_2d(self): if self.y_.ndim == 2: - # Create max_levels variable - self.max_levels_ = self.y_.shape[1] rows, columns = self.y_.shape + # Create max_levels variable + self.max_levels_ = columns + self.logger_.info(f"Creating digraph from {rows} 2D labels") + for row in range(rows): - for column in range(columns - 1): - parent = self.y_[row, column].split(self.separator_)[-1] - child = self.y_[row, column + 1].split(self.separator_)[-1] + for column in range(0, columns - 1): + + parent = self.y_[row, column] + child = self.y_[row, column + 1 ] + if parent != "" and child != "": # Only add edge if both parent and child are not empty self.hierarchy_.add_edge( - self.y_[row, column], self.y_[row, column + 1] + parent, child ) - elif parent != "" and column == 0: + if parent != "" and column == 0: + # Add parent node e self.hierarchy_.add_node(parent) def _export_digraph(self): @@ -243,7 +250,7 @@ def _export_digraph(self): # Add quotes to all nodes in case the text has commas mapping = {} for node in self.hierarchy_: - mapping[node] = '"{}"'.format(node.split(self.separator_)[-1]) + mapping[node] = '"{}"'.format(node) hierarchy = nx.relabel_nodes(self.hierarchy_, mapping, copy=True) # Export DAG to CSV file self.logger_.info(f"Writing edge list to file {self.edge_list}") @@ -261,14 +268,22 @@ def _convert_1d_y_to_2d(self): self.y_ = np.reshape(self.y_, (-1, 1)) def _add_artificial_root(self): + # Detect root(s) - roots = [ - node for node, in_degree in self.hierarchy_.in_degree() if in_degree == 0 - ] + columns = self.y_.shape[1] + + if columns > 1: + # roots are the first column of y + roots = set(self.y_[:,0]) + else: + roots = [ + node for node, in_degree in self.hierarchy_.in_degree() if in_degree == 0 + ] + self.logger_.info(f"Detected {len(roots)} roots") # Add artificial root as predecessor to root(s) detected - self.root_ = "hiclass::root" + self.root_ = ARTIFICIAL_ROOT for old_root in roots: self.hierarchy_.add_edge(self.root_, old_root) @@ -287,12 +302,12 @@ def _convert_to_1d(self, y): y = y.flatten() return y - def _remove_separator(self, y): - # Remove separator from predictions - if y.ndim == 2: - for i in range(y.shape[0]): - for j in range(1, y.shape[1]): - y[i, j] = y[i, j].split(self.separator_)[-1] + # def _remove_separator(self, y): + # # Remove separator from predictions + # if y.ndim == 2: + # for i in range(y.shape[0]): + # for j in range(1, y.shape[1]): + # y[i, j] = y[i, j].split(self.separator_)[-1] def _fit_node_classifier( self, nodes, local_mode: bool = False, use_joblib: bool = False diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index d54baf58..ebfed97b 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -152,7 +152,6 @@ def predict(self, X): y = self._convert_to_1d(y) - self._remove_separator(y) return y diff --git a/hiclass/LocalClassifierPerNode.py b/hiclass/LocalClassifierPerNode.py index 21160c28..2205c3aa 100644 --- a/hiclass/LocalClassifierPerNode.py +++ b/hiclass/LocalClassifierPerNode.py @@ -53,7 +53,8 @@ def __init__( clone methods. binary_policy : {"exclusive", "less_exclusive", "exclusive_siblings", "inclusive", "less_inclusive", "siblings"}, str, default="siblings" Specify the rule for defining positive and negative training examples, using one of the following options: - + + # TODO : Phrasing? - `exclusive`: Positive examples belong only to the class being considered. All classes are negative examples, except for the selected class; - `less_exclusive`: Positive examples belong only to the class being considered. All classes are negative examples, except for the selected class and its descendants; - `exclusive_siblings`: Positive examples belong only to the class being considered. All sibling classes are negative examples; @@ -166,7 +167,7 @@ def predict(self, X): if subset_x.shape[0] > 0: probabilities = np.zeros((subset_x.shape[0], len(successors))) for i, successor in enumerate(successors): - successor_name = str(successor).split(self.separator_)[-1] + successor_name = str(successor) self.logger_.info(f"Predicting for node '{successor_name}'") classifier = self.hierarchy_.nodes[successor]["classifier"] positive_index = np.where(classifier.classes_ == 1)[0] @@ -185,7 +186,6 @@ def predict(self, X): y = self._convert_to_1d(y) - self._remove_separator(y) return y diff --git a/tests/test_BinaryPolicy.py b/tests/test_BinaryPolicy.py index 13905f06..7a625528 100644 --- a/tests/test_BinaryPolicy.py +++ b/tests/test_BinaryPolicy.py @@ -404,3 +404,64 @@ def test_siblings_get_binary_examples_sparse_3(digraph, features_sparse, labels) assert_array_equal(ground_truth_x, x.todense()) assert_array_equal(ground_truth_y, y) assert weights is None + +########################################################### + +@pytest.fixture +def digraph_dag(): + return nx.DiGraph( + [ + ("r", "a"), + ("r", "b"), + ("a", "b"), + ("b", "c"), + ("d", "e"), + ]) + + +@pytest.fixture +def features_dag_1d(): + return np.array( + [ + 1, + 2, + 3 + ] + ) + + +@pytest.fixture +def features_dag_2d(): + return np.array( + [ + [1, 2], + [3, 4], + [5, 6] + ] + ) + +@pytest.fixture +def labels_dag(): + return np.array( + [ + ["a", "b"], + ["b", "c"], + ["d", "e"], + ] + ) + +@pytest.mark.parametrize( + ["node", "expected"], [["a", [True, False, False]], ["b", [True, True, False]], ["c", [False, True, False]]] +) +def test_exclusive_policy_positive_examples_1_dag(digraph_dag, features_dag_1d, labels_dag, node, expected): + policy = ExclusivePolicy(digraph_dag, features_dag_1d, labels_dag) + result = policy.positive_examples(node) + assert_array_equal(expected, result) + +@pytest.mark.parametrize( + ["node", "expected"], [["a", [False, True, True]], ["b", [False, False, True]], ["c", [True, False, True]]] +) +def test_exclusive_policy_negative_examples_1_dag(digraph_dag, features_dag_1d, labels_dag, node, expected): + policy = ExclusivePolicy(digraph_dag, features_dag_1d, labels_dag) + result = policy.negative_examples(node) + assert_array_equal(expected, result) \ No newline at end of file diff --git a/tests/test_HierarchicalClassifier.py b/tests/test_HierarchicalClassifier.py index bff36e32..e9beb184 100644 --- a/tests/test_HierarchicalClassifier.py +++ b/tests/test_HierarchicalClassifier.py @@ -7,37 +7,9 @@ from numpy.testing import assert_array_equal from sklearn.linear_model import LogisticRegression -from hiclass.HierarchicalClassifier import HierarchicalClassifier, make_leveled +from hiclass.HierarchicalClassifier import HierarchicalClassifier, make_leveled, ARTIFICIAL_ROOT -@pytest.fixture -def ambiguous_node_str(): - classifier = HierarchicalClassifier() - classifier.y_ = np.array([["a", "b"], ["b", "c"]]) - return classifier - - -def test_disambiguate_str(ambiguous_node_str): - ground_truth = np.array( - [["a", "a::HiClass::Separator::b"], ["b", "b::HiClass::Separator::c"]] - ) - ambiguous_node_str._disambiguate() - assert_array_equal(ground_truth, ambiguous_node_str.y_) - - -@pytest.fixture -def ambiguous_node_int(): - classifier = HierarchicalClassifier() - classifier.y_ = np.array([[1, 2], [2, 3]]) - return classifier - - -def test_disambiguate_int(ambiguous_node_int): - ground_truth = np.array( - [["1", "1::HiClass::Separator::2"], ["2", "2::HiClass::Separator::3"]] - ) - ambiguous_node_int._disambiguate() - assert_array_equal(ground_truth, ambiguous_node_int.y_) @pytest.fixture @@ -81,14 +53,13 @@ def digraph_2d(): classifier.hierarchy_ = nx.DiGraph([("a", "b"), ("b", "c"), ("d", "e"), ("e", "f")]) classifier.logger_ = logging.getLogger("HC") classifier.edge_list = tempfile.TemporaryFile() - classifier.separator_ = "::HiClass::Separator::" return classifier def test_create_digraph_2d(digraph_2d): ground_truth = nx.DiGraph([("a", "b"), ("b", "c"), ("d", "e"), ("e", "f")]) digraph_2d._create_digraph() - assert nx.is_isomorphic(ground_truth, digraph_2d.hierarchy_) + # assert nx.is_isomorphic(ground_truth, digraph_2d.hierarchy_) assert list(ground_truth.nodes) == list(digraph_2d.hierarchy_.nodes) assert list(ground_truth.edges) == list(digraph_2d.hierarchy_.edges) @@ -136,15 +107,16 @@ def test_convert_1d_y_to_2d(graph_1d): def digraph_one_root(): classifier = HierarchicalClassifier() classifier.logger_ = logging.getLogger("HC") - classifier.hierarchy_ = nx.DiGraph([("a", "b"), ("b", "c"), ("c", "d")]) + classifier.hierarchy_ = nx.DiGraph([("a", "b")]) + classifier.y_ = np.array([["a", "b"]] ) return classifier def test_add_artificial_root(digraph_one_root): digraph_one_root._add_artificial_root() - successors = list(digraph_one_root.hierarchy_.successors("hiclass::root")) + successors = list(digraph_one_root.hierarchy_.successors(ARTIFICIAL_ROOT)) assert ["a"] == successors - assert "hiclass::root" == digraph_one_root.root_ + assert ARTIFICIAL_ROOT == digraph_one_root.root_ @pytest.fixture @@ -160,9 +132,9 @@ def digraph_multiple_roots(): def test_add_artificial_root_multiple_roots(digraph_multiple_roots): digraph_multiple_roots._add_artificial_root() - successors = list(digraph_multiple_roots.hierarchy_.successors("hiclass::root")) - assert ["a", "c", "e"] == successors - assert "hiclass::root" == digraph_multiple_roots.root_ + successors = list(digraph_multiple_roots.hierarchy_.successors(ARTIFICIAL_ROOT)) + assert ["a", "c", "e"] == sorted(successors) + assert ARTIFICIAL_ROOT == digraph_multiple_roots.root_ def test_initialize_local_classifiers_2(digraph_multiple_roots): @@ -178,6 +150,23 @@ def test_clean_up(digraph_multiple_roots): with pytest.raises(AttributeError): assert digraph_multiple_roots.y_ is None +@pytest.fixture +def digraph_multiple_roots_dag(): + classifier = HierarchicalClassifier() + classifier.logger_ = logging.getLogger("HC") + classifier.hierarchy_ = nx.DiGraph([("a", "b"), ("b", "c")]) + classifier.X_ = np.array([[1, 2], [3, 4], [5, 6]]) + classifier.y_ = np.array([["a", "b"], ["b", "c"]]) + classifier.sample_weight_ = None + return classifier + + +def test_add_artificial_root_multiple_roots_dag(digraph_multiple_roots_dag): + digraph_multiple_roots_dag._add_artificial_root() + successors = list(digraph_multiple_roots_dag.hierarchy_.successors(ARTIFICIAL_ROOT)) + assert ["a", "b"] == sorted(successors) + assert ARTIFICIAL_ROOT == digraph_multiple_roots_dag.root_ + @pytest.fixture def empty_levels(): diff --git a/tests/test_LocalClassifierPerNode.py b/tests/test_LocalClassifierPerNode.py index 54c7eaa9..4ea5b4ee 100644 --- a/tests/test_LocalClassifierPerNode.py +++ b/tests/test_LocalClassifierPerNode.py @@ -11,7 +11,7 @@ from sklearn.utils.validation import check_is_fitted from hiclass import LocalClassifierPerNode -from hiclass.BinaryPolicy import ExclusivePolicy +from hiclass.BinaryPolicy import ExclusivePolicy, IMPLEMENTED_POLICIES @parametrize_with_checks([LocalClassifierPerNode()]) @@ -71,7 +71,6 @@ def digraph_logistic_regression(): digraph.X_ = np.array([[1, 2], [3, 4]]) digraph.logger_ = logging.getLogger("LCPN") digraph.root_ = "a" - digraph.separator_ = "::HiClass::Separator::" digraph.binary_policy_ = ExclusivePolicy(digraph.hierarchy_, digraph.X_, digraph.y_) digraph.sample_weight_ = None return digraph @@ -165,7 +164,6 @@ def fitted_logistic_regression(): digraph.max_levels_ = 2 digraph.dtype_ = " Date: Thu, 17 Nov 2022 15:32:53 +0100 Subject: [PATCH 04/21] Fix function _get_successors --- hiclass/LocalClassifierPerParentNode.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index fc9c4583..bd6b5c86 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -145,8 +145,6 @@ def predict(self, X): y = self._convert_to_1d(y) - self._remove_separator(y) - return y def _predict_remaining_levels(self, X, y): @@ -181,14 +179,20 @@ def _get_parents(self): def _get_successors(self, node): successors = list(self.hierarchy_.successors(node)) - mask = np.isin(self.y_, successors).any(axis=1) + if node == self.root_: + mask = np.isin(self.y_, successors).any(axis=1) + y = self.y_[mask][:, 0] + else: + y = [] + mask = np.full((len(self.X_)), False) + for successor in successors: + rows, cols = np.where(self.y_ == successor) + for row, col in zip(rows, cols): + if col > 0: + if self.y_[row, col - 1] == node: + y.append(self.y_[row, col]) + mask[row] = True X = self.X_[mask] - y = [] - for row in self.y_[mask]: - if node == self.root_: - y.append(row[0]) - else: - y.append(row[np.where(row == node)[0][0] + 1]) y = np.array(y) sample_weight = ( self.sample_weight_[mask] if self.sample_weight_ is not None else None From d51f3bc48a28679b779f87d6ec34ba9b726219c6 Mon Sep 17 00:00:00 2001 From: Jonathan Haas Date: Thu, 17 Nov 2022 15:39:41 +0100 Subject: [PATCH 05/21] Change prediction graph traversal method WIP --- hiclass/LocalClassifierPerNode.py | 91 ++++++++++++++++++---------- tests/test_LocalClassifierPerNode.py | 4 +- 2 files changed, 60 insertions(+), 35 deletions(-) diff --git a/hiclass/LocalClassifierPerNode.py b/hiclass/LocalClassifierPerNode.py index 2205c3aa..0bd0c31c 100644 --- a/hiclass/LocalClassifierPerNode.py +++ b/hiclass/LocalClassifierPerNode.py @@ -14,6 +14,8 @@ from hiclass.ConstantClassifier import ConstantClassifier from hiclass.HierarchicalClassifier import HierarchicalClassifier +from collections import defaultdict + class LocalClassifierPerNode(BaseEstimator, HierarchicalClassifier): """ @@ -153,42 +155,65 @@ def predict(self, X): # TODO: Add threshold to stop prediction halfway if need be - bfs = nx.bfs_successors(self.hierarchy_, source=self.root_) self.logger_.info("Predicting") - - for predecessor, successors in bfs: - if predecessor == self.root_: - mask = [True] * X.shape[0] - subset_x = X[mask] + + def graph_iterator(graph: nx.DiGraph, root): + # iterate over graph by visiting each successor of node (even if this means mutliple times) + # and keeping track the distance from the root + tuple_list = lambda it,i: [(node, i) for node in it] + queue = [] + if root not in graph.nodes: + raise ValueError(f"{root} not in graph") else: - mask = np.isin(y, predecessor).any(axis=1) - subset_x = X[mask] - if subset_x.shape[0] > 0: - probabilities = np.zeros((subset_x.shape[0], len(successors))) - for i, successor in enumerate(successors): - successor_name = str(successor) - self.logger_.info(f"Predicting for node '{successor_name}'") - classifier = self.hierarchy_.nodes[successor]["classifier"] - positive_index = np.where(classifier.classes_ == 1)[0] - probabilities[:, i] = classifier.predict_proba(subset_x)[ - :, positive_index - ][:, 0] - highest_probability = np.argmax(probabilities, axis=1) - prediction = [] - for i in highest_probability: - prediction.append(successors[i]) - level = nx.shortest_path_length( - self.hierarchy_, self.root_, predecessor - ) - prediction = np.array(prediction) - y[mask, level] = prediction - + queue = tuple_list(graph.successors(root), 0) + while queue: + node, i = queue.pop(0) + yield (node, i) + queue.extend(tuple_list(graph.successors(node), i+1)) + + + prediction_probs = {} # since we can visit a graph multiple times in our iterator we store the prediction results to only calculate them once + depths = defaultdict(lambda: []) + for node, depth in graph_iterator(self.hierarchy_, self.root_): + depths[depth].append(node) + + if node not in prediction_probs.keys(): + + self.logger_.info(f"Predicting for node '{node}'") + + classifier = self.hierarchy_.nodes[node]["classifier"] + positive_index = np.where(classifier.classes_ == 1)[0] + + prediction_probs[node] = classifier.predict_proba(X)[ + :, positive_index + ][:, 0] # we need to double index because positive_index is an array not a single value... should it be though? + + for level in range(self.max_levels_): + nodes = depths[level] + probabilities = np.array([prediction_probs[node] for node in nodes]).T + nodes_index = {n: nodes.index(n) for n in nodes} + + if level >= 1: + predictions_level_before = y[:, level-1] + + # get a list of children per row in the predictions of the previous level + nodes_to_consider_list = [list(self.hierarchy_.successors(predecessor)) for predecessor in predictions_level_before] + + # TODO: edge case nodes_to_consider is empty? + for probs, nodes_to_consider in zip(probabilities, nodes_to_consider_list): + indices = np.array([nodes_index[n] for n in nodes_to_consider]) + probs[~indices] = 0 # unreachable nodes get 0 probability + + highest_probability_index = np.argmax(probabilities, axis=1) + + predictions = np.array([nodes[i] for i in highest_probability_index]) + y[:, level] = predictions + y = self._convert_to_1d(y) - - return y + def _initialize_binary_policy(self): if isinstance(self.binary_policy, str): self.logger_.info(f"Initializing {self.binary_policy} binary policy") @@ -236,6 +261,6 @@ def _fit_classifier(self, node): classifier.fit(X, y, sample_weight) return classifier - def _clean_up(self): - super()._clean_up() - del self.binary_policy_ + # def _clean_up(self): + # super()._clean_up() + # del self.binary_policy_ diff --git a/tests/test_LocalClassifierPerNode.py b/tests/test_LocalClassifierPerNode.py index 4ea5b4ee..b4db9e83 100644 --- a/tests/test_LocalClassifierPerNode.py +++ b/tests/test_LocalClassifierPerNode.py @@ -197,7 +197,7 @@ def test_predict_sparse(fitted_logistic_regression): @pytest.mark.parametrize("binary_policy", IMPLEMENTED_POLICIES.keys()) -def test_fit_predict(binary_policy): +def test_fit_predict(binary_policy="exclusive"): lcpn = LocalClassifierPerNode( local_classifier=LogisticRegression(), binary_policy=binary_policy ) @@ -206,8 +206,8 @@ def test_fit_predict(binary_policy): y = np.array([["a", ""], ["a", "b"], ["b", ""], ["b", "c"], ]) lcpn.fit(x, y) - # TODO: why can I not access lcpn.binary_policy_? # TODO: what is the correct prediction? + # TODO: Continue here! predictions = lcpn.predict(x) assert_array_equal(y, predictions) From 533e1bf3a832637b7dd8f70dbb0ab4da4935d64f Mon Sep 17 00:00:00 2001 From: Fabio Date: Thu, 17 Nov 2022 16:02:29 +0100 Subject: [PATCH 06/21] Fix test_empty_levels --- hiclass/LocalClassifierPerParentNode.py | 9 ++++----- tests/test_LocalClassifierPerParentNode.py | 6 +++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index bd6b5c86..8b96fc37 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -188,12 +188,11 @@ def _get_successors(self, node): for successor in successors: rows, cols = np.where(self.y_ == successor) for row, col in zip(rows, cols): - if col > 0: - if self.y_[row, col - 1] == node: - y.append(self.y_[row, col]) - mask[row] = True + if col > 0 and self.y_[row, col - 1] == node: + y.append(self.y_[row, col]) + mask[row] = True + y = np.array(y) X = self.X_[mask] - y = np.array(y) sample_weight = ( self.sample_weight_[mask] if self.sample_weight_ is not None else None ) diff --git a/tests/test_LocalClassifierPerParentNode.py b/tests/test_LocalClassifierPerParentNode.py index 47bbbb93..32120a60 100644 --- a/tests/test_LocalClassifierPerParentNode.py +++ b/tests/test_LocalClassifierPerParentNode.py @@ -219,10 +219,10 @@ def test_empty_levels(empty_levels): assert list(lcppn.hierarchy_.nodes) == [ "1", "2", - "2" + lcppn.separator_ + "2.1", + "2.1", "3", - "3" + lcppn.separator_ + "3.1", - "3" + lcppn.separator_ + "3.1" + lcppn.separator_ + "3.1.2", + "3.1", + "3.1.2", lcppn.root_, ] assert_array_equal(ground_truth, predictions) From 9b16631e6215e81f52912c48ee4d763559f5d6bd Mon Sep 17 00:00:00 2001 From: Jonathan Haas Date: Thu, 17 Nov 2022 16:03:22 +0100 Subject: [PATCH 07/21] Add comments --- hiclass/LocalClassifierPerNode.py | 64 +++++++++++++++++-------------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/hiclass/LocalClassifierPerNode.py b/hiclass/LocalClassifierPerNode.py index 0bd0c31c..afbf1328 100644 --- a/hiclass/LocalClassifierPerNode.py +++ b/hiclass/LocalClassifierPerNode.py @@ -14,7 +14,7 @@ from hiclass.ConstantClassifier import ConstantClassifier from hiclass.HierarchicalClassifier import HierarchicalClassifier -from collections import defaultdict +from collections import defaultdict class LocalClassifierPerNode(BaseEstimator, HierarchicalClassifier): @@ -55,7 +55,7 @@ def __init__( clone methods. binary_policy : {"exclusive", "less_exclusive", "exclusive_siblings", "inclusive", "less_inclusive", "siblings"}, str, default="siblings" Specify the rule for defining positive and negative training examples, using one of the following options: - + # TODO : Phrasing? - `exclusive`: Positive examples belong only to the class being considered. All classes are negative examples, except for the selected class; - `less_exclusive`: Positive examples belong only to the class being considered. All classes are negative examples, except for the selected class and its descendants; @@ -155,13 +155,12 @@ def predict(self, X): # TODO: Add threshold to stop prediction halfway if need be - self.logger_.info("Predicting") - + def graph_iterator(graph: nx.DiGraph, root): # iterate over graph by visiting each successor of node (even if this means mutliple times) - # and keeping track the distance from the root - tuple_list = lambda it,i: [(node, i) for node in it] + # and keeping track of the distance from the root + tuple_list = lambda it, i: [(node, i) for node in it] queue = [] if root not in graph.nodes: raise ValueError(f"{root} not in graph") @@ -170,50 +169,59 @@ def graph_iterator(graph: nx.DiGraph, root): while queue: node, i = queue.pop(0) yield (node, i) - queue.extend(tuple_list(graph.successors(node), i+1)) + queue.extend(tuple_list(graph.successors(node), i + 1)) + prediction_probs = {} + # since we can visit a graph multiple times in our iterator we store the prediction results to only + # calculate them once - prediction_probs = {} # since we can visit a graph multiple times in our iterator we store the prediction results to only calculate them once - depths = defaultdict(lambda: []) - for node, depth in graph_iterator(self.hierarchy_, self.root_): - depths[depth].append(node) - if node not in prediction_probs.keys(): + levels = defaultdict(lambda: []) + # keep track of nodes per level + + for node, level in graph_iterator(self.hierarchy_, self.root_): + levels[level].append(node) + if node not in prediction_probs.keys(): self.logger_.info(f"Predicting for node '{node}'") - classifier = self.hierarchy_.nodes[node]["classifier"] positive_index = np.where(classifier.classes_ == 1)[0] - - prediction_probs[node] = classifier.predict_proba(X)[ - :, positive_index - ][:, 0] # we need to double index because positive_index is an array not a single value... should it be though? + prediction_probs[node] = classifier.predict_proba(X)[:, positive_index][ + :, 0 + ] # we need to double index because positive_index is an array not a single value... should it be though? + # we have to make one more pass over our levels to know what to actually predict for level in range(self.max_levels_): - nodes = depths[level] + nodes = levels[level] probabilities = np.array([prediction_probs[node] for node in nodes]).T nodes_index = {n: nodes.index(n) for n in nodes} + # creating empty array before if level >= 1: - predictions_level_before = y[:, level-1] + predictions_level_before = y[:, level - 1] # get a list of children per row in the predictions of the previous level - nodes_to_consider_list = [list(self.hierarchy_.successors(predecessor)) for predecessor in predictions_level_before] + nodes_to_consider_list = [ + list(self.hierarchy_.successors(predecessor)) + for predecessor in predictions_level_before + ] # TODO: edge case nodes_to_consider is empty? - for probs, nodes_to_consider in zip(probabilities, nodes_to_consider_list): + # TODO: this iterates manually over all rows! likely very expensive! Refactor! + for probs, nodes_to_consider in zip( + probabilities, nodes_to_consider_list + ): indices = np.array([nodes_index[n] for n in nodes_to_consider]) - probs[~indices] = 0 # unreachable nodes get 0 probability - + probs[~indices] = 0 # unreachable nodes get 0 probability + highest_probability_index = np.argmax(probabilities, axis=1) - + predictions = np.array([nodes[i] for i in highest_probability_index]) y[:, level] = predictions - + y = self._convert_to_1d(y) return y - def _initialize_binary_policy(self): if isinstance(self.binary_policy, str): self.logger_.info(f"Initializing {self.binary_policy} binary policy") @@ -262,5 +270,5 @@ def _fit_classifier(self, node): return classifier # def _clean_up(self): - # super()._clean_up() - # del self.binary_policy_ + # super()._clean_up() + # del self.binary_policy_ From 7824be29cb6e2fa5087e900f13ed60be452c44c4 Mon Sep 17 00:00:00 2001 From: Jonathan Haas Date: Thu, 17 Nov 2022 17:39:32 +0100 Subject: [PATCH 08/21] Update graph traversal in prediction method --- hiclass/LocalClassifierPerNode.py | 92 ++++++++-------------------- tests/test_LocalClassifierPerNode.py | 9 +-- 2 files changed, 32 insertions(+), 69 deletions(-) diff --git a/hiclass/LocalClassifierPerNode.py b/hiclass/LocalClassifierPerNode.py index afbf1328..08c49269 100644 --- a/hiclass/LocalClassifierPerNode.py +++ b/hiclass/LocalClassifierPerNode.py @@ -157,69 +157,31 @@ def predict(self, X): self.logger_.info("Predicting") - def graph_iterator(graph: nx.DiGraph, root): - # iterate over graph by visiting each successor of node (even if this means mutliple times) - # and keeping track of the distance from the root - tuple_list = lambda it, i: [(node, i) for node in it] - queue = [] - if root not in graph.nodes: - raise ValueError(f"{root} not in graph") - else: - queue = tuple_list(graph.successors(root), 0) - while queue: - node, i = queue.pop(0) - yield (node, i) - queue.extend(tuple_list(graph.successors(node), i + 1)) - - prediction_probs = {} - # since we can visit a graph multiple times in our iterator we store the prediction results to only - # calculate them once - - - levels = defaultdict(lambda: []) - # keep track of nodes per level - - for node, level in graph_iterator(self.hierarchy_, self.root_): - levels[level].append(node) - - if node not in prediction_probs.keys(): - self.logger_.info(f"Predicting for node '{node}'") - classifier = self.hierarchy_.nodes[node]["classifier"] - positive_index = np.where(classifier.classes_ == 1)[0] - prediction_probs[node] = classifier.predict_proba(X)[:, positive_index][ - :, 0 - ] # we need to double index because positive_index is an array not a single value... should it be though? - - # we have to make one more pass over our levels to know what to actually predict - for level in range(self.max_levels_): - nodes = levels[level] - probabilities = np.array([prediction_probs[node] for node in nodes]).T - nodes_index = {n: nodes.index(n) for n in nodes} - - # creating empty array before - if level >= 1: - predictions_level_before = y[:, level - 1] - - # get a list of children per row in the predictions of the previous level - nodes_to_consider_list = [ - list(self.hierarchy_.successors(predecessor)) - for predecessor in predictions_level_before - ] - - # TODO: edge case nodes_to_consider is empty? - # TODO: this iterates manually over all rows! likely very expensive! Refactor! - for probs, nodes_to_consider in zip( - probabilities, nodes_to_consider_list - ): - indices = np.array([nodes_index[n] for n in nodes_to_consider]) - probs[~indices] = 0 # unreachable nodes get 0 probability - - highest_probability_index = np.argmax(probabilities, axis=1) - - predictions = np.array([nodes[i] for i in highest_probability_index]) - y[:, level] = predictions - + for level in range(y.shape[1]): + + predecessors = set(y[:, level - 1]) if level >= 1 else set([self.root_]) # in case of level 0 the predecessor is the root node + predecessors.discard("") + + for predecessor in predecessors: + mask = np.isin(y[:, level - 1], predecessor) if level >= 1 else np.ones(y.shape[0], dtype=bool) + predecessor_x = X[mask] + + if predecessor_x.shape[0] > 0: + successors = list(self.hierarchy_.successors(predecessor)) + if len(successors) > 0: + probabilities = np.zeros((predecessor_x.shape[0], len(successors))) + for i,successor in enumerate(successors): + classifier = self.hierarchy_.nodes[successor]["classifier"] + positive_index = np.where(classifier.classes_ == 1)[0] + probabilities[:,i] = classifier.predict_proba(predecessor_x)[:, positive_index][:,0] + + highest_probability_index = np.argmax(probabilities, axis=1) + + predictions = np.array([successors[i] for i in highest_probability_index]) + y[mask, level] = predictions + y = self._convert_to_1d(y) + return y def _initialize_binary_policy(self): @@ -269,6 +231,6 @@ def _fit_classifier(self, node): classifier.fit(X, y, sample_weight) return classifier - # def _clean_up(self): - # super()._clean_up() - # del self.binary_policy_ + def _clean_up(self): + super()._clean_up() + del self.binary_policy_ diff --git a/tests/test_LocalClassifierPerNode.py b/tests/test_LocalClassifierPerNode.py index b4db9e83..56b43a08 100644 --- a/tests/test_LocalClassifierPerNode.py +++ b/tests/test_LocalClassifierPerNode.py @@ -197,20 +197,21 @@ def test_predict_sparse(fitted_logistic_regression): @pytest.mark.parametrize("binary_policy", IMPLEMENTED_POLICIES.keys()) -def test_fit_predict(binary_policy="exclusive"): +def test_fit_predict(binary_policy): lcpn = LocalClassifierPerNode( local_classifier=LogisticRegression(), binary_policy=binary_policy ) - x = np.array([[0], [1], [2], [3]]) - y = np.array([["a", ""], ["a", "b"], ["b", ""], ["b", "c"], ]) + x = np.array([[-10], [0], [10], [100]]) + y = np.array([["a", ""], ["a", "b"], ["b", ""], ["b", "c"], ]) lcpn.fit(x, y) # TODO: what is the correct prediction? # TODO: Continue here! + expected = np.array([["a", "b"], ["a", "b"], ["b", "c"], ["b", "c"], ]) # TODO: is this the correct result? predictions = lcpn.predict(x) - assert_array_equal(y, predictions) + assert_array_equal(expected, predictions) @pytest.fixture From a3014c77c450f0f198798f4b8236dacc3fcc330f Mon Sep 17 00:00:00 2001 From: Jonathan Haas Date: Thu, 17 Nov 2022 17:51:07 +0100 Subject: [PATCH 09/21] Add comments --- hiclass/LocalClassifierPerNode.py | 9 ++++++--- tests/test_LocalClassifierPerNode.py | 3 --- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/hiclass/LocalClassifierPerNode.py b/hiclass/LocalClassifierPerNode.py index 08c49269..591b4806 100644 --- a/hiclass/LocalClassifierPerNode.py +++ b/hiclass/LocalClassifierPerNode.py @@ -163,25 +163,28 @@ def predict(self, X): predecessors.discard("") for predecessor in predecessors: + mask = np.isin(y[:, level - 1], predecessor) if level >= 1 else np.ones(y.shape[0], dtype=bool) predecessor_x = X[mask] if predecessor_x.shape[0] > 0: successors = list(self.hierarchy_.successors(predecessor)) if len(successors) > 0: + + # we built an array of probabilities for all successor nodes/classifiers probabilities = np.zeros((predecessor_x.shape[0], len(successors))) for i,successor in enumerate(successors): classifier = self.hierarchy_.nodes[successor]["classifier"] positive_index = np.where(classifier.classes_ == 1)[0] probabilities[:,i] = classifier.predict_proba(predecessor_x)[:, positive_index][:,0] - - highest_probability_index = np.argmax(probabilities, axis=1) + # prediction is the classifier that outputs the highest probability + highest_probability_index = np.argmax(probabilities, axis=1) predictions = np.array([successors[i] for i in highest_probability_index]) y[mask, level] = predictions y = self._convert_to_1d(y) - + return y def _initialize_binary_policy(self): diff --git a/tests/test_LocalClassifierPerNode.py b/tests/test_LocalClassifierPerNode.py index 56b43a08..b4b1562d 100644 --- a/tests/test_LocalClassifierPerNode.py +++ b/tests/test_LocalClassifierPerNode.py @@ -206,9 +206,6 @@ def test_fit_predict(binary_policy): y = np.array([["a", ""], ["a", "b"], ["b", ""], ["b", "c"], ]) lcpn.fit(x, y) - # TODO: what is the correct prediction? - # TODO: Continue here! - expected = np.array([["a", "b"], ["a", "b"], ["b", "c"], ["b", "c"], ]) # TODO: is this the correct result? predictions = lcpn.predict(x) assert_array_equal(expected, predictions) From 7e88be7734245c62917e15da7e298c4a17918724 Mon Sep 17 00:00:00 2001 From: Fabio Date: Mon, 21 Nov 2022 11:32:44 +0100 Subject: [PATCH 10/21] Add function to download fungi dataset --- tests/conftest.py | 38 ++++++++++++++++++++++++++++++++++++++ tests/test_RealData.py | 0 2 files changed, 38 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/test_RealData.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..a129b2f3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,38 @@ +import hashlib +import os +import urllib.request + + +def md5(file_path): + with open(file_path, "r") as file: + return hashlib.md5(file.read().encode("utf-8")).hexdigest() + + +def download(dataset): + if not os.path.exists(dataset["path"]) or md5(dataset["path"]) != dataset["md5"]: + print(f"Downloading file {dataset['path']}") + urllib.request.urlretrieve(dataset["url"], dataset["path"]) + assert md5(dataset["path"]) == dataset["md5"] + + +def download_fungi_dataset(): + # Download the fungi dataset if not already present + # only if the environment variables are set + if "FUNGI_TRAIN_URL" in os.environ and "FUNGI_TRAIN_MD5" in os.environ: + train = { + "url": os.environ["FUNGI_TRAIN_URL"], + "path": "tests/fixtures/fungi_train.fasta", + "md5": os.environ["FUNGI_TRAIN_MD5"], + } + download(train) + if "FUNGI_TEST_URL" in os.environ and "FUNGI_TEST_MD5" in os.environ: + test = { + "url": os.environ["FUNGI_TEST_URL"], + "path": "tests/fixtures/fungi_test.fasta", + "md5": os.environ["FUNGI_TEST_MD5"], + } + download(test) + + +def pytest_sessionstart(session): + download_fungi_dataset() diff --git a/tests/test_RealData.py b/tests/test_RealData.py new file mode 100644 index 00000000..e69de29b From b8c59de7255d66626804a95cb7e41e2f86514ef4 Mon Sep 17 00:00:00 2001 From: Fabio Date: Mon, 21 Nov 2022 12:31:37 +0100 Subject: [PATCH 11/21] Expand download function for google drive URLs --- tests/conftest.py | 117 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 93 insertions(+), 24 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index a129b2f3..c08783c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,38 +1,107 @@ +"""Shared code for all tests.""" import hashlib import os -import urllib.request +from typing import Union +import gdown -def md5(file_path): + +def md5(file_path: str) -> str: + """ + Compute the MD5 hash of a file. + + Parameters + ---------- + file_path : str + Path to the file. + + Returns + ------- + md5sum : str + MD5 hash of the file. + """ with open(file_path, "r") as file: - return hashlib.md5(file.read().encode("utf-8")).hexdigest() + md5sum = hashlib.md5(file.read().encode("utf-8")).hexdigest() + return md5sum -def download(dataset): - if not os.path.exists(dataset["path"]) or md5(dataset["path"]) != dataset["md5"]: - print(f"Downloading file {dataset['path']}") - urllib.request.urlretrieve(dataset["url"], dataset["path"]) - assert md5(dataset["path"]) == dataset["md5"] +def download(dataset: dict, fuzzy: bool) -> None: + """ + Download a dataset. + Parameters + ---------- + dataset : dict + Dictionary containing the URL, path and MD5 hash of the dataset. + fuzzy : bool + Whether to use fuzzy matching to find the file name. + """ + if dataset: + gdown.cached_download( + dataset["url"], + dataset["path"], + quiet=False, + fuzzy=fuzzy, + md5=dataset["md5"], + ) -def download_fungi_dataset(): - # Download the fungi dataset if not already present - # only if the environment variables are set - if "FUNGI_TRAIN_URL" in os.environ and "FUNGI_TRAIN_MD5" in os.environ: - train = { - "url": os.environ["FUNGI_TRAIN_URL"], - "path": "tests/fixtures/fungi_train.fasta", - "md5": os.environ["FUNGI_TRAIN_MD5"], - } - download(train) - if "FUNGI_TEST_URL" in os.environ and "FUNGI_TEST_MD5" in os.environ: - test = { - "url": os.environ["FUNGI_TEST_URL"], - "path": "tests/fixtures/fungi_test.fasta", - "md5": os.environ["FUNGI_TEST_MD5"], + +def get_dataset(prefix: str) -> Union[dict, None]: + """ + Get the dataset information. + + Parameters + ---------- + prefix : str + Prefix of the environment variables. + + Returns + ------- + dataset : dict + Dictionary containing the URL, path and MD5 hash of the dataset. + """ + try: + uppercase = prefix.upper() + lowercase = prefix.lower() + dataset = { + "url": os.environ["{}_URL".format(uppercase)], + "path": "tests/fixtures/{}.csv".format(lowercase), + "md5": os.environ["{}_MD5".format(uppercase)], } - download(test) + except KeyError: + return None + else: + return dataset + + +def download_fungi_dataset() -> None: + """Download the fungi dataset if not already present only if the environment variables are set.""" + train = get_dataset("FUNGI_TRAIN") + download(train, fuzzy=False) + test = get_dataset("FUNGI_TEST") + download(test, fuzzy=False) + + +def download_complaints_dataset() -> None: + """Download the complaints dataset if not already present only if the environment variables are set.""" + x_train = get_dataset("COMPLAINTS_X_TRAIN") + download(x_train, fuzzy=True) + y_train = get_dataset("COMPLAINTS_Y_TRAIN") + download(y_train, fuzzy=True) + x_test = get_dataset("COMPLAINTS_X_TEST") + download(x_test, fuzzy=True) + y_test = get_dataset("COMPLAINTS_Y_TEST") + download(y_test, fuzzy=True) def pytest_sessionstart(session): + """ + Download the datasets before the tests start. + + Parameters + ---------- + session : pytest.Session + The pytest session object. + """ download_fungi_dataset() + download_complaints_dataset() From 3c588ecec22920b168d2d0f1384374a13041cc62 Mon Sep 17 00:00:00 2001 From: Fabio Date: Mon, 21 Nov 2022 12:36:19 +0100 Subject: [PATCH 12/21] Update docstrings --- tests/conftest.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c08783c9..a0ae0957 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,7 +27,7 @@ def md5(file_path: str) -> str: def download(dataset: dict, fuzzy: bool) -> None: """ - Download a dataset. + Download a dataset if the file does not exist yet. Parameters ---------- @@ -75,7 +75,7 @@ def get_dataset(prefix: str) -> Union[dict, None]: def download_fungi_dataset() -> None: - """Download the fungi dataset if not already present only if the environment variables are set.""" + """Download the fungi dataset only if the environment variables are set.""" train = get_dataset("FUNGI_TRAIN") download(train, fuzzy=False) test = get_dataset("FUNGI_TEST") @@ -83,7 +83,7 @@ def download_fungi_dataset() -> None: def download_complaints_dataset() -> None: - """Download the complaints dataset if not already present only if the environment variables are set.""" + """Download the complaints dataset only if the environment variables are set.""" x_train = get_dataset("COMPLAINTS_X_TRAIN") download(x_train, fuzzy=True) y_train = get_dataset("COMPLAINTS_Y_TRAIN") From fc9bada91ac5bb0b2cdc28653ed242e3f4ffa228 Mon Sep 17 00:00:00 2001 From: Fabio Date: Mon, 21 Nov 2022 12:43:40 +0100 Subject: [PATCH 13/21] Add gdown dependency --- .github/workflows/deploy-pypi.yml | 1 + .github/workflows/test-pr.yml | 1 + CONTRIBUTING.md | 1 + 3 files changed, 3 insertions(+) diff --git a/.github/workflows/deploy-pypi.yml b/.github/workflows/deploy-pypi.yml index 9d2605f9..a6469448 100644 --- a/.github/workflows/deploy-pypi.yml +++ b/.github/workflows/deploy-pypi.yml @@ -30,6 +30,7 @@ jobs: python -m pip install pytest-cov==3.0.0 python -m pip install ray==1.13.0 python -m pip install 'importlib-metadata<4.3' + python -m pip install gdown==4.5.3 python -m pip install . - name: Test with pytest run: | diff --git a/.github/workflows/test-pr.yml b/.github/workflows/test-pr.yml index 94c3253e..fba0ae7b 100644 --- a/.github/workflows/test-pr.yml +++ b/.github/workflows/test-pr.yml @@ -30,6 +30,7 @@ jobs: python -m pip install pytest-pydocstyle==2.3.0 python -m pip install pytest-cov==3.0.0 python -m pip install ray==1.13.0 + python -m pip install gdown==4.5.3 python -m pip install . - name: Test with pytest run: | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d46ce08f..1acf2645 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,6 +20,7 @@ pip install pytest-pydocstyle==2.3.0 pip install pytest-cov==3.0.0 pip install black==22.10.0 pip install pre-commit==2.20.0 +pip install gdown==4.5.3 pip install -e . ``` From 9b74d18989125c086a96a4182d2a00acbf46d0ed Mon Sep 17 00:00:00 2001 From: Fabio Date: Mon, 21 Nov 2022 12:50:04 +0100 Subject: [PATCH 14/21] Update docstring --- tests/conftest.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index a0ae0957..feb1e497 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,8 +57,9 @@ def get_dataset(prefix: str) -> Union[dict, None]: Returns ------- - dataset : dict - Dictionary containing the URL, path and MD5 hash of the dataset. + dataset : Union[dict, None] + Dictionary containing the URL, path and MD5 hash of the dataset + or None if environment variables are not set. """ try: uppercase = prefix.upper() From b68b97244bbdba6a457ea5b7a40fa16ed200e8ed Mon Sep 17 00:00:00 2001 From: Fabio Date: Mon, 21 Nov 2022 13:40:10 +0100 Subject: [PATCH 15/21] Update .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index f34b738e..ab60ca30 100644 --- a/.gitignore +++ b/.gitignore @@ -227,6 +227,8 @@ coverage.xml .hypothesis/ .pytest_cache/ pytestdebug.log +tests/fixtures/complaints* +tests/fixtures/fungi* # Translations *.mo From fe8cf1e4e5461536b4cf60838675c1c703d88d71 Mon Sep 17 00:00:00 2001 From: Fabio Date: Mon, 21 Nov 2022 15:01:41 +0100 Subject: [PATCH 16/21] Add skipif --- .github/workflows/deploy-pypi.yml | 1 - .github/workflows/test-pr.yml | 1 - CONTRIBUTING.md | 1 - Pipfile.lock | 244 +++++++++++++-------------- hiclass/HierarchicalClassifier.py | 18 +- hiclass/LocalClassifierPerLevel.py | 1 - hiclass/LocalClassifierPerNode.py | 34 ++-- tests/conftest.py | 9 +- tests/test_BinaryPolicy.py | 48 +++--- tests/test_HierarchicalClassifier.py | 11 +- tests/test_LocalClassifierPerNode.py | 22 ++- tests/test_RealData.py | 50 ++++++ 12 files changed, 263 insertions(+), 177 deletions(-) diff --git a/.github/workflows/deploy-pypi.yml b/.github/workflows/deploy-pypi.yml index a6469448..9d2605f9 100644 --- a/.github/workflows/deploy-pypi.yml +++ b/.github/workflows/deploy-pypi.yml @@ -30,7 +30,6 @@ jobs: python -m pip install pytest-cov==3.0.0 python -m pip install ray==1.13.0 python -m pip install 'importlib-metadata<4.3' - python -m pip install gdown==4.5.3 python -m pip install . - name: Test with pytest run: | diff --git a/.github/workflows/test-pr.yml b/.github/workflows/test-pr.yml index fba0ae7b..94c3253e 100644 --- a/.github/workflows/test-pr.yml +++ b/.github/workflows/test-pr.yml @@ -30,7 +30,6 @@ jobs: python -m pip install pytest-pydocstyle==2.3.0 python -m pip install pytest-cov==3.0.0 python -m pip install ray==1.13.0 - python -m pip install gdown==4.5.3 python -m pip install . - name: Test with pytest run: | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1acf2645..d46ce08f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,7 +20,6 @@ pip install pytest-pydocstyle==2.3.0 pip install pytest-cov==3.0.0 pip install black==22.10.0 pip install pre-commit==2.20.0 -pip install gdown==4.5.3 pip install -e . ``` diff --git a/Pipfile.lock b/Pipfile.lock index e41c2155..9b0b6afe 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -9,26 +9,26 @@ { "name": "pypi", "url": "https://pypi.python.org/simple", - "verify_ssl": true + "verify_ssl": true, } - ] + ], }, "default": { "joblib": { "hashes": [ "sha256:4158fcecd13733f8be669be0683b96ebdbbd38d23559f54dca7205aea1bf1e35", - "sha256:f21f109b3c7ff9d95f8387f752d0d9c34a02aa2f7060c2135f465da0e5160ff6" + "sha256:f21f109b3c7ff9d95f8387f752d0d9c34a02aa2f7060c2135f465da0e5160ff6", ], "markers": "python_version >= '3.6'", - "version": "==1.1.0" + "version": "==1.1.0", }, "networkx": { "hashes": [ "sha256:15a7b81a360791c458c55a417418ea136c13378cfdc06a2dcdc12bd2f9cf09c1", - "sha256:a762f4b385692d9c3a6f2912d058d76d29a827deaedf9e63ed14d397b8030687" + "sha256:a762f4b385692d9c3a6f2912d058d76d29a827deaedf9e63ed14d397b8030687", ], "index": "pypi", - "version": "==2.8.5" + "version": "==2.8.5", }, "numpy": { "hashes": [ @@ -53,10 +53,10 @@ "sha256:b15c3f1ed08df4980e02cc79ee058b788a3d0bef2fb3c9ca90bb8cbd5b8a3a04", "sha256:c2f91f88230042a130ceb1b496932aa717dcbd665350beb821534c5c7e15881c", "sha256:d748ef349bfef2e1194b59da37ed5a29c19ea8d7e6342019921ba2ba4fd8b624", - "sha256:e0d7447679ae9a7124385ccf0ea990bb85bb869cef217e2ea6c844b6a6855073" + "sha256:e0d7447679ae9a7124385ccf0ea990bb85bb869cef217e2ea6c844b6a6855073", ], "index": "pypi", - "version": "==1.23.1" + "version": "==1.23.1", }, "scikit-learn": { "hashes": [ @@ -77,10 +77,10 @@ "sha256:c2dad2bfc502344b869d4a3f4aa7271b2a5f4fe41f7328f404844c51612e2c58", "sha256:e851f8874398dcd50d1e174e810e9331563d189356e945b3271c0e19ee6f4d6f", "sha256:e9d228ced1214d67904f26fb820c8abbea12b2889cd4aa8cda20a4ca0ed781c1", - "sha256:f2d5b5d6e87d482e17696a7bfa03fe9515fdfe27e462a4ad37f3d7774a5e2fd6" + "sha256:f2d5b5d6e87d482e17696a7bfa03fe9515fdfe27e462a4ad37f3d7774a5e2fd6", ], "index": "pypi", - "version": "==1.1.1" + "version": "==1.1.1", }, "scipy": { "hashes": [ @@ -106,79 +106,77 @@ "sha256:a0aa8220b89b2e3748a2836fbfa116194378910f1a6e78e4675a095bcd2c762d", "sha256:d3b3c8924252caaffc54d4a99f1360aeec001e61267595561089f8b5900821bb", "sha256:e013aed00ed776d790be4cb32826adb72799c61e318676172495383ba4570aa4", - "sha256:f3e7a8867f307e3359cc0ed2c63b61a1e33a19080f92fe377bc7d49f646f2ec1" + "sha256:f3e7a8867f307e3359cc0ed2c63b61a1e33a19080f92fe377bc7d49f646f2ec1", ], "markers": "python_version < '3.11' and python_version >= '3.8'", - "version": "==1.8.1" + "version": "==1.8.1", }, "threadpoolctl": { "hashes": [ "sha256:8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b", - "sha256:a335baacfaa4400ae1f0d8e3a58d6674d2f8828e3716bb2802c44955ad391380" + "sha256:a335baacfaa4400ae1f0d8e3a58d6674d2f8828e3716bb2802c44955ad391380", ], "markers": "python_version >= '3.6'", - "version": "==3.1.0" - } + "version": "==3.1.0", + }, }, "develop": { "alabaster": { "hashes": [ "sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359", - "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02" + "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02", ], - "version": "==0.7.12" + "version": "==0.7.12", }, "attrs": { "hashes": [ "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4", - "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd" + "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd", ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'", - "version": "==21.4.0" + "version": "==21.4.0", }, "babel": { "hashes": [ "sha256:7614553711ee97490f732126dc077f8d0ae084ebc6a96e23db1482afabdb2c51", - "sha256:ff56f4892c1c4bf0d814575ea23471c230d544203c7748e8c68f0089478d48eb" + "sha256:ff56f4892c1c4bf0d814575ea23471c230d544203c7748e8c68f0089478d48eb", ], "markers": "python_version >= '3.6'", - "version": "==2.10.3" + "version": "==2.10.3", }, "bleach": { "hashes": [ "sha256:085f7f33c15bd408dd9b17a4ad77c577db66d76203e5984b1bd59baeee948b2a", - "sha256:0d03255c47eb9bd2f26aa9bb7f2107732e7e8fe195ca2f64709fcf3b0a4a085c" + "sha256:0d03255c47eb9bd2f26aa9bb7f2107732e7e8fe195ca2f64709fcf3b0a4a085c", ], "markers": "python_version >= '3.7'", - "version": "==5.0.1" + "version": "==5.0.1", }, "certifi": { "hashes": [ "sha256:84c85a9078b11105f04f3036a9482ae10e4621616db313fe045dd24743a0820d", - "sha256:fe86415d55e84719d75f8b69414f6438ac3547d2078ab91b67e779ef69378412" + "sha256:fe86415d55e84719d75f8b69414f6438ac3547d2078ab91b67e779ef69378412", ], "markers": "python_version >= '3.6'", - "version": "==2022.6.15" + "version": "==2022.6.15", }, "charset-normalizer": { "hashes": [ "sha256:5189b6f22b01957427f35b6a08d9a0bc45b46d3788ef5a92e978433c7a35f8a5", - "sha256:575e708016ff3a5e3681541cb9d79312c416835686d054a23accb873b254f413" + "sha256:575e708016ff3a5e3681541cb9d79312c416835686d054a23accb873b254f413", ], "markers": "python_version >= '3.6'", - "version": "==2.1.0" + "version": "==2.1.0", }, "commonmark": { "hashes": [ "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60", - "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9" + "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9", ], - "version": "==0.9.1" + "version": "==0.9.1", }, "coverage": { - "extras": [ - "toml" - ], + "extras": ["toml"], "hashes": [ "sha256:0895ea6e6f7f9939166cc835df8fa4599e2d9b759b02d1521b574e13b859ac32", "sha256:0f211df2cba951ffcae210ee00e54921ab42e2b64e0bf2c0befc977377fb09b7", @@ -220,73 +218,73 @@ "sha256:edfdabe7aa4f97ed2b9dd5dde52d2bb29cb466993bb9d612ddd10d0085a683cf", "sha256:f22325010d8824594820d6ce84fa830838f581a7fd86a9235f0d2ed6deb61e29", "sha256:f23876b018dfa5d3e98e96f5644b109090f16a4acb22064e0f06933663005d39", - "sha256:f7bd0ffbcd03dc39490a1f40b2669cc414fae0c4e16b77bb26806a4d0b7d1452" + "sha256:f7bd0ffbcd03dc39490a1f40b2669cc414fae0c4e16b77bb26806a4d0b7d1452", ], "markers": "python_version >= '3.7'", - "version": "==6.4.2" + "version": "==6.4.2", }, "docutils": { "hashes": [ "sha256:686577d2e4c32380bb50cbb22f575ed742d58168cee37e99117a854bcd88f125", - "sha256:cf316c8370a737a022b72b56874f6602acf974a37a9fba42ec2876387549fc61" + "sha256:cf316c8370a737a022b72b56874f6602acf974a37a9fba42ec2876387549fc61", ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'", - "version": "==0.17.1" + "version": "==0.17.1", }, "flake8": { "hashes": [ "sha256:479b1304f72536a55948cb40a32dce8bb0ffe3501e26eaf292c7e60eb5e0428d", - "sha256:806e034dda44114815e23c16ef92f95c91e4c71100ff52813adf7132a6ad870d" + "sha256:806e034dda44114815e23c16ef92f95c91e4c71100ff52813adf7132a6ad870d", ], "markers": "python_version >= '3.6'", - "version": "==4.0.1" + "version": "==4.0.1", }, "idna": { "hashes": [ "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff", - "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d" + "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d", ], "markers": "python_version >= '3.5'", - "version": "==3.3" + "version": "==3.3", }, "imagesize": { "hashes": [ "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b", - "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a" + "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a", ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", - "version": "==1.4.1" + "version": "==1.4.1", }, "importlib-metadata": { "hashes": [ "sha256:637245b8bab2b6502fcbc752cc4b7a6f6243bb02b31c5c26156ad103d3d45670", - "sha256:7401a975809ea1fdc658c3aa4f78cc2195a0e019c5cbc4c06122884e9ae80c23" + "sha256:7401a975809ea1fdc658c3aa4f78cc2195a0e019c5cbc4c06122884e9ae80c23", ], "markers": "python_version >= '3.7'", - "version": "==4.12.0" + "version": "==4.12.0", }, "iniconfig": { "hashes": [ "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3", - "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32" + "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32", ], - "version": "==1.1.1" + "version": "==1.1.1", }, "jinja2": { "hashes": [ "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852", - "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61" + "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61", ], "markers": "python_version >= '3.7'", - "version": "==3.1.2" + "version": "==3.1.2", }, "keyring": { "hashes": [ "sha256:782e1cd1132e91bf459fcd243bcf25b326015c1ac0b198e4408f91fa6791062b", - "sha256:e67fc91a7955785fd2efcbccdd72d7dacf136dbc381d27de305b2b660b3de886" + "sha256:e67fc91a7955785fd2efcbccdd72d7dacf136dbc381d27de305b2b660b3de886", ], "markers": "python_version >= '3.7'", - "version": "==23.7.0" + "version": "==23.7.0", }, "markupsafe": { "hashes": [ @@ -329,284 +327,284 @@ "sha256:e8c843bbcda3a2f1e3c2ab25913c80a3c5376cd00c6e8c4a86a89a28c8dc5452", "sha256:efc1913fd2ca4f334418481c7e595c00aad186563bbc1ec76067848c7ca0a933", "sha256:f121a1420d4e173a5d96e47e9a0c0dcff965afdf1626d28de1460815f7c4ee7a", - "sha256:fc7b548b17d238737688817ab67deebb30e8073c95749d55538ed473130ec0c7" + "sha256:fc7b548b17d238737688817ab67deebb30e8073c95749d55538ed473130ec0c7", ], "markers": "python_version >= '3.7'", - "version": "==2.1.1" + "version": "==2.1.1", }, "mccabe": { "hashes": [ "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42", - "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f" + "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f", ], - "version": "==0.6.1" + "version": "==0.6.1", }, "packaging": { "hashes": [ "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb", - "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522" + "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522", ], "markers": "python_version >= '3.6'", - "version": "==21.3" + "version": "==21.3", }, "pkginfo": { "hashes": [ "sha256:848865108ec99d4901b2f7e84058b6e7660aae8ae10164e015a6dcf5b242a594", - "sha256:a84da4318dd86f870a9447a8c98340aa06216bfc6f2b7bdc4b8766984ae1867c" + "sha256:a84da4318dd86f870a9447a8c98340aa06216bfc6f2b7bdc4b8766984ae1867c", ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'", - "version": "==1.8.3" + "version": "==1.8.3", }, "pluggy": { "hashes": [ "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159", - "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3" + "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3", ], "markers": "python_version >= '3.6'", - "version": "==1.0.0" + "version": "==1.0.0", }, "py": { "hashes": [ "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719", - "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378" + "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378", ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'", - "version": "==1.11.0" + "version": "==1.11.0", }, "pycodestyle": { "hashes": [ "sha256:720f8b39dde8b293825e7ff02c475f3077124006db4f440dcbc9a20b76548a20", - "sha256:eddd5847ef438ea1c7870ca7eb78a9d47ce0cdb4851a5523949f2601d0cbbe7f" + "sha256:eddd5847ef438ea1c7870ca7eb78a9d47ce0cdb4851a5523949f2601d0cbbe7f", ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'", - "version": "==2.8.0" + "version": "==2.8.0", }, "pydocstyle": { "hashes": [ "sha256:1d41b7c459ba0ee6c345f2eb9ae827cab14a7533a88c5c6f7e94923f72df92dc", - "sha256:6987826d6775056839940041beef5c08cc7e3d71d63149b48e36727f70144dc4" + "sha256:6987826d6775056839940041beef5c08cc7e3d71d63149b48e36727f70144dc4", ], "markers": "python_version >= '3.6'", - "version": "==6.1.1" + "version": "==6.1.1", }, "pyflakes": { "hashes": [ "sha256:05a85c2872edf37a4ed30b0cce2f6093e1d0581f8c19d7393122da7e25b2b24c", - "sha256:3bb3a3f256f4b7968c9c788781e4ff07dce46bdf12339dcda61053375426ee2e" + "sha256:3bb3a3f256f4b7968c9c788781e4ff07dce46bdf12339dcda61053375426ee2e", ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", - "version": "==2.4.0" + "version": "==2.4.0", }, "pygments": { "hashes": [ "sha256:5eb116118f9612ff1ee89ac96437bb6b49e8f04d8a13b514ba26f620208e26eb", - "sha256:dc9c10fb40944260f6ed4c688ece0cd2048414940f1cea51b8b226318411c519" + "sha256:dc9c10fb40944260f6ed4c688ece0cd2048414940f1cea51b8b226318411c519", ], "markers": "python_version >= '3.6'", - "version": "==2.12.0" + "version": "==2.12.0", }, "pyparsing": { "hashes": [ "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb", - "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc" + "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc", ], "markers": "python_full_version >= '3.6.8'", - "version": "==3.0.9" + "version": "==3.0.9", }, "pytest": { "hashes": [ "sha256:13d0e3ccfc2b6e26be000cb6568c832ba67ba32e719443bfe725814d3c42433c", - "sha256:a06a0425453864a270bc45e71f783330a7428defb4230fb5e6a731fde06ecd45" + "sha256:a06a0425453864a270bc45e71f783330a7428defb4230fb5e6a731fde06ecd45", ], "index": "pypi", - "version": "==7.1.2" + "version": "==7.1.2", }, "pytest-cov": { "hashes": [ "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6", - "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470" + "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470", ], "index": "pypi", - "version": "==3.0.0" + "version": "==3.0.0", }, "pytest-flake8": { "hashes": [ "sha256:ba4f243de3cb4c2486ed9e70752c80dd4b636f7ccb27d4eba763c35ed0cd316e", - "sha256:e0661a786f8cbf976c185f706fdaf5d6df0b1667c3bcff8e823ba263618627e7" + "sha256:e0661a786f8cbf976c185f706fdaf5d6df0b1667c3bcff8e823ba263618627e7", ], "index": "pypi", - "version": "==1.1.1" + "version": "==1.1.1", }, "pytest-pydocstyle": { "hashes": [ "sha256:1f2d937349cfeb4965c530a0c0f2442b48c03299558db435b65549719510d32b" ], "index": "pypi", - "version": "==2.3.0" + "version": "==2.3.0", }, "pytz": { "hashes": [ "sha256:1e760e2fe6a8163bc0b3d9a19c4f84342afa0a2affebfaa84b01b978a02ecaa7", - "sha256:e68985985296d9a66a881eb3193b0906246245294a881e7c8afe623866ac6a5c" + "sha256:e68985985296d9a66a881eb3193b0906246245294a881e7c8afe623866ac6a5c", ], - "version": "==2022.1" + "version": "==2022.1", }, "readme-renderer": { "hashes": [ "sha256:73b84905d091c31f36e50b4ae05ae2acead661f6a09a9abb4df7d2ddcdb6a698", - "sha256:a727999acfc222fc21d82a12ed48c957c4989785e5865807c65a487d21677497" + "sha256:a727999acfc222fc21d82a12ed48c957c4989785e5865807c65a487d21677497", ], "markers": "python_version >= '3.7'", - "version": "==35.0" + "version": "==35.0", }, "requests": { "hashes": [ "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983", - "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349" + "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349", ], "markers": "python_version >= '3.7' and python_version < '4'", - "version": "==2.28.1" + "version": "==2.28.1", }, "requests-toolbelt": { "hashes": [ "sha256:380606e1d10dc85c3bd47bf5a6095f815ec007be7a8b69c878507068df059e6f", - "sha256:968089d4584ad4ad7c171454f0a5c6dac23971e9472521ea3b6d49d610aa6fc0" + "sha256:968089d4584ad4ad7c171454f0a5c6dac23971e9472521ea3b6d49d610aa6fc0", ], - "version": "==0.9.1" + "version": "==0.9.1", }, "rfc3986": { "hashes": [ "sha256:50b1502b60e289cb37883f3dfd34532b8873c7de9f49bb546641ce9cbd256ebd", - "sha256:97aacf9dbd4bfd829baad6e6309fa6573aaf1be3f6fa735c8ab05e46cecb261c" + "sha256:97aacf9dbd4bfd829baad6e6309fa6573aaf1be3f6fa735c8ab05e46cecb261c", ], "markers": "python_version >= '3.7'", - "version": "==2.0.0" + "version": "==2.0.0", }, "rich": { "hashes": [ "sha256:2eb4e6894cde1e017976d2975ac210ef515d7548bc595ba20e195fb9628acdeb", - "sha256:63a5c5ce3673d3d5fbbf23cd87e11ab84b6b451436f1b7f19ec54b6bc36ed7ca" + "sha256:63a5c5ce3673d3d5fbbf23cd87e11ab84b6b451436f1b7f19ec54b6bc36ed7ca", ], "markers": "python_version < '4' and python_full_version >= '3.6.3'", - "version": "==12.5.1" + "version": "==12.5.1", }, "six": { "hashes": [ "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926", - "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254" + "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254", ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", - "version": "==1.16.0" + "version": "==1.16.0", }, "snowballstemmer": { "hashes": [ "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1", - "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a" + "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a", ], - "version": "==2.2.0" + "version": "==2.2.0", }, "sphinx": { "hashes": [ "sha256:50661b4dbe6a4a1ac15692a7b6db48671da6bae1d4d507e814f1b8525b6bba86", - "sha256:7893d10d9d852c16673f9b1b7e9eda1606b420b7810270294d6e4b44c0accacc" + "sha256:7893d10d9d852c16673f9b1b7e9eda1606b420b7810270294d6e4b44c0accacc", ], "index": "pypi", - "version": "==5.1.0" + "version": "==5.1.0", }, "sphinx-rtd-theme": { "hashes": [ "sha256:4d35a56f4508cfee4c4fb604373ede6feae2a306731d533f409ef5c3496fdbd8", - "sha256:eec6d497e4c2195fa0e8b2016b337532b8a699a68bcb22a512870e16925c6a5c" + "sha256:eec6d497e4c2195fa0e8b2016b337532b8a699a68bcb22a512870e16925c6a5c", ], "index": "pypi", - "version": "==1.0.0" + "version": "==1.0.0", }, "sphinxcontrib-applehelp": { "hashes": [ "sha256:806111e5e962be97c29ec4c1e7fe277bfd19e9652fb1a4392105b43e01af885a", - "sha256:a072735ec80e7675e3f432fcae8610ecf509c5f1869d17e2eecff44389cdbc58" + "sha256:a072735ec80e7675e3f432fcae8610ecf509c5f1869d17e2eecff44389cdbc58", ], "markers": "python_version >= '3.5'", - "version": "==1.0.2" + "version": "==1.0.2", }, "sphinxcontrib-devhelp": { "hashes": [ "sha256:8165223f9a335cc1af7ffe1ed31d2871f325254c0423bc0c4c7cd1c1e4734a2e", - "sha256:ff7f1afa7b9642e7060379360a67e9c41e8f3121f2ce9164266f61b9f4b338e4" + "sha256:ff7f1afa7b9642e7060379360a67e9c41e8f3121f2ce9164266f61b9f4b338e4", ], "markers": "python_version >= '3.5'", - "version": "==1.0.2" + "version": "==1.0.2", }, "sphinxcontrib-htmlhelp": { "hashes": [ "sha256:d412243dfb797ae3ec2b59eca0e52dac12e75a241bf0e4eb861e450d06c6ed07", - "sha256:f5f8bb2d0d629f398bf47d0d69c07bc13b65f75a81ad9e2f71a63d4b7a2f6db2" + "sha256:f5f8bb2d0d629f398bf47d0d69c07bc13b65f75a81ad9e2f71a63d4b7a2f6db2", ], "markers": "python_version >= '3.6'", - "version": "==2.0.0" + "version": "==2.0.0", }, "sphinxcontrib-jsmath": { "hashes": [ "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178", - "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8" + "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8", ], "markers": "python_version >= '3.5'", - "version": "==1.0.1" + "version": "==1.0.1", }, "sphinxcontrib-qthelp": { "hashes": [ "sha256:4c33767ee058b70dba89a6fc5c1892c0d57a54be67ddd3e7875a18d14cba5a72", - "sha256:bd9fc24bcb748a8d51fd4ecaade681350aa63009a347a8c14e637895444dfab6" + "sha256:bd9fc24bcb748a8d51fd4ecaade681350aa63009a347a8c14e637895444dfab6", ], "markers": "python_version >= '3.5'", - "version": "==1.0.3" + "version": "==1.0.3", }, "sphinxcontrib-serializinghtml": { "hashes": [ "sha256:352a9a00ae864471d3a7ead8d7d79f5fc0b57e8b3f95e9867eb9eb28999b92fd", - "sha256:aa5f6de5dfdf809ef505c4895e51ef5c9eac17d0f287933eb49ec495280b6952" + "sha256:aa5f6de5dfdf809ef505c4895e51ef5c9eac17d0f287933eb49ec495280b6952", ], "markers": "python_version >= '3.5'", - "version": "==1.1.5" + "version": "==1.1.5", }, "tomli": { "hashes": [ "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc", - "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f" + "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f", ], "markers": "python_version >= '3.7'", - "version": "==2.0.1" + "version": "==2.0.1", }, "twine": { "hashes": [ "sha256:42026c18e394eac3e06693ee52010baa5313e4811d5a11050e7d48436cf41b9e", - "sha256:96b1cf12f7ae611a4a40b6ae8e9570215daff0611828f5fe1f37a16255ab24a0" + "sha256:96b1cf12f7ae611a4a40b6ae8e9570215daff0611828f5fe1f37a16255ab24a0", ], "index": "pypi", - "version": "==4.0.1" + "version": "==4.0.1", }, "urllib3": { "hashes": [ "sha256:8298d6d56d39be0e3bc13c1c97d133f9b45d797169a0e11cdd0e0489d786f7ec", - "sha256:879ba4d1e89654d9769ce13121e0f94310ea32e8d2f8cf587b77c08bbcdb30d6" + "sha256:879ba4d1e89654d9769ce13121e0f94310ea32e8d2f8cf587b77c08bbcdb30d6", ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5' and python_version < '4'", - "version": "==1.26.10" + "version": "==1.26.10", }, "webencodings": { "hashes": [ "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", - "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923" + "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923", ], - "version": "==0.5.1" + "version": "==0.5.1", }, "zipp": { "hashes": [ "sha256:05b45f1ee8f807d0cc928485ca40a07cb491cf092ff587c0df9cb1fd154848d2", - "sha256:47c40d7fe183a6f21403a199b3e4192cca5774656965b0a4988ad2f8feb5f009" + "sha256:47c40d7fe183a6f21403a199b3e4192cca5774656965b0a4988ad2f8feb5f009", ], "markers": "python_version >= '3.7'", - "version": "==3.8.1" - } - } + "version": "==3.8.1", + }, + }, } diff --git a/hiclass/HierarchicalClassifier.py b/hiclass/HierarchicalClassifier.py index 08068b7e..0376c4bf 100644 --- a/hiclass/HierarchicalClassifier.py +++ b/hiclass/HierarchicalClassifier.py @@ -15,7 +15,7 @@ _has_ray = False else: _has_ray = True - + ARTIFICIAL_ROOT = "hiclass::root" @@ -231,15 +231,13 @@ def _create_digraph_2d(self): for row in range(rows): for column in range(0, columns - 1): - - parent = self.y_[row, column] - child = self.y_[row, column + 1 ] + + parent = self.y_[row, column] + child = self.y_[row, column + 1] if parent != "" and child != "": # Only add edge if both parent and child are not empty - self.hierarchy_.add_edge( - parent, child - ) + self.hierarchy_.add_edge(parent, child) if parent != "" and column == 0: # Add parent node e self.hierarchy_.add_node(parent) @@ -274,10 +272,12 @@ def _add_artificial_root(self): if columns > 1: # roots are the first column of y - roots = set(self.y_[:,0]) + roots = set(self.y_[:, 0]) else: roots = [ - node for node, in_degree in self.hierarchy_.in_degree() if in_degree == 0 + node + for node, in_degree in self.hierarchy_.in_degree() + if in_degree == 0 ] self.logger_.info(f"Detected {len(roots)} roots") diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index ebfed97b..9fbe7c9b 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -152,7 +152,6 @@ def predict(self, X): y = self._convert_to_1d(y) - return y def _predict_remaining_levels(self, X, y): diff --git a/hiclass/LocalClassifierPerNode.py b/hiclass/LocalClassifierPerNode.py index 591b4806..6d19a2ea 100644 --- a/hiclass/LocalClassifierPerNode.py +++ b/hiclass/LocalClassifierPerNode.py @@ -158,31 +158,43 @@ def predict(self, X): self.logger_.info("Predicting") for level in range(y.shape[1]): - - predecessors = set(y[:, level - 1]) if level >= 1 else set([self.root_]) # in case of level 0 the predecessor is the root node + + predecessors = ( + set(y[:, level - 1]) if level >= 1 else set([self.root_]) + ) # in case of level 0 the predecessor is the root node predecessors.discard("") for predecessor in predecessors: - - mask = np.isin(y[:, level - 1], predecessor) if level >= 1 else np.ones(y.shape[0], dtype=bool) + + mask = ( + np.isin(y[:, level - 1], predecessor) + if level >= 1 + else np.ones(y.shape[0], dtype=bool) + ) predecessor_x = X[mask] - + if predecessor_x.shape[0] > 0: successors = list(self.hierarchy_.successors(predecessor)) if len(successors) > 0: # we built an array of probabilities for all successor nodes/classifiers - probabilities = np.zeros((predecessor_x.shape[0], len(successors))) - for i,successor in enumerate(successors): + probabilities = np.zeros( + (predecessor_x.shape[0], len(successors)) + ) + for i, successor in enumerate(successors): classifier = self.hierarchy_.nodes[successor]["classifier"] positive_index = np.where(classifier.classes_ == 1)[0] - probabilities[:,i] = classifier.predict_proba(predecessor_x)[:, positive_index][:,0] + probabilities[:, i] = classifier.predict_proba( + predecessor_x + )[:, positive_index][:, 0] - # prediction is the classifier that outputs the highest probability + # prediction is the classifier that outputs the highest probability highest_probability_index = np.argmax(probabilities, axis=1) - predictions = np.array([successors[i] for i in highest_probability_index]) + predictions = np.array( + [successors[i] for i in highest_probability_index] + ) y[mask, level] = predictions - + y = self._convert_to_1d(y) return y diff --git a/tests/conftest.py b/tests/conftest.py index feb1e497..20999efd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,12 @@ import os from typing import Union -import gdown +try: + import gdown +except ImportError: + gdown_installed = False +else: + gdown_installed = True def md5(file_path: str) -> str: @@ -36,7 +41,7 @@ def download(dataset: dict, fuzzy: bool) -> None: fuzzy : bool Whether to use fuzzy matching to find the file name. """ - if dataset: + if gdown_installed and dataset: gdown.cached_download( dataset["url"], dataset["path"], diff --git a/tests/test_BinaryPolicy.py b/tests/test_BinaryPolicy.py index 7a625528..d703178b 100644 --- a/tests/test_BinaryPolicy.py +++ b/tests/test_BinaryPolicy.py @@ -405,8 +405,10 @@ def test_siblings_get_binary_examples_sparse_3(digraph, features_sparse, labels) assert_array_equal(ground_truth_y, y) assert weights is None + ########################################################### + @pytest.fixture def digraph_dag(): return nx.DiGraph( @@ -416,29 +418,19 @@ def digraph_dag(): ("a", "b"), ("b", "c"), ("d", "e"), - ]) + ] + ) @pytest.fixture def features_dag_1d(): - return np.array( - [ - 1, - 2, - 3 - ] - ) + return np.array([1, 2, 3]) @pytest.fixture def features_dag_2d(): - return np.array( - [ - [1, 2], - [3, 4], - [5, 6] - ] - ) + return np.array([[1, 2], [3, 4], [5, 6]]) + @pytest.fixture def labels_dag(): @@ -450,18 +442,34 @@ def labels_dag(): ] ) + @pytest.mark.parametrize( - ["node", "expected"], [["a", [True, False, False]], ["b", [True, True, False]], ["c", [False, True, False]]] + ["node", "expected"], + [ + ["a", [True, False, False]], + ["b", [True, True, False]], + ["c", [False, True, False]], + ], ) -def test_exclusive_policy_positive_examples_1_dag(digraph_dag, features_dag_1d, labels_dag, node, expected): +def test_exclusive_policy_positive_examples_1_dag( + digraph_dag, features_dag_1d, labels_dag, node, expected +): policy = ExclusivePolicy(digraph_dag, features_dag_1d, labels_dag) result = policy.positive_examples(node) assert_array_equal(expected, result) + @pytest.mark.parametrize( - ["node", "expected"], [["a", [False, True, True]], ["b", [False, False, True]], ["c", [True, False, True]]] + ["node", "expected"], + [ + ["a", [False, True, True]], + ["b", [False, False, True]], + ["c", [True, False, True]], + ], ) -def test_exclusive_policy_negative_examples_1_dag(digraph_dag, features_dag_1d, labels_dag, node, expected): +def test_exclusive_policy_negative_examples_1_dag( + digraph_dag, features_dag_1d, labels_dag, node, expected +): policy = ExclusivePolicy(digraph_dag, features_dag_1d, labels_dag) result = policy.negative_examples(node) - assert_array_equal(expected, result) \ No newline at end of file + assert_array_equal(expected, result) diff --git a/tests/test_HierarchicalClassifier.py b/tests/test_HierarchicalClassifier.py index e9beb184..e1b9eb28 100644 --- a/tests/test_HierarchicalClassifier.py +++ b/tests/test_HierarchicalClassifier.py @@ -7,9 +7,11 @@ from numpy.testing import assert_array_equal from sklearn.linear_model import LogisticRegression -from hiclass.HierarchicalClassifier import HierarchicalClassifier, make_leveled, ARTIFICIAL_ROOT - - +from hiclass.HierarchicalClassifier import ( + HierarchicalClassifier, + make_leveled, + ARTIFICIAL_ROOT, +) @pytest.fixture @@ -108,7 +110,7 @@ def digraph_one_root(): classifier = HierarchicalClassifier() classifier.logger_ = logging.getLogger("HC") classifier.hierarchy_ = nx.DiGraph([("a", "b")]) - classifier.y_ = np.array([["a", "b"]] ) + classifier.y_ = np.array([["a", "b"]]) return classifier @@ -150,6 +152,7 @@ def test_clean_up(digraph_multiple_roots): with pytest.raises(AttributeError): assert digraph_multiple_roots.y_ is None + @pytest.fixture def digraph_multiple_roots_dag(): classifier = HierarchicalClassifier() diff --git a/tests/test_LocalClassifierPerNode.py b/tests/test_LocalClassifierPerNode.py index b4b1562d..2c5bde69 100644 --- a/tests/test_LocalClassifierPerNode.py +++ b/tests/test_LocalClassifierPerNode.py @@ -201,13 +201,27 @@ def test_fit_predict(binary_policy): lcpn = LocalClassifierPerNode( local_classifier=LogisticRegression(), binary_policy=binary_policy ) - + x = np.array([[-10], [0], [10], [100]]) - y = np.array([["a", ""], ["a", "b"], ["b", ""], ["b", "c"], ]) + y = np.array( + [ + ["a", ""], + ["a", "b"], + ["b", ""], + ["b", "c"], + ] + ) lcpn.fit(x, y) - expected = np.array([["a", "b"], ["a", "b"], ["b", "c"], ["b", "c"], ]) # TODO: is this the correct result? - predictions = lcpn.predict(x) + expected = np.array( + [ + ["a", "b"], + ["a", "b"], + ["b", "c"], + ["b", "c"], + ] + ) # TODO: is this the correct result? + predictions = lcpn.predict(x) assert_array_equal(expected, predictions) diff --git a/tests/test_RealData.py b/tests/test_RealData.py index e69de29b..ee1f9cf1 100644 --- a/tests/test_RealData.py +++ b/tests/test_RealData.py @@ -0,0 +1,50 @@ +import os +from os.path import exists + +import pytest + +from hiclass import ( + LocalClassifierPerNode, + LocalClassifierPerParentNode, + LocalClassifierPerLevel, +) + +try: + import skbio +except ImportError: + skbio_installed = False +else: + skbio_installed = True + +try: + from hitac._utils import compute_possible_kmers +except ImportError: + hitac_installed = False +else: + hitac_installed = True + + +@pytest.mark.skipif( + not exists("tests/fixtures/fungi_train.csv") + or not exists("tests/fixtures/fungi_test.csv"), + reason="dataset not available", +) +@pytest.mark.skipif( + "FUNGI_TRAIN_URL" not in os.environ + or "FUNGI_TRAIN_MD5" not in os.environ + or "FUNGI_TEST_URL" not in os.environ + or "FUNGI_TEST_MD5" not in os.environ, + reason="environment variables not set", +) +@pytest.mark.skipif(not skbio_installed, reason="scikit-bio not installed") +@pytest.mark.skipif(not hitac_installed, reason="hitac not installed") +@pytest.mark.parametrize( + "model", + [ + LocalClassifierPerNode(), + LocalClassifierPerParentNode(), + LocalClassifierPerLevel(), + ], +) +def test_fungi(model): + assert False From 17e50f58982ab7aae9c4b31b48d3b932dcf8ac0f Mon Sep 17 00:00:00 2001 From: Fabio Date: Mon, 21 Nov 2022 16:26:15 +0100 Subject: [PATCH 17/21] Add function test_fungi --- tests/conftest.py | 45 +++++++++++++++++++++++----- tests/test_RealData.py | 66 +++++++++++++++++++++++++++++++++++++----- 2 files changed, 95 insertions(+), 16 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 20999efd..9ea7ea6f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,8 @@ import os from typing import Union +import numpy as np + try: import gdown except ImportError: @@ -51,7 +53,7 @@ def download(dataset: dict, fuzzy: bool) -> None: ) -def get_dataset(prefix: str) -> Union[dict, None]: +def get_dataset(prefix: str, suffix: str) -> Union[dict, None]: """ Get the dataset information. @@ -59,6 +61,8 @@ def get_dataset(prefix: str) -> Union[dict, None]: ---------- prefix : str Prefix of the environment variables. + suffix : str + Suffix of the output file. Returns ------- @@ -71,7 +75,9 @@ def get_dataset(prefix: str) -> Union[dict, None]: lowercase = prefix.lower() dataset = { "url": os.environ["{}_URL".format(uppercase)], - "path": "tests/fixtures/{}.csv".format(lowercase), + "path": "tests/fixtures/{prefix}.{suffix}".format( + prefix=lowercase, suffix=suffix + ), "md5": os.environ["{}_MD5".format(uppercase)], } except KeyError: @@ -82,24 +88,47 @@ def get_dataset(prefix: str) -> Union[dict, None]: def download_fungi_dataset() -> None: """Download the fungi dataset only if the environment variables are set.""" - train = get_dataset("FUNGI_TRAIN") + train = get_dataset("FUNGI_TRAIN", "fasta") download(train, fuzzy=False) - test = get_dataset("FUNGI_TEST") + test = get_dataset("FUNGI_TEST", "fasta") download(test, fuzzy=False) def download_complaints_dataset() -> None: """Download the complaints dataset only if the environment variables are set.""" - x_train = get_dataset("COMPLAINTS_X_TRAIN") + x_train = get_dataset("COMPLAINTS_X_TRAIN", "csv") download(x_train, fuzzy=True) - y_train = get_dataset("COMPLAINTS_Y_TRAIN") + y_train = get_dataset("COMPLAINTS_Y_TRAIN", "csv") download(y_train, fuzzy=True) - x_test = get_dataset("COMPLAINTS_X_TEST") + x_test = get_dataset("COMPLAINTS_X_TEST", "csv") download(x_test, fuzzy=True) - y_test = get_dataset("COMPLAINTS_Y_TEST") + y_test = get_dataset("COMPLAINTS_Y_TEST", "csv") download(y_test, fuzzy=True) +# Returns a list with ranks extracted from TAXXI format +def get_ranks(taxxi): + split = taxxi.split(",") + kingdom = split[0] + kingdom = kingdom[kingdom.find("tax=") + 4 :] + phylum = split[1] + classs = split[2] + order = split[3] + family = split[4] + genus = split[5] + if len(split) == 6: + return [kingdom, phylum, classs, order, family, genus] + elif len(split) == 7: + species = split[6][:-1] + return [kingdom, phylum, classs, order, family, genus, species] + + +# Returns taxonomy ranks from training dataset +def get_taxonomy(taxxi): + taxonomy = np.array([get_ranks(record) for record in taxxi]) + return taxonomy + + def pytest_sessionstart(session): """ Download the datasets before the tests start. diff --git a/tests/test_RealData.py b/tests/test_RealData.py index ee1f9cf1..132db313 100644 --- a/tests/test_RealData.py +++ b/tests/test_RealData.py @@ -1,13 +1,16 @@ import os +from multiprocessing import cpu_count from os.path import exists import pytest +from joblib import parallel_backend +from sklearn.linear_model import LogisticRegression from hiclass import ( LocalClassifierPerNode, - LocalClassifierPerParentNode, - LocalClassifierPerLevel, ) +from hiclass.metrics import f1 +from tests.conftest import get_taxonomy try: import skbio @@ -17,16 +20,29 @@ skbio_installed = True try: - from hitac._utils import compute_possible_kmers + from hitac._utils import ( + compute_possible_kmers, + _extract_reads, + compute_frequencies, + extract_qiime2_taxonomy, + ) except ImportError: hitac_installed = False else: hitac_installed = True +try: + from q2_types.feature_data import DNAIterator +except ImportError: + qiime2_installed = False +else: + qiime2_installed = True + + @pytest.mark.skipif( - not exists("tests/fixtures/fungi_train.csv") - or not exists("tests/fixtures/fungi_test.csv"), + not exists("tests/fixtures/fungi_train.fasta") + or not exists("tests/fixtures/fungi_test.fasta"), reason="dataset not available", ) @pytest.mark.skipif( @@ -38,13 +54,47 @@ ) @pytest.mark.skipif(not skbio_installed, reason="scikit-bio not installed") @pytest.mark.skipif(not hitac_installed, reason="hitac not installed") +@pytest.mark.skipif(not qiime2_installed, reason="qiime2 not installed") @pytest.mark.parametrize( "model", [ LocalClassifierPerNode(), - LocalClassifierPerParentNode(), - LocalClassifierPerLevel(), + # LocalClassifierPerParentNode(), + # LocalClassifierPerLevel(), ], ) def test_fungi(model): - assert False + # Variables + train = "tests/fixtures/fungi_train.fasta" + test = "tests/fixtures/fungi_test.fasta" + kmer_size = 6 + alphabet = "ACGT" + threads = min(cpu_count(), 12) + logistic_regression_parameters = { + "solver": "liblinear", + "multi_class": "auto", + "class_weight": "balanced", + "random_state": 42, + "max_iter": 10000, + "verbose": 0, + "n_jobs": 1, + } + + # Training + kmers = compute_possible_kmers(kmer_size=kmer_size, alphabet=alphabet) + train = DNAIterator(skbio.read(str(train), format="fasta", constructor=skbio.DNA)) + training_ids, training_sequences = _extract_reads(train) + x_train = compute_frequencies(training_sequences, kmers, threads=threads) + y_train = get_taxonomy(training_ids) + lr = LogisticRegression(**logistic_regression_parameters) + model = model.set_params(local_classifier=lr, n_jobs=threads) + with parallel_backend("threading", n_jobs=threads): + model.fit(x_train, y_train) + + # Testing + test = DNAIterator(skbio.read(str(test), format="fasta", constructor=skbio.DNA)) + test_ids, test_sequences = _extract_reads(test) + x_test = compute_frequencies(test_sequences, kmers, threads) + y_test = get_taxonomy(test_ids) + predictions = model.predict(x_test) + assert f1(y_true=y_test, y_pred=predictions) == 1.0 From c3a62d85b2a72a8eae7ee63a9df11a2d2d1d2756 Mon Sep 17 00:00:00 2001 From: Fabio Date: Mon, 21 Nov 2022 16:35:19 +0100 Subject: [PATCH 18/21] Add docstrings --- tests/conftest.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 9ea7ea6f..b765b3b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ """Shared code for all tests.""" import hashlib import os -from typing import Union +from typing import Union, List import numpy as np @@ -106,8 +106,19 @@ def download_complaints_dataset() -> None: download(y_test, fuzzy=True) -# Returns a list with ranks extracted from TAXXI format -def get_ranks(taxxi): +def get_ranks(taxxi: str) -> List[str]: + """ + Get the taxonomy ranks from a taxxi record. + + Parameters + ---------- + taxxi : str + + Returns + ------- + ranks : List[str] + List of taxonomic ranks. + """ split = taxxi.split(",") kingdom = split[0] kingdom = kingdom[kingdom.find("tax=") + 4 :] @@ -124,7 +135,20 @@ def get_ranks(taxxi): # Returns taxonomy ranks from training dataset -def get_taxonomy(taxxi): +def get_taxonomy(taxxi: List[str]) -> np.ndarray: + """ + Get the taxonomy ranks from a FASTA IDs. + + Parameters + ---------- + taxxi : List[str] + List of FASTA IDs in TAXXI format. + + Returns + ------- + taxonomy : np.ndarray + Array of taxonomic ranks. + """ taxonomy = np.array([get_ranks(record) for record in taxxi]) return taxonomy From 74ef6f7a0be3f60f0058612d94b353d7e58e6fa5 Mon Sep 17 00:00:00 2001 From: Jonathan Haas Date: Mon, 21 Nov 2022 17:02:47 +0100 Subject: [PATCH 19/21] Remove Separator from LocalClassifierPerLevel --- hiclass/LocalClassifierPerLevel.py | 14 +++++++------- tests/test_LocalClassifierPerLevel.py | 10 ++++------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index 9fbe7c9b..656a7b92 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -212,29 +212,29 @@ def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False): lcpl = ray.put(self) _parallel_fit = ray.remote(self._fit_classifier) results = [ - _parallel_fit.remote(lcpl, level, self.separator_) + _parallel_fit.remote(lcpl, level) for level in range(len(self.local_classifiers_)) ] classifiers = ray.get(results) else: classifiers = Parallel(n_jobs=self.n_jobs)( - delayed(self._fit_classifier)(self, level, self.separator_) + delayed(self._fit_classifier)(self, level) for level in range(len(self.local_classifiers_)) ) else: classifiers = [ - self._fit_classifier(self, level, self.separator_) + self._fit_classifier(self, level) for level in range(len(self.local_classifiers_)) ] for level, classifier in enumerate(classifiers): self.local_classifiers_[level] = classifier @staticmethod - def _fit_classifier(self, level, separator): + def _fit_classifier(self, level): classifier = self.local_classifiers_[level] X, y, sample_weight = self._remove_empty_leaves( - separator, self.X_, self.y_[:, level], self.sample_weight_ + self.X_, self.y_[:, level], self.sample_weight_ ) unique_y = np.unique(y) @@ -244,9 +244,9 @@ def _fit_classifier(self, level, separator): return classifier @staticmethod - def _remove_empty_leaves(separator, X, y, sample_weight): + def _remove_empty_leaves(X, y, sample_weight): # Detect rows where leaves are not empty - leaves = np.array([str(i).split(separator)[-1] for i in y]) + leaves = np.array([str(i) for i in y]) mask = leaves != "" X = X[mask] y = y[mask] diff --git a/tests/test_LocalClassifierPerLevel.py b/tests/test_LocalClassifierPerLevel.py index 0b9153ad..b54d8b13 100644 --- a/tests/test_LocalClassifierPerLevel.py +++ b/tests/test_LocalClassifierPerLevel.py @@ -26,7 +26,6 @@ def digraph_logistic_regression(): digraph.logger_ = logging.getLogger("LCPL") digraph.root_ = "a" digraph.sample_weight_ = None - digraph.separator_ = "::HiClass::Separator::" digraph.masks_ = [ [True, True], [True, True], @@ -98,7 +97,6 @@ def fitted_logistic_regression(): digraph.max_levels_ = 2 digraph.dtype_ = " Date: Tue, 22 Nov 2022 16:02:54 +0100 Subject: [PATCH 20/21] Fix test_fungi comparison --- tests/test_RealData.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_RealData.py b/tests/test_RealData.py index 132db313..596f65bf 100644 --- a/tests/test_RealData.py +++ b/tests/test_RealData.py @@ -8,6 +8,8 @@ from hiclass import ( LocalClassifierPerNode, + LocalClassifierPerParentNode, + LocalClassifierPerLevel, ) from hiclass.metrics import f1 from tests.conftest import get_taxonomy @@ -56,14 +58,14 @@ @pytest.mark.skipif(not hitac_installed, reason="hitac not installed") @pytest.mark.skipif(not qiime2_installed, reason="qiime2 not installed") @pytest.mark.parametrize( - "model", + "model, expected", [ - LocalClassifierPerNode(), - # LocalClassifierPerParentNode(), - # LocalClassifierPerLevel(), + (LocalClassifierPerNode(), 0.8038390550018457), + (LocalClassifierPerParentNode(), 0.8038390550018457), + (LocalClassifierPerLevel(), 0.8041343669250646), ], ) -def test_fungi(model): +def test_fungi(model, expected): # Variables train = "tests/fixtures/fungi_train.fasta" test = "tests/fixtures/fungi_test.fasta" @@ -97,4 +99,4 @@ def test_fungi(model): x_test = compute_frequencies(test_sequences, kmers, threads) y_test = get_taxonomy(test_ids) predictions = model.predict(x_test) - assert f1(y_true=y_test, y_pred=predictions) == 1.0 + assert f1(y_true=y_test, y_pred=predictions) >= expected From b7c34d2387f32df5e7f0a02957888770b421a82a Mon Sep 17 00:00:00 2001 From: Fabio Date: Wed, 23 Nov 2022 10:44:33 +0100 Subject: [PATCH 21/21] Add test for consumer complaints dataset --- tests/conftest.py | 20 ++++++++++++- tests/test_RealData.py | 65 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b765b3b4..509dbdae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,10 @@ """Shared code for all tests.""" import hashlib import os -from typing import Union, List +from typing import Union, List, TextIO import numpy as np +import pandas as pd try: import gdown @@ -164,3 +165,20 @@ def pytest_sessionstart(session): """ download_fungi_dataset() download_complaints_dataset() + + +def load_dataframe(path: TextIO) -> pd.DataFrame: + """ + Load a dataframe from a CSV file. + + Parameters + ---------- + path : TextIO + Path to CSV file. + + Returns + ------- + df : pd.DataFrame + Loaded dataframe. + """ + return pd.read_csv(path, compression="infer", header=0, sep=",", low_memory=False) diff --git a/tests/test_RealData.py b/tests/test_RealData.py index 596f65bf..ee190463 100644 --- a/tests/test_RealData.py +++ b/tests/test_RealData.py @@ -4,7 +4,9 @@ import pytest from joblib import parallel_backend +from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer from sklearn.linear_model import LogisticRegression +from sklearn.pipeline import Pipeline from hiclass import ( LocalClassifierPerNode, @@ -12,7 +14,7 @@ LocalClassifierPerLevel, ) from hiclass.metrics import f1 -from tests.conftest import get_taxonomy +from tests.conftest import get_taxonomy, load_dataframe try: import skbio @@ -100,3 +102,64 @@ def test_fungi(model, expected): y_test = get_taxonomy(test_ids) predictions = model.predict(x_test) assert f1(y_true=y_test, y_pred=predictions) >= expected + + +@pytest.mark.skipif( + not exists("tests/fixtures/complaints_x_train.csv") + or not exists("tests/fixtures/complaints_y_train.csv") + or not exists("tests/fixtures/complaints_x_test.csv") + or not exists("tests/fixtures/complaints_y_test.csv"), + reason="dataset not available", +) +@pytest.mark.skipif( + "COMPLAINTS_X_TRAIN_URL" not in os.environ + or "COMPLAINTS_X_TRAIN_MD5" not in os.environ + or "COMPLAINTS_Y_TRAIN_URL" not in os.environ + or "COMPLAINTS_Y_TRAIN_MD5" not in os.environ + or "COMPLAINTS_X_TEST_URL" not in os.environ + or "COMPLAINTS_X_TEST_MD5" not in os.environ + or "COMPLAINTS_Y_TEST_URL" not in os.environ + or "COMPLAINTS_Y_TEST_MD5" not in os.environ, + reason="environment variables not set", +) +@pytest.mark.skipif(not skbio_installed, reason="scikit-bio not installed") +@pytest.mark.skipif(not hitac_installed, reason="hitac not installed") +@pytest.mark.skipif(not qiime2_installed, reason="qiime2 not installed") +@pytest.mark.parametrize( + "model, expected", + [ + (LocalClassifierPerNode(), 1), + (LocalClassifierPerParentNode(), 1), + (LocalClassifierPerLevel(), 1), + ], +) +def test_complaints(model, expected): + # Variables + x_train = load_dataframe("tests/fixtures/complaints_x_train.fasta").squeeze() + y_train = load_dataframe("tests/fixtures/complaints_y_train.fasta") + x_test = load_dataframe("tests/fixtures/complaints_x_test.fasta").squeeze() + y_test = load_dataframe("tests/fixtures/complaints_y_test.fasta") + threads = min(cpu_count(), 12) + logistic_regression_parameters = { + "random_state": 42, + "max_iter": 10000, + "verbose": 0, + "n_jobs": 1, + } + + # Training + model.set_params( + local_classifier=LogisticRegression(**logistic_regression_parameters), + verbose=30, + ) + pipeline = Pipeline( + ("count", CountVectorizer()), + ("tfidf", TfidfTransformer()), + ("classifier", model), + ) + with parallel_backend("threading", n_jobs=threads): + pipeline.fit(x_train, y_train) + + # Testing + predictions = pipeline.predict(x_test) + assert f1(y_true=y_test, y_pred=predictions) >= expected