vertopal.com_lab6
vertopal.com_lab6
1. Environment setup
2. Understand basic operations in Transformer.
Reference.
1. https://uvadlc-notebooks.readthedocs.io.
2. https://github.com/phlippe/uvadlc_notebooks
## Standard libraries
import os
import numpy as np
import random
import math
import json
from functools import partial
## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
## Torchvision
import torchvision
from torchvision.datasets import CIFAR100
from torchvision import transforms
# PyTorch Lightning
try:
import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning
installed by default. Hence, we do it here if necessary
!pip install --quiet pytorch-lightning>=1.4
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial6"
# Ensure that all operations are deterministic on GPU (if used) for
reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Device: cuda:0
Two pre-trained models are downloaded below. Make sure to have adjusted
your CHECKPOINT_PATH before running this code if not already done.
import urllib.request
from urllib.error import HTTPError
# Github URL where saved models are stored for this tutorial
base_url =
"https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial6/"
# Files to download
pretrained_files = ["ReverseTask.ckpt", "SetAnomalyTask.ckpt"]
# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
file_path = os.path.join(CHECKPOINT_PATH, file_name)
if "/" in file_name:
os.makedirs(file_path.rsplit("/",1)[0], exist_ok=True)
if not os.path.isfile(file_path):
file_url = base_url + file_name
print(f"Downloading {file_url}...")
try:
urllib.request.urlretrieve(file_url, file_path)
except HTTPError as e:
print("Something went wrong. Please try to download the file from
the GDrive folder, or contact the author with the full output including the
following error:\n", e)
Downloading
https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial6/
ReverseTask.ckpt...
Downloading
https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial6/
SetAnomalyTask.ckpt...
What is Attention?
The weights of the average are calculated by a softmax over all score
function outputs. Hence, we assign those value vectors a higher weight
whose corresponding key is most similar to the query. If we try to
describe it with pseudo-math, we can write:
$$
\alpha_i = \frac{\exp\left(f_{attn}\left(\text{key}_i, \text{query}\right)\right)}
{\sum_j \exp\left(f_{attn}\left(\text{key}_j, \text{query}\right)\right)}, \
hspace{5mm} \text{out} = \sum_i \alpha_i \cdot \text{value}_i
$$
For every word, we have one key and one value vector. The query is
compared to all keys with a score function (in this case the dot
product) to determine the weights. The softmax is not visualized for
simplicity. Finally, the value vectors of all words are averaged using
the attention weights.
Most attention mechanisms differ in terms of what queries they use, how
the key and value vectors are defined, and what score function is used.
The attention applied inside the Transformer architecture is called
self-attention. In self-attention, each sequence element provides a key,
value, and query. For each element, we perform an attention layer where
based on its query, we check the similarity of the all sequence
elements' keys, and returned a different, averaged value vector for each
element. We will now go into a bit more detail by first looking at the
specific implementation of the attention mechanism which is in the
Transformer case the scaled dot product attention.
$$\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
The matrix multiplication QK^(T) performs the dot product for every
possible pair of queries and keys, resulting in a matrix of the shape
T × T. Each row represents the attention logits for a specific element i
to all other elements in the sequence. On these, we apply a softmax and
multiply with the value vector to obtain a weighted mean (the weights
being determined by the attention). Another perspective on this
attention mechanism offers the computation graph which is visualized
below (figure credit - Vaswani et al., 2017).
If we do not scale down the variance back to ∼ σ², the softmax over the
logits will already saturate to 1 for one random element and 0 for all
others. The gradients through the softmax will be close to zero so that
we can't learn the parameters appropriately. Note that the extra factor
of σ², i.e., having σ⁴ instead of σ², is usually not an issue, since we
keep the original variance σ² close to 1 anyways.
The block Mask (opt.) in the diagram above represents the optional
masking of specific entries in the attention matrix. This is for
instance used if we stack multiple sequences with different lengths into
a batch. To still benefit from parallelization in PyTorch, we pad the
sentences to the same length and mask out the padding tokens during the
calculation of the attention values. This is usually done by setting the
respective attention logits to a very low value.
After we have discussed the details of the scaled dot product attention
block, we can write a function below which computes the output features
given the triple of queries, keys, and values:
Note that our code above supports any additional dimensionality in front
of the sequence length so that we can also use it for batches. However,
for a better understanding, let's generate a few random queries, keys,
and value vectors, and calculate the attention outputs:
seq_len, d_k = 3, 2
pl.seed_everything(42)
q = torch.randn(seq_len, d_k)
k = torch.randn(seq_len, d_k)
v = torch.randn(seq_len, d_k)
values, attention = scaled_dot_product(q, k, v)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("Values\n", values)
print("Attention\n", attention)
INFO:lightning_fabric.utilities.seed:Seed set to 42
Q
tensor([[ 0.3367, 0.1288],
[ 0.2345, 0.2303],
[-1.1229, -0.1863]])
K
tensor([[ 2.2082, -0.6380],
[ 0.4617, 0.2674],
[ 0.5349, 0.8094]])
V
tensor([[ 1.1103, -1.6898],
[-0.9890, 0.9580],
[ 1.3221, 0.8172]])
Values
tensor([[ 0.5698, -0.1520],
[ 0.5379, -0.0265],
[ 0.2246, 0.5556]])
Attention
tensor([[0.4028, 0.2886, 0.3086],
[0.3538, 0.3069, 0.3393],
[0.1303, 0.4630, 0.4067]])
Before continuing, make sure you can follow the calculation of the
specific values here, and also check it by hand. It is important to
fully understand how the scaled dot product attention is calculated.
Multi-Head Attention
$$
\begin{split}
\text{Multihead}(Q,K,V) & = \text{Concat}(\text{head}_1,...,\
text{head}_h)W^{O}\\
\text{where } \text{head}_i & = \text{Attention}(QW_i^Q,KW_i^K, VW_i^V)
\end{split}
$$
# http://jalammar.github.io/illustrated-transformer/
class MultiheadAttention(nn.Module):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self._reset_parameters()
def _reset_parameters(self):
# Original Transformer initialization, see PyTorch documentation
nn.init.xavier_uniform_(self.qkv_proj.weight)
self.qkv_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.o_proj.weight)
self.o_proj.bias.data.fill_(0)
if return_attention:
return o, attention
else:
return o
Transformer Encoder
Next, we will look at how to apply the multi-head attention block inside
the Transformer architecture.
$$
\begin{split}
\text{FFN}(x) & = \max(0, xW_1+b_1)W_2 + b_2\\
x & = \text{LayerNorm}(x + \text{FFN}(x))
\end{split}
$$
This MLP adds extra complexity to the model and allows transformations
on each sequence element separately. You can imagine as this allows the
model to "post-process" the new information added by the previous
Multi-Head Attention, and prepare it for the next attention block.
Usually, the inner dimensionality of the MLP is 2-8× larger than
d_(model), i.e. the dimensionality of the original input x. The general
advantage of a wider layer instead of a narrow, multi-layer MLP is the
faster, parallelizable execution.
class EncoderBlock(nn.Module):
# Attention layer
self.self_attn = MultiheadAttention(input_dim, input_dim, num_heads)
# Two-layer MLP
self.linear_net = nn.Sequential(
nn.Linear(input_dim, dim_feedforward),
nn.Dropout(dropout),
nn.ReLU(inplace=True),
nn.Linear(dim_feedforward, input_dim)
)
# MLP part
linear_out = self.linear_net(x)
x = x + self.dropout(linear_out)
x = self.norm2(x)
return x
Based on this block, we can implement a module for the full Transformer
encoder.
Positional encoding
$$
PE_{(pos,i)} = \begin{cases}
\sin\left(\frac{pos}{10000^{i/d_{\text{model}}}}\right) & \text{if}\hspace{3mm}
i \text{ mod } 2=0\\
\cos\left(\frac{pos}{10000^{(i-1)/d_{\text{model}}}}\right) & \
text{otherwise}\\
\end{cases}
$$
The positional encoding is implemented below. The code is taken from the
PyTorch tutorial about Transformers on NLP and adjusted for our
purposes.
class PositionalEncoding(nn.Module):
[]
You can clearly see the sine and cosine waves with different wavelengths
that encode the position in the hidden dimensions. Specifically, we can
look at the sine/cosine wave for each hidden dimension separately, to
get a better intuition of the pattern. Below we visualize the positional
encoding for the hidden dimensions 1, 2, 3 and 4.
sns.set_theme()
fig, ax = plt.subplots(2, 2, figsize=(12,4))
ax = [a for a_list in ax for a in a_list]
for i in range(len(ax)):
ax[i].plot(np.arange(1,17), pe[i,:16], color=f'C{i}', marker="o",
markersize=6, markeredgecolor="black")
ax[i].set_title(f"Encoding in hidden dimension {i+1}")
ax[i].set_xlabel("Position in sequence", fontsize=10)
ax[i].set_ylabel("Positional encoding", fontsize=10)
ax[i].set_xticks(np.arange(1,17))
ax[i].tick_params(axis='both', which='major', labelsize=10)
ax[i].tick_params(axis='both', which='minor', labelsize=8)
ax[i].set_ylim(-1.2, 1.2)
fig.subplots_adjust(hspace=0.8)
sns.reset_orig()
plt.show()
[]
As we can see, the patterns between the hidden dimension 1 and 2 only
differ in the starting angle. The wavelength is 2π, hence the repetition
after position 6. The hidden dimensions 2 and 3 have about twice the
wavelength.
## Standard libraries
import os
import numpy as np
import random
import math
import json
from functools import partial
from PIL import Image
## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
## Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
# PyTorch Lightning
try:
import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning
installed by default. Hence, we do it here if necessary
!pip install --quiet pytorch-lightning>=1.4
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
# Import tensorboard
%load_ext tensorboard
# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial15"
# Ensure that all operations are deterministic on GPU (if used) for
reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
<ipython-input-11-4eb4ebd6ea16>:15: DeprecationWarning:
`set_matplotlib_formats` is deprecated since IPython 7.23, directly use
`matplotlib_inline.backend_inline.set_matplotlib_formats()`
set_matplotlib_formats('svg', 'pdf') # For export
INFO:lightning_fabric.utilities.seed:Seed set to 42
Device: cuda:0
import urllib.request
from urllib.error import HTTPError
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/"
# Files to download
pretrained_files = ["tutorial15/ViT.ckpt",
"tutorial15/tensorboards/ViT/events.out.tfevents.ViT",
"tutorial5/tensorboards/ResNet/events.out.tfevents.resnet"]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
file_path = os.path.join(CHECKPOINT_PATH, file_name.split("/",1)[1])
if "/" in file_name.split("/",1)[1]:
os.makedirs(file_path.rsplit("/",1)[0], exist_ok=True)
if not os.path.isfile(file_path):
file_url = base_url + file_name
print(f"Downloading {file_url}...")
try:
urllib.request.urlretrieve(file_url, file_path)
except HTTPError as e:
print("Something went wrong. Please try to download the file from
the GDrive folder, or contact the author with the full output including the
following error:\n", e)
Downloading
https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial15/ViT.ckpt...
Downloading
https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial15/
tensorboards/ViT/events.out.tfevents.ViT...
Downloading
https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial5/
tensorboards/ResNet/events.out.tfevents.resnet...
We load the CIFAR10 dataset below. We use the same setup of the datasets
and data augmentations as for the CNNs. The constants in the
transforms.Normalize correspond to the values that scale and shift the
data to a zero mean and standard deviation of one.
test_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.49139968,
0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784])
])
# For training, we add some augmentation. Networks are too powerful and would
overfit.
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((32,32),scale=(0.8,1.0),ratio=(0.9,1.1)),
transforms.ToTensor(),
transforms.Normalize([0.49139968,
0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784])
])
# Loading the training dataset. We need to split it into a training and
validation part
# We need to do a little trick because the validation set should not use the
augmentation.
train_dataset = CIFAR10(root=DATASET_PATH, train=True,
transform=train_transform, download=True)
val_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=test_transform,
download=True)
pl.seed_everything(42)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000])
pl.seed_everything(42)
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])
# We define a set of data loaders that we can use for various purposes later.
train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True,
drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=128, shuffle=False,
drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False,
drop_last=False, num_workers=4)
# Visualize some examples
NUM_IMAGES = 4
CIFAR_images = torch.stack([val_set[idx][0] for idx in range(NUM_IMAGES)],
dim=0)
img_grid = torchvision.utils.make_grid(CIFAR_images, nrow=4, normalize=True,
pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0)
plt.figure(figsize=(8,8))
plt.title("Image examples of the CIFAR10 dataset")
plt.imshow(img_grid)
plt.axis('off')
plt.show()
plt.close()
[]
We will walk step by step through the Vision Transformer, and implement
all parts by ourselves. First, let's implement the image preprocessing:
an image of size N * N has to be split into (N/M)² patches of size
M * M. These represent the input words to the Transformer.
Let's take a look at how that works for our CIFAR examples above. For
our images of size 32 × 32, we choose a patch size of 4. Hence, we
obtain sequences of 64 patches of size 4 × 4. We visualize them below:
[]
After we have looked at the preprocessing, we can now start building the
Transformer model. Since we have discussed the fundamentals of
Multi-Head Attention, we will use the PyTorch module
nn.MultiheadAttention here. Further, we use the Pre-Layer Normalization
version of the Transformer blocks proposed by Ruibin Xiong et al. in
2020. The idea is to apply Layer Normalization not in between residual
blocks, but instead as a first layer in the residual blocks. This
reorganization of the layers supports better gradient flow and removes
the necessity of a warm-up stage. A visualization of the difference
between the standard Post-LN and the Pre-LN version is shown
below.[pre_layer_norm.svg]
class AttentionBlock(nn.Module):
self.layer_norm_1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads,
dropout=dropout)
self.layer_norm_2 = nn.LayerNorm(embed_dim)
self.linear = nn.Sequential(
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, embed_dim),
nn.Dropout(dropout)
)
Now we have all modules ready to build our own Vision Transformer.
Besides the Transformer encoder, we need the following modules:
- An MLP head that takes the output feature vector of the CLS token,
and maps it to a classification prediction. This is usually
implemented by a small feed-forward network or even a single linear
layer.
class VisionTransformer(nn.Module):
self.patch_size = patch_size
# Layers/Networks
self.input_layer = nn.Linear(num_channels*(patch_size**2), embed_dim)
self.transformer = nn.Sequential(*[AttentionBlock(embed_dim,
hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers)])
self.mlp_head = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, num_classes)
)
self.dropout = nn.Dropout(dropout)
# Parameters/Embeddings
self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
self.pos_embedding =
nn.Parameter(torch.randn(1,1+num_patches,embed_dim))
# Apply Transforrmer
x = self.dropout(x)
x = x.transpose(0, 1)
x = self.transformer(x)
class ViT(pl.LightningModule):
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
milestones=[100,150], gamma=0.1)
return [optimizer], [lr_scheduler]
self.log(f'{mode}_loss', loss)
self.log(f'{mode}_acc', acc)
return loss
ViT Inference
import time
def test_inference_time(model, input_tensor, num_runs=100, device='cuda' if
torch.cuda.is_available() else 'cpu'):
"""
Test the inference runtime of the VisionTransformer model.
Parameters:
model: VisionTransformer model instance
input_tensor: Input tensor, shape (batch_size, num_channels, height,
width)
num_runs: Number of inference runs to compute average time
device: Device to run on ('cuda' or 'cpu')
Returns:
avg_time: Average inference time per run (seconds)
"""
# Move model and input to the specified device
model = model.to(device)
input_tensor = input_tensor.to(device)
return avg_time
# Model parameters
model_kwargs = {
'embed_dim': 256,
'hidden_dim': 512,
'num_channels': 3,
'num_heads': 8,
'num_layers': 3,
'num_classes': 10,
'patch_size': 4,
'num_patches': 64,
'dropout': 0.2
}
Parameters:
model_kwargs: Dictionary of initialization parameters for
VisionTransformer
input_tensor: Input tensor, shape (batch_size, num_channels, height,
width)
layer_range: Range of layer counts to test (list or range)
num_runs: Number of inference runs per test
device: Device to run on ('cuda' or 'cpu')
Returns:
latencies: List of average inference times for each layer count
"""
latencies = []
return latencies
[]
Parameters:
model_kwargs: Dictionary of initialization parameters for
VisionTransformer
input_tensor: Input tensor, shape (batch_size, num_channels, height,
width)
head_range: Range of attention head counts to test (list or range)
num_runs: Number of inference runs per test
device: Device to run on ('cuda' or 'cpu')
Returns:
latencies: List of average inference times for each head count
valid_heads: List of valid head counts tested
"""
embed_dim = model_kwargs['embed_dim']
# Filter valid num_heads to ensure embed_dim is divisible by num_heads
valid_heads = [h for h in head_range if embed_dim % h == 0]
if not valid_heads:
raise ValueError(f"No valid num_heads in {head_range} can divide
embed_dim={embed_dim}")
latencies = []
plt.show()
[]
ViT Training
def train_model(**kwargs):
trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"),
accelerator="gpu" if str(device).startswith("cuda")
else "cpu",
devices=1,
max_epochs=180,
callbacks=[ModelCheckpoint(save_weights_only=True,
mode="max", monitor="val_acc"),
LearningRateMonitor("epoch")])
trainer.logger._log_graph = True # If True, we plot the computation
graph in tensorboard
trainer.logger._default_hp_metric = None # Optional logging argument that
we don't need
# Check whether pretrained model exists. If yes, load it and skip training
pretrained_filename = os.path.join(CHECKPOINT_PATH, "ViT.ckpt")
if os.path.isfile(pretrained_filename):
print(f"Found pretrained model at {pretrained_filename}, loading...")
model = ViT.load_from_checkpoint(pretrained_filename) # Automatically
loads the model with the saved hyperparameters
else:
pl.seed_everything(42) # To be reproducable
model = ViT(**kwargs)
trainer.fit(model, train_loader, val_loader)
model =
ViT.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best
checkpoint after training
INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically
upgraded your loaded checkpoint from v1.6.4 to v2.5.1. To apply the upgrade to your
files permanently, run `python -m
pytorch_lightning.utilities.upgrade_checkpoint ../saved_models/tutorial15/ViT.ckpt`
/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py:624:
UserWarning: This DataLoader will create 4 worker processes in total. Our suggested
max number of worker in current system is 2, which is smaller than what this
DataLoader is going to create. Please be aware that excessive worker creation might
get DataLoader running slow or even freeze, lower the worker number to avoid
potential slowness/freeze if necessary.
warnings.warn(
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES:
[0]
{"model_id":"bbea62e005eb49bdab2884d048ebd2bb","version_major":2,"version_minor":0}
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES:
[0]
{"model_id":"958bcbc0e61c4e728dff9280fcd1a70f","version_major":2,"version_minor":0}
Change the parameters and retrain the ViT, below is an example of patch size/number
of patches
# always train
def train_model(**kwargs):
trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"),
accelerator="gpu" if str(device).startswith("cuda")
else "cpu",
devices=1,
max_epochs=180,
callbacks=[ModelCheckpoint(save_weights_only=True,
mode="max", monitor="val_acc"),
LearningRateMonitor("epoch")])
trainer.logger._log_graph = True # If True, we plot the computation
graph in tensorboard
trainer.logger._default_hp_metric = None # Optional logging argument that
we don't need
# Check whether pretrained model exists. If yes, load it and skip training
pretrained_filename = os.path.join(CHECKPOINT_PATH, "ViT.ckpt")
# if os.path.isfile(pretrained_filename):
# print(f"Found pretrained model at {pretrained_filename}, loading...")
# model = ViT.load_from_checkpoint(pretrained_filename) # Automatically
loads the model with the saved hyperparameters
# else:
pl.seed_everything(42) # To be reproducable
model = ViT(**kwargs)
trainer.fit(model, train_loader, val_loader)
model =
ViT.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best
checkpoint after training
embed_dim_list = [128,256]
hidden_dim = [256,512]
num_heads_list = [4,8]
num_layers=[4,8]
patch_num=[[2,256],[8,16]]
for i in range(len(patch_num)):
model, results = train_model(model_kwargs={
'embed_dim': 256,
'hidden_dim': 512,
'num_heads': 8,
'num_layers': 6,
'patch_size': patch_num[i][0],
'num_channels': 3,
'num_patches': patch_num[i][1],
'num_classes': 10,
'dropout': 0.2
},
lr=3e-4)
print(f"patch size: {patch_num[i][0]}, patch number: {patch_num[i][1]}, ViT
results", results)
-----------------------------------------------------------------------------------
0 | model | VisionTransformer | 3.2 M | train | [128, 3, 32, 32] | [128, 10]
-----------------------------------------------------------------------------------
3.2 M Trainable params
0 Non-trainable params
3.2 M Total params
12.940 Total estimated model params size (MB)
73 Modules in train mode
0 Modules in eval mode
{"model_id":"a25b7b0eec2641d6ac6166223071454d","version_major":2,"version_minor":0}
{"model_id":"445c6475e74840d88038f795fded2f71","version_major":2,"version_minor":0}
{"model_id":"a1325543bb104a00b3d05f686850edeb","version_major":2,"version_minor":0}
{"model_id":"57caedea2676489a878d535e97fde89f","version_major":2,"version_minor":0}
{"model_id":"04d680079a294fa3a9bf7ee19a259fbd","version_major":2,"version_minor":0}
{"model_id":"651db07634164257abfc158f585559ef","version_major":2,"version_minor":0}
{"model_id":"86c6d9e87d2e4a6d9fe8859e67b96790","version_major":2,"version_minor":0}
{"model_id":"e36d62529b1247ebbc1b6d566ff1028c","version_major":2,"version_minor":0}
{"model_id":"426d15a05c6e4cf693e4977f852ba7ee","version_major":2,"version_minor":0}
{"model_id":"170599fa1cba4f3eb35a4f5dbfd0c1dd","version_major":2,"version_minor":0}
{"model_id":"62edb6c9927c45ec8800ca74d4fd22df","version_major":2,"version_minor":0}
{"model_id":"7efcc7f4fbb84e56bb4115ec1b141974","version_major":2,"version_minor":0}
{"model_id":"d04f33f07f794cf4815add6e0c4e0237","version_major":2,"version_minor":0}
{"model_id":"01187947a6db4acbb3b27a25ff98541a","version_major":2,"version_minor":0}
{"model_id":"bdfef0bbd020467184fa33b8b5eb3e9d","version_major":2,"version_minor":0}