Skip to main content

Useful Features of PyTorch Lightning's CheckpointCallback

A note on the feature that lets you save anything in a checkpoint file during training.

on_save_checkpoint

def on_save_checkpoint(self, checkpoint):
# 99% of use cases you don't need to implement this method
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object

By using this, you can store information needed at inference time in the checkpoint alongside the parameters, saving you the trouble of manually loading it or saving it in a separate file at inference time.

def on_load_checkpoint

You can load it with this:

def on_load_checkpoint(self, checkpoint):
# 99% of the time you don't need to implement this method
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']

Use case

You can save things like the covariance matrix of the training data together with the model weights.