add yolo v10 and modify pipeline
This commit is contained in:
@ -3,6 +3,4 @@
|
||||
from .model import SAM
|
||||
from .predict import Predictor
|
||||
|
||||
# from .build import build_sam
|
||||
|
||||
__all__ = 'SAM', 'Predictor' # tuple or list
|
||||
__all__ = "SAM", "Predictor" # tuple or list
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -8,10 +8,9 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def is_box_near_crop_edge(boxes: torch.Tensor,
|
||||
crop_box: List[int],
|
||||
orig_box: List[int],
|
||||
atol: float = 20.0) -> torch.Tensor:
|
||||
def is_box_near_crop_edge(
|
||||
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
|
||||
) -> torch.Tensor:
|
||||
"""Return a boolean tensor indicating if boxes are near the crop edge."""
|
||||
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
||||
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
||||
@ -24,23 +23,25 @@ def is_box_near_crop_edge(boxes: torch.Tensor,
|
||||
|
||||
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
||||
"""Yield batches of data from the input arguments."""
|
||||
assert args and all(len(a) == len(args[0]) for a in args), 'Batched iteration must have same-size inputs.'
|
||||
assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
|
||||
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
||||
for b in range(n_batches):
|
||||
yield [arg[b * batch_size:(b + 1) * batch_size] for arg in args]
|
||||
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
||||
|
||||
|
||||
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
||||
"""
|
||||
Computes the stability score for a batch of masks. The stability
|
||||
score is the IoU between the binary masks obtained by thresholding
|
||||
the predicted mask logits at high and low values.
|
||||
Computes the stability score for a batch of masks.
|
||||
|
||||
The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high
|
||||
and low values.
|
||||
|
||||
Notes:
|
||||
- One mask is always contained inside the other.
|
||||
- Save memory by preventing unnecessary cast to torch.int64
|
||||
"""
|
||||
# One mask is always contained inside the other.
|
||||
# Save memory by preventing unnecessary cast to torch.int64
|
||||
intersections = ((masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1,
|
||||
dtype=torch.int32))
|
||||
unions = ((masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32))
|
||||
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||
return intersections / unions
|
||||
|
||||
|
||||
@ -55,12 +56,17 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
|
||||
|
||||
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
|
||||
"""Generate point grids for all crop layers."""
|
||||
return [build_point_grid(int(n_per_side / (scale_per_layer ** i))) for i in range(n_layers + 1)]
|
||||
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
|
||||
|
||||
|
||||
def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int,
|
||||
overlap_ratio: float) -> Tuple[List[List[int]], List[int]]:
|
||||
"""Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer."""
|
||||
def generate_crop_boxes(
|
||||
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
||||
) -> Tuple[List[List[int]], List[int]]:
|
||||
"""
|
||||
Generates a list of crop boxes of different sizes.
|
||||
|
||||
Each layer has (2**i)**2 boxes for the ith layer.
|
||||
"""
|
||||
crop_boxes, layer_idxs = [], []
|
||||
im_h, im_w = im_size
|
||||
short_side = min(im_h, im_w)
|
||||
@ -127,8 +133,8 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
|
||||
"""Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in {'holes', 'islands'}
|
||||
correct_holes = mode == 'holes'
|
||||
assert mode in {"holes", "islands"}
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
@ -145,8 +151,9 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
|
||||
|
||||
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
|
||||
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
||||
Calculates boxes in XYXY format around masks.
|
||||
|
||||
Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
||||
"""
|
||||
# torch.max below raises an error on empty inputs, just skip in this case
|
||||
if torch.numel(masks) == 0:
|
||||
|
@ -11,7 +11,6 @@ from functools import partial
|
||||
import torch
|
||||
|
||||
from ultralytics.utils.downloads import attempt_download_asset
|
||||
|
||||
from .modules.decoders import MaskDecoder
|
||||
from .modules.encoders import ImageEncoderViT, PromptEncoder
|
||||
from .modules.sam import Sam
|
||||
@ -64,46 +63,47 @@ def build_mobile_sam(checkpoint=None):
|
||||
)
|
||||
|
||||
|
||||
def _build_sam(encoder_embed_dim,
|
||||
encoder_depth,
|
||||
encoder_num_heads,
|
||||
encoder_global_attn_indexes,
|
||||
checkpoint=None,
|
||||
mobile_sam=False):
|
||||
def _build_sam(
|
||||
encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False
|
||||
):
|
||||
"""Builds the selected SAM model architecture."""
|
||||
prompt_embed_dim = 256
|
||||
image_size = 1024
|
||||
vit_patch_size = 16
|
||||
image_embedding_size = image_size // vit_patch_size
|
||||
image_encoder = (TinyViT(
|
||||
img_size=1024,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dims=encoder_embed_dim,
|
||||
depths=encoder_depth,
|
||||
num_heads=encoder_num_heads,
|
||||
window_sizes=[7, 7, 14, 7],
|
||||
mlp_ratio=4.0,
|
||||
drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
use_checkpoint=False,
|
||||
mbconv_expand_ratio=4.0,
|
||||
local_conv_size=3,
|
||||
layer_lr_decay=0.8,
|
||||
) if mobile_sam else ImageEncoderViT(
|
||||
depth=encoder_depth,
|
||||
embed_dim=encoder_embed_dim,
|
||||
img_size=image_size,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
||||
num_heads=encoder_num_heads,
|
||||
patch_size=vit_patch_size,
|
||||
qkv_bias=True,
|
||||
use_rel_pos=True,
|
||||
global_attn_indexes=encoder_global_attn_indexes,
|
||||
window_size=14,
|
||||
out_chans=prompt_embed_dim,
|
||||
))
|
||||
image_encoder = (
|
||||
TinyViT(
|
||||
img_size=1024,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dims=encoder_embed_dim,
|
||||
depths=encoder_depth,
|
||||
num_heads=encoder_num_heads,
|
||||
window_sizes=[7, 7, 14, 7],
|
||||
mlp_ratio=4.0,
|
||||
drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
use_checkpoint=False,
|
||||
mbconv_expand_ratio=4.0,
|
||||
local_conv_size=3,
|
||||
layer_lr_decay=0.8,
|
||||
)
|
||||
if mobile_sam
|
||||
else ImageEncoderViT(
|
||||
depth=encoder_depth,
|
||||
embed_dim=encoder_embed_dim,
|
||||
img_size=image_size,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
||||
num_heads=encoder_num_heads,
|
||||
patch_size=vit_patch_size,
|
||||
qkv_bias=True,
|
||||
use_rel_pos=True,
|
||||
global_attn_indexes=encoder_global_attn_indexes,
|
||||
window_size=14,
|
||||
out_chans=prompt_embed_dim,
|
||||
)
|
||||
)
|
||||
sam = Sam(
|
||||
image_encoder=image_encoder,
|
||||
prompt_encoder=PromptEncoder(
|
||||
@ -129,7 +129,7 @@ def _build_sam(encoder_embed_dim,
|
||||
)
|
||||
if checkpoint is not None:
|
||||
checkpoint = attempt_download_asset(checkpoint)
|
||||
with open(checkpoint, 'rb') as f:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
sam.load_state_dict(state_dict)
|
||||
sam.eval()
|
||||
@ -139,20 +139,22 @@ def _build_sam(encoder_embed_dim,
|
||||
|
||||
|
||||
sam_model_map = {
|
||||
'sam_h.pt': build_sam_vit_h,
|
||||
'sam_l.pt': build_sam_vit_l,
|
||||
'sam_b.pt': build_sam_vit_b,
|
||||
'mobile_sam.pt': build_mobile_sam, }
|
||||
"sam_h.pt": build_sam_vit_h,
|
||||
"sam_l.pt": build_sam_vit_l,
|
||||
"sam_b.pt": build_sam_vit_b,
|
||||
"mobile_sam.pt": build_mobile_sam,
|
||||
}
|
||||
|
||||
|
||||
def build_sam(ckpt='sam_b.pt'):
|
||||
def build_sam(ckpt="sam_b.pt"):
|
||||
"""Build a SAM model specified by ckpt."""
|
||||
model_builder = None
|
||||
ckpt = str(ckpt) # to allow Path ckpt types
|
||||
for k in sam_model_map.keys():
|
||||
if ckpt.endswith(k):
|
||||
model_builder = sam_model_map.get(k)
|
||||
|
||||
if not model_builder:
|
||||
raise FileNotFoundError(f'{ckpt} is not a supported sam model. Available models are: \n {sam_model_map.keys()}')
|
||||
raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
|
||||
|
||||
return model_builder(ckpt)
|
||||
|
@ -1,51 +1,114 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
SAM model interface
|
||||
SAM model interface.
|
||||
|
||||
This module provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for real-time image
|
||||
segmentation tasks. The SAM model allows for promptable segmentation with unparalleled versatility in image analysis,
|
||||
and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new
|
||||
image distributions and tasks without prior knowledge.
|
||||
|
||||
Key Features:
|
||||
- Promptable segmentation
|
||||
- Real-time performance
|
||||
- Zero-shot transfer capabilities
|
||||
- Trained on SA-1B dataset
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.engine.model import Model
|
||||
from ultralytics.utils.torch_utils import model_info
|
||||
|
||||
from .build import build_sam
|
||||
from .predict import Predictor
|
||||
|
||||
|
||||
class SAM(Model):
|
||||
"""
|
||||
SAM model interface.
|
||||
SAM (Segment Anything Model) interface class.
|
||||
|
||||
SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as
|
||||
bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B
|
||||
dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, model='sam_b.pt') -> None:
|
||||
if model and Path(model).suffix not in ('.pt', '.pth'):
|
||||
raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
|
||||
super().__init__(model=model, task='segment')
|
||||
def __init__(self, model="sam_b.pt") -> None:
|
||||
"""
|
||||
Initializes the SAM model with a pre-trained model file.
|
||||
|
||||
Args:
|
||||
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the model file extension is not .pt or .pth.
|
||||
"""
|
||||
if model and Path(model).suffix not in (".pt", ".pth"):
|
||||
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
|
||||
super().__init__(model=model, task="segment")
|
||||
|
||||
def _load(self, weights: str, task=None):
|
||||
"""
|
||||
Loads the specified weights into the SAM model.
|
||||
|
||||
Args:
|
||||
weights (str): Path to the weights file.
|
||||
task (str, optional): Task name. Defaults to None.
|
||||
"""
|
||||
self.model = build_sam(weights)
|
||||
|
||||
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""Predicts and returns segmentation masks for given image or video source."""
|
||||
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
|
||||
"""
|
||||
Performs segmentation prediction on the given image or video source.
|
||||
|
||||
Args:
|
||||
source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
|
||||
stream (bool, optional): If True, enables real-time streaming. Defaults to False.
|
||||
bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
|
||||
points (list, optional): List of points for prompted segmentation. Defaults to None.
|
||||
labels (list, optional): List of labels for prompted segmentation. Defaults to None.
|
||||
|
||||
Returns:
|
||||
(list): The model predictions.
|
||||
"""
|
||||
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
|
||||
kwargs.update(overrides)
|
||||
prompts = dict(bboxes=bboxes, points=points, labels=labels)
|
||||
return super().predict(source, stream, prompts=prompts, **kwargs)
|
||||
|
||||
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
||||
"""
|
||||
Alias for the 'predict' method.
|
||||
|
||||
Args:
|
||||
source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
|
||||
stream (bool, optional): If True, enables real-time streaming. Defaults to False.
|
||||
bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
|
||||
points (list, optional): List of points for prompted segmentation. Defaults to None.
|
||||
labels (list, optional): List of labels for prompted segmentation. Defaults to None.
|
||||
|
||||
Returns:
|
||||
(list): The model predictions.
|
||||
"""
|
||||
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
||||
|
||||
def info(self, detailed=False, verbose=True):
|
||||
"""
|
||||
Logs model info.
|
||||
Logs information about the SAM model.
|
||||
|
||||
Args:
|
||||
detailed (bool): Show detailed information about model.
|
||||
verbose (bool): Controls verbosity.
|
||||
detailed (bool, optional): If True, displays detailed information about the model. Defaults to False.
|
||||
verbose (bool, optional): If True, displays information on the console. Defaults to True.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing the model's information.
|
||||
"""
|
||||
return model_info(self.model, detailed=detailed, verbose=verbose)
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
return {'segment': {'predictor': Predictor}}
|
||||
"""
|
||||
Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
|
||||
"""
|
||||
return {"segment": {"predictor": Predictor}}
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -10,6 +10,21 @@ from ultralytics.nn.modules import LayerNorm2d
|
||||
|
||||
|
||||
class MaskDecoder(nn.Module):
|
||||
"""
|
||||
Decoder module for generating masks and their associated quality scores, using a transformer architecture to predict
|
||||
masks given image and prompt embeddings.
|
||||
|
||||
Attributes:
|
||||
transformer_dim (int): Channel dimension for the transformer module.
|
||||
transformer (nn.Module): The transformer module used for mask prediction.
|
||||
num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
|
||||
iou_token (nn.Embedding): Embedding for the IoU token.
|
||||
num_mask_tokens (int): Number of mask tokens.
|
||||
mask_tokens (nn.Embedding): Embedding for the mask tokens.
|
||||
output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.
|
||||
output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.
|
||||
iou_prediction_head (nn.Module): MLP for predicting mask quality.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -49,8 +64,9 @@ class MaskDecoder(nn.Module):
|
||||
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||
activation(),
|
||||
)
|
||||
self.output_hypernetworks_mlps = nn.ModuleList([
|
||||
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)])
|
||||
self.output_hypernetworks_mlps = nn.ModuleList(
|
||||
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
|
||||
)
|
||||
|
||||
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
|
||||
|
||||
@ -98,10 +114,14 @@ class MaskDecoder(nn.Module):
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Predicts masks. See 'forward' for more details."""
|
||||
"""
|
||||
Predicts masks.
|
||||
|
||||
See 'forward' for more details.
|
||||
"""
|
||||
# Concatenate output tokens
|
||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
|
||||
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
||||
|
||||
# Expand per-image data in batch direction to be per-mask
|
||||
@ -113,13 +133,14 @@ class MaskDecoder(nn.Module):
|
||||
# Run the transformer
|
||||
hs, src = self.transformer(src, pos_src, tokens)
|
||||
iou_token_out = hs[:, 0, :]
|
||||
mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]
|
||||
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
||||
|
||||
# Upscale mask embeddings and predict masks using the mask tokens
|
||||
src = src.transpose(1, 2).view(b, c, h, w)
|
||||
upscaled_embedding = self.output_upscaling(src)
|
||||
hyper_in_list: List[torch.Tensor] = [
|
||||
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)]
|
||||
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
|
||||
]
|
||||
hyper_in = torch.stack(hyper_in_list, dim=1)
|
||||
b, c, h, w = upscaled_embedding.shape
|
||||
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
||||
@ -132,7 +153,7 @@ class MaskDecoder(nn.Module):
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
Lightly adapted from
|
||||
MLP (Multi-Layer Perceptron) model lightly adapted from
|
||||
https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
|
||||
"""
|
||||
|
||||
@ -144,6 +165,16 @@ class MLP(nn.Module):
|
||||
num_layers: int,
|
||||
sigmoid_output: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the MLP (Multi-Layer Perceptron) model.
|
||||
|
||||
Args:
|
||||
input_dim (int): The dimensionality of the input features.
|
||||
hidden_dim (int): The dimensionality of the hidden layers.
|
||||
output_dim (int): The dimensionality of the output layer.
|
||||
num_layers (int): The number of hidden layers.
|
||||
sigmoid_output (bool, optional): Apply a sigmoid activation to the output layer. Defaults to False.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
|
@ -10,27 +10,41 @@ import torch.nn.functional as F
|
||||
from ultralytics.nn.modules import LayerNorm2d, MLPBlock
|
||||
|
||||
|
||||
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
||||
class ImageEncoderViT(nn.Module):
|
||||
"""
|
||||
An image encoder using Vision Transformer (ViT) architecture for encoding an image into a compact latent space. The
|
||||
encoder takes an image, splits it into patches, and processes these patches through a series of transformer blocks.
|
||||
The encoded patches are then processed through a neck to generate the final encoded representation.
|
||||
|
||||
This class and its supporting functions below lightly adapted from the ViTDet backbone available at
|
||||
https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py.
|
||||
|
||||
Attributes:
|
||||
img_size (int): Dimension of input images, assumed to be square.
|
||||
patch_embed (PatchEmbed): Module for patch embedding.
|
||||
pos_embed (nn.Parameter, optional): Absolute positional embedding for patches.
|
||||
blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.
|
||||
neck (nn.Sequential): Neck module to further process the output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int = 1024,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_chans: int = 256,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
use_abs_pos: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
global_attn_indexes: Tuple[int, ...] = (),
|
||||
self,
|
||||
img_size: int = 1024,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_chans: int = 256,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
use_abs_pos: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
global_attn_indexes: Tuple[int, ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -100,6 +114,9 @@ class ImageEncoderViT(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Processes input through patch embedding, applies positional embedding if present, and passes through blocks
|
||||
and neck.
|
||||
"""
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
@ -109,6 +126,22 @@ class ImageEncoderViT(nn.Module):
|
||||
|
||||
|
||||
class PromptEncoder(nn.Module):
|
||||
"""
|
||||
Encodes different types of prompts, including points, boxes, and masks, for input to SAM's mask decoder. The encoder
|
||||
produces both sparse and dense embeddings for the input prompts.
|
||||
|
||||
Attributes:
|
||||
embed_dim (int): Dimension of the embeddings.
|
||||
input_image_size (Tuple[int, int]): Size of the input image as (H, W).
|
||||
image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
|
||||
pe_layer (PositionEmbeddingRandom): Module for random position embedding.
|
||||
num_point_embeddings (int): Number of point embeddings for different types of points.
|
||||
point_embeddings (nn.ModuleList): List of point embeddings.
|
||||
not_a_point_embed (nn.Embedding): Embedding for points that are not a part of any label.
|
||||
mask_input_size (Tuple[int, int]): Size of the input mask.
|
||||
mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
|
||||
no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -157,20 +190,15 @@ class PromptEncoder(nn.Module):
|
||||
|
||||
def get_dense_pe(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns the positional encoding used to encode point prompts,
|
||||
applied to a dense set of points the shape of the image encoding.
|
||||
Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the
|
||||
image encoding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
|
||||
"""
|
||||
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
||||
|
||||
def _embed_points(
|
||||
self,
|
||||
points: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
pad: bool,
|
||||
) -> torch.Tensor:
|
||||
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
|
||||
"""Embeds point prompts."""
|
||||
points = points + 0.5 # Shift to center of pixel
|
||||
if pad:
|
||||
@ -204,9 +232,7 @@ class PromptEncoder(nn.Module):
|
||||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
) -> int:
|
||||
"""
|
||||
Gets the batch size of the output given the batch size of the input prompts.
|
||||
"""
|
||||
"""Gets the batch size of the output given the batch size of the input prompts."""
|
||||
if points is not None:
|
||||
return points[0].shape[0]
|
||||
elif boxes is not None:
|
||||
@ -217,6 +243,7 @@ class PromptEncoder(nn.Module):
|
||||
return 1
|
||||
|
||||
def _get_device(self) -> torch.device:
|
||||
"""Returns the device of the first point embedding's weight tensor."""
|
||||
return self.point_embeddings[0].weight.device
|
||||
|
||||
def forward(
|
||||
@ -251,23 +278,22 @@ class PromptEncoder(nn.Module):
|
||||
if masks is not None:
|
||||
dense_embeddings = self._embed_masks(masks)
|
||||
else:
|
||||
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1,
|
||||
1).expand(bs, -1, self.image_embedding_size[0],
|
||||
self.image_embedding_size[1])
|
||||
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
||||
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
||||
)
|
||||
|
||||
return sparse_embeddings, dense_embeddings
|
||||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""
|
||||
Positional encoding using random spatial frequencies.
|
||||
"""
|
||||
"""Positional encoding using random spatial frequencies."""
|
||||
|
||||
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
||||
"""Initializes a position embedding using random spatial frequencies."""
|
||||
super().__init__()
|
||||
if scale is None or scale <= 0.0:
|
||||
scale = 1.0
|
||||
self.register_buffer('positional_encoding_gaussian_matrix', scale * torch.randn((2, num_pos_feats)))
|
||||
self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
|
||||
|
||||
# Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
|
||||
torch.use_deterministic_algorithms(False)
|
||||
@ -275,11 +301,11 @@ class PositionEmbeddingRandom(nn.Module):
|
||||
|
||||
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
"""Positionally encode points that are normalized to [0,1]."""
|
||||
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
||||
# Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
||||
coords = 2 * coords - 1
|
||||
coords = coords @ self.positional_encoding_gaussian_matrix
|
||||
coords = 2 * np.pi * coords
|
||||
# outputs d_1 x ... x d_n x C shape
|
||||
# Outputs d_1 x ... x d_n x C shape
|
||||
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
||||
|
||||
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
||||
@ -304,7 +330,7 @@ class PositionEmbeddingRandom(nn.Module):
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -351,6 +377,7 @@ class Block(nn.Module):
|
||||
self.window_size = window_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Executes a forward pass through the transformer block with window attention and non-overlapping windows."""
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
# Window partition
|
||||
@ -380,6 +407,8 @@ class Attention(nn.Module):
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Attention module.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
@ -391,19 +420,20 @@ class Attention(nn.Module):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.use_rel_pos = use_rel_pos
|
||||
if self.use_rel_pos:
|
||||
assert (input_size is not None), 'Input size must be provided if using relative positional encoding.'
|
||||
# initialize relative positional embeddings
|
||||
assert input_size is not None, "Input size must be provided if using relative positional encoding."
|
||||
# Initialize relative positional embeddings
|
||||
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
||||
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies the forward operation including attention, normalization, MLP, and indexing within window limits."""
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (3, B, nHead, H * W, C)
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
@ -444,10 +474,12 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
|
||||
return windows, (Hp, Wp)
|
||||
|
||||
|
||||
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int],
|
||||
hw: Tuple[int, int]) -> torch.Tensor:
|
||||
def window_unpartition(
|
||||
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Window unpartition into original sequences and removing padding.
|
||||
|
||||
Args:
|
||||
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
||||
window_size (int): window size.
|
||||
@ -470,8 +502,8 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[in
|
||||
|
||||
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get relative positional embeddings according to the relative positions of
|
||||
query and key sizes.
|
||||
Get relative positional embeddings according to the relative positions of query and key sizes.
|
||||
|
||||
Args:
|
||||
q_size (int): size of query q.
|
||||
k_size (int): size of key k.
|
||||
@ -487,7 +519,7 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode='linear',
|
||||
mode="linear",
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||
else:
|
||||
@ -510,8 +542,9 @@ def add_decomposed_rel_pos(
|
||||
k_size: Tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
||||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
||||
Calculate decomposed Relative Positional Embeddings from mvitv2 paper at
|
||||
https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py.
|
||||
|
||||
Args:
|
||||
attn (Tensor): attention map.
|
||||
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
||||
@ -530,29 +563,30 @@ def add_decomposed_rel_pos(
|
||||
|
||||
B, _, dim = q.shape
|
||||
r_q = q.reshape(B, q_h, q_w, dim)
|
||||
rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh)
|
||||
rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw)
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
|
||||
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
|
||||
B, q_h * q_w, k_h * k_w)
|
||||
B, q_h * q_w, k_h * k_w
|
||||
)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
Image to Patch Embedding.
|
||||
"""
|
||||
"""Image to Patch Embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Tuple[int, int] = (16, 16),
|
||||
stride: Tuple[int, int] = (16, 16),
|
||||
padding: Tuple[int, int] = (0, 0),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
self,
|
||||
kernel_size: Tuple[int, int] = (16, 16),
|
||||
stride: Tuple[int, int] = (16, 16),
|
||||
padding: Tuple[int, int] = (0, 0),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize PatchEmbed module.
|
||||
|
||||
Args:
|
||||
kernel_size (Tuple): kernel size of the projection layer.
|
||||
stride (Tuple): stride of the projection layer.
|
||||
@ -565,4 +599,5 @@ class PatchEmbed(nn.Module):
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes patch embedding by applying convolution and transposing resulting tensor."""
|
||||
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
|
||||
|
@ -16,8 +16,23 @@ from .encoders import ImageEncoderViT, PromptEncoder
|
||||
|
||||
|
||||
class Sam(nn.Module):
|
||||
"""
|
||||
Sam (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate image
|
||||
embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by the mask
|
||||
decoder to predict object masks.
|
||||
|
||||
Attributes:
|
||||
mask_threshold (float): Threshold value for mask prediction.
|
||||
image_format (str): Format of the input image, default is 'RGB'.
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings.
|
||||
pixel_mean (List[float]): Mean pixel values for image normalization.
|
||||
pixel_std (List[float]): Standard deviation values for image normalization.
|
||||
"""
|
||||
|
||||
mask_threshold: float = 0.0
|
||||
image_format: str = 'RGB'
|
||||
image_format: str = "RGB"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -25,25 +40,26 @@ class Sam(nn.Module):
|
||||
prompt_encoder: PromptEncoder,
|
||||
mask_decoder: MaskDecoder,
|
||||
pixel_mean: List[float] = (123.675, 116.28, 103.53),
|
||||
pixel_std: List[float] = (58.395, 57.12, 57.375)
|
||||
pixel_std: List[float] = (58.395, 57.12, 57.375),
|
||||
) -> None:
|
||||
"""
|
||||
SAM predicts object masks from an image and input prompts.
|
||||
Initialize the Sam class to predict object masks from an image and input prompts.
|
||||
|
||||
Note:
|
||||
All forward() operations moved to SAMPredictor.
|
||||
|
||||
Args:
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
|
||||
efficient mask prediction.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
|
||||
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
||||
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
|
||||
pixel_mean (List[float], optional): Mean values for normalizing pixels in the input image. Defaults to
|
||||
(123.675, 116.28, 103.53).
|
||||
pixel_std (List[float], optional): Std values for normalizing pixels in the input image. Defaults to
|
||||
(58.395, 57.12, 57.375).
|
||||
"""
|
||||
super().__init__()
|
||||
self.image_encoder = image_encoder
|
||||
self.prompt_encoder = prompt_encoder
|
||||
self.mask_decoder = mask_decoder
|
||||
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
||||
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
||||
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
||||
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
||||
|
@ -21,19 +21,27 @@ from ultralytics.utils.instance import to_2tuple
|
||||
|
||||
|
||||
class Conv2d_BN(torch.nn.Sequential):
|
||||
"""A sequential container that performs 2D convolution followed by batch normalization."""
|
||||
|
||||
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
|
||||
"""Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and
|
||||
drop path.
|
||||
"""
|
||||
super().__init__()
|
||||
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
||||
self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
||||
bn = torch.nn.BatchNorm2d(b)
|
||||
torch.nn.init.constant_(bn.weight, bn_weight_init)
|
||||
torch.nn.init.constant_(bn.bias, 0)
|
||||
self.add_module('bn', bn)
|
||||
self.add_module("bn", bn)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Embeds images into patches and projects them into a specified embedding dimension."""
|
||||
|
||||
def __init__(self, in_chans, embed_dim, resolution, activation):
|
||||
"""Initialize the PatchMerging class with specified input, output dimensions, resolution and activation
|
||||
function.
|
||||
"""
|
||||
super().__init__()
|
||||
img_size: Tuple[int, int] = to_2tuple(resolution)
|
||||
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
||||
@ -48,12 +56,17 @@ class PatchEmbed(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Runs input tensor 'x' through the PatchMerging model's sequence of operations."""
|
||||
return self.seq(x)
|
||||
|
||||
|
||||
class MBConv(nn.Module):
|
||||
"""Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture."""
|
||||
|
||||
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
|
||||
"""Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation
|
||||
function.
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.hidden_chans = int(in_chans * expand_ratio)
|
||||
@ -73,6 +86,7 @@ class MBConv(nn.Module):
|
||||
self.drop_path = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""Implements the forward pass for the model architecture."""
|
||||
shortcut = x
|
||||
x = self.conv1(x)
|
||||
x = self.act1(x)
|
||||
@ -85,8 +99,12 @@ class MBConv(nn.Module):
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
"""Merges neighboring patches in the feature map and projects to a new dimension."""
|
||||
|
||||
def __init__(self, input_resolution, dim, out_dim, activation):
|
||||
"""Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other
|
||||
optional parameters.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.input_resolution = input_resolution
|
||||
@ -99,6 +117,7 @@ class PatchMerging(nn.Module):
|
||||
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies forward pass on the input utilizing convolution and activation layers, and returns the result."""
|
||||
if x.ndim == 3:
|
||||
H, W = self.input_resolution
|
||||
B = len(x)
|
||||
@ -115,6 +134,11 @@ class PatchMerging(nn.Module):
|
||||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
"""
|
||||
Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
|
||||
|
||||
Optionally applies downsample operations to the output, and provides support for gradient checkpointing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -122,41 +146,69 @@ class ConvLayer(nn.Module):
|
||||
input_resolution,
|
||||
depth,
|
||||
activation,
|
||||
drop_path=0.,
|
||||
drop_path=0.0,
|
||||
downsample=None,
|
||||
use_checkpoint=False,
|
||||
out_dim=None,
|
||||
conv_expand_ratio=4.,
|
||||
conv_expand_ratio=4.0,
|
||||
):
|
||||
"""
|
||||
Initializes the ConvLayer with the given dimensions and settings.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): The resolution of the input image.
|
||||
depth (int): The number of MBConv layers in the block.
|
||||
activation (Callable): Activation function applied after each convolution.
|
||||
drop_path (Union[float, List[float]]): Drop path rate. Single float or a list of floats for each MBConv.
|
||||
downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling.
|
||||
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
||||
out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
|
||||
conv_expand_ratio (float): Expansion ratio for the MBConv layers.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
MBConv(
|
||||
dim,
|
||||
dim,
|
||||
conv_expand_ratio,
|
||||
activation,
|
||||
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
) for i in range(depth)])
|
||||
# Build blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
MBConv(
|
||||
dim,
|
||||
dim,
|
||||
conv_expand_ratio,
|
||||
activation,
|
||||
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
# patch merging layer
|
||||
self.downsample = None if downsample is None else downsample(
|
||||
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
# Patch merging layer
|
||||
self.downsample = (
|
||||
None
|
||||
if downsample is None
|
||||
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Processes the input through a series of convolutional layers and returns the activated output."""
|
||||
for blk in self.blocks:
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""
|
||||
Multi-layer Perceptron (MLP) for transformer architectures.
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
This layer takes an input with in_features, applies layer normalization and two fully-connected layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
||||
"""Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
@ -167,6 +219,7 @@ class Mlp(nn.Module):
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies operations on input x and returns modified x, runs downsample if not None."""
|
||||
x = self.norm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
@ -176,20 +229,41 @@ class Mlp(nn.Module):
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
"""
|
||||
Multi-head attention module with support for spatial awareness, applying attention biases based on spatial
|
||||
resolution. Implements trainable attention biases for each unique offset between spatial positions in the resolution
|
||||
grid.
|
||||
|
||||
Attributes:
|
||||
ab (Tensor, optional): Cached attention biases for inference, deleted during training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
key_dim,
|
||||
num_heads=8,
|
||||
attn_ratio=4,
|
||||
resolution=(14, 14),
|
||||
self,
|
||||
dim,
|
||||
key_dim,
|
||||
num_heads=8,
|
||||
attn_ratio=4,
|
||||
resolution=(14, 14),
|
||||
):
|
||||
"""
|
||||
Initializes the Attention module.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
key_dim (int): The dimensionality of the keys and queries.
|
||||
num_heads (int, optional): Number of attention heads. Default is 8.
|
||||
attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
|
||||
resolution (Tuple[int, int], optional): Spatial resolution of the input feature map. Default is (14, 14).
|
||||
|
||||
Raises:
|
||||
AssertionError: If `resolution` is not a tuple of length 2.
|
||||
"""
|
||||
super().__init__()
|
||||
# (h, w)
|
||||
|
||||
assert isinstance(resolution, tuple) and len(resolution) == 2
|
||||
self.num_heads = num_heads
|
||||
self.scale = key_dim ** -0.5
|
||||
self.scale = key_dim**-0.5
|
||||
self.key_dim = key_dim
|
||||
self.nh_kd = nh_kd = key_dim * num_heads
|
||||
self.d = int(attn_ratio * key_dim)
|
||||
@ -212,18 +286,20 @@ class Attention(torch.nn.Module):
|
||||
attention_offsets[offset] = len(attention_offsets)
|
||||
idxs.append(attention_offsets[offset])
|
||||
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
||||
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
|
||||
self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def train(self, mode=True):
|
||||
"""Sets the module in training mode and handles attribute 'ab' based on the mode."""
|
||||
super().train(mode)
|
||||
if mode and hasattr(self, 'ab'):
|
||||
if mode and hasattr(self, "ab"):
|
||||
del self.ab
|
||||
else:
|
||||
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
||||
|
||||
def forward(self, x): # x (B,N,C)
|
||||
B, N, _ = x.shape
|
||||
def forward(self, x): # x
|
||||
"""Performs forward pass over the input tensor 'x' by applying normalization and querying keys/values."""
|
||||
B, N, _ = x.shape # B, N, C
|
||||
|
||||
# Normalization
|
||||
x = self.norm(x)
|
||||
@ -237,28 +313,16 @@ class Attention(torch.nn.Module):
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
self.ab = self.ab.to(self.attention_biases.device)
|
||||
|
||||
attn = ((q @ k.transpose(-2, -1)) * self.scale +
|
||||
(self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab))
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale + (
|
||||
self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
|
||||
)
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class TinyViTBlock(nn.Module):
|
||||
"""
|
||||
TinyViT Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int, int]): Input resolution.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
local_conv_size (int): the kernel size of the convolution between Attention and MLP. Default: 3
|
||||
activation (torch.nn): the activation function. Default: nn.GELU
|
||||
"""
|
||||
"""TinyViT Block that applies self-attention and a local convolution to the input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -266,17 +330,35 @@ class TinyViTBlock(nn.Module):
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
mlp_ratio=4.0,
|
||||
drop=0.0,
|
||||
drop_path=0.0,
|
||||
local_conv_size=3,
|
||||
activation=nn.GELU,
|
||||
):
|
||||
"""
|
||||
Initializes the TinyViTBlock.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int, optional): Window size for attention. Default is 7.
|
||||
mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
|
||||
drop (float, optional): Dropout rate. Default is 0.
|
||||
drop_path (float, optional): Stochastic depth rate. Default is 0.
|
||||
local_conv_size (int, optional): The kernel size of the local convolution. Default is 3.
|
||||
activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `window_size` is not greater than 0.
|
||||
AssertionError: If `dim` is not divisible by `num_heads`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.num_heads = num_heads
|
||||
assert window_size > 0, 'window_size must be greater than 0'
|
||||
assert window_size > 0, "window_size must be greater than 0"
|
||||
self.window_size = window_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
@ -284,7 +366,7 @@ class TinyViTBlock(nn.Module):
|
||||
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.drop_path = nn.Identity()
|
||||
|
||||
assert dim % num_heads == 0, 'dim must be divisible by num_heads'
|
||||
assert dim % num_heads == 0, "dim must be divisible by num_heads"
|
||||
head_dim = dim // num_heads
|
||||
|
||||
window_resolution = (window_size, window_size)
|
||||
@ -298,9 +380,12 @@ class TinyViTBlock(nn.Module):
|
||||
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies attention-based transformation or padding to input 'x' before passing it through a local
|
||||
convolution.
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
res_x = x
|
||||
if H == self.window_size and W == self.window_size:
|
||||
x = self.attn(x)
|
||||
@ -316,11 +401,14 @@ class TinyViTBlock(nn.Module):
|
||||
pH, pW = H + pad_b, W + pad_r
|
||||
nH = pH // self.window_size
|
||||
nW = pW // self.window_size
|
||||
# window partition
|
||||
x = x.view(B, nH, self.window_size, nW, self.window_size,
|
||||
C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C)
|
||||
# Window partition
|
||||
x = (
|
||||
x.view(B, nH, self.window_size, nW, self.window_size, C)
|
||||
.transpose(2, 3)
|
||||
.reshape(B * nH * nW, self.window_size * self.window_size, C)
|
||||
)
|
||||
x = self.attn(x)
|
||||
# window reverse
|
||||
# Window reverse
|
||||
x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
|
||||
|
||||
if padding:
|
||||
@ -337,29 +425,17 @@ class TinyViTBlock(nn.Module):
|
||||
return x + self.drop_path(self.mlp(x))
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
|
||||
f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
|
||||
"""Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
|
||||
attentions heads, window size, and MLP ratio.
|
||||
"""
|
||||
return (
|
||||
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
||||
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
|
||||
)
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
"""
|
||||
A basic TinyViT layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
local_conv_size (int): the kernel size of the depthwise convolution between attention and MLP. Default: 3
|
||||
activation (torch.nn): the activation function. Default: nn.GELU
|
||||
out_dim (int | optional): the output dimension of the layer. Default: None
|
||||
"""
|
||||
"""A basic TinyViT layer for one stage in a TinyViT architecture."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -368,57 +444,90 @@ class BasicLayer(nn.Module):
|
||||
depth,
|
||||
num_heads,
|
||||
window_size,
|
||||
mlp_ratio=4.,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
mlp_ratio=4.0,
|
||||
drop=0.0,
|
||||
drop_path=0.0,
|
||||
downsample=None,
|
||||
use_checkpoint=False,
|
||||
local_conv_size=3,
|
||||
activation=nn.GELU,
|
||||
out_dim=None,
|
||||
):
|
||||
"""
|
||||
Initializes the BasicLayer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
||||
depth (int): Number of TinyViT blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
|
||||
drop (float, optional): Dropout rate. Default is 0.
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default is 0.
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default is None.
|
||||
use_checkpoint (bool, optional): Whether to use checkpointing to save memory. Default is False.
|
||||
local_conv_size (int, optional): Kernel size of the local convolution. Default is 3.
|
||||
activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
|
||||
out_dim (int | None, optional): The output dimension of the layer. Default is None.
|
||||
|
||||
Raises:
|
||||
ValueError: If `drop_path` is a list of float but its length doesn't match `depth`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
TinyViTBlock(
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
local_conv_size=local_conv_size,
|
||||
activation=activation,
|
||||
) for i in range(depth)])
|
||||
# Build blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
TinyViTBlock(
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
local_conv_size=local_conv_size,
|
||||
activation=activation,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
# patch merging layer
|
||||
self.downsample = None if downsample is None else downsample(
|
||||
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
# Patch merging layer
|
||||
self.downsample = (
|
||||
None
|
||||
if downsample is None
|
||||
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Performs forward propagation on the input tensor and returns a normalized tensor."""
|
||||
for blk in self.blocks:
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
|
||||
"""Returns a string representation of the extra_repr function with the layer's parameters."""
|
||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
"""A PyTorch implementation of Layer Normalization in 2D."""
|
||||
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||
"""Initialize LayerNorm2d with the number of channels and an optional epsilon."""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(num_channels))
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Perform a forward pass, normalizing the input tensor."""
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
@ -426,6 +535,30 @@ class LayerNorm2d(nn.Module):
|
||||
|
||||
|
||||
class TinyViT(nn.Module):
|
||||
"""
|
||||
The TinyViT architecture for vision tasks.
|
||||
|
||||
Attributes:
|
||||
img_size (int): Input image size.
|
||||
in_chans (int): Number of input channels.
|
||||
num_classes (int): Number of classification classes.
|
||||
embed_dims (List[int]): List of embedding dimensions for each layer.
|
||||
depths (List[int]): List of depths for each layer.
|
||||
num_heads (List[int]): List of number of attention heads for each layer.
|
||||
window_sizes (List[int]): List of window sizes for each layer.
|
||||
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
||||
drop_rate (float): Dropout rate for drop layers.
|
||||
drop_path_rate (float): Drop path rate for stochastic depth.
|
||||
use_checkpoint (bool): Use checkpointing for efficient memory usage.
|
||||
mbconv_expand_ratio (float): Expansion ratio for MBConv layer.
|
||||
local_conv_size (int): Local convolution kernel size.
|
||||
layer_lr_decay (float): Layer-wise learning rate decay.
|
||||
|
||||
Note:
|
||||
This implementation is generalized to accept a list of depths, attention heads,
|
||||
embedding dimensions and window sizes, which allows you to create a
|
||||
"stack" of TinyViT models of varying configurations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -436,14 +569,33 @@ class TinyViT(nn.Module):
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_sizes=[7, 7, 14, 7],
|
||||
mlp_ratio=4.,
|
||||
drop_rate=0.,
|
||||
mlp_ratio=4.0,
|
||||
drop_rate=0.0,
|
||||
drop_path_rate=0.1,
|
||||
use_checkpoint=False,
|
||||
mbconv_expand_ratio=4.0,
|
||||
local_conv_size=3,
|
||||
layer_lr_decay=1.0,
|
||||
):
|
||||
"""
|
||||
Initializes the TinyViT model.
|
||||
|
||||
Args:
|
||||
img_size (int, optional): The input image size. Defaults to 224.
|
||||
in_chans (int, optional): Number of input channels. Defaults to 3.
|
||||
num_classes (int, optional): Number of classification classes. Defaults to 1000.
|
||||
embed_dims (List[int], optional): List of embedding dimensions for each layer. Defaults to [96, 192, 384, 768].
|
||||
depths (List[int], optional): List of depths for each layer. Defaults to [2, 2, 6, 2].
|
||||
num_heads (List[int], optional): List of number of attention heads for each layer. Defaults to [3, 6, 12, 24].
|
||||
window_sizes (List[int], optional): List of window sizes for each layer. Defaults to [7, 7, 14, 7].
|
||||
mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension. Defaults to 4.
|
||||
drop_rate (float, optional): Dropout rate. Defaults to 0.
|
||||
drop_path_rate (float, optional): Drop path rate for stochastic depth. Defaults to 0.1.
|
||||
use_checkpoint (bool, optional): Whether to use checkpointing for efficient memory usage. Defaults to False.
|
||||
mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer. Defaults to 4.0.
|
||||
local_conv_size (int, optional): Local convolution kernel size. Defaults to 3.
|
||||
layer_lr_decay (float, optional): Layer-wise learning rate decay. Defaults to 1.0.
|
||||
"""
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
self.num_classes = num_classes
|
||||
@ -453,50 +605,52 @@ class TinyViT(nn.Module):
|
||||
|
||||
activation = nn.GELU
|
||||
|
||||
self.patch_embed = PatchEmbed(in_chans=in_chans,
|
||||
embed_dim=embed_dims[0],
|
||||
resolution=img_size,
|
||||
activation=activation)
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation
|
||||
)
|
||||
|
||||
patches_resolution = self.patch_embed.patches_resolution
|
||||
self.patches_resolution = patches_resolution
|
||||
|
||||
# stochastic depth
|
||||
# Stochastic depth
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
|
||||
# build layers
|
||||
# Build layers
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
kwargs = dict(
|
||||
dim=embed_dims[i_layer],
|
||||
input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
||||
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
|
||||
input_resolution=(
|
||||
patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
||||
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
||||
),
|
||||
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
||||
# patches_resolution[1] // (2 ** i_layer)),
|
||||
depth=depths[i_layer],
|
||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
||||
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
use_checkpoint=use_checkpoint,
|
||||
out_dim=embed_dims[min(i_layer + 1,
|
||||
len(embed_dims) - 1)],
|
||||
out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
|
||||
activation=activation,
|
||||
)
|
||||
if i_layer == 0:
|
||||
layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
|
||||
else:
|
||||
layer = BasicLayer(num_heads=num_heads[i_layer],
|
||||
window_size=window_sizes[i_layer],
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
drop=drop_rate,
|
||||
local_conv_size=local_conv_size,
|
||||
**kwargs)
|
||||
layer = BasicLayer(
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_sizes[i_layer],
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
drop=drop_rate,
|
||||
local_conv_size=local_conv_size,
|
||||
**kwargs,
|
||||
)
|
||||
self.layers.append(layer)
|
||||
|
||||
# Classifier head
|
||||
self.norm_head = nn.LayerNorm(embed_dims[-1])
|
||||
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
|
||||
|
||||
# init weights
|
||||
# Init weights
|
||||
self.apply(self._init_weights)
|
||||
self.set_layer_lr_decay(layer_lr_decay)
|
||||
self.neck = nn.Sequential(
|
||||
@ -518,13 +672,15 @@ class TinyViT(nn.Module):
|
||||
)
|
||||
|
||||
def set_layer_lr_decay(self, layer_lr_decay):
|
||||
"""Sets the learning rate decay for each layer in the TinyViT model."""
|
||||
decay_rate = layer_lr_decay
|
||||
|
||||
# layers -> blocks (depth)
|
||||
# Layers -> blocks (depth)
|
||||
depth = sum(self.depths)
|
||||
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
|
||||
|
||||
def _set_lr_scale(m, scale):
|
||||
"""Sets the learning rate scale for each layer in the model based on the layer's depth."""
|
||||
for p in m.parameters():
|
||||
p.lr_scale = scale
|
||||
|
||||
@ -544,12 +700,14 @@ class TinyViT(nn.Module):
|
||||
p.param_name = k
|
||||
|
||||
def _check_lr_scale(m):
|
||||
"""Checks if the learning rate scale attribute is present in module's parameters."""
|
||||
for p in m.parameters():
|
||||
assert hasattr(p, 'lr_scale'), p.param_name
|
||||
assert hasattr(p, "lr_scale"), p.param_name
|
||||
|
||||
self.apply(_check_lr_scale)
|
||||
|
||||
def _init_weights(self, m):
|
||||
"""Initializes weights for linear layers and layer normalization in the given module."""
|
||||
if isinstance(m, nn.Linear):
|
||||
# NOTE: This initialization is needed only for training.
|
||||
# trunc_normal_(m.weight, std=.02)
|
||||
@ -561,11 +719,12 @@ class TinyViT(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay_keywords(self):
|
||||
return {'attention_biases'}
|
||||
"""Returns a dictionary of parameter names where weight decay should not be applied."""
|
||||
return {"attention_biases"}
|
||||
|
||||
def forward_features(self, x):
|
||||
# x: (N, C, H, W)
|
||||
x = self.patch_embed(x)
|
||||
"""Runs the input through the model layers and returns the transformed output."""
|
||||
x = self.patch_embed(x) # x input is (N, C, H, W)
|
||||
|
||||
x = self.layers[0](x)
|
||||
start_i = 1
|
||||
@ -573,10 +732,11 @@ class TinyViT(nn.Module):
|
||||
for i in range(start_i, len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
x = layer(x)
|
||||
B, _, C = x.size()
|
||||
B, _, C = x.shape
|
||||
x = x.view(B, 64, 64, C)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
return self.neck(x)
|
||||
|
||||
def forward(self, x):
|
||||
"""Executes a forward pass on the input tensor through the constructed model layers."""
|
||||
return self.forward_features(x)
|
||||
|
@ -10,6 +10,21 @@ from ultralytics.nn.modules import MLPBlock
|
||||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
"""
|
||||
A Two-Way Transformer module that enables the simultaneous attention to both image and query points. This class
|
||||
serves as a specialized transformer decoder that attends to an input image using queries whose positional embedding
|
||||
is supplied. This is particularly useful for tasks like object detection, image segmentation, and point cloud
|
||||
processing.
|
||||
|
||||
Attributes:
|
||||
depth (int): The number of layers in the transformer.
|
||||
embedding_dim (int): The channel dimension for the input embeddings.
|
||||
num_heads (int): The number of heads for multihead attention.
|
||||
mlp_dim (int): The internal channel dimension for the MLP block.
|
||||
layers (nn.ModuleList): The list of TwoWayAttentionBlock layers that make up the transformer.
|
||||
final_attn_token_to_image (Attention): The final attention layer applied from the queries to the image.
|
||||
norm_final_attn (nn.LayerNorm): The layer normalization applied to the final queries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -21,8 +36,7 @@ class TwoWayTransformer(nn.Module):
|
||||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer decoder that attends to an input image using
|
||||
queries whose positional embedding is supplied.
|
||||
A transformer decoder that attends to an input image using queries whose positional embedding is supplied.
|
||||
|
||||
Args:
|
||||
depth (int): number of layers in the transformer
|
||||
@ -48,7 +62,8 @@ class TwoWayTransformer(nn.Module):
|
||||
activation=activation,
|
||||
attention_downsample_rate=attention_downsample_rate,
|
||||
skip_first_layer_pe=(i == 0),
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
||||
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
||||
@ -99,6 +114,23 @@ class TwoWayTransformer(nn.Module):
|
||||
|
||||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that performs both self-attention and cross-attention in two directions: queries to keys and
|
||||
keys to queries. This block consists of four main layers: (1) self-attention on sparse inputs, (2) cross-attention
|
||||
of sparse inputs to dense inputs, (3) an MLP block on sparse inputs, and (4) cross-attention of dense inputs to
|
||||
sparse inputs.
|
||||
|
||||
Attributes:
|
||||
self_attn (Attention): The self-attention layer for the queries.
|
||||
norm1 (nn.LayerNorm): Layer normalization following the first attention block.
|
||||
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
|
||||
norm2 (nn.LayerNorm): Layer normalization following the second attention block.
|
||||
mlp (MLPBlock): MLP block that transforms the query embeddings.
|
||||
norm3 (nn.LayerNorm): Layer normalization following the MLP block.
|
||||
norm4 (nn.LayerNorm): Layer normalization following the third attention block.
|
||||
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
|
||||
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -171,8 +203,7 @@ class TwoWayAttentionBlock(nn.Module):
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||
"""An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||
values.
|
||||
"""
|
||||
|
||||
@ -182,24 +213,37 @@ class Attention(nn.Module):
|
||||
num_heads: int,
|
||||
downsample_rate: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the Attention model with the given dimensions and settings.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): The dimensionality of the input embeddings.
|
||||
num_heads (int): The number of attention heads.
|
||||
downsample_rate (int, optional): The factor by which the internal dimensions are downsampled. Defaults to 1.
|
||||
|
||||
Raises:
|
||||
AssertionError: If 'num_heads' does not evenly divide the internal dimension (embedding_dim / downsample_rate).
|
||||
"""
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.internal_dim = embedding_dim // downsample_rate
|
||||
self.num_heads = num_heads
|
||||
assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.'
|
||||
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
||||
|
||||
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
||||
|
||||
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
||||
@staticmethod
|
||||
def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
|
||||
"""Separate the input tensor into the specified number of attention heads."""
|
||||
b, n, c = x.shape
|
||||
x = x.reshape(b, n, num_heads, c // num_heads)
|
||||
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
||||
|
||||
def _recombine_heads(self, x: Tensor) -> Tensor:
|
||||
@staticmethod
|
||||
def _recombine_heads(x: Tensor) -> Tensor:
|
||||
"""Recombine the separated attention heads into a single tensor."""
|
||||
b, n_heads, n_tokens, c_per_head = x.shape
|
||||
x = x.transpose(1, 2)
|
||||
|
@ -1,4 +1,12 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Generate predictions using the Segment Anything Model (SAM).
|
||||
|
||||
SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance.
|
||||
This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation
|
||||
using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image
|
||||
segmentation tasks.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -10,129 +18,155 @@ from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
from ultralytics.utils.torch_utils import select_device
|
||||
|
||||
from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score,
|
||||
generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks)
|
||||
from .amg import (
|
||||
batch_iterator,
|
||||
batched_mask_to_box,
|
||||
build_all_layer_point_grids,
|
||||
calculate_stability_score,
|
||||
generate_crop_boxes,
|
||||
is_box_near_crop_edge,
|
||||
remove_small_regions,
|
||||
uncrop_boxes_xyxy,
|
||||
uncrop_masks,
|
||||
)
|
||||
from .build import build_sam
|
||||
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
"""
|
||||
Predictor class for the Segment Anything Model (SAM), extending BasePredictor.
|
||||
|
||||
The class provides an interface for model inference tailored to image segmentation tasks.
|
||||
With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time
|
||||
mask generation. The class is capable of working with various types of prompts such as bounding boxes,
|
||||
points, and low-resolution masks.
|
||||
|
||||
Attributes:
|
||||
cfg (dict): Configuration dictionary specifying model and task-related parameters.
|
||||
overrides (dict): Dictionary containing values that override the default configuration.
|
||||
_callbacks (dict): Dictionary of user-defined callback functions to augment behavior.
|
||||
args (namespace): Namespace to hold command-line arguments or other operational variables.
|
||||
im (torch.Tensor): Preprocessed input image tensor.
|
||||
features (torch.Tensor): Extracted image features used for inference.
|
||||
prompts (dict): Collection of various prompt types, such as bounding boxes and points.
|
||||
segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the Predictor with configuration, overrides, and callbacks.
|
||||
|
||||
The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It
|
||||
initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration dictionary.
|
||||
overrides (dict, optional): Dictionary of values to override default configuration.
|
||||
_callbacks (dict, optional): Dictionary of callback functions to customize behavior.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
|
||||
overrides.update(dict(task="segment", mode="predict", imgsz=1024))
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
# SAM needs retina_masks=True, or the results would be a mess.
|
||||
self.args.retina_masks = True
|
||||
# Args for set_image
|
||||
self.im = None
|
||||
self.features = None
|
||||
# Args for set_prompts
|
||||
self.prompts = {}
|
||||
# Args for segment everything
|
||||
self.segment_all = False
|
||||
|
||||
def preprocess(self, im):
|
||||
"""Prepares input image before inference.
|
||||
"""
|
||||
Preprocess the input image for model inference.
|
||||
|
||||
The method prepares the input image by applying transformations and normalization.
|
||||
It supports both torch.Tensor and list of np.ndarray as input formats.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
|
||||
im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The preprocessed image tensor.
|
||||
"""
|
||||
if self.im is not None:
|
||||
return self.im
|
||||
not_tensor = not isinstance(im, torch.Tensor)
|
||||
if not_tensor:
|
||||
im = np.stack(self.pre_transform(im))
|
||||
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
im = im[..., ::-1].transpose((0, 3, 1, 2))
|
||||
im = np.ascontiguousarray(im)
|
||||
im = torch.from_numpy(im)
|
||||
|
||||
im = im.to(self.device)
|
||||
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
|
||||
im = im.half() if self.model.fp16 else im.float()
|
||||
if not_tensor:
|
||||
im = (im - self.mean) / self.std
|
||||
return im
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
Pre-transform input image before inference.
|
||||
Perform initial transformations on the input image for preprocessing.
|
||||
|
||||
The method applies transformations such as resizing to prepare the image for further preprocessing.
|
||||
Currently, batched inference is not supported; hence the list length should be 1.
|
||||
|
||||
Args:
|
||||
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
||||
im (List[np.ndarray]): List containing images in HWC numpy array format.
|
||||
|
||||
Returns:
|
||||
(list): A list of transformed images.
|
||||
(List[np.ndarray]): List of transformed images.
|
||||
"""
|
||||
assert len(im) == 1, 'SAM model does not currently support batched inference'
|
||||
assert len(im) == 1, "SAM model does not currently support batched inference"
|
||||
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
|
||||
return [letterbox(image=x) for x in im]
|
||||
|
||||
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
|
||||
"""
|
||||
Predict masks for the given input prompts, using the currently set image.
|
||||
Perform image segmentation inference based on the given input cues, using the currently loaded image. This
|
||||
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
|
||||
mask decoder for real-time and promptable segmentation tasks.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
||||
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
|
||||
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
|
||||
1 indicates a foreground point and 0 indicates a background point.
|
||||
masks (np.ndarray, None): A low resolution mask input to the model, typically
|
||||
coming from a previous prediction iteration. Has form (N, H, W), where
|
||||
for SAM, H=W=256.
|
||||
multimask_output (bool): If true, the model will return three masks.
|
||||
For ambiguous input prompts (such as a single click), this will often
|
||||
produce better masks than a single prediction. If only a single
|
||||
mask is needed, the model's predicted quality score can be used
|
||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
||||
input prompts, multimask_output=False can give better results.
|
||||
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates.
|
||||
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background.
|
||||
masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256.
|
||||
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The output masks in CxHxW format, where C is the
|
||||
number of masks, and (H, W) is the original image size.
|
||||
(np.ndarray): An array of length C containing the model's
|
||||
predictions for the quality of each mask.
|
||||
(np.ndarray): An array of shape CxHxW, where C is the number
|
||||
of masks and H=W=256. These low resolution logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
(tuple): Contains the following three elements.
|
||||
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
|
||||
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
||||
"""
|
||||
# Get prompts from self.prompts first
|
||||
bboxes = self.prompts.pop('bboxes', bboxes)
|
||||
points = self.prompts.pop('points', points)
|
||||
masks = self.prompts.pop('masks', masks)
|
||||
# Override prompts if any stored in self.prompts
|
||||
bboxes = self.prompts.pop("bboxes", bboxes)
|
||||
points = self.prompts.pop("points", points)
|
||||
masks = self.prompts.pop("masks", masks)
|
||||
|
||||
if all(i is None for i in [bboxes, points, masks]):
|
||||
return self.generate(im, *args, **kwargs)
|
||||
|
||||
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
||||
|
||||
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
|
||||
"""
|
||||
Predict masks for the given input prompts, using the currently set image.
|
||||
Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
|
||||
Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
||||
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
|
||||
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
|
||||
1 indicates a foreground point and 0 indicates a background point.
|
||||
masks (np.ndarray, None): A low resolution mask input to the model, typically
|
||||
coming from a previous prediction iteration. Has form (N, H, W), where
|
||||
for SAM, H=W=256.
|
||||
multimask_output (bool): If true, the model will return three masks.
|
||||
For ambiguous input prompts (such as a single click), this will often
|
||||
produce better masks than a single prediction. If only a single
|
||||
mask is needed, the model's predicted quality score can be used
|
||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
||||
input prompts, multimask_output=False can give better results.
|
||||
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates.
|
||||
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background.
|
||||
masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256.
|
||||
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The output masks in CxHxW format, where C is the
|
||||
number of masks, and (H, W) is the original image size.
|
||||
(np.ndarray): An array of length C containing the model's
|
||||
predictions for the quality of each mask.
|
||||
(np.ndarray): An array of shape CxHxW, where C is the number
|
||||
of masks and H=W=256. These low resolution logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
(tuple): Contains the following three elements.
|
||||
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
|
||||
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
||||
"""
|
||||
features = self.model.image_encoder(im) if self.features is None else self.features
|
||||
|
||||
@ -158,11 +192,7 @@ class Predictor(BasePredictor):
|
||||
|
||||
points = (points, labels) if points is not None else None
|
||||
# Embed prompts
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
||||
points=points,
|
||||
boxes=bboxes,
|
||||
masks=masks,
|
||||
)
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
|
||||
|
||||
# Predict masks
|
||||
pred_masks, pred_scores = self.model.mask_decoder(
|
||||
@ -177,58 +207,50 @@ class Predictor(BasePredictor):
|
||||
# `d` could be 1 or 3 depends on `multimask_output`.
|
||||
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
||||
|
||||
def generate(self,
|
||||
im,
|
||||
crop_n_layers=0,
|
||||
crop_overlap_ratio=512 / 1500,
|
||||
crop_downscale_factor=1,
|
||||
point_grids=None,
|
||||
points_stride=32,
|
||||
points_batch_size=64,
|
||||
conf_thres=0.88,
|
||||
stability_score_thresh=0.95,
|
||||
stability_score_offset=0.95,
|
||||
crop_nms_thresh=0.7):
|
||||
"""Segment the whole image.
|
||||
def generate(
|
||||
self,
|
||||
im,
|
||||
crop_n_layers=0,
|
||||
crop_overlap_ratio=512 / 1500,
|
||||
crop_downscale_factor=1,
|
||||
point_grids=None,
|
||||
points_stride=32,
|
||||
points_batch_size=64,
|
||||
conf_thres=0.88,
|
||||
stability_score_thresh=0.95,
|
||||
stability_score_offset=0.95,
|
||||
crop_nms_thresh=0.7,
|
||||
):
|
||||
"""
|
||||
Perform image segmentation using the Segment Anything Model (SAM).
|
||||
|
||||
This function segments an entire image into constituent parts by leveraging SAM's advanced architecture
|
||||
and real-time performance capabilities. It can optionally work on image crops for finer segmentation.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
||||
crop_n_layers (int): If >0, mask prediction will be run again on
|
||||
crops of the image. Sets the number of layers to run, where each
|
||||
layer has 2**i_layer number of image crops.
|
||||
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
||||
In the first crop layer, crops will overlap by this fraction of
|
||||
the image length. Later layers with more crops scale down this overlap.
|
||||
crop_downscale_factor (int): The number of points-per-side
|
||||
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
point_grids (list(np.ndarray), None): A list over explicit grids
|
||||
of points used for sampling, normalized to [0,1]. The nth grid in the
|
||||
list is used in the nth crop layer. Exclusive with points_per_side.
|
||||
points_stride (int, None): The number of points to be sampled
|
||||
along one side of the image. The total number of points is
|
||||
points_per_side**2. If None, 'point_grids' must provide explicit
|
||||
point sampling.
|
||||
points_batch_size (int): Sets the number of points run simultaneously
|
||||
by the model. Higher numbers may be faster but use more GPU memory.
|
||||
conf_thres (float): A filtering threshold in [0,1], using the
|
||||
model's predicted mask quality.
|
||||
stability_score_thresh (float): A filtering threshold in [0,1], using
|
||||
the stability of the mask under changes to the cutoff used to binarize
|
||||
the model's mask predictions.
|
||||
stability_score_offset (float): The amount to shift the cutoff when
|
||||
calculated the stability score.
|
||||
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks between different crops.
|
||||
im (torch.Tensor): Input tensor representing the preprocessed image with dimensions (N, C, H, W).
|
||||
crop_n_layers (int): Specifies the number of layers for additional mask predictions on image crops.
|
||||
Each layer produces 2**i_layer number of image crops.
|
||||
crop_overlap_ratio (float): Determines the extent of overlap between crops. Scaled down in subsequent layers.
|
||||
crop_downscale_factor (int): Scaling factor for the number of sampled points-per-side in each layer.
|
||||
point_grids (list[np.ndarray], optional): Custom grids for point sampling normalized to [0,1].
|
||||
Used in the nth crop layer.
|
||||
points_stride (int, optional): Number of points to sample along each side of the image.
|
||||
Exclusive with 'point_grids'.
|
||||
points_batch_size (int): Batch size for the number of points processed simultaneously.
|
||||
conf_thres (float): Confidence threshold [0,1] for filtering based on the model's mask quality prediction.
|
||||
stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on mask stability.
|
||||
stability_score_offset (float): Offset value for calculating stability score.
|
||||
crop_nms_thresh (float): IoU cutoff for Non-Maximum Suppression (NMS) to remove duplicate masks between crops.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing segmented masks, confidence scores, and bounding boxes.
|
||||
"""
|
||||
self.segment_all = True
|
||||
ih, iw = im.shape[2:]
|
||||
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
|
||||
if point_grids is None:
|
||||
point_grids = build_all_layer_point_grids(
|
||||
points_stride,
|
||||
crop_n_layers,
|
||||
crop_downscale_factor,
|
||||
)
|
||||
point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor)
|
||||
pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
|
||||
for crop_region, layer_idx in zip(crop_regions, layer_idxs):
|
||||
x1, y1, x2, y2 = crop_region
|
||||
@ -236,19 +258,20 @@ class Predictor(BasePredictor):
|
||||
area = torch.tensor(w * h, device=im.device)
|
||||
points_scale = np.array([[w, h]]) # w, h
|
||||
# Crop image and interpolate to input size
|
||||
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False)
|
||||
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False)
|
||||
# (num_points, 2)
|
||||
points_for_image = point_grids[layer_idx] * points_scale
|
||||
crop_masks, crop_scores, crop_bboxes = [], [], []
|
||||
for (points, ) in batch_iterator(points_batch_size, points_for_image):
|
||||
for (points,) in batch_iterator(points_batch_size, points_for_image):
|
||||
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
|
||||
# Interpolate predicted masks to input size
|
||||
pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0]
|
||||
pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0]
|
||||
idx = pred_score > conf_thres
|
||||
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
||||
|
||||
stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold,
|
||||
stability_score_offset)
|
||||
stability_score = calculate_stability_score(
|
||||
pred_mask, self.model.mask_threshold, stability_score_offset
|
||||
)
|
||||
idx = stability_score > stability_score_thresh
|
||||
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
||||
# Bool type is much more memory-efficient.
|
||||
@ -291,7 +314,22 @@ class Predictor(BasePredictor):
|
||||
return pred_masks, pred_scores, pred_bboxes
|
||||
|
||||
def setup_model(self, model, verbose=True):
|
||||
"""Set up YOLO model with specified thresholds and device."""
|
||||
"""
|
||||
Initializes the Segment Anything Model (SAM) for inference.
|
||||
|
||||
This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
|
||||
parameters for image normalization and other Ultralytics compatibility settings.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration.
|
||||
verbose (bool): If True, prints selected device information.
|
||||
|
||||
Attributes:
|
||||
model (torch.nn.Module): The SAM model allocated to the chosen device for inference.
|
||||
device (torch.device): The device to which the model and tensors are allocated.
|
||||
mean (torch.Tensor): The mean values for image normalization.
|
||||
std (torch.Tensor): The standard deviation values for image normalization.
|
||||
"""
|
||||
device = select_device(self.args.device, verbose=verbose)
|
||||
if model is None:
|
||||
model = build_sam(self.args.model)
|
||||
@ -300,7 +338,8 @@ class Predictor(BasePredictor):
|
||||
self.device = device
|
||||
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
|
||||
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
|
||||
# TODO: Temporary settings for compatibility
|
||||
|
||||
# Ultralytics compatibility settings
|
||||
self.model.pt = False
|
||||
self.model.triton = False
|
||||
self.model.stride = 32
|
||||
@ -308,7 +347,20 @@ class Predictor(BasePredictor):
|
||||
self.done_warmup = True
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Post-processes inference output predictions to create detection masks for objects."""
|
||||
"""
|
||||
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
|
||||
|
||||
The method scales masks and boxes to the original image size and applies a threshold to the mask predictions. The
|
||||
SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance.
|
||||
|
||||
Args:
|
||||
preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes.
|
||||
img (torch.Tensor): The processed input image tensor.
|
||||
orig_imgs (list | torch.Tensor): The original, unprocessed images.
|
||||
|
||||
Returns:
|
||||
(list): List of Results objects containing detection masks, bounding boxes, and other metadata.
|
||||
"""
|
||||
# (N, 1, H, W), (N, 1)
|
||||
pred_masks, pred_scores = preds[:2]
|
||||
pred_bboxes = preds[2] if self.segment_all else None
|
||||
@ -334,21 +386,36 @@ class Predictor(BasePredictor):
|
||||
return results
|
||||
|
||||
def setup_source(self, source):
|
||||
"""Sets up source and inference mode."""
|
||||
"""
|
||||
Sets up the data source for inference.
|
||||
|
||||
This method configures the data source from which images will be fetched for inference. The source could be a
|
||||
directory, a video file, or other types of image data sources.
|
||||
|
||||
Args:
|
||||
source (str | Path): The path to the image data source for inference.
|
||||
"""
|
||||
if source is not None:
|
||||
super().setup_source(source)
|
||||
|
||||
def set_image(self, image):
|
||||
"""Set image in advance.
|
||||
Args:
|
||||
"""
|
||||
Preprocesses and sets a single image for inference.
|
||||
|
||||
image (str | np.ndarray): image file path or np.ndarray image by cv2.
|
||||
This function sets up the model if not already initialized, configures the data source to the specified image,
|
||||
and preprocesses the image for feature extraction. Only one image can be set at a time.
|
||||
|
||||
Args:
|
||||
image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2.
|
||||
|
||||
Raises:
|
||||
AssertionError: If more than one image is set.
|
||||
"""
|
||||
if self.model is None:
|
||||
model = build_sam(self.args.model)
|
||||
self.setup_model(model)
|
||||
self.setup_source(image)
|
||||
assert len(self.dataset) == 1, '`set_image` only supports setting one image!'
|
||||
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
|
||||
for batch in self.dataset:
|
||||
im = self.preprocess(batch[1])
|
||||
self.features = self.model.image_encoder(im)
|
||||
@ -360,23 +427,27 @@ class Predictor(BasePredictor):
|
||||
self.prompts = prompts
|
||||
|
||||
def reset_image(self):
|
||||
"""Resets the image and its features to None."""
|
||||
self.im = None
|
||||
self.features = None
|
||||
|
||||
@staticmethod
|
||||
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
|
||||
"""
|
||||
Removes small disconnected regions and holes in masks, then reruns
|
||||
box NMS to remove any new duplicates. Requires open-cv as a dependency.
|
||||
Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this
|
||||
function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
|
||||
Suppression (NMS) to eliminate any newly created duplicate boxes.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Masks, (N, H, W).
|
||||
min_area (int): Minimum area threshold.
|
||||
nms_thresh (float): NMS threshold.
|
||||
masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is
|
||||
the number of masks, H is height, and W is width.
|
||||
min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0.
|
||||
nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7.
|
||||
|
||||
Returns:
|
||||
new_masks (torch.Tensor): New Masks, (N, H, W).
|
||||
keep (List[int]): The indices of the new masks, which can be used to filter
|
||||
the corresponding boxes.
|
||||
(tuple([torch.Tensor, List[int]])):
|
||||
- new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W).
|
||||
- keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes.
|
||||
"""
|
||||
if len(masks) == 0:
|
||||
return masks
|
||||
@ -386,23 +457,18 @@ class Predictor(BasePredictor):
|
||||
scores = []
|
||||
for mask in masks:
|
||||
mask = mask.cpu().numpy().astype(np.uint8)
|
||||
mask, changed = remove_small_regions(mask, min_area, mode='holes')
|
||||
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
||||
unchanged = not changed
|
||||
mask, changed = remove_small_regions(mask, min_area, mode='islands')
|
||||
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
||||
unchanged = unchanged and not changed
|
||||
|
||||
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
||||
# Give score=0 to changed masks and score=1 to unchanged masks
|
||||
# so NMS will prefer ones that didn't need postprocessing
|
||||
# Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing
|
||||
scores.append(float(unchanged))
|
||||
|
||||
# Recalculate boxes and remove any new duplicates
|
||||
new_masks = torch.cat(new_masks, dim=0)
|
||||
boxes = batched_mask_to_box(new_masks)
|
||||
keep = torchvision.ops.nms(
|
||||
boxes.float(),
|
||||
torch.as_tensor(scores),
|
||||
nms_thresh,
|
||||
)
|
||||
keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)
|
||||
|
||||
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
|
||||
|
Reference in New Issue
Block a user