0% found this document useful (0 votes)
52 views8 pages

GNN-02-Augmented Notes

This document discusses augmenting notes using graph neural networks. It introduces a python notebook to train and run experiments on augmenting notes by suggesting related Wikipedia articles. The notebook downloads libraries, creates a mini Wikipedia dataset from a subset of pages for demonstration, and includes functions to create a graph from the dataset, remove self-references, and train a doc2vec model to generate feature vectors from page contents.

Uploaded by

vitormeriat
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
52 views8 pages

GNN-02-Augmented Notes

This document discusses augmenting notes using graph neural networks. It introduces a python notebook to train and run experiments on augmenting notes by suggesting related Wikipedia articles. The notebook downloads libraries, creates a mini Wikipedia dataset from a subset of pages for demonstration, and includes functions to create a graph from the dataset, remove self-references, and train a doc2vec model to generate feature vectors from page contents.

Uploaded by

vitormeriat
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 8

11/26/23, 11:35 PM GNN-Augmenting_Notes - Colaboratory

Augmented Notes
This is the python notebook file to train and run experiments from https://medium.com/@arjunkaranam10/augmenting-your-notes-using-graph-
neural-networks-e61f0898033a

The github with more files and information can be found at: https://github.com/QuantumArjun/Augmented-Notes-GNNs

Augmented Notes is a model we have created to give you suggestions of Wikipedia articles given a page of notes. Specifically, we train a Graph
Neural Network on a subset of the Wikipedia dataset, and then predict which nodes are closest to your note (by predicting which nodes your
note is most likely to have a link to).

Dataset Creation
First, we need to create our dataset. As the Wikipedia dataset (and even the subset that we used for our project) is too big to upload to colab,
we'll create a mini-version here, just to demonstrate how it works.

First, we'll download the required libraries!

1 # Data manipulation
2 import pandas as pd
3 import random
4 import json
5 import os
6 import pickle
7 import time
8 # DOC2VEC
9 from gensim.models.doc2vec import Doc2Vec, TaggedDocument
10 import nltk
11 nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...


[nltk_data] Unzipping tokenizers/punkt.zip.
True

1 # Wikipedia API
2 !pip install wikipedia
3 import wikipedia as wp
4 from wikipedia.exceptions import DisambiguationError, PageError

Collecting wikipedia
Downloading wikipedia-1.4.0.tar.gz (27 kB)
Preparing metadata (setup.py) ... done
Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.10/dist-packages (from wikipedia) (4.11.2)
Requirement already satisfied: requests<3.0.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wikipedia) (2.31.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.0.0->wikiped
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.0.0->w
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.0.0->w
Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4->wikipedia) (2.
Building wheels for collected packages: wikipedia
Building wheel for wikipedia (setup.py) ... done
Created wheel for wikipedia: filename=wikipedia-1.4.0-py3-none-any.whl size=11679 sha256=751685c912b19e9210a1af1c12d11442b
Stored in directory: /root/.cache/pip/wheels/5e/b6/c5/93f3dec388ae76edc830cb42901bb0232504dfc0df02fc50de
Successfully built wikipedia
Installing collected packages: wikipedia
Successfully installed wikipedia-1.4.0

1 # Plotting
2 import networkx as nx
3 import matplotlib.pyplot as plt

1 #Parsing args
2 import argparse

https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 1/8
11/26/23, 11:35 PM GNN-Augmenting_Notes - Colaboratory
1 #Converting to PyG
2 !pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html
3 !pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html
4 !pip install torch-geometric
5 import numpy as np
6 import torch
7 from torch_geometric.utils.convert import from_networkx

Looking in links: https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html


Collecting torch-scatter
Downloading torch_scatter-2.1.2.tar.gz (108 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 108.0/108.0 kB 1.4 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Building wheels for collected packages: torch-scatter
Building wheel for torch-scatter (setup.py) ... done
Created wheel for torch-scatter: filename=torch_scatter-2.1.2-cp310-cp310-linux_x86_64.whl size=495089 sha256=f2de310cb3c5
Stored in directory: /root/.cache/pip/wheels/92/f1/2b/3b46d54b134259f58c8363568569053248040859b1a145b3ce
Successfully built torch-scatter
Installing collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2
Looking in links: https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html
Collecting torch-sparse
Downloading torch_sparse-0.6.18.tar.gz (209 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 210.0/210.0 kB 3.9 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from torch-sparse) (1.11.3)
Requirement already satisfied: numpy<1.28.0,>=1.21.6 in /usr/local/lib/python3.10/dist-packages (from scipy->torch-sparse) (
Building wheels for collected packages: torch-sparse
Building wheel for torch-sparse (setup.py) ... done
Created wheel for torch-sparse: filename=torch_sparse-0.6.18-cp310-cp310-linux_x86_64.whl size=1035675 sha256=08cf99e1e1e2
Stored in directory: /root/.cache/pip/wheels/c9/dd/0f/a6a16f9f3b0236733d257b4b4ea91b548b984a341ed3b8f38c
Successfully built torch-sparse
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.18
Collecting torch-geometric
Downloading torch_geometric-2.4.0-py3-none-any.whl (1.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 12.5 MB/s eta 0:00:00
Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (4.66.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (1.23.5)
Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (1.11.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (3.1.2)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (2.31.0)
Requirement already satisfied: pyparsing in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (3.1.1)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (1.2.2)
Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (5.9.5)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch-geometric) (2.
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torch-geo
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torch-geometric) (3.4
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torch-geometric
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torch-geometric
Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->torch-geometric)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->torch-geo
Installing collected packages: torch-geometric
Successfully installed torch-geometric-2.4.0

1 import warnings
2 warnings.filterwarnings('ignore')

This is the function that takes the data from Wikipedia (fetched using Relationship Generator), and puts it into an nx graph!

https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 2/8
11/26/23, 11:35 PM GNN-Augmenting_Notes - Colaboratory
1 def create_graph(topics=["tests"], depth=20, max_size=20, simplify=False, plot=False, save_dir=None, max_nodes=None):
2 rg = RelationshipGenerator(save_dir=save_dir)
3
4 for topic in topics:
5 rg.scan(topic, max_nodes=max_nodes)
6
7 print(f"Created {len(rg.links)} links with {rg.rank_terms().shape[0]} nodes.")
8
9 links = rg.links
10 links = remove_self_references(links)
11
12 node_data = rg.rank_terms()
13 nodes = node_data.index.tolist()
14 node_weights = node_data.values.tolist()
15 node_weights = [nw * 100 for nw in node_weights]
16 nodelist = nodes
17
18
19 G = nx.DiGraph() # MultiGraph()
20
21 # G.add_node()
22 G.add_nodes_from(nodes)
23 feature_vectors, model = doc2vec(nodes, rg)
24 nx.set_node_attributes(G, feature_vectors, name="features")
25
26 # Add edges
27 G.add_weighted_edges_from(links)
28 return G, nodelist, node_weights, model

Helper function to remove self edges

1 def remove_self_references(l):
2 return [i for i in l if i[0]!=i[1]]

Doc2Vec Function - This function takes our nodes, links, and the contents of each page, and trains a Doc2Vec model to turn the page contents
into page features

1 def doc2vec(nodes, rg):


2 # List of tuples page title, page content
3 features = dict(filter(lambda x: x[0] in nodes, rg.features.items()))
4 features = sorted(rg.features.items(), key=lambda key_value: nodes.index(key_value[0]))
5 tokenized_docs = [nltk.word_tokenize(' '.join(doc).lower()) for doc in features]
6 tagged_docs = [TaggedDocument(words=doc, tags=[str(i)]) for i, doc in enumerate(tokenized_docs)]
7 # Model
8 model = Doc2Vec(vector_size=300, min_count=1, epochs=50)
9 model.build_vocab(tagged_docs)
10 model.train(tagged_docs, total_examples=model.corpus_count, epochs=model.epochs)
11 feature_vectors = {node: model.infer_vector(tokenized_docs[i]) for i, node in enumerate(nodes)}
12
13 return feature_vectors, model

This is a big block of code, but in summary, but in summary, it takes a starter word, and uses a modified BFS to fetch articles from Wikipedia,
while storing the edges and page content! As you'll notice, we calculate an edge weight, and this is to explore articles that have a higher edge
weight. These weights are not used in the final model.

https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 3/8
11/26/23, 11:35 PM GNN-Augmenting_Notes - Colaboratory
1 class RelationshipGenerator():
2 """Generates relationships between terms, based on wikipedia links"""
3 def __init__(self, save_dir):
4 self.links = [] # [start, end, weight]
5 self.features = {} #{page: page_content}
6 self.page_links = {}
7
8
9 def scan(self, start=None, repeat=0, max_nodes=None):
10 print("On depth: ", repeat)
11 """Start scanning from a specific word, or from internal database
12
13 Args:
14 start (str): the term to start searching from, can be None to let
15 algorithm decide where to start
16 repeat (int): the number of times to repeat the scan
17 """
18 nodes_visited = 0
19 while repeat >= 0:
20 if max_nodes != None and nodes_visited == max_nodes:
21 return
22 # should check if start page exists
23 # and haven't already scanned
24 # if start in [l[0] for l in self.links]:
25 # raise Exception("Already scanned")
26
27 term_search = True if start is not None else False
28
29 # If a start isn't defined, we should find one
30 if start is None:
31 start = self.find_starting_point()
32
33 # Scan the starting point specified for links
34 print(f"Scanning page {start}...")
35 try:
36 # Fetch the page through the Wikipedia API
37 page = wp.page(start)
38 self.features[start] = page.content
39 links = list(set(page.links))
40
41 # ignore some uninteresting terms
42 links = [l for l in links if not self.ignore_term(l)]
43
44 # Add links to database
45 link_weights = []
46 for link in links:
47 weight = self.weight_link(page, link)
48 link_weights.append(weight)
49
50 link_weights = [w / max(link_weights) for w in link_weights]
51
52 #add the links
53 for i, link in enumerate(links):
54 if max_nodes != None and nodes_visited == max_nodes:
55 return
56
57 #Access all the pages that link to the links that have been added
58 try:
59 link = link.lower()
60 if link not in self.features or link not in self.page_links:
61 time.sleep(np.random.randint(0, 10))
62 page = wp.page(link)
63 self.features[link] = page.content
64 self.page_links[link] = [l.lower() for l in page.links]
65 print("Page Accessed: ", link)
66 nodes_visited += 1
67 else:
68 print("Page has previously been accessed: ", link)
69 total_nodes = set([l[1].lower() for l in self.links])
70 for links_to in set([l.lower() for l in self.page_links[link]]).intersection(total_nodes):
71 self.links.append([link, links_to, 0.1]) # 3 works pretty well
72 print("hi")
73 self.links.append([start, link, link_weights[i] + 2 * int(term_search)]) # 3 works pretty well
74
75 except (DisambiguationError, PageError):
76 print("Page not found: ", link)
77

https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 4/8
11/26/23, 11:35 PM GNN-Augmenting_Notes - Colaboratory

https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 5/8
11/26/23, 11:35 PM GNN-Augmenting_Notes - Colaboratory
78
79 # Print some data to the user on progress
80 explored_nodes = set([l[0] for l in self.links])
81 explored_nodes_count = len(explored_nodes)
82 total_nodes = set([l[1] for l in self.links])
83 total_nodes_count = len(total_nodes)
84 new_nodes = [l.lower() for l in links if l not in total_nodes]
85 new_nodes_count = len(new_nodes)
86 print(f"New nodes added: {new_nodes_count}, Total Nodes: {total_nodes_count}, Explored Nodes: {explored_nodes
87
88 except (DisambiguationError, PageError):
89 # This happens if the page has disambiguation or doesn't exist
90 # We just ignore the page for now, could improve this
91 # self.links.append([start, "DISAMBIGUATION", 0])
92 print("ERROR, I DID NOT GET THIS PAGE")
93 pass
94
95 repeat -= 1
96 start = None
97
98 def find_starting_point(self):
99 """Find the best place to start when no input is given"""
100 # Need some links to work with.
101 if len(self.links) == 0:
102 raise Exception("Unable to start, no start defined or existing links")
103
104 # Get top terms
105 res = self.rank_terms()
106 sorted_links = list(zip(res.index, res.values))
107 all_starts = set([l[0] for l in self.links])
108
109 # Remove identifiers (these are on many Wikipedia pages)
110 all_starts = [l for l in all_starts if '(identifier)' not in l]
111
112 # print(sorted_links[:10])
113 # Iterate over the top links, until we find a new one
114 for i in range(len(sorted_links)):
115 if sorted_links[i][0] not in all_starts and len(sorted_links[i][0]) > 0:
116 return sorted_links[i][0]
117
118 # no link found
119 raise Exception("No starting point found within links")
120 return
121
122 @staticmethod
123 def weight_link(page, link):
124 """Weight an outgoing link for a given source page
125
126 Args:
127 page (obj):
128 link (str): the outgoing link of interest
129
130 Returns:
131 (float): the weight, between 0 and 1
132 """
133 weight = 0.1
134
135 link_counts = page.content.lower().count(link.lower())
136 weight += link_counts
137
138 if link.lower() in page.summary.lower():
139 weight += 3
140
141 return weight
142
143 def rank_terms(self, with_start=True):
144 # We can use graph theory here!
145 # tws = [l[1:] for l in self.links]
146 df = pd.DataFrame(self.links, columns=["start", "end", "weight"])
147
148 if with_start:
149 df = df.append(df.rename(columns={"end": "start", "start":"end"}))
150
151 return df.groupby("end").weight.sum().sort_values(ascending=False)
152
153 def get_key_terms(self, n=20):
154 return "'" + "', '".join([t for t in self.rank_terms().head(n).index.tolist() if "(identifier)" not in t]) + "'"

https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 6/8
11/26/23, 11:35 PM GNN-Augmenting_Notes - Colaboratory

155
156 @staticmethod
157 def ignore_term(term):
158 """List of terms to ignore"""
159 if "(identifier)" in term or term == "doi":
160 return True
161 return False

1 G, nodelist, node_weights, model = create_graph(topics=["political philosophy"], max_nodes=2)

On depth: 0
Scanning page political philosophy...
Page Accessed: transactionalism
Page Accessed: encyclopédie
Created 2 links with 3 nodes.

1 G, nodelist, node_weights, model = create_graph(topics=["political philosophy"], max_nodes=4)

On depth: 0
Scanning page political philosophy...
Page Accessed: transactionalism
Page Accessed: encyclopédie
Page Accessed: ibn khaldun
Page Accessed: ethical naturalism
Created 4 links with 5 nodes.

There we go! We've created a small part of the Wikipedia dataset! Now imagine doing this for thousands of nodes...

Let's draw what we have created

1 def simplified_plot(G, nodelist, node_weights):


2 pos = nx.spring_layout(G, k=1, seed=7) # positions for all nodes - seed for reproducibility
3
4 fig = plt.figure(figsize=(12,12))
5
6 nx.draw_networkx_nodes(
7 G, pos,
8 nodelist=nodelist,
9 node_size=node_weights,
10 node_color='lightblue',
11 alpha=0.7
12 )
13
14 widths = nx.get_edge_attributes(G, 'weight')
15 nx.draw_networkx_edges(
16 G, pos,
17 edgelist = widths.keys(),
18 width=list(widths.values()),
19 edge_color='lightblue',
20 alpha=0.6
21 )
22
23 nx.draw_networkx_labels(G, pos=pos,
24 labels=dict(zip(nodelist,nodelist)),
25 font_color='black')
26 fig = plt.show()
27 plt.show()

1 simplified_plot(G, nodelist, node_weights)

https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 7/8
11/26/23, 11:35 PM GNN-Augmenting_Notes - Colaboratory

https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 8/8

You might also like