first push
This commit is contained in:
48
cirtorch/layers/loss.py
Executable file
48
cirtorch/layers/loss.py
Executable file
@ -0,0 +1,48 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import cirtorch.layers.functional as LF
|
||||
|
||||
# --------------------------------------
|
||||
# Loss/Error layers
|
||||
# --------------------------------------
|
||||
|
||||
class ContrastiveLoss(nn.Module):
|
||||
r"""CONTRASTIVELOSS layer that computes contrastive loss for a batch of images:
|
||||
Q query tuples, each packed in the form of (q,p,n1,..nN)
|
||||
|
||||
Args:
|
||||
x: tuples arranges in columns as [q,p,n1,nN, ... ]
|
||||
label: -1 for query, 1 for corresponding positive, 0 for corresponding negative
|
||||
margin: contrastive loss margin. Default: 0.7
|
||||
|
||||
>>> contrastive_loss = ContrastiveLoss(margin=0.7)
|
||||
>>> input = torch.randn(128, 35, requires_grad=True)
|
||||
>>> label = torch.Tensor([-1, 1, 0, 0, 0, 0, 0] * 5)
|
||||
>>> output = contrastive_loss(input, label)
|
||||
>>> output.backward()
|
||||
"""
|
||||
|
||||
def __init__(self, margin=0.7, eps=1e-6):
|
||||
super(ContrastiveLoss, self).__init__()
|
||||
self.margin = margin
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x, label):
|
||||
return LF.contrastive_loss(x, label, margin=self.margin, eps=self.eps)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')'
|
||||
|
||||
|
||||
class TripletLoss(nn.Module):
|
||||
|
||||
def __init__(self, margin=0.1):
|
||||
super(TripletLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, x, label):
|
||||
return LF.triplet_loss(x, label, margin=self.margin)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')'
|
Reference in New Issue
Block a user