0% found this document useful (0 votes)
29 views14 pages

GNN MetaLayer

Generative Neural Network

Uploaded by

MaxImus AlphA
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
29 views14 pages

GNN MetaLayer

Generative Neural Network

Uploaded by

MaxImus AlphA
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 14

GNN_MetaLayer

July 3, 2024

[1]: import os
import time
import random
import numpy as np

from scipy.stats import ortho_group

from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
from torch.nn import Linear, ReLU, BatchNorm1d, Module, Sequential
from torch import Tensor

torch.set_default_dtype(torch.float64)

from torch_geometric.typing import (


Adj,
OptPairTensor,
OptTensor,
Size,
SparseTensor,
torch_sparse,
)

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data import Batch
import torch_geometric.transforms as T
from torch_geometric.utils import remove_self_loops, to_dense_adj,␣
↪dense_to_sparse, to_undirected

from torch_geometric.loader import DataLoader


from torch_geometric.nn import MessagePassing, global_mean_pool, knn_graph
from torch_geometric.datasets import QM9

1
from torch_scatter import scatter
from torch_cluster import knn

import matplotlib.pyplot as plt


import seaborn as sns
import pandas as pd
import uproot
import vector
vector.register_awkward()
import awkward as ak

from IPython.display import HTML

print("PyTorch version {}".format(torch.__version__))


print("PyG version {}".format(torch_geometric.__version__))

PyTorch version 2.3.1.post100


PyG version 2.5.3

1 Let’s create a dataset where we put Jet properties as global data


[2]: class Jet_Dataset(data.Dataset):

def __init__(self, dataset_path:str, tree_name:str = 'tree', k:int = 5) ->␣


↪None:

"""
Inputs:
size - Number of data points we want to generate
std - Standard deviation of the noise (see generate_continuous_xor␣
↪function)

"""
super(Jet_Dataset, self).__init__()

self.dataset = uproot.open(dataset_path)
self.tree = self.dataset[tree_name].arrays()

self.num_entries = self.dataset[tree_name].num_entries

self.part_feat = self.dataset[tree_name].keys(filter_name='part_*')
self.jet_feat = self.dataset[tree_name].keys(filter_name='jet_*')
self.labels = self.dataset[tree_name].keys(filter_name='labels_*')

self.k = k

2
#self.pc_dataset = [ self.transform_jet_to_point_cloud(idx) for idx in␣
↪range(self.num_entries-1) ]

def transform_jet_to_point_cloud(self, idx:int) -> Data :

npart = self.tree['jet_nparticles'].to_numpy()[idx:idx+1]

part_feat_list = [ak.flatten(self.tree[part_feat][idx:idx+1]).
↪to_numpy() for part_feat in self.part_feat]

jet_pt = self.tree['jet_pt'].to_numpy()[idx:idx+1]
jet_eta = self.tree['jet_eta'].to_numpy()[idx:idx+1]
jet_phi = self.tree['jet_phi'].to_numpy()[idx:idx+1]
jet_energy = self.tree['jet_energy'].to_numpy()[idx:idx+1]
jet_tau21 = self.tree['jet_tau2'].to_numpy()[idx:idx+1]/self.
↪tree['jet_tau1'].to_numpy()[idx:idx+1]

jet_tau32 = self.tree['jet_tau3'].to_numpy()[idx:idx+1]/self.
↪tree['jet_tau2'].to_numpy()[idx:idx+1]

jet_tau43 = self.tree['jet_tau4'].to_numpy()[idx:idx+1]/self.
↪tree['jet_tau3'].to_numpy()[idx:idx+1]

jet_sd_mass = self.tree['jet_sdmass'].to_numpy()[idx:idx+1]

jet_feat = np.stack([jet_pt, jet_eta, jet_phi, jet_energy, jet_tau21,␣


↪jet_tau32, jet_tau43]).T

#jet_feat = np.repeat(jet_feat, int(npart), axis=0)

part_feat = np.stack(part_feat_list).T

total_jet_feat = part_feat #np.concatenate((part_feat, jet_feat),␣


↪axis=-1)

total_jet_feat[np.isnan(total_jet_feat)] = 0.

#print(type(total_jet_feat), 'total_jet_feat shape : ', total_jet_feat.


↪shape)

jet_class = -1

if(self.tree['label_QCD'].to_numpy()[idx:idx+1] == 1) : jet_class = 0

if( (self.tree['label_Tbqq'].to_numpy()[idx:idx+1] == 1) or
(self.tree['label_Tbl'].to_numpy()[idx:idx+1] == 1)) : jet_class = 2

3
if( (self.tree['label_Zqq'].to_numpy()[idx:idx+1] == 1) or
(self.tree['label_Wqq'].to_numpy()[idx:idx+1] == 1)) : jet_class = 0

if( (self.tree['label_Hbb'].to_numpy()[idx:idx+1] == True) or


(self.tree['label_Hcc'].to_numpy()[idx:idx+1] == True) or
(self.tree['label_Hgg'].to_numpy()[idx:idx+1] == True) or
(self.tree['label_H4q'].to_numpy()[idx:idx+1] == True) or
(self.tree['label_Hqql'].to_numpy()[idx:idx+1] == True) ) :␣
↪jet_class = 1

part_eta = torch.tensor( ak.flatten(self.tree['part_deta'][idx:idx+1]).


↪to_numpy() )
part_phi = torch.tensor( ak.flatten(self.tree['part_dphi'][idx:idx+1]).
↪to_numpy() )

eta_phi_pos = torch.stack([part_eta, part_phi], dim=-1)

edge_index = torch_geometric.nn.pool.knn_graph(x = eta_phi_pos, k =␣


↪self.k)

src, dst = edge_index

part_del_eta = part_eta[dst] - part_eta[src]


part_del_phi = part_phi[dst] - part_phi[src]

part_del_R = torch.hypot(part_del_eta, part_del_phi).view(-1, 1) # --␣


↪why do we need this view function ?

data = Data(x=torch.tensor(total_jet_feat), edge_index=edge_index,␣


↪edge_deltaR = part_del_R)

data.label = torch.tensor([jet_class])
data.sd_mass = torch.tensor(jet_sd_mass)
data.global_data = torch.tensor(jet_feat)
data.seq_length = torch.tensor(npart)

return data

def __len__(self) -> int:


# Number of data point we have. Alternatively self.data.shape[0], or␣
↪self.label.shape[0]

return self.num_entries#len(self.pc_dataset)

def __getitem__(self, idx:int) -> Data :


# Return the idx-th data point of the dataset

4
return self.transform_jet_to_point_cloud(idx)#self.
↪ pc_dataset[idx]#data_point, data_label

1.1 For later convenience we build a function to make MLP


[3]: def␣
↪build_mlp(inputsize,outputsize,features,add_batch_norm=False,add_activation=None):

layers = []
layers.append(nn.Linear(inputsize,features[0]))
layers.append(nn.ReLU())
for hidden_i in range(1,len(features)):
if add_batch_norm:
layers.append(nn.BatchNorm1d(features[hidden_i-1]))
layers.append(nn.Linear(features[hidden_i-1],features[hidden_i]))
layers.append(nn.ReLU())
layers.append(nn.Linear(features[-1],outputsize))
if add_activation!=None:
layers.append(add_activation)
return nn.Sequential(*layers)

[4]: build_mlp(inputsize=3, outputsize=4, features=[5, 6, 3])

[4]: Sequential(
(0): Linear(in_features=3, out_features=5, bias=True)
(1): ReLU()
(2): Linear(in_features=5, out_features=6, bias=True)
(3): ReLU()
(4): Linear(in_features=6, out_features=3, bias=True)
(5): ReLU()
(6): Linear(in_features=3, out_features=4, bias=True)
)

[6]: dataset_path = '/home/swadhin/miniconda3/MLSCHOOL_IOPB_2024-main (1)/


↪MLSCHOOL_IOPB_2024-main/Lecture2/'

file_name = dataset_path + 'JetClass_example_100k.root' # -- from -- "https://


↪hqu.web.cern.ch/datasets/JetClass/example/" #

jet_dataset = Jet_Dataset(dataset_path=file_name)

[7]: data_loader = DataLoader(dataset=jet_dataset, batch_size=5, shuffle = True)

[8]: gr_b = next(iter(data_loader))

[9]: gr_b

5
[9]: DataBatch(x=[149, 16], edge_index=[2, 745], edge_deltaR=[745, 1], label=[5],
sd_mass=[5], global_data=[5, 7], seq_length=[5], batch=[149], ptr=[6])

1.2 What is graph network?


Here, we recapitulate the “graph network” (GN) formalism {cite:p}battaglia2018relational,
which generalizes various GNNs and other similar methods. GNs are graph-to-graph mappings,
whose output graphs have the same node and edge structure as the input. Formally, a GN block
contains three “update” functions, 𝜙, and three “aggregation” functions, 𝜌. The stages of processing
in a single GN block are:
where 𝐸𝑖′ = {(e′𝑘 , 𝑟𝑘 , 𝑠𝑘 )}𝑟 =𝑖, 𝑘=1∶𝑁 𝑒 contains the updated edge features for edges whose receiver
𝑘
node is the 𝑖th node, 𝐸 ′ = ⋃𝑖 𝐸𝑖′ = {(e′𝑘 , 𝑟𝑘 , 𝑠𝑘 )}𝑘=1∶𝑁 𝑒 is the set of updated edges, and 𝑉 ′ =
{v′𝑖 }𝑖=1∶𝑁 𝑣 is the set of updated nodes.

1.3 Building a MetaLayer


from https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/meta.html#Met

[10]: class MetaLayer(torch.nn.Module):


r"""A meta layer for building any kind of graph network, inspired by the
`"Relational Inductive Biases, Deep Learning, and Graph Networks"
<https://arxiv.org/abs/1806.01261>`_ paper.

A graph network takes a graph as input and returns an updated graph as


output (with same connectivity).
The input graph has node features :obj:`x`, edge features :obj:`edge_attr`
as well as graph-level features :obj:`u`.
The output graph has the same structure, but updated features.

Edge features, node features as well as global features are updated by


calling the modules :obj:`edge_model`, :obj:`node_model` and
:obj:`global_model`, respectively.

To allow for batch-wise graph processing, all callable functions take an


additional argument :obj:`batch`, which determines the assignment of
edges or nodes to their specific graphs.

Args:
edge_model (torch.nn.Module, optional): A callable which updates a
graph's edge features based on its source and target node features,
its current edge features and its global features.
(default: :obj:`None`)
node_model (torch.nn.Module, optional): A callable which updates a
graph's node features based on its current node features, its graph
connectivity, its edge features and its global features.
(default: :obj:`None`)
global_model (torch.nn.Module, optional): A callable which updates a

6
graph's global features based on its node features, its graph
connectivity, its edge features and its current global features.
(default: :obj:`None`)

.. code-block:: python

from torch.nn import Sequential as Seq, Linear as Lin, ReLU


from torch_geometric.utils import scatter
from torch_geometric.nn import MetaLayer

class EdgeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.edge_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...))

def forward(self, src, dst, edge_attr, u, batch):


# src, dst: [E, F_x], where E is the number of edges.
# edge_attr: [E, F_e]
# u: [B, F_u], where B is the number of graphs.
# batch: [E] with max entry B - 1.
out = torch.cat([src, dst, edge_attr, u[batch]], 1)
return self.edge_mlp(out)

class NodeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.node_mlp_1 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))
self.node_mlp_2 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))

def forward(self, x, edge_index, edge_attr, u, batch):


# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u]
# batch: [N] with max entry B - 1.
row, col = edge_index
out = torch.cat([x[row], edge_attr], dim=1)
out = self.node_mlp_1(out)
out = scatter(out, col, dim=0, dim_size=x.size(0),
reduce='mean')
out = torch.cat([x, out, u[batch]], dim=1)
return self.node_mlp_2(out)

class GlobalModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.global_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...))

7
def forward(self, x, edge_index, edge_attr, u, batch):
# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u]
# batch: [N] with max entry B - 1.
out = torch.cat([
u,
scatter(x, batch, dim=0, reduce='mean'),
], dim=1)
return self.global_mlp(out)

op = MetaLayer(EdgeModel(), NodeModel(), GlobalModel())


x, edge_attr, u = op(x, edge_index, edge_attr, u, batch)
"""
def __init__(
self,
edge_model: Optional[torch.nn.Module] = None,
node_model: Optional[torch.nn.Module] = None,
global_model: Optional[torch.nn.Module] = None,
):
super(MetaLayer, self).__init__()
self.edge_model = edge_model
self.node_model = node_model
self.global_model = global_model

#self.reset_parameters()

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
for item in [self.node_model, self.edge_model, self.global_model]:
if hasattr(item, 'reset_parameters'):
item.reset_parameters()

def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
u: Optional[Tensor] = None,
batch: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
r"""
Args:
x (torch.Tensor): The node features.

8
edge_index (torch.Tensor): The edge indices.
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:`None`)
u (torch.Tensor, optional): The global graph features.
(default: :obj:`None`)
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each node to a specific graph. (default: :obj:`None`)
"""
row = edge_index[0]
col = edge_index[1]

if self.edge_model is not None:


edge_attr = self.edge_model(x[row], x[col], edge_attr, u,
batch if batch is None else batch[row])

if self.node_model is not None:


x = self.node_model(x, edge_index, edge_attr, u, batch)

if self.global_model is not None:


u = self.global_model(x, edge_index, edge_attr, u, batch)

return x, edge_attr, u

def __repr__(self) -> str:


return (f'{self.__class__.__name__}(\n'
f' edge_model={self.edge_model},\n'
f' node_model={self.node_model},\n'
f' global_model={self.global_model}\n'
f')')

1.3.1 Let’s declare the edge_network


[11]: class EdgeModel(nn.Module):
def __init__(self, input_edge_dim:int, output_edge_dim:int, node_dim:int,␣
↪global_dim:int, features:list ):

super(EdgeModel, self).__init__()
self.edge_mlp =␣
↪build_mlp(inputsize=2*node_dim+global_dim+input_edge_dim,␣

↪outputsize=output_edge_dim, features=features)

def forward(self, src, dst, edge_attr, u, edge_batch):


# src, dst: [E, F_x], where E is the number of edges.
# edge_attr: [E, F_e]
# u: [B, F_u], where B is the number of graphs.

9
# batch: [E] with max entry B - 1.
out = torch.cat([src, dst, edge_attr, u[edge_batch]], dim=1)
return self.edge_mlp(out)

[12]: gr_b.num_edges, gr_b.num_nodes

[12]: (745, 149)

[13]: src_idx, dst_idx = gr_b.edge_index


src, dst = gr_b.x[src_idx], gr_b.x[dst_idx]
edge_attr = gr_b.edge_deltaR
u = gr_b.global_data
batch = gr_b.batch
node_batch = gr_b.batch
edge_batch = gr_b.batch[src_idx]

[14]: batch, node_batch

[14]: (tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4]),
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4]))

[15]: edge_batch

[15]: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

10
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4])

[16]: print('src shape : ', src.shape, ' dst shape : ', dst.shape)
print('edge attribute shape : ', edge_attr.shape)
print('global attribute shape : ', u.shape)
print('batch : ', edge_batch.shape)
print('global data edge replicated shape : ', u[edge_batch].shape)

src shape : torch.Size([745, 16]) dst shape : torch.Size([745, 16])


edge attribute shape : torch.Size([745, 1])
global attribute shape : torch.Size([5, 7])
batch : torch.Size([745])
global data edge replicated shape : torch.Size([745, 7])

[17]: edge_network = EdgeModel(input_edge_dim=1, output_edge_dim=3, node_dim=16,␣


↪global_dim=7, features=[3,4,2])

[18]: updated_edge = edge_network(src=src, dst=dst, edge_attr=edge_attr, u=u,␣


↪edge_batch=edge_batch)

print('updated_edge shape : ', updated_edge.shape)

updated_edge shape : torch.Size([745, 3])

[ ]:

11
1.3.2 Now let’s build a node_network

[19]: class NodeModel(torch.nn.Module):


def __init__(self, input_edge_dim:int, input_node_dim:int, output_node_dim:
↪int, global_dim:int, features:list):

super(NodeModel, self).__init__()
self.node_mlp =␣
↪build_mlp(inputsize=input_edge_dim+input_node_dim+global_dim,␣

↪outputsize=output_node_dim, features=features)

def forward(self, x, edge_index, edge_attr, u, batch):


# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u]
# batch: [N] with max entry B - 1.
row, col = edge_index

out = scatter(edge_attr, col, dim=0, dim_size=x.size(0),


reduce='mean') #Vi_prime is calculated by this scatter␣
↪function

print('Agrregated out shape : ', out.shape)

out = torch.cat([x, out, u[batch]], dim=1)


print('Stacked out shape : ', out.shape)
return self.node_mlp(out) #Vi_prime

[20]: updated_edge.shape[-1]

[20]: 3

[21]: node_network = NodeModel(input_edge_dim = updated_edge.shape[-1],␣


↪input_node_dim=gr_b.x.shape[-1], output_node_dim=4, global_dim=u.shape[-1],␣

↪features=[3,4,5])

[23]: updated_node = node_network(gr_b.x, edge_index=gr_b.


↪edge_index,edge_attr=updated_edge, u=u, batch=node_batch)

Agrregated out shape : torch.Size([149, 3])


Stacked out shape : torch.Size([149, 26])

1.4 Finally the global update network


[24]: class GlobalModel(torch.nn.Module):
def __init__(self, input_edge_dim:int, input_node_dim:int, input_global_dim:
↪int, output_global_dim:int, features:list):

super(GlobalModel, self).__init__()

12
self.global_mlp =␣
↪build_mlp(inputsize=input_edge_dim+input_node_dim+input_global_dim,␣
↪outputsize=output_global_dim, features=features)

def forward(self, x, edge_index, edge_attr, u, batch):


# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u]
# batch: [N] with max entry B - 1.
src_idx, dst_idx = edge_index

out = torch.cat([
u,
scatter(x, batch, dim=0, reduce='mean'), #aggrigation over all nodes
scatter(edge_attr, batch[src_idx], dim=0, reduce='mean')␣
↪#aggregation over edges

], dim=1)
return self.global_mlp(out)

[25]: global_network = GlobalModel(input_edge_dim=3, input_node_dim=4,␣


↪input_global_dim=7, output_global_dim=5, features=[3,4,2])

[26]: updated_global_data = global_network(x=updated_node, edge_index=gr_b.


↪edge_index,edge_attr=updated_edge, u=u, batch=batch)

[27]: updated_global_data.shape

[27]: torch.Size([5, 5])

2 The full GNN model at one go


[28]: gnn_layer = MetaLayer(edge_model=edge_network,
node_model=node_network,
global_model=global_network)

2.0.1 Comment : By construction, edge_model, node_model and gobal_model can


be instances of MessagePassing layer

[29]: x1, edge_attr1, u1 = gnn_layer(x=gr_b.x, edge_index=gr_b.


↪edge_index,edge_attr=gr_b.edge_deltaR, u=u, batch=batch)

Agrregated out shape : torch.Size([149, 3])


Stacked out shape : torch.Size([149, 26])

13
2.1 HW : Make a GNN model by stacking two Meta-Layers and setup a model
which will identify if the nodes are hadrons or not
[ ]:

[ ]:

14

You might also like