PyTorch comes with powerful data loading capabilities out of the box. But with great power comes great responsibility and that makes data loading in PyTorch a fairly advanced topic.
One of the best ways to learn advanced topics is to start with the happy path. Then add complexity when you find out you need it. Let’s run through a quick start example.
What is a PyTorch DataLoader?
The PyTorch DataLoader
class gives you an iterable over a Dataset
. It’s useful because it can parallelize data loading and automatically shuffle and batch individual samples, all out of the box. This sets you up for a very simple training loop.
PyTorch Dataset
But to create a DataLoader
, you have to start with a Dataset
, the class responsible for actually reading samples into memory. When you’re implementing a DataLoader
, the Dataset
is where almost all of the interesting logic will go.
There are two styles of Dataset
class, map-style and iterable-style. Map-style Datasets
are more common and more straightforward so we’ll focus on them but you can read more about iterable-style datasets in the docs.
To create a map-style Dataset
class, you need to implement two methods: __getitem__()
and __len__()
. The __len__()
method returns the total number of samples in the dataset and the __getitem__()
method takes an index and returns the sample at that index.
PyTorch Dataset
objects are very flexible â they can return any kind of tensor(s) you want. But supervised training datasets should usually return an input tensor and a label. For illustration purposes, let’s create a dataset where the input tensor is a 3×3 matrix with the index along the diagonal. The label will be the index.
It should look like this:
dataset[3]
# Expected result
# {'x': array([[3., 0., 0.],
# [0., 3., 0.],
# [0., 0., 3.]]),
# 'y': 3}
Remember, all we have to implement are __getitem__()
and __len__()
:
from typing import Dict, Union
import numpy as np
import torch
class ToyDataset(torch.utils.data.Dataset):
def __init__(self, size: int):
self.size = size
def __len__(self) -> int:
return self.size
def __getitem__(self, index: int) -> Dict[str, Union[int, np.ndarray]]:
return dict(
x=np.eye(3) * index,
y=index,
)
Very simple. We can instantiate the class and start accessing individual samples:
dataset = ToyDataset(10)
dataset[3]
# Expected result
# {'x': array([[3., 0., 0.],
# [0., 3., 0.],
# [0., 0., 3.]]),
# 'y': 3}
If happen to be working with image data, __getitem__()
may be a good place to put your TorchVision transforms.
At this point, a sample is a dict
with "x"
as a matrix with shape (3, 3)
and "y"
as a Python integer. But what we want are batches of data. "x"
should be a PyTorch tensor with shape (batch_size, 3, 3)
and "y"
should be a tensor with shape batch_size
. This is where DataLoader
comes back in.
PyTorch DataLoader
To iterate through batches of samples, pass your Dataset
object to a DataLoader
:
torch.manual_seed(1234)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=3,
shuffle=True,
num_workers=2,
)
for batch in loader:
print(batch["x"].shape, batch["y"])
# Expected result
# torch.Size([3, 3, 3]) tensor([2, 1, 3])
# torch.Size([3, 3, 3]) tensor([6, 7, 9])
# torch.Size([3, 3, 3]) tensor([5, 4, 8])
# torch.Size([1, 3, 3]) tensor([0])
Notice a few things that are happening here:
- Both the NumPy arrays and Python integers are both getting converted to PyTorch tensors.
- Although we’re fetching individual samples in
ToyDataset
, theDataLoader
is automatically batching them for us, with the batch size we request. This works even though the individual samples are in dict structures. This also works if you return tuples. - The samples are randomly shuffled. We maintain reproducibility by setting
torch.manual_seed(1234)
. - The samples are read in parallel across processes. In fact, this code will fail if you run it in a Jupyter notebook. To get it to work, you need to put it underneath a
if __name__ == "__main__":
check in a Python script.
There’s one other thing that I’m not doing in this sample but you should be aware of. If you need to use your tensors on a GPU (and you probably are for non-trivial PyTorch problems), then you should set pin_memory=True
in the DataLoader
. This will speed things up by letting the DataLoader
allocate space in page-locked memory. You can read more about it here.
Summary
To review: the interesting part of custom PyTorch data loaders is the Dataset
class you implement. From there, you get lots of nice features to simplify your data loop. If you need something more advanced, like custom batching logic, check out the API docs. Happy training!
The post PyTorch DataLoader Quick Start appeared first on Sparrow Computing.
Planet Python