forked from Refefer/cloverleaf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattention.py
121 lines (101 loc) · 4.13 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import argparse
import sys
import cloverleaf
import numpy as np
import tabulate
def softmax(values):
max_v = np.max(values, axis=1).reshape((-1, 1))
numers = np.exp(values - max_v)
return numers / np.sum(numers, axis=1).reshape((-1, 1))
# query, key, query, key, value
def get_key_query_value(emb, head_num, num_heads, d_k):
query_start = head_num * d_k
key_start = num_heads * d_k + head_num * d_k
value_start = num_heads * d_k * 2;
values = emb[value_start:]
value_size = int(len(values) / num_heads)
value = values[head_num * value_size:(head_num+1) * value_size]
return emb[query_start:query_start+d_k], emb[key_start:key_start+d_k], value
def get_attention(embs, feats, head_num, num_heads, d_k, context_window):
terms, query, key, value = [],[],[],[]
for f in feats:
if embs.contains(('feat', f)):
q, k, v = get_key_query_value(embs.get_embedding(('feat', f)), head_num, num_heads, d_k)
terms.append(f)
query.append(q)
key.append(k)
value.append(v)
qs = np.vstack(query)
keys = np.vstack(key)
values = np.vstack(value)
rows = [[] for _ in range(len(qs))]
for i in range(len(qs)):
if context_window is None:
start, stop = 0, len(qs)
else:
start, stop = max(i - context_window, 0), min(i+1+context_window, len(qs))
for j in range(len(qs)):
if start <= j < stop:
rows[i].append(qs[i].dot(keys[j]))
else:
rows[i].append(0)
attention = np.array(rows)
sm = softmax(attention / np.sqrt(qs[0].shape[0]))
return terms, sm, (values * sm.sum(axis=0).reshape((-1, 1))).mean(axis=0)
def cosine(e1, e2):
return e1.dot(e2) / (e1.dot(e1) ** 0.5 * e2.dot(e2) ** 0.5)
def format_row(row):
return [round(v, 3) for v in row]
def parse_embedder(fname):
with open(fname) as f:
etype = f.readline().strip()
if etype != 'Attention':
raise TypeError("Embedder type is not Attention!")
num_heads = int(f.readline().strip())
d_k = int(f.readline().strip())
window = int(f.readline().strip())
if window == 0:
window = None
return num_heads, d_k, window
def main(args):
embs = cloverleaf.NodeEmbeddings.load(args.features, cloverleaf.Distance.Cosine)
num_heads, d_k, context_window = parse_embedder(args.embedder)
print(f"Num Heads:{num_heads}, d_k: {d_k}, sliding: {context_window}")
while True:
terms = input("> ")
terms = terms.split()
actual_terms = None
sms = []
for head_num in range(num_heads):
nterms, mat, embedding = get_attention(embs, terms, head_num, num_heads, d_k, context_window)
if actual_terms is None:
actual_terms = nterms[:]
rows = [[term] + format_row(row) for term, row in zip(nterms, mat)]
rows.append(tabulate.SEPARATING_LINE)
summed = mat.sum(axis=0)
sm = summed / summed.sum()
rows.append(['Softmax'] + format_row(sm))
sms.append(sm)
rows.append(tabulate.SEPARATING_LINE)
idxs = np.argsort(sm)[::-1]
rows.append(['Sorted'] + [nterms[i] for i in idxs])
headers = ['Head {}'.format(head_num)] + actual_terms
print(tabulate.tabulate(rows, headers=headers, tablefmt="fancy_grid"))
print()
avg = np.sum(sms, axis=0) / len(sms)
idxs = np.argsort(avg)[::-1]
header = ['Terms'] + [actual_terms[i] for i in idxs]
sorted_scores = ['Average'] + format_row([avg[i] for i in idxs])
print(tabulate.tabulate([sorted_scores], headers=header, tablefmt="fancy_grid"))
print()
def build_arg_parser():
parser = argparse.ArgumentParser(
description='Examine Attention Matrix',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("features",
help="Path to feature embeddings.")
parser.add_argument("embedder",
help="Path to embedder spec.")
return parser.parse_args()
if __name__ == '__main__':
main(build_arg_parser())