0% found this document useful (0 votes)
14 views28 pages

GNNs

The document provides an overview of Graph Neural Networks (GNNs), detailing their input, output, goals, and the message-passing mechanism that enables node feature aggregation. It discusses various types of GNNs, their applications in drug discovery and biomedical research, as well as challenges such as oversmoothing and overfitting, along with recent innovations aimed at addressing these issues. The presentation concludes that GNNs are powerful tools in machine learning with growing applications, although caution is needed to avoid classification impairments.

Uploaded by

Juan Placer
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
14 views28 pages

GNNs

The document provides an overview of Graph Neural Networks (GNNs), detailing their input, output, goals, and the message-passing mechanism that enables node feature aggregation. It discusses various types of GNNs, their applications in drug discovery and biomedical research, as well as challenges such as oversmoothing and overfitting, along with recent innovations aimed at addressing these issues. The presentation concludes that GNNs are powerful tools in machine learning with growing applications, although caution is needed to avoid classification impairments.

Uploaded by

Juan Placer
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 28

Graph Neural Networks

David Sahner, M.D.


Senior Advisor, National Center for the Advancement of Translational Sciences, and
Executive Advisor, Data Science and Translational Medicine, Axle Research and Technologies
Presentation Outline
• Overview of Graph Neural Net (GNN) input, output, goals and underlying “intuition”

• Description of a message-passing GNN

• Message passing, objective function, and optimization

• Major types of graph neural nets

• GNN caveats, insights, and recent innovations

• Applications in drug/biologic discovery and biomedical research

• Conclusions
• INPUT: GNNs take as input a graph
GNNs: with nodes (examples) and edges.
Input,
output & • OUTPUT & GOALS: GNNs embed
(encode) examples as vectors in
goals vector space such that similar or
related examples tend to cluster.
This enables us to classify
unlabeled examples or predict links.
GNNs: The Intuition
This is iterated layer-wise,
A node (example) “asks” its incorporating information
immediate neighbors about passed to the (“evolving”) Locations of nodes in vector
their features and becomes target node from neighbor space are updated in this
more like them in vector nodes residing at a manner layer by layer.
space. progressively increasing
number of hops away.

Learned embeddings in this


vector space reflect some “Homogenization of
measure of node similarity neighbors” by a GNN may
through use of an appropriate improve the ability to classify
loss function and, e.g., new examples.
stochastic gradient descent.
“Layer 0” Vector Embeddings
k=3
hops • Target node = blue node
x1 k=2 • Neighbors = purple nodes
x2 hops
x3 • First, second- and third-degree
neighbors are 1, 2, and 3 hops away
k=1
from target node, respectively
x1
hop
x1
x2 x2
• All nodes represented by vectors
x3 x3 (only several low-dimensional vectors
depicted here to minimize clutter)
• Vector representations in layer 0 of
the GNN = node features
second
degree x1
x2
third degree
neighbor x3 neighbor
Layer 1 Embedding of Target Node
k=3
hops • Target node receives information
k=2 through “message passing” from all
hops nodes within the smallest concentric
circle (i.e., one hop away)

• Note that the target node also


receives information from itself
k=1
hop • Vector representation of target node
modified to reflect the aggregated
information it has received
Layer 2 Embedding of Target Node
k=3
• Target node now receives information
hops
from nodes that are to two hops away
k=2
hops • Messages (information) passed from
this expanded range of sources are
aggregated and used to update the
embedding of the target node.

• This pattern continues at each


k=1 subsequent layer of the GNN. That is,
hop in layer 3, the target node receives
information from neighbors up to 3
hops away, etc.

• Thus, the target node’s embedding is


updated in every layer of the GNN
How is “information” aggregated and
used to update target node embedding?
Element-wise expectation value
of neighbors’ vectors in prior layer

where ht(l+1) is the vector representation of the target node (t) in layer l+1; ht(l) is the
vector representation of the target node in layer l; N is the set of neighbor nodes in
layer l with cardinality |N|, n is an element of N with vector representation hn(l) in
layer 1, and WN(l) and Wt(l) are learned layer-specific weight matrices.
How does learning take place in
this message passing GNN (I)?
• “Similar” nodes in the original graph should have similar (final) embeddings (z) in
the vector space after all layer-wise transformations in the GNN.

• Similarity in the embedding space is usually assessed with the dot product (cosine
similarity)

• Similarity of nodes in an original graph may be defined in several ways, such as:
• Node u has a high likelihood of being visited during a random walk starting at node v
(this can be truly random walk, or a biased walk geared toward preferential capture of
either local or global aspects of network topology)

• Nodes u and v are connected

• Nodes u and v have overlapping neighborhoods

• Nodes u and v share the same label


How does learning take place in
this message passing GNN (II)?
What is the objective (loss) function we optimize during gradient
descent? A simple cross-entropy (CE) loss function:

Loss = ∑u,vCE(yu,v,zu∙zv)

where yu,v is based on the similarity between nodes u and v in the


original graph and zu∙zv is the dot product of the GNN embeddings in
vector space.

Add stochastic gradient descent, and you’re in business!


Major flavors of graph neural nets
• Prototypical message-passing GNN (covered in prior slides)

• Graph convolutional network (GCN) – example forthcoming in this module

• Graph attention networks (GATs) – leverage self-attention over node


features (see LLM module for in-depth discussion of attention, but, in short,
in GATs, nodes “attend” to their neighbors)

• Each neighbor passes a vector of attentional coefficients to a target node (one


per attention head).

• Aggregation of linearly transformed neighbor features that reflect attention as


well (somewhat analogous to message-passing GNNs)

• Dropout of attentional coefficients as a helpful regularizer.


GNNs: Caveats, insights and
innovations (I)
• GNNs have been criticized for focusing on node level information, leaving
representational information captured by high-level graph topology unexploited

• Ai, et. al. (2024) have proposed melding node-level and higher-level (subgraph)
structural information to produce a more comprehensive representation. Subgraphs
represented as “super-nodes.”

• Typically, the number of labeled nodes used to train a GNN is far smaller than the
number of unlabeled nodes whose labels one wishes to predict (overfitting risk),
and the distribution of labeled training nodes and test nodes may be different (e.g.,
different degrees, dissimilar proportions of neighbors with different labels),
exacerbating the risk of poor prediction generalization

• Fan et al. (2024) propose a new variable “decorrelation regularizer” to mitigate the
effect of this distribution shift while maintaining sample size (this complex approach is
covered in supplementary slides of the regularization module).
GNNs: Caveats, insights and
innovations (II)
• GNNs may promote “over-homogenization” of nodes
(oversmoothing), especially with increasing depth of the network.
• This makes intuitive sense as GNNs induce target nodes to become more
like their neighbors – including, potentially, nodes belonging to other
classes. The oversmoothing issue is aggravated by more layers, which
allow passing of messages from more distant neighbors.
• This can compromise classification performance by GNNs. Inter-class
edges in the graph fed to a GCN are thought to shoulder blame here.
• Wang et al. (2024) proposed “GUIded Dropout over Edges” (GUIDE) to
mitigate this problem. Edge strength (tied to number of times an edge lies
on the shortest paths between all node pairs) used as a proxy for “inter-
class” edges, which are preferentially removed.
GUIDE in Graph-based semi-supervised
learning (GSSL) – Wang et al., 2024

• Goal of GSSL is to predict


labels for unlabeled nodes in
a graph given known labels
for a subset of nodes, feature
matrix, and graph topology

• GUIDE uses edge strength


as a surrogate for inter-class
edges to target for removal.
Edge strength related to
inter-class linkages
• Most edges in Cora dataset have
low edge strength.

• Ratio of intra- to inter-class edges


rises with diminishing edge
strength.

• Orange dots represent intra- to


interclass ratio of edges for each
edge strength quintile

Fig. from Wang et al., 2024


Probabilistic excision of edges
by GUIDE
• Top Equation: Denominator sums over all edge
strengths, resulting in normalization of the specific
edge strength. Probability of dropping an edge is
related to this normalized edge strength.

• Bottom Equation: Key values generated for all edges.


A random real number (r) between 0 and 1 (exclusive)
is selected for each edge. The greater the strength of
the edge, the smaller the exponent. The smaller the
exponent, the larger the key value. Edges with largest
key values excised to create a new graph adjacency
matrix.

• Different edge key values for each training epoch


provide the benefit of dataset augmentation!
Wang et al. 2024
The GCN-based learner module
evaluated by Wang et al.
• Predictions calculated using
Softmax output at right. A-hat is
the adjacency matrix after
selective excision of edges.

• Cross-entropy loss over labeled


nodes used as objective function.

• Note that GUIDE can be used to


prep graphs for various types of
learning algorithms (see next
slide).
GUIDE
Benchmark
Results
• Results from Wang et al., 2024
• Note that GAUG and DropEdge
are other edge dropping or edge
demotion tools
• See Wang et al. (2024) paper for
GUIDE parameter settings and
parameter sensitivity data
GNNs: Potential applications in drug/biologic
discovery and biomedical research
• Applications include prediction of node labels and edges in graphical models

• GNN-learned representations of molecules can be used to help predict the biological


activity (“label”) of an unknown molecule (structure-activity relationship prediction [e.g.,
Wong et al., 2024])

• A bipartite graph, with nodes consisting of patients and chronic diseases, was used to
create a patient network with weights reflecting number of shared diseases. This was fed
to a GNN, with patient features for each node, to make chronic disease predictions (Lu
and Uddin, 2021 – see https://pubmed.ncbi.nlm.nih.gov/34799627/ and supplementary
slide)

• In a novel recent application, GCN leveraging high-dimensional data and a patchwork of


sparse graphs outperformed regularized Cox Proportional Hazards in making survival
predictions (Ling, Liu and Xue, 2024). See https://pubmed.ncbi.nlm.nih.gov/35862325/
and back-up slides.
Case Study: Discovery of a structural class of antibiotics with
explainable deep learning (Wong et al. 2024)
• Feature vector for each atom and bond created
• Bond-based message-passing neural network* used to create molecular
representation, which is passed to a neural net to predict the compound’s
activity against Staphylococcus aureus.
• Model optimizations:
• Additional molecular features added to GNN-based representation
• Grid search to optimize hyperparameters
• Ensemble of individual models created using different splits of training data

• Over 39K compounds, with 20% held-out for testing of the ensemble after it was
trained on all training data
• Predictions then made for over 12 million compounds

*See Yang et al. for details on bond-based message passing (https://pubs.acs.org/doi/10.1021/acs.jcim.9b00237). Note Images from Yang et al.
that atom-based features concatenated with bond features prior to message passing. Skip connections with original
feature vector used. At the end, Yang et al. return to an atom-based representation by summing incoming bond
messages and concatenating with atomic features.
Case Study (continued)
• Wong and colleagues further filtered hits, removing those thought to
have unfavorable medicinal chemistry properties

• For molecules with high predicted activity, Monte Carlo tree search
used to identify subgraphs (i.e., a portion of the molecule, referred to in
the paper as a “rationale”) thought to be responsible for activity

• The team found molecules with presumed “rationales” devoid of


scaffolds known to be associated with antibacterial activity (quinolone
bicyclic core or β-lactam ring). Not all predicted hits had rationales.

• They focused on rationales that were conserved across predicted hits,


identifying five types of potential molecular scaffolds

• 4 of 9 molecules with one of these rationales had anti-staphylococcal


activity and two of these (right) appeared to be drug-like with favorable
predicted bioavailability. They also appeared to have a novel MOA &
absence of cross-resistance with existing classes of antibiotics.
Conclusions
• Various types of graph neural nets have become powerful
additions to our machine learning arsenal
• Applications in discovery and biomedical research are growing,
and include classification and prediction leveraging graph topology
• Deep GNNs can “over-homogenize” samples, impairing
classification accuracy, but new innovations in the field are rife and
promising, including very recent innovation that may mitigate the
risk of oversmoothing.
• In general, however, it is best to use shallower GNN architectures to
minimize this risk of oversmoothing.
Supplementary Slides
Prediction
of chronic
disease
Figure from Lu and Uddin,
2021
Has Cox been outfoxed?
Ling, Liu and Xue, 2024
• If a GCN is to be leveraged in survival analysis, one would like the graph to align well with
survival time (i.e., linked patients [neighbors] have similar survival times)

• Authors found that GCNs fed sparse graphs more precisely identified patients with similar
survival times compared with denser graphs, but since sparse graphs may miss neighbors,
the team “knitted together” multiple sparse continuous k-nearest neighbor graphs to improve
sensitivity

• Two-layer GCN with final fully connected layer outputting risk scores

• After training, create new graph incorporating an unlabeled example, update degree and
adjacency matrices, and use the following to obtain the risk score for the unlabeled example:
Weaving together sparse graphs.
Figure from Ling, Liu and Xue, 2024. Random subsets of features used to create the sparse
graphs. Assess alignment of GCN-generated output with survival time for each. Start with
sparse graph that performs best. Sequentially add sparse graphs until composite graph, when
fed to GCN, does not generate improved alignment using Harrell’s CI.
Results of the
weaving exercise
• 90%-10% train-test split
• Ten-fold cross-validation using
training set
• Grid search for hyperparameter
optimization
• Adam optimizer
• Dropout 0.1
• Early stopping to avoid overfitting

Figure and table from Ling, Liu and Xue, 2024. Tabular data
from cross-validation. AGGSurv is the test method.
Rankings in figure presumably represent test set data (not
specified in paper).

You might also like