[PyTorch Lightning] About LightningDataModule
Since there was no proper English explanation available, I am writing this as a reference for the future.
Overview
LightningDataModule is the class that serves as the DataLoader (and in some cases the Dataset as well) when running models with PyTorch Lightning. It is compatible with the corresponding PyTorch modules. Including this, you need to write:
LightningModulefor the modelLightningDataModulefor the data- Any other necessary customizations (
Callbacks API,LR_FINDER, etc.)
How to Write a LightningDataModule
In addition to init, you need to implement three methods:
prepare_data(optional -- it works without it)setup~_dataloader
0. __init__
Create the necessary parameters. Note that you are not creating Dataset objects here. In the following example, we assume that the test data and training data are separated into different directories.
import pytorch-lightning as pl
from torch.utils.data import random_split, DataLoader
from torchvision import transforms
class DataModule(pl.LightningDataModule):
def __init__(self, train_dir='./train', test_dir='./test', batch_size=64):
super().__init__()
self.train_dir = train_dir
self.test_dir = test_dir
self.batch_size = batch_size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
self.data_augmentation = transforms.Compose([
transforms.ToTensor(),
# ... some data augmentations...
transforms.Normalize((0.1307,), (0.3081,))
])
Side note: albumentations is a convenient data augmentation library that is compatible with torchvision's transforms, so you can use it here as well
1. prepare_data
This is the first method to be called. Write processing that should only run once regardless of the number of GPUs, such as downloading data. By writing it here, multi-GPU setups will handle the download process appropriately. For example, when downloading MNIST:
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
2. setup
This is the second method to be called.
Write the logic here for providing different Datasets when Trainer.fit() and Trainer.test() are called.
This is also a good place to switch data augmentation on or off.
It is easier to read if you create a separate Dataset class.
- Note: The Trainer passes the mode as a string via the stage argument, but make sure to handle the case when it is None. You may call setup manually.
- Note 2: In multi-GPU setups, this is called once from each GPU.
def setup(self, stage=None):
if stage == 'fit' or stage is None:
self.train_set = MyDataset(
self.train_dir,
transform=self.data_augmentation
)
size = len(self.train_set)
t, v = (int(size * 0.9), int(size * 0.1)) # if using holdout method
t += (t + v != size)
self.train_set, self.valid_set = random_split(self.train_set, [t, v])
if stage == 'test' or stage is None:
self.test_set = MyDataset(
self.test_dir,
transform=self.transform
)
3. ~_dataloader
This is the last method to be called and returns DataLoader objects. Write three of them for training, validation, and testing.
def train_dataloader(self):
return DataLoader(
self.train_set,
batch_size=self.batch_size,
)
def val_dataloader(self):
return DataLoader(
self.test_set,
batch_size=self.batch_size,
)
def test_dataloader(self):
return DataLoader(
self.valid_set,
batch_size=self.batch_size,
)
That covers all the required methods.
EXTRA: Using a LightningDataModule
In the normal case:
dm = DataModule()
model = Model()
trainer.fit(model, dm)
trainer.test(datamodule=dm)
This will automatically call the above methods and run training for you.
However, in some cases you may need dataset information (such as the number of classes, image size, or number of channels) when creating the model. In that case, write the logic for collecting the necessary information inside setup, then:
dm = DataModule()
dm.prepare_data()
dm.setup('fit') # Make sure attributes are set up to store the information
model = Model(num_classes=dm.num_classes, width=dm.=img_size)
trainer.fit(model, dm)