Swin-Unet: Unet-Like Pure Transformer For Medical Image Segmentation
Swin-Unet: Unet-Like Pure Transformer For Medical Image Segmentation
1 Introduction
patches. Each patch is treated as a token and fed into the Transformer-based
encoder to learn deep feature representations. The extracted context features are
then up-sampled by the decoder with patch expanding layer, and fused with the
multi-scale features from the encoder via skip connections, so as to restore the
spatial resolution of the feature maps and further perform segmentation predic-
tion. Extensive experiments on multi-organ and cardiac segmentation datasets
indicate that the proposed method has excellent segmentation accuracy and ro-
bust generalization ability. Concretely, our contributions can be summarized as:
(1) Based on Swin Transformer block, we build a symmetric Encoder-Decoder
architecture with skip connections. In the encoder, self-attention from local to
global is realized; in the decoder, the global features are up-sampled to the in-
put resolution for corresponding pixel-level segmentation prediction. (2) A patch
expanding layer is developed to achieve up-sampling and feature dimension in-
crease without using convolution or interpolation operation. (3) It is found in
the experiment that skip connection is also effective for Transformer, so a pure
Transformer-based U-shaped Encoder-Decoder architecture with skip connection
is finally constructed, named Swin-Unet.
2 Related work
CNN-based methods : Early medical image segmentation methods are mainly
contour-based and traditional machine learning-based algorithms [20,21]. With
the development of deep CNN, U-Net is proposed in [3] for medical image seg-
mentation. Due to the simplicity and superior performance of the U-shaped
structure, various Unet-like methods are constantly emerging, such as Res-UNet [7],
Dense-UNet [22], U-Net++ [8] and UNet3+ [9]. And it is also introduced into
the field of 3D medical image segmentation, such as 3D-Unet [6] and V-Net [23].
At present, CNN-based methods have achieved tremendous success in the field
of medical image segmentation due to its powerful representation ability.
Vision transformers : Transformer was first proposed for the machine trans-
lation task in [15]. In the NLP domain, the Transformer-based methods have
achieved the state-of-the-art performance in various tasks [24]. Driven by Trans-
former’s success, the researchers introduced a pioneering vision transformer (ViT)
in [17], which achieved the impressive speed-accuracy trade-off on image recog-
nition task. Compared with CNN-based methods, the drawback of ViT is that it
requires pre-training on its own large dataset. To alleviate the difficulty in train-
ing ViT, Deit [18] describes several training strategies that allow ViT to train
well on ImageNet. Recently, several excellent works have been done baed on
ViT [25,26,19]. It is worth mentioning that an efficient and effective hierarchical
vision Transformer, called Swin Transformer, is proposed as a vision backbone
in [19]. Based on the shifted windows mechanism, Swin Transformer achieved
the state-of-the-art performance on various vision tasks including image classifi-
cation, object detection and semantic segmentation. In this work, we attempt to
use Swin Transformer block as basic unit to build a U-shaped Encoder-Decoder
4 Hu Cao et al.
architecture with skip connections for medical image segmentation, thus provid-
ing a benchmark comparison for the development of Transformer in the medical
image field.
3 Method
3.1 Architecture overview
The overall architecture of the proposed Swin-Unet is presented in Figure. 1.
Swin-Unet consists of encoder, bottleneck, decoder and skip connections. The
basic unit of Swin-Unet is Swin Transformer block [19]. For the encoder, to
transform the inputs into sequence embeddings, the medical images are split into
non-overlapping patches with patch size of 4 × 4. By such partition approach,
the feature dimension of each patch becomes to 4 × 4 × 3 = 48. Furthermore, a
linear embedding layer is applied to projected feature dimension into arbitrary
dimension (represented as C). The transformed patch tokens pass through several
Swin Transformer blocks and patch merging layers to generate the hierarchical
feature representations. Specifically, patch merging layer is responsible for down-
sampling and increasing dimension, and Swin Transformer block is responsible
for feature representation learning. Inspired by U-Net [3], we design a symmet-
ric transformer-based decoder. The decoder is composed of Swin Transformer
block and patch expanding layer. The extracted context features are fused with
multiscale features from encoder via skip connections to complement the loss
of spatial information caused by down-sampling. In contrast to patch merging
layer, a patch expanding layer is specially designed to perform up-sampling. The
patch expanding layer reshapes feature maps of adjacent dimensions into a large
feature maps with 2× up-sampling of resolution. In the end, the last patch ex-
panding layer is used to perform 4× up-sampling to restore the resolution of the
feature maps to the input resolution (W × H), and then a linear projection layer
is applied on these up-sampled features to output the pixel-level segmentation
predictions. We would elaborate each block in the following
Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation 5
Different from the conventional multi-head self attention (MSA) module, swin
transformer block [19] is constructed based on shifted windows. In Figure. 2,
two consecutive swin transformer blocks are presented. Each swin transformer
block is composed of LayerNorm (LN) layer, multi-head self attention module,
residual connection and 2-layer MLP with GELU non-linearity. The window-
based multi-head self attention (W-MSA) module and the shifted window-based
multi-head self attention (SW-MSA) module are applied in the two successive
6 Hu Cao et al.
QK T
Attention(Q, K, V ) = Sof tM ax( √ + B)V, (5)
d
2
where Q, K, V ∈ RM ×d denote the query, key and value matrices. M 2 and d
represent the number of patches in a window and the dimension of the query
or key, respectively. And, the values in B are taken from the bias matrix B̂ ∈
R(2M −1)×(2M +1) .
3.3 Encoder
Patch merging layer : The input patches are divided into 4 parts and con-
catenated together by the patch merging layer. With such processing, the feature
resolution will be down-sampled by 2×. And, since the concatenate operation
results the feature dimension increasing by 4×, a linear layer is applied on the
concatenated features to unify the feature dimension to the 2× the original di-
mension.
3.4 Bottleneck
Since Transformer is too deep to be converged [33], only two successive Swin
Transformer blocks are used to constructed the bottleneck to learn the deep
feature representation. In the bottleneck, the feature dimension and resolution
are kept unchanged.
3.5 Decoder
Corresponding to the encoder, the symmetric decoder is built based on Swin
Transformer block. To this end, in contrast to the patch merging layer used in
the encoder, we use the patch expanding layer in the decoder to up-sample the
extracted deep features. The patch expanding layer reshapes the feature maps
of adjacent dimensions into a higher resolution feature map (2× up-sampling)
and reduces the feature dimension to half of the original dimension accordingly.
Patch expanding layer : Take the first patch expanding layer as an example,
before up-sampling, a linear layer is applied on the input features ( W H
32 × 32 × 8C)
to increase the feature dimension to 2× the original dimension ( W H
32 × 32 × 16C).
Then, we use rearrange operation to expand the resolution of the input features
to 2× the input resolution and reduce the feature dimension to quarter of the
input dimension ( W H W H
32 × 32 × 16C → 16 × 16 × 4C). We will discuss the impact
of using patch expanding layer to perform up-sampling in section 4.5.
4 Experiments
4.1 Datasets
Synapse multi-organ segmentation dataset (Synapse): the dataset in-
cludes 30 cases with 3779 axial abdominal clinical CT images. Following [2,34],
8 Hu Cao et al.
18 samples are divided into the training set and 12 samples into testing set. And
the average Dice-Similarity coefficient (DSC) and average Hausdorff Distance
(HD) are used as evaluation metric to evaluate our method on 8 abdominal or-
gans (aorta, gallbladder, spleen, left kidney, right kidney, liver, pancreas, spleen,
stomach).
Automated cardiac diagnosis challenge dataset (ACDC): the ACDC
dataset is collected from different patients using MRI scanners. For each patient
MR image, left ventricle (LV), right ventricle (RV) and myocardium (MYO)
are labeled. The dataset is split into 70 training samples, 10 validation samples
and 20 testing samples. Similar to [2], only average DSC is used to evaluate our
method on this dataset.
indicates that our approach can achieve better edge predictions. The segmen-
tation results of different methods on the Synapse multi-organ CT dataset are
shown in Figure. 3. It can be seen from the figure that CNN-based methods
tend to have over-segmentation problems, which may be caused by the local-
ity of convolution operation. In this work, we demonstrate that by integrating
Transformer with a U-shaped architecture with skip connections, the pure Trans-
former approach without convolution can better learn both global and long-range
semantic information interactions, resulting in better segmentation results.
10 Hu Cao et al.
Up-sampling DSC Aorta Gallbladder Kidney(L) Kidney(R) Liver Pancreas Spleen Stomach
Bilinear interpolation 76.15 81.84 66.33 80.12 73.91 93.64 55.04 86.10 72.20
Transposed convolution 77.63 84.81 65.96 82.66 74.61 94.39 54.81 89.42 74.41
Patch expand 79.13 85.47 66.53 83.28 79.61 94.29 56.58 90.66 76.60
Skip connection DSC Aorta Gallbladder Kidney(L) Kidney(R) Liver Pancreas Spleen Stomach
0 72.46 78.71 53.24 77.46 75.90 92.60 46.07 84.57 71.13
1 76.43 82.53 60.44 81.36 79.27 93.64 53.36 85.95 74.90
2 78.93 85.82 66.27 84.70 80.32 93.94 55.32 88.35 76.71
3 79.13 85.47 66.53 83.28 79.61 94.29 56.58 90.66 76.60
Effect of the number of skip connections: The skip connections of our Swin-
UNet are added in places of the 1/4, 1/8, and 1/16 resolution scales. By changing
the number of skip connections to 0, 1, 2 and 3 respectively, we explored the
influence of different skip connections on the segmentation performance of the
proposed model. In Table 4, we can see that the segmentation performance of the
model increases with the increase of the number of skip connections. Therefore,
in order to make the model more robust, the number of skip connections is set
as 3 in this work.
Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation 11
Input size DSC Aorta Gallbladder Kidney(L) Kidney(R) Liver Pancreas Spleen Stomach
224 79.13 85.47 66.53 83.28 79.61 94.29 56.58 90.66 76.60
384 81.12 87.07 70.53 84.64 82.87 94.72 63.73 90.14 75.29
Model scale DSC Aorta Gallbladder Kidney(L) Kidney(R) Liver Pancreas Spleen Stomach
tiny 79.13 85.47 66.53 83.28 79.61 94.29 56.58 90.66 76.60
base 79.25 87.16 69.19 84.61 81.99 93.86 58.10 88.44 70.65
Effect of input size: The testing results of the proposed Swin-Unet with
224 × 224, 384 × 384 input resolutions as input are presented in Table. 5. As
the input size increases from 224 × 224 to 384 × 384 and the patch size remains
the same as 4, the input token sequence of Transformer will become larger,
thus leading to improve the segmentation performance of the model. However,
although the segmentation accuracy of the model has been slightly improved,
the computational load of the whole network has also increased significantly. In
order to ensure the running efficiency of the algorithm, the experiments in this
paper are based on 224 × 224 resolution scale as the input.
Effect of model scale: Similar to [19], we discuss the effect of network deep-
ening on model performance. It can be seen from Table. 6 that the increase of
model scale hardly improves the performance of the model, but increases the
computational cost of the whole network. Considering the accuracy-speed trade
off, we adopt the Tiny-based model to perform medical image segmentation.
4.6 Discussion
5 Conclusion
References
1. A. Hatamizadeh, D. Yang, H. Roth, and D. Xu, “Unetr: Transformers for 3d med-
ical image segmentation,” 2021.
2. J. Chen, Y. Lu, Q. Yu, X. Luo, E. Adeli, Y. Wang, L. Lu, A. L. Yuille, and Y. Zhou,
“Transunet: Transformers make strong encoders for medical image segmentation,”
CoRR, vol. abs/2102.04306, 2021.
3. O. Ronneberger, P.Fischer, and T. Brox, “U-net: Convolutional networks for
biomedical image segmentation,” in Medical Image Computing and Computer-
Assisted Intervention (MICCAI), ser. LNCS, vol. 9351. Springer, 2015, pp. 234–
241.
4. K. S. P. J. M.-H. K. Isensee F, Jaeger PF, “nnu-net: a self-configuring method for
deep learning-based biomedical image segmentation,” Nat Methods, vol. 18(2):203-
211, 2021.
5. Q. Jin, Z. Meng, C. Sun, H. Cui, and R. Su, “Ra-unet: A hybrid deep attention-
aware network to extract liver and tumor in ct scans,” Frontiers in Bioengineering
and Biotechnology, vol. 8, p. 1471, 2020.
6. Ö. Çiçek, A. Abdulkadir, S. Lienkamp, T. Brox, and O. Ronneberger, “3d u-net:
Learning dense volumetric segmentation from sparse annotation,” in Medical Image
Computing and Computer-Assisted Intervention (MICCAI), ser. LNCS, vol. 9901.
Springer, Oct 2016, pp. 424–432.
7. X. Xiao, S. Lian, Z. Luo, and S. Li, “Weighted res-unet for high-quality retina vessel
segmentation,” 2018 9th International Conference on Information Technology in
Medicine and Education (ITME), pp. 327–331, 2018.
8. Z. Zhou, M. Rahman Siddiquee, N. Tajbakhsh, and J. Liang, “Unet++: A nested
u-net architecture for medical image segmentation.” Springer Verlag, 2018, pp.
3–11.
9. H. Huang, L. Lin, R. Tong, H. Hu, Q. Zhang, Y. Iwamoto, X. Han, Y.-W. Chen,
and J. Wu, “Unet 3+: A full-scale connected unet for medical image segmentation,”
2020.
10. L.-C. Chen, G. Papandreou, I. Kokkinos, K. Murphy, and A. L. Yuille, “Deeplab:
Semantic image segmentation with deep convolutional nets, atrous convolution,
and fully connected crfs,” IEEE Transactions on Pattern Analysis and Machine
Intelligence, vol. 40, no. 4, pp. 834–848, 2018.
11. Z. Gu, J. Cheng, H. Fu, K. Zhou, H. Hao, Y. Zhao, T. Zhang, S. Gao, and
J. Liu, “Ce-net: Context encoder network for 2d medical image segmentation,”
IEEE Transactions on Medical Imaging, vol. 38, no. 10, pp. 2281–2292, 2019.
12. J. Schlemper, O. Oktay, M. Schaap, M. Heinrich, B. Kainz, B. Glocker, and
D. Rueckert, “Attention gated networks: Learning to leverage salient regions in
medical images,” Medical Image Analysis, vol. 53, pp. 197–207, 2019.
13. X. Wang, R. Girshick, A. Gupta, and K. He, “Non-local neural networks,” in 2018
IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2018, pp.
7794–7803.
14. H. Zhao, J. Shi, X. Qi, X. Wang, and J. Jia, “Pyramid scene parsing network,”
in 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR),
2017, pp. 6230–6239.
Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation 13
30. Y. Xie, J. Zhang, C. Shen, and Y. Xia, “Cotr: Efficiently bridging CNN and
transformer for 3d medical image segmentation,” CoRR, vol. abs/2103.03024,
2021. [Online]. Available: https://arxiv.org/abs/2103.03024
31. H. Hu, J. Gu, Z. Zhang, J. Dai, and Y. Wei, “Relation networks for object detec-
tion,” in 2018 IEEE/CVF Conference on Computer Vision and Pattern Recogni-
tion, 2018, pp. 3588–3597.
32. H. Hu, Z. Zhang, Z. Xie, and S. Lin, “Local relation networks for image recogni-
tion,” in 2019 IEEE/CVF International Conference on Computer Vision (ICCV),
2019, pp. 3463–3472.
33. H. Touvron, M. Cord, A. Sablayrolles, G. Synnaeve, and H. Jégou, “Going deeper
with image transformers,” CoRR, vol. abs/2103.17239, 2021. [Online]. Available:
https://arxiv.org/abs/2103.17239
34. S. Fu, Y. Lu, Y. Wang, Y. Zhou, W. Shen, E. Fishman, and A. Yuille, “Domain
adaptive relational reasoning for 3d multi-organ segmentation,” in Medical Image
Computing and Computer Assisted Intervention – MICCAI 2020, 2020, pp. 656–
666.
35. F. Milletari, N. Navab, and S.-A. Ahmadi, “V-net: Fully convolutional neural net-
works for volumetric medical image segmentation,” in 2016 Fourth International
Conference on 3D Vision (3DV), 2016, pp. 565–571.
36. S. Fu, Y. Lu, Y. Wang, Y. Zhou, W. Shen, E. Fishman, and A. Yuille, “Domain
adaptive relational reasoning for 3d multi-organ segmentation,” Germany, 2020,
pp. 656–666.
37. O. Oktay, J. Schlemper, L. L. Folgoc, M. Lee, M. Heinrich, K. Misawa, K. Mori,
S. McDonagh, N. Y. Hammerla, B. Kainz, B. Glocker, and D. Rueckert, “Attention
u-net: Learning where to look for the pancreas,” IMIDL Conference, 2018.