Deep Learning Implementation Workflow with PyTorch
- Verify preprocessing, postprocessing, and network model output
- Create a Dataset
- Create a DataLoader
- Create a network model
- Define the forward pass
- Define the loss function
- Configure the optimization method
- Perform training and validation
- Run inference on test data
Addendum
If you are going to build deep learning models with PyTorch from now on, consider learning PyTorch Lightning alongside it.
About Dataset and DataLoader
- Dataset class
- A class that holds input data paired with its labels, etc.
- You can provide an instance of a preprocessing class so that preprocessing is automatically applied when loading data files
- DataLoader class
- A class that configures how data is fetched from a Dataset
- Makes it easy to extract mini-batches from a Dataset
About Network Models
- Creating a network model
- Building everything from scratch
- Loading and using a pretrained model
- Modifying a pretrained model as a base
- In applied deep learning methods, it is common to modify a pretrained model as a base
About the Forward Pass
- In applied deep learning methods, the network model often branches midway, so the forward pass tends to be complex
- Simple network models just flow from front to back, but since that is often not the case, you should properly define the forward function
About the Loss Function
- Defined for performing backpropagation
- For simple deep learning methods, a simple function like mean squared error is used, but more complex functions are used in applied deep learning methods
About the Optimization Method
- Used when training the connection parameters of the network model
- Backpropagation computes the gradient of the error with respect to the connection parameters, and this gradient is used to configure how the update amount for the connection parameters is calculated
- Examples include Momentum SGD, Adam, etc.
About Training, Validation, and Inference
- Basically, you check performance on both training data and validation data at each epoch
- Once validation performance stops improving, further training will lead to overfitting on the training data, so training is usually stopped at that point
- early stopping
- After training is complete, inference is performed on the test data
Implementing Transfer Learning
- Create a Dataset from image data
- Create a DataLoader from the Dataset
- Modify the output layer of a pretrained model to the desired shape
- Train only the connection parameters of the output layer to implement transfer learning
- Transfer Learning
- A method where a pretrained model is used as a base, and only the final output layer is replaced and trained
- The final output layer is replaced with one that corresponds to your own data, and the connection parameters to the replaced output layer are retrained with a small amount of your own data
- Since a pretrained model is used as the base, deep learning with good performance can be achieved even with a small amount of your own data
- Fine-tuning
- A method where the output layer and other parts of a pretrained model are modified, and the connection parameters of the entire neural network model are trained with your own data
- The pretrained model's parameters are used as initial values for the connection parameters
- Unlike transfer learning, not only the output layer and layers near it, but all layers' parameters are retrained
- It is common to set a smaller learning rate for parameters near the input layer and a larger learning rate for parameters near the output layer
- Like transfer learning, it has the advantage of achieving good deep learning performance even with a small amount of your own data
- The optimization method configuration differs from transfer learning
Creating a Dataset
- When creating a Dataset, using the torchvision.datasets.ImageFolder class is an easy approach
- The above method is simple, but you can also create a Dataset yourself
- Data Augmentation
- A technique that applies random image transformations to training data to augment it. The following classes are commonly used:
- RandomResizedCrop: A class that crops a given PIL image to a random size and aspect ratio
- (Usage example) RandomResizedCrop(resize, scale=(0.5~1.0))
- Scales up/down to between 0.5 and 1.0
- Additionally changes the aspect ratio to somewhere between 3/4 and 4/3, stretching the image horizontally or vertically
- Finally crops the image to the size specified by resize
- RandomHorizontalFlip: A class that randomly flips a given PIL image horizontally with a specified probability
- (Usage example) RandomHorizontalFlip()
- Flips the image left-right with a 50% probability
- torchvision.transforms
- albumentations is also a great choice
By augmenting data and training on diverse data, performance on test data (generalization performance) tends to improve!
Creating a DataLoader
- Created using a Dataset
- torch.utils.data.DataLoader
- shuffle=True
- Randomizes the order in which images are fetched
Defining the Loss Function
- For standard classification, the cross-entropy loss function is used
- Cross-entropy loss function
- Applies the softmax function to the output from the fully connected layer, then computes the negative log likelihood loss for classification
Configuring the Optimization Method
- Optimization Algorithms for Deep Learning
- requires_grad
- Controls whether to compute gradients for automatic differentiation
- requires_grad = True
- Gradients are computed during backpropagation, and values change during training (automatic differentiation is performed)
- requires_grad = False
- Used when you want to freeze parameters and prevent updates (automatic differentiation is not performed)
Performing Training and Validation
- Dropout and gradient computation are typically only used during training, not during prediction
- Therefore, the network is switched between training mode and evaluation mode
- (Example) net.train(), net.eval()
- Since gradients do not need to be computed during validation, conditional branching is used
Implementing Fine-tuning
- The optimization method differs from transfer learning
- Configure the optimizer so that all layers' parameters can be trained