Splitting a dataset is an important step in training machine learning models. It helps to separate the data into different sets, typically training, and validation, so we can train our model on one set and validate its performance on another.
In this article, we are going to discuss the process of splitting a dataset using PyTorch, a popular framework for deep learning.
Introduction to Dataset Splitting
When we are working with a dataset, it is important not to use all of it just for training your model.
We should split it into different parts:
- Training Set: This is the portion of the dataset used to train our model.
- Validation Set: This set is used to evaluate our model's performance and adjust it accordingly.
- Test Set: Sometimes, a third set is used to test the model after training and validation are complete.
Splitting the data helps to avoid a common problem where the model learns too much from the training data and does really well on it but does not do well when faced with unseen (new) data.
Splitting Datasets in PyTorch: A Step-by-Step Guide with Random Split
PyTorch provides a simple function known as "random_split" to help us to split our dataset. This function divides our data into non-overlapping chunks based on the proportions we specify.
Step 1: Import Required Libraries
First, we need to import the necessary libraries for our task. We’ll use PyTorch, NumPy, and some tools from the sklearn library to generate sample data.
import pprint as pp
from sklearn import datasets
import numpy as np
import torch
from torch.utils.data import Dataset, random_split
Step 2: Generate Sample Data
Next, we will create some sample data that we can work with. We will use make_blobs from sklearn to generate a simple dataset.
# Define the number of samples
total_samples = 1800
# Generate sample data with 3 features and 2 centers
X_data, Y_data = datasets.make_blobs(n_samples=total_samples, n_features=3, centers=[(-2, 5), (3, -4)], random_state=42)
Here, we are generating a dataset with 1800 samples, each having 3 features, and split across 2 centers. This will give us some synthetic data to work with.
Step 3: Create a Custom Dataset Class
In PyTorch, it’s common to create a custom Dataset class to handle our data. This class will allow us to manage how data is loaded.
class CustomDataset(Dataset):
def __init__(self, x, y):
self.x = x
self.y = y
def __getitem__(self, index):
# Return a dictionary with 'features' and 'label' as keys
sample = {
'features': torch.tensor(self.x[index], dtype=torch.float32),
'label': torch.tensor(self.y[index], dtype=torch.long)
}
return sample
def __len__(self):
# Return the total number of samples
return len(self.x)
This class takes the input data (x) and labels (y), and returns them as a dictionary when accessed. The __len__ method returns the number of samples in the dataset.
Step 4: Create the Dataset Instance
Now, we will create an instance of our custom dataset and check its length.
# Create the dataset instance
dataset = CustomDataset(X_data, Y_data)
# Print the length of the dataset
print("Total number of samples in the dataset:", len(dataset))
This will print out the total number of samples in our dataset, which should be 1800.
Step 5: Split the Dataset
Finally, we can split the dataset into training and validation sets using random_split.
# Split the dataset into training (1200 samples) and validation (600 samples)
train_data, val_data = random_split(dataset, [1200, 600])
# Print the lengths of the train and validation sets
print("Number of training samples:", len(train_data))
print("Number of validation samples:", len(val_data))
Here, we’re splitting the dataset so that 1200 samples go to the training set and 600 to the validation set.
Example Code:
Below is example code for splitting a dataset using PyTorch:
Python
import pprint as pp
from sklearn import datasets
import numpy as np
import torch
from torch.utils.data import Dataset, random_split
# Generate Sample Data
total_samples = 1800
X_data, Y_data = datasets.make_blobs(n_samples=total_samples, n_features=3, centers=[(-2, 5), (3, -4)], random_state=42)
# Create a Custom Dataset Class
class CustomDataset(Dataset):
def __init__(self, x, y):
self.x = x
self.y = y
def __getitem__(self, index):
sample = {
'features': torch.tensor(self.x[index], dtype=torch.float32),
'label': torch.tensor(self.y[index], dtype=torch.long)
}
return sample
def __len__(self):
return len(self.x)
# Create the Dataset Instance
dataset = CustomDataset(X_data, Y_data)
print("Total number of samples in the dataset:", len(dataset))
# Split the Dataset
train_data, val_data = random_split(dataset, [1200, 600])
print("Number of training samples:", len(train_data))
print("Number of validation samples:", len(val_data))
Output:
Total number of samples in the dataset: 1800
Number of training samples: 1200
Number of validation samples: 600
How to Split CIFAR-10 Dataset for Training and Validation in PyTorch?
Splitting a dataset into training and validation sets is a crucial step in machine learning to ensure that a model is trained on one subset of data and evaluated on another, unseen subset. Now, we’ll walk through how to split the CIFAR-10 dataset using PyTorch.
Step 1: Import Required Libraries
First, we need to import the necessary libraries for data manipulation and model training. PyTorch provides tools for handling datasets and transformations, which we'll use in this example.
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
Step 2: Define Data Transformations
Data transformations are essential to preprocess the CIFAR-10 images before feeding them into the model. We will convert the images to tensors and normalize them.
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
Step 3: Load the CIFAR-10 Dataset
Download and load the CIFAR-10 dataset. The dataset will be transformed according to the transformations defined earlier.
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
Step 4: Define Split Ratios
Specify the proportions for training and validation splits. In this example, we'll use an 80-20 split.
train_ratio = 0.8
validation_ratio = 0.2
Step 5: Calculate Sizes for Each Split
Calculate the number of samples for each split based on the specified ratios.
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size
Step 6: Perform the Split
Use the random_split
function to divide the dataset into training and validation sets.
train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])
Step 7: Create DataLoaders
DataLoaders are used to load the data in batches, which is useful for training and evaluating models.
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=False)
Verify the Splits
Finally, print out the sizes of the training and validation datasets to verify the splits.
print(f'Total dataset size: {dataset_size}')
print(f'Training dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(validation_dataset)}')
Complete Code
Python
import torch
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms
# Step 1: Define the transformations
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Step 2: Load the dataset
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# Step 3: Define the split ratios
train_ratio = 0.8
validation_ratio = 0.2
# Step 4: Calculate the sizes for each split
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size
# Step 5: Perform the split
train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])
# Step 6: Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=False)
# Verify the splits
print(f'Total dataset size: {dataset_size}')
print(f'Training dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(validation_dataset)}')
Output:
Total dataset size: 50000
Training dataset size: 40000
Validation dataset size: 10000
Conclusion
Splitting a dataset is a fundamental step in machine learning. We can easily do it using built-in functions of PyTorch. By following above steps, we can ensure that our model is trained and validated effectively, leading to better generalization and performance on new data.
Similar Reads
How to create a dataset using PyBrain?
In this article, we are going to see how to create a dataset using PyBrain. Dataset Datasets are the data that are specifically given to test, validate and train on networks. Instead of troubling with arrays, PyBrain provides us with a more flexible data structure using which handling data can be qu
3 min read
How to use a DataLoader in PyTorch?
Operating with large datasets requires loading them into memory all at once. In most cases, we face a memory outage due to the limited amount of memory available in the system. Also, the programs tend to run slowly due to heavy datasets loaded once. PyTorch offers a solution for parallelizing the da
2 min read
How to split a Dataset into Train and Test Sets using Python
One of the most important steps in preparing data for training a ML model is splitting the dataset into training and testing sets. This simply means dividing the data into two parts: one to train the machine learning model (training set), and another to evaluate how well it performs on unseen data (
3 min read
How to Get the Data Type of a Pytorch Tensor?
In this article, we are going to create a tensor and get the data type. The Pytorch is used to process the tensors. Tensors are multidimensional arrays. PyTorch accelerates the scientific computation of tensors as it has various inbuilt functions. Vector: A vector is a one-dimensional tensor that ho
3 min read
How to import datasets using sklearn in PyBrain
In this article, we will discuss how to import datasets using sklearn in PyBrain Dataset: A Dataset is defined as the set of data that is can be used to test, validate, and train on networks. On comparing it with arrays, a dataset is considered more flexible and easy to use. A dataset resembles a 2-
2 min read
How to load CIFAR10 Dataset in Pytorch?
The CIFAR-10 dataset is a popular resource for training machine learning models, especially in the field of image recognition. It consists of 60,000 32x32 color images in 10 different classes, with 6,000 images per class. The dataset is divided into 50,000 training images and 10,000 testing images.
3 min read
How To Update Pytorch Using Pip
PyTorch is an open-source machine learning framework based on the Torch library. It is crucial to keep PyTorch up to date in order to use the latest features and improves bug fixing. In this article, we will learn some concepts related to updating PyTorch using pip and learn how to update PyTorch us
2 min read
Load a Computer Vision Dataset in PyTorch
Computer vision is a subset of Artificial Intelligence that gives the ability to the computer to understand images. In Deep Learning, Convolution Neural Network is used to process the image. For building the good we need a lot of images to process. There are several ways to load a computer vision da
3 min read
How to Slice a 3D Tensor in Pytorch?
In this article, we will discuss how to Slice a 3D Tensor in Pytorch. Let's create a 3D Tensor for demonstration. We can create a vector by using torch.tensor() function Syntax: torch.tensor([value1,value2,.value n]) Code: C/C++ Code # import torch module import torch # create an 3 D tensor with 8 e
2 min read
How to Split Vector and DataFrame in R
R is a programming language and environment specifically designed for facts analysis, statistical computing, and graphics. Sometimes it is required to split data into batches for various data manipulation and analysis tasks. In this article, we will discuss some techniques to split vectors into chun
6 min read