Files
Stomach_Cancer_Pytorch/experiments/pytorch_Model.py

81 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch.nn as nn
import timm
# class ModifiedXception(nn.Module):
# def __init__(self, num_classes):
# super(ModifiedXception, self).__init__()
# # Load Xception pre-trained model (full model, not just features)
# self.base_model = timm.create_model(
# 'xception',
# pretrained=True
# )
# # Replace the default global pooling with AdaptiveAvgPool2d
# self.base_model.global_pool = nn.AdaptiveAvgPool2d(output_size=1) # Output size of 1x1 spatially
# # Replace the final fully connected layer with Identity to get features
# self.base_model.fc = nn.Identity() # Output will be 2048 (Xception's default feature size)
# # Custom head: Linear from 2048 to 1370, additional 1370 layer, then to num_classes
# self.custom_head = nn.Sequential(
# nn.Linear(2048, 1025), # From Xceptions 2048 features to 1370
# nn.ReLU(), # Activation
# nn.Dropout(0.6), # Dropout for regularization
# nn.Linear(1025, num_classes) # Final output layer
# # nn.Softmax(dim = 1) # Sigmoid for binary/multi-label classification
# )
# def forward(self, x):
# # Pass through the base Xception model (up to global pooling)
# x = self.base_model.forward_features(x) # Get feature maps
# x = self.base_model.global_pool(x) # Apply AdaptiveAvgPool2d (output: [B, 2048, 1, 1])
# x = x.flatten(1) # Flatten to [B, 2048]
# # x = self.base_model.fc(x) # Identity layer (still [B, 2048])
# output = self.custom_head(x) # Custom head processing
# return output
class ModifiedXception(nn.Module):
def __init__(self, num_classes):
super(ModifiedXception, self).__init__()
# 加載 Xception 預訓練模型,去掉最後一層 (fc 層)
self.base_model = timm.create_model('xception', pretrained=True)
self.base_model.fc = nn.Identity() # 移除原來的 fully connected 層
# 新增全局平均池化層、隱藏層和輸出層
GAP_Output = 2048
self.global_avg_pool = nn.AdaptiveAvgPool1d(2048) # 全局平均池化
self.hidden_layer = nn.Linear(2048, 1025) # 隱藏層,輸入大小取決於 Xception 的輸出大小
self.output_layer = nn.Linear(1025, num_classes) # 輸出層,依據分類數目設定
# 激活函數與 dropout
self.relu = nn.ReLU()
self.softmax = nn.Softmax(1)
self.dropout = nn.Dropout(0.6)
def forward(self, x):
x = self.base_model(x) # Xception 主體
x = self.global_avg_pool(x) # 全局平均池化
x = self.relu(self.hidden_layer(x)) # 隱藏層 + ReLU
x = self.dropout(x) # Dropout
x = self.output_layer(x) # 輸出層
return x
class Model_module():
def __init__(self):
self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = 3, padding = 1)
self.conv2 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, padding = 1)
self.conv3 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, padding = 1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.max_Pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear()
self.fc2 = nn.Linear()
pass
def forward(self, input):
pass