142 lines
5.0 KiB
Python
142 lines
5.0 KiB
Python
# vit_branch.py
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from einops import rearrange
|
||
from typing import Tuple, Optional
|
||
|
||
|
||
class PatchEmbed(nn.Module):
|
||
"""將 feature map 切成 patches"""
|
||
def __init__(self, in_chs: int = 728, embed_dim: int = 728, patch_size: int = 4):
|
||
super().__init__()
|
||
self.proj = nn.Conv2d(in_chs, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||
self.norm = nn.LayerNorm(embed_dim)
|
||
|
||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||
B, C, H, W = x.shape
|
||
x = self.proj(x) # (B, D, H/p, W/p)
|
||
x = x.flatten(2).transpose(1, 2) # (B, N, D)
|
||
x = self.norm(x)
|
||
return x, (H, W)
|
||
|
||
|
||
# vit_branch.py
|
||
class PatchReconstruction(nn.Module):
|
||
def __init__(self, embed_dim=728, out_chs=728, patch_size=4):
|
||
super().__init__()
|
||
self.patch_size = patch_size
|
||
self.out_chs = out_chs
|
||
self.proj = nn.Linear(embed_dim, patch_size * patch_size * out_chs)
|
||
|
||
def forward(self, x, orig_size):
|
||
B, N, _ = x.shape
|
||
H, W = orig_size
|
||
p = self.patch_size
|
||
h = H // p
|
||
w = W // p
|
||
assert h * w == N, f"N mismatch: {h*w} vs {N}"
|
||
|
||
x = self.proj(x) # (B, N, p*p*C)
|
||
x = x.view(B, h, w, p*p, self.out_chs) # (B, h, w, p*p, C)
|
||
x = rearrange(
|
||
x,
|
||
'b h w (p1 p2) c -> b c (h p1) (w p2)',
|
||
p1=p, p2=p
|
||
)
|
||
return x
|
||
|
||
|
||
class TransformerBlock(nn.Module):
|
||
def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0,
|
||
qkv_bias: bool = True, drop: float = 0., attn_drop: float = 0.):
|
||
super().__init__()
|
||
self.norm1 = nn.LayerNorm(dim)
|
||
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=attn_drop,
|
||
bias=qkv_bias, batch_first=True)
|
||
self.norm2 = nn.LayerNorm(dim)
|
||
hidden_dim = int(dim * mlp_ratio)
|
||
self.mlp = nn.Sequential(
|
||
nn.Linear(dim, hidden_dim),
|
||
nn.GELU(),
|
||
nn.Dropout(drop),
|
||
nn.Linear(hidden_dim, dim),
|
||
nn.Dropout(drop)
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
y, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x))
|
||
x = x + y
|
||
x = x + self.mlp(self.norm2(x))
|
||
return x
|
||
|
||
|
||
class ViTBranch(nn.Module):
|
||
"""
|
||
ViT Branch: 接收 CNN feature map,輸出同尺寸增強特徵
|
||
用途:與 Xception Middle Flow 融合
|
||
"""
|
||
def __init__(
|
||
self,
|
||
in_chs: int = 728,
|
||
embed_dim: int = 728,
|
||
patch_size: int = 4,
|
||
depth: int = 3,
|
||
num_heads: int = 8,
|
||
mlp_ratio: float = 4.0,
|
||
drop_rate: float = 0.0,
|
||
attn_drop_rate: float = 0.0,
|
||
device: Optional[torch.device] = None,
|
||
dtype: Optional[torch.dtype] = None,
|
||
):
|
||
super().__init__()
|
||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||
|
||
self.patch_embed = PatchEmbed(in_chs, embed_dim, patch_size)
|
||
self.patch_recon = PatchReconstruction(embed_dim, in_chs, patch_size)
|
||
|
||
# Learnable pos embed (will be resized dynamically)
|
||
num_patches_approx = ((299 // 16) ** 2)
|
||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches_approx, embed_dim, **factory_kwargs))
|
||
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
||
|
||
self.blocks = nn.ModuleList([
|
||
TransformerBlock(
|
||
dim=embed_dim,
|
||
num_heads=num_heads,
|
||
mlp_ratio=mlp_ratio,
|
||
drop=drop_rate,
|
||
attn_drop=attn_drop_rate
|
||
) for _ in range(depth)
|
||
])
|
||
self.norm = nn.LayerNorm(embed_dim)
|
||
|
||
def _resize_pos_embed(self, pos_embed: torch.Tensor, target_hw: Tuple[int, int]) -> torch.Tensor:
|
||
"""動態調整位置編碼尺寸"""
|
||
B = pos_embed.shape[0]
|
||
sqrt_N = int(pos_embed.shape[1] ** 0.5)
|
||
pos_embed = pos_embed.reshape(1, sqrt_N, sqrt_N, -1).permute(0, 3, 1, 2)
|
||
pos_embed = F.interpolate(pos_embed, size=target_hw, mode='bilinear', align_corners=False)
|
||
pos_embed = pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
|
||
return pos_embed.expand(B, -1, -1)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
B, C, H, W = x.shape
|
||
x, orig_size = self.patch_embed(x) # (B, N, D), orig_size=(H,W)
|
||
N = x.shape[1]
|
||
|
||
# 動態調整 pos_embed
|
||
target_grid = (H // self.patch_embed.proj.stride[0], W // self.patch_embed.proj.stride[1])
|
||
if self.pos_embed.shape[1] != N:
|
||
pos_embed = self._resize_pos_embed(self.pos_embed, target_grid)
|
||
else:
|
||
pos_embed = self.pos_embed.expand(B, -1, -1)
|
||
|
||
x = x + pos_embed
|
||
|
||
for blk in self.blocks:
|
||
x = blk(x)
|
||
|
||
x = self.norm(x)
|
||
x = self.patch_recon(x, orig_size)
|
||
return x |