30 lines
1.5 KiB
Python
30 lines
1.5 KiB
Python
import torch.nn as nn
|
|
import timm
|
|
from utils.Stomach_Config import Model_Config
|
|
|
|
class ModifiedXception(nn.Module):
|
|
def __init__(self):
|
|
super(ModifiedXception, self).__init__()
|
|
|
|
# 加載 Xception 預訓練模型,去掉最後一層 (fc 層)
|
|
self.base_model = timm.create_model(Model_Config["Model Name"], pretrained=True, num_classes = 0)
|
|
# self.base_model.fc = nn.Identity() # 移除原來的 fully connected 層
|
|
|
|
# 新增全局平均池化層、隱藏層和輸出層
|
|
in_features = self.base_model.num_features # 自動取得 2048
|
|
self.global_avg_pool = nn.AdaptiveAvgPool1d(Model_Config["GPA Output Nodes"]) # 全局平均池化
|
|
self.hidden_layer = nn.Linear(in_features, Model_Config["Linear Hidden Nodes"]) # 隱藏層,輸入大小取決於 Xception 的輸出大小
|
|
self.output_layer = nn.Linear(Model_Config["Linear Hidden Nodes"], Model_Config["Output Linear Nodes"]) # 輸出層,依據分類數目設定
|
|
|
|
# 激活函數與 dropout
|
|
self.relu = nn.ReLU()
|
|
self.dropout = nn.Dropout(Model_Config["Dropout Rate"])
|
|
|
|
def forward(self, x):
|
|
x = self.base_model(x) # Xception 主體
|
|
# x = self.global_avg_pool(x) # 全局平均池化
|
|
x = self.dropout(x) # Dropout
|
|
x = self.hidden_layer(x)
|
|
x = self.relu(x) # 隱藏層 + ReLU
|
|
x = self.output_layer(x) # 輸出層
|
|
return x |