Code
Code
import os
import pickle
import time
from os.path import join
import torch
import torch.nn as nn
from torch.nn import functional as F
import utils
from torch.autograd import Variable
import numpy as np
from tqdm import tqdm
import random
import copy
from base_model import FindCntfImages
from itertools import cycle
img_loader = cycle(ext_loader)
if mode=='q_debias':
topq=args.topq
keep_qtype=args.keep_qtype
elif mode=='v_debias':
topv=args.topv
top_hint=args.top_hint
elif mode=='q_v_debias':
topv=args.topv
top_hint=args.top_hint
topq=args.topq
keep_qtype=args.keep_qtype
qvp=args.qvp
t = time.time()
# for i, (data1, data2) in tqdm(enumerate(zip(train_loader, ext_loader)),
ncols=100,
# desc="Epoch %d" % (epoch + 1),
total=len(train_loader)):
# v, q, a, b, hintscore,type_mask,notype_mask,q_mask = data1
# img_batch, _ = data2
total_step += 1
#########################################
v = Variable(v).cuda().requires_grad_()
q = Variable(q).cuda()
q_mask=Variable(q_mask).cuda()
a = Variable(a).cuda()
b = Variable(b).cuda()
hintscore = Variable(hintscore).cuda()
type_mask=Variable(type_mask).float().cuda()
notype_mask=Variable(notype_mask).float().cuda()
#########################################
if mode=='updn':
pred, loss,_ = model(v, q, a, b, None)
if (loss != loss).any():
raise ValueError("NaN loss")
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 0.25)
optim.step()
optim.zero_grad()
elif mode=='q_debias':
if keep_qtype==True:
sen_mask=type_mask
else:
sen_mask=notype_mask
## first train
pred, loss,word_emb = model(v, q, a, b, None)
if (loss != loss).any():
raise ValueError("NaN loss")
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 0.25)
optim.step()
optim.zero_grad()
## second train
word_grad_cam = word_grad.sum(2)
# word_grad_cam_sigmoid = torch.sigmoid(word_grad_cam * 1000)
word_grad_cam_sigmoid = torch.exp(word_grad_cam * sen_mask)
word_grad_cam_sigmoid = word_grad_cam_sigmoid * sen_mask
q2 = copy.deepcopy(q_mask)
m1 = copy.deepcopy(sen_mask) ##[0,0,0...0,1,1,1,1]
m1.scatter_(1, w_ind, 0) ##[0,0,0...0,0,1,1,0]
m2 = 1 - m1 ##[1,1,1...1,1,0,0,1]
if dataset=='cpv1':
m3=m1*18330
else:
m3 = m1 * 18455 ##[0,0,0...0,0,18455,18455,0]
q2 = q2 * m2.long() + m3.long()
## third train
if (loss != loss).any():
raise ValueError("NaN loss")
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 0.25)
optim.step()
optim.zero_grad()
elif mode=='v_debias':
## first train
pred, loss, _ = model(v, q, a, b, None)
visual_grad=torch.autograd.grad((pred * (a > 0).float()).sum(), v,
create_graph=True)[0]
if (loss != loss).any():
raise ValueError("NaN loss")
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 0.25)
optim.step()
optim.zero_grad()
##second train
v_mask = torch.zeros(v.shape[0], 36).cuda()
visual_grad_cam = visual_grad.sum(2)
hint_sort, hint_ind = hintscore.sort(1, descending=True)
v_ind = hint_ind[:, :top_hint]
v_grad = visual_grad_cam.gather(1, v_ind)
if topv==-1:
v_grad_score,v_grad_ind=v_grad.sort(1,descending=True)
v_grad_score=nn.functional.softmax(v_grad_score*10,dim=1)
v_grad_sum=torch.cumsum(v_grad_score,dim=1)
v_grad_mask=(v_grad_sum<=0.65).long()
v_grad_mask[:,0] = 1
v_mask_ind=v_grad_mask*v_ind
for x in range(a.shape[0]):
num=len(torch.nonzero(v_grad_mask[x]))
v_mask[x].scatter_(0,v_mask_ind[x,:num],1)
else:
v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv]
v_star = v_ind.gather(1, v_grad_ind)
v_mask.scatter_(1, v_star, 1)
v_mask = 1 - v_mask
if (loss != loss).any():
raise ValueError("NaN loss")
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 0.25)
optim.step()
optim.zero_grad()
elif mode=='q_v_debias':
random_num = random.randint(1, 10)
if keep_qtype == True:
sen_mask = type_mask
else:
sen_mask = notype_mask
if random_num<=qvp:
## first train
pred, loss, word_emb = model(v, q, a, b, None)
word_grad = torch.autograd.grad((pred * (a > 0).float()).sum(),
word_emb, create_graph=True)[0]
if (loss != loss).any():
raise ValueError("NaN loss")
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 0.25)
optim.step()
optim.zero_grad()
## second train
word_grad_cam = word_grad.sum(2)
# word_grad_cam_sigmoid = torch.sigmoid(word_grad_cam * 1000)
word_grad_cam_sigmoid = torch.exp(word_grad_cam * sen_mask)
word_grad_cam_sigmoid = word_grad_cam_sigmoid * sen_mask
w_ind = word_grad_cam_sigmoid.sort(1, descending=True)[1]
[:, :topq]
q2 = copy.deepcopy(q_mask)
m1 = copy.deepcopy(sen_mask) ##[0,0,0...0,1,1,1,1]
m1.scatter_(1, w_ind, 0) ##[0,0,0...0,0,1,1,0]
m2 = 1 - m1 ##[1,1,1...1,1,0,0,1]
if dataset=='cpv1':
m3=m1*18330
else:
m3 = m1 * 18455 ##[0,0,0...0,0,18455,18455,0]
q2 = q2 * m2.long() + m3.long()
## third train
if (loss != loss).any():
raise ValueError("NaN loss")
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 0.25)
optim.step()
optim.zero_grad()
total_loss += loss.item() * q.size(0)
if (loss != loss).any():
raise ValueError("NaN loss")
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 0.25)
optim.step()
optim.zero_grad()
if topv == -1:
v_grad_score, v_grad_ind = v_grad.sort(1, descending=True)
v_grad_score = nn.functional.softmax(v_grad_score * 10,
dim=1) # 객체 중요도 분포만듬
v_grad_sum = torch.cumsum(v_grad_score, dim=1)
# 상위 객체 중요도 누적
v_grad_mask = (v_grad_sum <= 0.65).long()
# 65%이하를 바꿈
v_grad_mask[:,0] = 1
# 마스킹할 객체가 하나도 없는 상황을 방지
v_mask_ind = v_grad_mask * v_ind
# 중요 객체의 인덱스
for x in range(a.shape[0]):
num = len(torch.nonzero(v_grad_mask[x]))
# 객체 위치찾음 -> 개수
v_mask[x].scatter_(0, v_mask_ind[x,:num], 1)
# num 개만큼의 인덱스를 1 로 바꿈( 마스킹)
else:
v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv]
v_star = v_ind.gather(1, v_grad_ind)
v_mask.scatter_(1, v_star, 1)
if (loss != loss).any():
raise ValueError("NaN loss")
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 0.25)
optim.step()
optim.zero_grad()
v_mask = 1 - v_mask
# 원상복구
while True:
img, _ = next(img_loader)
if img.size(0) != 512:
img = img.reshape(512, 36, 2048).to(device)
break
if (loss != loss).any():
raise ValueError("NaN loss")
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 0.25)
optim.step()
optim.zero_grad()
## last train
counterfactual_generator = FindCntfImages(model,
v,
q,
a,
b,
pred,
v_mask,
visual_grad_cam,
v_new
)
counterfacter_img, _ = counterfactual_generator()
if (loss != loss).any():
raise ValueError("NaN loss")
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 0.25)
optim.step()
optim.zero_grad()
if mode=='updn':
total_loss /= len(train_loader.dataset)
else:
total_loss /= len(train_loader.dataset) * 2
train_score = 100 * train_score / len(train_loader.dataset)
if run_eval:
model.train(False)
results = evaluate(model, eval_loader, qid2type)
results["epoch"] = epoch + 1
results["step"] = total_step
results["train_loss"] = total_loss
results["train_score"] = train_score
model.train(True)
eval_score = results["score"]
bound = results["upper_bound"]
yn = results['score_yesno']
other = results['score_other']
num = results['score_number']
results = dict(
score=score,
upper_bound=upper_bound,
score_yesno=score_yesno,
score_other=score_other,
score_number=score_number,
)
return results