18 lines
456 B
Python
18 lines
456 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class FocalLoss(nn.Module):
|
|
|
|
def __init__(self, gamma=2):
|
|
super().__init__()
|
|
self.gamma = gamma
|
|
self.ce = torch.nn.CrossEntropyLoss()
|
|
|
|
def forward(self, input, target):
|
|
|
|
#print(f'theta {input.shape, input[0]}, target {target.shape, target}')
|
|
logp = self.ce(input, target)
|
|
p = torch.exp(-logp)
|
|
loss = (1 - p) ** self.gamma * logp
|
|
return loss.mean() |