import torch import torch.nn as nn class FocalLoss(nn.Module): def __init__(self, gamma=2): super().__init__() self.gamma = gamma self.ce = torch.nn.CrossEntropyLoss() def forward(self, input, target): #print(f'theta {input.shape, input[0]}, target {target.shape, target}') logp = self.ce(input, target) p = torch.exp(-logp) loss = (1 - p) ** self.gamma * logp return loss.mean()