add yolo v10 and modify pipeline
This commit is contained in:
@ -1,7 +1,5 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Model head modules
|
||||
"""
|
||||
"""Model head modules."""
|
||||
|
||||
import math
|
||||
|
||||
@ -9,25 +7,28 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.init import constant_, xavier_uniform_
|
||||
|
||||
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, make_anchors
|
||||
|
||||
from .block import DFL, Proto
|
||||
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
|
||||
from .block import DFL, Proto, ContrastiveHead, BNContrastiveHead
|
||||
from .conv import Conv
|
||||
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
|
||||
from .utils import bias_init_with_prob, linear_init_
|
||||
from .utils import bias_init_with_prob, linear_init
|
||||
import copy
|
||||
from ultralytics.utils import ops
|
||||
|
||||
__all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'RTDETRDecoder'
|
||||
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"
|
||||
|
||||
|
||||
class Detect(nn.Module):
|
||||
"""YOLOv8 Detect head for detection models."""
|
||||
|
||||
dynamic = False # force grid reconstruction
|
||||
export = False # export mode
|
||||
shape = None
|
||||
anchors = torch.empty(0) # init
|
||||
strides = torch.empty(0) # init
|
||||
|
||||
def __init__(self, nc=80, ch=()): # detection layer
|
||||
def __init__(self, nc=80, ch=()):
|
||||
"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""
|
||||
super().__init__()
|
||||
self.nc = nc # number of classes
|
||||
self.nl = len(ch) # number of detection layers
|
||||
@ -36,41 +37,54 @@ class Detect(nn.Module):
|
||||
self.stride = torch.zeros(self.nl) # strides computed during build
|
||||
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
|
||||
self.cv2 = nn.ModuleList(
|
||||
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
|
||||
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
|
||||
)
|
||||
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
|
||||
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""Concatenates and returns predicted bounding boxes and class probabilities."""
|
||||
def inference(self, x):
|
||||
# Inference path
|
||||
shape = x[0].shape # BCHW
|
||||
for i in range(self.nl):
|
||||
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
|
||||
if self.training:
|
||||
return x
|
||||
elif self.dynamic or self.shape != shape:
|
||||
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
||||
if self.dynamic or self.shape != shape:
|
||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
||||
if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
|
||||
box = x_cat[:, :self.reg_max * 4]
|
||||
cls = x_cat[:, self.reg_max * 4:]
|
||||
if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
|
||||
box = x_cat[:, : self.reg_max * 4]
|
||||
cls = x_cat[:, self.reg_max * 4 :]
|
||||
else:
|
||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
||||
|
||||
if self.export and self.format in ('tflite', 'edgetpu'):
|
||||
# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:
|
||||
# https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309
|
||||
# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
|
||||
img_h = shape[2] * self.stride[0]
|
||||
img_w = shape[3] * self.stride[0]
|
||||
img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)
|
||||
dbox /= img_size
|
||||
if self.export and self.format in ("tflite", "edgetpu"):
|
||||
# Precompute normalization factor to increase numerical stability
|
||||
# See https://github.com/ultralytics/ultralytics/issues/7371
|
||||
grid_h = shape[2]
|
||||
grid_w = shape[3]
|
||||
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
|
||||
norm = self.strides / (self.stride[0] * grid_size)
|
||||
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
|
||||
else:
|
||||
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
|
||||
|
||||
y = torch.cat((dbox, cls.sigmoid()), 1)
|
||||
return y if self.export else (y, x)
|
||||
|
||||
def forward_feat(self, x, cv2, cv3):
|
||||
y = []
|
||||
for i in range(self.nl):
|
||||
y.append(torch.cat((cv2[i](x[i]), cv3[i](x[i])), 1))
|
||||
return y
|
||||
|
||||
def forward(self, x):
|
||||
"""Concatenates and returns predicted bounding boxes and class probabilities."""
|
||||
y = self.forward_feat(x, self.cv2, self.cv3)
|
||||
|
||||
if self.training:
|
||||
return y
|
||||
|
||||
return self.inference(y)
|
||||
|
||||
def bias_init(self):
|
||||
"""Initialize Detect() biases, WARNING: requires stride availability."""
|
||||
m = self # self.model[-1] # Detect() module
|
||||
@ -78,7 +92,13 @@ class Detect(nn.Module):
|
||||
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
||||
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
||||
a[-1].bias.data[:] = 1.0 # box
|
||||
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
||||
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
||||
|
||||
def decode_bboxes(self, bboxes, anchors):
|
||||
"""Decode bounding boxes."""
|
||||
if self.export:
|
||||
return dist2bbox(bboxes, anchors, xywh=False, dim=1)
|
||||
return dist2bbox(bboxes, anchors, xywh=True, dim=1)
|
||||
|
||||
|
||||
class Segment(Detect):
|
||||
@ -107,6 +127,37 @@ class Segment(Detect):
|
||||
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
|
||||
|
||||
|
||||
class OBB(Detect):
|
||||
"""YOLOv8 OBB detection head for detection with rotation models."""
|
||||
|
||||
def __init__(self, nc=80, ne=1, ch=()):
|
||||
"""Initialize OBB with number of classes `nc` and layer channels `ch`."""
|
||||
super().__init__(nc, ch)
|
||||
self.ne = ne # number of extra parameters
|
||||
self.detect = Detect.forward
|
||||
|
||||
c4 = max(ch[0] // 4, self.ne)
|
||||
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
|
||||
|
||||
def forward(self, x):
|
||||
"""Concatenates and returns predicted bounding boxes and class probabilities."""
|
||||
bs = x[0].shape[0] # batch size
|
||||
angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits
|
||||
# NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
|
||||
angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
|
||||
# angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]
|
||||
if not self.training:
|
||||
self.angle = angle
|
||||
x = self.detect(self, x)
|
||||
if self.training:
|
||||
return x, angle
|
||||
return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
|
||||
|
||||
def decode_bboxes(self, bboxes, anchors):
|
||||
"""Decode rotated bounding boxes."""
|
||||
return dist2rbox(bboxes, self.angle, anchors, dim=1)
|
||||
|
||||
|
||||
class Pose(Detect):
|
||||
"""YOLOv8 Pose head for keypoints models."""
|
||||
|
||||
@ -142,7 +193,7 @@ class Pose(Detect):
|
||||
else:
|
||||
y = kpts.clone()
|
||||
if ndim == 3:
|
||||
y[:, 2::3].sigmoid_() # inplace sigmoid
|
||||
y[:, 2::3] = y[:, 2::3].sigmoid() # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug)
|
||||
y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
|
||||
y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
|
||||
return y
|
||||
@ -151,7 +202,10 @@ class Pose(Detect):
|
||||
class Classify(nn.Module):
|
||||
"""YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
|
||||
"""Initializes YOLOv8 classification head with specified input and output channels, kernel size, stride,
|
||||
padding, and groups.
|
||||
"""
|
||||
super().__init__()
|
||||
c_ = 1280 # efficientnet_b0 size
|
||||
self.conv = Conv(c1, c_, k, s, p, g)
|
||||
@ -167,27 +221,99 @@ class Classify(nn.Module):
|
||||
return x if self.training else x.softmax(1)
|
||||
|
||||
|
||||
class WorldDetect(Detect):
|
||||
def __init__(self, nc=80, embed=512, with_bn=False, ch=()):
|
||||
"""Initialize YOLOv8 detection layer with nc classes and layer channels ch."""
|
||||
super().__init__(nc, ch)
|
||||
c3 = max(ch[0], min(self.nc, 100))
|
||||
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
|
||||
self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
|
||||
|
||||
def forward(self, x, text):
|
||||
"""Concatenates and returns predicted bounding boxes and class probabilities."""
|
||||
for i in range(self.nl):
|
||||
x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)
|
||||
if self.training:
|
||||
return x
|
||||
|
||||
# Inference path
|
||||
shape = x[0].shape # BCHW
|
||||
x_cat = torch.cat([xi.view(shape[0], self.nc + self.reg_max * 4, -1) for xi in x], 2)
|
||||
if self.dynamic or self.shape != shape:
|
||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
|
||||
box = x_cat[:, : self.reg_max * 4]
|
||||
cls = x_cat[:, self.reg_max * 4 :]
|
||||
else:
|
||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||
|
||||
if self.export and self.format in ("tflite", "edgetpu"):
|
||||
# Precompute normalization factor to increase numerical stability
|
||||
# See https://github.com/ultralytics/ultralytics/issues/7371
|
||||
grid_h = shape[2]
|
||||
grid_w = shape[3]
|
||||
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
|
||||
norm = self.strides / (self.stride[0] * grid_size)
|
||||
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
|
||||
else:
|
||||
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
|
||||
|
||||
y = torch.cat((dbox, cls.sigmoid()), 1)
|
||||
return y if self.export else (y, x)
|
||||
|
||||
|
||||
class RTDETRDecoder(nn.Module):
|
||||
"""
|
||||
Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
|
||||
|
||||
This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes
|
||||
and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
|
||||
Transformer decoder layers to output the final predictions.
|
||||
"""
|
||||
|
||||
export = False # export mode
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nc=80,
|
||||
ch=(512, 1024, 2048),
|
||||
hd=256, # hidden dim
|
||||
nq=300, # num queries
|
||||
ndp=4, # num decoder points
|
||||
nh=8, # num head
|
||||
ndl=6, # num decoder layers
|
||||
d_ffn=1024, # dim of feedforward
|
||||
dropout=0.,
|
||||
act=nn.ReLU(),
|
||||
eval_idx=-1,
|
||||
# training args
|
||||
nd=100, # num denoising
|
||||
label_noise_ratio=0.5,
|
||||
box_noise_scale=1.0,
|
||||
learnt_init_query=False):
|
||||
self,
|
||||
nc=80,
|
||||
ch=(512, 1024, 2048),
|
||||
hd=256, # hidden dim
|
||||
nq=300, # num queries
|
||||
ndp=4, # num decoder points
|
||||
nh=8, # num head
|
||||
ndl=6, # num decoder layers
|
||||
d_ffn=1024, # dim of feedforward
|
||||
dropout=0.0,
|
||||
act=nn.ReLU(),
|
||||
eval_idx=-1,
|
||||
# Training args
|
||||
nd=100, # num denoising
|
||||
label_noise_ratio=0.5,
|
||||
box_noise_scale=1.0,
|
||||
learnt_init_query=False,
|
||||
):
|
||||
"""
|
||||
Initializes the RTDETRDecoder module with the given parameters.
|
||||
|
||||
Args:
|
||||
nc (int): Number of classes. Default is 80.
|
||||
ch (tuple): Channels in the backbone feature maps. Default is (512, 1024, 2048).
|
||||
hd (int): Dimension of hidden layers. Default is 256.
|
||||
nq (int): Number of query points. Default is 300.
|
||||
ndp (int): Number of decoder points. Default is 4.
|
||||
nh (int): Number of heads in multi-head attention. Default is 8.
|
||||
ndl (int): Number of decoder layers. Default is 6.
|
||||
d_ffn (int): Dimension of the feed-forward networks. Default is 1024.
|
||||
dropout (float): Dropout rate. Default is 0.
|
||||
act (nn.Module): Activation function. Default is nn.ReLU.
|
||||
eval_idx (int): Evaluation index. Default is -1.
|
||||
nd (int): Number of denoising. Default is 100.
|
||||
label_noise_ratio (float): Label noise ratio. Default is 0.5.
|
||||
box_noise_scale (float): Box noise scale. Default is 1.0.
|
||||
learnt_init_query (bool): Whether to learn initial query embeddings. Default is False.
|
||||
"""
|
||||
super().__init__()
|
||||
self.hidden_dim = hd
|
||||
self.nhead = nh
|
||||
@ -196,7 +322,7 @@ class RTDETRDecoder(nn.Module):
|
||||
self.num_queries = nq
|
||||
self.num_decoder_layers = ndl
|
||||
|
||||
# backbone feature projection
|
||||
# Backbone feature projection
|
||||
self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)
|
||||
# NOTE: simplified version but it's not consistent with .pt weights.
|
||||
# self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)
|
||||
@ -205,58 +331,61 @@ class RTDETRDecoder(nn.Module):
|
||||
decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)
|
||||
self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)
|
||||
|
||||
# denoising part
|
||||
# Denoising part
|
||||
self.denoising_class_embed = nn.Embedding(nc, hd)
|
||||
self.num_denoising = nd
|
||||
self.label_noise_ratio = label_noise_ratio
|
||||
self.box_noise_scale = box_noise_scale
|
||||
|
||||
# decoder embedding
|
||||
# Decoder embedding
|
||||
self.learnt_init_query = learnt_init_query
|
||||
if learnt_init_query:
|
||||
self.tgt_embed = nn.Embedding(nq, hd)
|
||||
self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)
|
||||
|
||||
# encoder head
|
||||
# Encoder head
|
||||
self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))
|
||||
self.enc_score_head = nn.Linear(hd, nc)
|
||||
self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)
|
||||
|
||||
# decoder head
|
||||
# Decoder head
|
||||
self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])
|
||||
self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def forward(self, x, batch=None):
|
||||
"""Runs the forward pass of the module, returning bounding box and classification scores for the input."""
|
||||
from ultralytics.models.utils.ops import get_cdn_group
|
||||
|
||||
# input projection and embedding
|
||||
# Input projection and embedding
|
||||
feats, shapes = self._get_encoder_input(x)
|
||||
|
||||
# prepare denoising training
|
||||
dn_embed, dn_bbox, attn_mask, dn_meta = \
|
||||
get_cdn_group(batch,
|
||||
self.nc,
|
||||
self.num_queries,
|
||||
self.denoising_class_embed.weight,
|
||||
self.num_denoising,
|
||||
self.label_noise_ratio,
|
||||
self.box_noise_scale,
|
||||
self.training)
|
||||
# Prepare denoising training
|
||||
dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group(
|
||||
batch,
|
||||
self.nc,
|
||||
self.num_queries,
|
||||
self.denoising_class_embed.weight,
|
||||
self.num_denoising,
|
||||
self.label_noise_ratio,
|
||||
self.box_noise_scale,
|
||||
self.training,
|
||||
)
|
||||
|
||||
embed, refer_bbox, enc_bboxes, enc_scores = \
|
||||
self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
|
||||
embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
|
||||
|
||||
# decoder
|
||||
dec_bboxes, dec_scores = self.decoder(embed,
|
||||
refer_bbox,
|
||||
feats,
|
||||
shapes,
|
||||
self.dec_bbox_head,
|
||||
self.dec_score_head,
|
||||
self.query_pos_head,
|
||||
attn_mask=attn_mask)
|
||||
# Decoder
|
||||
dec_bboxes, dec_scores = self.decoder(
|
||||
embed,
|
||||
refer_bbox,
|
||||
feats,
|
||||
shapes,
|
||||
self.dec_bbox_head,
|
||||
self.dec_score_head,
|
||||
self.query_pos_head,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
|
||||
if self.training:
|
||||
return x
|
||||
@ -264,29 +393,31 @@ class RTDETRDecoder(nn.Module):
|
||||
y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
|
||||
return y if self.export else (y, x)
|
||||
|
||||
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
|
||||
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2):
|
||||
"""Generates anchor bounding boxes for given shapes with specific grid size and validates them."""
|
||||
anchors = []
|
||||
for i, (h, w) in enumerate(shapes):
|
||||
sy = torch.arange(end=h, dtype=dtype, device=device)
|
||||
sx = torch.arange(end=w, dtype=dtype, device=device)
|
||||
grid_y, grid_x = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
|
||||
grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
|
||||
grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
|
||||
|
||||
valid_WH = torch.tensor([h, w], dtype=dtype, device=device)
|
||||
valid_WH = torch.tensor([w, h], dtype=dtype, device=device)
|
||||
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
|
||||
wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i)
|
||||
wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i)
|
||||
anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
|
||||
|
||||
anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
|
||||
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
|
||||
valid_mask = ((anchors > eps) & (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
|
||||
anchors = torch.log(anchors / (1 - anchors))
|
||||
anchors = anchors.masked_fill(~valid_mask, float('inf'))
|
||||
anchors = anchors.masked_fill(~valid_mask, float("inf"))
|
||||
return anchors, valid_mask
|
||||
|
||||
def _get_encoder_input(self, x):
|
||||
# get projection features
|
||||
"""Processes and returns encoder inputs by getting projection features from input and concatenating them."""
|
||||
# Get projection features
|
||||
x = [self.input_proj[i](feat) for i, feat in enumerate(x)]
|
||||
# get encoder inputs
|
||||
# Get encoder inputs
|
||||
feats = []
|
||||
shapes = []
|
||||
for feat in x:
|
||||
@ -301,14 +432,15 @@ class RTDETRDecoder(nn.Module):
|
||||
return feats, shapes
|
||||
|
||||
def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
|
||||
bs = len(feats)
|
||||
# prepare input for decoder
|
||||
"""Generates and prepares the input required for the decoder from the provided features and shapes."""
|
||||
bs = feats.shape[0]
|
||||
# Prepare input for decoder
|
||||
anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
|
||||
features = self.enc_output(valid_mask * feats) # bs, h*w, 256
|
||||
|
||||
enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
|
||||
|
||||
# query selection
|
||||
# Query selection
|
||||
# (bs, num_queries)
|
||||
topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
|
||||
# (bs, num_queries)
|
||||
@ -319,7 +451,7 @@ class RTDETRDecoder(nn.Module):
|
||||
# (bs, num_queries, 4)
|
||||
top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)
|
||||
|
||||
# dynamic anchors + static content
|
||||
# Dynamic anchors + static content
|
||||
refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors
|
||||
|
||||
enc_bboxes = refer_bbox.sigmoid()
|
||||
@ -339,20 +471,21 @@ class RTDETRDecoder(nn.Module):
|
||||
|
||||
# TODO
|
||||
def _reset_parameters(self):
|
||||
# class and bbox head init
|
||||
"""Initializes or resets the parameters of the model's various components with predefined weights and biases."""
|
||||
# Class and bbox head init
|
||||
bias_cls = bias_init_with_prob(0.01) / 80 * self.nc
|
||||
# NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets.
|
||||
# linear_init_(self.enc_score_head)
|
||||
# NOTE: the weight initialization in `linear_init` would cause NaN when training with custom datasets.
|
||||
# linear_init(self.enc_score_head)
|
||||
constant_(self.enc_score_head.bias, bias_cls)
|
||||
constant_(self.enc_bbox_head.layers[-1].weight, 0.)
|
||||
constant_(self.enc_bbox_head.layers[-1].bias, 0.)
|
||||
constant_(self.enc_bbox_head.layers[-1].weight, 0.0)
|
||||
constant_(self.enc_bbox_head.layers[-1].bias, 0.0)
|
||||
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
|
||||
# linear_init_(cls_)
|
||||
# linear_init(cls_)
|
||||
constant_(cls_.bias, bias_cls)
|
||||
constant_(reg_.layers[-1].weight, 0.)
|
||||
constant_(reg_.layers[-1].bias, 0.)
|
||||
constant_(reg_.layers[-1].weight, 0.0)
|
||||
constant_(reg_.layers[-1].bias, 0.0)
|
||||
|
||||
linear_init_(self.enc_output[0])
|
||||
linear_init(self.enc_output[0])
|
||||
xavier_uniform_(self.enc_output[0].weight)
|
||||
if self.learnt_init_query:
|
||||
xavier_uniform_(self.tgt_embed.weight)
|
||||
@ -360,3 +493,43 @@ class RTDETRDecoder(nn.Module):
|
||||
xavier_uniform_(self.query_pos_head.layers[1].weight)
|
||||
for layer in self.input_proj:
|
||||
xavier_uniform_(layer[0].weight)
|
||||
|
||||
class v10Detect(Detect):
|
||||
|
||||
max_det = 300
|
||||
|
||||
def __init__(self, nc=80, ch=()):
|
||||
super().__init__(nc, ch)
|
||||
c3 = max(ch[0], min(self.nc, 100)) # channels
|
||||
self.cv3 = nn.ModuleList(nn.Sequential(nn.Sequential(Conv(x, x, 3, g=x), Conv(x, c3, 1)), \
|
||||
nn.Sequential(Conv(c3, c3, 3, g=c3), Conv(c3, c3, 1)), \
|
||||
nn.Conv2d(c3, self.nc, 1)) for i, x in enumerate(ch))
|
||||
|
||||
self.one2one_cv2 = copy.deepcopy(self.cv2)
|
||||
self.one2one_cv3 = copy.deepcopy(self.cv3)
|
||||
|
||||
def forward(self, x):
|
||||
one2one = self.forward_feat([xi.detach() for xi in x], self.one2one_cv2, self.one2one_cv3)
|
||||
if not self.export:
|
||||
one2many = super().forward(x)
|
||||
|
||||
if not self.training:
|
||||
one2one = self.inference(one2one)
|
||||
if not self.export:
|
||||
return {"one2many": one2many, "one2one": one2one}
|
||||
else:
|
||||
assert(self.max_det != -1)
|
||||
boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det, self.nc)
|
||||
return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1)
|
||||
else:
|
||||
return {"one2many": one2many, "one2one": one2one}
|
||||
|
||||
def bias_init(self):
|
||||
super().bias_init()
|
||||
"""Initialize Detect() biases, WARNING: requires stride availability."""
|
||||
m = self # self.model[-1] # Detect() module
|
||||
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
||||
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
||||
for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
|
||||
a[-1].bias.data[:] = 1.0 # box
|
||||
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
||||
|
Reference in New Issue
Block a user