BLDD VIT ResNet50v2 CustomCNN
BLDD VIT ResNet50v2 CustomCNN
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import cv2
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
from torchinfo import summary
from torchvision.models import vit_b_16, ViT_B_16_Weights
from sklearn.metrics import accuracy_score, confusion_matrix
import os
from pathlib import Path
from tqdm.auto import tqdm
from collections import OrderedDict
import random
import warnings
warnings.filterwarnings("ignore")
classes = sorted(os.listdir(IMAGE_PATH))
print('==' * 20)
print(' ' * 10, f'Total Classes = {len(classes)}')
print('==' * 20)
for c in classes:
total_images_class = list(Path(os.path.join(IMAGE_PATH,
c)).glob("*.jpg"))
print(f'* {c}: {len(total_images_class)} images')
def __len__(self):
return len(self.df)
train_dataloader = DataLoader(dataset=train_dataset,
batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
valid_dataloader = DataLoader(dataset=valid_dataset,
batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
# Visualize a batch
batch_images, batch_labels = next(iter(train_dataloader))
print("Batch images shape:", batch_images.shape, "and Batch labels
shape:", batch_labels.shape)
# Load ViT-16 model and modify last layer for current number of
classes
model = vit_b_16(weights=weights)
summary(model=model, input_size=[8, 3, 224, 224], col_width=15,
col_names=["input_size", "output_size", "num_params", "trainable"],
row_settings=["var_names"])
output_shape = len(classes)
model.heads = nn.Sequential(OrderedDict([('head',
nn.Linear(in_features=768, out_features=output_shape))]))
summary(model=model, input_size=[8, 3, 224, 224], col_width=15,
col_names=["input_size", "output_size", "num_params", "trainable"],
row_settings=["var_names"])
loss_metric_curve_plot(MODEL_RESULTS)
confusion_matrix_test =
confusion_matrix(df_test["label"].map(label_map), y_pred_test.numpy())
fig, ax = plt.subplots(figsize=(10, 4.5))
sns.heatmap(confusion_matrix_test, cmap='Oranges', annot=True,
annot_kws={"fontsize": 9, "fontweight": "bold"},
linewidths=1.2, fmt=' ', linecolor="white", square=True,
xticklabels=classes, yticklabels=classes,
cbar=False, ax=ax)
ax.set_title("Confusion Matrix Test", fontsize=15, fontweight="bold",
color="darkblue")
ax.tick_params('x', rotation=90)
fig.show()
Downloading: "https://download.pytorch.org/models/vit_b_16-
c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-
c867db91.pth
100%|██████████| 330M/330M [00:01<00:00, 210MB/s]
{"model_id":"e97a19f7213640598339fe3c82118a19","version_major":2,"vers
ion_minor":0}
{"model_id":"3fbdaf5bc856475697b142497520467b","version_major":2,"vers
ion_minor":0}
Classification Report:
precision recall f1-score
support
accuracy 0.95
62
macro avg 0.93 0.96 0.94
62
weighted avg 0.96 0.95 0.95
62
# 1. ResNet50V2 Implementation
# Import necessary libraries
import torch
from torch import nn
import torchvision.models as models
from collections import OrderedDict
# Second block
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
# Third block
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
# Fourth block
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
accuracy = accuracy_score(df_test["label"].map(label_map),
y_pred_test.numpy())
print(f'{model_name} Test Accuracy = {round(accuracy, 4)}')
conf_matrix = confusion_matrix(df_test["label"].map(label_map),
y_pred_test.numpy())
fig, ax = plt.subplots(figsize=(10, 4.5))
sns.heatmap(conf_matrix, cmap='Blues', annot=True,
annot_kws={"fontsize": 9, "fontweight": "bold"},
linewidths=1.2, fmt=' ', linecolor="white",
square=True,
xticklabels=classes, yticklabels=classes, cbar=False,
ax=ax)
ax.set_title(f"{model_name} Confusion Matrix", fontsize=15,
fontweight="bold", color="darkblue")
ax.tick_params('x', rotation=90)
plt.tight_layout()
plt.show()
# Evaluate ResNet50V2
resnet_accuracy, resnet_preds = evaluate_model(resnet_model,
test_dataloader, "ResNet50V2")
# Plot comparison
model_names = ['ViT-B/16', 'ResNet50V2', 'Custom CNN']
accuracies = [vit_accuracy, resnet_accuracy, cnn_accuracy]
plt.figure(figsize=(10, 6))
bars = plt.bar(model_names, accuracies, color=['orange', 'skyblue',
'lightgreen'])
# Training Loss
plt.subplot(2, 2, 1)
plt.plot(MODEL_RESULTS["train_loss"], 'o-', color='orange',
label='ViT')
plt.plot(RESNET_RESULTS["train_loss"], 'o-', color='skyblue',
label='ResNet50V2')
plt.plot(CNN_RESULTS["train_loss"], 'o-', color='lightgreen',
label='Custom CNN')
plt.title("Training Loss", fontsize=12, fontweight="bold")
plt.xlabel("Epochs", fontsize=10)
plt.ylabel("Loss", fontsize=10)
plt.legend()
plt.grid(True, alpha=0.3)
# Validation Loss
plt.subplot(2, 2, 2)
plt.plot(MODEL_RESULTS["valid_loss"], 'o-', color='orange',
label='ViT')
plt.plot(RESNET_RESULTS["valid_loss"], 'o-', color='skyblue',
label='ResNet50V2')
plt.plot(CNN_RESULTS["valid_loss"], 'o-', color='lightgreen',
label='Custom CNN')
plt.title("Validation Loss", fontsize=12, fontweight="bold")
plt.xlabel("Epochs", fontsize=10)
plt.ylabel("Loss", fontsize=10)
plt.legend()
plt.grid(True, alpha=0.3)
# Training Accuracy
plt.subplot(2, 2, 3)
plt.plot(MODEL_RESULTS["train_accuracy"], 'o-', color='orange',
label='ViT')
plt.plot(RESNET_RESULTS["train_accuracy"], 'o-', color='skyblue',
label='ResNet50V2')
plt.plot(CNN_RESULTS["train_accuracy"], 'o-', color='lightgreen',
label='Custom CNN')
plt.title("Training Accuracy", fontsize=12, fontweight="bold")
plt.xlabel("Epochs", fontsize=10)
plt.ylabel("Accuracy", fontsize=10)
plt.legend()
plt.grid(True, alpha=0.3)
# Validation Accuracy
plt.subplot(2, 2, 4)
plt.plot(MODEL_RESULTS["valid_accuracy"], 'o-', color='orange',
label='ViT')
plt.plot(RESNET_RESULTS["valid_accuracy"], 'o-', color='skyblue',
label='ResNet50V2')
plt.plot(CNN_RESULTS["valid_accuracy"], 'o-', color='lightgreen',
label='Custom CNN')
plt.title("Validation Accuracy", fontsize=12, fontweight="bold")
plt.xlabel("Epochs", fontsize=10)
plt.ylabel("Accuracy", fontsize=10)
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Downloading: "https://download.pytorch.org/models/resnet50-
11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-
11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 178MB/s]
==================================================
Training ResNet50V2 model
==================================================
{"model_id":"949cdceee7e340a38b38db2cf219f103","version_major":2,"vers
ion_minor":0}
{"model_id":"afb7427359e84896b6b6ead4b9b8de68","version_major":2,"vers
ion_minor":0}
{"model_id":"d1e86ab2cc074314b9cf63fce9314e88","version_major":2,"vers
ion_minor":0}
print(f"Accuracy: {accuracy_resnet:.4f}")
print(f"Precision: {precision_resnet:.4f}")
print(f"Recall: {recall_resnet:.4f}")
print(f"F1 Score: {f1_resnet:.4f}")
print("\nClassification Report:")
print(classification_report(true_labels, resnet_pred,
target_names=classes))
print(f"Accuracy: {accuracy_cnn:.4f}")
print(f"Precision: {precision_cnn:.4f}")
print(f"Recall: {recall_cnn:.4f}")
print(f"F1 Score: {f1_cnn:.4f}")
print("\nClassification Report:")
print(classification_report(true_labels, cnn_pred,
target_names=classes))
Classification Report:
precision recall f1-score
support
accuracy 0.90
62
macro avg 0.88 0.91 0.89
62
weighted avg 0.91 0.90 0.90
62
Custom CNN Evaluation Metrics:
Accuracy: 0.6290
Precision: 0.6514
Recall: 0.6237
F1 Score: 0.6073
Classification Report:
precision recall f1-score
support