Model Training

Step 8.1: Training Process

  • Training Function: train_model() function with support for early stopping.

    • Early Stopping: Stops training if no improvement is observed after a specified number of epochs.
    • Phases: Each epoch has two phases - training and validation.
    • Loss Calculation: Combines binary classification loss and subtype classification loss.
  • Loss Functions: Cross-entropy loss for both tasks.

  • Optimizer: Adam optimizer with weight decay for regularization.

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
    
    
  • Learning Rate Scheduler: Step learning rate scheduler to adjust learning rate over epochs.

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    

Step 8.2: Training Parameters

  • Epochs: Train for 25 epochs.
  • Batch Size: Use a batch size of 16.
  • Patience: Early stopping patience set to 5 epochs.
  • Checkpointing: Saving the model state with the best validation accuracy.