Stomach_Cancer_Pytorch/experiments/Models/GastroSegNet_Model.py

91 lines
3.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
class Encode_Block(nn.Module):
"""基本的卷積塊Conv2d + BatchNorm + ReLU"""
def __init__(self, in_channels, out_channels):
super(Encode_Block, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class GastroSegNet(nn.Module):
"""簡單的U-Net實現"""
def __init__(self, in_channels=3, out_channels=3, features=[32, 64, 128, 256]):
super(GastroSegNet, self).__init__()
# 編碼器(下採樣路徑)
self.encoder = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# 第一層
self.encoder.append(Encode_Block(in_channels, features[0]))
# 其他編碼層
for i in range(1, len(features)):
self.encoder.append(Encode_Block(features[i-1], features[i]))
# 瓶頸層(最底層)
self.bottleneck = Encode_Block(features[-1], features[-1] * 2)
# 解碼器(上採樣路徑)
self.decoder = nn.ModuleList()
self.upconv = nn.ModuleList()
# 創建上採樣和解碼層
for i in range(len(features)):
self.upconv.append(
nn.ConvTranspose2d(features[-1-i] * 2, features[-1-i], kernel_size=2, stride=2)
)
self.decoder.append(
Encode_Block(features[-1-i] * 2, features[-1-i])
)
# 最終輸出層
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
def forward(self, x):
# 存儲跳躍連接
skip_connections = []
# 編碼器路徑
for encoder_layer in self.encoder:
x = encoder_layer(x)
skip_connections.append(x)
x = self.pool(x)
# 瓶頸層
x = self.bottleneck(x)
# 反轉跳躍連接列表
skip_connections = skip_connections[::-1]
# 解碼器路徑
for i, (upconv, decoder) in enumerate(zip(self.upconv, self.decoder)):
# 上採樣
x = upconv(x)
# 獲取對應的跳躍連接
skip = skip_connections[i]
# 如果尺寸不匹配,調整大小
if x.shape != skip.shape:
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
# 連接跳躍連接
x = torch.cat([skip, x], dim=1)
# 通過解碼塊
x = decoder(x)
# 最終輸出
return self.final_conv(x)