Feature Importances With A Forest of Trees - Scikit-Learn 1.2.2 Documentation
Feature Importances With A Forest of Trees - Scikit-Learn 1.2.2 Documentation
html
Note: Click here to download the full example code or to run this example in your browser via Binder
As expected, the plot suggests that 3 features are informative, while the remaining are not.
X, y = make_classification(
n_samples=1000,
n_features=10,
n_informative=3,
n_redundant=0,
n_repeated=0,
n_classes=2,
random_state=0,
shuffle=False,
)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)
▾ RandomForestClassifier
RandomForestClassifier(random_state=0)
Warning: Impurity-based feature importances can be misleading for high cardinality features (many unique values). See
Permutation feature importance as an alternative below.
import time
import numpy as np
start_time = time.time()
importances = forest.feature_importances_
std = np.std([tree.feature_importances_ for tree in forest.estimators_], axis=0)
elapsed_time = time.time() - start_time
Toggle Menu
Let’s plot the impurity-based importance.
1 of 3 16/05/2023, 15:08
Feature importances with a forest of trees — scikit-learn 1.2.2 documentation https://scikit-learn.org/stable/auto_examples/ensemble/plot_forest_importances.html
import pandas as pd
fig, ax = plt.subplots()
forest_importances.plot.bar(yerr=std, ax=ax)
ax.set_title("Feature importances using MDI")
ax.set_ylabel("Mean decrease in impurity")
fig.tight_layout()
We observe that, as expected, the three first features are found important.
start_time = time.time()
result = permutation_importance(
forest, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2
)
elapsed_time = time.time() - start_time
print(f"Elapsed time to compute the importances: {elapsed_time:.3f} seconds")
The computation for full permutation importance is more costly. Features are shuffled n times and the model refitted to estimate the im-
portance of it. Please see Permutation feature importance for more details. We can now plot the importance ranking.
fig, ax = plt.subplots()
forest_importances.plot.bar(yerr=result.importances_std, ax=ax)
ax.set_title("Feature importances using permutation on full model")
ax.set_ylabel("Mean accuracy decrease")
fig.tight_layout()
plt.show()
Toggle Menu
2 of 3 16/05/2023, 15:08
Feature importances with a forest of trees — scikit-learn 1.2.2 documentation https://scikit-learn.org/stable/auto_examples/ensemble/plot_forest_importances.html
The same features are detected as most important using both methods. Although the relative importances vary. As seen on the plots,
MDI is less likely than permutation importance to fully omit a feature.
launch binder
© 2007 - 2023, scikit-learn developers (BSD License). Show this page source
Toggle Menu
3 of 3 16/05/2023, 15:08