import torch.nn as nn import timm from utils.Stomach_Config import Model_Config class ModifiedXception(nn.Module): def __init__(self, num_classes): super(ModifiedXception, self).__init__() # 加載 Xception 預訓練模型,去掉最後一層 (fc 層) self.base_model = timm.create_model(Model_Config["Model Name"], pretrained=True) self.base_model.fc = nn.Identity() # 移除原來的 fully connected 層 # 新增全局平均池化層、隱藏層和輸出層 self.global_avg_pool = nn.AdaptiveAvgPool1d(Model_Config["GPA Output Nodes"]) # 全局平均池化 self.hidden_layer = nn.Linear(Model_Config["GPA Output Nodes"], Model_Config["Linear Hidden Nodes"]) # 隱藏層,輸入大小取決於 Xception 的輸出大小 self.output_layer = nn.Linear(Model_Config["Linear Hidden Nodes"], Model_Config["Output Linear Nodes"]) # 輸出層,依據分類數目設定 # 激活函數與 dropout self.relu = nn.ReLU() self.softmax = nn.Softmax(1) self.dropout = nn.Dropout(Model_Config["Dropout Rate"]) def forward(self, x): x = self.base_model(x) # Xception 主體 x = self.global_avg_pool(x) # 全局平均池化 x = self.relu(self.hidden_layer(x)) # 隱藏層 + ReLU x = self.dropout(x) # Dropout x = self.output_layer(x) # 輸出層 return x