0% found this document useful (0 votes)
69 views9 pages

Improving Text Classifiers Through Controlled Text Generation Using Tranformer Wasserstein Autoencoder

This document discusses using a transformer-based Wasserstein autoencoder for controlled text generation to improve classifiers trained on imbalanced datasets. It proposes generating synthetic minority class data using the autoencoder to balance imbalanced text classification datasets. The paper compares classifiers trained on this synthetic data to those trained on data from other synthetic data generators. It discusses training a controller network to generate controlled synthetic text for different target classes to balance two example imbalanced natural language datasets.
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
69 views9 pages

Improving Text Classifiers Through Controlled Text Generation Using Tranformer Wasserstein Autoencoder

This document discusses using a transformer-based Wasserstein autoencoder for controlled text generation to improve classifiers trained on imbalanced datasets. It proposes generating synthetic minority class data using the autoencoder to balance imbalanced text classification datasets. The paper compares classifiers trained on this synthetic data to those trained on data from other synthetic data generators. It discusses training a controller network to generate controlled synthetic text for different target classes to balance two example imbalanced natural language datasets.
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 9

Improving Text Classifiers through Controlled

Text Generation using Tranformer Wasserstein


Autoencoder?

Harikrishnan C and Dhanya N M

Amrita Vishwa Vidyapeetham, Amritanagar, Ettimadai, Tamil Nadu 641112


[email protected]
nm [email protected]

Abstract. Training good classifiers on imbalanced dataset have always


been a challenge, especially if the classifier has to work with textual
data. Natural language is one such area where there are abundant im-
balanced datasets such as spam filtering, fake news detection, toxic com-
ment classification, etc. Techniques for generating synthetic data like
Synthetic Minority Over-sampling Technique fail to train effective clas-
sifiers. This paper proposes a technique for generating controlled text
using the transformer-based Wasserstein autoencoder which helps in im-
proving the classifiers. The paper compares the results with classifiers
trained on data generated by other synthetic data generators. Further-
more, the potential issues of the proposed model for training classifiers
are discussed.

Keywords: Text Classification · Natural Language Generation · Trans-


formers

1 Introduction

A dataset is said to imbalanced when there is a skew in class proportions. This


skew is reflected on to the classifiers as when they are trained on skewed data,
their results are also skewed towards the class which has a higher proportion. A
dataset is said to imbalanced when there is a skew in class proportions. This skew
is reflected on to the classifiers as when they are trained on skewed data, their
results are also skewed towards the class which has a higher proportion. There are
different approaches to balance the data such as oversampling, undersampling,
using techniques like Synthetic Minority Over-sampling Technique (SMOTE)
[11] to generate synthetic data and assigning weights to classes [16]. But each
of those approaches has its downsides. But each of those approaches has its
downsides.There have been impressive works on spam detection and fake news
classification in the last few years [15] [12] [1] which showed how well deep
learning works better than traditional machine learning algorithms.
?
Supported by Amrita Vishwa Vidyapeetham
2 Harikrishnan et al.

Generative models have shown great improvements in the past few years. These
models can learn the distribution of the original data and generate samples from
that distribution. In the domain of natural language, text generation using a
variational autoencoder (VAE) was proven to be effective [3]. This architecture
used Recurrent Neural Networks based VAE to generate texts that are similar
to the trained dataset. The very same architecture with few modifications was
used to perform controlled text generation [6], this model was able to obtain
meaningful sentences by restricting the sentence length and better accuracy with
sentiment attributes.

After the introduction of the transformer architecture [14] which proved to be


much better at producing results on natural language tasks, most of the systems
moved from RNN based architecture to transformers. Naturally, this included
generic text generation [10] and controlled text [7] [4] using transformers.

The contribution of this paper is a transformer-based Wasserstein autoencoder


which is used for controlled text generation which in turn is used to train a
classifier.

2 Background

2.1 Strategies for balancing datasets

– Oversampling :It is the process where the data belonging to the minority
class is replicated randomly to match the number of instances in the majority
class. The disadvantage of this approach is that it can lead the model to
overfit on the training data
– Undersampling :It is the process where the instances belonging to the
majority class are removed to match the count of instances in the minority
class. There is a chance of losing important information which will lead to
poor generalization of the model.
– SMOTE :SMOTE is used for generating synthetic data for the minority
class. These instances are generated by interpolating the points between
nearest neighbors. While SMOTE has shown some promise in numerical
datasets, it doesn’t work very well in text data

2.2 Wasserstein Autoencoder

Wasserstein autoencoder (WAE) [13] uses the same architecture as variational


autoencoder [8]. While VAE uses Kullback-Leibler divergence for minimizing the
distance between the prior and the posterior distribution WAE replaces this by
using a discriminator network that assigns a score how much does the posterior
distribution resembles the prior. This is achieved by a min-max game played by
the encoder and the discriminator. The discriminator is trained on the following
objective function
Improving text classifier 3

λ m
Σ log(Dγ (zi )) + log(1 − Dγ (zˆi )) (1)
m i
The above objective is maximized by performing a gradient ascent. The en-
coder is trained on the following objective function which is to be minimized.
1 m
Σ c(xi , Gθ (zˆi )) − λ · log(Dγ (zˆi ) (2)
m i

2.3 Transformers
The transformer is a type of architecture that works with sequence to sequence
tasks. This architecture owes its performance to the self-attention mechanism
to understand the weightage of each word in the sentence. The self-attention
mechanism is further enhanced using Multihead attention, where there are h
number of heads and each head performs the self-attention operation. This helps
in interpreting the different meanings of a single sentence. The self-attention
operation can be expressed in the following mathematical expression.

(QK T )
Attention(Q, K, V ) = sof tmax( √ )V (3)
dK

2.4 Decoding Strategies


When generating text from a model, the diversity of the generated text depends
on the decoding strategies used. Some of the decoding strategies are:-
– Greedy Decoding: This is one of the simplest decoding strategies. While
generating the text, the next word is chosen by picking the word with the
highest probability. This process goes on until the maximum number of words
or the end of sentence tag is encountered.
– Topk Sampling: This decoding strategy takes top k probabilities and sam-
ples a word from it. This strategy helps to introduce words that don’t come
up often in sentences.
– Softmax with Temperature: Here a parameter T for temperature is used
to manipulate the output probabilities of the model [5]. The value of T is
used to divide the probabilities before the exponential operation in softmax.

yi
eT
sof tmax(x)i = yj (4)
ΣjN e T

3 Method and Experiments


To ascertain the results, two imbalanced natural language datasets were chosen:-
1. Covid Fake News Dataset [2]
2. Spam Identification [9]
4 Harikrishnan et al.

The Covid dataset has a total of 10201 headlines of which 9727 headlines are
real news and 474 headlines are fake. In the spam or ham dataset, there are a
total of 5572 mail subjects of which 4825 are not spam and 747 are spam mails.

Fig. 1. The proportion of imbalance in a) Covid dataset b) Spam dataset

For preprocessing text data, after tokenization, the words that do not repeat
more than once were replaced by the < unk > token. The numbers were replaced
by < num > token. To train the Transformer WAE start of sentence < sos >
and end of sentence < eos > tokens were appended at the beginning and end of
the text.

Fig. 2. Tranformer Wasserstein Autoencoder


Improving text classifier 5

The transformer WAE was trained using teachers forcing with negative log-
likelihood as the reconstruction loss and divergence loss determined by the dis-
criminator. The training is performed on the complete dataset. After training,
the encoder is used to train a controlling network. The dataset is downsam-
pled and the balanced dataset is encoded into a latent representation by the
encoder. The controlling network is trained to distinguish between the latent
representation by class.

Fig. 3. Training Controller Network

For controlled text generation, a random noise z is sampled from a unit Gaussian
distribution. This noise is then passed to the Controller Network Cz . The label
output by the controller network is set as the expected output and a cross-
entropy loss the calculated with respect to noise z. The noise z is then updated
by gradient descent after scaling it by a factor of η

y = argmax(Cz (z)) (5)

1 m
L= Σ y · log(Cz (z)) (6)
m 1

dL
z =z+η· (7)
dz
The noise is updated iteratively to convert it to the value that the controller is
confident about. The combination of decoding strategy is used while generating
the text, the softmax with temperature was applied over the probabilities and
greedy decoding was implemented. Topk sampling was used to find a replacement
word when a < unk > token was encountered.
The classifier is first trained on a balanced downsampled dataset for 100
epochs. Later the same classifier is finetuned for 5 epochs on the dataset that
is a combination of the downsampled dataset with the generated dataset. The
model trained on downsampled on the same downsampled dataset to prevent
the model from forgetting the original data.
6 Harikrishnan et al.

Fig. 4. Training the classifier

Two different types of classifiers were chosen to validate this approach. A


normal RNN classifier and an RNN with an attention decoder classifier. The
embedding dimension was chosen as 64, the hidden size was set as 256. As for the
architecture, two layers of LSTM were stacked together. For the second classifier,
one LSTM encoder and two LSTM decoders were part of the architecture. A
scaled dot product was performed on the output of the encoder and output
of the first decoder. The results of the dot product were input to the second
decoder.

The setup shown in Fig. 4 was modified by swapping out the transformer model
with RNN Variational Autoencoder’s (RNN VAE) decoder. This was done to
compare how transformer-based text generation affects the classifier.
Lastly, another set of classifiers were trained on a combination of the real
data and synthetic data generated by SMOTE. This was done to understand
how much does the SMOTE help in text classification.

4 Results
For comparing the different models explained in the previous section, accuracy
and F1-score were chosen as the metrics.The models were tested on validation
set which was not the part of the training set. The validation set was taken such
that there is no significant skewness among the proportion of class instances.
From the results, it can be inferred that fine-tuning the classifier on the
text generated by the transformer-based model produces better results. From
Table 1 and Table 2 it is evident that SMOTE does not function well with text
classification and prevents the classifier from generalizing.
The downside of this approach is that the text generated by the generator
dependant upon the random noise. This way we have no control over what kind
of text will be generated. Another issue would be memory consumption. While
training the classifier, at least three models are loaded into the memory i.e. the
generator model, the controller network, and the classifier network. If the size of
the models is large it can result in causing an out-of-memory error.
Improving text classifier 7

Table 1. Covid Fake news detection results

Model Accuracy F1 Score


RNN classifier 0.8105 0.8392
RNN with attention 0.8315 0.8446
RNN classifier on SMOTE 0.4315 0.1147
RNN with attention on SMOTE 0.4578 0.1889
RNN classifier trained on RNN VAE 0.7736 0.7902
RNN with attention trained on RNN VAE 0.8157 0.8292
RNN classifier trained on Transformer WAE 0.8421 0.8514
RNN with attention trained on Transformer WAE 0.8421 0.8543

Table 2. Spam detection results

Model Accuracy F1 Score


RNN classifier 0.8695 0.8849
RNN with attention 0.9364 0.9396
RNN classifier on SMOTE 0.4682 0.1928
RNN with attention on SMOTE 0.4715 0.2882
RNN classifier trained on RNN VAE 0.8695 0.8876
RNN with attention trained on RNN VAE 0.9096 0.9184
RNN classifier trained on Transformer WAE 0.8862 0.8950
RNN with attention trained on Transformer WAE 0.9397 0.9407

5 Conclusions

This paper proposes a new approach to train better models on an imbalanced


dataset. The experiments showed that the existing synthetic data generation
techniques such as SMOTE proved to be ineffective in the natural language
domain and the proposed approach is quite effective. The paper further discusses
the disadvantages of using the proposed approach.
Bibliography

[1] Anjali, B., Reshma, R., Geetha Lekshmy, V.: Detection of Counterfeit News
Using Machine Learning. 2019 2nd International Conference on Intelligent
Computing, Instrumentation and Control Technologies, ICICICT 2019 pp.
1382–1386 (2019). https://doi.org/10.1109/ICICICT46008.2019.8993330
[2] Banik, S.: Covid fake news dataset (Nov
2020). https://doi.org/10.5281/zenodo.4282522,
https://doi.org/10.5281/zenodo.4282522
[3] Bowman, S.R., Vilnis, L., Vinyals, O., Dai, A.M., Jozefowicz, R., Bengio, S.:
Generating sentences from a continuous space. CoNLL 2016 - 20th SIGNLL
Conference on Computational Natural Language Learning, Proceedings pp.
10–21 (2016). https://doi.org/10.18653/v1/k16-1002
[4] Dathathri, S., Madotto, A., Lan, J., Hung, J., Frank, E., Molino, P., Yosin-
ski, J., Liu, R.: Plug and play language models: A simple approach to con-
trolled text generation. arXiv pp. 1–34 (2019)
[5] Hinton, G., Vinyals, O., Dean, J.: Distilling the Knowledge in a Neural
Network pp. 1–9 (2015), http://arxiv.org/abs/1503.02531
[6] Hu, Z., Yang, Z., Liang, X., Salakhutdinov, R., Xing, E.P.: Toward con-
trolled generation of text. 34th International Conference on Machine Learn-
ing, ICML 2017 4, 2503–2513 (2017)
[7] Keskar, N.S., McCann, B., Varshney, L.R., Xiong, C., Socher, R.: CTRL: A
conditional transformer language model for controllable generation. arXiv
pp. 1–18 (2019)
[8] Kingma, D.P., Welling, M.: Auto-encoding variational bayes. 2nd Interna-
tional Conference on Learning Representations, ICLR 2014 - Conference
Track Proceedings (Ml), 1–14 (2014)
[9] Klimt, B., Yang, Y.: The enron corpus: A new dataset for email classification
research pp. 217–226 (2004)
[10] Liu, D., Liu, G.: A Transformer-Based Variational Autoencoder
for Sentence Generation. Proceedings of the International Joint
Conference on Neural Networks 2019-July(July), 1–7 (2019).
https://doi.org/10.1109/IJCNN.2019.8852155
[11] Mansourifar, H., Shi, W.: Deep synthetic minority over-sampling technique.
arXiv 16, 321–357 (2020)
[12] Srinivasan, S., Ravi, V., Alazab, M., Ketha, S., Al-Zoubi, A.M., Kotti
Padannayil, S.: Spam Emails Detection Based on Distributed Word
Embedding with Deep Learning. Studies in Computational Intelligence
919(December), 161–189 (2021)
[13] Tolstikhin, I., Bousquet, O., Gelly, S., Schölkopf, B.: Wasserstein auto-
encoders. arXiv pp. 1–20 (2017)
[14] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N.,
Kaiser, L., Polosukhin, I.: Attention is all you need. Advances in Neural
Information Processing Systems 2017-Decem(Nips), 5999–6009 (2017)
Improving text classifier 9

[15] Vinayakumar, R., Soman, K.P., Poornachandran, P., Akarsh, S.: Applica-
tion of deep learning architectures for cyber security. No. June, Springer
International Publishing (2019)
[16] Vishagini, V., Rajan, A.K.: An Improved Spam Detection Method
with Weighted Support Vector Machine. 2018 International Con-
ference on Data Science and Engineering, ICDSE 2018 (2018).
https://doi.org/10.1109/ICDSE.2018.8527737

You might also like