ResNet50 Model
Step 7.1: Multitask Model without BatchNorm1d
- Base Model (ResNet50):
- Layers: ResNet50 consists of 50 layers, including convolutional layers, pooling layers, and fully connected layers.
- Residual Learning: Residual connections help train very deep networks by addressing the vanishing gradient problem.
- Pretrained Model: The ResNet50 is used as a feature extractor, leveraging its ability to extract rich features from medical images.
Model Modifications
-
Remove Final Layer: The final fully connected (fc) layer of ResNet50 is replaced with two separate heads:
- Binary Classifier Head:
- Architecture: Linear layer (256 units) -> ReLU -> Dropout -> Linear layer (2 output units).
- Purpose: Classify images as Benign or Malignant.
self.binary_classifier = nn.Sequential( nn.Linear(num_ftrs, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, 2) )
- Subtype Classifier Head:
- Architecture: Linear layer (256 units) -> ReLU -> Dropout -> Linear layer (3 output units).
- Purpose: Classify malignant images into subtypes (Pre-B, Pro-B, early Pre-B).
self.subtype_classifier = nn.Sequential( nn.Linear(num_ftrs, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, 3) )
- Binary Classifier Head:
Multitask Learning
- Simultaneous Learning: The model simultaneously learns to classify images as Benign/Malignant and, if Malignant, classify the subtype.
- Shared Representations: Multitask learning helps leverage shared representations, improving overall performance.
Freezing Early Layers
- Frozen Layers: Freeze the early layers of ResNet50 to prevent them from being updated during training.
- Purpose: The early layers typically learn general features (e.g., edges, textures) applicable to a wide range of images.
- Benefits: Reduces computational cost and prevents overfitting, especially with limited data.
Forward Pass
-
Feature Extraction: Pass the input image through the base ResNet50 model to extract features.
-
Binary Classification: Use the binary classifier head to determine if the image is Benign or Malignant.
-
Subtype Classification: If classified as Malignant, pass the features through the subtype classifier head to determine the specific subtype.
def forward(self, x): x = self.base_model(x) binary_out = self.binary_classifier(x) subtype_out = self.subtype_classifier(x) return binary_out, subtype_out
Overfitting Prevention
- Dropout Layers: Dropout layers in the classifier heads randomly disable neurons during training.
- Purpose: Helps prevent overfitting, which is crucial when working with medical data that often has limited samples.