Skip to main content

Deep Learning Implementation Workflow with PyTorch

  1. Verify preprocessing, postprocessing, and network model output
  2. Create a Dataset
  3. Create a DataLoader
  4. Create a network model
  5. Define the forward pass
  6. Define the loss function
  7. Configure the optimization method
  8. Perform training and validation
  9. Run inference on test data

Addendum

If you are going to build deep learning models with PyTorch from now on, consider learning PyTorch Lightning alongside it.

About Dataset and DataLoader

  • Dataset class
    • A class that holds input data paired with its labels, etc.
    • You can provide an instance of a preprocessing class so that preprocessing is automatically applied when loading data files
  • DataLoader class
    • A class that configures how data is fetched from a Dataset
    • Makes it easy to extract mini-batches from a Dataset

About Network Models

  • Creating a network model
    • Building everything from scratch
    • Loading and using a pretrained model
    • Modifying a pretrained model as a base
  • In applied deep learning methods, it is common to modify a pretrained model as a base

About the Forward Pass

  • In applied deep learning methods, the network model often branches midway, so the forward pass tends to be complex
  • Simple network models just flow from front to back, but since that is often not the case, you should properly define the forward function

About the Loss Function

  • Defined for performing backpropagation
  • For simple deep learning methods, a simple function like mean squared error is used, but more complex functions are used in applied deep learning methods

About the Optimization Method

  • Used when training the connection parameters of the network model
  • Backpropagation computes the gradient of the error with respect to the connection parameters, and this gradient is used to configure how the update amount for the connection parameters is calculated
  • Examples include Momentum SGD, Adam, etc.

About Training, Validation, and Inference

  • Basically, you check performance on both training data and validation data at each epoch
  • Once validation performance stops improving, further training will lead to overfitting on the training data, so training is usually stopped at that point
    • early stopping
  • After training is complete, inference is performed on the test data

Implementing Transfer Learning

  1. Create a Dataset from image data
  2. Create a DataLoader from the Dataset
  3. Modify the output layer of a pretrained model to the desired shape
  4. Train only the connection parameters of the output layer to implement transfer learning
  • Transfer Learning
    • A method where a pretrained model is used as a base, and only the final output layer is replaced and trained
    • The final output layer is replaced with one that corresponds to your own data, and the connection parameters to the replaced output layer are retrained with a small amount of your own data
    • Since a pretrained model is used as the base, deep learning with good performance can be achieved even with a small amount of your own data
  • Fine-tuning
    • A method where the output layer and other parts of a pretrained model are modified, and the connection parameters of the entire neural network model are trained with your own data
    • The pretrained model's parameters are used as initial values for the connection parameters
    • Unlike transfer learning, not only the output layer and layers near it, but all layers' parameters are retrained
    • It is common to set a smaller learning rate for parameters near the input layer and a larger learning rate for parameters near the output layer
    • Like transfer learning, it has the advantage of achieving good deep learning performance even with a small amount of your own data
    • The optimization method configuration differs from transfer learning

Creating a Dataset

  • When creating a Dataset, using the torchvision.datasets.ImageFolder class is an easy approach
  • The above method is simple, but you can also create a Dataset yourself
  • Data Augmentation
    • A technique that applies random image transformations to training data to augment it. The following classes are commonly used:
    • RandomResizedCrop: A class that crops a given PIL image to a random size and aspect ratio
    • (Usage example) RandomResizedCrop(resize, scale=(0.5~1.0))
      • Scales up/down to between 0.5 and 1.0
      • Additionally changes the aspect ratio to somewhere between 3/4 and 4/3, stretching the image horizontally or vertically
      • Finally crops the image to the size specified by resize
    • RandomHorizontalFlip: A class that randomly flips a given PIL image horizontally with a specified probability
    • (Usage example) RandomHorizontalFlip()
      • Flips the image left-right with a 50% probability
    • torchvision.transforms
    • albumentations is also a great choice

By augmenting data and training on diverse data, performance on test data (generalization performance) tends to improve!

Creating a DataLoader

Defining the Loss Function

  • For standard classification, the cross-entropy loss function is used
  • Cross-entropy loss function
    • Applies the softmax function to the output from the fully connected layer, then computes the negative log likelihood loss for classification

Configuring the Optimization Method

  • Optimization Algorithms for Deep Learning
  • requires_grad
    • Controls whether to compute gradients for automatic differentiation
  • requires_grad = True
    • Gradients are computed during backpropagation, and values change during training (automatic differentiation is performed)
  • requires_grad = False
    • Used when you want to freeze parameters and prevent updates (automatic differentiation is not performed)

Performing Training and Validation

  • Dropout and gradient computation are typically only used during training, not during prediction
    • Therefore, the network is switched between training mode and evaluation mode
    • (Example) net.train(), net.eval()
  • Since gradients do not need to be computed during validation, conditional branching is used

Implementing Fine-tuning

  • The optimization method differs from transfer learning
    • Configure the optimizer so that all layers' parameters can be trained