{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# First step is to import the needed libraries\n", "import torch\n", "import torch.nn as nn\n", "from torch.autograd import Variable\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "import random\n", "import numpy as np\n", "import pickle\n", "from tqdm import tqdm\n", "%matplotlib inline\n", "from sklearn.metrics import f1_score,accuracy_score\n", "import math\n", "import re\n", "from torch.utils.data import Dataset, DataLoader" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "You are using cuda. Good!\n" ] } ], "source": [ "# in this section we define static values and variables for ease of access and testing\n", "_fn=\"final\" # file unique id for saving and loading models\n", "bert_base='./bert-base-uncased/'\n", "bert_large='./bert-large-uncased/'\n", "\n", "snips_train=\"./dataset/snips_train.iob\"\n", "snips_test=\"./dataset/snips_test.iob\"\n", "atis_train=\"./dataset/atis.train.w-intent.iob\"\n", "atis_test=\"./dataset/atis.test.w-intent.iob\"\n", "#ENV variables directly affect the model's behaviour\n", "ENV_DATASET_TRAIN=atis_train\n", "ENV_DATASET_TEST=atis_test\n", "\n", "ENV_BERT_ID_CLS=False # use cls token for id classification\n", "ENV_EMBEDDING_SIZE=768# dimention of embbeding, bertbase=768,bertlarge&elmo=1024\n", "ENV_BERT_ADDR=bert_base\n", "ENV_SEED=1331\n", "ENV_CNN_FILTERS=128\n", "ENV_CNN_KERNELS=4\n", "ENV_HIDDEN_SIZE=ENV_CNN_FILTERS*ENV_CNN_KERNELS\n", "\n", "#these are related to training\n", "BATCH_SIZE=16\n", "LENGTH=60\n", "STEP_SIZE=50\n", "\n", "# you must use cuda to run this code. if this returns false, you can not proceed.\n", "USE_CUDA = torch.cuda.is_available()\n", "if USE_CUDA:\n", " print(\"You are using cuda. Good!\")\n", "else:\n", " print('You are NOT using cuda! Some problems may occur.')\n", "\n", "torch.manual_seed(ENV_SEED)\n", "random.seed(ENV_SEED)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "implement dataloader" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "\n", "#this function converts tokens to ids and then to a tensor\n", "def prepare_sequence(seq, to_ix):\n", " idxs = list(map(lambda w: to_ix[w] if w in to_ix.keys() else to_ix[\"\"], seq))\n", " tensor = Variable(torch.LongTensor(idxs)).cuda() if USE_CUDA else Variable(torch.LongTensor(idxs))\n", " return tensor\n", "# this function turns class text to id\n", "def prepare_intent(intent, to_ix):\n", " idxs = to_ix[intent] if intent in to_ix.keys() else to_ix[\"UNKNOWN\"]\n", " return idxs\n", "# converts numbers to TAG\n", "def number_to_tag(txt):\n", " return \"\" if txt.isdecimal() else txt\n", "\n", "# Here we remove multiple spaces and punctuation which cause errors in tokenization for bert & elmo.\n", "def remove_punc(mlist):\n", " mlist = [re.sub(\" +\",\" \",t.split(\"\\t\")[0][4:-4]) for t in mlist] # remove spaces down to 1\n", " temp_train_tokens = []\n", " # punct remove example: play samuel-el jackson from 2009 - 2010 > play samuelel jackson from 2009 - 2010\n", " for row in mlist:\n", " tokens = row.split(\" \")\n", " newtokens = []\n", " for token in tokens:\n", " newtoken = re.sub(r\"[.,'\\\"\\\\/\\-:&’—=–官方杂志¡…“”~%]\",r\"\",token) # remove punc\n", " newtoken = re.sub(r\"[楽園追放�]\",r\"A\",newtoken)\n", " newtokens.append(newtoken if len(token)>1 else token)\n", " if newtokens[-1]==\"\":\n", " newtokens.pop(-1)\n", " if newtokens[0]==\"\":\n", " newtokens.pop(0)\n", " temp_train_tokens.append(\" \".join(newtokens))\n", " return temp_train_tokens\n", "# this function returns the main tokens so that we can apply tagging on them. see original paper.\n", "def get_subtoken_mask(current_tokens,bert_tokenizer):\n", " temp_mask = []\n", " for i in current_tokens:\n", " temp_row_mask = []\n", " temp_row_mask.append(False) # for cls token\n", " temp = bert_tokenizer.tokenize(i)\n", " for j in temp:\n", " temp_row_mask.append(j[:2]!=\"##\")\n", " while len(temp_row_mask)')\n", " else:\n", " temp = temp[:LENGTH]\n", " sin.append(temp)\n", " # add padding inside output tokens\n", " temp = seq_out[i]\n", " if len(temp)')\n", " else:\n", " temp = temp[:LENGTH]\n", " sout.append(temp)\n", " return sin,sout\n", "sin,sout=add_paddings(seq_in,seq_out)\n", "sin_test,sout_test=add_paddings(seq_in_test,seq_out_test)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# making dictionary (token:id), initial value\n", "word2index = {'': 0, '':1,'':2,'':3,'':4}\n", "# add rest of token list to dictionary\n", "for token in vocab:\n", " if token not in word2index.keys():\n", " word2index[token]=len(word2index)\n", "#make id to token list ( reverse )\n", "index2word = {v:k for k,v in word2index.items()}\n", "\n", "# initial tag2index dictionary\n", "tag2index = {'' : 0,'':2,'':1,'':3}\n", "# add rest of tag tokens to list\n", "for tag in slot_tag:\n", " if tag not in tag2index.keys():\n", " tag2index[tag] = len(tag2index)\n", "# making index to tag\n", "index2tag = {v:k for k,v in tag2index.items()}\n", "\n", "#initialize intent to index\n", "intent2index={'UNKNOWN':0}\n", "for ii in intent_tag:\n", " if ii not in intent2index.keys():\n", " intent2index[ii] = len(intent2index)\n", "index2intent = {v:k for k,v in intent2index.items()}" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Loading PreTrained Embeddings" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "#defining datasets.\n", "def remove_values_from_list(the_list, val):\n", " return [value for value in the_list if value != val]\n", "\n", "class NLUDataset(Dataset):\n", " def __init__(self, sin,sout,intent,input_ids,attention_mask,token_type_ids,subtoken_mask):\n", " self.sin = [prepare_sequence(temp,word2index) for temp in sin]\n", " self.sout = [prepare_sequence(temp,tag2index) for temp in sout]\n", " self.intent = Variable(torch.LongTensor([prepare_intent(temp,intent2index) for temp in intent])).cuda()\n", " self.input_ids=input_ids.cuda()\n", " self.attention_mask=attention_mask.cuda()\n", " self.token_type_ids=token_type_ids.cuda()\n", " self.subtoken_mask=subtoken_mask.cuda()\n", " self.x_mask = [Variable(torch.BoolTensor(tuple(map(lambda s: s ==0, t )))).cuda() for t in self.sin]\n", " def __len__(self):\n", " return len(self.intent)\n", " def __getitem__(self, idx):\n", " sample = self.sin[idx],self.sout[idx],self.intent[idx],self.input_ids[idx],self.attention_mask[idx],self.token_type_ids[idx],self.subtoken_mask[idx],self.x_mask[idx]\n", " return sample\n", "#making single list\n", "train_data=NLUDataset(sin,sout,intent,train_toks['input_ids'],train_toks['attention_mask'],train_toks['token_type_ids'],train_subtoken_mask)\n", "test_data=NLUDataset(sin_test,sout_test,intent_test,test_toks['input_ids'],test_toks['attention_mask'],test_toks['token_type_ids'],test_subtoken_mask)\n", "train_data = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)\n", "test_data = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# we put all tags inside of the batch in a flat array for F1 measure.\n", "# we use masking so that we only non PAD tokens are counted in f1 measurement\n", "def mask_important_tags(predictions,tags,masks):\n", " result_tags=[]\n", " result_preds=[]\n", " for pred,tag,mask in zip(predictions.tolist(),tags.tolist(),masks.tolist()):\n", " #index [0] is to get the data\n", " for p,t,m in zip(pred,tag,mask):\n", " if not m:\n", " result_tags.append(p)\n", " result_preds.append(t)\n", " #result_tags.pop()\n", " #result_preds.pop()\n", " return result_preds,result_tags\n" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Modeling" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# generates transformer mask\n", "def generate_square_subsequent_mask(sz: int) :\n", " \"\"\"Generates an upper-triangular matrix of -inf, with zeros on diag.\"\"\"\n", " return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)\n", "def generate_square_diagonal_mask(sz: int) :\n", " \"\"\"Generates a matrix which there are zeros on diag and other indexes are -inf.\"\"\"\n", " return torch.triu(torch.ones(sz,sz)-float('inf'), diagonal=1)+torch.tril(torch.ones(sz,sz)-float('inf'), diagonal=-1)\n", "# positional embedding used in transformers\n", "class PositionalEncoding(nn.Module):\n", "\n", " def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):\n", " super().__init__()\n", " self.dropout = nn.Dropout(p=dropout)\n", "\n", " position = torch.arange(max_len).unsqueeze(1)\n", " div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n", " pe = torch.zeros(max_len, 1, d_model)\n", " pe[:, 0, 0::2] = torch.sin(position * div_term)\n", " pe[:, 0, 1::2] = torch.cos(position * div_term)\n", " self.register_buffer('pe', pe)\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " Args:\n", " x: Tensor, shape [seq_len, batch_size, embedding_dim]\n", " \"\"\"\n", " x = x + self.pe[:x.size(0)]\n", " return self.dropout(x)\n", "\n", "\n", "#start of the shared encoder\n", "class BertLayer(nn.Module):\n", " def __init__(self):\n", " super(BertLayer, self).__init__()\n", " self.bert_model = torch.hub.load(ENV_BERT_ADDR, 'model', ENV_BERT_ADDR,source=\"local\")\n", "\n", " def forward(self, bert_info=None):\n", " (bert_tokens, bert_mask, bert_tok_typeid) = bert_info\n", " bert_encodings = self.bert_model(bert_tokens, bert_mask, bert_tok_typeid)\n", " bert_last_hidden = bert_encodings['last_hidden_state']\n", " bert_pooler_output = bert_encodings['pooler_output']\n", " return bert_last_hidden, bert_pooler_output\n", "\n", "\n", "class Encoder(nn.Module):\n", " def __init__(self, p_dropout=0.5):\n", " super(Encoder, self).__init__()\n", " self.filter_number = ENV_CNN_FILTERS\n", " self.kernel_number = ENV_CNN_KERNELS # tedad size haye filter : 2,3,5 = 3\n", " self.embedding_size = ENV_EMBEDDING_SIZE\n", " self.activation = nn.ReLU()\n", " self.p_dropout = p_dropout\n", " self.softmax = nn.Softmax(dim=1)\n", " self.conv1 = nn.Conv1d(in_channels=self.embedding_size, out_channels=self.filter_number, kernel_size=(2,),\n", " padding=\"same\", padding_mode=\"zeros\")\n", " self.conv2 = nn.Conv1d(in_channels=self.embedding_size, out_channels=self.filter_number, kernel_size=(3,),\n", " padding=\"same\", padding_mode=\"zeros\")\n", " self.conv3 = nn.Conv1d(in_channels=self.embedding_size, out_channels=self.filter_number, kernel_size=(5,),\n", " padding=\"same\", padding_mode=\"zeros\")\n", " self.conv4 = nn.Conv1d(in_channels=self.embedding_size, out_channels=self.filter_number, kernel_size=(1,),\n", " padding=\"same\", padding_mode=\"zeros\")\n", "\n", " def forward(self, bert_last_hidden):\n", " trans_embedded = torch.transpose(bert_last_hidden, dim0=1, dim1=2)\n", " convolve1 = self.activation(self.conv1(trans_embedded))\n", " convolve2 = self.activation(self.conv2(trans_embedded))\n", " convolve3 = self.activation(self.conv3(trans_embedded))\n", " convolve4 = self.activation(self.conv4(trans_embedded))\n", " convolve1 = torch.transpose(convolve1, dim0=1, dim1=2)\n", " convolve2 = torch.transpose(convolve2, dim0=1, dim1=2)\n", " convolve3 = torch.transpose(convolve3, dim0=1, dim1=2)\n", " convolve4 = torch.transpose(convolve4, dim0=1, dim1=2)\n", " output = torch.cat((convolve4, convolve1, convolve2, convolve3), dim=2)\n", " return output\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "#Middle\n", "class Middle(nn.Module):\n", " def __init__(self ,p_dropout=0.5):\n", " super(Middle, self).__init__()\n", " self.activation = nn.ReLU()\n", " self.p_dropout = p_dropout\n", " self.softmax = nn.Softmax(dim=1)\n", " #Transformer\n", " nlayers = 2 # number of nn.TransformerEncoderLayer in nn.TransformerEncoder\n", " self.pos_encoder = PositionalEncoding(ENV_HIDDEN_SIZE, dropout=0.1)\n", " encoder_layers = nn.TransformerEncoderLayer(ENV_HIDDEN_SIZE, nhead=2,batch_first=True, dim_feedforward=2048 ,activation=\"relu\", dropout=0.1)\n", " self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers,enable_nested_tensor=False)\n", " self.transformer_mask = generate_square_subsequent_mask(LENGTH).cuda()\n", "\n", " def forward(self, fromencoder,input_masking,training=True):\n", " src = fromencoder * math.sqrt(ENV_HIDDEN_SIZE)\n", " src = self.pos_encoder(src)\n", " output = (self.transformer_encoder(src,src_key_padding_mask=input_masking)) # outputs probably\n", " return output" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "#start of the decoder\n", "class Decoder(nn.Module):\n", "\n", " def __init__(self,slot_size,intent_size,dropout_p=0.5):\n", " super(Decoder, self).__init__()\n", " self.slot_size = slot_size\n", " self.intent_size = intent_size\n", " self.dropout_p = dropout_p\n", " self.softmax= nn.Softmax(dim=1)\n", " # Define the layers\n", " self.embedding = nn.Embedding(self.slot_size, ENV_HIDDEN_SIZE)\n", " self.activation = nn.ReLU()\n", " self.dropout1 = nn.Dropout(self.dropout_p)\n", " self.dropout2 = nn.Dropout(self.dropout_p)\n", " self.dropout3 = nn.Dropout(self.dropout_p)\n", " self.slot_trans = nn.Linear(ENV_HIDDEN_SIZE, self.slot_size)\n", " self.intent_out = nn.Linear(ENV_HIDDEN_SIZE,self.intent_size)\n", " self.intent_out_cls = nn.Linear(ENV_EMBEDDING_SIZE,self.intent_size) # dim of bert\n", " self.decoder_layer = nn.TransformerDecoderLayer(d_model=ENV_HIDDEN_SIZE, nhead=2,batch_first=True,dim_feedforward=300 ,activation=\"relu\")\n", " self.transformer_decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=2)\n", " self.transformer_mask = generate_square_subsequent_mask(LENGTH).cuda()\n", " self.transformer_diagonal_mask = generate_square_diagonal_mask(LENGTH).cuda()\n", " self.pos_encoder = PositionalEncoding(ENV_HIDDEN_SIZE, dropout=0.1)\n", " self.self_attention = nn.MultiheadAttention(embed_dim=ENV_HIDDEN_SIZE\n", " ,num_heads=8,dropout=0.1\n", " ,batch_first=True)\n", " self.layer_norm = nn.LayerNorm(ENV_HIDDEN_SIZE)\n", "\n", "\n", " def forward(self, input,encoder_outputs,encoder_maskings,bert_subtoken_maskings=None,infer=False):\n", " # encoder outputs: BATCH,LENGTH,Dims (16,60,1024)\n", " batch_size = encoder_outputs.shape[0]\n", " length = encoder_outputs.size(1) #for every token in batches\n", " embedded = self.embedding(input)\n", "\n", " #print(\"NOT CLS\")\n", " encoder_outputs2=encoder_outputs\n", " context,attn_weight = self.self_attention(encoder_outputs2,encoder_outputs2,encoder_outputs2\n", " ,key_padding_mask=encoder_maskings)\n", " encoder_outputs2 = self.layer_norm(self.dropout2(context))+encoder_outputs2\n", " sum_mask = (~encoder_maskings).sum(1).unsqueeze(1)\n", " sum_encoder = ((((encoder_outputs2)))*((~encoder_maskings).unsqueeze(2))).sum(1)\n", " intent_score = self.intent_out(self.dropout1(sum_encoder/sum_mask)) # B,D\n", "\n", "\n", " newtensor = torch.cuda.FloatTensor(batch_size, length,ENV_HIDDEN_SIZE).fill_(0.) # size of newtensor same as original\n", " for i in range(batch_size): # per batch\n", " newtensor_index=0\n", " for j in range(length): # for each token\n", " if bert_subtoken_maskings[i][j].item()==1:\n", " newtensor[i][newtensor_index] = encoder_outputs[i][j]\n", " newtensor_index+=1\n", "\n", " if infer==False:\n", " embedded=embedded*math.sqrt(ENV_HIDDEN_SIZE)\n", " embedded = self.pos_encoder(embedded)\n", " zol = self.transformer_decoder(tgt=embedded,memory=newtensor\n", " ,memory_mask=self.transformer_diagonal_mask\n", " ,tgt_mask=self.transformer_mask)\n", "\n", " scores = self.slot_trans(self.dropout3(zol))\n", " slot_scores = F.log_softmax(scores,dim=2)\n", " else:\n", " bos = Variable(torch.LongTensor([[tag2index['']]*batch_size])).cuda().transpose(1,0)\n", " bos = self.embedding(bos)\n", " tokens=bos\n", " for i in range(length):\n", " temp_embedded=tokens*math.sqrt(ENV_HIDDEN_SIZE)\n", " temp_embedded = self.pos_encoder(temp_embedded)\n", " zol = self.transformer_decoder(tgt=temp_embedded,\n", " memory=newtensor,\n", " tgt_mask=self.transformer_mask[:i+1,:i+1],\n", " memory_mask=self.transformer_diagonal_mask[:i+1,:]\n", " )\n", " scores = self.slot_trans(self.dropout3(zol))\n", " softmaxed = F.log_softmax(scores,dim=2)\n", " #the last token is apended to vectors\n", " _,input = torch.max(softmaxed,2)\n", " newtok = self.embedding(input)\n", " tokens=torch.cat((bos,newtok),dim=1)\n", " slot_scores = softmaxed\n", "\n", " return slot_scores.view(input.size(0)*length,-1), intent_score" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Training\n", "\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 14, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at ./bert-base-uncased/ were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] } ], "source": [ "bert_layer = BertLayer()\n", "encoder = Encoder(len(word2index))\n", "middle = Middle()\n", "decoder = Decoder(len(tag2index),len(intent2index))\n", "if USE_CUDA:\n", " encoder = encoder.cuda()\n", " decoder = decoder.cuda()\n", " middle = middle.cuda()\n", " bert_layer.cuda()\n", "\n", "loss_function_1 = nn.CrossEntropyLoss(ignore_index=0)\n", "loss_function_2 = nn.CrossEntropyLoss()\n", "dec_optim = optim.AdamW(decoder.parameters(),lr=0.0001)\n", "enc_optim = optim.AdamW(encoder.parameters(),lr=0.001)\n", "ber_optim = optim.AdamW(bert_layer.parameters(),lr=0.0001)\n", "mid_optim = optim.AdamW(middle.parameters(), lr=0.0001)\n", "enc_scheduler = torch.optim.lr_scheduler.StepLR(enc_optim, 1, gamma=0.96)\n", "dec_scheduler = torch.optim.lr_scheduler.StepLR(dec_optim, 1, gamma=0.96)\n", "mid_scheduler = torch.optim.lr_scheduler.StepLR(mid_optim, 1, gamma=0.96)\n", "ber_scheduler = torch.optim.lr_scheduler.StepLR(ber_optim, 1, gamma=0.96)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/50 [00:00']]*batch_size])).cuda().transpose(1,0)\n", " start_decode = torch.cat((start_decode,tag_target[:,:-1]),dim=1)\n", " tag_score, intent_score = decoder(start_decode,output,bert_mask==0,bert_subtoken_maskings=subtoken_mask)\n", " loss_1 = loss_function_1(tag_score,tag_target.view(-1))\n", " loss_2 = loss_function_2(intent_score,intent_target)\n", " loss = loss_1+loss_2\n", " losses.append(loss.data.cpu().numpy() if USE_CUDA else loss.data.numpy()[0])\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(encoder.parameters(), 0.5)\n", " torch.nn.utils.clip_grad_norm_(middle.parameters(), 0.5)\n", " torch.nn.utils.clip_grad_norm_(decoder.parameters(), 0.5)\n", " torch.nn.utils.clip_grad_norm_(bert_layer.parameters(), 0.5)\n", " enc_optim.step()\n", " mid_optim.step()\n", " dec_optim.step()\n", " ber_optim.step()\n", " #print(bert_tokens[0])\n", " #print(tag_target[0])\n", " id_precision.append(accuracy_score(intent_target.detach().cpu(),torch.argmax(intent_score,dim=1).detach().cpu()))\n", " pred_list,target_list=mask_important_tags(torch.argmax(tag_score,dim=1).view(batch_size,LENGTH),tag_target,x_mask)\n", " sf_f1.append(f1_score(pred_list,target_list,average=\"micro\",zero_division=0))\n", " #print report\n", " print(\"Step\",step,\" batches\",i,\" :\")\n", " print(\"Train-\")\n", " print(f\"loss:{round(float(np.mean(losses)),4)}\")\n", " print(f\"SlotFilling F1:{round(float(np.mean(sf_f1)),3)}\")\n", " print(f\"IntentDet Prec:{round(float(np.mean(id_precision)),3)}\")\n", " losses=[]\n", " sf_f1=[]\n", " id_precision=[]\n", " #scheduler.step()\n", "\n", " #### TEST\n", " encoder.eval() # set to test mode\n", " middle.eval()\n", " decoder.eval()\n", " bert_layer.eval()\n", " with torch.no_grad(): # to turn off gradients computation\n", " for i,(x,tag_target,intent_target,bert_tokens,bert_mask,bert_toktype,subtoken_mask,x_mask) in enumerate(test_data):\n", " batch_size=tag_target.size(0)\n", " encoder.zero_grad()\n", " middle.zero_grad()\n", " decoder.zero_grad()\n", " bert_layer.zero_grad()\n", " bert_hidden,bert_pooler = bert_layer(bert_info=(bert_tokens,bert_mask,bert_toktype))\n", " encoder_output = encoder(bert_last_hidden=bert_hidden)\n", " output = middle(encoder_output,bert_mask==0,training=True)\n", " start_decode = Variable(torch.LongTensor([[tag2index['']]*batch_size])).cuda().transpose(1,0)\n", " tag_score, intent_score = decoder(start_decode,output,bert_mask==0,bert_subtoken_maskings=subtoken_mask,infer=True)\n", " loss_1 = loss_function_1(tag_score,tag_target.view(-1))\n", " loss_2 = loss_function_2(intent_score,intent_target)\n", " loss = loss_1+loss_2\n", " losses.append(loss.data.cpu().numpy() if USE_CUDA else loss.data.numpy()[0])\n", " id_precision.append(accuracy_score(intent_target.detach().cpu(),torch.argmax(intent_score,dim=1).detach().cpu()))\n", " pred_list,target_list=mask_important_tags(torch.argmax(tag_score,dim=1).view(batch_size,LENGTH),tag_target,x_mask)\n", " sf_f1.append(f1_score(pred_list,target_list,average=\"micro\",zero_division=0))\n", " print(\"Test-\")\n", " print(f\"loss:{round(float(np.mean(losses)),4)}\")\n", " print(f\"SlotFilling F1:{round(float(np.mean(sf_f1)),4)}\")\n", " print(f\"IntentDet Prec:{round(float(np.mean(id_precision)),4)}\")\n", " print(\"--------------\")\n", " max_sf_f1 = max_sf_f1 if round(float(np.mean(sf_f1)),4)<=max_sf_f1 else round(float(np.mean(sf_f1)),4)\n", " max_id_prec = max_id_prec if round(float(np.mean(id_precision)),4)<=max_id_prec else round(float(np.mean(id_precision)),4)\n", " if max_sf_f1_both<=round(float(np.mean(sf_f1)),4) and max_id_prec_both<=round(float(np.mean(id_precision)),4):\n", " max_sf_f1_both=round(float(np.mean(sf_f1)),4)\n", " max_id_prec_both=round(float(np.mean(id_precision)),4)\n", " torch.save(bert_layer,f\"models/ctran{_fn}-bertlayer.pkl\")\n", " torch.save(encoder,f\"models/ctran{_fn}-encoder.pkl\")\n", " torch.save(middle,f\"models/ctran{_fn}-middle.pkl\")\n", " torch.save(decoder,f\"models/ctran{_fn}-decoder.pkl\")\n", " enc_scheduler.step()\n", " dec_scheduler.step()\n", " mid_scheduler.step()\n", " ber_scheduler.step()\n", "print(f\"max single SF F1: {max_sf_f1}\")\n", "print(f\"max single ID PR: {max_id_prec}\")\n", "print(f\"max mutual SF:{max_sf_f1_both} PR: {max_id_prec_both}\")\n" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Test\n", "\n", "The following cells is for reviewing the performance of CTran." ] }, { "cell_type": "code", "execution_count": 15, "outputs": [], "source": [], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 16, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# This cell reloads the best model during training from hard-drive.\n", "bert_layer.load_state_dict(torch.load(f'models/ctran{_fn}-bertlayer.pkl').state_dict())\n", "encoder.load_state_dict(torch.load(f'models/ctran{_fn}-encoder.pkl').state_dict())\n", "middle.load_state_dict(torch.load(f'models/ctran{_fn}-middle.pkl').state_dict())\n", "decoder.load_state_dict(torch.load(f'models/ctran{_fn}-decoder.pkl').state_dict())\n", "if USE_CUDA:\n", " bert_layer = bert_layer.cuda()\n", " encoder = encoder.cuda()\n", " middle = middle.cuda()\n", " decoder = decoder.cuda()\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 17, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "global clipindex\n", "clipindex=0\n", "def removepads(toks,clip=False):\n", " global clipindex\n", " result = toks.copy()\n", " for i,t in enumerate(toks):\n", " if t==\"\":\n", " result.remove(t)\n", " elif t==\"\":\n", " result.remove(t)\n", " if not clip:\n", " clipindex=i\n", " if clip:\n", " result=result[:clipindex]\n", " return result" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Example of model prediction on test dataset\n", "Sentence : show me first class flights from new york to miami round trip\n", "Tag Truth : O O B-class_type I-class_type O O B-fromloc.city_name I-fromloc.city_name O B-toloc.city_name B-round_trip I-round_trip\n", "Tag Prediction : O O B-class_type I-class_type O O B-fromloc.city_name I-fromloc.city_name O B-toloc.city_name B-round_trip I-round_trip\n", "Intent Truth : atis_flight\n", "Intent Prediction : atis_flight\n" ] } ], "source": [ "print(\"Example of model prediction on test dataset\")\n", "encoder.eval()\n", "middle.eval()\n", "decoder.eval()\n", "bert_layer.eval()\n", "with torch.no_grad():\n", " index = random.choice(range(len(test)))\n", " test_raw = test[index][0]\n", " bert_tokens = test_toks['input_ids'][index].unsqueeze(0).cuda()\n", " bert_mask = test_toks['attention_mask'][index].unsqueeze(0).cuda()\n", " bert_toktype = test_toks['token_type_ids'][index].unsqueeze(0).cuda()\n", " subtoken_mask = test_subtoken_mask[index].unsqueeze(0).cuda()\n", " test_in = prepare_sequence(test_raw,word2index)\n", " test_mask = Variable(torch.BoolTensor(tuple(map(lambda s: s ==0, test_in.data)))).cuda() if USE_CUDA else Variable(torch.ByteTensor(tuple(map(lambda s: s ==0, test_in.data)))).view(1,-1)\n", " start_decode = Variable(torch.LongTensor([[word2index['']]*1])).cuda().transpose(1,0) if USE_CUDA else Variable(torch.LongTensor([[word2index['']]*1])).transpose(1,0)\n", " test_raw = [removepads(test_raw)]\n", " bert_hidden,bert_pooler = bert_layer(bert_info=(bert_tokens,bert_mask,bert_toktype))\n", " encoder_output = encoder(bert_last_hidden=bert_hidden)\n", " output = middle(encoder_output,bert_mask==0)\n", " tag_score, intent_score = decoder(start_decode,output,bert_mask==0,bert_subtoken_maskings=subtoken_mask,infer=True)\n", "\n", " v,i = torch.max(tag_score,1)\n", " print(\"Sentence : \",*test_raw[0])\n", " print(\"Tag Truth : \", *test[index][1][:len(test_raw[0])])\n", " print(\"Tag Prediction : \",*(list(map(lambda ii:index2tag[ii],i.data.tolist()))[:len(test_raw[0])]))\n", " v,i = torch.max(intent_score,1)\n", " print(\"Intent Truth : \", test[index][2])\n", " print(\"Intent Prediction : \",index2intent[i.data.tolist()[0]])" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Instances where model predicted intent wrong\n", "Sentence : show flight and prices kansas city to chicago on next wednesday arriving in chicago by 7 pm\n", "Tag Truth : O O O O B-fromloc.city_name I-fromloc.city_name O B-toloc.city_name O B-depart_date.date_relative B-depart_date.day_name O O B-toloc.city_name B-arrive_time.time_relative B-arrive_time.time I-arrive_time.time\n", "Tag Prediction : O O O O B-fromloc.city_name I-fromloc.city_name O B-toloc.city_name O B-depart_date.date_relative B-depart_date.day_name O O B-toloc.city_name B-arrive_time.time_relative B-arrive_time.time I-arrive_time.time\n", "Intent Truth : atis_flight#atis_airfare\n", "Intent Prediction : atis_flight\n", "--------------------------------------\n", "Sentence : what day of the week do flights from nashville to tacoma fly on\n", "Tag Truth : O O O O O O O O B-fromloc.city_name O B-toloc.city_name O O\n", "Tag Prediction : O O O O O O O O B-fromloc.city_name O B-toloc.city_name O O\n", "Intent Truth : atis_day_name\n", "Intent Prediction : atis_flight\n", "--------------------------------------\n", "Sentence : what days of the week do flights from san jose to nashville fly on\n", "Tag Truth : O O O O O O O O B-fromloc.city_name I-fromloc.city_name O B-toloc.city_name O O\n", "Tag Prediction : O O O O O O O O B-fromloc.city_name I-fromloc.city_name O B-toloc.city_name O O\n", "Intent Truth : atis_day_name\n", "Intent Prediction : atis_flight\n", "--------------------------------------\n", "Sentence : does the airport at burbank have a flight that comes in from kansas city\n", "Tag Truth : O O O O B-toloc.city_name O O O O O O O B-fromloc.city_name I-fromloc.city_name\n", "Tag Prediction : O O O O B-fromloc.city_name O O O O O O O B-fromloc.city_name I-fromloc.city_name\n", "Intent Truth : atis_flight\n", "Intent Prediction : atis_airport\n", "--------------------------------------\n", "Sentence : show me the connecting flights between boston and denver and the types of aircraft used\n", "Tag Truth : O O O B-connect O O B-fromloc.city_name O B-toloc.city_name O O O O O O\n", "Tag Prediction : O O O B-connect O O B-fromloc.city_name O B-toloc.city_name O O O O O O\n", "Intent Truth : atis_flight\n", "Intent Prediction : atis_aircraft\n", "--------------------------------------\n", "Sentence : list the airfare for american airlines flight 19 from jfk to lax\n", "Tag Truth : O O O O B-airline_name I-airline_name O B-flight_number O B-fromloc.airport_code O B-toloc.airport_code\n", "Tag Prediction : O O O O B-airline_name I-airline_name O B-flight_number O B-fromloc.airport_code O B-toloc.city_name\n", "Intent Truth : atis_airfare#atis_flight\n", "Intent Prediction : atis_airfare\n", "--------------------------------------\n", "Sentence : i need a round trip flight from san diego to washington dc and the fares\n", "Tag Truth : O O O B-round_trip I-round_trip O O B-fromloc.city_name I-fromloc.city_name O B-toloc.city_name B-toloc.state_code O O O\n", "Tag Prediction : O O O B-round_trip I-round_trip O O B-fromloc.city_name I-fromloc.city_name O B-toloc.city_name B-toloc.state_code O O O\n", "Intent Truth : atis_flight#atis_airfare\n", "Intent Prediction : atis_flight\n", "--------------------------------------\n", "Sentence : i need a round trip from atlanta to washington dc and the fares leaving in the morning\n", "Tag Truth : O O O B-round_trip I-round_trip O B-fromloc.city_name O B-toloc.city_name B-toloc.state_code O O O O O O B-depart_time.period_of_day\n", "Tag Prediction : O O O B-round_trip I-round_trip O B-fromloc.city_name O B-toloc.city_name B-toloc.state_code O O O O O O B-depart_time.period_of_day\n", "Intent Truth : atis_flight#atis_airfare\n", "Intent Prediction : atis_airfare\n", "--------------------------------------\n", "Sentence : i need a round trip from phoenix to washington dc and the fare leaving in the morning\n", "Tag Truth : O O O B-round_trip I-round_trip O B-fromloc.city_name O B-toloc.city_name B-toloc.state_code O O O O O O B-depart_time.period_of_day\n", "Tag Prediction : O O O B-round_trip I-round_trip O B-fromloc.city_name O B-toloc.city_name B-toloc.state_code O O O O O O B-depart_time.period_of_day\n", "Intent Truth : atis_flight#atis_airfare\n", "Intent Prediction : atis_airfare\n", "--------------------------------------\n", "Sentence : i need flight and airline information for a flight from denver to salt lake city on monday departing after 5 pm\n", "Tag Truth : O O O O O O O O O O B-fromloc.city_name O B-toloc.city_name I-toloc.city_name I-toloc.city_name O B-depart_date.day_name O B-depart_time.time_relative B-depart_time.time I-depart_time.time\n", "Tag Prediction : O O O O O O O O O O B-fromloc.city_name O B-toloc.city_name I-toloc.city_name I-toloc.city_name O B-depart_date.day_name O B-depart_time.time_relative B-depart_time.time I-depart_time.time\n", "Intent Truth : atis_flight#atis_airline\n", "Intent Prediction : atis_flight\n", "--------------------------------------\n", "Sentence : i need flight and fare information for thursday departing prior to 9 am from oakland going to salt lake city\n", "Tag Truth : O O O O O O O B-depart_date.day_name O B-depart_time.time_relative I-depart_time.time_relative B-depart_time.time I-depart_time.time O B-fromloc.city_name O O B-toloc.city_name I-toloc.city_name I-toloc.city_name\n", "Tag Prediction : O O O O O O O B-depart_date.day_name O B-depart_time.time_relative O B-depart_time.time I-depart_time.time O B-fromloc.city_name O O B-toloc.city_name I-toloc.city_name I-toloc.city_name\n", "Intent Truth : atis_flight#atis_airfare\n", "Intent Prediction : atis_flight\n", "--------------------------------------\n", "Sentence : i need flight and fare information departing from oakland to salt lake city on thursday before 8 am\n", "Tag Truth : O O O O O O O O B-fromloc.city_name O B-toloc.city_name I-toloc.city_name I-toloc.city_name O B-depart_date.day_name B-depart_time.time_relative B-depart_time.time I-depart_time.time\n", "Tag Prediction : O O O O O O O O B-fromloc.city_name O B-toloc.city_name I-toloc.city_name I-toloc.city_name O B-depart_date.day_name B-depart_time.time_relative B-depart_time.time I-depart_time.time\n", "Intent Truth : atis_flight#atis_airfare\n", "Intent Prediction : atis_flight\n", "--------------------------------------\n", "Sentence : i need flight numbers and airlines for flights departing from oakland to salt lake city on thursday departing before 8 am\n", "Tag Truth : O O O O O O O O O O B-fromloc.city_name O B-toloc.city_name I-toloc.city_name I-toloc.city_name O B-depart_date.day_name O B-depart_time.time_relative B-depart_time.time I-depart_time.time\n", "Tag Prediction : O O O O O O O O O O B-fromloc.city_name O B-toloc.city_name I-toloc.city_name I-toloc.city_name O B-depart_date.day_name O B-depart_time.time_relative B-depart_time.time I-depart_time.time\n", "Intent Truth : atis_flight_no#atis_airline\n", "Intent Prediction : atis_flight_no\n", "--------------------------------------\n", "Sentence : what does the restriction ap58 mean\n", "Tag Truth : O O O O B-restriction_code O\n", "Tag Prediction : O O O O B-restriction_code O\n", "Intent Truth : atis_abbreviation\n", "Intent Prediction : atis_restriction\n", "--------------------------------------\n", "Sentence : list la\n", "Tag Truth : O B-city_name\n", "Tag Prediction : O B-city_name\n", "Intent Truth : atis_city\n", "Intent Prediction : atis_flight\n", "--------------------------------------\n", "Sentence : list la\n", "Tag Truth : O B-city_name\n", "Tag Prediction : O B-city_name\n", "Intent Truth : atis_city\n", "Intent Prediction : atis_flight\n", "--------------------------------------\n", "Sentence : give me the flights and fares for a trip to cleveland from miami on wednesday\n", "Tag Truth : O O O O O O O O O O B-toloc.city_name O B-fromloc.city_name O B-depart_date.day_name\n", "Tag Prediction : O O O O O O O O O O B-toloc.city_name O B-fromloc.city_name O B-depart_date.day_name\n", "Intent Truth : atis_flight\n", "Intent Prediction : atis_flight#atis_airfare\n", "--------------------------------------\n", "Sentence : how many northwest flights leave st. paul\n", "Tag Truth : O O B-airline_name O O B-fromloc.city_name I-fromloc.city_name\n", "Tag Prediction : O O B-airline_name O O B-fromloc.city_name I-fromloc.city_name\n", "Intent Truth : atis_flight\n", "Intent Prediction : atis_quantity\n", "--------------------------------------\n", "Sentence : how many northwest flights leave washington dc\n", "Tag Truth : O O B-airline_name O O B-fromloc.city_name B-fromloc.state_code\n", "Tag Prediction : O O B-airline_name O O B-fromloc.city_name B-fromloc.state_code\n", "Intent Truth : atis_flight\n", "Intent Prediction : atis_quantity\n", "--------------------------------------\n", "Sentence : how many flights does northwest have leaving dulles\n", "Tag Truth : O O O O B-airline_name O O B-fromloc.airport_name\n", "Tag Prediction : O O O O B-airline_name O O B-fromloc.airport_name\n", "Intent Truth : atis_flight\n", "Intent Prediction : atis_quantity\n", "--------------------------------------\n", "Sentence : how many flights does alaska airlines have to burbank\n", "Tag Truth : O O O O B-airline_name I-airline_name O O B-toloc.city_name\n", "Tag Prediction : O O O O B-airline_name I-airline_name O O B-toloc.city_name\n", "Intent Truth : atis_flight\n", "Intent Prediction : atis_quantity\n", "--------------------------------------\n", "Total instances of wrong intent prediction is 21\n" ] } ], "source": [ "print(\"Instances where model predicted intent wrong\")\n", "encoder.eval()\n", "middle.eval()\n", "decoder.eval()\n", "bert_layer.eval()\n", "total_wrong_predicted_intents = 0\n", "with torch.no_grad():\n", " for i in range(len(test)):\n", " index = i\n", " test_raw = test[index][0]\n", " bert_tokens = test_toks['input_ids'][index].unsqueeze(0).cuda()\n", " bert_mask = test_toks['attention_mask'][index].unsqueeze(0).cuda()\n", " bert_toktype = test_toks['token_type_ids'][index].unsqueeze(0).cuda()\n", " subtoken_mask = test_subtoken_mask[index].unsqueeze(0).cuda()\n", " test_in = prepare_sequence(test_raw,word2index)\n", " test_mask = Variable(torch.BoolTensor(tuple(map(lambda s: s ==0, test_in.data)))).cuda() if USE_CUDA else Variable(torch.ByteTensor(tuple(map(lambda s: s ==0, test_in.data)))).view(1,-1)\n", " # print(removepads(test_raw))\n", " start_decode = Variable(torch.LongTensor([[word2index['']]*1])).cuda().transpose(1,0) if USE_CUDA else Variable(torch.LongTensor([[word2index['']]*1])).transpose(1,0)\n", " test_raw = [removepads(test_raw)]\n", " bert_hidden,bert_pooler = bert_layer(bert_info=(bert_tokens,bert_mask,bert_toktype))\n", " encoder_output = encoder(bert_last_hidden=bert_hidden)\n", " output = middle(encoder_output,bert_mask==0)\n", " tag_score, intent_score = decoder(start_decode,output,bert_mask==0,bert_subtoken_maskings=subtoken_mask,infer=True)\n", "\n", " v,i = torch.max(intent_score,1)\n", " if test[index][2]!=index2intent[i.data.tolist()[0]]:\n", " v,i = torch.max(tag_score,1)\n", " print(\"Sentence : \",*test_raw[0])\n", " print(\"Tag Truth : \", *test[index][1][:len(test_raw[0])])\n", " print(\"Tag Prediction : \",*list(map(lambda ii:index2tag[ii],i.data.tolist()))[:len(test_raw[0])])\n", " v,i = torch.max(intent_score,1)\n", " print(\"Intent Truth : \", test[index][2])\n", " print(\"Intent Prediction : \",index2intent[i.data.tolist()[0]])\n", " print(\"--------------------------------------\")\n", " total_wrong_predicted_intents+=1\n", "\n", "print(\"Total instances of wrong intent prediction is \",total_wrong_predicted_intents)" ] }, { "cell_type": "code", "execution_count": 19, "outputs": [], "source": [], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "name": "python3", "language": "python", "display_name": "t2s kernel" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.11" } }, "nbformat": 4, "nbformat_minor": 2 }