Model Training
Integrations guide for our application.
Step 8.1: Training Process
- Training Function: The train_model() function is used to train the VGG-based multitask model.
- Early Stopping: Stops training if no improvement is observed for a defined number of epochs (patience).
- Training and Validation Phases: Each epoch consists of a training phase followed by a validation phase.
- Loss Calculation: Uses a combined loss for both binary and subtype classification tasks.
- Loss Functions: Cross-entropy loss is used for both classification tasks.
- Optimizer: Adam optimizer with weight decay is used to help regularize the model.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
- Learning Rate Scheduler: A step learning rate scheduler is applied to adjust the learning rate periodically during training.
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
Step 8.2: Training Parameters
- Epochs: The model is trained for 25 epochs.
- Batch Size: A batch size of 16 is used for training and validation.
- Patience: The early stopping mechanism is set to stop training after 5 epochs without improvement.
- Checkpointing: The best model weights are saved based on validation accuracy.