This commit is contained in:
lee
2025-06-11 15:23:50 +08:00
commit 37ecef40f7
79 changed files with 26981 additions and 0 deletions

18
model/loss.py Normal file
View File

@ -0,0 +1,18 @@
import torch
import torch.nn as nn
class FocalLoss(nn.Module):
def __init__(self, gamma=2):
super().__init__()
self.gamma = gamma
self.ce = torch.nn.CrossEntropyLoss()
def forward(self, input, target):
#print(f'theta {input.shape, input[0]}, target {target.shape, target}')
logp = self.ce(input, target)
p = torch.exp(-logp)
loss = (1 - p) ** self.gamma * logp
return loss.mean()