initial project version!

This commit is contained in:
王庆刚
2024-05-20 20:01:06 +08:00
commit d6f3693d3f
483 changed files with 60345 additions and 0 deletions

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .rtdetr import RTDETR
from .sam import SAM
from .yolo import YOLO
__all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import

Binary file not shown.

View File

@ -0,0 +1,8 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .model import FastSAM
from .predict import FastSAMPredictor
from .prompt import FastSAMPrompt
from .val import FastSAMValidator
__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator'

View File

@ -0,0 +1,33 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from pathlib import Path
from ultralytics.engine.model import Model
from .predict import FastSAMPredictor
from .val import FastSAMValidator
class FastSAM(Model):
"""
FastSAM model interface.
Example:
```python
from ultralytics import FastSAM
model = FastSAM('last.pt')
results = model.predict('ultralytics/assets/bus.jpg')
```
"""
def __init__(self, model='FastSAM-x.pt'):
"""Call the __init__ method of the parent class (YOLO) with the updated default model"""
if str(model) == 'FastSAM.pt':
model = 'FastSAM-x.pt'
assert Path(model).suffix not in ('.yaml', '.yml'), 'FastSAM models only support pre-trained models.'
super().__init__(model=model, task='segment')
@property
def task_map(self):
return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}

View File

@ -0,0 +1,51 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.engine.results import Results
from ultralytics.models.fastsam.utils import bbox_iou
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ops
class FastSAMPredictor(DetectionPredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'segment'
def postprocess(self, preds, img, orig_imgs):
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=len(self.model.names),
classes=self.args.classes)
full_box = torch.zeros(p[0].shape[1], device=p[0].device)
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
full_box = full_box.view(1, -1)
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
if critical_iou_index.numel() != 0:
full_box[0][4] = p[0][critical_iou_index][:, 4]
full_box[0][6:] = p[0][critical_iou_index][:, 6:]
p[0][critical_iou_index] = full_box
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
for i, pred in enumerate(p):
orig_img = orig_imgs[i]
img_path = self.batch[0][i]
if not len(pred): # save empty boxes
masks = None
elif self.args.retina_masks:
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
else:
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
return results

View File

@ -0,0 +1,307 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from ultralytics.utils import TQDM
class FastSAMPrompt:
def __init__(self, source, results, device='cuda') -> None:
self.device = device
self.results = results
self.source = source
# Import and assign clip
try:
import clip # for linear_assignment
except ImportError:
from ultralytics.utils.checks import check_requirements
check_requirements('git+https://github.com/openai/CLIP.git')
import clip
self.clip = clip
@staticmethod
def _segment_image(image, bbox):
image_array = np.array(image)
segmented_image_array = np.zeros_like(image_array)
x1, y1, x2, y2 = bbox
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
segmented_image = Image.fromarray(segmented_image_array)
black_image = Image.new('RGB', image.size, (255, 255, 255))
# transparency_mask = np.zeros_like((), dtype=np.uint8)
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
transparency_mask[y1:y2, x1:x2] = 255
transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
black_image.paste(segmented_image, mask=transparency_mask_image)
return black_image
@staticmethod
def _format_results(result, filter=0):
annotations = []
n = len(result.masks.data) if result.masks is not None else 0
for i in range(n):
mask = result.masks.data[i] == 1.0
if torch.sum(mask) >= filter:
annotation = {
'id': i,
'segmentation': mask.cpu().numpy(),
'bbox': result.boxes.data[i],
'score': result.boxes.conf[i]}
annotation['area'] = annotation['segmentation'].sum()
annotations.append(annotation)
return annotations
@staticmethod
def _get_bbox_from_mask(mask):
mask = mask.astype(np.uint8)
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
x1, y1, w, h = cv2.boundingRect(contours[0])
x2, y2 = x1 + w, y1 + h
if len(contours) > 1:
for b in contours:
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
x1 = min(x1, x_t)
y1 = min(y1, y_t)
x2 = max(x2, x_t + w_t)
y2 = max(y2, y_t + h_t)
return [x1, y1, x2, y2]
def plot(self,
annotations,
output,
bbox=None,
points=None,
point_label=None,
mask_random_color=True,
better_quality=True,
retina=False,
with_contours=True):
pbar = TQDM(annotations, total=len(annotations))
for ann in pbar:
result_name = os.path.basename(ann.path)
image = ann.orig_img
original_h, original_w = ann.orig_shape
# for macOS only
# plt.switch_backend('TkAgg')
plt.figure(figsize=(original_w / 100, original_h / 100))
# Add subplot with no margin.
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(image)
if ann.masks is not None:
masks = ann.masks.data
if better_quality:
if isinstance(masks[0], torch.Tensor):
masks = np.array(masks.cpu())
for i, mask in enumerate(masks):
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
self.fast_show_mask(
masks,
plt.gca(),
random_color=mask_random_color,
bbox=bbox,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
if with_contours:
contour_all = []
temp = np.zeros((original_h, original_w, 1))
for i, mask in enumerate(masks):
mask = mask.astype(np.uint8)
if not retina:
mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
contour_all.extend(iter(contours))
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask)
plt.axis('off')
fig = plt.gcf()
# Check if the canvas has been drawn
if fig.canvas.get_renderer() is None: # macOS requires this or tests fail
fig.canvas.draw()
save_path = Path(output) / result_name
save_path.parent.mkdir(exist_ok=True, parents=True)
image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
image.save(save_path)
plt.close()
pbar.set_description(f'Saving {result_name} to {save_path}')
@staticmethod
def fast_show_mask(
annotation,
ax,
random_color=False,
bbox=None,
points=None,
pointlabel=None,
retinamask=True,
target_height=960,
target_width=960,
):
n, h, w = annotation.shape # batch, height, width
areas = np.sum(annotation, axis=(1, 2))
annotation = annotation[np.argsort(areas)]
index = (annotation != 0).argmax(axis=0)
if random_color:
color = np.random.random((n, 1, 1, 3))
else:
color = np.ones((n, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
transparency = np.ones((n, 1, 1, 1)) * 0.6
visual = np.concatenate([color, transparency], axis=-1)
mask_image = np.expand_dims(annotation, -1) * visual
show = np.zeros((h, w, 4))
h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
show[h_indices, w_indices, :] = mask_image[indices]
if bbox is not None:
x1, y1, x2, y2 = bbox
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
# Draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
s=20,
c='y',
)
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
s=20,
c='m',
)
if not retinamask:
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
ax.imshow(show)
@torch.no_grad()
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
preprocessed_images = [preprocess(image).to(device) for image in elements]
tokenized_text = self.clip.tokenize([search_text]).to(device)
stacked_images = torch.stack(preprocessed_images)
image_features = model.encode_image(stacked_images)
text_features = model.encode_text(tokenized_text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
probs = 100.0 * image_features @ text_features.T
return probs[:, 0].softmax(dim=0)
def _crop_image(self, format_results):
if os.path.isdir(self.source):
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
ori_w, ori_h = image.size
annotations = format_results
mask_h, mask_w = annotations[0]['segmentation'].shape
if ori_w != mask_w or ori_h != mask_h:
image = image.resize((mask_w, mask_h))
cropped_boxes = []
cropped_images = []
not_crop = []
filter_id = []
for _, mask in enumerate(annotations):
if np.sum(mask['segmentation']) <= 100:
filter_id.append(_)
continue
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片
cropped_images.append(bbox) # 保存裁剪的图片的bbox
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
def box_prompt(self, bbox):
if self.results[0].masks is not None:
assert (bbox[2] != 0 and bbox[3] != 0)
if os.path.isdir(self.source):
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
masks = self.results[0].masks.data
target_height, target_width = self.results[0].orig_shape
h = masks.shape[1]
w = masks.shape[2]
if h != target_height or w != target_width:
bbox = [
int(bbox[0] * w / target_width),
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height), ]
bbox[0] = max(round(bbox[0]), 0)
bbox[1] = max(round(bbox[1]), 0)
bbox[2] = min(round(bbox[2]), w)
bbox[3] = min(round(bbox[3]), h)
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
orig_masks_area = torch.sum(masks, dim=(1, 2))
union = bbox_area + orig_masks_area - masks_area
IoUs = masks_area / union
max_iou_index = torch.argmax(IoUs)
self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
return self.results
def point_prompt(self, points, pointlabel): # numpy 处理
if self.results[0].masks is not None:
if os.path.isdir(self.source):
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
masks = self._format_results(self.results[0], 0)
target_height, target_width = self.results[0].orig_shape
h = masks[0]['segmentation'].shape[0]
w = masks[0]['segmentation'].shape[1]
if h != target_height or w != target_width:
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
onemask = np.zeros((h, w))
for annotation in masks:
mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation
for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
onemask += mask
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
onemask -= mask
onemask = onemask >= 1
self.results[0].masks.data = torch.tensor(np.array([onemask]))
return self.results
def text_prompt(self, text):
if self.results[0].masks is not None:
format_results = self._format_results(self.results[0], 0)
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
max_idx = scores.argsort()
max_idx = max_idx[-1]
max_idx += sum(np.array(filter_id) <= int(max_idx))
self.results[0].masks.data = torch.tensor(np.array([ann['segmentation'] for ann in annotations]))
return self.results
def everything_prompt(self):
return self.results

View File

@ -0,0 +1,67 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
"""
Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args:
boxes (torch.Tensor): (n, 4)
image_shape (tuple): (height, width)
threshold (int): pixel threshold
Returns:
adjusted_boxes (torch.Tensor): adjusted bounding boxes
"""
# Image dimensions
h, w = image_shape
# Adjust boxes
boxes[boxes[:, 0] < threshold, 0] = 0 # x1
boxes[boxes[:, 1] < threshold, 1] = 0 # y1
boxes[boxes[:, 2] > w - threshold, 2] = w # x2
boxes[boxes[:, 3] > h - threshold, 3] = h # y2
return boxes
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
"""
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
Args:
box1 (torch.Tensor): (4, )
boxes (torch.Tensor): (n, 4)
iou_thres (float): IoU threshold
image_shape (tuple): (height, width)
raw_output (bool): If True, return the raw IoU values instead of the indices
Returns:
high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres
"""
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
# obtain coordinates for intersections
x1 = torch.max(box1[0], boxes[:, 0])
y1 = torch.max(box1[1], boxes[:, 1])
x2 = torch.min(box1[2], boxes[:, 2])
y2 = torch.min(box1[3], boxes[:, 3])
# compute the area of intersection
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
# compute the area of both individual boxes
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
# compute the area of union
union = box1_area + box2_area - intersection
# compute the IoU
iou = intersection / union # Should be shape (n, )
if raw_output:
return 0 if iou.numel() == 0 else iou
# return indices of boxes with IoU > thres
return torch.nonzero(iou > iou_thres).flatten()

View File

@ -0,0 +1,14 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.models.yolo.segment import SegmentationValidator
from ultralytics.utils.metrics import SegmentMetrics
class FastSAMValidator(SegmentationValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'segment'
self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .model import NAS
from .predict import NASPredictor
from .val import NASValidator
__all__ = 'NASPredictor', 'NASValidator', 'NAS'

Binary file not shown.

View File

@ -0,0 +1,61 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
YOLO-NAS model interface.
Example:
```python
from ultralytics import NAS
model = NAS('yolo_nas_s')
results = model.predict('ultralytics/assets/bus.jpg')
```
"""
from pathlib import Path
import torch
from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from .predict import NASPredictor
from .val import NASValidator
class NAS(Model):
def __init__(self, model='yolo_nas_s.pt') -> None:
assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.'
super().__init__(model, task='detect')
@smart_inference_mode()
def _load(self, weights: str, task: str):
# Load or create new NAS model
import super_gradients
suffix = Path(weights).suffix
if suffix == '.pt':
self.model = torch.load(weights)
elif suffix == '':
self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
# Standardize model
self.model.fuse = lambda verbose=True: self.model
self.model.stride = torch.tensor([32])
self.model.names = dict(enumerate(self.model._class_names))
self.model.is_fused = lambda: False # for info()
self.model.yaml = {} # for info()
self.model.pt_path = weights # for export()
self.model.task = 'detect' # for export()
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
@property
def task_map(self):
return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}

View File

@ -0,0 +1,35 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ops
class NASPredictor(BasePredictor):
def postprocess(self, preds_in, img, orig_imgs):
"""Postprocess predictions and returns a list of Results objects."""
# Cat boxes and class scores
boxes = ops.xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i]
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
img_path = self.batch[0][i]
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
return results

View File

@ -0,0 +1,24 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import ops
__all__ = ['NASValidator']
class NASValidator(DetectionValidator):
def postprocess(self, preds_in):
"""Apply Non-maximum suppression to prediction outputs."""
boxes = ops.xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
return ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=False,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
max_time_img=0.5)

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .model import RTDETR
from .predict import RTDETRPredictor
from .val import RTDETRValidator
__all__ = 'RTDETRPredictor', 'RTDETRValidator', 'RTDETR'

View File

@ -0,0 +1,30 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
RT-DETR model interface
"""
from ultralytics.engine.model import Model
from ultralytics.nn.tasks import RTDETRDetectionModel
from .predict import RTDETRPredictor
from .train import RTDETRTrainer
from .val import RTDETRValidator
class RTDETR(Model):
"""
RTDETR model interface.
"""
def __init__(self, model='rtdetr-l.pt') -> None:
if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'):
raise NotImplementedError('RT-DETR only supports creating from *.pt file or *.yaml file.')
super().__init__(model=model, task='detect')
@property
def task_map(self):
return {
'detect': {
'predictor': RTDETRPredictor,
'validator': RTDETRValidator,
'trainer': RTDETRTrainer,
'model': RTDETRDetectionModel}}

View File

@ -0,0 +1,62 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.data.augment import LetterBox
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ops
class RTDETRPredictor(BasePredictor):
"""
A class extending the BasePredictor class for prediction based on an RT-DETR detection model.
Example:
```python
from ultralytics.utils import ASSETS
from ultralytics.models.rtdetr import RTDETRPredictor
args = dict(model='rtdetr-l.pt', source=ASSETS)
predictor = RTDETRPredictor(overrides=args)
predictor.predict_cli()
```
"""
def postprocess(self, preds, img, orig_imgs):
"""Postprocess predictions and returns a list of Results objects."""
nd = preds[0].shape[-1]
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for i, bbox in enumerate(bboxes): # (300, 4)
bbox = ops.xywh2xyxy(bbox)
score, cls = scores[i].max(-1, keepdim=True) # (300, 1)
idx = score.squeeze(-1) > self.args.conf # (300, )
if self.args.classes is not None:
idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
pred = torch.cat([bbox, score, cls], dim=-1)[idx] # filter
orig_img = orig_imgs[i]
oh, ow = orig_img.shape[:2]
pred[..., [0, 2]] *= ow
pred[..., [1, 3]] *= oh
img_path = self.batch[0][i]
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
return results
def pre_transform(self, im):
"""Pre-transform input image before inference.
Args:
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
Notes: The size must be square(640) and scaleFilled.
Returns:
(list): A list of transformed imgs.
"""
letterbox = LetterBox(self.imgsz, auto=False, scaleFill=True)
return [letterbox(image=x) for x in im]

View File

@ -0,0 +1,72 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy
import torch
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import RTDETRDetectionModel
from ultralytics.utils import RANK, colorstr
from .val import RTDETRDataset, RTDETRValidator
class RTDETRTrainer(DetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on an RT-DETR detection model.
Notes:
- F.grid_sample used in rt-detr does not support the `deterministic=True` argument.
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
Example:
```python
from ultralytics.models.rtdetr.train import RTDETRTrainer
args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
```
"""
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return a YOLO detection model."""
model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def build_dataset(self, img_path, mode='val', batch=None):
"""Build RTDETR Dataset
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
return RTDETRDataset(
img_path=img_path,
imgsz=self.args.imgsz,
batch_size=batch,
augment=mode == 'train', # no augmentation
hyp=self.args,
rect=False, # no rect
cache=self.args.cache or None,
prefix=colorstr(f'{mode}: '),
data=self.data)
def get_validator(self):
"""Returns a DetectionValidator for RTDETR model validation."""
self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def preprocess_batch(self, batch):
"""Preprocesses a batch of images by scaling and converting to float."""
batch = super().preprocess_batch(batch)
bs = len(batch['img'])
batch_idx = batch['batch_idx']
gt_bbox, gt_class = [], []
for i in range(bs):
gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
return batch

View File

@ -0,0 +1,141 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from pathlib import Path
import torch
from ultralytics.data import YOLODataset
from ultralytics.data.augment import Compose, Format, v8_transforms
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import colorstr, ops
__all__ = 'RTDETRValidator', # tuple or list
# TODO: Temporarily RT-DETR does not need padding.
class RTDETRDataset(YOLODataset):
def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs)
# NOTE: add stretch version load_image for rtdetr mosaic
def load_image(self, i, rect_mode=False):
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
return super().load_image(i=i, rect_mode=rect_mode)
def build_transforms(self, hyp=None):
"""Temporary, only for evaluation."""
if self.augment:
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
else:
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
transforms = Compose([])
transforms.append(
Format(bbox_format='xywh',
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask))
return transforms
class RTDETRValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on an RT-DETR detection model.
Example:
```python
from ultralytics.models.rtdetr import RTDETRValidator
args = dict(model='rtdetr-l.pt', data='coco8.yaml')
validator = RTDETRValidator(args=args)
validator()
```
"""
def build_dataset(self, img_path, mode='val', batch=None):
"""
Build an RTDETR Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
return RTDETRDataset(
img_path=img_path,
imgsz=self.args.imgsz,
batch_size=batch,
augment=False, # no augmentation
hyp=self.args,
rect=False, # no rect
cache=self.args.cache or None,
prefix=colorstr(f'{mode}: '),
data=self.data)
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
bs, _, nd = preds[0].shape
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
bboxes *= self.args.imgsz
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
for i, bbox in enumerate(bboxes): # (300, 4)
bbox = ops.xywh2xyxy(bbox)
score, cls = scores[i].max(-1) # (300, )
# Do not need threshold for evaluation as only got 300 boxes here.
# idx = score > self.args.conf
pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter
# sort by confidence to correctly get internal metrics.
pred = pred[score.argsort(descending=True)]
outputs[i] = pred # [idx]
return outputs
def update_metrics(self, preds, batch):
"""Metrics."""
for si, pred in enumerate(preds):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
bbox = batch['bboxes'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
shape = batch['ori_shape'][si]
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
if npr == 0:
if nl:
self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), cls.squeeze(-1)))
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
continue
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
predn[..., [0, 2]] *= shape[1] / self.args.imgsz # native-space pred
predn[..., [1, 3]] *= shape[0] / self.args.imgsz # native-space pred
# Evaluate
if nl:
tbox = ops.xywh2xyxy(bbox) # target boxes
tbox[..., [0, 2]] *= shape[1] # native-space pred
tbox[..., [1, 3]] *= shape[0] # native-space pred
labelsn = torch.cat((cls, tbox), 1) # native-space labels
# NOTE: To get correct metrics, the inputs of `_process_batch` should always be float32 type.
correct_bboxes = self._process_batch(predn.float(), labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) # (conf, pcls, tcls)
# Save
if self.args.save_json:
self.pred_to_json(predn, batch['im_file'][si])
if self.args.save_txt:
file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt'
self.save_one_txt(predn, self.args.save_conf, shape, file)

View File

@ -0,0 +1,8 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .model import SAM
from .predict import Predictor
# from .build import build_sam
__all__ = 'SAM', 'Predictor' # tuple or list

Binary file not shown.

View File

@ -0,0 +1,180 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import math
from itertools import product
from typing import Any, Generator, List, Tuple
import numpy as np
import torch
def is_box_near_crop_edge(boxes: torch.Tensor,
crop_box: List[int],
orig_box: List[int],
atol: float = 20.0) -> torch.Tensor:
"""Return a boolean tensor indicating if boxes are near the crop edge."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
return torch.any(near_crop_edge, dim=1)
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
"""Yield batches of data from the input arguments."""
assert args and all(len(a) == len(args[0]) for a in args), 'Batched iteration must have same-size inputs.'
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
for b in range(n_batches):
yield [arg[b * batch_size:(b + 1) * batch_size] for arg in args]
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
"""
Computes the stability score for a batch of masks. The stability
score is the IoU between the binary masks obtained by thresholding
the predicted mask logits at high and low values.
"""
# One mask is always contained inside the other.
# Save memory by preventing unnecessary cast to torch.int64
intersections = ((masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1,
dtype=torch.int32))
unions = ((masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32))
return intersections / unions
def build_point_grid(n_per_side: int) -> np.ndarray:
"""Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1]."""
offset = 1 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
"""Generate point grids for all crop layers."""
return [build_point_grid(int(n_per_side / (scale_per_layer ** i))) for i in range(n_layers + 1)]
def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int,
overlap_ratio: float) -> Tuple[List[List[int]], List[int]]:
"""Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer."""
crop_boxes, layer_idxs = [], []
im_h, im_w = im_size
short_side = min(im_h, im_w)
# Original image
crop_boxes.append([0, 0, im_w, im_h])
layer_idxs.append(0)
def crop_len(orig_len, n_crops, overlap):
"""Crops bounding boxes to the size of the input image."""
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
for i_layer in range(n_layers):
n_crops_per_side = 2 ** (i_layer + 1)
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
crop_w = crop_len(im_w, n_crops_per_side, overlap)
crop_h = crop_len(im_h, n_crops_per_side, overlap)
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
# Crops in XYWH format
for x0, y0 in product(crop_box_x0, crop_box_y0):
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
crop_boxes.append(box)
layer_idxs.append(i_layer + 1)
return crop_boxes, layer_idxs
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
"""Uncrop bounding boxes by adding the crop box offset."""
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
# Check if boxes has a channel dimension
if len(boxes.shape) == 3:
offset = offset.unsqueeze(1)
return boxes + offset
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
"""Uncrop points by adding the crop box offset."""
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0]], device=points.device)
# Check if points has a channel dimension
if len(points.shape) == 3:
offset = offset.unsqueeze(1)
return points + offset
def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
"""Uncrop masks by padding them to the original image size."""
x0, y0, x1, y1 = crop_box
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
return masks
# Coordinate transform masks
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
pad = (x0, pad_x - x0, y0, pad_y - y0)
return torch.nn.functional.pad(masks, pad, value=0)
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
"""Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
import cv2 # type: ignore
assert mode in {'holes', 'islands'}
correct_holes = mode == 'holes'
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if not small_regions:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
# If every region is below threshold, keep largest
fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
"""
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
"""
# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
# Normalize shape to CxHxW
shape = masks.shape
h, w = shape[-2:]
masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0)
# Get top and bottom edges
in_height, _ = torch.max(masks, dim=-1)
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
in_height_coords = in_height_coords + h * (~in_height)
top_edges, _ = torch.min(in_height_coords, dim=-1)
# Get left and right edges
in_width, _ = torch.max(masks, dim=-2)
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
right_edges, _ = torch.max(in_width_coords, dim=-1)
in_width_coords = in_width_coords + w * (~in_width)
left_edges, _ = torch.min(in_width_coords, dim=-1)
# If the mask is empty the right edge will be to the left of the left edge.
# Replace these boxes with [0, 0, 0, 0]
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
out = out * (~empty_filter).unsqueeze(-1)
# Return to original shape
return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0]

View File

@ -0,0 +1,158 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
import torch
from ultralytics.utils.downloads import attempt_download_asset
from .modules.decoders import MaskDecoder
from .modules.encoders import ImageEncoderViT, PromptEncoder
from .modules.sam import Sam
from .modules.tiny_encoder import TinyViT
from .modules.transformer import TwoWayTransformer
def build_sam_vit_h(checkpoint=None):
"""Build and return a Segment Anything Model (SAM) h-size model."""
return _build_sam(
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_global_attn_indexes=[7, 15, 23, 31],
checkpoint=checkpoint,
)
def build_sam_vit_l(checkpoint=None):
"""Build and return a Segment Anything Model (SAM) l-size model."""
return _build_sam(
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_global_attn_indexes=[5, 11, 17, 23],
checkpoint=checkpoint,
)
def build_sam_vit_b(checkpoint=None):
"""Build and return a Segment Anything Model (SAM) b-size model."""
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
def build_mobile_sam(checkpoint=None):
"""Build and return Mobile Segment Anything Model (Mobile-SAM)."""
return _build_sam(
encoder_embed_dim=[64, 128, 160, 320],
encoder_depth=[2, 2, 6, 2],
encoder_num_heads=[2, 4, 5, 10],
encoder_global_attn_indexes=None,
mobile_sam=True,
checkpoint=checkpoint,
)
def _build_sam(encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
mobile_sam=False):
"""Builds the selected SAM model architecture."""
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
image_encoder = (TinyViT(
img_size=1024,
in_chans=3,
num_classes=1000,
embed_dims=encoder_embed_dim,
depths=encoder_depth,
num_heads=encoder_num_heads,
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.0,
drop_rate=0.0,
drop_path_rate=0.0,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=0.8,
) if mobile_sam else ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
))
sam = Sam(
image_encoder=image_encoder,
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
if checkpoint is not None:
checkpoint = attempt_download_asset(checkpoint)
with open(checkpoint, 'rb') as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
sam.eval()
# sam.load_state_dict(torch.load(checkpoint), strict=True)
# sam.eval()
return sam
sam_model_map = {
'sam_h.pt': build_sam_vit_h,
'sam_l.pt': build_sam_vit_l,
'sam_b.pt': build_sam_vit_b,
'mobile_sam.pt': build_mobile_sam, }
def build_sam(ckpt='sam_b.pt'):
"""Build a SAM model specified by ckpt."""
model_builder = None
for k in sam_model_map.keys():
if ckpt.endswith(k):
model_builder = sam_model_map.get(k)
if not model_builder:
raise FileNotFoundError(f'{ckpt} is not a supported sam model. Available models are: \n {sam_model_map.keys()}')
return model_builder(ckpt)

View File

@ -0,0 +1,51 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
SAM model interface
"""
from pathlib import Path
from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info
from .build import build_sam
from .predict import Predictor
class SAM(Model):
"""
SAM model interface.
"""
def __init__(self, model='sam_b.pt') -> None:
if model and Path(model).suffix not in ('.pt', '.pth'):
raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
super().__init__(model=model, task='segment')
def _load(self, weights: str, task=None):
self.model = build_sam(weights)
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Predicts and returns segmentation masks for given image or video source."""
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
kwargs.update(overrides)
prompts = dict(bboxes=bboxes, points=points, labels=labels)
return super().predict(source, stream, prompts=prompts, **kwargs)
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, bboxes, points, labels, **kwargs)
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose)
@property
def task_map(self):
return {'segment': {'predictor': Predictor}}

View File

@ -0,0 +1 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

View File

@ -0,0 +1,159 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from typing import List, Tuple, Type
import torch
from torch import nn
from torch.nn import functional as F
from ultralytics.nn.modules import LayerNorm2d
class MaskDecoder(nn.Module):
def __init__(
self,
*,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
activation: Type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
) -> None:
"""
Predicts masks given an image and prompt embeddings, using a transformer architecture.
Args:
transformer_dim (int): the channel dimension of the transformer module
transformer (nn.Module): the transformer used to predict masks
num_multimask_outputs (int): the number of masks to predict when disambiguating masks
activation (nn.Module): the type of activation to use when upscaling masks
iou_head_depth (int): the depth of the MLP used to predict mask quality
iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.num_multimask_outputs = num_multimask_outputs
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
self.output_hypernetworks_mlps = nn.ModuleList([
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)])
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Args:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
"""
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# Select the correct mask or masks for output
mask_slice = slice(1, None) if multimask_output else slice(0, 1)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]
# Prepare output
return masks, iou_pred
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = [
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)]
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
class MLP(nn.Module):
"""
Lightly adapted from
https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
sigmoid_output: bool = False,
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
self.sigmoid_output = sigmoid_output
def forward(self, x):
"""Executes feedforward within the neural network module and applies activation."""
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = torch.sigmoid(x)
return x

View File

@ -0,0 +1,568 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from typing import Any, Optional, Tuple, Type
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.nn.modules import LayerNorm2d, MLPBlock
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
return self.neck(x.permute(0, 3, 1, 2))
class PromptEncoder(nn.Module):
def __init__(
self,
embed_dim: int,
image_embedding_size: Tuple[int, int],
input_image_size: Tuple[int, int],
mask_in_chans: int,
activation: Type[nn.Module] = nn.GELU,
) -> None:
"""
Encodes prompts for input to SAM's mask decoder.
Args:
embed_dim (int): The prompts' embedding dimension
image_embedding_size (tuple(int, int)): The spatial size of the
image embedding, as (H, W).
input_image_size (int): The padded size of the image as input
to the image encoder, as (H, W).
mask_in_chans (int): The number of hidden channels used for
encoding input masks.
activation (nn.Module): The activation to use when encoding
input masks.
"""
super().__init__()
self.embed_dim = embed_dim
self.input_image_size = input_image_size
self.image_embedding_size = image_embedding_size
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)]
self.point_embeddings = nn.ModuleList(point_embeddings)
self.not_a_point_embed = nn.Embedding(1, embed_dim)
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
self.mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
activation(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans),
activation(),
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
self.no_mask_embed = nn.Embedding(1, embed_dim)
def get_dense_pe(self) -> torch.Tensor:
"""
Returns the positional encoding used to encode point prompts,
applied to a dense set of points the shape of the image encoding.
Returns:
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
"""
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
def _embed_points(
self,
points: torch.Tensor,
labels: torch.Tensor,
pad: bool,
) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
"""Embeds mask inputs."""
return self.mask_downscaling(masks)
def _get_batch_size(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> int:
"""
Gets the batch size of the output given the batch size of the input prompts.
"""
if points is not None:
return points[0].shape[0]
elif boxes is not None:
return boxes.shape[0]
elif masks is not None:
return masks.shape[0]
else:
return 1
def _get_device(self) -> torch.device:
return self.point_embeddings[0].weight.device
def forward(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense embeddings.
Args:
points (tuple(torch.Tensor, torch.Tensor), None): point coordinates and labels to embed.
boxes (torch.Tensor, None): boxes to embed
masks (torch.Tensor, None): masks to embed
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined
by the number of input points and boxes.
torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
if points is not None:
coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
if boxes is not None:
box_embeddings = self._embed_boxes(boxes)
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1,
1).expand(bs, -1, self.image_embedding_size[0],
self.image_embedding_size[1])
return sparse_embeddings, dense_embeddings
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.register_buffer('positional_encoding_gaussian_matrix', scale * torch.randn((2, num_pos_feats)))
# Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
h, w = size
device: Any = self.positional_encoding_gaussian_matrix.device
grid = torch.ones((h, w), device=device, dtype=torch.float32)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / h
x_embed = x_embed / w
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
return pe.permute(2, 0, 1) # C x H x W
def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
"""Positionally encode points that are not normalized to [0,1]."""
coords = coords_input.clone()
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
return self._pe_encoding(coords.to(torch.float)) # B x N x C
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int), None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
return x + self.mlp(self.norm2(x))
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int), None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (input_size is not None), 'Input size must be provided if using relative positional encoding.'
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
return self.proj(x)
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int],
hw: Tuple[int, int]) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode='linear',
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh)
rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw)
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
B, q_h * q_w, k_h * k_w)
return attn
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C

View File

@ -0,0 +1,49 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
import torch
from torch import nn
from .decoders import MaskDecoder
from .encoders import ImageEncoderViT, PromptEncoder
class Sam(nn.Module):
mask_threshold: float = 0.0
image_format: str = 'RGB'
def __init__(
self,
image_encoder: ImageEncoderViT,
prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder,
pixel_mean: List[float] = (123.675, 116.28, 103.53),
pixel_std: List[float] = (58.395, 57.12, 57.375)
) -> None:
"""
SAM predicts object masks from an image and input prompts.
Note:
All forward() operations moved to SAMPredictor.
Args:
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
efficient mask prediction.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
pixel_std (list(float)): Std values for normalizing pixels in the input image.
"""
super().__init__()
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)

View File

@ -0,0 +1,582 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# --------------------------------------------------------
# TinyViT Model Architecture
# Copyright (c) 2022 Microsoft
# Adapted from LeViT and Swin Transformer
# LeViT: (https://github.com/facebookresearch/levit)
# Swin: (https://github.com/microsoft/swin-transformer)
# Build the TinyViT Model
# --------------------------------------------------------
import itertools
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from ultralytics.utils.instance import to_2tuple
class Conv2d_BN(torch.nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
super().__init__()
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
class PatchEmbed(nn.Module):
def __init__(self, in_chans, embed_dim, resolution, activation):
super().__init__()
img_size: Tuple[int, int] = to_2tuple(resolution)
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
n = embed_dim
self.seq = nn.Sequential(
Conv2d_BN(in_chans, n // 2, 3, 2, 1),
activation(),
Conv2d_BN(n // 2, n, 3, 2, 1),
)
def forward(self, x):
return self.seq(x)
class MBConv(nn.Module):
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
super().__init__()
self.in_chans = in_chans
self.hidden_chans = int(in_chans * expand_ratio)
self.out_chans = out_chans
self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
self.act1 = activation()
self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans)
self.act2 = activation()
self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
self.act3 = activation()
# NOTE: `DropPath` is needed only for training.
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.act2(x)
x = self.conv3(x)
x = self.drop_path(x)
x += shortcut
return self.act3(x)
class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim, out_dim, activation):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.out_dim = out_dim
self.act = activation()
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
stride_c = 1 if out_dim in [320, 448, 576] else 2
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
def forward(self, x):
if x.ndim == 3:
H, W = self.input_resolution
B = len(x)
# (B, C, H, W)
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
x = self.conv1(x)
x = self.act(x)
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
return x.flatten(2).transpose(1, 2)
class ConvLayer(nn.Module):
def __init__(
self,
dim,
input_resolution,
depth,
activation,
drop_path=0.,
downsample=None,
use_checkpoint=False,
out_dim=None,
conv_expand_ratio=4.,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
MBConv(
dim,
dim,
conv_expand_ratio,
activation,
drop_path[i] if isinstance(drop_path, list) else drop_path,
) for i in range(depth)])
# patch merging layer
self.downsample = None if downsample is None else downsample(
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
def forward(self, x):
for blk in self.blocks:
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
return x if self.downsample is None else self.downsample(x)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.norm = nn.LayerNorm(in_features)
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
self.act = act_layer()
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
return self.drop(x)
class Attention(torch.nn.Module):
def __init__(
self,
dim,
key_dim,
num_heads=8,
attn_ratio=4,
resolution=(14, 14),
):
super().__init__()
# (h, w)
assert isinstance(resolution, tuple) and len(resolution) == 2
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
h = self.dh + nh_kd * 2
self.norm = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, h)
self.proj = nn.Linear(self.dh, dim)
points = list(itertools.product(range(resolution[0]), range(resolution[1])))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # x (B,N,C)
B, N, _ = x.shape
# Normalization
x = self.norm(x)
qkv = self.qkv(x)
# (B, N, num_heads, d)
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
# (B, num_heads, N, d)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
self.ab = self.ab.to(self.attention_biases.device)
attn = ((q @ k.transpose(-2, -1)) * self.scale +
(self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab))
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
return self.proj(x)
class TinyViTBlock(nn.Module):
"""
TinyViT Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int, int]): Input resolution.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
local_conv_size (int): the kernel size of the convolution between Attention and MLP. Default: 3
activation (torch.nn): the activation function. Default: nn.GELU
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
local_conv_size=3,
activation=nn.GELU,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
assert window_size > 0, 'window_size must be greater than 0'
self.window_size = window_size
self.mlp_ratio = mlp_ratio
# NOTE: `DropPath` is needed only for training.
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = nn.Identity()
assert dim % num_heads == 0, 'dim must be divisible by num_heads'
head_dim = dim // num_heads
window_resolution = (window_size, window_size)
self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution)
mlp_hidden_dim = int(dim * mlp_ratio)
mlp_activation = activation
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=mlp_activation, drop=drop)
pad = local_conv_size // 2
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, 'input feature has wrong size'
res_x = x
if H == self.window_size and W == self.window_size:
x = self.attn(x)
else:
x = x.view(B, H, W, C)
pad_b = (self.window_size - H % self.window_size) % self.window_size
pad_r = (self.window_size - W % self.window_size) % self.window_size
padding = pad_b > 0 or pad_r > 0
if padding:
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
pH, pW = H + pad_b, W + pad_r
nH = pH // self.window_size
nW = pW // self.window_size
# window partition
x = x.view(B, nH, self.window_size, nW, self.window_size,
C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C)
x = self.attn(x)
# window reverse
x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
if padding:
x = x[:, :H, :W].contiguous()
x = x.view(B, L, C)
x = res_x + self.drop_path(x)
x = x.transpose(1, 2).reshape(B, C, H, W)
x = self.local_conv(x)
x = x.view(B, C, L).transpose(1, 2)
return x + self.drop_path(self.mlp(x))
def extra_repr(self) -> str:
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
class BasicLayer(nn.Module):
"""
A basic TinyViT layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
local_conv_size (int): the kernel size of the depthwise convolution between attention and MLP. Default: 3
activation (torch.nn): the activation function. Default: nn.GELU
out_dim (int | optional): the output dimension of the layer. Default: None
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
downsample=None,
use_checkpoint=False,
local_conv_size=3,
activation=nn.GELU,
out_dim=None,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
TinyViTBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
local_conv_size=local_conv_size,
activation=activation,
) for i in range(depth)])
# patch merging layer
self.downsample = None if downsample is None else downsample(
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
def forward(self, x):
for blk in self.blocks:
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
return x if self.downsample is None else self.downsample(x)
def extra_repr(self) -> str:
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
return self.weight[:, None, None] * x + self.bias[:, None, None]
class TinyViT(nn.Module):
def __init__(
self,
img_size=224,
in_chans=3,
num_classes=1000,
embed_dims=[96, 192, 384, 768],
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.1,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=1.0,
):
super().__init__()
self.img_size = img_size
self.num_classes = num_classes
self.depths = depths
self.num_layers = len(depths)
self.mlp_ratio = mlp_ratio
activation = nn.GELU
self.patch_embed = PatchEmbed(in_chans=in_chans,
embed_dim=embed_dims[0],
resolution=img_size,
activation=activation)
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
kwargs = dict(
dim=embed_dims[i_layer],
input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
# patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
out_dim=embed_dims[min(i_layer + 1,
len(embed_dims) - 1)],
activation=activation,
)
if i_layer == 0:
layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
else:
layer = BasicLayer(num_heads=num_heads[i_layer],
window_size=window_sizes[i_layer],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
local_conv_size=local_conv_size,
**kwargs)
self.layers.append(layer)
# Classifier head
self.norm_head = nn.LayerNorm(embed_dims[-1])
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
# init weights
self.apply(self._init_weights)
self.set_layer_lr_decay(layer_lr_decay)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dims[-1],
256,
kernel_size=1,
bias=False,
),
LayerNorm2d(256),
nn.Conv2d(
256,
256,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(256),
)
def set_layer_lr_decay(self, layer_lr_decay):
decay_rate = layer_lr_decay
# layers -> blocks (depth)
depth = sum(self.depths)
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
def _set_lr_scale(m, scale):
for p in m.parameters():
p.lr_scale = scale
self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
i = 0
for layer in self.layers:
for block in layer.blocks:
block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
i += 1
if layer.downsample is not None:
layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))
assert i == depth
for m in [self.norm_head, self.head]:
m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
for k, p in self.named_parameters():
p.param_name = k
def _check_lr_scale(m):
for p in m.parameters():
assert hasattr(p, 'lr_scale'), p.param_name
self.apply(_check_lr_scale)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# NOTE: This initialization is needed only for training.
# trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'attention_biases'}
def forward_features(self, x):
# x: (N, C, H, W)
x = self.patch_embed(x)
x = self.layers[0](x)
start_i = 1
for i in range(start_i, len(self.layers)):
layer = self.layers[i]
x = layer(x)
B, _, C = x.size()
x = x.view(B, 64, 64, C)
x = x.permute(0, 3, 1, 2)
return self.neck(x)
def forward(self, x):
return self.forward_features(x)

View File

@ -0,0 +1,230 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import math
from typing import Tuple, Type
import torch
from torch import Tensor, nn
from ultralytics.nn.modules import MLPBlock
class TwoWayTransformer(nn.Module):
def __init__(
self,
depth: int,
embedding_dim: int,
num_heads: int,
mlp_dim: int,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
) -> None:
"""
A transformer decoder that attends to an input image using
queries whose positional embedding is supplied.
Args:
depth (int): number of layers in the transformer
embedding_dim (int): the channel dimension for the input embeddings
num_heads (int): the number of heads for multihead attention. Must
divide embedding_dim
mlp_dim (int): the channel dimension internal to the MLP block
activation (nn.Module): the activation to use in the MLP block
"""
super().__init__()
self.depth = depth
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.layers = nn.ModuleList()
for i in range(depth):
self.layers.append(
TwoWayAttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
activation=activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
))
self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
def forward(
self,
image_embedding: Tensor,
image_pe: Tensor,
point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
(torch.Tensor): the processed point_embedding
(torch.Tensor): the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
# Prepare queries
queries = point_embedding
keys = image_embedding
# Apply transformer blocks and final layernorm
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
# Apply the final attention layer from the points to the image
q = queries + point_embedding
k = keys + image_pe
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries)
return queries, keys
class TwoWayAttentionBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
num_heads: int,
mlp_dim: int = 2048,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""
A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse
inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse
inputs.
Args:
embedding_dim (int): the channel dimension of the embeddings
num_heads (int): the number of heads in the attention layers
mlp_dim (int): the hidden dimension of the mlp block
activation (nn.Module): the activation of the mlp block
skip_first_layer_pe (bool): skip the PE on the first layer
"""
super().__init__()
self.self_attn = Attention(embedding_dim, num_heads)
self.norm1 = nn.LayerNorm(embedding_dim)
self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.norm2 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
"""Apply self-attention and cross-attention to queries and keys and return the processed embeddings."""
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)
return queries, keys
class Attention(nn.Module):
"""
An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values.
"""
def __init__(
self,
embedding_dim: int,
num_heads: int,
downsample_rate: int = 1,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.'
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
"""Separate the input tensor into the specified number of attention heads."""
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
def _recombine_heads(self, x: Tensor) -> Tensor:
"""Recombine the separated attention heads into a single tensor."""
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
"""Compute the attention output given the input query, key, and value tensors."""
# Input projections
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# Separate into heads
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)
# Attention
_, _, _, c_per_head = q.shape
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)
# Get output
out = attn @ v
out = self._recombine_heads(out)
return self.out_proj(out)

View File

@ -0,0 +1,408 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from ultralytics.data.augment import LetterBox
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ops
from ultralytics.utils.torch_utils import select_device
from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score,
generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks)
from .build import build_sam
class Predictor(BasePredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
if overrides is None:
overrides = {}
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
super().__init__(cfg, overrides, _callbacks)
# SAM needs retina_masks=True, or the results would be a mess.
self.args.retina_masks = True
# Args for set_image
self.im = None
self.features = None
# Args for set_prompts
self.prompts = {}
# Args for segment everything
self.segment_all = False
def preprocess(self, im):
"""Prepares input image before inference.
Args:
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
"""
if self.im is not None:
return self.im
not_tensor = not isinstance(im, torch.Tensor)
if not_tensor:
im = np.stack(self.pre_transform(im))
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
im = np.ascontiguousarray(im) # contiguous
im = torch.from_numpy(im)
im = im.to(self.device)
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
if not_tensor:
im = (im - self.mean) / self.std
return im
def pre_transform(self, im):
"""
Pre-transform input image before inference.
Args:
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
Returns:
(list): A list of transformed images.
"""
assert len(im) == 1, 'SAM model does not currently support batched inference'
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
return [letterbox(image=x) for x in im]
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
"""
Predict masks for the given input prompts, using the currently set image.
Args:
im (torch.Tensor): The preprocessed image, (N, C, H, W).
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
1 indicates a foreground point and 0 indicates a background point.
masks (np.ndarray, None): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form (N, H, W), where
for SAM, H=W=256.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
Returns:
(np.ndarray): The output masks in CxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(np.ndarray): An array of length C containing the model's
predictions for the quality of each mask.
(np.ndarray): An array of shape CxHxW, where C is the number
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
# Get prompts from self.prompts first
bboxes = self.prompts.pop('bboxes', bboxes)
points = self.prompts.pop('points', points)
masks = self.prompts.pop('masks', masks)
if all(i is None for i in [bboxes, points, masks]):
return self.generate(im, *args, **kwargs)
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
"""
Predict masks for the given input prompts, using the currently set image.
Args:
im (torch.Tensor): The preprocessed image, (N, C, H, W).
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
1 indicates a foreground point and 0 indicates a background point.
masks (np.ndarray, None): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form (N, H, W), where
for SAM, H=W=256.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
Returns:
(np.ndarray): The output masks in CxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(np.ndarray): An array of length C containing the model's
predictions for the quality of each mask.
(np.ndarray): An array of shape CxHxW, where C is the number
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
features = self.model.image_encoder(im) if self.features is None else self.features
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
# Transform input prompts
if points is not None:
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
points = points[None] if points.ndim == 1 else points
# Assuming labels are all positive if users don't pass labels.
if labels is None:
labels = np.ones(points.shape[0])
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
points *= r
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
points, labels = points[:, None, :], labels[:, None]
if bboxes is not None:
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
bboxes *= r
if masks is not None:
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
points = (points, labels) if points is not None else None
# Embed prompts
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=points,
boxes=bboxes,
masks=masks,
)
# Predict masks
pred_masks, pred_scores = self.model.mask_decoder(
image_embeddings=features,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
# `d` could be 1 or 3 depends on `multimask_output`.
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
def generate(self,
im,
crop_n_layers=0,
crop_overlap_ratio=512 / 1500,
crop_downscale_factor=1,
point_grids=None,
points_stride=32,
points_batch_size=64,
conf_thres=0.88,
stability_score_thresh=0.95,
stability_score_offset=0.95,
crop_nms_thresh=0.7):
"""Segment the whole image.
Args:
im (torch.Tensor): The preprocessed image, (N, C, H, W).
crop_n_layers (int): If >0, mask prediction will be run again on
crops of the image. Sets the number of layers to run, where each
layer has 2**i_layer number of image crops.
crop_overlap_ratio (float): Sets the degree to which crops overlap.
In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
crop_downscale_factor (int): The number of points-per-side
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
point_grids (list(np.ndarray), None): A list over explicit grids
of points used for sampling, normalized to [0,1]. The nth grid in the
list is used in the nth crop layer. Exclusive with points_per_side.
points_stride (int, None): The number of points to be sampled
along one side of the image. The total number of points is
points_per_side**2. If None, 'point_grids' must provide explicit
point sampling.
points_batch_size (int): Sets the number of points run simultaneously
by the model. Higher numbers may be faster but use more GPU memory.
conf_thres (float): A filtering threshold in [0,1], using the
model's predicted mask quality.
stability_score_thresh (float): A filtering threshold in [0,1], using
the stability of the mask under changes to the cutoff used to binarize
the model's mask predictions.
stability_score_offset (float): The amount to shift the cutoff when
calculated the stability score.
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks between different crops.
"""
self.segment_all = True
ih, iw = im.shape[2:]
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
if point_grids is None:
point_grids = build_all_layer_point_grids(
points_stride,
crop_n_layers,
crop_downscale_factor,
)
pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
for crop_region, layer_idx in zip(crop_regions, layer_idxs):
x1, y1, x2, y2 = crop_region
w, h = x2 - x1, y2 - y1
area = torch.tensor(w * h, device=im.device)
points_scale = np.array([[w, h]]) # w, h
# Crop image and interpolate to input size
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False)
# (num_points, 2)
points_for_image = point_grids[layer_idx] * points_scale
crop_masks, crop_scores, crop_bboxes = [], [], []
for (points, ) in batch_iterator(points_batch_size, points_for_image):
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
# Interpolate predicted masks to input size
pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0]
idx = pred_score > conf_thres
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold,
stability_score_offset)
idx = stability_score > stability_score_thresh
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
# Bool type is much more memory-efficient.
pred_mask = pred_mask > self.model.mask_threshold
# (N, 4)
pred_bbox = batched_mask_to_box(pred_mask).float()
keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih])
if not torch.all(keep_mask):
pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask]
crop_masks.append(pred_mask)
crop_bboxes.append(pred_bbox)
crop_scores.append(pred_score)
# Do nms within this crop
crop_masks = torch.cat(crop_masks)
crop_bboxes = torch.cat(crop_bboxes)
crop_scores = torch.cat(crop_scores)
keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS
crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region)
crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw)
crop_scores = crop_scores[keep]
pred_masks.append(crop_masks)
pred_bboxes.append(crop_bboxes)
pred_scores.append(crop_scores)
region_areas.append(area.expand(len(crop_masks)))
pred_masks = torch.cat(pred_masks)
pred_bboxes = torch.cat(pred_bboxes)
pred_scores = torch.cat(pred_scores)
region_areas = torch.cat(region_areas)
# Remove duplicate masks between crops
if len(crop_regions) > 1:
scores = 1 / region_areas
keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh)
pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep]
return pred_masks, pred_scores, pred_bboxes
def setup_model(self, model, verbose=True):
"""Set up YOLO model with specified thresholds and device."""
device = select_device(self.args.device, verbose=verbose)
if model is None:
model = build_sam(self.args.model)
model.eval()
self.model = model.to(device)
self.device = device
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
# TODO: Temporary settings for compatibility
self.model.pt = False
self.model.triton = False
self.model.stride = 32
self.model.fp16 = False
self.done_warmup = True
def postprocess(self, preds, img, orig_imgs):
"""Post-processes inference output predictions to create detection masks for objects."""
# (N, 1, H, W), (N, 1)
pred_masks, pred_scores = preds[:2]
pred_bboxes = preds[2] if self.segment_all else None
names = dict(enumerate(str(i) for i in range(len(pred_masks))))
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for i, masks in enumerate([pred_masks]):
orig_img = orig_imgs[i]
if pred_bboxes is not None:
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
masks = masks > self.model.mask_threshold # to bool
img_path = self.batch[0][i]
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
# Reset segment-all mode.
self.segment_all = False
return results
def setup_source(self, source):
"""Sets up source and inference mode."""
if source is not None:
super().setup_source(source)
def set_image(self, image):
"""Set image in advance.
Args:
image (str | np.ndarray): image file path or np.ndarray image by cv2.
"""
if self.model is None:
model = build_sam(self.args.model)
self.setup_model(model)
self.setup_source(image)
assert len(self.dataset) == 1, '`set_image` only supports setting one image!'
for batch in self.dataset:
im = self.preprocess(batch[1])
self.features = self.model.image_encoder(im)
self.im = im
break
def set_prompts(self, prompts):
"""Set prompts in advance."""
self.prompts = prompts
def reset_image(self):
self.im = None
self.features = None
@staticmethod
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
"""
Removes small disconnected regions and holes in masks, then reruns
box NMS to remove any new duplicates. Requires open-cv as a dependency.
Args:
masks (torch.Tensor): Masks, (N, H, W).
min_area (int): Minimum area threshold.
nms_thresh (float): NMS threshold.
Returns:
new_masks (torch.Tensor): New Masks, (N, H, W).
keep (List[int]): The indices of the new masks, which can be used to filter
the corresponding boxes.
"""
if len(masks) == 0:
return masks
# Filter small disconnected regions and holes
new_masks = []
scores = []
for mask in masks:
mask = mask.cpu().numpy().astype(np.uint8)
mask, changed = remove_small_regions(mask, min_area, mode='holes')
unchanged = not changed
mask, changed = remove_small_regions(mask, min_area, mode='islands')
unchanged = unchanged and not changed
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
# Give score=0 to changed masks and score=1 to unchanged masks
# so NMS will prefer ones that didn't need postprocessing
scores.append(float(unchanged))
# Recalculate boxes and remove any new duplicates
new_masks = torch.cat(new_masks, dim=0)
boxes = batched_mask_to_box(new_masks)
keep = torchvision.ops.nms(
boxes.float(),
torch.as_tensor(scores),
nms_thresh,
)
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep

View File

@ -0,0 +1 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

View File

@ -0,0 +1,297 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.utils.loss import FocalLoss, VarifocalLoss
from ultralytics.utils.metrics import bbox_iou
from .ops import HungarianMatcher
class DETRLoss(nn.Module):
def __init__(self,
nc=80,
loss_gain=None,
aux_loss=True,
use_fl=True,
use_vfl=False,
use_uni_match=False,
uni_match_ind=0):
"""
DETR loss function.
Args:
nc (int): The number of classes.
loss_gain (dict): The coefficient of loss.
aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
use_vfl (bool): Use VarifocalLoss or not.
use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch.
uni_match_ind (int): The fixed indices of a layer.
"""
super().__init__()
if loss_gain is None:
loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1}
self.nc = nc
self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2})
self.loss_gain = loss_gain
self.aux_loss = aux_loss
self.fl = FocalLoss() if use_fl else None
self.vfl = VarifocalLoss() if use_vfl else None
self.use_uni_match = use_uni_match
self.uni_match_ind = uni_match_ind
self.device = None
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
name_class = f'loss_class{postfix}'
bs, nq = pred_scores.shape[:2]
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
one_hot.scatter_(2, targets.unsqueeze(-1), 1)
one_hot = one_hot[..., :-1]
gt_scores = gt_scores.view(bs, nq, 1) * one_hot
if self.fl:
if num_gts and self.vfl:
loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
else:
loss_cls = self.fl(pred_scores, one_hot.float())
loss_cls /= max(num_gts, 1) / nq
else:
loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
return {name_class: loss_cls.squeeze() * self.loss_gain['class']}
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''):
# boxes: [b, query, 4], gt_bbox: list[[n, 4]]
name_bbox = f'loss_bbox{postfix}'
name_giou = f'loss_giou{postfix}'
loss = {}
if len(gt_bboxes) == 0:
loss[name_bbox] = torch.tensor(0., device=self.device)
loss[name_giou] = torch.tensor(0., device=self.device)
return loss
loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes)
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
loss[name_giou] = self.loss_gain['giou'] * loss[name_giou]
return {k: v.squeeze() for k, v in loss.items()}
# This function is for future RT-DETR Segment models
# def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
# # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
# name_mask = f'loss_mask{postfix}'
# name_dice = f'loss_dice{postfix}'
#
# loss = {}
# if sum(len(a) for a in gt_mask) == 0:
# loss[name_mask] = torch.tensor(0., device=self.device)
# loss[name_dice] = torch.tensor(0., device=self.device)
# return loss
#
# num_gts = len(gt_mask)
# src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
# src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
# # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
# loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
# torch.tensor([num_gts], dtype=torch.float32))
# loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
# return loss
# This function is for future RT-DETR Segment models
# @staticmethod
# def _dice_loss(inputs, targets, num_gts):
# inputs = F.sigmoid(inputs).flatten(1)
# targets = targets.flatten(1)
# numerator = 2 * (inputs * targets).sum(1)
# denominator = inputs.sum(-1) + targets.sum(-1)
# loss = 1 - (numerator + 1) / (denominator + 1)
# return loss.sum() / num_gts
def _get_loss_aux(self,
pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
match_indices=None,
postfix='',
masks=None,
gt_mask=None):
"""Get auxiliary losses"""
# NOTE: loss class, bbox, giou, mask, dice
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
if match_indices is None and self.use_uni_match:
match_indices = self.matcher(pred_bboxes[self.uni_match_ind],
pred_scores[self.uni_match_ind],
gt_bboxes,
gt_cls,
gt_groups,
masks=masks[self.uni_match_ind] if masks is not None else None,
gt_mask=gt_mask)
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
aux_masks = masks[i] if masks is not None else None
loss_ = self._get_loss(aux_bboxes,
aux_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=aux_masks,
gt_mask=gt_mask,
postfix=postfix,
match_indices=match_indices)
loss[0] += loss_[f'loss_class{postfix}']
loss[1] += loss_[f'loss_bbox{postfix}']
loss[2] += loss_[f'loss_giou{postfix}']
# if masks is not None and gt_mask is not None:
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
# loss[3] += loss_[f'loss_mask{postfix}']
# loss[4] += loss_[f'loss_dice{postfix}']
loss = {
f'loss_class_aux{postfix}': loss[0],
f'loss_bbox_aux{postfix}': loss[1],
f'loss_giou_aux{postfix}': loss[2]}
# if masks is not None and gt_mask is not None:
# loss[f'loss_mask_aux{postfix}'] = loss[3]
# loss[f'loss_dice_aux{postfix}'] = loss[4]
return loss
@staticmethod
def _get_index(match_indices):
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
src_idx = torch.cat([src for (src, _) in match_indices])
dst_idx = torch.cat([dst for (_, dst) in match_indices])
return (batch_idx, src_idx), dst_idx
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
pred_assigned = torch.cat([
t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (I, _) in zip(pred_bboxes, match_indices)])
gt_assigned = torch.cat([
t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (_, J) in zip(gt_bboxes, match_indices)])
return pred_assigned, gt_assigned
def _get_loss(self,
pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=None,
gt_mask=None,
postfix='',
match_indices=None):
"""Get losses"""
if match_indices is None:
match_indices = self.matcher(pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=masks,
gt_mask=gt_mask)
idx, gt_idx = self._get_index(match_indices)
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
bs, nq = pred_scores.shape[:2]
targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
targets[idx] = gt_cls[gt_idx]
gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
if len(gt_bboxes):
gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
loss = {}
loss.update(self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix))
loss.update(self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix))
# if masks is not None and gt_mask is not None:
# loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
return loss
def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs):
"""
Args:
pred_bboxes (torch.Tensor): [l, b, query, 4]
pred_scores (torch.Tensor): [l, b, query, num_classes]
batch (dict): A dict includes:
gt_cls (torch.Tensor) with shape [num_gts, ],
gt_bboxes (torch.Tensor): [num_gts, 4],
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
postfix (str): postfix of loss name.
"""
self.device = pred_bboxes.device
match_indices = kwargs.get('match_indices', None)
gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups']
total_loss = self._get_loss(pred_bboxes[-1],
pred_scores[-1],
gt_bboxes,
gt_cls,
gt_groups,
postfix=postfix,
match_indices=match_indices)
if self.aux_loss:
total_loss.update(
self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices,
postfix))
return total_loss
class RTDETRDetectionLoss(DETRLoss):
def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
pred_bboxes, pred_scores = preds
total_loss = super().forward(pred_bboxes, pred_scores, batch)
if dn_meta is not None:
dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group']
assert len(batch['gt_groups']) == len(dn_pos_idx)
# Denoising match indices
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups'])
# Compute denoising training loss
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices)
total_loss.update(dn_loss)
else:
total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()})
return total_loss
@staticmethod
def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
"""
Get the match indices for denoising.
Args:
dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising.
dn_num_group (int): The number of groups of denoising.
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
Returns:
dn_match_indices (List(tuple)): Matched indices.
"""
dn_match_indices = []
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
for i, num_gt in enumerate(gt_groups):
if num_gt > 0:
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
gt_idx = gt_idx.repeat(dn_num_group)
assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, '
f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.'
dn_match_indices.append((dn_pos_idx[i], gt_idx))
else:
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
return dn_match_indices

View File

@ -0,0 +1,261 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from ultralytics.utils.metrics import bbox_iou
from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
class HungarianMatcher(nn.Module):
"""
A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in
an end-to-end fashion.
HungarianMatcher performs optimal assignment over the predicted and ground truth bounding boxes using a cost
function that considers classification scores, bounding box coordinates, and optionally, mask predictions.
Attributes:
cost_gain (dict): Dictionary of cost coefficients: 'class', 'bbox', 'giou', 'mask', and 'dice'.
use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation.
with_mask (bool): Indicates whether the model makes mask predictions.
num_sample_points (int): The number of sample points used in mask cost calculation.
alpha (float): The alpha factor in Focal Loss calculation.
gamma (float): The gamma factor in Focal Loss calculation.
Methods:
forward(pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None): Computes the
assignment between predictions and ground truths for a batch.
_cost_mask(bs, num_gts, masks=None, gt_mask=None): Computes the mask cost and dice cost if masks are predicted.
"""
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
super().__init__()
if cost_gain is None:
cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1}
self.cost_gain = cost_gain
self.use_fl = use_fl
self.with_mask = with_mask
self.num_sample_points = num_sample_points
self.alpha = alpha
self.gamma = gamma
def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
"""
Forward pass for HungarianMatcher. This function computes costs based on prediction and ground truth
(classification cost, L1 cost between boxes and GIoU cost between boxes) and finds the optimal matching
between predictions and ground truth based on these costs.
Args:
pred_bboxes (Tensor): Predicted bounding boxes with shape [batch_size, num_queries, 4].
pred_scores (Tensor): Predicted scores with shape [batch_size, num_queries, num_classes].
gt_cls (torch.Tensor): Ground truth classes with shape [num_gts, ].
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape [num_gts, 4].
gt_groups (List[int]): List of length equal to batch size, containing the number of ground truths for
each image.
masks (Tensor, optional): Predicted masks with shape [batch_size, num_queries, height, width].
Defaults to None.
gt_mask (List[Tensor], optional): List of ground truth masks, each with shape [num_masks, Height, Width].
Defaults to None.
Returns:
(List[Tuple[Tensor, Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where:
- index_i is the tensor of indices of the selected predictions (in order)
- index_j is the tensor of indices of the corresponding selected ground truth targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
bs, nq, nc = pred_scores.shape
if sum(gt_groups) == 0:
return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
# We flatten to compute the cost matrices in a batch
# [batch_size * num_queries, num_classes]
pred_scores = pred_scores.detach().view(-1, nc)
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
# [batch_size * num_queries, 4]
pred_bboxes = pred_bboxes.detach().view(-1, 4)
# Compute the classification cost
pred_scores = pred_scores[:, gt_cls]
if self.use_fl:
neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log())
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
cost_class = pos_cost_class - neg_cost_class
else:
cost_class = -pred_scores
# Compute the L1 cost between boxes
cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
# Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
# Final cost matrix
C = self.cost_gain['class'] * cost_class + \
self.cost_gain['bbox'] * cost_bbox + \
self.cost_gain['giou'] * cost_giou
# Compute the mask cost and dice cost
if self.with_mask:
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
C = C.view(bs, nq, -1).cpu()
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
# (idx for queries, idx for gt)
return [(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
for k, (i, j) in enumerate(indices)]
# This function is for future RT-DETR Segment models
# def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
# assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
# # all masks share the same set of points for efficient matching
# sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
# sample_points = 2.0 * sample_points - 1.0
#
# out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
# out_mask = out_mask.flatten(0, 1)
#
# tgt_mask = torch.cat(gt_mask).unsqueeze(1)
# sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
# tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
#
# with torch.cuda.amp.autocast(False):
# # binary cross entropy cost
# pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
# neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
# cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
# cost_mask /= self.num_sample_points
#
# # dice cost
# out_mask = F.sigmoid(out_mask)
# numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
# denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
# cost_dice = 1 - (numerator + 1) / (denominator + 1)
#
# C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
# return C
def get_cdn_group(batch,
num_classes,
num_queries,
class_embed,
num_dn=100,
cls_noise_ratio=0.5,
box_noise_scale=1.0,
training=False):
"""
Get contrastive denoising training group. This function creates a contrastive denoising training group with
positive and negative samples from the ground truths (gt). It applies noise to the class labels and bounding
box coordinates, and returns the modified labels, bounding boxes, attention mask and meta information.
Args:
batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape [num_gts, ]), 'gt_bboxes'
(torch.Tensor with shape [num_gts, 4]), 'gt_groups' (List(int)) which is a list of batch size length
indicating the number of gts of each image.
num_classes (int): Number of classes.
num_queries (int): Number of queries.
class_embed (torch.Tensor): Embedding weights to map class labels to embedding space.
num_dn (int, optional): Number of denoising. Defaults to 100.
cls_noise_ratio (float, optional): Noise ratio for class labels. Defaults to 0.5.
box_noise_scale (float, optional): Noise scale for bounding box coordinates. Defaults to 1.0.
training (bool, optional): If it's in training mode. Defaults to False.
Returns:
(Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Dict]]): The modified class embeddings,
bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn'
is less than or equal to 0, the function returns None for all elements in the tuple.
"""
if (not training) or num_dn <= 0:
return None, None, None, None
gt_groups = batch['gt_groups']
total_num = sum(gt_groups)
max_nums = max(gt_groups)
if max_nums == 0:
return None, None, None, None
num_group = num_dn // max_nums
num_group = 1 if num_group == 0 else num_group
# pad gt to max_num of a batch
bs = len(gt_groups)
gt_cls = batch['cls'] # (bs*num, )
gt_bbox = batch['bboxes'] # bs*num, 4
b_idx = batch['batch_idx']
# each group has positive and negative queries.
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
# positive and negative mask
# (bs*num*num_group, ), the second total_num*num_group part as negative samples
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
if cls_noise_ratio > 0:
# half of bbox prob
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
idx = torch.nonzero(mask).squeeze(-1)
# randomly put a new one here
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
dn_cls[idx] = new_label
if box_noise_scale > 0:
known_bbox = xywh2xyxy(dn_bbox)
diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
rand_part = torch.rand_like(dn_bbox)
rand_part[neg_idx] += 1.0
rand_part *= rand_sign
known_bbox += rand_part * diff
known_bbox.clip_(min=0.0, max=1.0)
dn_bbox = xyxy2xywh(known_bbox)
dn_bbox = inverse_sigmoid(dn_bbox)
# total denoising queries
num_dn = int(max_nums * 2 * num_group)
# class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
tgt_size = num_dn + num_queries
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
# match query cannot see the reconstruct
attn_mask[num_dn:, :num_dn] = True
# reconstruct cannot see each other
for i in range(num_group):
if i == 0:
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
if i == num_group - 1:
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True
else:
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True
dn_meta = {
'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
'dn_num_group': num_group,
'dn_num_split': [num_dn, num_queries]}
return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to(
class_embed.device), dn_meta
def inverse_sigmoid(x, eps=1e-6):
"""Inverse sigmoid function."""
x = x.clip(min=0., max=1.)
return torch.log(x / (1 - x + eps) + eps)

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.models.yolo import classify, detect, pose, segment
from .model import YOLO
__all__ = 'classify', 'segment', 'detect', 'pose', 'YOLO'

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.models.yolo.classify.predict import ClassificationPredictor
from ultralytics.models.yolo.classify.train import ClassificationTrainer
from ultralytics.models.yolo.classify.val import ClassificationValidator
__all__ = 'ClassificationPredictor', 'ClassificationTrainer', 'ClassificationValidator'

View File

@ -0,0 +1,49 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ops
class ClassificationPredictor(BasePredictor):
"""
A class extending the BasePredictor class for prediction based on a classification model.
Notes:
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
Example:
```python
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.classify import ClassificationPredictor
args = dict(model='yolov8n-cls.pt', source=ASSETS)
predictor = ClassificationPredictor(overrides=args)
predictor.predict_cli()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'classify'
def preprocess(self, img):
"""Converts input image to model-compatible data type."""
if not isinstance(img, torch.Tensor):
img = torch.stack([self.transforms(im) for im in img], dim=0)
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
def postprocess(self, preds, img, orig_imgs):
"""Post-processes predictions to return Results objects."""
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i]
img_path = self.batch[0][i]
results.append(Results(orig_img, path=img_path, names=self.model.names, probs=pred))
return results

View File

@ -0,0 +1,150 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
import torchvision
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
from ultralytics.utils.plotting import plot_images, plot_results
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
class ClassificationTrainer(BaseTrainer):
"""
A class extending the BaseTrainer class for training based on a classification model.
Notes:
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
Example:
```python
from ultralytics.models.yolo.classify import ClassificationTrainer
args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3)
trainer = ClassificationTrainer(overrides=args)
trainer.train()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
if overrides is None:
overrides = {}
overrides['task'] = 'classify'
if overrides.get('imgsz') is None:
overrides['imgsz'] = 224
super().__init__(cfg, overrides, _callbacks)
def set_model_attributes(self):
"""Set the YOLO model's class names from the loaded dataset."""
self.model.names = self.data['names']
def get_model(self, cfg=None, weights=None, verbose=True):
"""Returns a modified PyTorch model configured for training YOLO."""
model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
for m in model.modules():
if not self.args.pretrained and hasattr(m, 'reset_parameters'):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
m.p = self.args.dropout # set dropout
for p in model.parameters():
p.requires_grad = True # for training
return model
def setup_model(self):
"""Load, create or download model for any task."""
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
model, ckpt = str(self.model), None
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith('.pt'):
self.model, ckpt = attempt_load_one_weight(model, device='cpu')
for p in self.model.parameters():
p.requires_grad = True # for training
elif model.split('.')[-1] in ('yaml', 'yml'):
self.model = self.get_model(cfg=model)
elif model in torchvision.models.__dict__:
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None)
else:
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
return ckpt
def build_dataset(self, img_path, mode='train', batch=None):
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train', prefix=mode)
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode)
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
# Attach inference transforms
if mode != 'train':
if is_parallel(self.model):
self.model.module.transforms = loader.dataset.torch_transforms
else:
self.model.transforms = loader.dataset.torch_transforms
return loader
def preprocess_batch(self, batch):
"""Preprocesses a batch of images and classes."""
batch['img'] = batch['img'].to(self.device)
batch['cls'] = batch['cls'].to(self.device)
return batch
def progress_string(self):
"""Returns a formatted string showing training progress."""
return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
def get_validator(self):
"""Returns an instance of ClassificationValidator for validation."""
self.loss_names = ['loss']
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir)
def label_loss_items(self, loss_items=None, prefix='train'):
"""
Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for
segmentation & detection
"""
keys = [f'{prefix}/{x}' for x in self.loss_names]
if loss_items is None:
return keys
loss_items = [round(float(loss_items), 5)]
return dict(zip(keys, loss_items))
def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
def final_eval(self):
"""Evaluate trained model and save validation results."""
for f in self.last, self.best:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
LOGGER.info(f'\nValidating {f}...')
self.validator.args.data = self.args.data
self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f)
self.metrics.pop('fitness', None)
self.run_callbacks('on_fit_epoch_end')
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
def plot_training_samples(self, batch, ni):
"""Plots training samples with their annotations."""
plot_images(
images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)

View File

@ -0,0 +1,110 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.validator import BaseValidator
from ultralytics.utils import LOGGER
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
from ultralytics.utils.plotting import plot_images
class ClassificationValidator(BaseValidator):
"""
A class extending the BaseValidator class for validation based on a classification model.
Notes:
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
Example:
```python
from ultralytics.models.yolo.classify import ClassificationValidator
args = dict(model='yolov8n-cls.pt', data='imagenet10')
validator = ClassificationValidator(args=args)
validator()
```
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.targets = None
self.pred = None
self.args.task = 'classify'
self.metrics = ClassifyMetrics()
def get_desc(self):
"""Returns a formatted string summarizing classification metrics."""
return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
def init_metrics(self, model):
"""Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
self.names = model.names
self.nc = len(model.names)
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task='classify')
self.pred = []
self.targets = []
def preprocess(self, batch):
"""Preprocesses input batch and returns it."""
batch['img'] = batch['img'].to(self.device, non_blocking=True)
batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
batch['cls'] = batch['cls'].to(self.device)
return batch
def update_metrics(self, preds, batch):
"""Updates running metrics with model predictions and batch targets."""
n5 = min(len(self.names), 5)
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
self.targets.append(batch['cls'])
def finalize_metrics(self, *args, **kwargs):
"""Finalizes metrics of the model such as confusion_matrix and speed."""
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir,
names=self.names.values(),
normalize=normalize,
on_plot=self.on_plot)
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
self.metrics.save_dir = self.save_dir
def get_stats(self):
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
self.metrics.process(self.targets, self.pred)
return self.metrics.results_dict
def build_dataset(self, img_path):
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
def get_dataloader(self, dataset_path, batch_size):
"""Builds and returns a data loader for classification tasks with given parameters."""
dataset = self.build_dataset(dataset_path)
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
def print_results(self):
"""Prints evaluation metrics for YOLO object detection model."""
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
def plot_val_samples(self, batch, ni):
"""Plot validation image samples."""
plot_images(
images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
plot_images(batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=torch.argmax(preds, dim=1),
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .predict import DetectionPredictor
from .train import DetectionTrainer
from .val import DetectionValidator
__all__ = 'DetectionPredictor', 'DetectionTrainer', 'DetectionValidator'

View File

@ -0,0 +1,41 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ops
class DetectionPredictor(BasePredictor):
"""
A class extending the BasePredictor class for prediction based on a detection model.
Example:
```python
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.detect import DetectionPredictor
args = dict(model='yolov8n.pt', source=ASSETS)
predictor = DetectionPredictor(overrides=args)
predictor.predict_cli()
```
"""
def postprocess(self, preds, img, orig_imgs):
"""Post-processes predictions and returns a list of Results objects."""
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i]
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
img_path = self.batch[0][i]
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
return results

View File

@ -0,0 +1,116 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy
import numpy as np
from ultralytics.data import build_dataloader, build_yolo_dataset
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import LOGGER, RANK
from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
class DetectionTrainer(BaseTrainer):
"""
A class extending the BaseTrainer class for training based on a detection model.
Example:
```python
from ultralytics.models.yolo.detect import DetectionTrainer
args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3)
trainer = DetectionTrainer(overrides=args)
trainer.train()
```
"""
def build_dataset(self, img_path, mode='train', batch=None):
"""
Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
"""Construct and return dataloader."""
assert mode in ['train', 'val']
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode, batch_size)
shuffle = mode == 'train'
if getattr(dataset, 'rect', False) and shuffle:
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
workers = self.args.workers if mode == 'train' else self.args.workers * 2
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
def preprocess_batch(self, batch):
"""Preprocesses a batch of images by scaling and converting to float."""
batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
return batch
def set_model_attributes(self):
"""nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
# self.args.box *= 3 / nl # scale to layers
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
self.model.nc = self.data['nc'] # attach number of classes to model
self.model.names = self.data['names'] # attach class names to model
self.model.args = self.args # attach hyperparameters to model
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return a YOLO detection model."""
model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Returns a DetectionValidator for YOLO model validation."""
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
return yolo.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def label_loss_items(self, loss_items=None, prefix='train'):
"""
Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for
segmentation & detection
"""
keys = [f'{prefix}/{x}' for x in self.loss_names]
if loss_items is not None:
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
return dict(zip(keys, loss_items))
else:
return keys
def progress_string(self):
"""Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
return ('\n' + '%11s' *
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
def plot_training_samples(self, batch, ni):
"""Plots training samples with their annotations."""
plot_images(images=batch['img'],
batch_idx=batch['batch_idx'],
cls=batch['cls'].squeeze(-1),
bboxes=batch['bboxes'],
paths=batch['im_file'],
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
def plot_training_labels(self):
"""Create a labeled training plot of the YOLO model."""
boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)

View File

@ -0,0 +1,268 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
from pathlib import Path
import numpy as np
import torch
from ultralytics.data import build_dataloader, build_yolo_dataset, converter
from ultralytics.engine.validator import BaseValidator
from ultralytics.utils import LOGGER, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
from ultralytics.utils.plotting import output_to_target, plot_images
from ultralytics.utils.torch_utils import de_parallel
class DetectionValidator(BaseValidator):
"""
A class extending the BaseValidator class for validation based on a detection model.
Example:
```python
from ultralytics.models.yolo.detect import DetectionValidator
args = dict(model='yolov8n.pt', data='coco8.yaml')
validator = DetectionValidator(args=args)
validator()
```
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize detection model with necessary variables and settings."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.nt_per_class = None
self.is_coco = False
self.class_map = None
self.args.task = 'detect'
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
self.niou = self.iouv.numel()
self.lb = [] # for autolabelling
def preprocess(self, batch):
"""Preprocesses batch of images for YOLO training."""
batch['img'] = batch['img'].to(self.device, non_blocking=True)
batch['img'] = (batch['img'].half() if self.args.half else batch['img'].float()) / 255
for k in ['batch_idx', 'cls', 'bboxes']:
batch[k] = batch[k].to(self.device)
if self.args.save_hybrid:
height, width = batch['img'].shape[2:]
nb = len(batch['img'])
bboxes = batch['bboxes'] * torch.tensor((width, height, width, height), device=self.device)
self.lb = [
torch.cat([batch['cls'][batch['batch_idx'] == i], bboxes[batch['batch_idx'] == i]], dim=-1)
for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
return batch
def init_metrics(self, model):
"""Initialize evaluation metrics for YOLO."""
val = self.data.get(self.args.split, '') # validation path
self.is_coco = isinstance(val, str) and 'coco' in val and val.endswith(f'{os.sep}val2017.txt') # is COCO
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
self.names = model.names
self.nc = len(model.names)
self.metrics.names = self.names
self.metrics.plot = self.args.plots
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
self.seen = 0
self.jdict = []
self.stats = []
def get_desc(self):
"""Return a formatted string summarizing class metrics of YOLO model."""
return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
return ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det)
def update_metrics(self, preds, batch):
"""Metrics."""
for si, pred in enumerate(preds):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
bbox = batch['bboxes'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
shape = batch['ori_shape'][si]
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
if npr == 0:
if nl:
self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), cls.squeeze(-1)))
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
continue
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
ratio_pad=batch['ratio_pad'][si]) # native-space pred
# Evaluate
if nl:
height, width = batch['img'].shape[2:]
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
(width, height, width, height), device=self.device) # target boxes
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
ratio_pad=batch['ratio_pad'][si]) # native-space labels
labelsn = torch.cat((cls, tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn, labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) # (conf, pcls, tcls)
# Save
if self.args.save_json:
self.pred_to_json(predn, batch['im_file'][si])
if self.args.save_txt:
file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt'
self.save_one_txt(predn, self.args.save_conf, shape, file)
def finalize_metrics(self, *args, **kwargs):
"""Set final values for metrics speed and confusion matrix."""
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
def get_stats(self):
"""Returns metrics statistics and results dictionary."""
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
if len(stats) and stats[0].any():
self.metrics.process(*stats)
self.nt_per_class = np.bincount(stats[-1].astype(int), minlength=self.nc) # number of targets per class
return self.metrics.results_dict
def print_results(self):
"""Prints training/validation set metrics per class."""
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
if self.nt_per_class.sum() == 0:
LOGGER.warning(
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
# Print results per class
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
for i, c in enumerate(self.metrics.ap_class_index):
LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir,
names=self.names.values(),
normalize=normalize,
on_plot=self.on_plot)
def _process_batch(self, detections, labels):
"""
Return correct prediction matrix.
Args:
detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
Each detection is of the format: x1, y1, x2, y2, conf, class.
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
Each label is of the format: class, x1, y1, x2, y2.
Returns:
(torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
"""
iou = box_iou(labels[:, 1:], detections[:, :4])
return self.match_predictions(detections[:, 5], labels[:, 0], iou)
def build_dataset(self, img_path, mode='val', batch=None):
"""
Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
gs = max(int(de_parallel(self.model).stride if self.model else 0), 32)
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=gs)
def get_dataloader(self, dataset_path, batch_size):
"""Construct and return dataloader."""
dataset = self.build_dataset(dataset_path, batch=batch_size, mode='val')
return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
def plot_val_samples(self, batch, ni):
"""Plot validation image samples."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
plot_images(batch['img'],
*output_to_target(preds, max_det=self.args.max_det),
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
def save_one_txt(self, predn, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
for *xyxy, conf, cls in predn.tolist():
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
with open(file, 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
def pred_to_json(self, predn, filename):
"""Serialize YOLO predictions to COCO json format."""
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for p, b in zip(predn.tolist(), box.tolist()):
self.jdict.append({
'image_id': image_id,
'category_id': self.class_map[int(p[5])],
'bbox': [round(x, 3) for x in b],
'score': round(p[4], 5)})
def eval_json(self, stats):
"""Evaluates YOLO output in JSON format and returns performance statistics."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6')
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f'{x} file not found'
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
eval = COCOeval(anno, pred, 'bbox')
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats

View File

@ -0,0 +1,36 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.engine.model import Model
from ultralytics.models import yolo # noqa
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, PoseModel, SegmentationModel
class YOLO(Model):
"""
YOLO (You Only Look Once) object detection model.
"""
@property
def task_map(self):
"""Map head to model, trainer, validator, and predictor classes"""
return {
'classify': {
'model': ClassificationModel,
'trainer': yolo.classify.ClassificationTrainer,
'validator': yolo.classify.ClassificationValidator,
'predictor': yolo.classify.ClassificationPredictor, },
'detect': {
'model': DetectionModel,
'trainer': yolo.detect.DetectionTrainer,
'validator': yolo.detect.DetectionValidator,
'predictor': yolo.detect.DetectionPredictor, },
'segment': {
'model': SegmentationModel,
'trainer': yolo.segment.SegmentationTrainer,
'validator': yolo.segment.SegmentationValidator,
'predictor': yolo.segment.SegmentationPredictor, },
'pose': {
'model': PoseModel,
'trainer': yolo.pose.PoseTrainer,
'validator': yolo.pose.PoseValidator,
'predictor': yolo.pose.PosePredictor, }, }

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .predict import PosePredictor
from .train import PoseTrainer
from .val import PoseValidator
__all__ = 'PoseTrainer', 'PoseValidator', 'PosePredictor'

View File

@ -0,0 +1,52 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
class PosePredictor(DetectionPredictor):
"""
A class extending the DetectionPredictor class for prediction based on a pose model.
Example:
```python
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.pose import PosePredictor
args = dict(model='yolov8n-pose.pt', source=ASSETS)
predictor = PosePredictor(overrides=args)
predictor.predict_cli()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'pose'
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
'See https://github.com/ultralytics/ultralytics/issues/4031.')
def postprocess(self, preds, img, orig_imgs):
"""Return detection results for a given input image or list of images."""
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes,
nc=len(self.model.names))
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i]
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
img_path = self.batch[0][i]
results.append(
Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts))
return results

View File

@ -0,0 +1,73 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy
from ultralytics.models import yolo
from ultralytics.nn.tasks import PoseModel
from ultralytics.utils import DEFAULT_CFG, LOGGER
from ultralytics.utils.plotting import plot_images, plot_results
class PoseTrainer(yolo.detect.DetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on a pose model.
Example:
```python
from ultralytics.models.yolo.pose import PoseTrainer
args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
trainer = PoseTrainer(overrides=args)
trainer.train()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a PoseTrainer object with specified configurations and overrides."""
if overrides is None:
overrides = {}
overrides['task'] = 'pose'
super().__init__(cfg, overrides, _callbacks)
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
'See https://github.com/ultralytics/ultralytics/issues/4031.')
def get_model(self, cfg=None, weights=None, verbose=True):
"""Get pose estimation model with specified configuration and weights."""
model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
if weights:
model.load(weights)
return model
def set_model_attributes(self):
"""Sets keypoints shape attribute of PoseModel."""
super().set_model_attributes()
self.model.kpt_shape = self.data['kpt_shape']
def get_validator(self):
"""Returns an instance of the PoseValidator class for validation."""
self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
return yolo.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def plot_training_samples(self, batch, ni):
"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
images = batch['img']
kpts = batch['keypoints']
cls = batch['cls'].squeeze(-1)
bboxes = batch['bboxes']
paths = batch['im_file']
batch_idx = batch['batch_idx']
plot_images(images,
batch_idx,
cls,
bboxes,
kpts=kpts,
paths=paths,
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png

View File

@ -0,0 +1,215 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from pathlib import Path
import numpy as np
import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
from ultralytics.utils.plotting import output_to_target, plot_images
class PoseValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on a pose model.
Example:
```python
from ultralytics.models.yolo.pose import PoseValidator
args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml')
validator = PoseValidator(args=args)
validator()
```
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.sigma = None
self.kpt_shape = None
self.args.task = 'pose'
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
'See https://github.com/ultralytics/ultralytics/issues/4031.')
def preprocess(self, batch):
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
batch = super().preprocess(batch)
batch['keypoints'] = batch['keypoints'].to(self.device).float()
return batch
def get_desc(self):
"""Returns description of evaluation metrics in string format."""
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P',
'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds):
"""Apply non-maximum suppression and return detections with high confidence scores."""
return ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
nc=self.nc)
def init_metrics(self, model):
"""Initiate pose estimation metrics for YOLO model."""
super().init_metrics(model)
self.kpt_shape = self.data['kpt_shape']
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0]
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
def update_metrics(self, preds, batch):
"""Metrics."""
for si, pred in enumerate(preds):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
bbox = batch['bboxes'][idx]
kpts = batch['keypoints'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
nk = kpts.shape[1] # number of keypoints
shape = batch['ori_shape'][si]
correct_kpts = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
if npr == 0:
if nl:
self.stats.append((correct_bboxes, correct_kpts, *torch.zeros(
(2, 0), device=self.device), cls.squeeze(-1)))
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
continue
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
ratio_pad=batch['ratio_pad'][si]) # native-space pred
pred_kpts = predn[:, 6:].view(npr, nk, -1)
ops.scale_coords(batch['img'][si].shape[1:], pred_kpts, shape, ratio_pad=batch['ratio_pad'][si])
# Evaluate
if nl:
height, width = batch['img'].shape[2:]
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
(width, height, width, height), device=self.device) # target boxes
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
ratio_pad=batch['ratio_pad'][si]) # native-space labels
tkpts = kpts.clone()
tkpts[..., 0] *= width
tkpts[..., 1] *= height
tkpts = ops.scale_coords(batch['img'][si].shape[1:], tkpts, shape, ratio_pad=batch['ratio_pad'][si])
labelsn = torch.cat((cls, tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn[:, :6], labelsn)
correct_kpts = self._process_batch(predn[:, :6], labelsn, pred_kpts, tkpts)
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
# Append correct_masks, correct_boxes, pconf, pcls, tcls
self.stats.append((correct_bboxes, correct_kpts, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
# Save
if self.args.save_json:
self.pred_to_json(predn, batch['im_file'][si])
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def _process_batch(self, detections, labels, pred_kpts=None, gt_kpts=None):
"""
Return correct prediction matrix.
Args:
detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
Each detection is of the format: x1, y1, x2, y2, conf, class.
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
Each label is of the format: class, x1, y1, x2, y2.
pred_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing predicted keypoints.
51 corresponds to 17 keypoints each with 3 values.
gt_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing ground truth keypoints.
Returns:
torch.Tensor: Correct prediction matrix of shape [N, 10] for 10 IoU levels.
"""
if pred_kpts is not None and gt_kpts is not None:
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
area = ops.xyxy2xywh(labels[:, 1:])[:, 2:].prod(1) * 0.53
iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
else: # boxes
iou = box_iou(labels[:, 1:], detections[:, :4])
return self.match_predictions(detections[:, 5], labels[:, 0], iou)
def plot_val_samples(self, batch, ni):
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
kpts=batch['keypoints'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots predictions for YOLO model."""
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
plot_images(batch['img'],
*output_to_target(preds, max_det=self.args.max_det),
kpts=pred_kpts,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
def pred_to_json(self, predn, filename):
"""Converts YOLO predictions to COCO JSON format."""
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for p, b in zip(predn.tolist(), box.tolist()):
self.jdict.append({
'image_id': image_id,
'category_id': self.class_map[int(p[5])],
'bbox': [round(x, 3) for x in b],
'keypoints': p[6:],
'score': round(p[4], 5)})
def eval_json(self, stats):
"""Evaluates object detection model using COCO JSON format."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6')
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f'{x} file not found'
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'keypoints')]):
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
idx = i * 4 + 2
stats[self.metrics.keys[idx + 1]], stats[
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .predict import SegmentationPredictor
from .train import SegmentationTrainer
from .val import SegmentationValidator
__all__ = 'SegmentationPredictor', 'SegmentationTrainer', 'SegmentationValidator'

View File

@ -0,0 +1,53 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ops
class SegmentationPredictor(DetectionPredictor):
"""
A class extending the DetectionPredictor class for prediction based on a segmentation model.
Example:
```python
from ultralytics.utils import ASSETS
from ultralytics.models.yolo.segment import SegmentationPredictor
args = dict(model='yolov8n-seg.pt', source=ASSETS)
predictor = SegmentationPredictor(overrides=args)
predictor.predict_cli()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'segment'
def postprocess(self, preds, img, orig_imgs):
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=len(self.model.names),
classes=self.args.classes)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
for i, pred in enumerate(p):
orig_img = orig_imgs[i]
img_path = self.batch[0][i]
if not len(pred): # save empty boxes
masks = None
elif self.args.retina_masks:
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
else:
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
return results

View File

@ -0,0 +1,58 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy
from ultralytics.models import yolo
from ultralytics.nn.tasks import SegmentationModel
from ultralytics.utils import DEFAULT_CFG, RANK
from ultralytics.utils.plotting import plot_images, plot_results
class SegmentationTrainer(yolo.detect.DetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on a segmentation model.
Example:
```python
from ultralytics.models.yolo.segment import SegmentationTrainer
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
trainer = SegmentationTrainer(overrides=args)
trainer.train()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a SegmentationTrainer object with given arguments."""
if overrides is None:
overrides = {}
overrides['task'] = 'segment'
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return SegmentationModel initialized with specified config and weights."""
model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Return an instance of SegmentationValidator for validation of YOLO model."""
self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
return yolo.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def plot_training_samples(self, batch, ni):
"""Creates a plot of training sample images with labels and box coordinates."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png

View File

@ -0,0 +1,247 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from multiprocessing.pool import ThreadPool
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, NUM_THREADS, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
from ultralytics.utils.plotting import output_to_target, plot_images
class SegmentationValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on a segmentation model.
Example:
```python
from ultralytics.models.yolo.segment import SegmentationValidator
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml')
validator = SegmentationValidator(args=args)
validator()
```
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.plot_masks = None
self.process = None
self.args.task = 'segment'
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
def preprocess(self, batch):
"""Preprocesses batch by converting masks to float and sending to device."""
batch = super().preprocess(batch)
batch['masks'] = batch['masks'].to(self.device).float()
return batch
def init_metrics(self, model):
"""Initialize metrics and select mask processing function based on save_json flag."""
super().init_metrics(model)
self.plot_masks = []
if self.args.save_json:
check_requirements('pycocotools>=2.0.6')
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
def get_desc(self):
"""Return a formatted description of evaluation metrics."""
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds):
"""Post-processes YOLO predictions and returns output detections with proto."""
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
nc=self.nc)
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
return p, proto
def update_metrics(self, preds, batch):
"""Metrics."""
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
bbox = batch['bboxes'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
shape = batch['ori_shape'][si]
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
if npr == 0:
if nl:
self.stats.append((correct_bboxes, correct_masks, *torch.zeros(
(2, 0), device=self.device), cls.squeeze(-1)))
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
continue
# Masks
midx = [si] if self.args.overlap_mask else idx
gt_masks = batch['masks'][midx]
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:])
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
ratio_pad=batch['ratio_pad'][si]) # native-space pred
# Evaluate
if nl:
height, width = batch['img'].shape[2:]
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
(width, height, width, height), device=self.device) # target boxes
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
ratio_pad=batch['ratio_pad'][si]) # native-space labels
labelsn = torch.cat((cls, tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn, labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable
correct_masks = self._process_batch(predn,
labelsn,
pred_masks,
gt_masks,
overlap=self.args.overlap_mask,
masks=True)
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
# Append correct_masks, correct_boxes, pconf, pcls, tcls
self.stats.append((correct_bboxes, correct_masks, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
if self.args.plots and self.batch_i < 3:
self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
# Save
if self.args.save_json:
pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
shape,
ratio_pad=batch['ratio_pad'][si])
self.pred_to_json(predn, batch['im_file'][si], pred_masks)
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def finalize_metrics(self, *args, **kwargs):
"""Sets speed and confusion matrix for evaluation metrics."""
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
"""
Return correct prediction matrix
Args:
detections (array[N, 6]), x1, y1, x2, y2, conf, class
labels (array[M, 5]), class, x1, y1, x2, y2
Returns:
correct (array[N, 10]), for 10 IoU levels
"""
if masks:
if overlap:
nl = len(labels)
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
if gt_masks.shape[1:] != pred_masks.shape[1:]:
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
gt_masks = gt_masks.gt_(0.5)
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
else: # boxes
iou = box_iou(labels[:, 1:], detections[:, :4])
return self.match_predictions(detections[:, 5], labels[:, 0], iou)
def plot_val_samples(self, batch, ni):
"""Plots validation samples with bounding box labels."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots batch predictions with masks and bounding boxes."""
plot_images(
batch['img'],
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
self.plot_masks.clear()
def pred_to_json(self, predn, filename, pred_masks):
"""Save one JSON result."""
# Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
from pycocotools.mask import encode # noqa
def single_encode(x):
"""Encode predicted masks as RLE and append results to jdict."""
rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
rle['counts'] = rle['counts'].decode('utf-8')
return rle
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
pred_masks = np.transpose(pred_masks, (2, 0, 1))
with ThreadPool(NUM_THREADS) as pool:
rles = pool.map(single_encode, pred_masks)
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
self.jdict.append({
'image_id': image_id,
'category_id': self.class_map[int(p[5])],
'bbox': [round(x, 3) for x in b],
'score': round(p[4], 5),
'segmentation': rles[i]})
def eval_json(self, stats):
"""Return COCO-style object detection evaluation metrics."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6')
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f'{x} file not found'
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
idx = i * 4 + 2
stats[self.metrics.keys[idx + 1]], stats[
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats