Dataset and DataLoader Class
Dataset and DataLoader Class
Problems
1. Memory inefficient
2. Better Convergence
Dataset and DataLoader are core abstractions in PyTorch that decouple how you
define your data from how you efficiently iterate over it in training loops. Dataset Class
It defines:
DataLoader Class
• for each index in the chunk, data samples are fetched from
the Dataset object
Imagine the entire data loading and training process for one epoch with num_workers=4:
Assumptions:
• Total samples: 10,000
• Batch size: 32
• Workers (num_workers): 4
• Approximately 312 full batches per epoch (10000 / 32 ≈ 312).
Workflow:
1. Sampler and Batch Creation (Main Process):
Before training starts for the epoch, the DataLoader’s sampler generates a shuffled list of all 10,000 indices. These
are then grouped into 312 batches of 32 indices each. All these batches are queued up, ready to be fetched by
workers.
2. Parallel Data Loading (Workers):
○ At the start of the training epoch, you run a training loop like:
python
Copy code
for batch_data, batch_labels in dataloader:
# Training logic
○ Under the hood, as soon as you start iterating over dataloader, it dispatches the first four batches of indices
to the four workers:
▪ Worker #1 loads batch 1 (indices [batch_1_indices])
▪ Worker #2 loads batch 2 (indices [batch_2_indices])
▪ Worker #3 loads batch 3 (indices [batch_3_indices])
▪ Worker #4 loads batch 4 (indices [batch_4_indices])
Each worker:
○ Fetches the corresponding samples by calling __getitem__ on the dataset for each index in that batch.
○ Applies any defined transforms and passes the samples through collate_fn to form a single batch tensor.
3. First Batch Returned to Main Process:
○ Whichever worker finishes first sends its fully prepared batch (e.g., batch 1) back to the main process.
○ As soon as the main process gets this first prepared batch, it yields it to your training loop, so your codefor
batch_data, batch_labels in dataloader: receives (batch_data, batch_labels) for the first batch.
4. Model Training on the Main Process:
○ While you are now performing the forward pass, computing loss, and doing backpropagation on the first
batch, the other three workers are still preparing their batches in parallel.
○ By the time you finish updating your model parameters for the first batch, the DataLoader likely has the
second, third, or even more batches ready to go (depending on processing speed and hardware).
5. Continuous Processing:
○ As soon as a worker finishes its batch, it grabs the next batch of indices from the queue.
○ For example, after Worker #1 finishes with batch 1, it immediately starts on batch 5. After Worker #2
finishes batch 2, it takes batch 6, and so forth.
○ This creates a pipeline effect: at any given moment, up to 4 batches are being prepared concurrently.
6. Loop Progression:
○ Your training loop simply sees:
python
Copy code
for batch_data, batch_labels in dataloader:
# forward pass
# loss computation
# backward pass
# optimizer step
○ Each iteration, it gets a new, ready-to-use batch without long I/O waits, because the workers have been pre-
loading and processing data in parallel.
7. End of the Epoch:
○ After ~312 iterations, all batches have been processed. All indices have been consumed, so the DataLoader
has no more batches to yield.
○ The epoch ends. If shuffle=True, on the next epoch, the sampler reshuffles indices, and the whole process
repeats with workers again loading data in parallel.
In PyTorch, the sampler in the DataLoader determines the strategy for selecting samples from
the dataset during data loading. It controls how indices of the dataset are drawn for each
batch.
Types of Samplers
PyTorch provides several predefined samplers, and you can create custom ones:
1. SequentialSampler:
2. RandomSampler:
The collate_fn in PyTorch's DataLoader is a function that specifies how to combine a list of
samples from a dataset into a single batch. By default, the DataLoader uses a simple batch
collation mechanism, but collate_fn allows you to customize how the data should be
processed and batched.
The DataLoader class in PyTorch comes with several parameters that allow you to customize
how data is loaded, batched, and preprocessed. Some of the most commonly used and
important parameters include:
1. dataset (mandatory):
○ The Dataset from which the DataLoader will pull data.
○ Must be a subclass of torch.utils.data.Dataset that implements __getitem__ and
__len__.
2. batch_size:
○ How many samples per batch to load.
○ Default is 1.
○ Larger batch sizes can speed up training on GPUs but require more memory.
3. shuffle:
○ If True, the DataLoader will shuffle the dataset indices each epoch.
○ Helpful to avoid the model becoming too dependent on the order of samples.
4. num_workers:
○ The number of worker processes used to load data in parallel.
○ Setting num_workers > 0 can speed up data loading by leveraging multiple CPU
cores, especially if I/O or preprocessing is a bottleneck.
5. pin_memory:
○ If True, the DataLoader will copy tensors into pinned (page-locked) memory before
returning them.
○ This can improve GPU transfer speed and thus overall training throughput,
particularly on CUDA systems.
6. drop_last:
○ If True, the DataLoader will drop the last incomplete batch if the total number of
samples is not divisible by the batch size.
○ Useful when exact batch sizes are required (for example, in some batch
normalization scenarios).
7. collate_fn:
○ A callable that processes a list of samples into a batch (the default simply stacks
tensors).
○ Custom collate_fn can handle variable-length sequences, perform custom batching
logic, or handle complex data structures.
8. sampler:
○ sampler defines the strategy for drawing samples (e.g., for handling imbalanced
classes, or custom sampling strategies).
○ batch_sampler works at the batch level, controlling how batches are formed.
○ Typically, you don’t need to specify these if you are using batch_size and shuffle.
However, they provide lower-level control if you have advanced requirements.