Model Training
Step 8.1: Training Process
-
Training Function: train_model() function with support for early stopping.
- Early Stopping: Stops training if no improvement is observed after a specified number of epochs.
- Phases: Each epoch has two phases - training and validation.
- Loss Calculation: Combines binary classification loss and subtype classification loss.
-
Loss Functions: Cross-entropy loss for both tasks.
-
Optimizer: Adam optimizer with weight decay for regularization.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
-
Learning Rate Scheduler: Step learning rate scheduler to adjust learning rate over epochs.
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
Step 8.2: Training Parameters
- Epochs: Train for 25 epochs.
- Batch Size: Use a batch size of 16.
- Patience: Early stopping patience set to 5 epochs.
- Checkpointing: Saving the model state with the best validation accuracy.