PyTorch collate_fn with different data types [TIL #1]
Table of contents
Today I Learned (TIL) is a series of short and actionable articles on topics related to programming, machine learning, deep learning, data science, etc…
List of all articles is available in today-i-learned repository
Introduction
In this article, we investigate how PyTorch DataLoader collates different types of data and collections by default
- We focus on dictionaries, but the same steps allow us to understand the behavior of any type of data
- We also learn how to implement a custom collate function for our specific data format
TL;DR
- Default collate behavior in PyTorch DataLoader depends on the type of the object/collection returned from the PyTorch Dataset
- By default, DataLoader uses the default_collate function to collate lists of samples into batches
- To check how different data types are handled by
default_collate
we can investigate examples in the docstring of this function - It is also possible to write custom
collate_fn
- examples in sections 1.3 and 2.3 below
1. General examples
Code is also available as a Colab Notebook:
The examples below all use 2x2 RGB images for simplicity - e.g. torch.rand(3, 2, 2)
.
1.1 Returning image Tensor and int target
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class TensorIntDataset(Dataset):
def __init__(self, n_samples):
self.imgs = [torch.rand(3, 2, 2) for i in range(n_samples)]
self.targets = np.random.randint(0, 9, size=n_samples)
def __getitem__(self, idx):
return self.imgs[idx], self.targets[idx]
def __len__(self):
return len(self.targets)
ti_dataset = TensorIntDataset(10)
ti_dataloader = DataLoader(ti_dataset, batch_size=2)
A single sample from this dataset is an image tensor in CHW format and int
target.
img, target = ti_dataset[0]
print(img.shape)
# torch.Size([3, 2, 2])
print(target)
# 3
So, as we can expect the Dataloader will return the batch of images in NCHW format and the tensor with targets.
imgs, targets = next(iter(ti_dataloader))
print(imgs.shape)
# torch.Size([2, 3, 2, 2])
print(targets)
# tensor([3, 5])
1.2 Returning image Tensor and dictionary target
class TensorDictDataset(Dataset):
def __init__(self, n_samples):
self.imgs = [torch.rand(3, 2, 2) for i in range(n_samples)]
self.targets = [
{
"label": np.random.randint(0, 9),
"other_value": np.random.randint(0, 9),
}
for i in range(n_samples)
]
def __getitem__(self, idx):
return self.imgs[idx], self.targets[idx]
def __len__(self):
return len(self.targets)
A single sample from this dataset is an image tensor in CHW format and the target dictionary.
td_dataset = TensorDictDataset(10)
img, target = td_dataset[0]
print(img.shape)
# torch.Size([3, 2, 2])
print(target)
# {'label': 1, 'other_value': 2}
So, based on example 1.1, we might expect that the Dataloader will return the target as a list of dictionaries - for example:
targets = [
{
"label": 4,
"other_value": 0,
},
{
"label": 2,
"other_value": 6,
},
]
However, this is not the case!
In fact, the Dataloader will return the batch of images in NCHW format and the single target dictionary containing targets for all the samples.
td_dataloader = DataLoader(td_dataset, batch_size=2)
imgs, targets = next(iter(td_dataloader))
print(imgs.shape)
# torch.Size([2, 3, 2, 2])
print(targets)
# {'label': tensor([1, 5]), 'other_value': tensor([2, 3])}
By default, DataLoader uses the default_collate function to collate lists of samples into batches.
To check how different data types are handled by default_collate
we can investigate the docstring of this function - for example, behavior for Mapping
is described here and we can see that it matches the output format we obtained above.
1.3 Custom collate_fn
To modify collate behavior for our specific needs we can write custom collate function based on the hint from the docstring.
def custom_collate(batch):
if (
isinstance(batch, list)
and len(batch[0]) == 2
and isinstance(batch[0][1], dict)
):
imgs = torch.stack([img for img, target in batch])
targets = [target for img, target in batch]
return imgs, targets
else: # Fall back to `default_collate`
return torch.utils.data.default_collate(batch)
td_dataloader = DataLoader(td_dataset, batch_size=2, collate_fn=custom_collate)
imgs, targets = next(iter(td_dataloader))
print(imgs.shape)
# torch.Size([2, 3, 2, 2])
print(targets)
# [{'label': 1, 'other_value': 2}, {'label': 5, 'other_value': 3}]
2. Real-life example - torchvision object detection
2.1 Motivation
Let’s imagine the following situation. We work with a fasterrcnn_resnet50_fpn
object detection model from the torchvision library.
During training, the model expects both the input tensors and a list of target dictionaries containing ground-truth boxes and labels with the following format.
def rand_boxes(n):
"""
Generate "random" bounding boxes ensuring x2>x1 and y2>y1
Only for presentation purposes
"""
xy1 = 0.9 * torch.rand(n, 2)
xy2 = xy1 + 0.1
return torch.cat([xy1, xy2], dim=1)
imgs = [torch.rand(3, 2, 2), torch.rand(3, 2, 2)] # 2 RGB images (2x2 size)
targets = [
{
# Ground-truth for the first image
# 5 boxes with [x1, y1, x2, y2] coordinates and 5 COCO class labels
"boxes": rand_boxes(5), # torch.Size([5, 4])
"labels": torch.randint(low=0, high=91, size=(5,)),
},
{
# Ground-truth for the second image
# 7 boxes with [x1, y1, x2, y2] coordinates and 7 COCO class labels
"boxes": rand_boxes(7), # torch.Size([7, 4])
"labels": torch.randint(low=0, high=91, size=(7,)),
},
]
from torchvision.models.detection import (
fasterrcnn_resnet50_fpn,
FasterRCNN_ResNet50_FPN_Weights,
)
model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
# Test if model accepts the data format
model.train()
model(imgs, targets) # OK - valid format
2.2 Initial implementation
Let’s prepare the PyTorch dataset that will return the data in this format
def rand_target():
n_objects = np.random.randint(1, 10)
target = {
"boxes": rand_boxes(n_objects),
"labels": torch.randint(low=0, high=91, size=(n_objects,)),
}
return target
class DetectionDataset(torch.utils.data.Dataset):
def __init__(self, n_samples):
self.imgs = [torch.rand(3, 2, 2) for i in range(n_samples)]
self.targets = [rand_target() for i in range(n_samples)]
def __getitem__(self, idx):
return self.imgs[idx], self.targets[idx]
def __len__(self):
return len(self.targets)
detection_dataset = DetectionDataset(10)
img, target = detection_dataset[0]
print(img.shape)
# torch.Size([3, 2, 2])
print(target)
# {'boxes': tensor([[0.5160, 0.4331, 0.6160, 0.5331],
# [0.0248, 0.1283, 0.1248, 0.2283],
# [0.1320, 0.6417, 0.2320, 0.7417],
# [0.5979, 0.1503, 0.6979, 0.2503],
# [0.4593, 0.6163, 0.5593, 0.7163],
# [0.2641, 0.8964, 0.3641, 0.9964],
# [0.3178, 0.7418, 0.4178, 0.8418]]),
# 'labels': tensor([12, 56, 19, 37, 83, 23, 79])}
A single sample returned from the dataset matches the format required by the torchvision
model
However, if we use the DataLoader with the default collate function, the format of the batched data will be incorrect or we might even encounter RuntimeError
if the number of targets is different for each sample
detection_dataloader = DataLoader(detection_dataset, batch_size=2)
# This code will throw RuntimeError or TypeError due to the data format problems
# imgs, targets = next(iter(detection_dataloader))
# model.train()
# model(imgs, targets)
2.3 Solution
The solution is to use the custom collate function similar to custom_collate
we introduced in section 1.3.
def custom_detection_collate(batch):
if (
isinstance(batch, list)
and len(batch[0]) == 2
and isinstance(batch[0][1], dict)
):
imgs = [img for img, target in batch]
targets = [target for img, target in batch]
return imgs, targets
else: # Fall back to `default_collate`
return torch.utils.data.default_collate(batch)
detection_dataloader = DataLoader(
detection_dataset, batch_size=2, collate_fn=custom_detection_collate
)
imgs, targets = next(iter(detection_dataloader))
print(f"{len(imgs)} images of size: {imgs[0].shape}")
# 2 images of size: torch.Size([3, 2, 2])
print(targets)
# [{'boxes': tensor([[0.5160, 0.4331, 0.6160, 0.5331],
# [0.0248, 0.1283, 0.1248, 0.2283],
# [0.1320, 0.6417, 0.2320, 0.7417],
# [0.5979, 0.1503, 0.6979, 0.2503],
# [0.4593, 0.6163, 0.5593, 0.7163],
# [0.2641, 0.8964, 0.3641, 0.9964],
# [0.3178, 0.7418, 0.4178, 0.8418]]),
# 'labels': tensor([12, 56, 19, 37, 83, 23, 79])},
# {'boxes': tensor([[0.2821, 0.0579, 0.3821, 0.1579],
# [0.7718, 0.6386, 0.8718, 0.7386]]),
# 'labels': tensor([81, 7])}]
model.train()
model(imgs, targets) # OK - valid format