ResNet50 Model

Step 7.1: Multitask Model without BatchNorm1d

  • Base Model (ResNet50):
    • Layers: ResNet50 consists of 50 layers, including convolutional layers, pooling layers, and fully connected layers.
    • Residual Learning: Residual connections help train very deep networks by addressing the vanishing gradient problem.
    • Pretrained Model: The ResNet50 is used as a feature extractor, leveraging its ability to extract rich features from medical images.

Model Modifications

  • Remove Final Layer: The final fully connected (fc) layer of ResNet50 is replaced with two separate heads:

    1. Binary Classifier Head:
      • Architecture: Linear layer (256 units) -> ReLU -> Dropout -> Linear layer (2 output units).
      • Purpose: Classify images as Benign or Malignant.
        self.binary_classifier = nn.Sequential(
              nn.Linear(num_ftrs, 256),
              nn.ReLU(),
              nn.Dropout(0.5),
              nn.Linear(256, 2)
          )
    
    1. Subtype Classifier Head:
      • Architecture: Linear layer (256 units) -> ReLU -> Dropout -> Linear layer (3 output units).
      • Purpose: Classify malignant images into subtypes (Pre-B, Pro-B, early Pre-B).
        self.subtype_classifier = nn.Sequential(
              nn.Linear(num_ftrs, 256),
              nn.ReLU(),
              nn.Dropout(0.5),
              nn.Linear(256, 3)
          )
    

Multitask Learning

  • Simultaneous Learning: The model simultaneously learns to classify images as Benign/Malignant and, if Malignant, classify the subtype.
    • Shared Representations: Multitask learning helps leverage shared representations, improving overall performance.

Freezing Early Layers

  • Frozen Layers: Freeze the early layers of ResNet50 to prevent them from being updated during training.
    • Purpose: The early layers typically learn general features (e.g., edges, textures) applicable to a wide range of images.
    • Benefits: Reduces computational cost and prevents overfitting, especially with limited data.

Forward Pass

  • Feature Extraction: Pass the input image through the base ResNet50 model to extract features.

  • Binary Classification: Use the binary classifier head to determine if the image is Benign or Malignant.

  • Subtype Classification: If classified as Malignant, pass the features through the subtype classifier head to determine the specific subtype.

        def forward(self, x):
          x = self.base_model(x)
          binary_out = self.binary_classifier(x)
          subtype_out = self.subtype_classifier(x)
          return binary_out, subtype_out
    

Overfitting Prevention

  • Dropout Layers: Dropout layers in the classifier heads randomly disable neurons during training.
    • Purpose: Helps prevent overfitting, which is crucial when working with medical data that often has limited samples.