Open In App

How to Split a Dataset Using PyTorch

Last Updated : 07 Aug, 2024
Comments
Improve
Suggest changes
Like Article
Like
Report

Splitting a dataset is an important step in training machine learning models. It helps to separate the data into different sets, typically training, and validation, so we can train our model on one set and validate its performance on another.

In this article, we are going to discuss the process of splitting a dataset using PyTorch, a popular framework for deep learning.

Introduction to Dataset Splitting

When we are working with a dataset, it is important not to use all of it just for training your model.

We should split it into different parts:

  • Training Set: This is the portion of the dataset used to train our model.
  • Validation Set: This set is used to evaluate our model's performance and adjust it accordingly.
  • Test Set: Sometimes, a third set is used to test the model after training and validation are complete.

Splitting the data helps to avoid a common problem where the model learns too much from the training data and does really well on it but does not do well when faced with unseen (new) data.

Splitting Datasets in PyTorch: A Step-by-Step Guide with Random Split

PyTorch provides a simple function known as "random_split" to help us to split our dataset. This function divides our data into non-overlapping chunks based on the proportions we specify.

Step 1: Import Required Libraries

First, we need to import the necessary libraries for our task. We’ll use PyTorch, NumPy, and some tools from the sklearn library to generate sample data.

import pprint as pp
from sklearn import datasets
import numpy as np
import torch
from torch.utils.data import Dataset, random_split

Step 2: Generate Sample Data

Next, we will create some sample data that we can work with. We will use make_blobs from sklearn to generate a simple dataset.

# Define the number of samples
total_samples = 1800

# Generate sample data with 3 features and 2 centers
X_data, Y_data = datasets.make_blobs(n_samples=total_samples, n_features=3, centers=[(-2, 5), (3, -4)], random_state=42)

Here, we are generating a dataset with 1800 samples, each having 3 features, and split across 2 centers. This will give us some synthetic data to work with.

Step 3: Create a Custom Dataset Class

In PyTorch, it’s common to create a custom Dataset class to handle our data. This class will allow us to manage how data is loaded.

class CustomDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, index):
        # Return a dictionary with 'features' and 'label' as keys
        sample = {
            'features': torch.tensor(self.x[index], dtype=torch.float32),
            'label': torch.tensor(self.y[index], dtype=torch.long)
        }
        return sample

    def __len__(self):
        # Return the total number of samples
        return len(self.x)

This class takes the input data (x) and labels (y), and returns them as a dictionary when accessed. The __len__ method returns the number of samples in the dataset.

Step 4: Create the Dataset Instance

Now, we will create an instance of our custom dataset and check its length.

# Create the dataset instance
dataset = CustomDataset(X_data, Y_data)

# Print the length of the dataset
print("Total number of samples in the dataset:", len(dataset))

This will print out the total number of samples in our dataset, which should be 1800.

Step 5: Split the Dataset

Finally, we can split the dataset into training and validation sets using random_split.

# Split the dataset into training (1200 samples) and validation (600 samples)
train_data, val_data = random_split(dataset, [1200, 600])

# Print the lengths of the train and validation sets
print("Number of training samples:", len(train_data))
print("Number of validation samples:", len(val_data))

Here, we’re splitting the dataset so that 1200 samples go to the training set and 600 to the validation set.

Example Code:

Below is example code for splitting a dataset using PyTorch:

Python
import pprint as pp
from sklearn import datasets
import numpy as np
import torch
from torch.utils.data import Dataset, random_split

# Generate Sample Data
total_samples = 1800
X_data, Y_data = datasets.make_blobs(n_samples=total_samples, n_features=3, centers=[(-2, 5), (3, -4)], random_state=42)

# Create a Custom Dataset Class
class CustomDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, index):
        sample = {
            'features': torch.tensor(self.x[index], dtype=torch.float32),
            'label': torch.tensor(self.y[index], dtype=torch.long)
        }
        return sample

    def __len__(self):
        return len(self.x)

# Create the Dataset Instance
dataset = CustomDataset(X_data, Y_data)
print("Total number of samples in the dataset:", len(dataset))

# Split the Dataset
train_data, val_data = random_split(dataset, [1200, 600])

print("Number of training samples:", len(train_data))
print("Number of validation samples:", len(val_data))

Output:

Total number of samples in the dataset: 1800
Number of training samples: 1200
Number of validation samples: 600

How to Split CIFAR-10 Dataset for Training and Validation in PyTorch?

Splitting a dataset into training and validation sets is a crucial step in machine learning to ensure that a model is trained on one subset of data and evaluated on another, unseen subset. Now, we’ll walk through how to split the CIFAR-10 dataset using PyTorch.

Step 1: Import Required Libraries

First, we need to import the necessary libraries for data manipulation and model training. PyTorch provides tools for handling datasets and transformations, which we'll use in this example.

import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

Step 2: Define Data Transformations

Data transformations are essential to preprocess the CIFAR-10 images before feeding them into the model. We will convert the images to tensors and normalize them.

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

Step 3: Load the CIFAR-10 Dataset

Download and load the CIFAR-10 dataset. The dataset will be transformed according to the transformations defined earlier.

dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

Step 4: Define Split Ratios

Specify the proportions for training and validation splits. In this example, we'll use an 80-20 split.

train_ratio = 0.8
validation_ratio = 0.2

Step 5: Calculate Sizes for Each Split

Calculate the number of samples for each split based on the specified ratios.

dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size

Step 6: Perform the Split

Use the random_split function to divide the dataset into training and validation sets.

train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])

Step 7: Create DataLoaders

DataLoaders are used to load the data in batches, which is useful for training and evaluating models.

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=False)

Verify the Splits

Finally, print out the sizes of the training and validation datasets to verify the splits.

print(f'Total dataset size: {dataset_size}')
print(f'Training dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(validation_dataset)}')

Complete Code

Python
import torch
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms

# Step 1: Define the transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Step 2: Load the dataset
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Step 3: Define the split ratios
train_ratio = 0.8
validation_ratio = 0.2

# Step 4: Calculate the sizes for each split
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size

# Step 5: Perform the split
train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])

# Step 6: Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=False)

# Verify the splits
print(f'Total dataset size: {dataset_size}')
print(f'Training dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(validation_dataset)}')

Output:

Total dataset size: 50000
Training dataset size: 40000
Validation dataset size: 10000

Conclusion

Splitting a dataset is a fundamental step in machine learning. We can easily do it using built-in functions of PyTorch. By following above steps, we can ensure that our model is trained and validated effectively, leading to better generalization and performance on new data.


Next Article

Similar Reads