81 lines
3.5 KiB
Python
81 lines
3.5 KiB
Python
import torch.nn as nn
|
||
import timm
|
||
|
||
|
||
# class ModifiedXception(nn.Module):
|
||
# def __init__(self, num_classes):
|
||
# super(ModifiedXception, self).__init__()
|
||
|
||
# # Load Xception pre-trained model (full model, not just features)
|
||
# self.base_model = timm.create_model(
|
||
# 'xception',
|
||
# pretrained=True
|
||
# )
|
||
|
||
# # Replace the default global pooling with AdaptiveAvgPool2d
|
||
# self.base_model.global_pool = nn.AdaptiveAvgPool2d(output_size=1) # Output size of 1x1 spatially
|
||
|
||
# # Replace the final fully connected layer with Identity to get features
|
||
# self.base_model.fc = nn.Identity() # Output will be 2048 (Xception's default feature size)
|
||
|
||
# # Custom head: Linear from 2048 to 1370, additional 1370 layer, then to num_classes
|
||
# self.custom_head = nn.Sequential(
|
||
# nn.Linear(2048, 1025), # From Xception’s 2048 features to 1370
|
||
# nn.ReLU(), # Activation
|
||
# nn.Dropout(0.6), # Dropout for regularization
|
||
# nn.Linear(1025, num_classes) # Final output layer
|
||
# # nn.Softmax(dim = 1) # Sigmoid for binary/multi-label classification
|
||
# )
|
||
|
||
# def forward(self, x):
|
||
# # Pass through the base Xception model (up to global pooling)
|
||
# x = self.base_model.forward_features(x) # Get feature maps
|
||
# x = self.base_model.global_pool(x) # Apply AdaptiveAvgPool2d (output: [B, 2048, 1, 1])
|
||
# x = x.flatten(1) # Flatten to [B, 2048]
|
||
# # x = self.base_model.fc(x) # Identity layer (still [B, 2048])
|
||
# output = self.custom_head(x) # Custom head processing
|
||
# return output
|
||
|
||
class ModifiedXception(nn.Module):
|
||
def __init__(self, num_classes):
|
||
super(ModifiedXception, self).__init__()
|
||
|
||
# 加載 Xception 預訓練模型,去掉最後一層 (fc 層)
|
||
self.base_model = timm.create_model('xception', pretrained=True)
|
||
self.base_model.fc = nn.Identity() # 移除原來的 fully connected 層
|
||
|
||
# 新增全局平均池化層、隱藏層和輸出層
|
||
GAP_Output = 2048
|
||
self.global_avg_pool = nn.AdaptiveAvgPool1d(2048) # 全局平均池化
|
||
self.hidden_layer = nn.Linear(2048, 1025) # 隱藏層,輸入大小取決於 Xception 的輸出大小
|
||
self.output_layer = nn.Linear(1025, num_classes) # 輸出層,依據分類數目設定
|
||
|
||
# 激活函數與 dropout
|
||
self.relu = nn.ReLU()
|
||
self.softmax = nn.Softmax(1)
|
||
self.dropout = nn.Dropout(0.6)
|
||
|
||
def forward(self, x):
|
||
x = self.base_model(x) # Xception 主體
|
||
x = self.global_avg_pool(x) # 全局平均池化
|
||
x = self.relu(self.hidden_layer(x)) # 隱藏層 + ReLU
|
||
x = self.dropout(x) # Dropout
|
||
x = self.output_layer(x) # 輸出層
|
||
return x
|
||
|
||
class Model_module():
|
||
def __init__(self):
|
||
self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = 3, padding = 1)
|
||
self.conv2 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, padding = 1)
|
||
self.conv3 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, padding = 1)
|
||
|
||
self.relu = nn.ReLU()
|
||
self.sigmoid = nn.Sigmoid()
|
||
|
||
self.max_Pool = nn.MaxPool2d(2, 2)
|
||
|
||
self.fc1 = nn.Linear()
|
||
self.fc2 = nn.Linear()
|
||
pass
|
||
def forward(self, input):
|
||
pass |