# Definition of ArcFace loss and CosFace loss import math import torch import torch.nn as nn import torch.nn.functional as F class ArcFace(nn.Module): def __init__(self, embedding_size, class_num, s=30.0, m=0.50): """ArcFace formula: cos(m + theta) = cos(m)cos(theta) - sin(m)sin(theta) Note that: 0 <= m + theta <= Pi So if (m + theta) >= Pi, then theta >= Pi - m. In [0, Pi] we have: cos(theta) < cos(Pi - m) So we can use cos(Pi - m) as threshold to check whether (m + theta) go out of [0, Pi] Args: embedding_size: usually 128, 256, 512 ... class_num: num of people when training s: scale, see normface https://arxiv.org/abs/1704.06369 m: margin, see SphereFace, CosFace, and ArcFace paper """ super().__init__() self.in_features = embedding_size self.out_features = class_num self.s = s self.m = m self.weight = nn.Parameter(torch.FloatTensor(class_num, embedding_size)) nn.init.xavier_uniform_(self.weight) self.cos_m = math.cos(m) self.sin_m = math.sin(m) self.th = math.cos(math.pi - m) self.mm = math.sin(math.pi - m) * m def forward(self, input, label): #print(f"embding {self.in_features}, class_num {self.out_features}, input {len(input)}, label {len(label)}") cosine = F.linear(F.normalize(input), F.normalize(self.weight)) # print('F.normalize(input)',input.shape) # print('F.normalize(self.weight)',F.normalize(self.weight).shape) sine = ((1.0 - cosine.pow(2)).clamp(0, 1)).sqrt() phi = cosine * self.cos_m - sine * self.sin_m phi = torch.where(cosine > self.th, phi, cosine - self.mm) # drop to CosFace #print(f'consine {cosine.shape, cosine}, sine {sine.shape, sine}, phi {phi.shape, phi}') # update y_i by phi in cosine output = cosine * 1.0 # make backward works batch_size = len(output) output[range(batch_size), label] = phi[range(batch_size), label] # print(f'output {(output * self.s).shape}') # print(f'phi[range(batch_size), label] {phi[range(batch_size), label]}') return output * self.s class CosFace(nn.Module): def __init__(self, in_features, out_features, s=30.0, m=0.40): """ Args: embedding_size: usually 128, 256, 512 ... class_num: num of people when training s: scale, see normface https://arxiv.org/abs/1704.06369 m: margin, see SphereFace, CosFace, and ArcFace paper """ super().__init__() self.in_features = in_features self.out_features = out_features self.s = s self.m = m self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features)) nn.init.xavier_uniform_(self.weight) def forward(self, input, label): cosine = F.linear(F.normalize(input), F.normalize(self.weight)) phi = cosine - self.m output = cosine * 1.0 # make backward works batch_size = len(output) output[range(batch_size), label] = phi[range(batch_size), label] return output * self.s