add yolo v10 and modify pipeline
This commit is contained in:
@ -1,4 +1,12 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Generate predictions using the Segment Anything Model (SAM).
|
||||
|
||||
SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance.
|
||||
This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation
|
||||
using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image
|
||||
segmentation tasks.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -10,129 +18,155 @@ from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
from ultralytics.utils.torch_utils import select_device
|
||||
|
||||
from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score,
|
||||
generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks)
|
||||
from .amg import (
|
||||
batch_iterator,
|
||||
batched_mask_to_box,
|
||||
build_all_layer_point_grids,
|
||||
calculate_stability_score,
|
||||
generate_crop_boxes,
|
||||
is_box_near_crop_edge,
|
||||
remove_small_regions,
|
||||
uncrop_boxes_xyxy,
|
||||
uncrop_masks,
|
||||
)
|
||||
from .build import build_sam
|
||||
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
"""
|
||||
Predictor class for the Segment Anything Model (SAM), extending BasePredictor.
|
||||
|
||||
The class provides an interface for model inference tailored to image segmentation tasks.
|
||||
With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time
|
||||
mask generation. The class is capable of working with various types of prompts such as bounding boxes,
|
||||
points, and low-resolution masks.
|
||||
|
||||
Attributes:
|
||||
cfg (dict): Configuration dictionary specifying model and task-related parameters.
|
||||
overrides (dict): Dictionary containing values that override the default configuration.
|
||||
_callbacks (dict): Dictionary of user-defined callback functions to augment behavior.
|
||||
args (namespace): Namespace to hold command-line arguments or other operational variables.
|
||||
im (torch.Tensor): Preprocessed input image tensor.
|
||||
features (torch.Tensor): Extracted image features used for inference.
|
||||
prompts (dict): Collection of various prompt types, such as bounding boxes and points.
|
||||
segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the Predictor with configuration, overrides, and callbacks.
|
||||
|
||||
The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It
|
||||
initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration dictionary.
|
||||
overrides (dict, optional): Dictionary of values to override default configuration.
|
||||
_callbacks (dict, optional): Dictionary of callback functions to customize behavior.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
|
||||
overrides.update(dict(task="segment", mode="predict", imgsz=1024))
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
# SAM needs retina_masks=True, or the results would be a mess.
|
||||
self.args.retina_masks = True
|
||||
# Args for set_image
|
||||
self.im = None
|
||||
self.features = None
|
||||
# Args for set_prompts
|
||||
self.prompts = {}
|
||||
# Args for segment everything
|
||||
self.segment_all = False
|
||||
|
||||
def preprocess(self, im):
|
||||
"""Prepares input image before inference.
|
||||
"""
|
||||
Preprocess the input image for model inference.
|
||||
|
||||
The method prepares the input image by applying transformations and normalization.
|
||||
It supports both torch.Tensor and list of np.ndarray as input formats.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
|
||||
im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The preprocessed image tensor.
|
||||
"""
|
||||
if self.im is not None:
|
||||
return self.im
|
||||
not_tensor = not isinstance(im, torch.Tensor)
|
||||
if not_tensor:
|
||||
im = np.stack(self.pre_transform(im))
|
||||
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
im = im[..., ::-1].transpose((0, 3, 1, 2))
|
||||
im = np.ascontiguousarray(im)
|
||||
im = torch.from_numpy(im)
|
||||
|
||||
im = im.to(self.device)
|
||||
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
|
||||
im = im.half() if self.model.fp16 else im.float()
|
||||
if not_tensor:
|
||||
im = (im - self.mean) / self.std
|
||||
return im
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
Pre-transform input image before inference.
|
||||
Perform initial transformations on the input image for preprocessing.
|
||||
|
||||
The method applies transformations such as resizing to prepare the image for further preprocessing.
|
||||
Currently, batched inference is not supported; hence the list length should be 1.
|
||||
|
||||
Args:
|
||||
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
||||
im (List[np.ndarray]): List containing images in HWC numpy array format.
|
||||
|
||||
Returns:
|
||||
(list): A list of transformed images.
|
||||
(List[np.ndarray]): List of transformed images.
|
||||
"""
|
||||
assert len(im) == 1, 'SAM model does not currently support batched inference'
|
||||
assert len(im) == 1, "SAM model does not currently support batched inference"
|
||||
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
|
||||
return [letterbox(image=x) for x in im]
|
||||
|
||||
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
|
||||
"""
|
||||
Predict masks for the given input prompts, using the currently set image.
|
||||
Perform image segmentation inference based on the given input cues, using the currently loaded image. This
|
||||
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
|
||||
mask decoder for real-time and promptable segmentation tasks.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
||||
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
|
||||
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
|
||||
1 indicates a foreground point and 0 indicates a background point.
|
||||
masks (np.ndarray, None): A low resolution mask input to the model, typically
|
||||
coming from a previous prediction iteration. Has form (N, H, W), where
|
||||
for SAM, H=W=256.
|
||||
multimask_output (bool): If true, the model will return three masks.
|
||||
For ambiguous input prompts (such as a single click), this will often
|
||||
produce better masks than a single prediction. If only a single
|
||||
mask is needed, the model's predicted quality score can be used
|
||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
||||
input prompts, multimask_output=False can give better results.
|
||||
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates.
|
||||
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background.
|
||||
masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256.
|
||||
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The output masks in CxHxW format, where C is the
|
||||
number of masks, and (H, W) is the original image size.
|
||||
(np.ndarray): An array of length C containing the model's
|
||||
predictions for the quality of each mask.
|
||||
(np.ndarray): An array of shape CxHxW, where C is the number
|
||||
of masks and H=W=256. These low resolution logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
(tuple): Contains the following three elements.
|
||||
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
|
||||
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
||||
"""
|
||||
# Get prompts from self.prompts first
|
||||
bboxes = self.prompts.pop('bboxes', bboxes)
|
||||
points = self.prompts.pop('points', points)
|
||||
masks = self.prompts.pop('masks', masks)
|
||||
# Override prompts if any stored in self.prompts
|
||||
bboxes = self.prompts.pop("bboxes", bboxes)
|
||||
points = self.prompts.pop("points", points)
|
||||
masks = self.prompts.pop("masks", masks)
|
||||
|
||||
if all(i is None for i in [bboxes, points, masks]):
|
||||
return self.generate(im, *args, **kwargs)
|
||||
|
||||
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
||||
|
||||
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
|
||||
"""
|
||||
Predict masks for the given input prompts, using the currently set image.
|
||||
Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
|
||||
Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
||||
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
|
||||
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
|
||||
1 indicates a foreground point and 0 indicates a background point.
|
||||
masks (np.ndarray, None): A low resolution mask input to the model, typically
|
||||
coming from a previous prediction iteration. Has form (N, H, W), where
|
||||
for SAM, H=W=256.
|
||||
multimask_output (bool): If true, the model will return three masks.
|
||||
For ambiguous input prompts (such as a single click), this will often
|
||||
produce better masks than a single prediction. If only a single
|
||||
mask is needed, the model's predicted quality score can be used
|
||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
||||
input prompts, multimask_output=False can give better results.
|
||||
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates.
|
||||
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background.
|
||||
masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256.
|
||||
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The output masks in CxHxW format, where C is the
|
||||
number of masks, and (H, W) is the original image size.
|
||||
(np.ndarray): An array of length C containing the model's
|
||||
predictions for the quality of each mask.
|
||||
(np.ndarray): An array of shape CxHxW, where C is the number
|
||||
of masks and H=W=256. These low resolution logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
(tuple): Contains the following three elements.
|
||||
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
|
||||
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
||||
"""
|
||||
features = self.model.image_encoder(im) if self.features is None else self.features
|
||||
|
||||
@ -158,11 +192,7 @@ class Predictor(BasePredictor):
|
||||
|
||||
points = (points, labels) if points is not None else None
|
||||
# Embed prompts
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
||||
points=points,
|
||||
boxes=bboxes,
|
||||
masks=masks,
|
||||
)
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
|
||||
|
||||
# Predict masks
|
||||
pred_masks, pred_scores = self.model.mask_decoder(
|
||||
@ -177,58 +207,50 @@ class Predictor(BasePredictor):
|
||||
# `d` could be 1 or 3 depends on `multimask_output`.
|
||||
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
||||
|
||||
def generate(self,
|
||||
im,
|
||||
crop_n_layers=0,
|
||||
crop_overlap_ratio=512 / 1500,
|
||||
crop_downscale_factor=1,
|
||||
point_grids=None,
|
||||
points_stride=32,
|
||||
points_batch_size=64,
|
||||
conf_thres=0.88,
|
||||
stability_score_thresh=0.95,
|
||||
stability_score_offset=0.95,
|
||||
crop_nms_thresh=0.7):
|
||||
"""Segment the whole image.
|
||||
def generate(
|
||||
self,
|
||||
im,
|
||||
crop_n_layers=0,
|
||||
crop_overlap_ratio=512 / 1500,
|
||||
crop_downscale_factor=1,
|
||||
point_grids=None,
|
||||
points_stride=32,
|
||||
points_batch_size=64,
|
||||
conf_thres=0.88,
|
||||
stability_score_thresh=0.95,
|
||||
stability_score_offset=0.95,
|
||||
crop_nms_thresh=0.7,
|
||||
):
|
||||
"""
|
||||
Perform image segmentation using the Segment Anything Model (SAM).
|
||||
|
||||
This function segments an entire image into constituent parts by leveraging SAM's advanced architecture
|
||||
and real-time performance capabilities. It can optionally work on image crops for finer segmentation.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
||||
crop_n_layers (int): If >0, mask prediction will be run again on
|
||||
crops of the image. Sets the number of layers to run, where each
|
||||
layer has 2**i_layer number of image crops.
|
||||
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
||||
In the first crop layer, crops will overlap by this fraction of
|
||||
the image length. Later layers with more crops scale down this overlap.
|
||||
crop_downscale_factor (int): The number of points-per-side
|
||||
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
point_grids (list(np.ndarray), None): A list over explicit grids
|
||||
of points used for sampling, normalized to [0,1]. The nth grid in the
|
||||
list is used in the nth crop layer. Exclusive with points_per_side.
|
||||
points_stride (int, None): The number of points to be sampled
|
||||
along one side of the image. The total number of points is
|
||||
points_per_side**2. If None, 'point_grids' must provide explicit
|
||||
point sampling.
|
||||
points_batch_size (int): Sets the number of points run simultaneously
|
||||
by the model. Higher numbers may be faster but use more GPU memory.
|
||||
conf_thres (float): A filtering threshold in [0,1], using the
|
||||
model's predicted mask quality.
|
||||
stability_score_thresh (float): A filtering threshold in [0,1], using
|
||||
the stability of the mask under changes to the cutoff used to binarize
|
||||
the model's mask predictions.
|
||||
stability_score_offset (float): The amount to shift the cutoff when
|
||||
calculated the stability score.
|
||||
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks between different crops.
|
||||
im (torch.Tensor): Input tensor representing the preprocessed image with dimensions (N, C, H, W).
|
||||
crop_n_layers (int): Specifies the number of layers for additional mask predictions on image crops.
|
||||
Each layer produces 2**i_layer number of image crops.
|
||||
crop_overlap_ratio (float): Determines the extent of overlap between crops. Scaled down in subsequent layers.
|
||||
crop_downscale_factor (int): Scaling factor for the number of sampled points-per-side in each layer.
|
||||
point_grids (list[np.ndarray], optional): Custom grids for point sampling normalized to [0,1].
|
||||
Used in the nth crop layer.
|
||||
points_stride (int, optional): Number of points to sample along each side of the image.
|
||||
Exclusive with 'point_grids'.
|
||||
points_batch_size (int): Batch size for the number of points processed simultaneously.
|
||||
conf_thres (float): Confidence threshold [0,1] for filtering based on the model's mask quality prediction.
|
||||
stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on mask stability.
|
||||
stability_score_offset (float): Offset value for calculating stability score.
|
||||
crop_nms_thresh (float): IoU cutoff for Non-Maximum Suppression (NMS) to remove duplicate masks between crops.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing segmented masks, confidence scores, and bounding boxes.
|
||||
"""
|
||||
self.segment_all = True
|
||||
ih, iw = im.shape[2:]
|
||||
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
|
||||
if point_grids is None:
|
||||
point_grids = build_all_layer_point_grids(
|
||||
points_stride,
|
||||
crop_n_layers,
|
||||
crop_downscale_factor,
|
||||
)
|
||||
point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor)
|
||||
pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
|
||||
for crop_region, layer_idx in zip(crop_regions, layer_idxs):
|
||||
x1, y1, x2, y2 = crop_region
|
||||
@ -236,19 +258,20 @@ class Predictor(BasePredictor):
|
||||
area = torch.tensor(w * h, device=im.device)
|
||||
points_scale = np.array([[w, h]]) # w, h
|
||||
# Crop image and interpolate to input size
|
||||
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False)
|
||||
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False)
|
||||
# (num_points, 2)
|
||||
points_for_image = point_grids[layer_idx] * points_scale
|
||||
crop_masks, crop_scores, crop_bboxes = [], [], []
|
||||
for (points, ) in batch_iterator(points_batch_size, points_for_image):
|
||||
for (points,) in batch_iterator(points_batch_size, points_for_image):
|
||||
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
|
||||
# Interpolate predicted masks to input size
|
||||
pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0]
|
||||
pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0]
|
||||
idx = pred_score > conf_thres
|
||||
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
||||
|
||||
stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold,
|
||||
stability_score_offset)
|
||||
stability_score = calculate_stability_score(
|
||||
pred_mask, self.model.mask_threshold, stability_score_offset
|
||||
)
|
||||
idx = stability_score > stability_score_thresh
|
||||
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
||||
# Bool type is much more memory-efficient.
|
||||
@ -291,7 +314,22 @@ class Predictor(BasePredictor):
|
||||
return pred_masks, pred_scores, pred_bboxes
|
||||
|
||||
def setup_model(self, model, verbose=True):
|
||||
"""Set up YOLO model with specified thresholds and device."""
|
||||
"""
|
||||
Initializes the Segment Anything Model (SAM) for inference.
|
||||
|
||||
This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
|
||||
parameters for image normalization and other Ultralytics compatibility settings.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration.
|
||||
verbose (bool): If True, prints selected device information.
|
||||
|
||||
Attributes:
|
||||
model (torch.nn.Module): The SAM model allocated to the chosen device for inference.
|
||||
device (torch.device): The device to which the model and tensors are allocated.
|
||||
mean (torch.Tensor): The mean values for image normalization.
|
||||
std (torch.Tensor): The standard deviation values for image normalization.
|
||||
"""
|
||||
device = select_device(self.args.device, verbose=verbose)
|
||||
if model is None:
|
||||
model = build_sam(self.args.model)
|
||||
@ -300,7 +338,8 @@ class Predictor(BasePredictor):
|
||||
self.device = device
|
||||
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
|
||||
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
|
||||
# TODO: Temporary settings for compatibility
|
||||
|
||||
# Ultralytics compatibility settings
|
||||
self.model.pt = False
|
||||
self.model.triton = False
|
||||
self.model.stride = 32
|
||||
@ -308,7 +347,20 @@ class Predictor(BasePredictor):
|
||||
self.done_warmup = True
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Post-processes inference output predictions to create detection masks for objects."""
|
||||
"""
|
||||
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
|
||||
|
||||
The method scales masks and boxes to the original image size and applies a threshold to the mask predictions. The
|
||||
SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance.
|
||||
|
||||
Args:
|
||||
preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes.
|
||||
img (torch.Tensor): The processed input image tensor.
|
||||
orig_imgs (list | torch.Tensor): The original, unprocessed images.
|
||||
|
||||
Returns:
|
||||
(list): List of Results objects containing detection masks, bounding boxes, and other metadata.
|
||||
"""
|
||||
# (N, 1, H, W), (N, 1)
|
||||
pred_masks, pred_scores = preds[:2]
|
||||
pred_bboxes = preds[2] if self.segment_all else None
|
||||
@ -334,21 +386,36 @@ class Predictor(BasePredictor):
|
||||
return results
|
||||
|
||||
def setup_source(self, source):
|
||||
"""Sets up source and inference mode."""
|
||||
"""
|
||||
Sets up the data source for inference.
|
||||
|
||||
This method configures the data source from which images will be fetched for inference. The source could be a
|
||||
directory, a video file, or other types of image data sources.
|
||||
|
||||
Args:
|
||||
source (str | Path): The path to the image data source for inference.
|
||||
"""
|
||||
if source is not None:
|
||||
super().setup_source(source)
|
||||
|
||||
def set_image(self, image):
|
||||
"""Set image in advance.
|
||||
Args:
|
||||
"""
|
||||
Preprocesses and sets a single image for inference.
|
||||
|
||||
image (str | np.ndarray): image file path or np.ndarray image by cv2.
|
||||
This function sets up the model if not already initialized, configures the data source to the specified image,
|
||||
and preprocesses the image for feature extraction. Only one image can be set at a time.
|
||||
|
||||
Args:
|
||||
image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2.
|
||||
|
||||
Raises:
|
||||
AssertionError: If more than one image is set.
|
||||
"""
|
||||
if self.model is None:
|
||||
model = build_sam(self.args.model)
|
||||
self.setup_model(model)
|
||||
self.setup_source(image)
|
||||
assert len(self.dataset) == 1, '`set_image` only supports setting one image!'
|
||||
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
|
||||
for batch in self.dataset:
|
||||
im = self.preprocess(batch[1])
|
||||
self.features = self.model.image_encoder(im)
|
||||
@ -360,23 +427,27 @@ class Predictor(BasePredictor):
|
||||
self.prompts = prompts
|
||||
|
||||
def reset_image(self):
|
||||
"""Resets the image and its features to None."""
|
||||
self.im = None
|
||||
self.features = None
|
||||
|
||||
@staticmethod
|
||||
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
|
||||
"""
|
||||
Removes small disconnected regions and holes in masks, then reruns
|
||||
box NMS to remove any new duplicates. Requires open-cv as a dependency.
|
||||
Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this
|
||||
function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
|
||||
Suppression (NMS) to eliminate any newly created duplicate boxes.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Masks, (N, H, W).
|
||||
min_area (int): Minimum area threshold.
|
||||
nms_thresh (float): NMS threshold.
|
||||
masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is
|
||||
the number of masks, H is height, and W is width.
|
||||
min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0.
|
||||
nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7.
|
||||
|
||||
Returns:
|
||||
new_masks (torch.Tensor): New Masks, (N, H, W).
|
||||
keep (List[int]): The indices of the new masks, which can be used to filter
|
||||
the corresponding boxes.
|
||||
(tuple([torch.Tensor, List[int]])):
|
||||
- new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W).
|
||||
- keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes.
|
||||
"""
|
||||
if len(masks) == 0:
|
||||
return masks
|
||||
@ -386,23 +457,18 @@ class Predictor(BasePredictor):
|
||||
scores = []
|
||||
for mask in masks:
|
||||
mask = mask.cpu().numpy().astype(np.uint8)
|
||||
mask, changed = remove_small_regions(mask, min_area, mode='holes')
|
||||
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
||||
unchanged = not changed
|
||||
mask, changed = remove_small_regions(mask, min_area, mode='islands')
|
||||
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
||||
unchanged = unchanged and not changed
|
||||
|
||||
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
||||
# Give score=0 to changed masks and score=1 to unchanged masks
|
||||
# so NMS will prefer ones that didn't need postprocessing
|
||||
# Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing
|
||||
scores.append(float(unchanged))
|
||||
|
||||
# Recalculate boxes and remove any new duplicates
|
||||
new_masks = torch.cat(new_masks, dim=0)
|
||||
boxes = batched_mask_to_box(new_masks)
|
||||
keep = torchvision.ops.nms(
|
||||
boxes.float(),
|
||||
torch.as_tensor(scores),
|
||||
nms_thresh,
|
||||
)
|
||||
keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)
|
||||
|
||||
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
|
||||
|
Reference in New Issue
Block a user