forked from vinhkhuc/MemN2N-babi-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
babi_runner_pytorch.py
189 lines (150 loc) · 7.03 KB
/
babi_runner_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import glob
import os
import random
import sys
import pickle
import gzip
import argparse
import numpy as np
from config import BabiConfig, BabiConfigJointPytorch
from train_test_pytorch import train, train_linear_start, test
from util_pytorch import build_model_pytorch
from util import parse_babi_task, DataType
from demo_pytorch.qa_pytorch import MemN2N
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val) # for reproducing
def run_test(data_dir, task_id, memn2n):
print("Test for task %d ..." % task_id)
test_files = None
if type(data_dir) is tuple:
test_files = glob.glob('%s/qa%d_*_valid.txt' % (data_dir[1], task_id))
else:
test_files = glob.glob('%s/qa%d_*_test.txt' % (data_dir, task_id))
test_story, test_questions, test_qstory = parse_babi_task(test_files, memn2n.general_config.dictionary, False, dt=DataType.PYTORCH)
"""
reversed_dict = None
memory = None
model = None
loss = None
general_config = None
with gzip.open(model_file, "rb") as f:
self.reversed_dict, self.memory, self.model, self.loss, self.general_config = pickle.load(f)
"""
test(test_story, test_questions, test_qstory, memn2n.memory, memn2n.model, memn2n.loss, memn2n.general_config)
def run_all_tests(data_dir, memn2n):
print("Training and testing for all tasks ...")
for t in range(1, 21):
run_test(data_dir, t, memn2n)
def run_task(data_dir, task_id):
"""
Train and test for each task
"""
print("Train and test for task %d ..." % task_id)
# Parse data
train_files = glob.glob('%s/qa%d_*_train.txt' % (data_dir, task_id))
test_files = glob.glob('%s/qa%d_*_test.txt' % (data_dir, task_id))
dictionary = {"nil": 0}
train_story, train_questions, train_qstory = parse_babi_task(train_files, dictionary, False)
test_story, test_questions, test_qstory = parse_babi_task(test_files, dictionary, False)
general_config = BabiConfigJointPytorch(train_story, train_questions, dictionary)
memory, model, loss = build_model_pytorch(general_config)
if general_config.linear_start:
train_linear_start(train_story, train_questions, train_qstory, memory, model, loss, general_config)
else:
train(train_story, train_questions, train_qstory, memory, model, loss, general_config)
test(test_story, test_questions, test_qstory, memory, model, loss, general_config)
def run_all_tasks(data_dir):
"""
Train and test for all tasks
"""
print("Training and testing for all tasks ...")
for t in range(20):
run_task(data_dir, task_id=t + 1)
def run_joint_tasks(data_dir):
"""
Train and test for all tasks but the trained model is built using training data from all tasks.
"""
print("Jointly train and test for all tasks ...")
tasks = range(20)
# Parse training data
train_data_path = []
for t in tasks:
train_data_path += glob.glob('%s/qa%d_*_train.txt' % (data_dir, t + 1))
dictionary = {"nil": 0}
train_story, train_questions, train_qstory = parse_babi_task(train_data_path, dictionary, False, dt=DataType.PYTORCH)
# Parse test data for each task so that the dictionary covers all words before training
for t in tasks:
test_data_path = glob.glob('%s/qa%d_*_test.txt' % (data_dir, t + 1))
parse_babi_task(test_data_path, dictionary, False) # ignore output for now
general_config = BabiConfigJointPytorch(train_story, train_questions, dictionary)
memory, model, loss = build_model_pytorch(general_config)
if general_config.linear_start:
train_linear_start(train_story, train_questions, train_qstory, memory, model, loss, general_config)
else:
train(train_story, train_questions, train_qstory, memory, model, loss, general_config)
# Test on each task
for t in tasks:
print("Testing for task %d ..." % (t + 1))
test_data_path = glob.glob('%s/qa%d_*_test.txt' % (data_dir, t + 1))
dc = len(dictionary)
test_story, test_questions, test_qstory = parse_babi_task(test_data_path, dictionary, False)
assert dc == len(dictionary) # make sure that the dictionary already covers all words
test(test_story, test_questions, test_qstory, memory, model, loss, general_config)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data-dir", default="data/tasks_1-20_v1-2/en",
help="path to dataset directory (default: %(default)s)")
parser.add_argument("-m", "--model-file", default="trained_model_pytorch/memn2n_model_pytorch.pklz",
help="model file (default: %(default)s)")
group = parser.add_mutually_exclusive_group()
group.add_argument("-t", "--task", default="0", type=int,
help="train and test for a single task (default: %(default)s)")
group.add_argument("-a", "--all-tasks", action="store_true",
help="train and test for all tasks (one by one) (default: %(default)s)")
group.add_argument("-j", "--joint-tasks", action="store_true",
help="train and test for all tasks (all together) (default: %(default)s)")
group.add_argument("-s", "--test", default="0", type=int,
help="test for a single task (default: %(default)s)")
group.add_argument("-k", "--all-tests", action="store_true",
help="test for all tasks (one by one) (default: %(default)s)")
parser.add_argument("-d2", "--data-dir2", default=None,
help="path to directory containing a training and testing directory)")
args = parser.parse_args()
# Check if data is available
data_dir = args.data_dir
if not os.path.exists(data_dir):
print("The data directory '%s' does not exist. Please download it first." % data_dir)
sys.exit(1)
if args.data_dir2 is not None:
if not os.path.exists(args.data_dir2):
print("The data directory '%s' does not exist." % args.data_dir)
sys.exit(1)
else:
train_path = os.path.join(args.data_dir2, 'train')
if not os.path.exists(train_path):
print("'%s' does not exist." % train_path)
sys.exit(1)
test_path = os.path.join(args.data_dir2, 'test')
if not os.path.exists(test_path):
print("'%s' does not exist." % test_path)
sys.exit(1)
args.data_dir = train_path, test_path
if type(args.data_dir) is tuple:
print("Using data from {} and {}".format(args.data_dir[0], args.data_dir[1]))
else:
print("Using data from %s" % args.data_dir)
if args.test or args.all_tests:
m = MemN2N(args.data_dir, args.model_file)
m.load_model()
if args.all_tests:
run_all_tests(data_dir, m)
else:
run_test(data_dir, args.task, m)
else:
if args.all_tasks:
run_all_tasks(data_dir)
elif args.joint_tasks:
run_joint_tasks(data_dir)
else:
run_task(data_dir, task_id=args.task)