Files
Stomach_Cancer_Pytorch/Model_Loss/CIOU_Loss.py

315 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
import math
class CIOULoss(nn.Module):
"""
Complete Intersection over Union (CIOU) Loss
適用於目標檢測中的邊界框回歸任務
CIOU Loss 考慮了三個幾何因子:
1. 重疊面積 (Overlap area)
2. 中心點距離 (Central point distance)
3. 長寬比一致性 (Aspect ratio consistency)
"""
def __init__(self, eps=1e-7):
super(CIOULoss, self).__init__()
self.eps = eps
def forward(self, pred_boxes, target_boxes):
"""
計算 CIOU Loss
Args:
pred_boxes: 預測邊界框 [N, 4] (x1, y1, x2, y2) 或 [N, 4] (cx, cy, w, h) 或分割掩碼 [B, 1, H, W]
target_boxes: 真實邊界框 [N, 4] (x1, y1, x2, y2) 或 [N, 4] (cx, cy, w, h) 或分割掩碼 [B, 1, H, W]
Returns:
CIOU loss value
"""
# 檢查輸入是否為分割掩碼格式
if len(pred_boxes.shape) == 4 and pred_boxes.shape[1] == 1:
# 將分割掩碼轉換為邊界框格式
pred_boxes = self._mask_to_boxes(pred_boxes)
target_boxes = self._mask_to_boxes(target_boxes)
# 如果無法從掩碼中提取有效的邊界框,則返回一個小的損失值
if pred_boxes is None or target_boxes is None:
return torch.tensor(0.01, device=pred_boxes.device if pred_boxes is not None else target_boxes.device)
# 確保輸入為浮點數
pred_boxes = pred_boxes.float()
target_boxes = target_boxes.float()
# 檢查邊界框維度是否正確
if pred_boxes.dim() == 1:
# 如果是單個邊界框擴展為批次大小為1的張量
pred_boxes = pred_boxes.unsqueeze(0)
if target_boxes.dim() == 1:
target_boxes = target_boxes.unsqueeze(0)
# 確保邊界框有4個坐標
if pred_boxes.shape[1] != 4 or target_boxes.shape[1] != 4:
# 如果坐標數量不正確,返回一個小的損失值
return torch.tensor(0.01, device=pred_boxes.device)
# 如果輸入是 (cx, cy, w, h) 格式,轉換為 (x1, y1, x2, y2)
if self._is_center_format(pred_boxes, target_boxes):
pred_boxes = self._center_to_corner(pred_boxes)
target_boxes = self._center_to_corner(target_boxes)
# 計算交集區域
intersection = self._calculate_intersection(pred_boxes, target_boxes)
# 計算各自的面積
pred_area = (pred_boxes[:, 2] - pred_boxes[:, 0]) * (pred_boxes[:, 3] - pred_boxes[:, 1])
target_area = (target_boxes[:, 2] - target_boxes[:, 0]) * (target_boxes[:, 3] - target_boxes[:, 1])
# 計算聯集面積
union = pred_area + target_area - intersection + self.eps
# 計算 IoU
iou = intersection / union
# 計算最小外接矩形
enclose_x1 = torch.min(pred_boxes[:, 0], target_boxes[:, 0])
enclose_y1 = torch.min(pred_boxes[:, 1], target_boxes[:, 1])
enclose_x2 = torch.max(pred_boxes[:, 2], target_boxes[:, 2])
enclose_y2 = torch.max(pred_boxes[:, 3], target_boxes[:, 3])
# 計算最小外接矩形的對角線距離平方
enclose_diagonal_sq = (enclose_x2 - enclose_x1) ** 2 + (enclose_y2 - enclose_y1) ** 2 + self.eps
# 計算兩個邊界框中心點之間的距離平方
pred_center_x = (pred_boxes[:, 0] + pred_boxes[:, 2]) / 2
pred_center_y = (pred_boxes[:, 1] + pred_boxes[:, 3]) / 2
target_center_x = (target_boxes[:, 0] + target_boxes[:, 2]) / 2
target_center_y = (target_boxes[:, 1] + target_boxes[:, 3]) / 2
center_distance_sq = (pred_center_x - target_center_x) ** 2 + (pred_center_y - target_center_y) ** 2
# 計算長寬比一致性項
pred_w = pred_boxes[:, 2] - pred_boxes[:, 0]
pred_h = pred_boxes[:, 3] - pred_boxes[:, 1]
target_w = target_boxes[:, 2] - target_boxes[:, 0]
target_h = target_boxes[:, 3] - target_boxes[:, 1]
# 避免除零
pred_w = torch.clamp(pred_w, min=self.eps)
pred_h = torch.clamp(pred_h, min=self.eps)
target_w = torch.clamp(target_w, min=self.eps)
target_h = torch.clamp(target_h, min=self.eps)
v = (4 / (math.pi ** 2)) * torch.pow(torch.atan(target_w / target_h) - torch.atan(pred_w / pred_h), 2)
# 計算 alpha 參數
with torch.no_grad():
alpha = v / (1 - iou + v + self.eps)
# 計算 CIOU
ciou = iou - (center_distance_sq / enclose_diagonal_sq) - alpha * v
# 返回 CIOU Loss (1 - CIOU)
ciou_loss = 1 - ciou
return ciou_loss.mean()
def _is_center_format(self, pred_boxes, target_boxes):
"""
判斷輸入格式是否為中心點格式 (cx, cy, w, h)
簡單的啟發式判斷:如果第三、四列的值都是正數且相對較小,可能是寬高
"""
# 這裡使用簡單的判斷邏輯,實際使用時可能需要更精確的判斷
return False # 預設假設輸入為 (x1, y1, x2, y2) 格式
def _center_to_corner(self, boxes):
"""
將中心點格式 (cx, cy, w, h) 轉換為角點格式 (x1, y1, x2, y2)
"""
cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
x1 = cx - w / 2
y1 = cy - h / 2
x2 = cx + w / 2
y2 = cy + h / 2
return torch.stack([x1, y1, x2, y2], dim=1)
def _mask_to_boxes(self, masks):
"""
將分割掩碼轉換為邊界框格式 [N, 4] (x1, y1, x2, y2)
Args:
masks: 分割掩碼 [B, 1, H, W]
Returns:
boxes: 邊界框 [B, 4] (x1, y1, x2, y2)
"""
batch_size = masks.size(0)
device = masks.device
# 將掩碼轉換為二值掩碼
binary_masks = (torch.sigmoid(masks) > 0.5).float()
# 初始化邊界框張量
boxes = torch.zeros(batch_size, 4, device=device)
# 對每個批次處理
for b in range(batch_size):
mask = binary_masks[b, 0] # [H, W]
# 找出非零元素的索引
non_zero_indices = torch.nonzero(mask, as_tuple=True)
# 如果掩碼中沒有非零元素,則使用默認的小邊界框
if len(non_zero_indices[0]) == 0:
# 返回一個默認的小邊界框
boxes[b] = torch.tensor([0, 0, 1, 1], device=device)
continue
# 計算邊界框坐標
y_min = torch.min(non_zero_indices[0])
y_max = torch.max(non_zero_indices[0])
x_min = torch.min(non_zero_indices[1])
x_max = torch.max(non_zero_indices[1])
# 存儲邊界框 [x1, y1, x2, y2]
boxes[b] = torch.tensor([x_min, y_min, x_max, y_max], device=device)
return boxes
def _calculate_intersection(self, pred_boxes, target_boxes):
"""
計算兩個邊界框的交集面積
"""
x1 = torch.max(pred_boxes[:, 0], target_boxes[:, 0])
y1 = torch.max(pred_boxes[:, 1], target_boxes[:, 1])
x2 = torch.min(pred_boxes[:, 2], target_boxes[:, 2])
y2 = torch.min(pred_boxes[:, 3], target_boxes[:, 3])
# 計算交集的寬度和高度
intersection_w = torch.clamp(x2 - x1, min=0)
intersection_h = torch.clamp(y2 - y1, min=0)
return intersection_w * intersection_h
class DIoULoss(nn.Module):
"""
Distance Intersection over Union (DIoU) Loss
CIOU 的簡化版本,只考慮重疊面積和中心點距離
"""
def __init__(self, eps=1e-7):
super(DIoULoss, self).__init__()
self.eps = eps
def forward(self, pred_boxes, target_boxes):
# 確保輸入為浮點數
pred_boxes = pred_boxes.float()
target_boxes = target_boxes.float()
# 計算交集區域
intersection = self._calculate_intersection(pred_boxes, target_boxes)
# 計算各自的面積
pred_area = (pred_boxes[:, 2] - pred_boxes[:, 0]) * (pred_boxes[:, 3] - pred_boxes[:, 1])
target_area = (target_boxes[:, 2] - target_boxes[:, 0]) * (target_boxes[:, 3] - target_boxes[:, 1])
# 計算聯集面積
union = pred_area + target_area - intersection + self.eps
# 計算 IoU
iou = intersection / union
# 計算最小外接矩形的對角線距離平方
enclose_x1 = torch.min(pred_boxes[:, 0], target_boxes[:, 0])
enclose_y1 = torch.min(pred_boxes[:, 1], target_boxes[:, 1])
enclose_x2 = torch.max(pred_boxes[:, 2], target_boxes[:, 2])
enclose_y2 = torch.max(pred_boxes[:, 3], target_boxes[:, 3])
enclose_diagonal_sq = (enclose_x2 - enclose_x1) ** 2 + (enclose_y2 - enclose_y1) ** 2 + self.eps
# 計算中心點距離平方
pred_center_x = (pred_boxes[:, 0] + pred_boxes[:, 2]) / 2
pred_center_y = (pred_boxes[:, 1] + pred_boxes[:, 3]) / 2
target_center_x = (target_boxes[:, 0] + target_boxes[:, 2]) / 2
target_center_y = (target_boxes[:, 1] + target_boxes[:, 3]) / 2
center_distance_sq = (pred_center_x - target_center_x) ** 2 + (pred_center_y - target_center_y) ** 2
# 計算 DIoU
diou = iou - (center_distance_sq / enclose_diagonal_sq)
# 返回 DIoU Loss
diou_loss = 1 - diou
return diou_loss.mean()
def _calculate_intersection(self, pred_boxes, target_boxes):
"""計算交集面積"""
x1 = torch.max(pred_boxes[:, 0], target_boxes[:, 0])
y1 = torch.max(pred_boxes[:, 1], target_boxes[:, 1])
x2 = torch.min(pred_boxes[:, 2], target_boxes[:, 2])
y2 = torch.min(pred_boxes[:, 3], target_boxes[:, 3])
intersection_w = torch.clamp(x2 - x1, min=0)
intersection_h = torch.clamp(y2 - y1, min=0)
return intersection_w * intersection_h
class GIoULoss(nn.Module):
"""
Generalized Intersection over Union (GIoU) Loss
IoU 的泛化版本,考慮了最小外接矩形
"""
def __init__(self, eps=1e-7):
super(GIoULoss, self).__init__()
self.eps = eps
def forward(self, pred_boxes, target_boxes):
# 確保輸入為浮點數
pred_boxes = pred_boxes.float()
target_boxes = target_boxes.float()
# 計算交集
intersection = self._calculate_intersection(pred_boxes, target_boxes)
# 計算各自面積
pred_area = (pred_boxes[:, 2] - pred_boxes[:, 0]) * (pred_boxes[:, 3] - pred_boxes[:, 1])
target_area = (target_boxes[:, 2] - target_boxes[:, 0]) * (target_boxes[:, 3] - target_boxes[:, 1])
# 計算聯集
union = pred_area + target_area - intersection + self.eps
# 計算 IoU
iou = intersection / union
# 計算最小外接矩形面積
enclose_x1 = torch.min(pred_boxes[:, 0], target_boxes[:, 0])
enclose_y1 = torch.min(pred_boxes[:, 1], target_boxes[:, 1])
enclose_x2 = torch.max(pred_boxes[:, 2], target_boxes[:, 2])
enclose_y2 = torch.max(pred_boxes[:, 3], target_boxes[:, 3])
enclose_area = (enclose_x2 - enclose_x1) * (enclose_y2 - enclose_y1) + self.eps
# 計算 GIoU
giou = iou - (enclose_area - union) / enclose_area
# 返回 GIoU Loss
giou_loss = 1 - giou
return giou_loss.mean()
def _calculate_intersection(self, pred_boxes, target_boxes):
"""計算交集面積"""
x1 = torch.max(pred_boxes[:, 0], target_boxes[:, 0])
y1 = torch.max(pred_boxes[:, 1], target_boxes[:, 1])
x2 = torch.min(pred_boxes[:, 2], target_boxes[:, 2])
y2 = torch.min(pred_boxes[:, 3], target_boxes[:, 3])
intersection_w = torch.clamp(x2 - x1, min=0)
intersection_h = torch.clamp(y2 - y1, min=0)
return intersection_w * intersection_h