A Gentle Introduction To Graph Neural Networks
A Gentle Introduction To Graph Neural Networks
This is your last free member-only story this month. Upgrade for unlimited access.
Graph
Before we get into GNN, let’s first understand what is Graph. In Computer Science, a
graph is a data structure consisting of two components, vertices and edges. A graph G
can be well described by the set of vertices V and edges E it contains.
The vertices are often called nodes. In this article, these two terms are interchangeable.
In the node classification problem setup, each node v is characterized by its feature x_v
and associated with a ground-truth label t_v. Given a partially labeled graph G, the
goal is to leverage these labeled nodes to predict the labels of the unlabeled. It learns to
represent each node with a d dimensional vector (state) h_v which contains the
information of its neighborhood. Specifically,
https://arxiv.org/pdf/1812.08434
where x_co[v] denotes the features of the edges connecting with v, h_ne[v] denotes the
embedding of the neighboring nodes of v, and x_ne[v] denotes the features of the
neighboring nodes of v. The function f is the transition function that projects these
inputs onto a d-dimensional space. Since we are seeking a unique solution for h_v, we
can apply Banach fixed point theorem and rewrite the above equation as an iteratively
update process. Such operation is often referred to as message passing or
neighborhood aggregation.
https://arxiv.org/pdf/1812.08434
The output of the GNN is computed by passing the state h_v as well as the feature x_v
to an output function g.
https://arxiv.org/pdf/1812.08434
However, there are three main limitations with this original proposal of GNN pointed
out by this paper:
2. It cannot process edge information (e.g. different edges in a knowledge graph may
indicate different relationship between nodes)
3. Fixed point can discourage the diversification of node distribution, and thus may
not be suitable for learning to represent nodes.
Several variants of GNN have been proposed to address the above issue. However, they
are not covered as they are not the focus in this post.
DeepWalk
DeepWalk is the first algorithm proposing node embedding learned in an unsupervised
manner. It highly resembles word embedding in terms of the training process. The
motivation is that the distribution of both nodes in a graph and words in a corpus
follow a power law as shown in the following figure:
http://www.perozzi.net/publications/14_kdd_deepwalk.pdf
The algorithm contains two steps:
2. Run skip-gram to learn the embedding of each node based on the node sequences
generated in step 1
At each time step of the random walk, the next node is sampled uniformly from the
neighbor of the previous node. Each sequence is then truncated into sub-sequences of
length 2|w| + 1, where w denotes the window size in skip-gram. If you are not
familiar with skip-gram, my previous blog post shall brief you how it works.
Therefore, the computation time is O(|V|) for the original softmax, where V denotes
the set of vertices in the graph.
Hierarchical softmax utilizes a binary tree to deal with the problem. In this binary tree,
all the leaves (v1, v2, … v8 in the above graph) are the vertices in the graph. In each of
the inner node, there is a binary classifier to decide which path to choose. To compute
the probability of a given vertex v_k, one simply compute the probability of each of the
sub-path along the path from the root node to the leave v_k. Since the probability of
each node’ children sums to 1, the property that the sum of the probability of all the
vertices equals 1 still holds in the hierarchical softmax. The computation time for an
element is now reduced to O(log|V|) as the longest path for a binary tree is bounded
by O(log(n)) where n is the number of leaves.
Hierarchical Softmax (http://www.perozzi.net/publications/14_kdd_deepwalk.pdf)
After a DeepWalk GNN is trained, the model has learned a good representation of each
node as shown in the following figure. Different colors indicate different labels in the
input graph. We can see that in the output graph (embedding with 2 dimensions),
nodes having the same labels are clustered together, while most nodes with different
labels are separated properly.
http://www.perozzi.net/publications/14_kdd_deepwalk.pdf
However, the main issue with DeepWalk is that it lacks the ability of generalization.
Whenever a new node comes in, it has to re-train the model in order to represent this
node (transductive). Thus, such GNN is not suitable for dynamic graphs where the
nodes in the graphs are ever-changing.
GraphSage
GraphSage provides a solution to address the aforementioned problem, learning the
embedding for each node in an inductive way. Specifically, each node is represented
by the aggregation of its neighborhood. Thus, even if a new node unseen during
training time appears in the graph, it can still be properly represented by its
neighboring nodes. Below shows the algorithm of GraphSage.
https://www-cs-faculty.stanford.edu/people/jure/pubs/graphsage-nips17.pdf
The outer loop indicates the number of update iteration, while h^k_v denotes the
latent vector of node v at update iteration k. At each update iteration, h^k_v is updated
based on an aggregation function, the latent vectors of v and v’s neighborhood in the
previous iteration, and a weight matrix W^k. The paper proposed three aggregation
function:
1. Mean aggregator:
The mean aggregator takes the average of the latent vectors of a node and all its
neighborhood.
https://www-cs-faculty.stanford.edu/people/jure/pubs/graphsage-nips17.pdf
Compared with the original equation, it removes the concatenation operation at line 5
in the above pseudo code. This operation can be viewed as a “skip-connection”, which
later in the paper proved to largely improve the performance of the model.
2. LSTM aggregator:
Since the nodes in the graph don’t have any order, they assign the order randomly by
permuting these nodes.
3. Pooling aggregator:
https://www-cs-faculty.stanford.edu/people/jure/pubs/graphsage-nips17.pdf
, which can be replaced with mean-pooling or any other symmetric pooling function. It
points out that pooling aggregator performs the best, while mean-pooling and max-
pooling aggregator have similar performance. The paper uses max-pooling as the
default aggregation function.
https://www-cs-faculty.stanford.edu/people/jure/pubs/graphsage-nips17.pdf
where u and v co-occur in a fixed-length random walk, while v_n are the negative
samples that don’t co-occur with u. Such loss function encourages nodes closer to have
similar embedding, while those far apart to be separated in the projected space. Via
this approach, the nodes will gain more and more information about their
neighborhoods.
Every Thursday, the Variable delivers the very best of Towards Data Science: from hands-on tutorials
and cutting-edge research to original features you don't want to miss. Take a look.