Stomach_Cancer_Pytorch/experiments/Models/ViT_Model.py

142 lines
5.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# vit_branch.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from typing import Tuple, Optional
class PatchEmbed(nn.Module):
"""將 feature map 切成 patches"""
def __init__(self, in_chs: int = 728, embed_dim: int = 728, patch_size: int = 4):
super().__init__()
self.proj = nn.Conv2d(in_chs, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
B, C, H, W = x.shape
x = self.proj(x) # (B, D, H/p, W/p)
x = x.flatten(2).transpose(1, 2) # (B, N, D)
x = self.norm(x)
return x, (H, W)
# vit_branch.py
class PatchReconstruction(nn.Module):
def __init__(self, embed_dim=728, out_chs=728, patch_size=4):
super().__init__()
self.patch_size = patch_size
self.out_chs = out_chs
self.proj = nn.Linear(embed_dim, patch_size * patch_size * out_chs)
def forward(self, x, orig_size):
B, N, _ = x.shape
H, W = orig_size
p = self.patch_size
h = H // p
w = W // p
assert h * w == N, f"N mismatch: {h*w} vs {N}"
x = self.proj(x) # (B, N, p*p*C)
x = x.view(B, h, w, p*p, self.out_chs) # (B, h, w, p*p, C)
x = rearrange(
x,
'b h w (p1 p2) c -> b c (h p1) (w p2)',
p1=p, p2=p
)
return x
class TransformerBlock(nn.Module):
def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0,
qkv_bias: bool = True, drop: float = 0., attn_drop: float = 0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=attn_drop,
bias=qkv_bias, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(hidden_dim, dim),
nn.Dropout(drop)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x))
x = x + y
x = x + self.mlp(self.norm2(x))
return x
class ViTBranch(nn.Module):
"""
ViT Branch: 接收 CNN feature map輸出同尺寸增強特徵
用途:與 Xception Middle Flow 融合
"""
def __init__(
self,
in_chs: int = 728,
embed_dim: int = 728,
patch_size: int = 4,
depth: int = 3,
num_heads: int = 8,
mlp_ratio: float = 4.0,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
factory_kwargs = {'device': device, 'dtype': dtype}
self.patch_embed = PatchEmbed(in_chs, embed_dim, patch_size)
self.patch_recon = PatchReconstruction(embed_dim, in_chs, patch_size)
# Learnable pos embed (will be resized dynamically)
num_patches_approx = ((299 // 16) ** 2)
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches_approx, embed_dim, **factory_kwargs))
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.blocks = nn.ModuleList([
TransformerBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate
) for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
def _resize_pos_embed(self, pos_embed: torch.Tensor, target_hw: Tuple[int, int]) -> torch.Tensor:
"""動態調整位置編碼尺寸"""
B = pos_embed.shape[0]
sqrt_N = int(pos_embed.shape[1] ** 0.5)
pos_embed = pos_embed.reshape(1, sqrt_N, sqrt_N, -1).permute(0, 3, 1, 2)
pos_embed = F.interpolate(pos_embed, size=target_hw, mode='bilinear', align_corners=False)
pos_embed = pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
return pos_embed.expand(B, -1, -1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
x, orig_size = self.patch_embed(x) # (B, N, D), orig_size=(H,W)
N = x.shape[1]
# 動態調整 pos_embed
target_grid = (H // self.patch_embed.proj.stride[0], W // self.patch_embed.proj.stride[1])
if self.pos_embed.shape[1] != N:
pos_embed = self._resize_pos_embed(self.pos_embed, target_grid)
else:
pos_embed = self.pos_embed.expand(B, -1, -1)
x = x + pos_embed
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
x = self.patch_recon(x, orig_size)
return x