Assignment Ai Paltforms Mostafa Hazem
Assignment Ai Paltforms Mostafa Hazem
import torch.nn as nn
import seaborn as sns
import matplotlib.pyplot as plt
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(3, 3)
self.fc2 = nn.Linear(3, 1)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
x = self.sigmoid(self.fc1(x))
x = self.tanh(self.fc2(x))
return x
model = SimpleNN()
sns.lineplot(data=initial_output)
plt.title('Network Predictions Before Training')
plt.show()
M a k in g t h e n e t w o r k m o r e t r a i n a b l e .
for param in model.parameters():
param.requires_grad = True
T r a i n in g l o o p .
# Training loop
epochs = 100
for epoch in range(epochs):
# Zero the gradients
optimizer.zero_grad()
# Forward pass
output = model(x)
P r e d ic t i o n s b e f o r e a n d a f t e r T r a i n in g .
final_output = model(x).detach().numpy()
sns.lineplot(data=initial_output, ax=axs[0])
axs[0].set_title('Predictions Before Training')
sns.lineplot(data=final_output, ax=axs[1])
axs[1].set_title('Predictions After Training')
plt.show()