F G R L P T G: AST Raph Epresentation Earning With Y Orch Eometric
F G R L P T G: AST Raph Epresentation Earning With Y Orch Eometric
A BSTRACT
We introduce PyTorch Geometric, a library for deep learning on irregularly struc-
tured input data such as graphs, point clouds and manifolds, built upon PyTorch. In
addition to general graph data structures and processing methods, it contains a va-
riety of recently published methods from the domains of relational learning and 3D
data processing. PyTorch Geometric achieves high data throughput by leveraging
sparse GPU acceleration, by providing dedicated CUDA kernels and by introduc-
ing efficient mini-batch handling for input examples of different size. In this work,
we present the library in detail and perform a comprehensive comparative study of
the implemented methods in homogeneous evaluation scenarios.
1 INTRODUCTION
Graph Neural Networks (GNNs) recently emerged as a powerful approach for representation learning
on graphs, point clouds and manifolds (Bronstein et al., 2017; Kipf & Welling, 2017). Similar to the
concepts of convolutional and pooling layers on regular domains, GNNs are able to (hierarchically)
extract localized embeddings by passing, transforming, and aggregating information (Bronstein et al.,
2017; Gilmer et al., 2017; Battaglia et al., 2018; Ying et al., 2018; Morris et al., 2019).
However, implementing GNNs is challenging, as high GPU throughput needs to be achieved on
highly sparse and irregular data of varying size. Here, we introduce PyTorch Geometric (PyG), a
geometric deep learning extension library for PyTorch (Paszke et al., 2017) which achieves high
performance by leveraging dedicated CUDA kernels. Following a simple message passing API, it
bundles most of the recently proposed convolutional and pooling layers into a single and unified
framework. All implemented methods support both CPU and GPU computations and follow an
immutable data flow paradigm that enables dynamic changes in graph structures through time. PyG
is released under the MIT license and is available on GitHub.1 It is thoroughly documented and
provides accompanying tutorials and examples as a first starting point.2
2 OVERVIEW
In PyG, we represent a graph = (X, (I, E)) by a node feature matrix X ∈ ℝ𝑁×𝐹 of 𝑁 nodes and
a sparse adjacency tuple (I, E) of 𝐸 edges, where I ∈ ℕ2×𝐸 encodes edge indices in COOrdinate
(COO) format and E ∈ ℝ𝐸×𝐷 (optionally) holds 𝐷-dimensional edge features. All user facing APIs,
e.g., data loading routines, multi-GPU support, data augmentation or model instantiations are heavily
inspired by PyTorch to keep them as familiar as possible.
1
Published as a workshop paper at ICLR 2019
( )
𝑥⃗4 𝜙 𝑥⃗1 , 𝑥⃗2 , 𝑒⃗2,1
( )
𝑥⃗2 𝜙 𝑥⃗1 , 𝑥⃗3 , 𝑒⃗3,1
( )
𝑥⃗1 𝜙 𝑥⃗1 , 𝑥⃗4 , 𝑒⃗4,1 𝑥⃗′1
gather (𝑰) scatter_ ⬚ (𝑰)
𝛾
𝑥⃗3
Figure 1: Computation scheme of a GNN layer by leveraging gather and scatter methods based on
edge indices I, hence alternating between node parallel space and edge parallel space.
where ⬚ denotes a differentiable, permutation invariant function, e.g., sum, mean or max, and 𝛾
and 𝜙 denote differentiable functions, e.g., MLPs. In practice, this can be achieved by gathering
and scattering of node features and vectorized element-wise computation of 𝛾 and 𝜙, as visualized in
Figure 1. Although working on irregularly structured input, this scheme can be heavily accelerated by
the GPU. In contrast to implementations via sparse matrix multiplications, the usage of gather/scatter
proves to be advantageous for low-degree graphs and non-coalesced input (cf. Appendix A), and
allows for the integration of central node and multi-dimensional edge information while aggregating.
We provide the user with a general MessagePassing interface to allow for rapid and clean proto-
typing of new research ideas. In order to use, users only need to define the methods 𝜙, i.e., message,
and 𝛾, i.e., update, as well as chosing an aggregation scheme ⬚. For implementing 𝜙, node features
are automatically mapped to the respective source and target nodes.
Almost all recently proposed neighborhood aggregation functions can be lifted to this interface, in-
cluding (but not limited to) the methods already integrated into PyG: For learning on arbitrary graphs
we have implemented GCN (Kipf & Welling, 2017) and its simplified version (SGC) from Wu et al.
(2019), the spectral chebyshev and ARMA filter convolutions (Defferrard et al., 2016; Bianchi et al.,
2019), GraphSAGE (Hamilton et al., 2017), the attention-based operators GAT (Veličković et al.,
2018) and AGNN (Thekumparampil et al., 2018), the Graph Isomorphism Network (GIN) from Xu
et al. (2019), the Approximate Personalized Propagation of Neural Predictions (APPNP) operator
(Klicpera et al., 2019), the Dynamic Neighborhood Aggregation (DNA) operator (Fey, 2019) and
the signed operator for learning in signed networks (Derr et al., 2018).
For learning on point clouds, manifolds and graphs with multi-dimensional edge features, we pro-
vide the relational GCN operator from Schlichtkrull et al. (2018), PointNet++ (Qi et al., 2017),
PointCNN (Li et al., 2018), and the continuous kernel-based methods MPNN (Gilmer et al., 2017;
Simonovsky & Komodakis, 2017), MoNet (Monti et al., 2017), SplineCNN (Fey et al., 2018) and
the edge convolution operator (EdgeCNN) from Wang et al. (2018b).
In addition to these operators, we provide high-level implementations of, e.g., maximizing mutual in-
formation (Veličković et al., 2019), autoencoding graphs (Kipf & Welling, 2016; Pan et al., 2018), ag-
gregating jumping knowledge (Xu et al., 2018), and predicting temporal events in knowledge graphs
(Jin et al., 2019).
Global Pooling. PyG also supports graph-level outputs as opposed to node-level outputs by pro-
viding a variety of readout functions such as global add, mean or max pooling. We additionaly offer
more sophisticated methods such as set-to-set (Vinyals et al., 2016), sort pooling (Zhang et al., 2018)
or the global soft attention layer from Li et al. (2016).
Hierarchical Pooling. To further extract hierarchical information and to allow deeper GNN mod-
els, various pooling approaches can be applied in a spatial or data-dependent manner. We currently
provide implementation examples for Graclus (Dhillon et al., 2007; Fagginger Auer & Bisseling,
2011) and voxel grid pooling (Simonovsky & Komodakis, 2017), the iterative farthest point sam-
pling algorithm (Qi et al., 2017) followed by 𝑘-NN or query ball graph generation (Qi et al., 2017;
Wang et al., 2018b), and differentiable pooling mechanisms such as DiffPool (Ying et al., 2018) and
top𝑘 pooling (Gao & Ji, 2018; Cangea et al., 2018).
2
Published as a workshop paper at ICLR 2019
Table 1: Semi-supervised node classification with both fixed and random splits.
Cora CiteSeer PubMed
Method
Fixed Random Fixed Random Fixed Random
Cheby 81.4 ± 0.7 77.8 ± 2.2 70.2 ± 1.0 67.7 ± 1.7 78.4 ± 0.4 75.8 ± 2.2
GCN 81.5 ± 0.6 79.4 ± 1.9 71.1 ± 0.7 68.1 ± 1.7 79.0 ± 0.6 77.4 ± 2.4
GAT 83.1 ± 0.4 81.0 ± 1.4 70.8 ± 0.5 69.2 ± 1.9 78.5 ± 0.3 78.3 ± 2.3
SGC 81.7 ± 0.1 80.2 ± 1.6 71.3 ± 0.2 68.7 ± 1.6 78.9 ± 0.1 76.5 ± 2.4
ARMA 82.8 ± 0.6 80.7 ± 1.4 72.3 ± 1.1 68.9 ± 1.6 78.8 ± 0.3 77.7 ± 2.6
APPNP 83.3 ± 0.5 82.2 ± 1.5 71.8 ± 0.5 70.0 ± 1.4 80.1 ± 0.2 79.4 ± 2.2
Mini-batch Handling. Our framework supports batches of multiple graph instances (of potentially
different size) by automatically creating a single (sparse) block-diagonal adjacency matrix and con-
catenating feature matrices in the node dimension. Therefore, neighborhood aggregation methods
can be applied without modification, since no messages are exchanged between disconnected graphs.
In addition, an automatically generated assignment vector ensures that node-level information is not
aggregated across graphs, e.g., when executing global aggregation operators.
Processing of Datasets. We provide a consistent data format and an easy-to-use interface for the
creation and processing of datasets, both for large datasets and for datasets that can be kept in memory
during training. In order to create new datasets, users just need to read/download their data and
convert it to the PyG data format in the respective process method. In addition, datasets can be
modified by the use of transforms, which take in separate graphs and transform them, e.g., for
data augmentation, for enhancing node features with synthetic structural graph properties (Cai &
Wang, 2018), to automatically generate graphs from point clouds or to sample point clouds from
meshes.
PyG already supports a lot of common benchmark datasets often found in literature which are au-
tomatically downloaded and processed on first instantiation. In detail, we provide over 60 graph
kernel benchmark datasets3 (Kersting et al., 2016), e.g., PROTEINS or IMDB-BINARY, the cita-
tion graphs Cora, CiteSeer, PubMed and Cora-Full (Sen et al., 2008; Bojchevski & Günnemann,
2018), the Coauthor CS/Physics and Amazon Computers/Photo datasets from Shchur et al. (2018),
the molecule datasets QM7b (Montavon et al., 2013) and QM9 (Ramakrishnan et al., 2014), the
protein-protein interaction graphs from Hamilton et al. (2017), and the temporal datasets Bitcoin-
OTC (Kumar et al., 2016), ICEWS (Boschee et al., 2015) and GDELT (Leetaru & Schrodt, 2013). In
addition, we provide embedded datasets like MNIST superpixels (Monti et al., 2017), FAUST (Bogo
et al., 2014), ModelNet10/40 (Wu et al., 2015), ShapeNet (Chang et al., 2015), COMA (Ranjan et al.,
2018) and the PCPNet dataset from Guerrero et al. (2018).
3 E MPIRICAL E VALUATION
3
Published as a workshop paper at ICLR 2019
GIN-0 85.7 ± 7.7 72.1 ± 5.1 79.3 ± 2.7 72.8 ± 4.5 89.6 ± 2.6
GIN-𝜖 83.4 ± 7.5 72.6 ± 4.9 79.8 ± 2.4 72.1 ± 5.1 90.3 ± 3.0
Graclus 77.1 ± 7.2 73.0 ± 4.1 79.6 ± 2.0 72.2 ± 4.2 88.8 ± 3.2
Hier.
top𝑘 76.3 ± 7.5 72.7 ± 4.1 79.7 ± 2.2 72.5 ± 4.6 87.6 ± 2.4
DiffPool 85.0 ± 10.3 75.1 ± 3.5 78.9 ± 2.3 72.6 ± 3.9 92.1 ± 2.6
SAGE w/o JK 73.7 ± 7.8 72.7 ± 3.6 79.6 ± 2.4 72.1 ± 4.4 87.9 ± 1.9
Global
GlobalAttention 74.6 ± 8.0 72.5 ± 4.5 79.6 ± 2.2 72.3 ± 3.8 87.4 ± 2.5
Set2Set 73.7 ± 6.9 73.6 ± 3.7 79.6 ± 2.3 72.2 ± 4.2 89.6 ± 2.4
SortPool 77.3 ± 8.9 72.4 ± 4.1 77.7 ± 3.1 72.4 ± 3.8 74.9 ± 6.7
(Bianchi et al., 2019), SGC (Wu et al., 2019), GCN (Kipf & Welling, 2017) and GAT (Veličković
et al., 2018) operators follow closely behind.
Graph Classification. We report the average accuracy of 10-fold cross validation on a number of
common benchmark datasets (cf. Table 2) where we randomly sample a training fold to serve as a
validation set. We only make use of discrete node features. In case they are not given, we use one-hot
encodings of node degrees as feature input. For all experiments, we use the global mean operator
to obtain graph-level outputs. Inspired by the Jumping Knowledge framework (Xu et al., 2018), we
compute graph-level outputs after each convolutional layer and combine them via concatenation.
For evaluating the (global) pooling operators, we use the GraphSAGE operator as our baseline. We
omit Jumping Knowledge when comparing global pooling operators, and hence report an additional
baseline based on global mean pooling. For each dataset, we tune (1) the number of hidden units
∈ {16, 32, 64, 128} and (2) the number of layers ∈ {2, 3, 4, 5} with respect to the validation set.
Due to standardized evaluations and network architectures, not all results are aligned with their offi-
cial reported values. For example, except for DiffPool (Ying et al., 2018), (global) pooling operators
do not perform as benefically as expected to their respective (flat) counterparts, especially when
baselines are enhanced by Jumping Knowledge (Xu et al., 2018). However, the potential of more
sophisticated approaches may not be well-reflected on these simple benchmark tasks (Cai & Wang,
2018). Among the flat GNN approaches, the GIN layer (Xu et al., 2019) generally achieves the best
results.
Point Cloud Classification. We evaluate various point cloud methods on ModelNet10 (Wu et al.,
2015) where we uniformly sample 1,024 points from mesh surfaces based on face area (cf. Table 3).
As hierarchical pooling layers, we use the iterative farthest point sampling algorithm followed by a
new graph generation based on a larger query ball (PointNet++ (Qi et al., 2017), MPNN (Gilmer
et al., 2017; Simonovsky & Komodakis, 2017) and SplineCNN (Fey et al., 2018)) or based on a fixed
4
Published as a workshop paper at ICLR 2019
number of nearest neighbors (EdgeCNN (Wang et al., 2018b) and PointCNN (Li et al., 2018)). We
have taken care to use approximately the same number of parameters for each model.
All approaches perform nearly identically with PointCNN (Li et al., 2018) taking a slight lead. We
attribute this to the fact that all operators are based on similar principles and might have the same
expressive power for the given task.
ACKNOWLEDGMENTS
This work has been supported by the German Research Association (DFG) within the Collaborative
Research Center SFB 876, Providing Information by Resource-Constrained Analysis, projects A6
and B2. We thank Moritz Ludolph and all other contributors for their amazing involvement in this
project. Last but not least, we thank Christopher Morris for fruitful discussions, proofreading and
helpful advice.
R EFERENCES
P. W. Battaglia, J. B. Hamrick, V. Bapst, A. Sanchez-Gonzalez, V. F. Zambaldi, M. Malinowski,
A. Tacchetti, D. Raposo, A. Santoro, R. Faulkner, Ç. Gülçehre, F. Song, A. J. Ballard, J. Gilmer,
G. E. Dahl, A. Vaswani, K. Allen, C. Nash, V. Langston, C. Dyer, N. Heess, D. Wierstra, P. Kohli,
M. Botvinick, O. Vinyals, Y. Li, and R. Pascanu. Relational inductive biases, deep learning, and
graph networks. CoRR, abs/1806.01261, 2018.
F. M. Bianchi, D. Grattarola, L. Livi, and C. Alippi. Graph neural networks with convolutional
ARMA filters. CoRR, abs/1901.01343, 2019.
F. Bogo, J. Romero, M. Loper, and M. J. Black. FAUST: Dataset and evaluation for 3D mesh regis-
tration. In CVPR, 2014.
A. Bojchevski and S. Günnemann. Deep gaussian embedding of attributed graphs: Unsupervised
inductive learning via ranking. In ICLR, 2018.
E. Boschee, J. Lautenschlager, S. O’Brien, S. Shellman, J. Starz, and M. Ward. ICEWS coded event
data. Harvard Dataverse, 2015.
M. M. Bronstein, J. Bruna, Y. LeCun, A. Szlam, and P. Vandergheynst. Geometric deep learning:
Going beyond euclidean data. In Signal Processing Magazine, 2017.
C. Cai and Y. Wang. A simple yet effective baseline for non-attribute graph classification. CoRR,
abs/1811.03508, 2018.
C. Cangea, P. Veličković, N. Jovanović, T. N. Kipf, and P. Liò. Towards sparse hierarchical graph
classifiers. In NeurIPS-W, 2018.
A. X. Chang, T. Funkhouser, L. J. Guibas, P. Hanrahan, Q. Huang, Z. Li, S. Savarese, M. Savva,
S. Song, H. Su, J. Xiao, L. Yi, and F. Yu. ShapeNet: An information-rich 3D model repository.
CoRR, abs/1512.03012, 2015.
5
Published as a workshop paper at ICLR 2019
6
Published as a workshop paper at ICLR 2019
S. Pan, R. Hu, G. Long, J. Jiang, L. Yao, and C. Zhang. Adversarially regularized graph autoencoder
for graph embedding. In IJCAI, 2018.
A. Paszke, S. Gross, S. Chintala, G. Chanan, E. Yang, Z. DeVito, Z. Lin, A. Desmaison, L. Antiga,
and A. Lerer. Automatic differentiation in PyTorch. In NIPS-W, 2017.
C. R. Qi, L. Yi, H. Su, and L. J. Guibas. PointNet++: Deep hierarchical feature learning on point
sets in a metric space. In NIPS, 2017.
R. Ramakrishnan, P. O. Dral, M. Rupp, and O. A. von Lilienfeld. Quantum chemistry structures and
properties of 134 kilo molecules. Scientific Data, 2014.
A. Ranjan, T. Bolkart, S. Sanyal, and M. J. Black. Generating 3D faces using convolutional mesh
autoencoders. In ECCV, 2018.
M. S. Schlichtkrull, T. N. Kipf, P. Bloem, R. van den Berg, I. Titov, and M. Welling. Modeling
relational data with graph convolutional networks. In ESWC, 2018.
G. Sen, G. Namata, M. Bilgic, and L. Getoor. Collective classification in network data. AI Magazine,
29, 2008.
O. Shchur, M. Mumme, A. Bojchevski, and S. Günnemann. Pitfalls of graph neural network evalu-
ation. In NeurIPS-W, 2018.
M. Simonovsky and N. Komodakis. Dynamic edge-conditioned filters in convolutional neural net-
works on graphs. In CVPR, 2017.
K. K. Thekumparampil, C. Wang, S. Oh, and L. Li. Attention-based graph neural network for semi-
supervised learning. CoRR, abs/1803.03735, 2018.
P. Veličković, G. Cucurull, A. Casanova, A. Romero, P. Liò, and Y. Bengio. Graph attention net-
works. In ICLR, 2018.
P. Veličković, W. Fedus, W. L. Hamilton, P. Liò, Y. Bengio, and R. D. Hjeml. Deep graph infomax.
In ICLR, 2019.
O. Vinyals, S. Bengio, and M. Kudlur. Order matters: Sequence to sequence for sets. In ICLR, 2016.
M. Wang, L. Yu, A. Gan, D. Zheng, Y. Gai, Z. Ye, M. Li, J. Zhou, Q. Huang, J. Zhao, H. Lin,
C. Ma, D. Deng, Q. Guo, H. Zhang, J. Li, A. J. Smola, and Z. Zhang. Deep graph library. http:
//dgl.ai, 2018a.
Y. Wang, Y. Sun, Z. Liu, S. E. Sarma, M. M. Bronstein, and J. M. Solomon. Dynamic graph CNN
for learning on point clouds. CoRR, abs/1801.07829, 2018b.
F. Wu, T. Zhang, A. H. de Souza Jr., C. Fifty, T. Yu, and K. Q. Weinberger. Simplifying graph
convolutional networks. CoRR, abs/1902.07153, 2019.
Z. Wu, S. Song, A. Khosla, F. Yu, L. Zhang, X. Tang, and J. Xiao. 3D ShapeNets: A deep represen-
tation for volumetric shapes. In CVPR, 2015.
K. Xu, C. Li, Y. Tian, T. Sonobe, K. Kawarabayashi, and S. Jegelka. Representation learning on
graphs with jumping knowledge networks. In ICML, 2018.
K. Xu, W. Hu, J. Leskovec, and S. Jegelka. How powerful are graph neural networks? In ICLR,
2019.
R. Ying, J. You, C. Morris, X. Ren, W. Hamilton, and J. Leskovec. Hierarchical graph representation
learning with differentiable pooling. In NeurIPS, 2018.
M. Zhang, Z. Cui, M. Neumann, and Y. Chen. An end-to-end deep learning architecture for graph
classification. In AAAI, 2018.
7
Published as a workshop paper at ICLR 2019
0.5
0
1 2 4 8 16 32 64 128 256
Average degree of 10 000 nodes
Figure 2: Forward and backward runtimes of 1 000 runs of gather and scatter operations (GS) in
comparison to sparse-matrix multiplications (SpMM) on Erdős RÃľnyi graphs with 10 000 nodes
and varying average node degrees. Runtimes are capped at two seconds due to visualization. We
report runtimes for both coalesced (i.e. ordered by row index) and non-coalesced sparse layout.
PyTorch Geometric makes heavy usage of gather and scatter operations to map node and edge in-
formation into edge and node parallel space, respectively. Despite inevitable non-coalesced memory
access, these operations achieve high data throughput by leveraging parallelization over all elements
and making use of atomic operations with approximately constant runtime in practice. Following
upon the PyTorch scatter_add implementation, we provide our own mean and max operations
to allow for all different kinds of aggregation.5
Figure 2 compares the runtime of gather and scatter operations (GS) to the frequently used alternative
of using sparse-matrix multiplications (SpMM). It shows that atomic operations only begin to throttle
the runtime when scattering graphs with high average node degree (≈ 128) and even exceed the
runtime of highly optimized SpMM executions, both for forward and backward passes.
Due to SpMM first converting adjacency matrices into Compressed Row Storage (CSR) format, it
expects coalesced sparse tensors (i.e. ordered by row index) which is expensive to compute on GPUs
and should be hence performed as part of the pre-processing. However, for the backward pass of
SpMM, coalescing is performed in any case due to the need of transposing the sparse tensor. In
contrast, GS is always fast, nevertheless of the input being coalesced. Additionally, it allows for
modifications of the graph connectivity (i.e. adding self-loops), allows bidirectional data flow, and
does naturally support the integration of central node and multi-dimensional edge information.
However, we do think that our GS scheme can still be improved, e.g., in highly dense graph settings
and towards reducing the memory footprint in the edge parallel space. In addition, it should be noted
that scatter operations are non-deterministic by nature on the GPU. Although we did not observe any
deviations for inference, training results can vary across the same manual seeds.
B DATASETS
We give detailed descriptions and statistics (cf. Table 5) of the datasets used in our experiments:
Citation Networks. In the citation network datasets Cora, Citeseer and Pubmed nodes represent
documents and edges represent citation links. The networks contain bag-of-words feature vectors for
each document. We treat the citation links as (undirected) edges. For training, we use 20 labels per
class.
5
GitHub repository: https://github.com/rusty1s/pytorch_scatter
8
Published as a workshop paper at ICLR 2019
Social Network Datasets. COLLAB is derived from three public scientific collaboration datasets.
Each graph corresponds to an ego-network of different researchers from each field with the task to
label each graph to the field the corresponding researcher belongs to. IMDB-BINARY is a movie
collaboration dataset where each graph corresponds to an ego-network of actors/actresses. An edge
is drawn between two actors/actresses if they appear in the same movie. The task is to infer the genre
of the graph. REDDIT-BINARY is an online discussion dataset where each graph corresponds to a
thread. An edge is drawn between two users if one of them responded to another’s comment. The
task is to label each graph to the community/subreddit it belongs to.