GNN MetaLayer
GNN MetaLayer
July 3, 2024
[1]: import os
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
from torch.nn import Linear, ReLU, BatchNorm1d, Module, Sequential
from torch import Tensor
torch.set_default_dtype(torch.float64)
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data import Batch
import torch_geometric.transforms as T
from torch_geometric.utils import remove_self_loops, to_dense_adj,␣
↪dense_to_sparse, to_undirected
1
from torch_scatter import scatter
from torch_cluster import knn
"""
Inputs:
size - Number of data points we want to generate
std - Standard deviation of the noise (see generate_continuous_xor␣
↪function)
"""
super(Jet_Dataset, self).__init__()
self.dataset = uproot.open(dataset_path)
self.tree = self.dataset[tree_name].arrays()
self.num_entries = self.dataset[tree_name].num_entries
self.part_feat = self.dataset[tree_name].keys(filter_name='part_*')
self.jet_feat = self.dataset[tree_name].keys(filter_name='jet_*')
self.labels = self.dataset[tree_name].keys(filter_name='labels_*')
self.k = k
2
#self.pc_dataset = [ self.transform_jet_to_point_cloud(idx) for idx in␣
↪range(self.num_entries-1) ]
npart = self.tree['jet_nparticles'].to_numpy()[idx:idx+1]
part_feat_list = [ak.flatten(self.tree[part_feat][idx:idx+1]).
↪to_numpy() for part_feat in self.part_feat]
jet_pt = self.tree['jet_pt'].to_numpy()[idx:idx+1]
jet_eta = self.tree['jet_eta'].to_numpy()[idx:idx+1]
jet_phi = self.tree['jet_phi'].to_numpy()[idx:idx+1]
jet_energy = self.tree['jet_energy'].to_numpy()[idx:idx+1]
jet_tau21 = self.tree['jet_tau2'].to_numpy()[idx:idx+1]/self.
↪tree['jet_tau1'].to_numpy()[idx:idx+1]
jet_tau32 = self.tree['jet_tau3'].to_numpy()[idx:idx+1]/self.
↪tree['jet_tau2'].to_numpy()[idx:idx+1]
jet_tau43 = self.tree['jet_tau4'].to_numpy()[idx:idx+1]/self.
↪tree['jet_tau3'].to_numpy()[idx:idx+1]
jet_sd_mass = self.tree['jet_sdmass'].to_numpy()[idx:idx+1]
part_feat = np.stack(part_feat_list).T
total_jet_feat[np.isnan(total_jet_feat)] = 0.
jet_class = -1
if(self.tree['label_QCD'].to_numpy()[idx:idx+1] == 1) : jet_class = 0
if( (self.tree['label_Tbqq'].to_numpy()[idx:idx+1] == 1) or
(self.tree['label_Tbl'].to_numpy()[idx:idx+1] == 1)) : jet_class = 2
3
if( (self.tree['label_Zqq'].to_numpy()[idx:idx+1] == 1) or
(self.tree['label_Wqq'].to_numpy()[idx:idx+1] == 1)) : jet_class = 0
data.label = torch.tensor([jet_class])
data.sd_mass = torch.tensor(jet_sd_mass)
data.global_data = torch.tensor(jet_feat)
data.seq_length = torch.tensor(npart)
return data
return self.num_entries#len(self.pc_dataset)
4
return self.transform_jet_to_point_cloud(idx)#self.
↪ pc_dataset[idx]#data_point, data_label
layers = []
layers.append(nn.Linear(inputsize,features[0]))
layers.append(nn.ReLU())
for hidden_i in range(1,len(features)):
if add_batch_norm:
layers.append(nn.BatchNorm1d(features[hidden_i-1]))
layers.append(nn.Linear(features[hidden_i-1],features[hidden_i]))
layers.append(nn.ReLU())
layers.append(nn.Linear(features[-1],outputsize))
if add_activation!=None:
layers.append(add_activation)
return nn.Sequential(*layers)
[4]: Sequential(
(0): Linear(in_features=3, out_features=5, bias=True)
(1): ReLU()
(2): Linear(in_features=5, out_features=6, bias=True)
(3): ReLU()
(4): Linear(in_features=6, out_features=3, bias=True)
(5): ReLU()
(6): Linear(in_features=3, out_features=4, bias=True)
)
jet_dataset = Jet_Dataset(dataset_path=file_name)
[9]: gr_b
5
[9]: DataBatch(x=[149, 16], edge_index=[2, 745], edge_deltaR=[745, 1], label=[5],
sd_mass=[5], global_data=[5, 7], seq_length=[5], batch=[149], ptr=[6])
Args:
edge_model (torch.nn.Module, optional): A callable which updates a
graph's edge features based on its source and target node features,
its current edge features and its global features.
(default: :obj:`None`)
node_model (torch.nn.Module, optional): A callable which updates a
graph's node features based on its current node features, its graph
connectivity, its edge features and its global features.
(default: :obj:`None`)
global_model (torch.nn.Module, optional): A callable which updates a
6
graph's global features based on its node features, its graph
connectivity, its edge features and its current global features.
(default: :obj:`None`)
.. code-block:: python
class EdgeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.edge_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...))
class NodeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.node_mlp_1 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))
self.node_mlp_2 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))
class GlobalModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.global_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...))
7
def forward(self, x, edge_index, edge_attr, u, batch):
# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u]
# batch: [N] with max entry B - 1.
out = torch.cat([
u,
scatter(x, batch, dim=0, reduce='mean'),
], dim=1)
return self.global_mlp(out)
#self.reset_parameters()
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
for item in [self.node_model, self.edge_model, self.global_model]:
if hasattr(item, 'reset_parameters'):
item.reset_parameters()
def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
u: Optional[Tensor] = None,
batch: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
r"""
Args:
x (torch.Tensor): The node features.
8
edge_index (torch.Tensor): The edge indices.
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:`None`)
u (torch.Tensor, optional): The global graph features.
(default: :obj:`None`)
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each node to a specific graph. (default: :obj:`None`)
"""
row = edge_index[0]
col = edge_index[1]
return x, edge_attr, u
super(EdgeModel, self).__init__()
self.edge_mlp =␣
↪build_mlp(inputsize=2*node_dim+global_dim+input_edge_dim,␣
↪outputsize=output_edge_dim, features=features)
9
# batch: [E] with max entry B - 1.
out = torch.cat([src, dst, edge_attr, u[edge_batch]], dim=1)
return self.edge_mlp(out)
[14]: (tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4]),
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4]))
[15]: edge_batch
[15]: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
10
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4])
[16]: print('src shape : ', src.shape, ' dst shape : ', dst.shape)
print('edge attribute shape : ', edge_attr.shape)
print('global attribute shape : ', u.shape)
print('batch : ', edge_batch.shape)
print('global data edge replicated shape : ', u[edge_batch].shape)
[ ]:
11
1.3.2 Now let’s build a node_network
super(NodeModel, self).__init__()
self.node_mlp =␣
↪build_mlp(inputsize=input_edge_dim+input_node_dim+global_dim,␣
↪outputsize=output_node_dim, features=features)
[20]: updated_edge.shape[-1]
[20]: 3
↪features=[3,4,5])
super(GlobalModel, self).__init__()
12
self.global_mlp =␣
↪build_mlp(inputsize=input_edge_dim+input_node_dim+input_global_dim,␣
↪outputsize=output_global_dim, features=features)
out = torch.cat([
u,
scatter(x, batch, dim=0, reduce='mean'), #aggrigation over all nodes
scatter(edge_attr, batch[src_idx], dim=0, reduce='mean')␣
↪#aggregation over edges
], dim=1)
return self.global_mlp(out)
[27]: updated_global_data.shape
13
2.1 HW : Make a GNN model by stacking two Meta-Layers and setup a model
which will identify if the nodes are hadrons or not
[ ]:
[ ]:
14