class CNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=3, out_channels=32,
kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2),
torch.nn.Conv2d(in_channels=32, out_channels=64,
kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2),
torch.nn.Conv2d(in_channels=64, out_channels=64,
kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2),
torch.nn.Flatten(),
torch.nn.Linear(64 * 4 * 4, 512),
torch.nn.ReLU(),
torch.nn.Linear(512, 10)
)
def forward(self, x):
return self.model(x)