Skip to content

Improve decision tree plotting in Jupyter environment #6261

@nvdv

Description

@nvdv

Recently I used built-in visualizer for DecisionTreeClassifier in Jupyter and can say that
its interface could be better (example is taken from docs):

>>> from IPython.display import Image  
>>> dot_data = StringIO()  
>>> tree.export_graphviz(clf, out_file=dot_data,  
                         feature_names=iris.feature_names,  
                         class_names=iris.target_names,  
                         filled=True, rounded=True,  
                         special_characters=True)  
>>> graph = pydot.graph_from_dot_data(dot_data.getvalue())  
>>> Image(graph.create_png())  

In addition this does not work on Python 3, since at the time of writing pydot cannot be installed for Python 3.
The ideal solution will be something like

from sklearn import tree
tc = tree.DecisionTreeClassifier()
...
tree.plot(tc) # or even tc.plot()

but in this case tree module should depend on pydot and IPython.display.image modules.
I can fix this issue, but what is the best way to do this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    EnhancementModerateAnything that requires some knowledge of conventions and best practiceshelp wanted

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions