How to Use PyTorch Lightning Bolts
What a convenient time to be alive.
What is Bolts?
An official library packed with useful code for PyTorch Lightning.
- Pretrained SOTA models
- Commonly used model components
- Forward and backward hooks for the Callback API
- Loss functions
- Popular datasets
All of these are ready to use with PyTorch Lightning right away, making it extremely convenient. Below are some usage examples.
1. Using a Pretrained Model As-Is
When you want to try out the latest clustering, for example:
from pl_bolts.models.self_supervised import SwAV
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar' # weight file of ImageNet
swav = SwAV.load_from_checkpoint(weight_path, strict=True)
swav.freeze()
After this, swav can be used as a regular nn.Module.
2. Using Individual Components
You can use pretrained models as backbones or adopt just the encoder portion.
In this example, we change the input channel count of ResNet152 from 3 to 4.
from pl_bolts.models.self_supervised.resnets import resnet152
model = resnet152(pretrained=True)
temp_weight = model.conv1.weight.data.clone() # Save existing weights
model.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False) # Only increase input_channel
model.conv1.weight.data[:, :3] = temp_weight # Use existing weights for the first 3 channels
model.conv1.weight.data[:, 3] = model.conv1.weight.data[:, 0] # Use the R channel weights as the 4th channel weights
Conversely, you can also use just model.conv1 in a different model.
3. Using Callbacks
There are various convenient utilities available for PyTorch Lightning's Callback API.
Just declare the necessary Callback objects and pass them as a list to the Trainer.
Here are two examples.
- Display loss at each epoch
from pl_bolts.callbacks import PrintTableMetricsCallback
print_callback = PrintTableMetricsCallback()
trainer pl.Trainer(callback=[print_callback])
trainer.fit(model)
- Display images generated during a GAN's forward pass in TensorBoard
model.img_dim = (1, 28, 28)
# model forward must work for sampling
z = torch.rand(batch_size, latent_dim)
img_samples = GAN(z)
from pl_bolts.callbacks import TensorboardGenerativeModelImageSampler
trainer = Trainer(callbacks=[TensorboardGenerativeModelImageSampler()])
trainer.fit(GAN)
4. Using Loss Functions
Several functions are implemented for different tasks, though the number is still limited.
Once imported, they become regular PyTorch loss functions, so you can pass them to the lossfun method when creating a model class.
Below is GeneralizedIoU for object detection.
>>> import torch
>>> from pl_bolts.losses.object_detection import giou_loss
>>> preds = torch.tensor([[100, 100, 200, 200]])
>>> target = torch.tensor([[150, 150, 250, 250]])
>>> giou_loss(preds, target)
tensor([[1.0794]])
def lossfun(self, y, t): # method of a network
return giou_loss(y, t)
5. Using Data Modules
Datasets that have been converted into LightningDataModule format are available.
They also include a download feature to a specified directory, so you can start testing models immediately. Multi-GPU compatible.
- You can customize the data augmentation yourself
from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule('PATH_to_download/to_load')
dm.train_transforms = ... # Pass a Compose object here
dm.test_transforms = ...
dm.val_transforms = ...
- You can create a
LightningDataModulejust by passing NumPy x and y arrays -- pretty impressive
>>> from sklearn.datasets import load_boston
>>> from pl_bolts.datamodules import SklearnDataset
...
>>> X, y = load_boston(return_X_y=True)
>>> dataset = SklearnDataset(X, y) # Anything with matching shapes can be passed