23 lines
771 B
Python
23 lines
771 B
Python
from multiprocessing import Value
|
|
import pstats
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision.models as models
|
|
from torchvision import transforms
|
|
from Model_Loss.CIOU_Loss import CIOULoss
|
|
from Model_Loss.Perceptual_Loss import VGGPerceptualLoss
|
|
|
|
class Segmentation_Loss(nn.Module):
|
|
def __init__(self) -> None:
|
|
super(Segmentation_Loss, self).__init__()
|
|
self.Perceptual_Loss = VGGPerceptualLoss()
|
|
self.CIOU = CIOULoss()
|
|
pass
|
|
|
|
def forward(self, Output_Result, GroundTruth_Result):
|
|
Perceptual_Loss = self.Perceptual_Loss(Output_Result, GroundTruth_Result)
|
|
CIOU_Loss = self.CIOU(Output_Result, GroundTruth_Result)
|
|
|
|
return Perceptual_Loss + CIOU_Loss
|