Model Training

Integrations guide for our application.

Step 8.1: Training Process

  • Training Function: The train_model() function is used to train the VGG-based multitask model.
    • Early Stopping: Stops training if no improvement is observed for a defined number of epochs (patience).
    • Training and Validation Phases: Each epoch consists of a training phase followed by a validation phase.
    • Loss Calculation: Uses a combined loss for both binary and subtype classification tasks.
  • Loss Functions: Cross-entropy loss is used for both classification tasks.
  • Optimizer: Adam optimizer with weight decay is used to help regularize the model.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

  • Learning Rate Scheduler: A step learning rate scheduler is applied to adjust the learning rate periodically during training.
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

Step 8.2: Training Parameters

  • Epochs: The model is trained for 25 epochs.
  • Batch Size: A batch size of 16 is used for training and validation.
  • Patience: The early stopping mechanism is set to stop training after 5 epochs without improvement.
  • Checkpointing: The best model weights are saved based on validation accuracy.