Skip to main content

How to Use PyTorch Lightning Bolts

What a convenient time to be alive.

What is Bolts?

An official library packed with useful code for PyTorch Lightning.

  • Pretrained SOTA models
  • Commonly used model components
  • Forward and backward hooks for the Callback API
  • Loss functions
  • Popular datasets

All of these are ready to use with PyTorch Lightning right away, making it extremely convenient. Below are some usage examples.

1. Using a Pretrained Model As-Is

When you want to try out the latest clustering, for example:

from pl_bolts.models.self_supervised import SwAV

weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar' # weight file of ImageNet
swav = SwAV.load_from_checkpoint(weight_path, strict=True)

swav.freeze()

After this, swav can be used as a regular nn.Module.

2. Using Individual Components

You can use pretrained models as backbones or adopt just the encoder portion. In this example, we change the input channel count of ResNet152 from 3 to 4.

from pl_bolts.models.self_supervised.resnets import resnet152

model = resnet152(pretrained=True)

temp_weight = model.conv1.weight.data.clone() # Save existing weights
model.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False) # Only increase input_channel
model.conv1.weight.data[:, :3] = temp_weight # Use existing weights for the first 3 channels
model.conv1.weight.data[:, 3] = model.conv1.weight.data[:, 0] # Use the R channel weights as the 4th channel weights

Conversely, you can also use just model.conv1 in a different model.

3. Using Callbacks

There are various convenient utilities available for PyTorch Lightning's Callback API. Just declare the necessary Callback objects and pass them as a list to the Trainer. Here are two examples.

  1. Display loss at each epoch
from pl_bolts.callbacks import PrintTableMetricsCallback

print_callback = PrintTableMetricsCallback()
trainer pl.Trainer(callback=[print_callback])
trainer.fit(model)
  1. Display images generated during a GAN's forward pass in TensorBoard
model.img_dim = (1, 28, 28)

# model forward must work for sampling
z = torch.rand(batch_size, latent_dim)
img_samples = GAN(z)

from pl_bolts.callbacks import TensorboardGenerativeModelImageSampler

trainer = Trainer(callbacks=[TensorboardGenerativeModelImageSampler()])
trainer.fit(GAN)

4. Using Loss Functions

Several functions are implemented for different tasks, though the number is still limited. Once imported, they become regular PyTorch loss functions, so you can pass them to the lossfun method when creating a model class. Below is GeneralizedIoU for object detection.

>>> import torch
>>> from pl_bolts.losses.object_detection import giou_loss
>>> preds = torch.tensor([[100, 100, 200, 200]])
>>> target = torch.tensor([[150, 150, 250, 250]])
>>> giou_loss(preds, target)
tensor([[1.0794]])
def lossfun(self, y, t): # method of a network
return giou_loss(y, t)

5. Using Data Modules

Datasets that have been converted into LightningDataModule format are available. They also include a download feature to a specified directory, so you can start testing models immediately. Multi-GPU compatible.

  • You can customize the data augmentation yourself
from pl_bolts.datamodules import CIFAR10DataModule

dm = CIFAR10DataModule('PATH_to_download/to_load')
dm.train_transforms = ... # Pass a Compose object here
dm.test_transforms = ...
dm.val_transforms = ...
  • You can create a LightningDataModule just by passing NumPy x and y arrays -- pretty impressive
>>> from sklearn.datasets import load_boston
>>> from pl_bolts.datamodules import SklearnDataset
...
>>> X, y = load_boston(return_X_y=True)
>>> dataset = SklearnDataset(X, y) # Anything with matching shapes can be passed