.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/basics/quickstart_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_basics_quickstart_tutorial.py: `Learn the Basics `_ || **Quickstart** || `Tensors `_ || `Datasets & DataLoaders `_ || `Transforms `_ || `Build Model `_ || `Autograd `_ || `Optimization `_ || `Save & Load Model `_ Quickstart =================== This section runs through the API for common tasks in machine learning. Refer to the links in each section to dive deeper. Working with data ----------------- PyTorch has two `primitives to work with data `_: ``torch.utils.data.DataLoader`` and ``torch.utils.data.Dataset``. ``Dataset`` stores the samples and their corresponding labels, and ``DataLoader`` wraps an iterable around the ``Dataset``. .. GENERATED FROM PYTHON SOURCE LINES 24-31 .. code-block:: default import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor .. GENERATED FROM PYTHON SOURCE LINES 32-40 PyTorch offers domain-specific libraries such as `TorchText `_, `TorchVision `_, and `TorchAudio `_, all of which include datasets. For this tutorial, we will be using a TorchVision dataset. The ``torchvision.datasets`` module contains ``Dataset`` objects for many real-world vision data like CIFAR, COCO (`full list here `_). In this tutorial, we use the FashionMNIST dataset. Every TorchVision ``Dataset`` includes two arguments: ``transform`` and ``target_transform`` to modify the samples and labels respectively. .. GENERATED FROM PYTHON SOURCE LINES 40-57 .. code-block:: default # Download training data from open datasets. training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor(), ) # Download test data from open datasets. test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor(), ) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0.00/26.4M [00:00`_. .. GENERATED FROM PYTHON SOURCE LINES 78-80 -------------- .. GENERATED FROM PYTHON SOURCE LINES 82-89 Creating Models ------------------ To define a neural network in PyTorch, we create a class that inherits from `nn.Module `_. We define the layers of the network in the ``__init__`` function and specify how data will pass through the network in the ``forward`` function. To accelerate operations in the neural network, we move it to the `accelerator `__ such as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU. .. GENERATED FROM PYTHON SOURCE LINES 89-114 .. code-block:: default device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu" print(f"Using {device} device") # Define model class NeuralNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10) ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits model = NeuralNetwork().to(device) print(model) .. rst-class:: sphx-glr-script-out .. code-block:: none Using cuda device NeuralNetwork( (flatten): Flatten(start_dim=1, end_dim=-1) (linear_relu_stack): Sequential( (0): Linear(in_features=784, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=512, bias=True) (3): ReLU() (4): Linear(in_features=512, out_features=10, bias=True) ) ) .. GENERATED FROM PYTHON SOURCE LINES 115-117 Read more about `building neural networks in PyTorch `_. .. GENERATED FROM PYTHON SOURCE LINES 120-122 -------------- .. GENERATED FROM PYTHON SOURCE LINES 125-129 Optimizing the Model Parameters ---------------------------------------- To train a model, we need a `loss function `_ and an `optimizer `_. .. GENERATED FROM PYTHON SOURCE LINES 129-134 .. code-block:: default loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) .. GENERATED FROM PYTHON SOURCE LINES 135-137 In a single training loop, the model makes predictions on the training dataset (fed to it in batches), and backpropagates the prediction error to adjust the model's parameters. .. GENERATED FROM PYTHON SOURCE LINES 137-157 .. code-block:: default def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) model.train() for batch, (X, y) in enumerate(dataloader): X, y = X.to(device), y.to(device) # Compute prediction error pred = model(X) loss = loss_fn(pred, y) # Backpropagation loss.backward() optimizer.step() optimizer.zero_grad() if batch % 100 == 0: loss, current = loss.item(), (batch + 1) * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") .. GENERATED FROM PYTHON SOURCE LINES 158-159 We also check the model's performance against the test dataset to ensure it is learning. .. GENERATED FROM PYTHON SOURCE LINES 159-175 .. code-block:: default def test(dataloader, model, loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) model.eval() test_loss, correct = 0, 0 with torch.no_grad(): for X, y in dataloader: X, y = X.to(device), y.to(device) pred = model(X) test_loss += loss_fn(pred, y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() test_loss /= num_batches correct /= size print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") .. GENERATED FROM PYTHON SOURCE LINES 176-179 The training process is conducted over several iterations (*epochs*). During each epoch, the model learns parameters to make better predictions. We print the model's accuracy and loss at each epoch; we'd like to see the accuracy increase and the loss decrease with every epoch. .. GENERATED FROM PYTHON SOURCE LINES 179-187 .. code-block:: default epochs = 5 for t in range(epochs): print(f"Epoch {t+1}\n-------------------------------") train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model, loss_fn) print("Done!") .. rst-class:: sphx-glr-script-out .. code-block:: none Epoch 1 ------------------------------- loss: 2.303494 [ 64/60000] loss: 2.294637 [ 6464/60000] loss: 2.277102 [12864/60000] loss: 2.269977 [19264/60000] loss: 2.254234 [25664/60000] loss: 2.237145 [32064/60000] loss: 2.231056 [38464/60000] loss: 2.205036 [44864/60000] loss: 2.203239 [51264/60000] loss: 2.170890 [57664/60000] Test Error: Accuracy: 53.9%, Avg loss: 2.168587 Epoch 2 ------------------------------- loss: 2.177784 [ 64/60000] loss: 2.168083 [ 6464/60000] loss: 2.114908 [12864/60000] loss: 2.130411 [19264/60000] loss: 2.087470 [25664/60000] loss: 2.039667 [32064/60000] loss: 2.054271 [38464/60000] loss: 1.985452 [44864/60000] loss: 1.996019 [51264/60000] loss: 1.917239 [57664/60000] Test Error: Accuracy: 60.2%, Avg loss: 1.920371 Epoch 3 ------------------------------- loss: 1.951699 [ 64/60000] loss: 1.919513 [ 6464/60000] loss: 1.808724 [12864/60000] loss: 1.846544 [19264/60000] loss: 1.740612 [25664/60000] loss: 1.698728 [32064/60000] loss: 1.708887 [38464/60000] loss: 1.614431 [44864/60000] loss: 1.646473 [51264/60000] loss: 1.524302 [57664/60000] Test Error: Accuracy: 61.4%, Avg loss: 1.547089 Epoch 4 ------------------------------- loss: 1.612693 [ 64/60000] loss: 1.570868 [ 6464/60000] loss: 1.424729 [12864/60000] loss: 1.489538 [19264/60000] loss: 1.367247 [25664/60000] loss: 1.373463 [32064/60000] loss: 1.376742 [38464/60000] loss: 1.304958 [44864/60000] loss: 1.347153 [51264/60000] loss: 1.230657 [57664/60000] Test Error: Accuracy: 62.7%, Avg loss: 1.260888 Epoch 5 ------------------------------- loss: 1.337799 [ 64/60000] loss: 1.313273 [ 6464/60000] loss: 1.151835 [12864/60000] loss: 1.252141 [19264/60000] loss: 1.123040 [25664/60000] loss: 1.159529 [32064/60000] loss: 1.175010 [38464/60000] loss: 1.115551 [44864/60000] loss: 1.160972 [51264/60000] loss: 1.062725 [57664/60000] Test Error: Accuracy: 64.6%, Avg loss: 1.087372 Done! .. GENERATED FROM PYTHON SOURCE LINES 188-190 Read more about `Training your model `_. .. GENERATED FROM PYTHON SOURCE LINES 192-194 -------------- .. GENERATED FROM PYTHON SOURCE LINES 196-199 Saving Models ------------- A common way to save a model is to serialize the internal state dictionary (containing the model parameters). .. GENERATED FROM PYTHON SOURCE LINES 199-205 .. code-block:: default torch.save(model.state_dict(), "model.pth") print("Saved PyTorch Model State to model.pth") .. rst-class:: sphx-glr-script-out .. code-block:: none Saved PyTorch Model State to model.pth .. GENERATED FROM PYTHON SOURCE LINES 206-211 Loading Models ---------------------------- The process for loading a model includes re-creating the model structure and loading the state dictionary into it. .. GENERATED FROM PYTHON SOURCE LINES 211-215 .. code-block:: default model = NeuralNetwork().to(device) model.load_state_dict(torch.load("model.pth", weights_only=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 216-217 This model can now be used to make predictions. .. GENERATED FROM PYTHON SOURCE LINES 217-240 .. code-block:: default classes = [ "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot", ] model.eval() x, y = test_data[0][0], test_data[0][1] with torch.no_grad(): x = x.to(device) pred = model(x) predicted, actual = classes[pred[0].argmax(0)], classes[y] print(f'Predicted: "{predicted}", Actual: "{actual}"') .. rst-class:: sphx-glr-script-out .. code-block:: none Predicted: "Ankle boot", Actual: "Ankle boot" .. GENERATED FROM PYTHON SOURCE LINES 241-243 Read more about `Saving & Loading your model `_. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 35.626 seconds) .. _sphx_glr_download_beginner_basics_quickstart_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: quickstart_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: quickstart_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_