Exploring Computation Graphs in Rust
Exploring Computation Graphs in Rust
17 Jun 2018
Paul Kernfeld dot com
Recently, I’ve been trying to figure out a good way to model computation graphs in Rust. In this post, I
explore using a graph with vector indices. I’m not sure if this is the best approach, but writing it out has
helped me to understand the advantages and disadvantages better.
2 a b
\ / \ /
* *
\ /
+
In my ASCII diagram above, all edges go downwards. In general, edges don’t really have any interesting
information associated with them so I’m going to pretty much ignore them.
A homogenous graph
Below is a homogeneous DAG as a jumping-off point. This is strongly inspired by Modeling graphs in
Rust using vector indices. This isn’t a computation graph, but it does let us implement an algorithm that
traverses the graph and memoizes intermediate values. This is important because the number of paths
through a DAG can grow exponentially with the number of nodes, meaning that recursive
implementations can be very slow.
#[derive(Default)]
pub struct Graph {
nodes: Vec<Node>,
}
/// Graph maintains the invariant that nodes can only be added, never removed. This
means that
/// a particular Idx will always be valid as long as it is used with the correct
Graph.
impl Graph {
pub fn push(&mut self, children: Vec<Idx>) -> Idx {
self.nodes.push(Node { children });
Idx(self.nodes.len() - 1)
}
1
/// This returns the number of paths between each leaf node and the final node.
This
/// implementation memoizes the number of paths from leaves to each node.
///
/// Note that "final node" is only a meaningful concept in a DAG where there is
one node
/// that is the ancestor of every other node in the graph; I'm using it here for
simplicity.
pub fn count_paths(&self) -> usize {
let mut path_counts = Vec::new();
path_counts[path_counts.len() - 1]
}
}
• There are three different kinds of Node: Constant, Variable, and Sum
• There is a Subgraph type that represents an ordered set of nodes
• Idx implements std::ops::Add for cute graph-building syntax
• There is a derivative method that transforms a subgraph
• Graph implements Index<Idx> for slightly more type safety
2
/// To enable this to be used in HashMap and HashSet, this derives Eq, PartialEq,
and Hash
#[derive(Copy, Clone, Eq, Hash, PartialEq)]
pub struct Idx(usize);
impl Node {
fn get_value(&self, my_index: &Idx, values: &HashMap<Idx, f64>) -> f64 {
match self {
Node::Constant(value) => *value,
Node::Variable => values[my_index],
Node::Sum { children } => children.iter().map(|child|
values[child]).sum(),
}
}
fn derivative(
&self,
my_index: &Idx,
wrt: &HashSet<Idx>,
derivatives: &HashMap<Idx, Idx>,
) -> Node {
match self {
Node::Constant(_) => Node::Constant(0.0),
Node::Variable => {
if wrt.contains(my_index) {
Node::Constant(1.0)
} else {
Node::Constant(0.0)
}
}
Node::Sum { ref children } => {
Node::Sum {
children: children.iter().map(|child|
derivatives[child]).collect(),
}
}
}
}
}
/// This helps us to represent the idea that only a subset of the nodes in a graph
might be
/// relevant for a particular computation. The indices in a Subgraph are ordered
such that a
/// child always comes before one of its parents.
pub struct Subgraph {
indices: Vec<Idx>,
}
3
impl Subgraph {
fn new(indices_unsorted: impl Iterator<Item = Idx>) -> Self {
let mut indices: Vec<Idx> = indices_unsorted.collect();
#[derive(Default)]
pub struct Graph {
nodes: Vec<Node>,
}
impl Graph {
pub fn push(&mut self, node: Node) -> Idx {
self.nodes.push(node);
Idx(self.nodes.len() - 1)
}
/// Given values for each relevant variable, this computes the value for each
node in the
/// graph.
pub fn evaluate_subgraph(
&self,
subgraph: Subgraph,
variable_to_value: HashMap<Idx, f64>,
) -> HashMap<Idx, f64> {
let mut result = variable_to_value;
result
}
4
(
derivatives[&of],
Subgraph::new(derivatives.values().cloned()),
)
}
}
// c = 1 + b
let mut g = Graph::default();
let a = g.push(Node::Constant(1.0));
let b = g.push(Node::Variable);
let c = g.push(a + b);
// 1 + 2 = 3
let variable_to_value = {
let mut result = HashMap::new();
result.insert(b, 2.0);
result
};
assert_eq!(3.0, g.evaluate(variable_to_value)[&c]);
Overall, I’m pretty happy with this implementation. However, adding many different types of node will
cause the match statements to balloon. Additionally, I would rather see all the code for one type of Node
in one place.
fn derivative(
&self,
my_index: &Idx,
wrt: &HashSet<Idx>,
5
derivatives: &HashMap<Idx, Idx>,
) -> Box<Node>;
}
fn derivative(
&self,
_my_index: &Idx,
_wrt: &HashSet<Idx>,
_derivatives: &HashMap<Idx, Idx>,
) -> Box<Node> {
Box::from(Constant(0.0))
}
}
fn derivative(
&self,
my_index: &Idx,
wrt: &HashSet<Idx>,
_derivatives: &HashMap<Idx, Idx>,
) -> Box<Node> {
if wrt.contains(my_index) {
Box::from(Constant(1.0))
} else {
Box::from(Constant(0.0))
}
}
}
fn derivative(
&self,
_my_index: &Idx,
_wrt: &HashSet<Idx>,
derivatives: &HashMap<Idx, Idx>,
) -> Box<Node> {
Box::from(Sum {
children: self.children
.iter()
.map(|child| derivatives[child])
.collect(),
})
6
}
}
/// Since Node does not implement Sized, we need to box it so we can put it into a
Vec.
#[derive(Default)]
pub struct Graph {
nodes: Vec<Box<Node>>,
}
/// This is almost identical to the Graph implementation with the enum, except that
the push
/// fn now accepts a Box<Node>, and I've added push_box.
impl Graph {
pub fn push_box(&mut self, box_node: Box<Node>) -> Idx {
self.nodes.push(box_node);
Idx(self.nodes.len() - 1)
}
// c = 1 + b
let mut g = Graph::default();
let a = g.push(Constant(1.0));
let b = g.push(Variable);
let c = g.push_box(a + b);
// 1 + 2 = 3
let variable_to_value = {
let mut result = HashMap::new();
result.insert(b, 2.0);
result
};
assert_eq!(3.0, g.evaluate(variable_to_value)[&c]);
While implementing this, I noticed that making Node a trait enforces a clean separation of responsibility
between the graph and the node. I actually used the implementation of this version to clean up the
factorization of the enum version.
However, making Node a trait brings with it the ergonomic disadvantage that nodes often need to be
passed around inside a Box, which is slightly annoying. Separately, there is a performance penalty
because we are now using dynamic dispatch instead of static dispatch. I don’t think that I care too much
about this, because I’m interested in using these graphs with large tensors where the cost of the actual
computation will dwarf the cost of traversing the graph.
7
Advantages
I was happy about several aspects of this experiment:
Disadvantages
• The syntax for creating a graph is a bit cumbersome, since we have to write g.push(...) every
time we want to add a node to the graph.
• Having Graph, Subgraph, Node, and Idx is a lot of structs even for this toy implementation.
• I don’t entirely understand why Node needs to have the static lifetime. Hopefully that doesn’t
mean anything bad.
• Since a Node doesn’t know its own index, the index needs to be passed around a lot.
Future directions
I have an unhealthy obsession with building an elegant DSL in vanilla Rust. I would love to be able to
create a graph by writing something like this:
let x = variable();
let y = variable();
let z = x * 2.0 + y;
One somewhat crazy direction of exploration would be to allow different nodes in the graph to implement
different traits.
It would be good to have a way to represent functions that are composed of smaller functions, like
softmax.
These questions aside, the most obvious ways to make this more useful would be to implement many
different functions and to allow computation on data such as tensors.
About
This blog post was produced using cargo-readme to ensure that all of the code actually works. The source
code is here.