94 lines
3.6 KiB
Python
94 lines
3.6 KiB
Python
# Definition of ArcFace loss and CosFace loss
|
|
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class ArcFace(nn.Module):
|
|
|
|
def __init__(self, embedding_size, class_num, s=30.0, m=0.50):
|
|
"""ArcFace formula:
|
|
cos(m + theta) = cos(m)cos(theta) - sin(m)sin(theta)
|
|
Note that:
|
|
0 <= m + theta <= Pi
|
|
So if (m + theta) >= Pi, then theta >= Pi - m. In [0, Pi]
|
|
we have:
|
|
cos(theta) < cos(Pi - m)
|
|
So we can use cos(Pi - m) as threshold to check whether
|
|
(m + theta) go out of [0, Pi]
|
|
|
|
Args:
|
|
embedding_size: usually 128, 256, 512 ...
|
|
class_num: num of people when training
|
|
s: scale, see normface https://arxiv.org/abs/1704.06369
|
|
m: margin, see SphereFace, CosFace, and ArcFace paper
|
|
"""
|
|
super().__init__()
|
|
self.in_features = embedding_size
|
|
self.out_features = class_num
|
|
self.s = s
|
|
self.m = m
|
|
self.weight = nn.Parameter(torch.FloatTensor(class_num, embedding_size))
|
|
nn.init.xavier_uniform_(self.weight)
|
|
|
|
self.cos_m = math.cos(m)
|
|
self.sin_m = math.sin(m)
|
|
self.th = math.cos(math.pi - m)
|
|
self.mm = math.sin(math.pi - m) * m
|
|
|
|
def forward(self, input, label):
|
|
#print(f"embding {self.in_features}, class_num {self.out_features}, input {len(input)}, label {len(label)}")
|
|
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
|
# print('F.normalize(input)',input.shape)
|
|
# print('F.normalize(self.weight)',F.normalize(self.weight).shape)
|
|
sine = ((1.0 - cosine.pow(2)).clamp(0, 1)).sqrt()
|
|
phi = cosine * self.cos_m - sine * self.sin_m
|
|
phi = torch.where(cosine > self.th, phi, cosine - self.mm) # drop to CosFace
|
|
#print(f'consine {cosine.shape, cosine}, sine {sine.shape, sine}, phi {phi.shape, phi}')
|
|
# update y_i by phi in cosine
|
|
output = cosine * 1.0 # make backward works
|
|
batch_size = len(output)
|
|
output[range(batch_size), label] = phi[range(batch_size), label]
|
|
# print(f'output {(output * self.s).shape}')
|
|
# print(f'phi[range(batch_size), label] {phi[range(batch_size), label]}')
|
|
return output * self.s
|
|
|
|
|
|
class CosFace(nn.Module):
|
|
|
|
def __init__(self, in_features, out_features, s=30.0, m=0.40):
|
|
"""
|
|
Args:
|
|
embedding_size: usually 128, 256, 512 ...
|
|
class_num: num of people when training
|
|
s: scale, see normface https://arxiv.org/abs/1704.06369
|
|
m: margin, see SphereFace, CosFace, and ArcFace paper
|
|
"""
|
|
super().__init__()
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.s = s
|
|
self.m = m
|
|
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
|
|
nn.init.xavier_uniform_(self.weight)
|
|
|
|
def forward(self, input, label):
|
|
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
|
phi = cosine - self.m
|
|
output = cosine * 1.0 # make backward works
|
|
batch_size = len(output)
|
|
output[range(batch_size), label] = phi[range(batch_size), label]
|
|
return output * self.s
|
|
|
|
class Distillation(nn.Module):
|
|
def __init__(self, in_features, out_features, T=1.0):
|
|
super(Distillation, self).__init__()
|
|
self.T = T
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
|
|
nn.init.xavier_uniform_(self.weight)
|
|
def forward(self, input, labels):
|
|
pass |