GNN-02-Augmented Notes
GNN-02-Augmented Notes
Augmented Notes
This is the python notebook file to train and run experiments from https://medium.com/@arjunkaranam10/augmenting-your-notes-using-graph-
neural-networks-e61f0898033a
The github with more files and information can be found at: https://github.com/QuantumArjun/Augmented-Notes-GNNs
Augmented Notes is a model we have created to give you suggestions of Wikipedia articles given a page of notes. Specifically, we train a Graph
Neural Network on a subset of the Wikipedia dataset, and then predict which nodes are closest to your note (by predicting which nodes your
note is most likely to have a link to).
Dataset Creation
First, we need to create our dataset. As the Wikipedia dataset (and even the subset that we used for our project) is too big to upload to colab,
we'll create a mini-version here, just to demonstrate how it works.
1 # Data manipulation
2 import pandas as pd
3 import random
4 import json
5 import os
6 import pickle
7 import time
8 # DOC2VEC
9 from gensim.models.doc2vec import Doc2Vec, TaggedDocument
10 import nltk
11 nltk.download('punkt')
1 # Wikipedia API
2 !pip install wikipedia
3 import wikipedia as wp
4 from wikipedia.exceptions import DisambiguationError, PageError
Collecting wikipedia
Downloading wikipedia-1.4.0.tar.gz (27 kB)
Preparing metadata (setup.py) ... done
Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.10/dist-packages (from wikipedia) (4.11.2)
Requirement already satisfied: requests<3.0.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wikipedia) (2.31.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.0.0->wikiped
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.0.0->w
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.0.0->w
Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4->wikipedia) (2.
Building wheels for collected packages: wikipedia
Building wheel for wikipedia (setup.py) ... done
Created wheel for wikipedia: filename=wikipedia-1.4.0-py3-none-any.whl size=11679 sha256=751685c912b19e9210a1af1c12d11442b
Stored in directory: /root/.cache/pip/wheels/5e/b6/c5/93f3dec388ae76edc830cb42901bb0232504dfc0df02fc50de
Successfully built wikipedia
Installing collected packages: wikipedia
Successfully installed wikipedia-1.4.0
1 # Plotting
2 import networkx as nx
3 import matplotlib.pyplot as plt
1 #Parsing args
2 import argparse
https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 1/8
11/26/23, 11:35 PM GNN-Augmenting_Notes - Colaboratory
1 #Converting to PyG
2 !pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html
3 !pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html
4 !pip install torch-geometric
5 import numpy as np
6 import torch
7 from torch_geometric.utils.convert import from_networkx
1 import warnings
2 warnings.filterwarnings('ignore')
This is the function that takes the data from Wikipedia (fetched using Relationship Generator), and puts it into an nx graph!
https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 2/8
11/26/23, 11:35 PM GNN-Augmenting_Notes - Colaboratory
1 def create_graph(topics=["tests"], depth=20, max_size=20, simplify=False, plot=False, save_dir=None, max_nodes=None):
2 rg = RelationshipGenerator(save_dir=save_dir)
3
4 for topic in topics:
5 rg.scan(topic, max_nodes=max_nodes)
6
7 print(f"Created {len(rg.links)} links with {rg.rank_terms().shape[0]} nodes.")
8
9 links = rg.links
10 links = remove_self_references(links)
11
12 node_data = rg.rank_terms()
13 nodes = node_data.index.tolist()
14 node_weights = node_data.values.tolist()
15 node_weights = [nw * 100 for nw in node_weights]
16 nodelist = nodes
17
18
19 G = nx.DiGraph() # MultiGraph()
20
21 # G.add_node()
22 G.add_nodes_from(nodes)
23 feature_vectors, model = doc2vec(nodes, rg)
24 nx.set_node_attributes(G, feature_vectors, name="features")
25
26 # Add edges
27 G.add_weighted_edges_from(links)
28 return G, nodelist, node_weights, model
1 def remove_self_references(l):
2 return [i for i in l if i[0]!=i[1]]
Doc2Vec Function - This function takes our nodes, links, and the contents of each page, and trains a Doc2Vec model to turn the page contents
into page features
This is a big block of code, but in summary, but in summary, it takes a starter word, and uses a modified BFS to fetch articles from Wikipedia,
while storing the edges and page content! As you'll notice, we calculate an edge weight, and this is to explore articles that have a higher edge
weight. These weights are not used in the final model.
https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 3/8
11/26/23, 11:35 PM GNN-Augmenting_Notes - Colaboratory
1 class RelationshipGenerator():
2 """Generates relationships between terms, based on wikipedia links"""
3 def __init__(self, save_dir):
4 self.links = [] # [start, end, weight]
5 self.features = {} #{page: page_content}
6 self.page_links = {}
7
8
9 def scan(self, start=None, repeat=0, max_nodes=None):
10 print("On depth: ", repeat)
11 """Start scanning from a specific word, or from internal database
12
13 Args:
14 start (str): the term to start searching from, can be None to let
15 algorithm decide where to start
16 repeat (int): the number of times to repeat the scan
17 """
18 nodes_visited = 0
19 while repeat >= 0:
20 if max_nodes != None and nodes_visited == max_nodes:
21 return
22 # should check if start page exists
23 # and haven't already scanned
24 # if start in [l[0] for l in self.links]:
25 # raise Exception("Already scanned")
26
27 term_search = True if start is not None else False
28
29 # If a start isn't defined, we should find one
30 if start is None:
31 start = self.find_starting_point()
32
33 # Scan the starting point specified for links
34 print(f"Scanning page {start}...")
35 try:
36 # Fetch the page through the Wikipedia API
37 page = wp.page(start)
38 self.features[start] = page.content
39 links = list(set(page.links))
40
41 # ignore some uninteresting terms
42 links = [l for l in links if not self.ignore_term(l)]
43
44 # Add links to database
45 link_weights = []
46 for link in links:
47 weight = self.weight_link(page, link)
48 link_weights.append(weight)
49
50 link_weights = [w / max(link_weights) for w in link_weights]
51
52 #add the links
53 for i, link in enumerate(links):
54 if max_nodes != None and nodes_visited == max_nodes:
55 return
56
57 #Access all the pages that link to the links that have been added
58 try:
59 link = link.lower()
60 if link not in self.features or link not in self.page_links:
61 time.sleep(np.random.randint(0, 10))
62 page = wp.page(link)
63 self.features[link] = page.content
64 self.page_links[link] = [l.lower() for l in page.links]
65 print("Page Accessed: ", link)
66 nodes_visited += 1
67 else:
68 print("Page has previously been accessed: ", link)
69 total_nodes = set([l[1].lower() for l in self.links])
70 for links_to in set([l.lower() for l in self.page_links[link]]).intersection(total_nodes):
71 self.links.append([link, links_to, 0.1]) # 3 works pretty well
72 print("hi")
73 self.links.append([start, link, link_weights[i] + 2 * int(term_search)]) # 3 works pretty well
74
75 except (DisambiguationError, PageError):
76 print("Page not found: ", link)
77
https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 4/8
11/26/23, 11:35 PM GNN-Augmenting_Notes - Colaboratory
https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 5/8
11/26/23, 11:35 PM GNN-Augmenting_Notes - Colaboratory
78
79 # Print some data to the user on progress
80 explored_nodes = set([l[0] for l in self.links])
81 explored_nodes_count = len(explored_nodes)
82 total_nodes = set([l[1] for l in self.links])
83 total_nodes_count = len(total_nodes)
84 new_nodes = [l.lower() for l in links if l not in total_nodes]
85 new_nodes_count = len(new_nodes)
86 print(f"New nodes added: {new_nodes_count}, Total Nodes: {total_nodes_count}, Explored Nodes: {explored_nodes
87
88 except (DisambiguationError, PageError):
89 # This happens if the page has disambiguation or doesn't exist
90 # We just ignore the page for now, could improve this
91 # self.links.append([start, "DISAMBIGUATION", 0])
92 print("ERROR, I DID NOT GET THIS PAGE")
93 pass
94
95 repeat -= 1
96 start = None
97
98 def find_starting_point(self):
99 """Find the best place to start when no input is given"""
100 # Need some links to work with.
101 if len(self.links) == 0:
102 raise Exception("Unable to start, no start defined or existing links")
103
104 # Get top terms
105 res = self.rank_terms()
106 sorted_links = list(zip(res.index, res.values))
107 all_starts = set([l[0] for l in self.links])
108
109 # Remove identifiers (these are on many Wikipedia pages)
110 all_starts = [l for l in all_starts if '(identifier)' not in l]
111
112 # print(sorted_links[:10])
113 # Iterate over the top links, until we find a new one
114 for i in range(len(sorted_links)):
115 if sorted_links[i][0] not in all_starts and len(sorted_links[i][0]) > 0:
116 return sorted_links[i][0]
117
118 # no link found
119 raise Exception("No starting point found within links")
120 return
121
122 @staticmethod
123 def weight_link(page, link):
124 """Weight an outgoing link for a given source page
125
126 Args:
127 page (obj):
128 link (str): the outgoing link of interest
129
130 Returns:
131 (float): the weight, between 0 and 1
132 """
133 weight = 0.1
134
135 link_counts = page.content.lower().count(link.lower())
136 weight += link_counts
137
138 if link.lower() in page.summary.lower():
139 weight += 3
140
141 return weight
142
143 def rank_terms(self, with_start=True):
144 # We can use graph theory here!
145 # tws = [l[1:] for l in self.links]
146 df = pd.DataFrame(self.links, columns=["start", "end", "weight"])
147
148 if with_start:
149 df = df.append(df.rename(columns={"end": "start", "start":"end"}))
150
151 return df.groupby("end").weight.sum().sort_values(ascending=False)
152
153 def get_key_terms(self, n=20):
154 return "'" + "', '".join([t for t in self.rank_terms().head(n).index.tolist() if "(identifier)" not in t]) + "'"
https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 6/8
11/26/23, 11:35 PM GNN-Augmenting_Notes - Colaboratory
155
156 @staticmethod
157 def ignore_term(term):
158 """List of terms to ignore"""
159 if "(identifier)" in term or term == "doi":
160 return True
161 return False
On depth: 0
Scanning page political philosophy...
Page Accessed: transactionalism
Page Accessed: encyclopédie
Created 2 links with 3 nodes.
On depth: 0
Scanning page political philosophy...
Page Accessed: transactionalism
Page Accessed: encyclopédie
Page Accessed: ibn khaldun
Page Accessed: ethical naturalism
Created 4 links with 5 nodes.
There we go! We've created a small part of the Wikipedia dataset! Now imagine doing this for thousands of nodes...
https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 7/8
11/26/23, 11:35 PM GNN-Augmenting_Notes - Colaboratory
https://colab.research.google.com/drive/1o9RE6nl1wJUV1Jk92dOz_iZutv7WtKri#printMode=true 8/8