Improving Text Classifiers Through Controlled Text Generation Using Tranformer Wasserstein Autoencoder
Improving Text Classifiers Through Controlled Text Generation Using Tranformer Wasserstein Autoencoder
1 Introduction
Generative models have shown great improvements in the past few years. These
models can learn the distribution of the original data and generate samples from
that distribution. In the domain of natural language, text generation using a
variational autoencoder (VAE) was proven to be effective [3]. This architecture
used Recurrent Neural Networks based VAE to generate texts that are similar
to the trained dataset. The very same architecture with few modifications was
used to perform controlled text generation [6], this model was able to obtain
meaningful sentences by restricting the sentence length and better accuracy with
sentiment attributes.
2 Background
– Oversampling :It is the process where the data belonging to the minority
class is replicated randomly to match the number of instances in the majority
class. The disadvantage of this approach is that it can lead the model to
overfit on the training data
– Undersampling :It is the process where the instances belonging to the
majority class are removed to match the count of instances in the minority
class. There is a chance of losing important information which will lead to
poor generalization of the model.
– SMOTE :SMOTE is used for generating synthetic data for the minority
class. These instances are generated by interpolating the points between
nearest neighbors. While SMOTE has shown some promise in numerical
datasets, it doesn’t work very well in text data
λ m
Σ log(Dγ (zi )) + log(1 − Dγ (zˆi )) (1)
m i
The above objective is maximized by performing a gradient ascent. The en-
coder is trained on the following objective function which is to be minimized.
1 m
Σ c(xi , Gθ (zˆi )) − λ · log(Dγ (zˆi ) (2)
m i
2.3 Transformers
The transformer is a type of architecture that works with sequence to sequence
tasks. This architecture owes its performance to the self-attention mechanism
to understand the weightage of each word in the sentence. The self-attention
mechanism is further enhanced using Multihead attention, where there are h
number of heads and each head performs the self-attention operation. This helps
in interpreting the different meanings of a single sentence. The self-attention
operation can be expressed in the following mathematical expression.
(QK T )
Attention(Q, K, V ) = sof tmax( √ )V (3)
dK
yi
eT
sof tmax(x)i = yj (4)
ΣjN e T
The Covid dataset has a total of 10201 headlines of which 9727 headlines are
real news and 474 headlines are fake. In the spam or ham dataset, there are a
total of 5572 mail subjects of which 4825 are not spam and 747 are spam mails.
For preprocessing text data, after tokenization, the words that do not repeat
more than once were replaced by the < unk > token. The numbers were replaced
by < num > token. To train the Transformer WAE start of sentence < sos >
and end of sentence < eos > tokens were appended at the beginning and end of
the text.
The transformer WAE was trained using teachers forcing with negative log-
likelihood as the reconstruction loss and divergence loss determined by the dis-
criminator. The training is performed on the complete dataset. After training,
the encoder is used to train a controlling network. The dataset is downsam-
pled and the balanced dataset is encoded into a latent representation by the
encoder. The controlling network is trained to distinguish between the latent
representation by class.
For controlled text generation, a random noise z is sampled from a unit Gaussian
distribution. This noise is then passed to the Controller Network Cz . The label
output by the controller network is set as the expected output and a cross-
entropy loss the calculated with respect to noise z. The noise z is then updated
by gradient descent after scaling it by a factor of η
1 m
L= Σ y · log(Cz (z)) (6)
m 1
dL
z =z+η· (7)
dz
The noise is updated iteratively to convert it to the value that the controller is
confident about. The combination of decoding strategy is used while generating
the text, the softmax with temperature was applied over the probabilities and
greedy decoding was implemented. Topk sampling was used to find a replacement
word when a < unk > token was encountered.
The classifier is first trained on a balanced downsampled dataset for 100
epochs. Later the same classifier is finetuned for 5 epochs on the dataset that
is a combination of the downsampled dataset with the generated dataset. The
model trained on downsampled on the same downsampled dataset to prevent
the model from forgetting the original data.
6 Harikrishnan et al.
The setup shown in Fig. 4 was modified by swapping out the transformer model
with RNN Variational Autoencoder’s (RNN VAE) decoder. This was done to
compare how transformer-based text generation affects the classifier.
Lastly, another set of classifiers were trained on a combination of the real
data and synthetic data generated by SMOTE. This was done to understand
how much does the SMOTE help in text classification.
4 Results
For comparing the different models explained in the previous section, accuracy
and F1-score were chosen as the metrics.The models were tested on validation
set which was not the part of the training set. The validation set was taken such
that there is no significant skewness among the proportion of class instances.
From the results, it can be inferred that fine-tuning the classifier on the
text generated by the transformer-based model produces better results. From
Table 1 and Table 2 it is evident that SMOTE does not function well with text
classification and prevents the classifier from generalizing.
The downside of this approach is that the text generated by the generator
dependant upon the random noise. This way we have no control over what kind
of text will be generated. Another issue would be memory consumption. While
training the classifier, at least three models are loaded into the memory i.e. the
generator model, the controller network, and the classifier network. If the size of
the models is large it can result in causing an out-of-memory error.
Improving text classifier 7
5 Conclusions
[1] Anjali, B., Reshma, R., Geetha Lekshmy, V.: Detection of Counterfeit News
Using Machine Learning. 2019 2nd International Conference on Intelligent
Computing, Instrumentation and Control Technologies, ICICICT 2019 pp.
1382–1386 (2019). https://doi.org/10.1109/ICICICT46008.2019.8993330
[2] Banik, S.: Covid fake news dataset (Nov
2020). https://doi.org/10.5281/zenodo.4282522,
https://doi.org/10.5281/zenodo.4282522
[3] Bowman, S.R., Vilnis, L., Vinyals, O., Dai, A.M., Jozefowicz, R., Bengio, S.:
Generating sentences from a continuous space. CoNLL 2016 - 20th SIGNLL
Conference on Computational Natural Language Learning, Proceedings pp.
10–21 (2016). https://doi.org/10.18653/v1/k16-1002
[4] Dathathri, S., Madotto, A., Lan, J., Hung, J., Frank, E., Molino, P., Yosin-
ski, J., Liu, R.: Plug and play language models: A simple approach to con-
trolled text generation. arXiv pp. 1–34 (2019)
[5] Hinton, G., Vinyals, O., Dean, J.: Distilling the Knowledge in a Neural
Network pp. 1–9 (2015), http://arxiv.org/abs/1503.02531
[6] Hu, Z., Yang, Z., Liang, X., Salakhutdinov, R., Xing, E.P.: Toward con-
trolled generation of text. 34th International Conference on Machine Learn-
ing, ICML 2017 4, 2503–2513 (2017)
[7] Keskar, N.S., McCann, B., Varshney, L.R., Xiong, C., Socher, R.: CTRL: A
conditional transformer language model for controllable generation. arXiv
pp. 1–18 (2019)
[8] Kingma, D.P., Welling, M.: Auto-encoding variational bayes. 2nd Interna-
tional Conference on Learning Representations, ICLR 2014 - Conference
Track Proceedings (Ml), 1–14 (2014)
[9] Klimt, B., Yang, Y.: The enron corpus: A new dataset for email classification
research pp. 217–226 (2004)
[10] Liu, D., Liu, G.: A Transformer-Based Variational Autoencoder
for Sentence Generation. Proceedings of the International Joint
Conference on Neural Networks 2019-July(July), 1–7 (2019).
https://doi.org/10.1109/IJCNN.2019.8852155
[11] Mansourifar, H., Shi, W.: Deep synthetic minority over-sampling technique.
arXiv 16, 321–357 (2020)
[12] Srinivasan, S., Ravi, V., Alazab, M., Ketha, S., Al-Zoubi, A.M., Kotti
Padannayil, S.: Spam Emails Detection Based on Distributed Word
Embedding with Deep Learning. Studies in Computational Intelligence
919(December), 161–189 (2021)
[13] Tolstikhin, I., Bousquet, O., Gelly, S., Schölkopf, B.: Wasserstein auto-
encoders. arXiv pp. 1–20 (2017)
[14] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N.,
Kaiser, L., Polosukhin, I.: Attention is all you need. Advances in Neural
Information Processing Systems 2017-Decem(Nips), 5999–6009 (2017)
Improving text classifier 9
[15] Vinayakumar, R., Soman, K.P., Poornachandran, P., Akarsh, S.: Applica-
tion of deep learning architectures for cyber security. No. June, Springer
International Publishing (2019)
[16] Vishagini, V., Rajan, A.K.: An Improved Spam Detection Method
with Weighted Support Vector Machine. 2018 International Con-
ference on Data Science and Engineering, ICDSE 2018 (2018).
https://doi.org/10.1109/ICDSE.2018.8527737