Skip to content

Add pipeline visualizer? #14061

Closed
Closed
@amueller

Description

@amueller

Working with more complex pipelines and ColumnTransformer it can be a bit hard to visualize what's going on.
I think it would be nice to have a way to visualize the estimator graph.
Unfortunately this will be tricky without using graphviz.
With graphviz, it's pretty easy though.

on the column transformer example this produces:
image

or (with graph_attr={'rankdir': 'LR'}):
image

I'll put this into dabl for sure, but might also be interesting for sklearn, I think.
This is obviously just a hack and we could make this much nicer, for example with graphviz records. But even better would be avoiding graphviz.

def pipe_to_viz(est, name_from_parent=None, parent=None, **kwargs):
    # graph is not actually used in the singleton case
    # unless there's no parent (and it's not actually a graph)
    if name_from_parent is None:
        graph = Digraph("Some Pipeline", **kwargs)
    else:
        graph = Digraph("cluster " + name_from_parent, **kwargs)
        graph.attr(color='grey')
        graph.attr(label=name_from_parent)
    graph.attr(compound='true')

    if isinstance(est, Pipeline):
        prev_step = None
        prev_sub = None
        for step in est.steps:
            sub, node = pipe_to_viz(step[1], name_from_parent=step[0], parent=graph)
            graph.subgraph(sub)
            if prev_step is not None:
                ltail = getattr(prev_sub, 'name', None)
                lhead = getattr(sub, 'name', None)
                graph.edge(prev_step, node, ltail=ltail, lhead=lhead)


            prev_step = node
            prev_sub = sub
        return graph, node
    elif isinstance(est, ColumnTransformer):
        graph.node("out", shape="point")
        for trans in est.transformers:
            sub, node = pipe_to_viz(trans[1], name_from_parent=trans[0], parent=graph)
            graph.subgraph(sub)
            ltail = getattr(sub, 'name', None)
            graph.edge(node, 'out', ltail=ltail)
        return graph, "out"
    else:
        label = est.__class__.__name__
        node = label + "_" + str(np.random.randint(10000))
        if parent is None:
            graph.node(node, label=label, shape='box')
            return graph, node
        # we could draw boxes around each estimator with the name
        # or we could use the name in  the label? hm
        parent.node(node, label=label, shape='box')
        return None, node

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions