From 2556f5ea5c72ab759dd356c352436c6f90a03c70 Mon Sep 17 00:00:00 2001 From: mirand863 Date: Mon, 8 Apr 2024 14:41:31 +0200 Subject: [PATCH] Revert "Explainer implementation for LCPN #minor (#108)" This reverts commit d1bbd334a8503cb0feb04053b8558a6cf431a724. --- README.md | 2 +- docs/examples/plot_lcpn_explainer.py | 46 ------------------- docs/examples/plot_lcppn_explainer.py | 2 +- hiclass/Explainer.py | 49 +------------------- tests/test_Explainer.py | 64 +++------------------------ 5 files changed, 11 insertions(+), 152 deletions(-) delete mode 100644 docs/examples/plot_lcpn_explainer.py diff --git a/README.md b/README.md index c835ba91..7fd8bd28 100644 --- a/README.md +++ b/README.md @@ -202,7 +202,7 @@ predictions = pipeline.predict(X_test) ``` ## Explaining Hierarchical Classifiers -Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](https://colab.research.google.com/drive/1wqSl1t_Qn2f62WNZQ48mdB0mNeu1XSF1?usp=sharing), and [Local classifier per level]() is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html). +Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](), and [Local classifier per level]() is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html). ## Step-by-step walk-through diff --git a/docs/examples/plot_lcpn_explainer.py b/docs/examples/plot_lcpn_explainer.py deleted file mode 100644 index 39494fbf..00000000 --- a/docs/examples/plot_lcpn_explainer.py +++ /dev/null @@ -1,46 +0,0 @@ -# -*- coding: utf-8 -*- -""" -========================================= -Explaining Local Classifier Per Node -========================================= - -A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPN model. -A detailed summary of the Explainer class has been given at Algorithms Overview Section for :ref:`Hierarchical Explainability`. -SHAP values are calculated based on a synthetic platypus diseases dataset that can be downloaded `here `_. -""" -import numpy as np -from sklearn.ensemble import RandomForestClassifier -from hiclass import LocalClassifierPerNode, Explainer -from hiclass.datasets import load_platypus -import shap - -# Load train and test splits -X_train, X_test, Y_train, Y_test = load_platypus() - -# Use random forest classifiers for every node -rfc = RandomForestClassifier() -classifier = LocalClassifierPerNode(local_classifier=rfc, replace_classifiers=False) - -# Train local classifier per node -classifier.fit(X_train, Y_train) - -# Define Explainer -explainer = Explainer(classifier, data=X_train.values, mode="tree") -explanations = explainer.explain(X_test.values) -print(explanations) - -# Filter samples which only predicted "Respiratory" at first level -respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory" - -# Specify additional filters to obtain only level 0 -shap_filter = {"level": 0, "class": "Respiratory_1", "sample": respiratory_idx} - -# Use .sel() method to apply the filter and obtain filtered results -shap_val_respiratory = explanations.sel(shap_filter) - -# Plot feature importance on test set -shap.plots.violin( - shap_val_respiratory.shap_values, - feature_names=X_train.columns.values, - plot_size=(13, 8), -) diff --git a/docs/examples/plot_lcppn_explainer.py b/docs/examples/plot_lcppn_explainer.py index ab27ce38..c9d3dad0 100644 --- a/docs/examples/plot_lcppn_explainer.py +++ b/docs/examples/plot_lcppn_explainer.py @@ -26,7 +26,7 @@ classifier.fit(X_train, Y_train) # Define Explainer -explainer = Explainer(classifier, data=X_train.values, mode="tree") +explainer = Explainer(classifier, data=X_train, mode="tree") explanations = explainer.explain(X_test.values) print(explanations) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 585f4b85..36510256 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -189,46 +189,6 @@ def _get_traversed_nodes_lcppn(self, samples): ).flatten() return traversals - def _get_traversed_nodes_lcpn(self, samples): - """ - Return a list of all traversed nodes as per the provided LocalClassifierPerNode model. - - Parameters - ---------- - samples : array-like - Sample data for which to generate traversed nodes. - - Returns - ------- - traversals : list - A list of all traversed nodes as per LocalClassifierPerNode (LCPN) strategy. - """ - traversals = np.empty( - (samples.shape[0], self.hierarchical_model.max_levels_), - dtype=self.hierarchical_model.dtype_, - ) - - predictions = self.hierarchical_model.predict(samples) - - traversals[:, 0] = predictions[:, 0] - separator = np.full( - (samples.shape[0], 3), - self.hierarchical_model.separator_, - dtype=self.hierarchical_model.dtype_, - ) - - for level in range(1, traversals.shape[1]): - traversals[:, level] = np.char.add( - traversals[:, level - 1], - np.char.add(separator[:, 0], predictions[:, level]), - ) - - # For inconsistent hierarchies, levels with empty nodes should be ignored - mask = predictions == "" - traversals[mask] = "" - - return traversals - def _calculate_shap_values(self, X): """ Return an xarray.Dataset object for a single sample provided. This dataset is aligned on the `level` attribute. @@ -246,16 +206,11 @@ def _calculate_shap_values(self, X): traversed_nodes = [] if isinstance(self.hierarchical_model, LocalClassifierPerParentNode): traversed_nodes = self._get_traversed_nodes_lcppn(X)[0] - elif isinstance(self.hierarchical_model, LocalClassifierPerNode): - traversed_nodes = self._get_traversed_nodes_lcpn(X)[0] datasets = [] level = 0 for node in traversed_nodes: - # Skip if node is empty or classifier is not found, can happen in case of imbalanced hierarchies - if ( - node == "" - or "classifier" not in self.hierarchical_model.hierarchy_.nodes[node] - ): + # Skip if classifier is not found, can happen in case of imbalanced hierarchies + if "classifier" not in self.hierarchical_model.hierarchy_.nodes[node]: continue local_classifier = self.hierarchical_model.hierarchy_.nodes[node][ diff --git a/tests/test_Explainer.py b/tests/test_Explainer.py index c1caa5e7..c4af7a30 100644 --- a/tests/test_Explainer.py +++ b/tests/test_Explainer.py @@ -1,7 +1,10 @@ import numpy as np import pytest from sklearn.ensemble import RandomForestClassifier -from hiclass import LocalClassifierPerNode, LocalClassifierPerParentNode, Explainer +from hiclass import ( + LocalClassifierPerParentNode, + Explainer, +) try: import shap @@ -73,31 +76,6 @@ def test_explainer_tree_lcppn(data, request): assert explanation.data[j].split(lcppn.separator_)[-1] == y_pred[j] -@pytest.mark.skipif(not shap_installed, reason="shap not installed") -@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") -@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) -def test_explainer_tree_lcpn(data, request): - rfc = RandomForestClassifier() - lcpn = LocalClassifierPerNode(local_classifier=rfc, replace_classifiers=False) - - x_train, x_test, y_train = request.getfixturevalue(data) - - lcpn.fit(x_train, y_train) - - explainer = Explainer(lcpn, data=x_train, mode="tree") - explanations = explainer.explain(x_test) - - # Assert if explainer returns an xarray.Dataset object - assert isinstance(explanations, xarray.Dataset) - y_preds = lcpn.predict(x_test) - - # Assert if predictions made are consistent with the explanation object - for i in range(len(x_test)): - y_pred = y_preds[i] - for j in range(len(y_pred)): - assert str(explanations["node"][i].data[j]) == y_pred[j] - - @pytest.mark.skipif(not shap_installed, reason="shap not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) def test_traversal_path_lcppn(data, request): @@ -120,34 +98,10 @@ def test_traversal_path_lcppn(data, request): assert label == preds[i][j - 1] -@pytest.mark.skipif(not shap_installed, reason="shap not installed") -@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) -def test_traversal_path_lcpn(data, request): - x_train, x_test, y_train = request.getfixturevalue(data) - rfc = RandomForestClassifier() - lcpn = LocalClassifierPerNode(local_classifier=rfc, replace_classifiers=False) - - lcpn.fit(x_train, y_train) - explainer = Explainer(lcpn, data=x_train, mode="tree") - traversals = explainer._get_traversed_nodes_lcpn(x_test) - preds = lcpn.predict(x_test) - - # Assert if predictions and traversals are of same length - assert len(preds) == len(traversals) - - # Assert if traversal path in predictions is same as the computed traversal path - for i in range(len(x_test)): - for j in range(len(traversals[i])): - label = traversals[i][j].split(lcpn.separator_)[-1] - assert label == preds[i][j] - - @pytest.mark.skipif(not shap_installed, reason="shap not installed") @pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) -@pytest.mark.parametrize( - "classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode] -) +@pytest.mark.parametrize("classifier", [LocalClassifierPerParentNode]) def test_explain_with_xr(data, request, classifier): x_train, x_test, y_train = request.getfixturevalue(data) rfc = RandomForestClassifier() @@ -161,9 +115,7 @@ def test_explain_with_xr(data, request, classifier): assert isinstance(explanations, xarray.Dataset) -@pytest.mark.parametrize( - "classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode] -) +@pytest.mark.parametrize("classifier", [LocalClassifierPerParentNode]) def test_imports(classifier): x_train = [[76, 12, 49], [88, 63, 31], [5, 42, 24], [17, 90, 55]] y_train = [["a", "b", "d"], ["a", "b", "e"], ["a", "c", "f"], ["a", "c", "g"]] @@ -176,9 +128,7 @@ def test_imports(classifier): assert isinstance(explainer.data, np.ndarray) -@pytest.mark.parametrize( - "classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode] -) +@pytest.mark.parametrize("classifier", [LocalClassifierPerParentNode]) @pytest.mark.parametrize("data", ["explainer_data"]) @pytest.mark.parametrize("mode", ["linear", "gradient", "deep", "tree", ""]) def test_explainers(data, request, classifier, mode):