A Knowledge Distillation Integrated Pruning Method For Vision Transformer
A Knowledge Distillation Integrated Pruning Method For Vision Transformer
A Knowledge-Distillation-Integrated Pruning
Method for Vision Transformer
Bangguo Xu1 , Tiankui Zhang1 , Yapeng Wang2 , Zeren Chen3
1 School of Information and Communication Engineering, Beijing University of Posts and Telecommunicaions,
2022 21st International Symposium on Communications and Information Technologies (ISCIT) | 978-1-6654-9851-7/22/$31.00 ©2022 IEEE | DOI: 10.1109/ISCIT55906.2022.9931309
Abstract—Vision transformers (ViTs) have made remarkable complexity of the model, saves computing resources and
achievements in various computer vision applications such as speeds up the inference speed of the model.
image classification, object detection, and image segmentation. In recent years, with the introduction of Transformer into the
Since the self-attention mechanism introduced by itself can model
the relationship between all pixels of the input image, the per- field of computer vision, its powerful global feature extraction
formance of the ViTs model is significantly improved compared ability has led it to surpass traditional CNN models in terms of
to the traditional CNN network. However, their storage, runtime accuracy. The compression work for the Vision Transformer
memory and computing requirements hinder their deployment (ViT) model has just started. Due to the complex structure of
on edge devices. This paper proposes a ViT pruning method the ViT model introduced by the self-attention mechanism and
with knowledge distillation, which can prune the ViT model and
avoid the performance loss of the model after pruning. Based on the low redundancy of the model parameters, the model can be
the idea that knowledge distillation can make the student model tailored without losing the performance of the model. became
improve the performance of the model by learning the unique a challenge. Since the ViT model structure does not have a
knowledge of the teacher model, the convolution neural network convolutional structure, it cannot directly delete channels to
(CNN) which has the unique ability of parameter sharing and achieve pruning. Using the pruning method previously applied
local receptive field is used as a teacher model to guide the
training of the ViT model and enable the ViT model to obtain the to the CNN model will greatly damage the performance of the
same ability. In addition, some important parts may be cut during model.
pruning, resulting in irreversible loss of model performance.
To solve this problem, this paper designs the importance score
learning module to guide the pruning work, and determines
that the pruning work removes the unimportant parts of the
model. Finally, this paper compares the pruned model with other
methods in terms of accuracy, Floating Point Operations(FLOPs)
and model parameters on ImageNet-1K.
Index Terms—knowledge distillation, network pruning, trans-
former pruning, vision transformer
I. I NTRODUCTION
The development of computer vision technology is insepa- Fig. 1. Parameter comparison of ResNet50 and ViT/B16 models.
rable from the promotion of convolutional neural networks.
At present, the achievements of the CNN model in image The heatmap shown in Fig.1 shows the difference in model
recognition are close to bottleneck, and with the development parameters between the regular CNN model ResNet50 and
of the CNN model, there has been considerable accumulation the regular version ViTB/16 in the ViT model. The heatmap
of lightweight work for it. Neural network pruning methods consists of 30×30 parameters randomly selected from the
have received more and more attention due to their widespread weight matrix. The lighter the color, the closer the parameter
presence in compression tasks. By pruning a large number of is to 0. It can be seen that the parameters of the CNN model
unimportant parameters or directly removing a large number contain a large number of parameters close to 0, so it can
of channels in each layer of the model structure, the size of be tailored without losing model performance. Most of the
the model can be reduced without loss of performance. The parameters of the ViT model are not close to 0, indicating that
trimmed model not only saves a lot of storage space due to the most of the parameters in the ViT model are very important.
reduction of parameters, but also reduces the computational If the pruning method using the CNN model is deleted, it
This work is supported by Key Technology Research Project of Jiangxi will lead to serious loss of model performance, which greatly
Province (20213AAE01007) increases the difficulty of pruning. Therefore, how to prune
211
Authorized licensed use limited to: BEIJING UNIVERSITY OF POST AND TELECOM. Downloaded on November 23,2024 at 04:59:22 UTC from IEEE Xplore. Restrictions apply.
A. Pruning Location Analysis After getting the query matrix Q, the key-value matrix K
First analyze the parameters of the ViT model. In order to and the value matrix V, the multi-head attention is calculated:
obtain the output features in (1), the tokens input by the model
√
Attentioni (Q, K, V ) = Sof tmax QK T / d V (8)
first go through three matrices Q, K and V . This calculation’s
Floating Point Operations (FLOPs) are 3×n×d×(d+(d−1)). Attentiontotal (Q, K, V ) =
Considering that the fully connected layer has a bias calcula- (9)
[Attention1 (Q, K, V ) ; ...; Attentionh (Q, K, V )]
tion at the end, The FLOPs for this part are 3 × n × d × (2d).
Then Q, K and V is calculated by (2), the FLOPs of this part where Attentioni (.) represents the i-th self-attention module
are 4n2 d. in multi-head attention, Attentiontotal (.) represents the con-
After that, as shown in (4) above, the self-attention matrix catenation of the output results of h self-attention modules, and
of dimension (n, d) is dot-multiplied with the fully connected h represents the number of heads in the multi-head attention
layer matrix of dimension (d, d), FLOPs are n × d × (2d), and module.
then the obtained matrix and the input tokens Adding together, In order to solve the problems of performance degradation
FLOPs are n × d. In conclusion, the FLOPs of the ViT model and gradient disappearance during multi-layer model training,
in the MHSA part are 8nd2 + 4n2 d. the output of the multi-head attention module will be cal-
After obtaining the output Y of the MHSA part, Y is culated through the residual structure and the normalization
input to the MLP module for operation, as shown in (5), layer:
considering that the number of hidden layer dimensions of the
Attentionoutput (Q, K, V ) =
fully connected layer is generally 4 times the input dimension, (10)
so hidden The dimension of the layer is 4d. It can be obtained LN (Attentiontotal (Q, K, V ) + Temb )
by calculation: this part’s FLOPs are 16nd2 . whereLN denotes the normalization layer, Attentionoutput (.)
As for the operation of the normalization layer and the denotes the final output of the multi-head attention module,
activation function layer, the input is only calculated once, and is also the input of the subsequent MLP module. The
and its FLOPs are n × d. Compared with MHSA and MLP, it MLP module is actually composed of two fully connected
can be ignored. Therefore, the pruning work of the model is layers, and the output of the MLP module is represented by
mainly for the MHSA and MLP parts. Z:
B. Distillation Token Z = LN
The traditional ViT model will put a given image X ∈ (11)
(Attentionout + F C2 (F C1 (Attentionout )))
RH×W ×C , where H represents the height of the image, W
represents the width of the image, and C represents the number For a ViT model with layers L, the output feature vector
of channels of the image, which is converted into N patches. corresponding to the class token of the last layer is passed
After that, a fully connected layer is used to convert each patch through the classifier to obtain the predicted distribution:
into a patch token with a size of 16 × 16 × 3. In addition to Ppredict = sof tmax (F C (Zpredict )) (12)
the number of patch tokens of N , the ViT model will also
add a class token to interact with the real value to predict After getting the predicted distribution of the classifier,
the classification result. On this basis, this method adds a interact it with the true value of the sample, and use the cross-
distillation token to interact with the output of the teacher entropy loss function to get the loss function:
model to realize the teacher The learning of model knowledge, XX
and it also interacts with other tokens normally to achieve Lbase = − 1[yi =c] · log (Ppredict (yi = c)) (13)
global feature extraction. i∈N c∈C
Temb = [Tcls ; Tpatch ; Tdist ] (6) At the same time, the output feature vector corresponding to
the distillation token is extracted, and then a separate classifier
where Temb is the input of the ViT model introduced into is set for it to obtain the predicted distribution:
the distillation module, which will extract global feature in-
formation through the coding layer module of the ViT model, Pdist = Sof tmax (F C (Zdist )) (14)
Tpatch is the token converted from the input image, Tcls is the
classification token used for prediction in the traditional ViT Interact the prediction distribution obtained by the above
model, and Tdist is the distillation token introduced by this formula with the prediction results of the teacher model, and
method. also use the cross entropy loss function to construct the loss
After getting the model input Temb , perform self-attention function:
XX
calculation on it. The specific implementation method is to Ldist = − 1[yi =c] · log (Pdist (yi = c)) (15)
make Temb obtain the query matrix Q, the key value matrix i∈N c∈C
K and the value matrix V through the fully connected layer: where c is the label output by the teacher model after the
Q = F CQ (Temb ) , K = F CK (Temb ) , softmax function, the type is the same as the real value but
(7) the value is slightly different.
V = F CV (Temb )
212
Authorized licensed use limited to: BEIJING UNIVERSITY OF POST AND TELECOM. Downloaded on November 23,2024 at 04:59:22 UTC from IEEE Xplore. Restrictions apply.
Finally, α is used as a hyperparameter to balance distillation where H is the Hessian matrix equivalent to the second
loss and conventional loss, and the loss function is defined as: derivative of L(ω ∗ ; X; y) to ω ∗ , and the optimal solution of
the loss function (21) is:
Lglobal = aLdist + (1 − a) Lbase (16)
α
In experiments, we found that the final accuracy of the ωi = sign (ωi∗ ) max |ωi∗ | − ,0 (22)
Hi,i
model obtained by the proposed method as the student model
can not only surpass that before distillation, but even surpass It can be clearly seen that the solution of the loss function
the teacher model. This shows that the student model can is sparse after the introduction of the L1 norm. After obtaining
learn the unique inductive bias of the CNN model by distilling the sparse importance score learning module, we also get the
the CNN model, that is, the unique parameter sharing ability importance score of each dimension in the pruning layer, and
of the CNN and the local receptive field, thereby improving then sort the dimensions according to the score, and regard the
the student model’s ability to deal with image problems. In dimension with low score as the unimportant part and Drop
the proposed method, if only the output of the judger of the it. The workflow of the module which is after sparse training
teacher model is learned, it is easy to cause overfitting. In is as follows:
order to solve this problem, we introduce soft distillation into First set the importance score of the module output to a,
the proposed method, and control the influence of the teacher get a threshold of γ according to the pre-defined pruning rate,
model on the student model by adjusting the distillation and set the a below the threshold γ to zero, and the a above
temperature T. The formula of soft distillation is as follows: the threshold γ to 1. to get discrete a∗ .
Multiply the Temb from (6) by the discrete a∗ to get the
Lsof t = T 2 KL trimmed model input, named Temb ∗
:
(17)
(Sof tmax (Zpredict /T ) , Sof tmax (Zteacher /T ))
Temb ∗ = P rune(Temb ) (23)
where KL(.) represents the KL divergence formula, the
proposed method sets the hyperparameter β to control the Due to the unstructured pruning that has been completed
influence of soft distillation on the overall distillation loss of in Temb , some of the weights of the fully connected layers
the model, so the loss function of the final model is expressed F CQ , F CK and F CV that operate with it in (7) also lose their
as: meaning (the reason is that the essence of the fully connected
layer operation is actually a matrix dot product operation, if
Lglobal = aLdist + (1 − a) Lbase + βLsof t (18) some weights of the input of the fully connected layer are
C. Importance Score Learning Module set to zero, it is equivalent to setting a row or a column of
the matrix to zero), so the meaningless weight of the fully
In order to ensure that the pruned dimension is the unim- ∗
connected layer can be directly set to zero to get F CQ ∗
, F CK ,
portant or even redundant part of the model, this method ∗
F CV , and then connect The following operation:
introduces the importance score judgment module to evaluate
the importance of the dimension. The specific method is to Q∗ = F CQ ∗ (Temb ∗ ) , K ∗ = F CK ∗ (Temb ∗ ) ,
(24)
add a fully connected layer before the layer to be pruned, the V ∗ = F CV ∗ (Temb ∗ )
importance score of the corresponding dimension is obtained
by learning the parameters of this layer. The dimension with After the three matrices of Q∗ , K ∗ and V ∗ are ob-
high score is reserved and deleted otherwise. In order to tained, the operations of (8) to (10) are performed to obtain
prevent the importance scores from being very close to make Attentionoutput (Q∗ , K ∗ , V ∗ ), and then the importance score
it difficult to continue the pruning work, this method sparsely decision module located in the MLP part of the model is used
trains the parameters of this layer separately by adding the L1 to prune F C1 and F C2 of (11). Set the pruned fully connected
norm as a penalty item to the loss function as follows. First, layers to F C1∗ and F C2∗ . And the output is:
the loss function of adding the penalty term is as follows: Z ∗ = LN (Attentionoutput (Q∗ , K ∗ , V ∗ ) +
(25)
J (ω; X, y) = L (ω; X, y) + αΩ (ω) (19) F C2∗ (F C1∗ (Attentionoutput (Q∗ , K ∗ , V ∗ ))))
where X is the input image, y is the label corresponding to The pruned F C1 , F C2 in (25) represent the two fully
the image, ω is the parameter of the model, L is the initial connected layers of the MLP part of the model. Through
the above operations, this method realizes the pruning of the
Pn of the model, and Ω is the penalty term, Ω(ω) =
loss function
∥ω∥1 = i=1 |ωi |, let ω ∗ be the optimal solution of L, then: MHSA and MLP parts.
′
L (ω; X; y) = L (ω ∗ ; X; y) + L (ω ∗ ; X; y) (ω − ω ∗ ) + D. Pruning Strategy
1 ′′ ∗ 2
(20) Considering the structural complexity of the ViT model
L (ω ; X; y) (ω − ω ∗ )
2 itself and its low parameter redundancy, using the traditional
After deduction, we can get: training-pruning-finetuning pruning strategy will lead to irre-
1 versible performance loss of the model. Referring to the most
2
J (ω; X, y) = L (ω ∗ ; X, y) + H(ω − ω ∗ ) + α∥ω∥1 (21) advanced pruning strategy currently applied to CNN models,
2
213
Authorized licensed use limited to: BEIJING UNIVERSITY OF POST AND TELECOM. Downloaded on November 23,2024 at 04:59:22 UTC from IEEE Xplore. Restrictions apply.
this method uses sparse distillation to train this strategy. token and distillation token are added to Tpatch as in (6) to
Specifically, it first initializes the importance scores of the form Temb .
pruned dimensions of the model, and then performs the di- After getting Temb , the model will input Temb to the impor-
mensions according to the scores obtained from initialization. tance score decision module, which will zero out the tokens
Sorting, and zeroing the unimportant parameters instead of with low importance scores and prevent it from interacting
completely deleting them to facilitate subsequent exploration with other tokens, since the unimportant tokens are zeroed and
of sparsity. Then, the distillation loss is introduced to perform here It does not participate in the calculation in the secondary
distillation training on the model. After the distillation training cycle, and the corresponding MHSA and MLP parts are also
is completed, the sparse strategy of the sparse training model is equivalent to being deleted.
∗ ∗
explored to obtain a new sparse model. After that, the sparse Set the pruned Temb to Temb . After Temb passes through
model is trained by distillation training and then repeat the equations (6) to (18), the loss function Lglobal will be obtained,
above operations. The specific process is as follows: and then use the loss function to update the model parameters
First, initialize the sparse ratio of each layer, let W = without the importance score learning module:
(W (1) , ..., W (L) ) be the parameter of the ViT model, and L
represent the total number of layers in the model. Each layer W = W − η · ∇W Lglobal (29)
is initialized with a sparse ratio of s = (s(1) ; ...; s(L) ), where where η represents the learning rate.
s(l) represents the ratio of layer l: In order to enforce sparsity of importance scores, impor-
tance score learning module’s loss function must be J in
(τmax + τmin ) t × 2π
s(l) = 1 + cos (26) (14), and then the loss function is used to update the model
4 Tend
parameters of the importance score decision module:
where τmax represents the preset maximum sparsity rate, τmin
represents the preset minimum sparsity rate, t represents the W = W − η · ∇W J (W ) (30)
number of iterations of the current sparse exploration and
is initialized to 0, and Tend represents the total number of After updating the model parameters, the model uses a new
iterations of the sparse exploration. pruning strategy for each layer of the model according to (21),
Considering that the importance score module has not been and performs pruning and pruning according to the importance
trained in the first iteration, the MHSA and MLP parts are first score decision module and distillation training on the model
trimmed according to the gradient strategy, and the importance parameters until the number of repetitions reaches the preset
score formula of the MHSA part is as follows: number cycle.
IV. E XPERIMENTS
(l,h) T ∂L X (l)
H = A(l,h) · (27) In order to compare experiments with other ViT pruning
∂A(l,h)
methods, the data set used in this chapter is ImageNet-1K.
where H (l,h) represents the importance score of the hth The training set and test set are divided as shown in Table 1.
attention head of the lth layer based on the gradient strategy.
The higher the score, the more important the attention head TABLE I
DATASET D ETAILS
is, X (l) represents the tokens input to the lth layer, and A(l,h)
represents the output features of the attention head matrix L(.) Dataset Name Number of Classes Training set Testing set
represents the cross-entropy loss function. ImageNet-1K 1000 1281167 500000
The importance score formula for the MLP part is similar
to the above formula: In the paper, the training batch size of all datasets is 16,
the optimizer is AdamW, the learning rate of the classifier is
∂L A(l,h)
W (l,h) = O(l,h)
T
· (28) initialized to 0.001, the learning rate of the backbone network
∂O(l,h) network is 0.0001, and the learning first sets a warm-up
learning rate of size 0.0001 for 30 cycles. After the 31st epoch
where W (l,h) represents the importance score of the hth using the initial learning rate, it will be multiplied by a factor
dimension of the lth layer of the MLP part which is based on of 0.1 every 8 epochs to continue to converge.
the gradient strategy. The higher the score, the more important Through simulation experiments, the method proposed in
the dimension is. A(l,h) indicates that the output of the MHSA this study achieves a classification accuracy of 82.19% on
part of the lth layer is the MLP part of the lth layer. Input, the dataset ImageNet-1K after pruning the model, while the
O(l,h) represents the output features of the MLP part. ViT-B/16 model used in the experiment is only 77.9% on
After the unstructured pruning is completed, the model the dataset ImageNet-1K before pruning. That is to say, in
is updated using the distillation loss function. The model the case of pruning the model parameters to achieve model
parameter update process is as follows: First, a given input compression, the performance of this method not only has
image is set, and a series of tokens are obtained through no loss but is improved by 4.29%, which has surpassed most
transformation and set as Tpatch , and then the classification existing research methods. The experimental results and the
214
Authorized licensed use limited to: BEIJING UNIVERSITY OF POST AND TELECOM. Downloaded on November 23,2024 at 04:59:22 UTC from IEEE Xplore. Restrictions apply.
comparison results with the existing research methods of the are not important are pruned. At the same time, we have added
corresponding dataset are shown in Table 2. The table gives the distillation to the input part of the model, and have realized the
backbone network used and the classification accuracy, which learning of teacher model knowledge by interacting with the
are reproduced according to the official published results output results of the teacher model through distillation. Finally,
of the original paper or the same experimental environment the process of pruning while distillation has been designed to
as this experiment. Overall, the performance of the method realize the combination of pruning and knowledge distillation.
of introducing knowledge distillation into pruning has been
R EFERENCES
greatly improved compared with most pruning or knowledge
distillation methods. It can be seen that the introduction of [1] Han K, Xiao A, Wu E, et al. Transformer in transformer[J]. Advances
in Neural Information Processing Systems, 2021, 34.
knowledge distillation into the field of pruning is feasible and [2] Wang W, Xie E, Li X, et al. Pyramid vision transformer: A versatile
effective. backbone for dense prediction without convolutions[C]//Proceedings of
the IEEE/CVF International Conference on Computer Vision. 2021: 568-
TABLE II 578.
I MAGE N ET-1K DATASET CLASSIFICATION C OMPREHENSIVE COMPARISON [3] Song Han, Jeff Pool, John Tran, and William Dally. Learning both
weights and connections for efficient neural network. In C. Cortes, N. D.
Lawrence, D. D. Lee, M. Sugiyama, and R. Garnett, editors, Advances
Other Methods Backbone Network Accuracy(Top1) Params FLOPs in Neural Information Processing Systems 28, pages 1135–1143.Curran
Han et al. [1] TNT/B 83.6% 65.6M 42.3B
Wang et al. [2] PVT 81.7% 61.4M 39.6B
Associates, Inc., 2015.
Li et al. [8] T2T-ViT 82.3% 64.1M 41.3B [4] Pavlo Molchanov, Arun Mallya, Stephen Tyree, Iuri Frosio, and Jan
Hugo et al. [11] DeiT/B 81.9% 86M 55.5B Kautz. Importance estimation for neural network pruning. In Proceed-
Stéphane et al. [12] ConViT 82.4% 86M 55.5B
Yang et al. [13] NViT 83.1% 86M 55.5B ings of the IEEE/CVF Conference on Computer Vision and Pattern
Zhu et al. [14] VTP 80.7% 48M 31.0B Recognition, pages 11264–11272, 2019.
He et al. [15] SPViT 81.6% 62.3M 40.2B [5] LeCun, Yann, John S. Denker, and Sara A. Solla. 1989. “Optimal Brain
PDIP(ours) ViT-B/16 82.2% 56.5M 36.4B
Damage.” In Proceedings of the 2nd International Conference on Neural
Information Processing Systems (Nips), 2:598–605.1.
[6] Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan,
In order to prove that changing the model pruning strategy and Changshui Zhang.Learning efficient convolutional networks through
during training will not affect the final convergence of the network slimming. In Proceedings of the IEEE International Conference
model, this paper also analyzes the degree of convergence on Computer Vision, pages 2736–2744, 2017.
[7] Yihui He, Xiangyu Zhang, and Jian Sun. Channel pruning for accelerat-
when training the model. As shown in Fig.2 below, although ing very deep neural networks. In Proceedings of the IEEE International
the model fluctuated slightly during the training process, the Conference on Computer Vision, 2017.
overall trend of the accuracy rate was a steady increase [8] Yuan L, Chen Y, Wang T, et al. Tokens-to-token vit: Training vi-
sion transformers from scratch on imagenet[C]//Proceedings of the
and reached a maximum value around the 300th epochs and IEEE/CVF International Conference on Computer Vision. 2021: 558-
eventually stabilized at this value. To train the model we train 567.
on an RTX3090 graphics card for about 630 hours. Although [9] Geoffrey E. Hinton, Oriol Vinyals, and J. Dean. Distilling the knowledge
in a neural network. arXiv preprint arXiv:1503.02531, 2015.
the training cost is huge, the inference speed of the model [10] Samira Abnar, Mostafa Dehghani, and Willem Zuidema. Transfer-
after pruning is significantly improved due to the reduction of ring inductive biases through knowledge distillation. arXiv preprint
FLOPs and a large amount of memory space is saved due to arXiv:2006.00555, 2020.
[11] Touvron H, Cord M, Douze M, et al. Training data-efficient image trans-
a large number of parameters being pruned. formers & distillation through attention[C]//International Conference on
Machine Learning. PMLR, 2021: 10347-10357.
[12] d’Ascoli S, Touvron H, Leavitt M L, et al. Convit: Improving vision
transformers with soft convolutional inductive biases[C]//International
Conference on Machine Learning. PMLR, 2021: 2286-2296.
[13] Yang H, Yin H, Molchanov P, et al. Nvit: Vision transformer compres-
sion and parameter redistribution[J]. arXiv preprint arXiv:2110.04869,
2021.
[14] Zhu M, Tang Y, Han K. Vision Transformer Pruning[J]. arXiv preprint
arXiv:2104.08500, 2021.
[15] He H, Liu J, Pan Z, et al. Pruning self-attentions into convolutional
layers in single path[J]. arXiv preprint arXiv:2111.11802, 2021.
[16] Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[J].
Advances in neural information processing systems, 2017, 30.
V. C ONCLUSION
In this paper, we have proposed a new evaluation index of
importance score to make sure that all parts of the model that
215
Authorized licensed use limited to: BEIJING UNIVERSITY OF POST AND TELECOM. Downloaded on November 23,2024 at 04:59:22 UTC from IEEE Xplore. Restrictions apply.