import torch import torch.nn as nn import torch.nn.functional as F class Encode_Block(nn.Module): """基本的卷積塊:Conv2d + BatchNorm + ReLU""" def __init__(self, in_channels, out_channels): super(Encode_Block, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class GastroSegNet(nn.Module): """簡單的U-Net實現""" def __init__(self, in_channels=3, out_channels=3, features=[32, 64, 128, 256]): super(GastroSegNet, self).__init__() # 編碼器(下採樣路徑) self.encoder = nn.ModuleList() self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # 第一層 self.encoder.append(Encode_Block(in_channels, features[0])) # 其他編碼層 for i in range(1, len(features)): self.encoder.append(Encode_Block(features[i-1], features[i])) # 瓶頸層(最底層) self.bottleneck = Encode_Block(features[-1], features[-1] * 2) # 解碼器(上採樣路徑) self.decoder = nn.ModuleList() self.upconv = nn.ModuleList() # 創建上採樣和解碼層 for i in range(len(features)): self.upconv.append( nn.ConvTranspose2d(features[-1-i] * 2, features[-1-i], kernel_size=2, stride=2) ) self.decoder.append( Encode_Block(features[-1-i] * 2, features[-1-i]) ) # 最終輸出層 self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) def forward(self, x): # 存儲跳躍連接 skip_connections = [] # 編碼器路徑 for encoder_layer in self.encoder: x = encoder_layer(x) skip_connections.append(x) x = self.pool(x) # 瓶頸層 x = self.bottleneck(x) # 反轉跳躍連接列表 skip_connections = skip_connections[::-1] # 解碼器路徑 for i, (upconv, decoder) in enumerate(zip(self.upconv, self.decoder)): # 上採樣 x = upconv(x) # 獲取對應的跳躍連接 skip = skip_connections[i] # 如果尺寸不匹配,調整大小 if x.shape != skip.shape: x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False) # 連接跳躍連接 x = torch.cat([skip, x], dim=1) # 通過解碼塊 x = decoder(x) # 最終輸出 return self.final_conv(x)