Files
Stomach_Cancer_Pytorch/experiments/pytorch_Model.py

39 lines
1.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch.nn as nn
import timm
class ModifiedXception(nn.Module):
def __init__(self, num_classes):
super(ModifiedXception, self).__init__()
# Load Xception pre-trained model (full model, not just features)
self.base_model = timm.create_model(
'xception',
pretrained=True,
drop_rate=0.0, # Optional: adjust dropout if needed
)
# Replace the default global pooling with AdaptiveAvgPool2d
self.base_model.global_pool = nn.AdaptiveAvgPool2d(output_size=1) # Output size of 1x1 spatially
# Replace the final fully connected layer with Identity to get features
self.base_model.fc = nn.Identity() # Output will be 2048 (Xception's default feature size)
# Custom head: Linear from 2048 to 1370, additional 1370 layer, then to num_classes
self.custom_head = nn.Sequential(
nn.Linear(2048, 1025), # From Xceptions 2048 features to 1370
nn.ReLU(), # Activation
nn.Dropout(0.6), # Dropout for regularization
nn.Linear(1025, num_classes), # Final output layer
nn.Sigmoid() # Sigmoid for binary/multi-label classification
)
def forward(self, x):
# Pass through the base Xception model (up to global pooling)
x = self.base_model.forward_features(x) # Get feature maps
x = self.base_model.global_pool(x) # Apply AdaptiveAvgPool2d (output: [B, 2048, 1, 1])
x = x.flatten(1) # Flatten to [B, 2048]
x = self.base_model.fc(x) # Identity layer (still [B, 2048])
output = self.custom_head(x) # Custom head processing
return output