91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
class Encode_Block(nn.Module):
|
||
"""基本的卷積塊:Conv2d + BatchNorm + ReLU"""
|
||
def __init__(self, in_channels, out_channels):
|
||
super(Encode_Block, self).__init__()
|
||
self.conv = nn.Sequential(
|
||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||
nn.BatchNorm2d(out_channels),
|
||
nn.ReLU(inplace=True),
|
||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||
nn.BatchNorm2d(out_channels),
|
||
nn.ReLU(inplace=True)
|
||
)
|
||
|
||
def forward(self, x):
|
||
return self.conv(x)
|
||
|
||
class GastroSegNet(nn.Module):
|
||
"""簡單的U-Net實現"""
|
||
def __init__(self, in_channels=3, out_channels=3, features=[32, 64, 128, 256]):
|
||
super(GastroSegNet, self).__init__()
|
||
|
||
# 編碼器(下採樣路徑)
|
||
self.encoder = nn.ModuleList()
|
||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||
|
||
# 第一層
|
||
self.encoder.append(Encode_Block(in_channels, features[0]))
|
||
|
||
# 其他編碼層
|
||
for i in range(1, len(features)):
|
||
self.encoder.append(Encode_Block(features[i-1], features[i]))
|
||
|
||
# 瓶頸層(最底層)
|
||
self.bottleneck = Encode_Block(features[-1], features[-1] * 2)
|
||
|
||
# 解碼器(上採樣路徑)
|
||
self.decoder = nn.ModuleList()
|
||
self.upconv = nn.ModuleList()
|
||
|
||
# 創建上採樣和解碼層
|
||
for i in range(len(features)):
|
||
self.upconv.append(
|
||
nn.ConvTranspose2d(features[-1-i] * 2, features[-1-i], kernel_size=2, stride=2)
|
||
)
|
||
self.decoder.append(
|
||
Encode_Block(features[-1-i] * 2, features[-1-i])
|
||
)
|
||
|
||
# 最終輸出層
|
||
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
|
||
|
||
def forward(self, x):
|
||
# 存儲跳躍連接
|
||
skip_connections = []
|
||
|
||
# 編碼器路徑
|
||
for encoder_layer in self.encoder:
|
||
x = encoder_layer(x)
|
||
skip_connections.append(x)
|
||
x = self.pool(x)
|
||
|
||
# 瓶頸層
|
||
x = self.bottleneck(x)
|
||
|
||
# 反轉跳躍連接列表
|
||
skip_connections = skip_connections[::-1]
|
||
|
||
# 解碼器路徑
|
||
for i, (upconv, decoder) in enumerate(zip(self.upconv, self.decoder)):
|
||
# 上採樣
|
||
x = upconv(x)
|
||
|
||
# 獲取對應的跳躍連接
|
||
skip = skip_connections[i]
|
||
|
||
# 如果尺寸不匹配,調整大小
|
||
if x.shape != skip.shape:
|
||
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
||
|
||
# 連接跳躍連接
|
||
x = torch.cat([skip, x], dim=1)
|
||
|
||
# 通過解碼塊
|
||
x = decoder(x)
|
||
|
||
# 最終輸出
|
||
return self.final_conv(x) |