add yolo v10 and modify pipeline
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -10,6 +10,21 @@ from ultralytics.nn.modules import LayerNorm2d
|
||||
|
||||
|
||||
class MaskDecoder(nn.Module):
|
||||
"""
|
||||
Decoder module for generating masks and their associated quality scores, using a transformer architecture to predict
|
||||
masks given image and prompt embeddings.
|
||||
|
||||
Attributes:
|
||||
transformer_dim (int): Channel dimension for the transformer module.
|
||||
transformer (nn.Module): The transformer module used for mask prediction.
|
||||
num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
|
||||
iou_token (nn.Embedding): Embedding for the IoU token.
|
||||
num_mask_tokens (int): Number of mask tokens.
|
||||
mask_tokens (nn.Embedding): Embedding for the mask tokens.
|
||||
output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.
|
||||
output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.
|
||||
iou_prediction_head (nn.Module): MLP for predicting mask quality.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -49,8 +64,9 @@ class MaskDecoder(nn.Module):
|
||||
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||
activation(),
|
||||
)
|
||||
self.output_hypernetworks_mlps = nn.ModuleList([
|
||||
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)])
|
||||
self.output_hypernetworks_mlps = nn.ModuleList(
|
||||
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
|
||||
)
|
||||
|
||||
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
|
||||
|
||||
@ -98,10 +114,14 @@ class MaskDecoder(nn.Module):
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Predicts masks. See 'forward' for more details."""
|
||||
"""
|
||||
Predicts masks.
|
||||
|
||||
See 'forward' for more details.
|
||||
"""
|
||||
# Concatenate output tokens
|
||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
|
||||
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
||||
|
||||
# Expand per-image data in batch direction to be per-mask
|
||||
@ -113,13 +133,14 @@ class MaskDecoder(nn.Module):
|
||||
# Run the transformer
|
||||
hs, src = self.transformer(src, pos_src, tokens)
|
||||
iou_token_out = hs[:, 0, :]
|
||||
mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]
|
||||
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
||||
|
||||
# Upscale mask embeddings and predict masks using the mask tokens
|
||||
src = src.transpose(1, 2).view(b, c, h, w)
|
||||
upscaled_embedding = self.output_upscaling(src)
|
||||
hyper_in_list: List[torch.Tensor] = [
|
||||
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)]
|
||||
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
|
||||
]
|
||||
hyper_in = torch.stack(hyper_in_list, dim=1)
|
||||
b, c, h, w = upscaled_embedding.shape
|
||||
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
||||
@ -132,7 +153,7 @@ class MaskDecoder(nn.Module):
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
Lightly adapted from
|
||||
MLP (Multi-Layer Perceptron) model lightly adapted from
|
||||
https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
|
||||
"""
|
||||
|
||||
@ -144,6 +165,16 @@ class MLP(nn.Module):
|
||||
num_layers: int,
|
||||
sigmoid_output: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the MLP (Multi-Layer Perceptron) model.
|
||||
|
||||
Args:
|
||||
input_dim (int): The dimensionality of the input features.
|
||||
hidden_dim (int): The dimensionality of the hidden layers.
|
||||
output_dim (int): The dimensionality of the output layer.
|
||||
num_layers (int): The number of hidden layers.
|
||||
sigmoid_output (bool, optional): Apply a sigmoid activation to the output layer. Defaults to False.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
|
@ -10,27 +10,41 @@ import torch.nn.functional as F
|
||||
from ultralytics.nn.modules import LayerNorm2d, MLPBlock
|
||||
|
||||
|
||||
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
||||
class ImageEncoderViT(nn.Module):
|
||||
"""
|
||||
An image encoder using Vision Transformer (ViT) architecture for encoding an image into a compact latent space. The
|
||||
encoder takes an image, splits it into patches, and processes these patches through a series of transformer blocks.
|
||||
The encoded patches are then processed through a neck to generate the final encoded representation.
|
||||
|
||||
This class and its supporting functions below lightly adapted from the ViTDet backbone available at
|
||||
https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py.
|
||||
|
||||
Attributes:
|
||||
img_size (int): Dimension of input images, assumed to be square.
|
||||
patch_embed (PatchEmbed): Module for patch embedding.
|
||||
pos_embed (nn.Parameter, optional): Absolute positional embedding for patches.
|
||||
blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.
|
||||
neck (nn.Sequential): Neck module to further process the output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int = 1024,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_chans: int = 256,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
use_abs_pos: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
global_attn_indexes: Tuple[int, ...] = (),
|
||||
self,
|
||||
img_size: int = 1024,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_chans: int = 256,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
use_abs_pos: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
global_attn_indexes: Tuple[int, ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -100,6 +114,9 @@ class ImageEncoderViT(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Processes input through patch embedding, applies positional embedding if present, and passes through blocks
|
||||
and neck.
|
||||
"""
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
@ -109,6 +126,22 @@ class ImageEncoderViT(nn.Module):
|
||||
|
||||
|
||||
class PromptEncoder(nn.Module):
|
||||
"""
|
||||
Encodes different types of prompts, including points, boxes, and masks, for input to SAM's mask decoder. The encoder
|
||||
produces both sparse and dense embeddings for the input prompts.
|
||||
|
||||
Attributes:
|
||||
embed_dim (int): Dimension of the embeddings.
|
||||
input_image_size (Tuple[int, int]): Size of the input image as (H, W).
|
||||
image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
|
||||
pe_layer (PositionEmbeddingRandom): Module for random position embedding.
|
||||
num_point_embeddings (int): Number of point embeddings for different types of points.
|
||||
point_embeddings (nn.ModuleList): List of point embeddings.
|
||||
not_a_point_embed (nn.Embedding): Embedding for points that are not a part of any label.
|
||||
mask_input_size (Tuple[int, int]): Size of the input mask.
|
||||
mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
|
||||
no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -157,20 +190,15 @@ class PromptEncoder(nn.Module):
|
||||
|
||||
def get_dense_pe(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns the positional encoding used to encode point prompts,
|
||||
applied to a dense set of points the shape of the image encoding.
|
||||
Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the
|
||||
image encoding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
|
||||
"""
|
||||
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
||||
|
||||
def _embed_points(
|
||||
self,
|
||||
points: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
pad: bool,
|
||||
) -> torch.Tensor:
|
||||
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
|
||||
"""Embeds point prompts."""
|
||||
points = points + 0.5 # Shift to center of pixel
|
||||
if pad:
|
||||
@ -204,9 +232,7 @@ class PromptEncoder(nn.Module):
|
||||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
) -> int:
|
||||
"""
|
||||
Gets the batch size of the output given the batch size of the input prompts.
|
||||
"""
|
||||
"""Gets the batch size of the output given the batch size of the input prompts."""
|
||||
if points is not None:
|
||||
return points[0].shape[0]
|
||||
elif boxes is not None:
|
||||
@ -217,6 +243,7 @@ class PromptEncoder(nn.Module):
|
||||
return 1
|
||||
|
||||
def _get_device(self) -> torch.device:
|
||||
"""Returns the device of the first point embedding's weight tensor."""
|
||||
return self.point_embeddings[0].weight.device
|
||||
|
||||
def forward(
|
||||
@ -251,23 +278,22 @@ class PromptEncoder(nn.Module):
|
||||
if masks is not None:
|
||||
dense_embeddings = self._embed_masks(masks)
|
||||
else:
|
||||
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1,
|
||||
1).expand(bs, -1, self.image_embedding_size[0],
|
||||
self.image_embedding_size[1])
|
||||
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
||||
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
||||
)
|
||||
|
||||
return sparse_embeddings, dense_embeddings
|
||||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""
|
||||
Positional encoding using random spatial frequencies.
|
||||
"""
|
||||
"""Positional encoding using random spatial frequencies."""
|
||||
|
||||
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
||||
"""Initializes a position embedding using random spatial frequencies."""
|
||||
super().__init__()
|
||||
if scale is None or scale <= 0.0:
|
||||
scale = 1.0
|
||||
self.register_buffer('positional_encoding_gaussian_matrix', scale * torch.randn((2, num_pos_feats)))
|
||||
self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
|
||||
|
||||
# Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
|
||||
torch.use_deterministic_algorithms(False)
|
||||
@ -275,11 +301,11 @@ class PositionEmbeddingRandom(nn.Module):
|
||||
|
||||
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
"""Positionally encode points that are normalized to [0,1]."""
|
||||
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
||||
# Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
||||
coords = 2 * coords - 1
|
||||
coords = coords @ self.positional_encoding_gaussian_matrix
|
||||
coords = 2 * np.pi * coords
|
||||
# outputs d_1 x ... x d_n x C shape
|
||||
# Outputs d_1 x ... x d_n x C shape
|
||||
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
||||
|
||||
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
||||
@ -304,7 +330,7 @@ class PositionEmbeddingRandom(nn.Module):
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -351,6 +377,7 @@ class Block(nn.Module):
|
||||
self.window_size = window_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Executes a forward pass through the transformer block with window attention and non-overlapping windows."""
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
# Window partition
|
||||
@ -380,6 +407,8 @@ class Attention(nn.Module):
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Attention module.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
@ -391,19 +420,20 @@ class Attention(nn.Module):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.use_rel_pos = use_rel_pos
|
||||
if self.use_rel_pos:
|
||||
assert (input_size is not None), 'Input size must be provided if using relative positional encoding.'
|
||||
# initialize relative positional embeddings
|
||||
assert input_size is not None, "Input size must be provided if using relative positional encoding."
|
||||
# Initialize relative positional embeddings
|
||||
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
||||
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies the forward operation including attention, normalization, MLP, and indexing within window limits."""
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (3, B, nHead, H * W, C)
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
@ -444,10 +474,12 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
|
||||
return windows, (Hp, Wp)
|
||||
|
||||
|
||||
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int],
|
||||
hw: Tuple[int, int]) -> torch.Tensor:
|
||||
def window_unpartition(
|
||||
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Window unpartition into original sequences and removing padding.
|
||||
|
||||
Args:
|
||||
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
||||
window_size (int): window size.
|
||||
@ -470,8 +502,8 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[in
|
||||
|
||||
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get relative positional embeddings according to the relative positions of
|
||||
query and key sizes.
|
||||
Get relative positional embeddings according to the relative positions of query and key sizes.
|
||||
|
||||
Args:
|
||||
q_size (int): size of query q.
|
||||
k_size (int): size of key k.
|
||||
@ -487,7 +519,7 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode='linear',
|
||||
mode="linear",
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||
else:
|
||||
@ -510,8 +542,9 @@ def add_decomposed_rel_pos(
|
||||
k_size: Tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
||||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
||||
Calculate decomposed Relative Positional Embeddings from mvitv2 paper at
|
||||
https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py.
|
||||
|
||||
Args:
|
||||
attn (Tensor): attention map.
|
||||
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
||||
@ -530,29 +563,30 @@ def add_decomposed_rel_pos(
|
||||
|
||||
B, _, dim = q.shape
|
||||
r_q = q.reshape(B, q_h, q_w, dim)
|
||||
rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh)
|
||||
rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw)
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
|
||||
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
|
||||
B, q_h * q_w, k_h * k_w)
|
||||
B, q_h * q_w, k_h * k_w
|
||||
)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
Image to Patch Embedding.
|
||||
"""
|
||||
"""Image to Patch Embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Tuple[int, int] = (16, 16),
|
||||
stride: Tuple[int, int] = (16, 16),
|
||||
padding: Tuple[int, int] = (0, 0),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
self,
|
||||
kernel_size: Tuple[int, int] = (16, 16),
|
||||
stride: Tuple[int, int] = (16, 16),
|
||||
padding: Tuple[int, int] = (0, 0),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize PatchEmbed module.
|
||||
|
||||
Args:
|
||||
kernel_size (Tuple): kernel size of the projection layer.
|
||||
stride (Tuple): stride of the projection layer.
|
||||
@ -565,4 +599,5 @@ class PatchEmbed(nn.Module):
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes patch embedding by applying convolution and transposing resulting tensor."""
|
||||
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
|
||||
|
@ -16,8 +16,23 @@ from .encoders import ImageEncoderViT, PromptEncoder
|
||||
|
||||
|
||||
class Sam(nn.Module):
|
||||
"""
|
||||
Sam (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate image
|
||||
embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by the mask
|
||||
decoder to predict object masks.
|
||||
|
||||
Attributes:
|
||||
mask_threshold (float): Threshold value for mask prediction.
|
||||
image_format (str): Format of the input image, default is 'RGB'.
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings.
|
||||
pixel_mean (List[float]): Mean pixel values for image normalization.
|
||||
pixel_std (List[float]): Standard deviation values for image normalization.
|
||||
"""
|
||||
|
||||
mask_threshold: float = 0.0
|
||||
image_format: str = 'RGB'
|
||||
image_format: str = "RGB"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -25,25 +40,26 @@ class Sam(nn.Module):
|
||||
prompt_encoder: PromptEncoder,
|
||||
mask_decoder: MaskDecoder,
|
||||
pixel_mean: List[float] = (123.675, 116.28, 103.53),
|
||||
pixel_std: List[float] = (58.395, 57.12, 57.375)
|
||||
pixel_std: List[float] = (58.395, 57.12, 57.375),
|
||||
) -> None:
|
||||
"""
|
||||
SAM predicts object masks from an image and input prompts.
|
||||
Initialize the Sam class to predict object masks from an image and input prompts.
|
||||
|
||||
Note:
|
||||
All forward() operations moved to SAMPredictor.
|
||||
|
||||
Args:
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
|
||||
efficient mask prediction.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
|
||||
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
||||
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
|
||||
pixel_mean (List[float], optional): Mean values for normalizing pixels in the input image. Defaults to
|
||||
(123.675, 116.28, 103.53).
|
||||
pixel_std (List[float], optional): Std values for normalizing pixels in the input image. Defaults to
|
||||
(58.395, 57.12, 57.375).
|
||||
"""
|
||||
super().__init__()
|
||||
self.image_encoder = image_encoder
|
||||
self.prompt_encoder = prompt_encoder
|
||||
self.mask_decoder = mask_decoder
|
||||
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
||||
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
||||
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
||||
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
||||
|
@ -21,19 +21,27 @@ from ultralytics.utils.instance import to_2tuple
|
||||
|
||||
|
||||
class Conv2d_BN(torch.nn.Sequential):
|
||||
"""A sequential container that performs 2D convolution followed by batch normalization."""
|
||||
|
||||
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
|
||||
"""Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and
|
||||
drop path.
|
||||
"""
|
||||
super().__init__()
|
||||
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
||||
self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
||||
bn = torch.nn.BatchNorm2d(b)
|
||||
torch.nn.init.constant_(bn.weight, bn_weight_init)
|
||||
torch.nn.init.constant_(bn.bias, 0)
|
||||
self.add_module('bn', bn)
|
||||
self.add_module("bn", bn)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Embeds images into patches and projects them into a specified embedding dimension."""
|
||||
|
||||
def __init__(self, in_chans, embed_dim, resolution, activation):
|
||||
"""Initialize the PatchMerging class with specified input, output dimensions, resolution and activation
|
||||
function.
|
||||
"""
|
||||
super().__init__()
|
||||
img_size: Tuple[int, int] = to_2tuple(resolution)
|
||||
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
||||
@ -48,12 +56,17 @@ class PatchEmbed(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Runs input tensor 'x' through the PatchMerging model's sequence of operations."""
|
||||
return self.seq(x)
|
||||
|
||||
|
||||
class MBConv(nn.Module):
|
||||
"""Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture."""
|
||||
|
||||
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
|
||||
"""Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation
|
||||
function.
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.hidden_chans = int(in_chans * expand_ratio)
|
||||
@ -73,6 +86,7 @@ class MBConv(nn.Module):
|
||||
self.drop_path = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""Implements the forward pass for the model architecture."""
|
||||
shortcut = x
|
||||
x = self.conv1(x)
|
||||
x = self.act1(x)
|
||||
@ -85,8 +99,12 @@ class MBConv(nn.Module):
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
"""Merges neighboring patches in the feature map and projects to a new dimension."""
|
||||
|
||||
def __init__(self, input_resolution, dim, out_dim, activation):
|
||||
"""Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other
|
||||
optional parameters.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.input_resolution = input_resolution
|
||||
@ -99,6 +117,7 @@ class PatchMerging(nn.Module):
|
||||
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies forward pass on the input utilizing convolution and activation layers, and returns the result."""
|
||||
if x.ndim == 3:
|
||||
H, W = self.input_resolution
|
||||
B = len(x)
|
||||
@ -115,6 +134,11 @@ class PatchMerging(nn.Module):
|
||||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
"""
|
||||
Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
|
||||
|
||||
Optionally applies downsample operations to the output, and provides support for gradient checkpointing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -122,41 +146,69 @@ class ConvLayer(nn.Module):
|
||||
input_resolution,
|
||||
depth,
|
||||
activation,
|
||||
drop_path=0.,
|
||||
drop_path=0.0,
|
||||
downsample=None,
|
||||
use_checkpoint=False,
|
||||
out_dim=None,
|
||||
conv_expand_ratio=4.,
|
||||
conv_expand_ratio=4.0,
|
||||
):
|
||||
"""
|
||||
Initializes the ConvLayer with the given dimensions and settings.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): The resolution of the input image.
|
||||
depth (int): The number of MBConv layers in the block.
|
||||
activation (Callable): Activation function applied after each convolution.
|
||||
drop_path (Union[float, List[float]]): Drop path rate. Single float or a list of floats for each MBConv.
|
||||
downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling.
|
||||
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
||||
out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
|
||||
conv_expand_ratio (float): Expansion ratio for the MBConv layers.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
MBConv(
|
||||
dim,
|
||||
dim,
|
||||
conv_expand_ratio,
|
||||
activation,
|
||||
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
) for i in range(depth)])
|
||||
# Build blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
MBConv(
|
||||
dim,
|
||||
dim,
|
||||
conv_expand_ratio,
|
||||
activation,
|
||||
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
# patch merging layer
|
||||
self.downsample = None if downsample is None else downsample(
|
||||
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
# Patch merging layer
|
||||
self.downsample = (
|
||||
None
|
||||
if downsample is None
|
||||
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Processes the input through a series of convolutional layers and returns the activated output."""
|
||||
for blk in self.blocks:
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""
|
||||
Multi-layer Perceptron (MLP) for transformer architectures.
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
This layer takes an input with in_features, applies layer normalization and two fully-connected layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
||||
"""Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
@ -167,6 +219,7 @@ class Mlp(nn.Module):
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies operations on input x and returns modified x, runs downsample if not None."""
|
||||
x = self.norm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
@ -176,20 +229,41 @@ class Mlp(nn.Module):
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
"""
|
||||
Multi-head attention module with support for spatial awareness, applying attention biases based on spatial
|
||||
resolution. Implements trainable attention biases for each unique offset between spatial positions in the resolution
|
||||
grid.
|
||||
|
||||
Attributes:
|
||||
ab (Tensor, optional): Cached attention biases for inference, deleted during training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
key_dim,
|
||||
num_heads=8,
|
||||
attn_ratio=4,
|
||||
resolution=(14, 14),
|
||||
self,
|
||||
dim,
|
||||
key_dim,
|
||||
num_heads=8,
|
||||
attn_ratio=4,
|
||||
resolution=(14, 14),
|
||||
):
|
||||
"""
|
||||
Initializes the Attention module.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
key_dim (int): The dimensionality of the keys and queries.
|
||||
num_heads (int, optional): Number of attention heads. Default is 8.
|
||||
attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
|
||||
resolution (Tuple[int, int], optional): Spatial resolution of the input feature map. Default is (14, 14).
|
||||
|
||||
Raises:
|
||||
AssertionError: If `resolution` is not a tuple of length 2.
|
||||
"""
|
||||
super().__init__()
|
||||
# (h, w)
|
||||
|
||||
assert isinstance(resolution, tuple) and len(resolution) == 2
|
||||
self.num_heads = num_heads
|
||||
self.scale = key_dim ** -0.5
|
||||
self.scale = key_dim**-0.5
|
||||
self.key_dim = key_dim
|
||||
self.nh_kd = nh_kd = key_dim * num_heads
|
||||
self.d = int(attn_ratio * key_dim)
|
||||
@ -212,18 +286,20 @@ class Attention(torch.nn.Module):
|
||||
attention_offsets[offset] = len(attention_offsets)
|
||||
idxs.append(attention_offsets[offset])
|
||||
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
||||
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
|
||||
self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def train(self, mode=True):
|
||||
"""Sets the module in training mode and handles attribute 'ab' based on the mode."""
|
||||
super().train(mode)
|
||||
if mode and hasattr(self, 'ab'):
|
||||
if mode and hasattr(self, "ab"):
|
||||
del self.ab
|
||||
else:
|
||||
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
||||
|
||||
def forward(self, x): # x (B,N,C)
|
||||
B, N, _ = x.shape
|
||||
def forward(self, x): # x
|
||||
"""Performs forward pass over the input tensor 'x' by applying normalization and querying keys/values."""
|
||||
B, N, _ = x.shape # B, N, C
|
||||
|
||||
# Normalization
|
||||
x = self.norm(x)
|
||||
@ -237,28 +313,16 @@ class Attention(torch.nn.Module):
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
self.ab = self.ab.to(self.attention_biases.device)
|
||||
|
||||
attn = ((q @ k.transpose(-2, -1)) * self.scale +
|
||||
(self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab))
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale + (
|
||||
self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
|
||||
)
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class TinyViTBlock(nn.Module):
|
||||
"""
|
||||
TinyViT Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int, int]): Input resolution.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
local_conv_size (int): the kernel size of the convolution between Attention and MLP. Default: 3
|
||||
activation (torch.nn): the activation function. Default: nn.GELU
|
||||
"""
|
||||
"""TinyViT Block that applies self-attention and a local convolution to the input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -266,17 +330,35 @@ class TinyViTBlock(nn.Module):
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
mlp_ratio=4.0,
|
||||
drop=0.0,
|
||||
drop_path=0.0,
|
||||
local_conv_size=3,
|
||||
activation=nn.GELU,
|
||||
):
|
||||
"""
|
||||
Initializes the TinyViTBlock.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int, optional): Window size for attention. Default is 7.
|
||||
mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
|
||||
drop (float, optional): Dropout rate. Default is 0.
|
||||
drop_path (float, optional): Stochastic depth rate. Default is 0.
|
||||
local_conv_size (int, optional): The kernel size of the local convolution. Default is 3.
|
||||
activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `window_size` is not greater than 0.
|
||||
AssertionError: If `dim` is not divisible by `num_heads`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.num_heads = num_heads
|
||||
assert window_size > 0, 'window_size must be greater than 0'
|
||||
assert window_size > 0, "window_size must be greater than 0"
|
||||
self.window_size = window_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
@ -284,7 +366,7 @@ class TinyViTBlock(nn.Module):
|
||||
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.drop_path = nn.Identity()
|
||||
|
||||
assert dim % num_heads == 0, 'dim must be divisible by num_heads'
|
||||
assert dim % num_heads == 0, "dim must be divisible by num_heads"
|
||||
head_dim = dim // num_heads
|
||||
|
||||
window_resolution = (window_size, window_size)
|
||||
@ -298,9 +380,12 @@ class TinyViTBlock(nn.Module):
|
||||
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies attention-based transformation or padding to input 'x' before passing it through a local
|
||||
convolution.
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
res_x = x
|
||||
if H == self.window_size and W == self.window_size:
|
||||
x = self.attn(x)
|
||||
@ -316,11 +401,14 @@ class TinyViTBlock(nn.Module):
|
||||
pH, pW = H + pad_b, W + pad_r
|
||||
nH = pH // self.window_size
|
||||
nW = pW // self.window_size
|
||||
# window partition
|
||||
x = x.view(B, nH, self.window_size, nW, self.window_size,
|
||||
C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C)
|
||||
# Window partition
|
||||
x = (
|
||||
x.view(B, nH, self.window_size, nW, self.window_size, C)
|
||||
.transpose(2, 3)
|
||||
.reshape(B * nH * nW, self.window_size * self.window_size, C)
|
||||
)
|
||||
x = self.attn(x)
|
||||
# window reverse
|
||||
# Window reverse
|
||||
x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
|
||||
|
||||
if padding:
|
||||
@ -337,29 +425,17 @@ class TinyViTBlock(nn.Module):
|
||||
return x + self.drop_path(self.mlp(x))
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
|
||||
f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
|
||||
"""Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
|
||||
attentions heads, window size, and MLP ratio.
|
||||
"""
|
||||
return (
|
||||
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
||||
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
|
||||
)
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
"""
|
||||
A basic TinyViT layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
local_conv_size (int): the kernel size of the depthwise convolution between attention and MLP. Default: 3
|
||||
activation (torch.nn): the activation function. Default: nn.GELU
|
||||
out_dim (int | optional): the output dimension of the layer. Default: None
|
||||
"""
|
||||
"""A basic TinyViT layer for one stage in a TinyViT architecture."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -368,57 +444,90 @@ class BasicLayer(nn.Module):
|
||||
depth,
|
||||
num_heads,
|
||||
window_size,
|
||||
mlp_ratio=4.,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
mlp_ratio=4.0,
|
||||
drop=0.0,
|
||||
drop_path=0.0,
|
||||
downsample=None,
|
||||
use_checkpoint=False,
|
||||
local_conv_size=3,
|
||||
activation=nn.GELU,
|
||||
out_dim=None,
|
||||
):
|
||||
"""
|
||||
Initializes the BasicLayer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
||||
depth (int): Number of TinyViT blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
|
||||
drop (float, optional): Dropout rate. Default is 0.
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default is 0.
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default is None.
|
||||
use_checkpoint (bool, optional): Whether to use checkpointing to save memory. Default is False.
|
||||
local_conv_size (int, optional): Kernel size of the local convolution. Default is 3.
|
||||
activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
|
||||
out_dim (int | None, optional): The output dimension of the layer. Default is None.
|
||||
|
||||
Raises:
|
||||
ValueError: If `drop_path` is a list of float but its length doesn't match `depth`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
TinyViTBlock(
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
local_conv_size=local_conv_size,
|
||||
activation=activation,
|
||||
) for i in range(depth)])
|
||||
# Build blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
TinyViTBlock(
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
local_conv_size=local_conv_size,
|
||||
activation=activation,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
# patch merging layer
|
||||
self.downsample = None if downsample is None else downsample(
|
||||
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
# Patch merging layer
|
||||
self.downsample = (
|
||||
None
|
||||
if downsample is None
|
||||
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Performs forward propagation on the input tensor and returns a normalized tensor."""
|
||||
for blk in self.blocks:
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
|
||||
"""Returns a string representation of the extra_repr function with the layer's parameters."""
|
||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
"""A PyTorch implementation of Layer Normalization in 2D."""
|
||||
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||
"""Initialize LayerNorm2d with the number of channels and an optional epsilon."""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(num_channels))
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Perform a forward pass, normalizing the input tensor."""
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
@ -426,6 +535,30 @@ class LayerNorm2d(nn.Module):
|
||||
|
||||
|
||||
class TinyViT(nn.Module):
|
||||
"""
|
||||
The TinyViT architecture for vision tasks.
|
||||
|
||||
Attributes:
|
||||
img_size (int): Input image size.
|
||||
in_chans (int): Number of input channels.
|
||||
num_classes (int): Number of classification classes.
|
||||
embed_dims (List[int]): List of embedding dimensions for each layer.
|
||||
depths (List[int]): List of depths for each layer.
|
||||
num_heads (List[int]): List of number of attention heads for each layer.
|
||||
window_sizes (List[int]): List of window sizes for each layer.
|
||||
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
||||
drop_rate (float): Dropout rate for drop layers.
|
||||
drop_path_rate (float): Drop path rate for stochastic depth.
|
||||
use_checkpoint (bool): Use checkpointing for efficient memory usage.
|
||||
mbconv_expand_ratio (float): Expansion ratio for MBConv layer.
|
||||
local_conv_size (int): Local convolution kernel size.
|
||||
layer_lr_decay (float): Layer-wise learning rate decay.
|
||||
|
||||
Note:
|
||||
This implementation is generalized to accept a list of depths, attention heads,
|
||||
embedding dimensions and window sizes, which allows you to create a
|
||||
"stack" of TinyViT models of varying configurations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -436,14 +569,33 @@ class TinyViT(nn.Module):
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_sizes=[7, 7, 14, 7],
|
||||
mlp_ratio=4.,
|
||||
drop_rate=0.,
|
||||
mlp_ratio=4.0,
|
||||
drop_rate=0.0,
|
||||
drop_path_rate=0.1,
|
||||
use_checkpoint=False,
|
||||
mbconv_expand_ratio=4.0,
|
||||
local_conv_size=3,
|
||||
layer_lr_decay=1.0,
|
||||
):
|
||||
"""
|
||||
Initializes the TinyViT model.
|
||||
|
||||
Args:
|
||||
img_size (int, optional): The input image size. Defaults to 224.
|
||||
in_chans (int, optional): Number of input channels. Defaults to 3.
|
||||
num_classes (int, optional): Number of classification classes. Defaults to 1000.
|
||||
embed_dims (List[int], optional): List of embedding dimensions for each layer. Defaults to [96, 192, 384, 768].
|
||||
depths (List[int], optional): List of depths for each layer. Defaults to [2, 2, 6, 2].
|
||||
num_heads (List[int], optional): List of number of attention heads for each layer. Defaults to [3, 6, 12, 24].
|
||||
window_sizes (List[int], optional): List of window sizes for each layer. Defaults to [7, 7, 14, 7].
|
||||
mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension. Defaults to 4.
|
||||
drop_rate (float, optional): Dropout rate. Defaults to 0.
|
||||
drop_path_rate (float, optional): Drop path rate for stochastic depth. Defaults to 0.1.
|
||||
use_checkpoint (bool, optional): Whether to use checkpointing for efficient memory usage. Defaults to False.
|
||||
mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer. Defaults to 4.0.
|
||||
local_conv_size (int, optional): Local convolution kernel size. Defaults to 3.
|
||||
layer_lr_decay (float, optional): Layer-wise learning rate decay. Defaults to 1.0.
|
||||
"""
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
self.num_classes = num_classes
|
||||
@ -453,50 +605,52 @@ class TinyViT(nn.Module):
|
||||
|
||||
activation = nn.GELU
|
||||
|
||||
self.patch_embed = PatchEmbed(in_chans=in_chans,
|
||||
embed_dim=embed_dims[0],
|
||||
resolution=img_size,
|
||||
activation=activation)
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation
|
||||
)
|
||||
|
||||
patches_resolution = self.patch_embed.patches_resolution
|
||||
self.patches_resolution = patches_resolution
|
||||
|
||||
# stochastic depth
|
||||
# Stochastic depth
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
|
||||
# build layers
|
||||
# Build layers
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
kwargs = dict(
|
||||
dim=embed_dims[i_layer],
|
||||
input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
||||
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
|
||||
input_resolution=(
|
||||
patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
||||
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
||||
),
|
||||
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
||||
# patches_resolution[1] // (2 ** i_layer)),
|
||||
depth=depths[i_layer],
|
||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
||||
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
use_checkpoint=use_checkpoint,
|
||||
out_dim=embed_dims[min(i_layer + 1,
|
||||
len(embed_dims) - 1)],
|
||||
out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
|
||||
activation=activation,
|
||||
)
|
||||
if i_layer == 0:
|
||||
layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
|
||||
else:
|
||||
layer = BasicLayer(num_heads=num_heads[i_layer],
|
||||
window_size=window_sizes[i_layer],
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
drop=drop_rate,
|
||||
local_conv_size=local_conv_size,
|
||||
**kwargs)
|
||||
layer = BasicLayer(
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_sizes[i_layer],
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
drop=drop_rate,
|
||||
local_conv_size=local_conv_size,
|
||||
**kwargs,
|
||||
)
|
||||
self.layers.append(layer)
|
||||
|
||||
# Classifier head
|
||||
self.norm_head = nn.LayerNorm(embed_dims[-1])
|
||||
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
|
||||
|
||||
# init weights
|
||||
# Init weights
|
||||
self.apply(self._init_weights)
|
||||
self.set_layer_lr_decay(layer_lr_decay)
|
||||
self.neck = nn.Sequential(
|
||||
@ -518,13 +672,15 @@ class TinyViT(nn.Module):
|
||||
)
|
||||
|
||||
def set_layer_lr_decay(self, layer_lr_decay):
|
||||
"""Sets the learning rate decay for each layer in the TinyViT model."""
|
||||
decay_rate = layer_lr_decay
|
||||
|
||||
# layers -> blocks (depth)
|
||||
# Layers -> blocks (depth)
|
||||
depth = sum(self.depths)
|
||||
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
|
||||
|
||||
def _set_lr_scale(m, scale):
|
||||
"""Sets the learning rate scale for each layer in the model based on the layer's depth."""
|
||||
for p in m.parameters():
|
||||
p.lr_scale = scale
|
||||
|
||||
@ -544,12 +700,14 @@ class TinyViT(nn.Module):
|
||||
p.param_name = k
|
||||
|
||||
def _check_lr_scale(m):
|
||||
"""Checks if the learning rate scale attribute is present in module's parameters."""
|
||||
for p in m.parameters():
|
||||
assert hasattr(p, 'lr_scale'), p.param_name
|
||||
assert hasattr(p, "lr_scale"), p.param_name
|
||||
|
||||
self.apply(_check_lr_scale)
|
||||
|
||||
def _init_weights(self, m):
|
||||
"""Initializes weights for linear layers and layer normalization in the given module."""
|
||||
if isinstance(m, nn.Linear):
|
||||
# NOTE: This initialization is needed only for training.
|
||||
# trunc_normal_(m.weight, std=.02)
|
||||
@ -561,11 +719,12 @@ class TinyViT(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay_keywords(self):
|
||||
return {'attention_biases'}
|
||||
"""Returns a dictionary of parameter names where weight decay should not be applied."""
|
||||
return {"attention_biases"}
|
||||
|
||||
def forward_features(self, x):
|
||||
# x: (N, C, H, W)
|
||||
x = self.patch_embed(x)
|
||||
"""Runs the input through the model layers and returns the transformed output."""
|
||||
x = self.patch_embed(x) # x input is (N, C, H, W)
|
||||
|
||||
x = self.layers[0](x)
|
||||
start_i = 1
|
||||
@ -573,10 +732,11 @@ class TinyViT(nn.Module):
|
||||
for i in range(start_i, len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
x = layer(x)
|
||||
B, _, C = x.size()
|
||||
B, _, C = x.shape
|
||||
x = x.view(B, 64, 64, C)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
return self.neck(x)
|
||||
|
||||
def forward(self, x):
|
||||
"""Executes a forward pass on the input tensor through the constructed model layers."""
|
||||
return self.forward_features(x)
|
||||
|
@ -10,6 +10,21 @@ from ultralytics.nn.modules import MLPBlock
|
||||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
"""
|
||||
A Two-Way Transformer module that enables the simultaneous attention to both image and query points. This class
|
||||
serves as a specialized transformer decoder that attends to an input image using queries whose positional embedding
|
||||
is supplied. This is particularly useful for tasks like object detection, image segmentation, and point cloud
|
||||
processing.
|
||||
|
||||
Attributes:
|
||||
depth (int): The number of layers in the transformer.
|
||||
embedding_dim (int): The channel dimension for the input embeddings.
|
||||
num_heads (int): The number of heads for multihead attention.
|
||||
mlp_dim (int): The internal channel dimension for the MLP block.
|
||||
layers (nn.ModuleList): The list of TwoWayAttentionBlock layers that make up the transformer.
|
||||
final_attn_token_to_image (Attention): The final attention layer applied from the queries to the image.
|
||||
norm_final_attn (nn.LayerNorm): The layer normalization applied to the final queries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -21,8 +36,7 @@ class TwoWayTransformer(nn.Module):
|
||||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer decoder that attends to an input image using
|
||||
queries whose positional embedding is supplied.
|
||||
A transformer decoder that attends to an input image using queries whose positional embedding is supplied.
|
||||
|
||||
Args:
|
||||
depth (int): number of layers in the transformer
|
||||
@ -48,7 +62,8 @@ class TwoWayTransformer(nn.Module):
|
||||
activation=activation,
|
||||
attention_downsample_rate=attention_downsample_rate,
|
||||
skip_first_layer_pe=(i == 0),
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
||||
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
||||
@ -99,6 +114,23 @@ class TwoWayTransformer(nn.Module):
|
||||
|
||||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that performs both self-attention and cross-attention in two directions: queries to keys and
|
||||
keys to queries. This block consists of four main layers: (1) self-attention on sparse inputs, (2) cross-attention
|
||||
of sparse inputs to dense inputs, (3) an MLP block on sparse inputs, and (4) cross-attention of dense inputs to
|
||||
sparse inputs.
|
||||
|
||||
Attributes:
|
||||
self_attn (Attention): The self-attention layer for the queries.
|
||||
norm1 (nn.LayerNorm): Layer normalization following the first attention block.
|
||||
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
|
||||
norm2 (nn.LayerNorm): Layer normalization following the second attention block.
|
||||
mlp (MLPBlock): MLP block that transforms the query embeddings.
|
||||
norm3 (nn.LayerNorm): Layer normalization following the MLP block.
|
||||
norm4 (nn.LayerNorm): Layer normalization following the third attention block.
|
||||
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
|
||||
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -171,8 +203,7 @@ class TwoWayAttentionBlock(nn.Module):
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||
"""An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||
values.
|
||||
"""
|
||||
|
||||
@ -182,24 +213,37 @@ class Attention(nn.Module):
|
||||
num_heads: int,
|
||||
downsample_rate: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the Attention model with the given dimensions and settings.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): The dimensionality of the input embeddings.
|
||||
num_heads (int): The number of attention heads.
|
||||
downsample_rate (int, optional): The factor by which the internal dimensions are downsampled. Defaults to 1.
|
||||
|
||||
Raises:
|
||||
AssertionError: If 'num_heads' does not evenly divide the internal dimension (embedding_dim / downsample_rate).
|
||||
"""
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.internal_dim = embedding_dim // downsample_rate
|
||||
self.num_heads = num_heads
|
||||
assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.'
|
||||
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
||||
|
||||
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
||||
|
||||
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
||||
@staticmethod
|
||||
def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
|
||||
"""Separate the input tensor into the specified number of attention heads."""
|
||||
b, n, c = x.shape
|
||||
x = x.reshape(b, n, num_heads, c // num_heads)
|
||||
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
||||
|
||||
def _recombine_heads(self, x: Tensor) -> Tensor:
|
||||
@staticmethod
|
||||
def _recombine_heads(x: Tensor) -> Tensor:
|
||||
"""Recombine the separated attention heads into a single tensor."""
|
||||
b, n_heads, n_tokens, c_per_head = x.shape
|
||||
x = x.transpose(1, 2)
|
||||
|
Reference in New Issue
Block a user