Files
ieemoo-ai-detecttracking/contrast/feat_extract/model/metric.py
2025-04-18 14:41:53 +08:00

83 lines
3.2 KiB
Python

# Definition of ArcFace loss and CosFace loss
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class ArcFace(nn.Module):
def __init__(self, embedding_size, class_num, s=30.0, m=0.50):
"""ArcFace formula:
cos(m + theta) = cos(m)cos(theta) - sin(m)sin(theta)
Note that:
0 <= m + theta <= Pi
So if (m + theta) >= Pi, then theta >= Pi - m. In [0, Pi]
we have:
cos(theta) < cos(Pi - m)
So we can use cos(Pi - m) as threshold to check whether
(m + theta) go out of [0, Pi]
Args:
embedding_size: usually 128, 256, 512 ...
class_num: num of people when training
s: scale, see normface https://arxiv.org/abs/1704.06369
m: margin, see SphereFace, CosFace, and ArcFace paper
"""
super().__init__()
self.in_features = embedding_size
self.out_features = class_num
self.s = s
self.m = m
self.weight = nn.Parameter(torch.FloatTensor(class_num, embedding_size))
nn.init.xavier_uniform_(self.weight)
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
def forward(self, input, label):
#print(f"embding {self.in_features}, class_num {self.out_features}, input {len(input)}, label {len(label)}")
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
# print('F.normalize(input)',input.shape)
# print('F.normalize(self.weight)',F.normalize(self.weight).shape)
sine = ((1.0 - cosine.pow(2)).clamp(0, 1)).sqrt()
phi = cosine * self.cos_m - sine * self.sin_m
phi = torch.where(cosine > self.th, phi, cosine - self.mm) # drop to CosFace
#print(f'consine {cosine.shape, cosine}, sine {sine.shape, sine}, phi {phi.shape, phi}')
# update y_i by phi in cosine
output = cosine * 1.0 # make backward works
batch_size = len(output)
output[range(batch_size), label] = phi[range(batch_size), label]
# print(f'output {(output * self.s).shape}')
# print(f'phi[range(batch_size), label] {phi[range(batch_size), label]}')
return output * self.s
class CosFace(nn.Module):
def __init__(self, in_features, out_features, s=30.0, m=0.40):
"""
Args:
embedding_size: usually 128, 256, 512 ...
class_num: num of people when training
s: scale, see normface https://arxiv.org/abs/1704.06369
m: margin, see SphereFace, CosFace, and ArcFace paper
"""
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
def forward(self, input, label):
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
phi = cosine - self.m
output = cosine * 1.0 # make backward works
batch_size = len(output)
output[range(batch_size), label] = phi[range(batch_size), label]
return output * self.s