0% found this document useful (0 votes)
134 views

Implementing Graph Neural Networks With TensorFlow

This document introduces a Python package called kgcnn for implementing graph neural networks with TensorFlow-Keras. Kgcnn provides a set of Keras layers for graph networks that focus on a transparent tensor structure passed between layers and an ease-of-use mindset. It seamlessly integrates graph layers into the Keras environment by utilizing TensorFlow's new RaggedTensor class, which is well-suited for representing flexible graph data structures. The document also provides background on graph neural networks and discusses existing graph libraries.
Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
134 views

Implementing Graph Neural Networks With TensorFlow

This document introduces a Python package called kgcnn for implementing graph neural networks with TensorFlow-Keras. Kgcnn provides a set of Keras layers for graph networks that focus on a transparent tensor structure passed between layers and an ease-of-use mindset. It seamlessly integrates graph layers into the Keras environment by utilizing TensorFlow's new RaggedTensor class, which is well-suited for representing flexible graph data structures. The document also provides background on graph neural networks and discusses existing graph libraries.
Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 5

Implementing graph neural networks with

TensorFlow-Keras

Patrick Reiser1,2, Andre Eberhard 2 and Pascal Friederich1,2

1
Institute of Nanotechnology, Karlsruhe Institute of Technology (KIT), Hermann-von-Helmholtz-Platz 1, 76344
Eggenstein-Leopoldshafen,Germany
2
Institute of Theoretical Informatics, Karlsruhe Institute of Technology (KIT), Am Fasanengarten 5, 76131
Karlsruhe, Germany

E-Mail: [email protected], [email protected]

Graph neural networks are a versatile machine learning architecture that received a lot of
attention recently. In this technical report, we present an implementation of convolution and
pooling layers for TensorFlow-Keras models, which allows a seamless and flexible
integration into standard Keras layers to set up graph models in a functional way. This
implies the usage of mini-batches as the first tensor dimension, which can be realized via the
new RaggedTensor class of TensorFlow best suited for graphs. We developed the Keras
Graph Convolutional Neural Network Python package kgcnn based on TensorFlow-Keras
that provides a set of Keras layers for graph networks which focus on a transparent tensor
structure passed between layers and an ease-of-use mindset.

Introduction obtained. Most GCNs can be considered as message


passing networks,16 where neighbouring nodes
propagate information between each other along
Graph neural networks (GNNs) are a natural extension edges. In each update step 𝑡, the central nodes hidden
of common neural network architectures like
representation ℎ𝑣 is convolved with its neighbourhood
convolutional neural networks (CNN) for image
classification to graph structured data.1 For example, given by:
𝑡+1 𝑡+1 𝑡+1
recurrent,2,3 convolutional,1,4–6 and spatial-temporal7 ℎ𝑣 = 𝑈𝑡( ℎ𝑣 , 𝑚𝑣 ),
graph neural networks as well as graph autoencoders8,9 𝑡
where 𝑚𝑣 denotes the aggregated message and 𝑈𝑡 the
and graph transformer models10,11 have been reported
in literature. A graph 𝐺 = (𝑉, 𝐸) is defined as a set of update function. The message to update is usually
vertices or nodes 𝑣𝑖 ∈ 𝑉 and edges 𝑒𝑖𝑗 = (𝑣𝑖, 𝑣𝑗) ∈ 𝐸 (more complex aggregation is of course possible)
acquired from summing message functions 𝑀𝑡 from the
connecting two nodes. There are already
comprehensive and extensive review articles for graph neighbourhood 𝑁(𝑣) = {𝑢 ∈ 𝑉 | (𝑢, 𝑣) ∈ 𝐸} of node 𝑣:
neural networks, which summarize and categorize 𝑡+1 𝑡 𝑡
relevant literature on graph learning.12 The most 𝑚𝑣 = ∑ 𝑀𝑡(ℎ𝑣 , ℎ𝑤 , 𝑒𝑣𝑤).
𝑤∈𝑁(𝑣)
frequent applications of GNNs are either node There is a large variety in convolution operators, which
classification or graph embedding tasks. While node can be spectral-based17 or spatial-based involving
classification is a common task for very large graphs direct neighbours or a path of connected nodes to walk
such as citation networks9 or social graphs,13 graph and collect information from.11,18 Moreover, the message
embedding learns a representation of smaller graphs and update functions can be built from recurrent
like molecules4 or text classifications.14,15 Graph networks,19 multi-layer perceptrons (MLP)20 or attention
convolutional neural networks (GCN) stack multiple heads21 which are complemented by a set of possible
convolutional and pooling layers for deep learning to aggregation or pooling operations. Aggregation is
generate a high-level node representation from which usually done by a simple average of node
both a local node and global graph classification can be representations or by a more refined set2set encoder
part22 as proposed by Gilmer et al.16. A reduction of Keras graph layers which can be quickly rearranged,
nodes in the graph is achieved by pooling similar to changed and extended to build custom graph models
CNNs but which is much more challenging on arbitrarily with little effort. This implementation is focused on the
structured data. Examples of possibly differentiable and new TensorFlow’s RaggedTensor class which is most
learnable pooling filters introduced in literature are suited for flexible data structures such as graphs and
DiffPool,23 EdgePool,24 gPool,25 HGP-SL,23 SAGPool,26 natural language.
iPool,27 EigenPool28 and graph based clustering
methods such as the Graclus algorithm.29–32 Graph representation

In order to utilize the full scope of different graph A main issue with handling graphs is their flexible size,
operations for setting up a custom GNN model, a which is why graph data can not be easily arranged in
modular framework of convolution and pooling layers is tensors as it is done for example in image processing.
necessary. We briefly summarize and discuss existing Especially arranging smaller graphs of different size in
graph libraries and their code coverage. Then, a short mini-batches poses a problem with fixed sized tensors.
overview of representing graphs in tensor form is given. A way to circumvent this problem is to use
Finally, we introduce our graph package kgcnn for zero-padding with masking or composite tensors such
Tensorflow's 2.0 Keras API,33–35 which seamlessly as ragged or sparse tensors. Another possibility is to
integrates graph layers into the Keras36 environment. join small graphs into a single large graph, where the
individual subgraphs are not connected to each other,
which is illustrated in Figure 1 and is often referred to
Graph Libraries as disjoint representation. The tensors used to describe
a graph are typically given by a node list n of shape
Since graph neural networks require modified ([batch], N, F), a connection table of edge indices of
convolution and pooling operators, many Python incoming and outgoing node m with shape ([batch], M,
packages for deep learning have emerged for either 2) and a corresponding edge feature list e of shape
TensorFlow33,34 or PyTorch37 to work with graphs. We try ([batch], M, F). Here, N the number of nodes, F denotes
to summarize the most notable ones without any claim the dimension of the node representation and M the
that this list is complete. number of edges. A common representation of a
graph's structure is given by the adjacency matrix 𝐴 of
PyTorch Geometric.38 A PyTorch based graph library shape ([batch], N, N) which has 𝐴𝑖𝑗 = 1 if the graph has
which is probably the largest and most used graph
an edge between nodes i and j and 𝐴𝑖𝑗 = 0 otherwise.
learning Python package up to date. It implements a
huge variety of different graph models and uses a
disjoint graph representation to deal with batched
graphs (graph representations are discussed in the next
section).
Deep Graph Library (DGL).39 A graph model library with
a flexible backend and a performance optimized
implementation. It has its own graph data class with
Figure 1: Disjoint graph representation with adjacency matrix
many loading options. Moreover, variants such as A, node list n and connection table m. Edge features are
generative graph models,40 Capsule41 and added in form of a feature list e or a feature matrix E matching
42
transformers are included. A. The indices in m match the total graph as indicated by
Spektral.43 A Keras36 implementation of graph arrows. The subgraph distinction encoded by color has to be
convolutional networks. Originally restricted to spectral stored separately.
graph filters,17 it now includes spatial convolution and
pooling operations. The graph representation is made However, with RaggedTensors, node features and edge
flexible by different graph modes detected by each index lists can be passed to Keras models with a
layer. flexible tensor dimension that incorporates different
StellarGraph.44 A Keras36 implementation that numbers of nodes and edges. For example, a ragged
implements a set of convolution layers and a few node tensor of shape (batch, None, F) can
pooling layers plus a custom graph data format. accommodate a flexible graph size in the second
dimension. It is to note that even sparse matrices,
With PyTorch Geometric and DGL there are already which are commonly used to represent the adjacency
large graph libraries with a lot of contributors from both matrix in a disjoint representation, are internally stored
academics and industry. The focus of the graph as a value plus index tensor. This means that the
package presented here is on a neat integration of ragged tensor representation can be cast into a sparse
graphs into the TensorFlow-Keras framework in the or padded representation with little cost, if necessary.
most straightforward way. Thereby, we hope to provide TensorFlow 2.0 further supports limited sparse matrix
operations, which can be used for graph convolution seamless integration with other Keras models. We plan
models like GCN. to continue to extend the kgcnn library to incorporate
new models, in particular GNNExplainer48 and
Keras graph package - kgcnn DiffPool,23 and improve functionality.

A flexible and simple integration of graph operations Code availability


into the TensorFlow-Keras framework can be achieved
via ragged tensors. As mentioned above, ragged The package is available on the github repository
tensors are capable of efficiently representing graphs https://github.com/aimat-lab/gcnn_keras and through
and have inherently access to various methods within the Python Package Index via pip install kgcnn.
TensorFlow. For more sophisticated pooling algorithms
which can not be operated on batches, a parallelization Acknowledgement
of individual graphs within the batch could be achieved
with the TensorFlow map functionality, although this is P.F. acknowledges funding from the European Union’s
less efficient than vectorized operations and depends Horizon 2020 research and innovation programme
on implementation details. under the Marie Sklodowska-Curie grant agreement No
795206.
Model HOMO [eV] LUMO [eV] EG [eV]
References
MPN 0.061 0.047 0.083
(1) Kipf, T. N.; Welling, M. Semi-Supervised
Schnet 0.044 0.038 0.067 Classification with Graph Convolutional
Networks. ArXiv160902907 Cs Stat 2017.
MegNet 0.045 0.037 0.066 (2) Scarselli, F.; Gori, M.; Tsoi, A. C.; Hagenbuchner,
M.; Monfardini, G. The Graph Neural Network
Table 1: Mean absolute validation error for single training on
Model. IEEE Trans. Neural Netw. 2009, 20 (1),
QM9 dataset for targets like HOMO, LUMO level and gap EG in
61–80.
eV using popular GNN architectures implemented in kgcnn. No
https://doi.org/10.1109/TNN.2008.2005605.
hyperparameter optimization or feature engineering was
(3) Dai, H.; Kozareva, Z.; Dai, B.; Smola, A.; Song,
performed. L. Learning Steady-States of Iterative Algorithms
over Graphs. In International Conference on
Consequently, we introduce a Python package kgcnn Machine Learning; PMLR, 2018; pp 1106–1114.
(https://github.com/aimat-lab/gcnn_keras) that uses (4) Schütt, K. T.; Sauceda, H. E.; Kindermans, P.-J.;
RaggedTensors, which are passed between graph Tkatchenko, A.; Müller, K.-R. SchNet – A Deep
layers for graph convolution and message passing Learning Architecture for Molecules and
Materials. J. Chem. Phys. 2018, 148 (24),
models. We believe that the use of RaggedTensors
241722. https://doi.org/10.1063/1.5019779.
makes it easy to debug code, allows a transparent and (5) Niepert, M.; Ahmed, M.; Kutzkov, K. Learning
readable coding style, and enables a seamless Convolutional Neural Networks for Graphs.
integration with many TensorFlow methods which are ArXiv160505273 Cs Stat 2016.
available for custom layers. We implemented a set of (6) Battaglia, P. W.; Pascanu, R.; Lai, M.; Rezende,
basic Keras layers for TensorFlow 2.0 from which many D.; Kavukcuoglu, K. Interaction Networks for
models reported in literature can be constructed. The Learning about Objects, Relations and Physics.
ArXiv161200222 Cs 2016.
Python package implements as an example: GCN,1
(7) Yu, B.; Yin, H.; Zhu, Z. Spatio-Temporal Graph
Interaction network,6 message passing,16 Schnet,4 Convolutional Networks: A Deep Learning
MegNet20 and Unet25. The focus is set on graph Framework for Traffic Forecasting. Proc.
embedding tasks, but also node and link classification Twenty-Seventh Int. Jt. Conf. Artif. Intell. 2018,
tasks can be implemented using kgcnn. The models 3634–3640.
were tested with common bench-mark datasets such as https://doi.org/10.24963/ijcai.2018/505.
(8) Pan, S.; Hu, R.; Long, G.; Jiang, J.; Yao, L.;
Cora,45 MUTAG46 and QM9.47 Typical benchmark
Zhang, C. Adversarially Regularized Graph
accuracies such as chemical accuracy on the QM9 Autoencoder for Graph Embedding.
dataset are achieved with the corresponding models ArXiv180204407 Cs Stat 2019.
implemented in kgcnn. (9) Kipf, T. N.; Welling, M. Variational Graph
Auto-Encoders. ArXiv161107308 Cs Stat 2016.
(10) Yao, S.; Wang, T.; Wan, X. Heterogeneous
Conclusion Graph Transformer for Graph-to-Sequence
Learning. In Proceedings of the 58th Annual
In summary, we discussed a way to integrate graph Meeting of the Association for Computational
convolution models into the TensorFlow-Keras deep Linguistics; Association for Computational
learning framework. Main focus of our kgcnn package is Linguistics: Online, 2020; pp 7145–7154.
the transparency of the tensor representation and the https://doi.org/10.18653/v1/2020.acl-main.640.
(11) Chen, B.; Barzilay, R.; Jaakkola, T. Cs Stat 2019.
Path-Augmented Graph Transformer Network. (26) Lee, J.; Lee, I.; Kang, J. Self-Attention Graph
ArXiv190512712 Cs Stat 2019. Pooling. ArXiv190408082 Cs Stat 2019.
(12) Wu, Z.; Pan, S.; Chen, F.; Long, G.; Zhang, C.; (27) Gao, X.; Xiong, H.; Frossard, P. IPool --
Yu, P. S. A Comprehensive Survey on Graph Information-Based Pooling in Hierarchical Graph
Neural Networks. IEEE Trans. Neural Netw. Neural Networks. ArXiv190700832 Cs Stat 2019.
Learn. Syst. 2020, 1–21. (28) Ma, Y.; Wang, S.; Aggarwal, C. C.; Tang, J.
https://doi.org/10.1109/TNNLS.2020.2978386. Graph Convolutional Networks with
(13) Benchettara, N.; Kanawati, R.; Rouveirol, C. EigenPooling. ArXiv190413107 Cs Stat 2019.
Supervised Machine Learning Applied to Link (29) Defferrard, M.; Bresson, X.; Vandergheynst, P.
Prediction in Bipartite Social Networks. In 2010 Convolutional Neural Networks on Graphs with
International Conference on Advances in Social Fast Localized Spectral Filtering.
Networks Analysis and Mining; 2010; pp ArXiv160609375 Cs Stat 2017.
326–330. (30) Rhee, S.; Seo, S.; Kim, S. Hybrid Approach of
https://doi.org/10.1109/ASONAM.2010.87. Relation Network and Localized Graph
(14) Angelova, R.; Weikum, G. Graph-Based Text Convolutional Filtering for Breast Cancer
Classification: Learn from Your Neighbors. In Subtype Classification. ArXiv171105859 Cs
Proceedings of the 29th annual international 2018.
ACM SIGIR conference on Research and (31) Dhillon, I. S.; Guan, Y.; Kulis, B. Weighted Graph
development in information retrieval; SIGIR ’06; Cuts without Eigenvectors A Multilevel Approach.
Association for Computing Machinery: New York, IEEE Trans. Pattern Anal. Mach. Intell. 2007, 29
NY, USA, 2006; pp 485–492. (11), 1944–1957.
https://doi.org/10.1145/1148170.1148254. https://doi.org/10.1109/TPAMI.2007.1115.
(15) Rousseau, F.; Kiagias, E.; Vazirgiannis, M. Text (32) Simonovsky, M.; Komodakis, N. Dynamic
Categorization as a Graph Classification Edge-Conditioned Filters in Convolutional Neural
Problem. In Proceedings of the 53rd Annual Networks on Graphs. ArXiv170402901 Cs 2017.
Meeting of the Association for Computational (33) Abadi, M.; Barham, P.; Chen, J.; Chen, Z.; Davis,
Linguistics and the 7th International Joint A.; Dean, J.; Devin, M.; Ghemawat, S.; Irving, G.;
Conference on Natural Language Processing Isard, M.; Kudlur, M.; Levenberg, J.; Monga, R.;
(Volume 1: Long Papers); Association for Moore, S.; Murray, D. G.; Steiner, B.; Tucker, P.;
Computational Linguistics: Beijing, China, 2015; Vasudevan, V.; Warden, P.; Wicke, M.; Yu, Y.;
pp 1702–1712. Zheng, X. TensorFlow: A System for Large-Scale
https://doi.org/10.3115/v1/P15-1164. Machine Learning; 2016; pp 265–283.
(16) Gilmer, J.; Schoenholz, S. S.; Riley, P. F.; Vinyals, (34) Abadi, M.; Agarwal, A.; Barham, P.; Brevdo, E.;
O.; Dahl, G. E. Neural Message Passing for Chen, Z.; Citro, C.; Corrado, G. S.; Davis, A.;
Quantum Chemistry. ArXiv170401212 Cs 2017. Dean, J.; Devin, M.; Ghemawat, S.; Goodfellow,
(17) Levie, R.; Monti, F.; Bresson, X.; Bronstein, M. M. I.; Harp, A.; Irving, G.; Isard, M.; Jia, Y.;
CayleyNets: Graph Convolutional Neural Jozefowicz, R.; Kaiser, L.; Kudlur, M.; Levenberg,
Networks With Complex Rational Spectral Filters. J.; Mane, D.; Monga, R.; Moore, S.; Murray, D.;
IEEE Trans. Signal Process. 2019, 67 (1), Olah, C.; Schuster, M.; Shlens, J.; Steiner, B.;
97–109. Sutskever, I.; Talwar, K.; Tucker, P.; Vanhoucke,
https://doi.org/10.1109/TSP.2018.2879624. V.; Vasudevan, V.; Viegas, F.; Vinyals, O.;
(18) Flam-Shepherd, D.; Wu, T.; Friederich, P.; Warden, P.; Wattenberg, M.; Wicke, M.; Yu, Y.;
Aspuru-Guzik, A. Neural Message Passing on Zheng, X. TensorFlow: Large-Scale Machine
High Order Paths. ArXiv200210413 Cs Stat Learning on Heterogeneous Distributed Systems.
2020. ArXiv160304467 Cs 2016.
(19) Yan, T.; Zhang, H.; Li, Z.; Xia, Y. Stochastic (35) van Merriënboer, B.; Bahdanau, D.; Dumoulin, V.;
Graph Recurrent Neural Network. Serdyuk, D.; Warde-Farley, D.; Chorowski, J.;
ArXiv200900538 Cs Stat 2020. Bengio, Y. Blocks and Fuel: Frameworks for
(20) Chen, C.; Ye, W.; Zuo, Y.; Zheng, C.; Ong, S. P. Deep Learning. ArXiv150600619 Cs Stat 2015.
Graph Networks as a Universal Machine (36) Chollet, F. Keras; GitHub, 2015.
Learning Framework for Molecules and Crystals. (37) Paszke, A.; Gross, S.; Massa, F.; Lerer, A.;
Chem. Mater. 2019, 31 (9), 3564–3572. Bradbury, J.; Chanan, G.; Killeen, T.; Lin, Z.;
https://doi.org/10.1021/acs.chemmater.9b01294. Gimelshein, N.; Antiga, L.; Desmaison, A.; Köpf,
(21) Veličković, P.; Cucurull, G.; Casanova, A.; A.; Yang, E.; DeVito, Z.; Raison, M.; Tejani, A.;
Romero, A.; Liò, P.; Bengio, Y. Graph Attention Chilamkurthy, S.; Steiner, B.; Fang, L.; Bai, J.;
Networks. ArXiv171010903 Cs Stat 2018. Chintala, S. PyTorch: An Imperative Style,
(22) Vinyals, O.; Bengio, S.; Kudlur, M. Order Matters: High-Performance Deep Learning Library.
Sequence to Sequence for Sets. ArXiv191201703 Cs Stat 2019.
ArXiv151106391 Cs Stat 2016. (38) Fey, M.; Lenssen, J. E. Fast Graph
(23) Ying, R.; You, J.; Morris, C.; Ren, X.; Hamilton, Representation Learning with PyTorch
W. L.; Leskovec, J. Hierarchical Graph Geometric. ArXiv190302428 Cs Stat 2019.
Representation Learning with Differentiable (39) Wang, M.; Zheng, D.; Ye, Z.; Gan, Q.; Li, M.;
Pooling. ArXiv180608804 Cs Stat 2019. Song, X.; Zhou, J.; Ma, C.; Yu, L.; Gai, Y.; Xiao,
(24) Diehl, F. Edge Contraction Pooling for Graph T.; He, T.; Karypis, G.; Li, J.; Zhang, Z. Deep
Neural Networks. ArXiv190510990 Cs Stat 2019. Graph Library: A Graph-Centric,
(25) Gao, H.; Ji, S. Graph U-Nets. ArXiv190505178 Highly-Performant Package for Graph Neural
Networks. ArXiv190901315 Cs Stat 2020.
(40) Li, Y.; Vinyals, O.; Dyer, C.; Pascanu, R.;
Battaglia, P. Learning Deep Generative Models of
Graphs. ArXiv180303324 Cs Stat 2018.
(41) Sabour, S.; Frosst, N.; Hinton, G. E. Dynamic
Routing Between Capsules. ArXiv171009829 Cs
2017.
(42) Vaswani, A.; Shazeer, N.; Parmar, N.; Uszkoreit,
J.; Jones, L.; Gomez, A. N.; Kaiser, L.;
Polosukhin, I. Attention Is All You Need.
ArXiv170603762 Cs 2017.
(43) Grattarola, D.; Alippi, C. Graph Neural Networks
in TensorFlow and Keras with Spektral.
ArXiv200612138 Cs Stat 2020.
(44) CSIRO’s Data61. StellarGraph Machine Learning
Library; 2018.
(45) Sen, P.; Namata, G.; Bilgic, M.; Getoor, L.;
Galligher, B.; Eliassi-Rad, T. Collective
Classification in Network Data. AI Mag. 2008, 29
(3), 93–93.
https://doi.org/10.1609/aimag.v29i3.2157.
(46) Debnath, A. K.; Lopez de Compadre, R. L.;
Debnath, G.; Shusterman, A. J.; Hansch, C.
Structure-Activity Relationship of Mutagenic
Aromatic and Heteroaromatic Nitro Compounds.
Correlation with Molecular Orbital Energies and
Hydrophobicity. J. Med. Chem. 1991, 34 (2),
786–797. https://doi.org/10.1021/jm00106a046.
(47) Ramakrishnan, R.; Dral, P. O.; Rupp, M.; von
Lilienfeld, O. A. Quantum Chemistry Structures
and Properties of 134 Kilo Molecules. Sci. Data
2014, 1 (1), 140022.
https://doi.org/10.1038/sdata.2014.22.
(48) Ying, R.; Bourgeois, D.; You, J.; Zitnik, M.;
Leskovec, J. GNNExplainer: Generating
Explanations for Graph Neural Networks.
ArXiv190303894 Cs Stat 2019.

You might also like