39 lines
1.8 KiB
Python
39 lines
1.8 KiB
Python
import torch.nn as nn
|
||
import timm
|
||
|
||
|
||
class ModifiedXception(nn.Module):
|
||
def __init__(self, num_classes):
|
||
super(ModifiedXception, self).__init__()
|
||
|
||
# Load Xception pre-trained model (full model, not just features)
|
||
self.base_model = timm.create_model(
|
||
'xception',
|
||
pretrained=True,
|
||
drop_rate=0.0, # Optional: adjust dropout if needed
|
||
)
|
||
|
||
# Replace the default global pooling with AdaptiveAvgPool2d
|
||
self.base_model.global_pool = nn.AdaptiveAvgPool2d(output_size=1) # Output size of 1x1 spatially
|
||
|
||
# Replace the final fully connected layer with Identity to get features
|
||
self.base_model.fc = nn.Identity() # Output will be 2048 (Xception's default feature size)
|
||
|
||
# Custom head: Linear from 2048 to 1370, additional 1370 layer, then to num_classes
|
||
self.custom_head = nn.Sequential(
|
||
nn.Linear(2048, 1025), # From Xception’s 2048 features to 1370
|
||
nn.ReLU(), # Activation
|
||
nn.Dropout(0.6), # Dropout for regularization
|
||
nn.Linear(1025, num_classes), # Final output layer
|
||
nn.Sigmoid() # Sigmoid for binary/multi-label classification
|
||
)
|
||
|
||
def forward(self, x):
|
||
# Pass through the base Xception model (up to global pooling)
|
||
x = self.base_model.forward_features(x) # Get feature maps
|
||
x = self.base_model.global_pool(x) # Apply AdaptiveAvgPool2d (output: [B, 2048, 1, 1])
|
||
x = x.flatten(1) # Flatten to [B, 2048]
|
||
x = self.base_model.fc(x) # Identity layer (still [B, 2048])
|
||
output = self.custom_head(x) # Custom head processing
|
||
return output
|