How to define a simple artificial neural network in PyTorch?



To define a simple artificial neural network (ANN), we could use the following steps −

Steps

  • First we import the important libraries and packages. We try to implement a simple ANN in PyTorch. In all the following examples, the required Python library is torch. Make sure you have already installed it.

import torch import torch.nn as nn
  • Our next step is to build a simple ANN model. Here, we use the nn package to implement our model. For this, we define a class MyNetwork and pass nn.Module as the parameter.

class MyNetwork(nn.Module):
  • We need to create two functions inside the class to get our model ready. First is the init() and the second is the forward(). Within the init() function, we call a super() function and define different layers.

  • We need to instantiate the class to use for training on the dataset. When we instantiate the class, the forward() function is executed.

model = MyNetwork()
  • Print the model to see the different layers.

print(model)

Example 1

In the following example, we create a simple Artificial Neural Network with four layers without forward function.

# Import the required libraries import torch from torch import nn # define a simple sequential model model = nn.Sequential( nn.Linear(32, 128), nn.ReLU(), nn.Linear(128, 10), nn.Sigmoid() ) # print the model print(model)

Output

Sequential(
   (0): Linear(in_features=32, out_features=128, bias=True)
   (1): ReLU()
   (2): Linear(in_features=128, out_features=10, bias=True)
   (3): Sigmoid()
)

Example 2

The following Python program shows a different way to build a simple Neural network.

import torch import torch.nn as nn import torch.nn.functional as F class MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.fc1 = nn.Linear(4, 8) self.fc2 = nn.Linear(8, 16) self.fc3 = nn.Linear(16, 4) self.fc4 = nn.Linear(4,1) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = F.relu(self.fc3(x)) return torch.sigmoid(self.fc4(x)) model = MyNet() print(model)

Output

MyNet(
   (fc1): Linear(in_features=4, out_features=8, bias=True)
   (fc2): Linear(in_features=8, out_features=16, bias=True)
   (fc3): Linear(in_features=16, out_features=4, bias=True)
   (fc4): Linear(in_features=4, out_features=1, bias=True)
)
Updated on: 2022-01-25T08:39:11+05:30

549 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements