49 lines
1.5 KiB
Python
Executable File
49 lines
1.5 KiB
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
|
|
import cirtorch.layers.functional as LF
|
|
|
|
# --------------------------------------
|
|
# Loss/Error layers
|
|
# --------------------------------------
|
|
|
|
class ContrastiveLoss(nn.Module):
|
|
r"""CONTRASTIVELOSS layer that computes contrastive loss for a batch of images:
|
|
Q query tuples, each packed in the form of (q,p,n1,..nN)
|
|
|
|
Args:
|
|
x: tuples arranges in columns as [q,p,n1,nN, ... ]
|
|
label: -1 for query, 1 for corresponding positive, 0 for corresponding negative
|
|
margin: contrastive loss margin. Default: 0.7
|
|
|
|
>>> contrastive_loss = ContrastiveLoss(margin=0.7)
|
|
>>> input = torch.randn(128, 35, requires_grad=True)
|
|
>>> label = torch.Tensor([-1, 1, 0, 0, 0, 0, 0] * 5)
|
|
>>> output = contrastive_loss(input, label)
|
|
>>> output.backward()
|
|
"""
|
|
|
|
def __init__(self, margin=0.7, eps=1e-6):
|
|
super(ContrastiveLoss, self).__init__()
|
|
self.margin = margin
|
|
self.eps = eps
|
|
|
|
def forward(self, x, label):
|
|
return LF.contrastive_loss(x, label, margin=self.margin, eps=self.eps)
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')'
|
|
|
|
|
|
class TripletLoss(nn.Module):
|
|
|
|
def __init__(self, margin=0.1):
|
|
super(TripletLoss, self).__init__()
|
|
self.margin = margin
|
|
|
|
def forward(self, x, label):
|
|
return LF.triplet_loss(x, label, margin=self.margin)
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')'
|