51 lines
1.9 KiB
Python
51 lines
1.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torchvision.transforms as transforms
|
|
from torchvision import models
|
|
import timm
|
|
|
|
|
|
class ModifiedXception(nn.Module):
|
|
def __init__(self, num_classes):
|
|
super(ModifiedXception, self).__init__()
|
|
|
|
# 加載 Xception 預訓練模型,去掉最後一層 (fc 層)
|
|
self.base_model = timm.create_model(
|
|
'xception',
|
|
pretrained=True,
|
|
features_only=True, # 只保留特徵提取部分
|
|
out_indices=[3] # 選擇特徵層索引(根據模型結構)
|
|
)
|
|
|
|
# 自定義分類頭
|
|
self.custom_head = nn.Sequential(
|
|
nn.AdaptiveAvgPool2d(1), # Global Average Pooling,
|
|
nn.Flatten(),
|
|
nn.Linear(728, 368), # Xception 輸出特徵維度為2048
|
|
nn.ReLU(), # 可選激活函數
|
|
nn.Linear(368, num_classes),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
# self.base_model.fc = nn.Identity() # 移除原來的 fully connected 層
|
|
|
|
# # 新增全局平均池化層、隱藏層和輸出層
|
|
# self.global_avg_pool = nn.AdaptiveAvgPool2d(1) # 全局平均池化
|
|
# self.hidden_layer = nn.Linear(2048, 1370) # 隱藏層,輸入大小取決於 Xception 的輸出大小
|
|
# self.output_layer = nn.Linear(1370, 2) # 輸出層,依據分類數目設定
|
|
|
|
# # 激活函數與 dropout
|
|
# self.relu = nn.ReLU()
|
|
# self.dropout = nn.Dropout(0.6)
|
|
|
|
def forward(self, x):
|
|
x = self.base_model(x) # Xception 主體
|
|
x = x[0]
|
|
output = self.custom_head(x)
|
|
# x = self.global_avg_pool(x) # 全局平均池化
|
|
# x = self.relu(self.hidden_layer(x)) # 隱藏層 + ReLU
|
|
# x = self.dropout(x) # Dropout
|
|
# x = self.output_layer(x) # 輸出層
|
|
return output
|