Files
ieemoo-ai-searchv2/cirtorch/layers/loss.py
2022-11-22 15:32:06 +08:00

49 lines
1.5 KiB
Python
Executable File

import torch
import torch.nn as nn
import cirtorch.layers.functional as LF
# --------------------------------------
# Loss/Error layers
# --------------------------------------
class ContrastiveLoss(nn.Module):
r"""CONTRASTIVELOSS layer that computes contrastive loss for a batch of images:
Q query tuples, each packed in the form of (q,p,n1,..nN)
Args:
x: tuples arranges in columns as [q,p,n1,nN, ... ]
label: -1 for query, 1 for corresponding positive, 0 for corresponding negative
margin: contrastive loss margin. Default: 0.7
>>> contrastive_loss = ContrastiveLoss(margin=0.7)
>>> input = torch.randn(128, 35, requires_grad=True)
>>> label = torch.Tensor([-1, 1, 0, 0, 0, 0, 0] * 5)
>>> output = contrastive_loss(input, label)
>>> output.backward()
"""
def __init__(self, margin=0.7, eps=1e-6):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.eps = eps
def forward(self, x, label):
return LF.contrastive_loss(x, label, margin=self.margin, eps=self.eps)
def __repr__(self):
return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')'
class TripletLoss(nn.Module):
def __init__(self, margin=0.1):
super(TripletLoss, self).__init__()
self.margin = margin
def forward(self, x, label):
return LF.triplet_loss(x, label, margin=self.margin)
def __repr__(self):
return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')'