-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtext_augmentation.py
52 lines (47 loc) · 3.13 KB
/
text_augmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from transformers import BertTokenizer
from augmentation.eda import eda
from sentence_transformers import SentenceTransformer,util
import torch
import time
class TextAugmentation :
def __init__(self, config):
super().__init__()
self.num_beams = config["num_beams"]
self.num_return_sequences = config["num_return_sequences"]
self.max_length = config["max_text_len"]
self.type_txt_augm = config["type_txt_augm"]
self.tokenizer = BertTokenizer.from_pretrained(config["tokenizer"])
self.tokenizer_pegasus = PegasusTokenizer.from_pretrained('tuner007/pegasus_paraphrase')
self.pegasus = PegasusForConditionalGeneration.from_pretrained('tuner007/pegasus_paraphrase')
self.model_sentence_embedding = SentenceTransformer('paraphrase-MiniLM-L6-v2')
def augmentation(self,pl_module, batch):
epoch = pl_module.current_epoch
self.pegasus = self.pegasus.to(pl_module.device)
txt_input = []
text_masks = []
final_sentences = []
if "PEGASUS" in self.type_txt_augm :
batch_pegasus = self.tokenizer_pegasus(batch["text"],
truncation=True, padding='longest', return_tensors="pt").to(pl_module.device)
translated = self.pegasus.generate(**batch_pegasus,max_length=self.max_length,
num_beams=self.num_beams, num_return_sequences=self.num_return_sequences,
temperature=1.5).to(pl_module.device)
augmented_pegasus = self.tokenizer_pegasus.batch_decode(translated, skip_special_tokens=True)
augmented_pegasus = [augmented_pegasus[i:i + self.num_return_sequences] for i in range(0, len(augmented_pegasus), self.num_return_sequences)]
for i,sentence in enumerate(batch["text"]) :
augmented_text = []
if "PEGASUS" in self.type_txt_augm :
augmented_text.extend(augmented_pegasus[i])
if "EDA" in self.type_txt_augm :
augmented_text.extend(eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=self.num_return_sequences))
augmented_text_embeddings = self.model_sentence_embedding.encode(augmented_text,show_progress_bar=False)
original_text_embeddings = self.model_sentence_embedding.encode(sentence,show_progress_bar=False)
cosine_scores = util.pytorch_cos_sim(original_text_embeddings, augmented_text_embeddings)
values, indices = torch.sort(cosine_scores,descending =True)
final_sentences.append(augmented_text[int(indices[0][epoch])])
outputs = self.tokenizer(final_sentences, truncation=True, padding=True, max_length=self.max_length)
batch["text"] = augmented_text
batch["text_ids"]= torch.tensor(outputs["input_ids"]).to(pl_module.device)
batch["text_masks"]= torch.tensor(outputs["attention_mask"]).to(pl_module.device)
return batch