update
This commit is contained in:
7
ytracking/ultralytics/models/__init__.py
Normal file
7
ytracking/ultralytics/models/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .rtdetr import RTDETR
|
||||
from .sam import SAM
|
||||
from .yolo import YOLO
|
||||
|
||||
__all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import
|
BIN
ytracking/ultralytics/models/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
ytracking/ultralytics/models/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
ytracking/ultralytics/models/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
ytracking/ultralytics/models/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
8
ytracking/ultralytics/models/fastsam/__init__.py
Normal file
8
ytracking/ultralytics/models/fastsam/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .model import FastSAM
|
||||
from .predict import FastSAMPredictor
|
||||
from .prompt import FastSAMPrompt
|
||||
from .val import FastSAMValidator
|
||||
|
||||
__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator'
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
33
ytracking/ultralytics/models/fastsam/model.py
Normal file
33
ytracking/ultralytics/models/fastsam/model.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.engine.model import Model
|
||||
|
||||
from .predict import FastSAMPredictor
|
||||
from .val import FastSAMValidator
|
||||
|
||||
|
||||
class FastSAM(Model):
|
||||
"""
|
||||
FastSAM model interface.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics import FastSAM
|
||||
|
||||
model = FastSAM('last.pt')
|
||||
results = model.predict('ultralytics/assets/bus.jpg')
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, model='FastSAM-x.pt'):
|
||||
"""Call the __init__ method of the parent class (YOLO) with the updated default model"""
|
||||
if str(model) == 'FastSAM.pt':
|
||||
model = 'FastSAM-x.pt'
|
||||
assert Path(model).suffix not in ('.yaml', '.yml'), 'FastSAM models only support pre-trained models.'
|
||||
super().__init__(model=model, task='segment')
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}
|
51
ytracking/ultralytics/models/fastsam/predict.py
Normal file
51
ytracking/ultralytics/models/fastsam/predict.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.models.fastsam.utils import bbox_iou
|
||||
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
|
||||
|
||||
class FastSAMPredictor(DetectionPredictor):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nc=len(self.model.names),
|
||||
classes=self.args.classes)
|
||||
full_box = torch.zeros(p[0].shape[1], device=p[0].device)
|
||||
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
|
||||
full_box = full_box.view(1, -1)
|
||||
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
|
||||
if critical_iou_index.numel() != 0:
|
||||
full_box[0][4] = p[0][critical_iou_index][:, 4]
|
||||
full_box[0][6:] = p[0][critical_iou_index][:, 6:]
|
||||
p[0][critical_iou_index] = full_box
|
||||
|
||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
|
||||
results = []
|
||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
||||
for i, pred in enumerate(p):
|
||||
orig_img = orig_imgs[i]
|
||||
img_path = self.batch[0][i]
|
||||
if not len(pred): # save empty boxes
|
||||
masks = None
|
||||
elif self.args.retina_masks:
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
||||
else:
|
||||
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
|
||||
return results
|
307
ytracking/ultralytics/models/fastsam/prompt.py
Normal file
307
ytracking/ultralytics/models/fastsam/prompt.py
Normal file
@ -0,0 +1,307 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.utils import TQDM
|
||||
|
||||
|
||||
class FastSAMPrompt:
|
||||
|
||||
def __init__(self, source, results, device='cuda') -> None:
|
||||
self.device = device
|
||||
self.results = results
|
||||
self.source = source
|
||||
|
||||
# Import and assign clip
|
||||
try:
|
||||
import clip # for linear_assignment
|
||||
except ImportError:
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
check_requirements('git+https://github.com/openai/CLIP.git')
|
||||
import clip
|
||||
self.clip = clip
|
||||
|
||||
@staticmethod
|
||||
def _segment_image(image, bbox):
|
||||
image_array = np.array(image)
|
||||
segmented_image_array = np.zeros_like(image_array)
|
||||
x1, y1, x2, y2 = bbox
|
||||
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
|
||||
segmented_image = Image.fromarray(segmented_image_array)
|
||||
black_image = Image.new('RGB', image.size, (255, 255, 255))
|
||||
# transparency_mask = np.zeros_like((), dtype=np.uint8)
|
||||
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
|
||||
transparency_mask[y1:y2, x1:x2] = 255
|
||||
transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
|
||||
black_image.paste(segmented_image, mask=transparency_mask_image)
|
||||
return black_image
|
||||
|
||||
@staticmethod
|
||||
def _format_results(result, filter=0):
|
||||
annotations = []
|
||||
n = len(result.masks.data) if result.masks is not None else 0
|
||||
for i in range(n):
|
||||
mask = result.masks.data[i] == 1.0
|
||||
if torch.sum(mask) >= filter:
|
||||
annotation = {
|
||||
'id': i,
|
||||
'segmentation': mask.cpu().numpy(),
|
||||
'bbox': result.boxes.data[i],
|
||||
'score': result.boxes.conf[i]}
|
||||
annotation['area'] = annotation['segmentation'].sum()
|
||||
annotations.append(annotation)
|
||||
return annotations
|
||||
|
||||
@staticmethod
|
||||
def _get_bbox_from_mask(mask):
|
||||
mask = mask.astype(np.uint8)
|
||||
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
x1, y1, w, h = cv2.boundingRect(contours[0])
|
||||
x2, y2 = x1 + w, y1 + h
|
||||
if len(contours) > 1:
|
||||
for b in contours:
|
||||
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
|
||||
x1 = min(x1, x_t)
|
||||
y1 = min(y1, y_t)
|
||||
x2 = max(x2, x_t + w_t)
|
||||
y2 = max(y2, y_t + h_t)
|
||||
return [x1, y1, x2, y2]
|
||||
|
||||
def plot(self,
|
||||
annotations,
|
||||
output,
|
||||
bbox=None,
|
||||
points=None,
|
||||
point_label=None,
|
||||
mask_random_color=True,
|
||||
better_quality=True,
|
||||
retina=False,
|
||||
with_contours=True):
|
||||
pbar = TQDM(annotations, total=len(annotations))
|
||||
for ann in pbar:
|
||||
result_name = os.path.basename(ann.path)
|
||||
image = ann.orig_img
|
||||
original_h, original_w = ann.orig_shape
|
||||
# for macOS only
|
||||
# plt.switch_backend('TkAgg')
|
||||
plt.figure(figsize=(original_w / 100, original_h / 100))
|
||||
# Add subplot with no margin.
|
||||
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
||||
plt.margins(0, 0)
|
||||
plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
||||
plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
||||
plt.imshow(image)
|
||||
|
||||
if ann.masks is not None:
|
||||
masks = ann.masks.data
|
||||
if better_quality:
|
||||
if isinstance(masks[0], torch.Tensor):
|
||||
masks = np.array(masks.cpu())
|
||||
for i, mask in enumerate(masks):
|
||||
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
|
||||
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
|
||||
|
||||
self.fast_show_mask(
|
||||
masks,
|
||||
plt.gca(),
|
||||
random_color=mask_random_color,
|
||||
bbox=bbox,
|
||||
points=points,
|
||||
pointlabel=point_label,
|
||||
retinamask=retina,
|
||||
target_height=original_h,
|
||||
target_width=original_w,
|
||||
)
|
||||
|
||||
if with_contours:
|
||||
contour_all = []
|
||||
temp = np.zeros((original_h, original_w, 1))
|
||||
for i, mask in enumerate(masks):
|
||||
mask = mask.astype(np.uint8)
|
||||
if not retina:
|
||||
mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
contour_all.extend(iter(contours))
|
||||
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
|
||||
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
|
||||
contour_mask = temp / 255 * color.reshape(1, 1, -1)
|
||||
plt.imshow(contour_mask)
|
||||
|
||||
plt.axis('off')
|
||||
fig = plt.gcf()
|
||||
|
||||
# Check if the canvas has been drawn
|
||||
if fig.canvas.get_renderer() is None: # macOS requires this or tests fail
|
||||
fig.canvas.draw()
|
||||
|
||||
save_path = Path(output) / result_name
|
||||
save_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
||||
image.save(save_path)
|
||||
plt.close()
|
||||
pbar.set_description(f'Saving {result_name} to {save_path}')
|
||||
|
||||
@staticmethod
|
||||
def fast_show_mask(
|
||||
annotation,
|
||||
ax,
|
||||
random_color=False,
|
||||
bbox=None,
|
||||
points=None,
|
||||
pointlabel=None,
|
||||
retinamask=True,
|
||||
target_height=960,
|
||||
target_width=960,
|
||||
):
|
||||
n, h, w = annotation.shape # batch, height, width
|
||||
|
||||
areas = np.sum(annotation, axis=(1, 2))
|
||||
annotation = annotation[np.argsort(areas)]
|
||||
|
||||
index = (annotation != 0).argmax(axis=0)
|
||||
if random_color:
|
||||
color = np.random.random((n, 1, 1, 3))
|
||||
else:
|
||||
color = np.ones((n, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
|
||||
transparency = np.ones((n, 1, 1, 1)) * 0.6
|
||||
visual = np.concatenate([color, transparency], axis=-1)
|
||||
mask_image = np.expand_dims(annotation, -1) * visual
|
||||
|
||||
show = np.zeros((h, w, 4))
|
||||
h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
|
||||
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
||||
|
||||
show[h_indices, w_indices, :] = mask_image[indices]
|
||||
if bbox is not None:
|
||||
x1, y1, x2, y2 = bbox
|
||||
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
|
||||
# Draw point
|
||||
if points is not None:
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
s=20,
|
||||
c='y',
|
||||
)
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
s=20,
|
||||
c='m',
|
||||
)
|
||||
|
||||
if not retinamask:
|
||||
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
|
||||
ax.imshow(show)
|
||||
|
||||
@torch.no_grad()
|
||||
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
|
||||
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
||||
tokenized_text = self.clip.tokenize([search_text]).to(device)
|
||||
stacked_images = torch.stack(preprocessed_images)
|
||||
image_features = model.encode_image(stacked_images)
|
||||
text_features = model.encode_text(tokenized_text)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
probs = 100.0 * image_features @ text_features.T
|
||||
return probs[:, 0].softmax(dim=0)
|
||||
|
||||
def _crop_image(self, format_results):
|
||||
if os.path.isdir(self.source):
|
||||
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
||||
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
|
||||
ori_w, ori_h = image.size
|
||||
annotations = format_results
|
||||
mask_h, mask_w = annotations[0]['segmentation'].shape
|
||||
if ori_w != mask_w or ori_h != mask_h:
|
||||
image = image.resize((mask_w, mask_h))
|
||||
cropped_boxes = []
|
||||
cropped_images = []
|
||||
not_crop = []
|
||||
filter_id = []
|
||||
for _, mask in enumerate(annotations):
|
||||
if np.sum(mask['segmentation']) <= 100:
|
||||
filter_id.append(_)
|
||||
continue
|
||||
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
|
||||
cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片
|
||||
cropped_images.append(bbox) # 保存裁剪的图片的bbox
|
||||
|
||||
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
||||
|
||||
def box_prompt(self, bbox):
|
||||
if self.results[0].masks is not None:
|
||||
assert (bbox[2] != 0 and bbox[3] != 0)
|
||||
if os.path.isdir(self.source):
|
||||
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
||||
masks = self.results[0].masks.data
|
||||
target_height, target_width = self.results[0].orig_shape
|
||||
h = masks.shape[1]
|
||||
w = masks.shape[2]
|
||||
if h != target_height or w != target_width:
|
||||
bbox = [
|
||||
int(bbox[0] * w / target_width),
|
||||
int(bbox[1] * h / target_height),
|
||||
int(bbox[2] * w / target_width),
|
||||
int(bbox[3] * h / target_height), ]
|
||||
bbox[0] = max(round(bbox[0]), 0)
|
||||
bbox[1] = max(round(bbox[1]), 0)
|
||||
bbox[2] = min(round(bbox[2]), w)
|
||||
bbox[3] = min(round(bbox[3]), h)
|
||||
|
||||
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
||||
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
||||
|
||||
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
|
||||
orig_masks_area = torch.sum(masks, dim=(1, 2))
|
||||
|
||||
union = bbox_area + orig_masks_area - masks_area
|
||||
IoUs = masks_area / union
|
||||
max_iou_index = torch.argmax(IoUs)
|
||||
|
||||
self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
|
||||
return self.results
|
||||
|
||||
def point_prompt(self, points, pointlabel): # numpy 处理
|
||||
if self.results[0].masks is not None:
|
||||
if os.path.isdir(self.source):
|
||||
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
||||
masks = self._format_results(self.results[0], 0)
|
||||
target_height, target_width = self.results[0].orig_shape
|
||||
h = masks[0]['segmentation'].shape[0]
|
||||
w = masks[0]['segmentation'].shape[1]
|
||||
if h != target_height or w != target_width:
|
||||
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
||||
onemask = np.zeros((h, w))
|
||||
for annotation in masks:
|
||||
mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation
|
||||
for i, point in enumerate(points):
|
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
||||
onemask += mask
|
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
|
||||
onemask -= mask
|
||||
onemask = onemask >= 1
|
||||
self.results[0].masks.data = torch.tensor(np.array([onemask]))
|
||||
return self.results
|
||||
|
||||
def text_prompt(self, text):
|
||||
if self.results[0].masks is not None:
|
||||
format_results = self._format_results(self.results[0], 0)
|
||||
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
||||
clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
|
||||
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
|
||||
max_idx = scores.argsort()
|
||||
max_idx = max_idx[-1]
|
||||
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
||||
self.results[0].masks.data = torch.tensor(np.array([ann['segmentation'] for ann in annotations]))
|
||||
return self.results
|
||||
|
||||
def everything_prompt(self):
|
||||
return self.results
|
67
ytracking/ultralytics/models/fastsam/utils.py
Normal file
67
ytracking/ultralytics/models/fastsam/utils.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
||||
"""
|
||||
Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): (n, 4)
|
||||
image_shape (tuple): (height, width)
|
||||
threshold (int): pixel threshold
|
||||
|
||||
Returns:
|
||||
adjusted_boxes (torch.Tensor): adjusted bounding boxes
|
||||
"""
|
||||
|
||||
# Image dimensions
|
||||
h, w = image_shape
|
||||
|
||||
# Adjust boxes
|
||||
boxes[boxes[:, 0] < threshold, 0] = 0 # x1
|
||||
boxes[boxes[:, 1] < threshold, 1] = 0 # y1
|
||||
boxes[boxes[:, 2] > w - threshold, 2] = w # x2
|
||||
boxes[boxes[:, 3] > h - threshold, 3] = h # y2
|
||||
return boxes
|
||||
|
||||
|
||||
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
|
||||
"""
|
||||
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
|
||||
|
||||
Args:
|
||||
box1 (torch.Tensor): (4, )
|
||||
boxes (torch.Tensor): (n, 4)
|
||||
iou_thres (float): IoU threshold
|
||||
image_shape (tuple): (height, width)
|
||||
raw_output (bool): If True, return the raw IoU values instead of the indices
|
||||
|
||||
Returns:
|
||||
high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres
|
||||
"""
|
||||
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
|
||||
# obtain coordinates for intersections
|
||||
x1 = torch.max(box1[0], boxes[:, 0])
|
||||
y1 = torch.max(box1[1], boxes[:, 1])
|
||||
x2 = torch.min(box1[2], boxes[:, 2])
|
||||
y2 = torch.min(box1[3], boxes[:, 3])
|
||||
|
||||
# compute the area of intersection
|
||||
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
|
||||
|
||||
# compute the area of both individual boxes
|
||||
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||||
box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
# compute the area of union
|
||||
union = box1_area + box2_area - intersection
|
||||
|
||||
# compute the IoU
|
||||
iou = intersection / union # Should be shape (n, )
|
||||
if raw_output:
|
||||
return 0 if iou.numel() == 0 else iou
|
||||
|
||||
# return indices of boxes with IoU > thres
|
||||
return torch.nonzero(iou > iou_thres).flatten()
|
14
ytracking/ultralytics/models/fastsam/val.py
Normal file
14
ytracking/ultralytics/models/fastsam/val.py
Normal file
@ -0,0 +1,14 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.models.yolo.segment import SegmentationValidator
|
||||
from ultralytics.utils.metrics import SegmentMetrics
|
||||
|
||||
|
||||
class FastSAMValidator(SegmentationValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
|
||||
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
7
ytracking/ultralytics/models/nas/__init__.py
Normal file
7
ytracking/ultralytics/models/nas/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .model import NAS
|
||||
from .predict import NASPredictor
|
||||
from .val import NASValidator
|
||||
|
||||
__all__ = 'NASPredictor', 'NASValidator', 'NAS'
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
ytracking/ultralytics/models/nas/__pycache__/val.cpython-38.pyc
Normal file
BIN
ytracking/ultralytics/models/nas/__pycache__/val.cpython-38.pyc
Normal file
Binary file not shown.
BIN
ytracking/ultralytics/models/nas/__pycache__/val.cpython-39.pyc
Normal file
BIN
ytracking/ultralytics/models/nas/__pycache__/val.cpython-39.pyc
Normal file
Binary file not shown.
61
ytracking/ultralytics/models/nas/model.py
Normal file
61
ytracking/ultralytics/models/nas/model.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
YOLO-NAS model interface.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics import NAS
|
||||
|
||||
model = NAS('yolo_nas_s')
|
||||
results = model.predict('ultralytics/assets/bus.jpg')
|
||||
```
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.engine.model import Model
|
||||
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
|
||||
|
||||
from .predict import NASPredictor
|
||||
from .val import NASValidator
|
||||
|
||||
|
||||
class NAS(Model):
|
||||
|
||||
def __init__(self, model='yolo_nas_s.pt') -> None:
|
||||
assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.'
|
||||
super().__init__(model, task='detect')
|
||||
|
||||
@smart_inference_mode()
|
||||
def _load(self, weights: str, task: str):
|
||||
# Load or create new NAS model
|
||||
import super_gradients
|
||||
suffix = Path(weights).suffix
|
||||
if suffix == '.pt':
|
||||
self.model = torch.load(weights)
|
||||
elif suffix == '':
|
||||
self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
|
||||
# Standardize model
|
||||
self.model.fuse = lambda verbose=True: self.model
|
||||
self.model.stride = torch.tensor([32])
|
||||
self.model.names = dict(enumerate(self.model._class_names))
|
||||
self.model.is_fused = lambda: False # for info()
|
||||
self.model.yaml = {} # for info()
|
||||
self.model.pt_path = weights # for export()
|
||||
self.model.task = 'detect' # for export()
|
||||
|
||||
def info(self, detailed=False, verbose=True):
|
||||
"""
|
||||
Logs model info.
|
||||
|
||||
Args:
|
||||
detailed (bool): Show detailed information about model.
|
||||
verbose (bool): Controls verbosity.
|
||||
"""
|
||||
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}
|
35
ytracking/ultralytics/models/nas/predict.py
Normal file
35
ytracking/ultralytics/models/nas/predict.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import ops
|
||||
|
||||
|
||||
class NASPredictor(BasePredictor):
|
||||
|
||||
def postprocess(self, preds_in, img, orig_imgs):
|
||||
"""Postprocess predictions and returns a list of Results objects."""
|
||||
|
||||
# Cat boxes and class scores
|
||||
boxes = ops.xyxy2xywh(preds_in[0][0])
|
||||
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
||||
|
||||
preds = ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
classes=self.args.classes)
|
||||
|
||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
|
||||
results = []
|
||||
for i, pred in enumerate(preds):
|
||||
orig_img = orig_imgs[i]
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
img_path = self.batch[0][i]
|
||||
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
||||
return results
|
24
ytracking/ultralytics/models/nas/val.py
Normal file
24
ytracking/ultralytics/models/nas/val.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import ops
|
||||
|
||||
__all__ = ['NASValidator']
|
||||
|
||||
|
||||
class NASValidator(DetectionValidator):
|
||||
|
||||
def postprocess(self, preds_in):
|
||||
"""Apply Non-maximum suppression to prediction outputs."""
|
||||
boxes = ops.xyxy2xywh(preds_in[0][0])
|
||||
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
||||
return ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
labels=self.lb,
|
||||
multi_label=False,
|
||||
agnostic=self.args.single_cls,
|
||||
max_det=self.args.max_det,
|
||||
max_time_img=0.5)
|
7
ytracking/ultralytics/models/rtdetr/__init__.py
Normal file
7
ytracking/ultralytics/models/rtdetr/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .model import RTDETR
|
||||
from .predict import RTDETRPredictor
|
||||
from .val import RTDETRValidator
|
||||
|
||||
__all__ = 'RTDETRPredictor', 'RTDETRValidator', 'RTDETR'
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
30
ytracking/ultralytics/models/rtdetr/model.py
Normal file
30
ytracking/ultralytics/models/rtdetr/model.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
RT-DETR model interface
|
||||
"""
|
||||
from ultralytics.engine.model import Model
|
||||
from ultralytics.nn.tasks import RTDETRDetectionModel
|
||||
|
||||
from .predict import RTDETRPredictor
|
||||
from .train import RTDETRTrainer
|
||||
from .val import RTDETRValidator
|
||||
|
||||
|
||||
class RTDETR(Model):
|
||||
"""
|
||||
RTDETR model interface.
|
||||
"""
|
||||
|
||||
def __init__(self, model='rtdetr-l.pt') -> None:
|
||||
if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'):
|
||||
raise NotImplementedError('RT-DETR only supports creating from *.pt file or *.yaml file.')
|
||||
super().__init__(model=model, task='detect')
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
return {
|
||||
'detect': {
|
||||
'predictor': RTDETRPredictor,
|
||||
'validator': RTDETRValidator,
|
||||
'trainer': RTDETRTrainer,
|
||||
'model': RTDETRDetectionModel}}
|
62
ytracking/ultralytics/models/rtdetr/predict.py
Normal file
62
ytracking/ultralytics/models/rtdetr/predict.py
Normal file
@ -0,0 +1,62 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data.augment import LetterBox
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import ops
|
||||
|
||||
|
||||
class RTDETRPredictor(BasePredictor):
|
||||
"""
|
||||
A class extending the BasePredictor class for prediction based on an RT-DETR detection model.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils import ASSETS
|
||||
from ultralytics.models.rtdetr import RTDETRPredictor
|
||||
|
||||
args = dict(model='rtdetr-l.pt', source=ASSETS)
|
||||
predictor = RTDETRPredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
```
|
||||
"""
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Postprocess predictions and returns a list of Results objects."""
|
||||
nd = preds[0].shape[-1]
|
||||
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
||||
|
||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
|
||||
results = []
|
||||
for i, bbox in enumerate(bboxes): # (300, 4)
|
||||
bbox = ops.xywh2xyxy(bbox)
|
||||
score, cls = scores[i].max(-1, keepdim=True) # (300, 1)
|
||||
idx = score.squeeze(-1) > self.args.conf # (300, )
|
||||
if self.args.classes is not None:
|
||||
idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
|
||||
pred = torch.cat([bbox, score, cls], dim=-1)[idx] # filter
|
||||
orig_img = orig_imgs[i]
|
||||
oh, ow = orig_img.shape[:2]
|
||||
pred[..., [0, 2]] *= ow
|
||||
pred[..., [1, 3]] *= oh
|
||||
img_path = self.batch[0][i]
|
||||
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
||||
return results
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""Pre-transform input image before inference.
|
||||
|
||||
Args:
|
||||
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
||||
|
||||
Notes: The size must be square(640) and scaleFilled.
|
||||
|
||||
Returns:
|
||||
(list): A list of transformed imgs.
|
||||
"""
|
||||
letterbox = LetterBox(self.imgsz, auto=False, scaleFill=True)
|
||||
return [letterbox(image=x) for x in im]
|
72
ytracking/ultralytics/models/rtdetr/train.py
Normal file
72
ytracking/ultralytics/models/rtdetr/train.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.models.yolo.detect import DetectionTrainer
|
||||
from ultralytics.nn.tasks import RTDETRDetectionModel
|
||||
from ultralytics.utils import RANK, colorstr
|
||||
|
||||
from .val import RTDETRDataset, RTDETRValidator
|
||||
|
||||
|
||||
class RTDETRTrainer(DetectionTrainer):
|
||||
"""
|
||||
A class extending the DetectionTrainer class for training based on an RT-DETR detection model.
|
||||
|
||||
Notes:
|
||||
- F.grid_sample used in rt-detr does not support the `deterministic=True` argument.
|
||||
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.models.rtdetr.train import RTDETRTrainer
|
||||
|
||||
args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
|
||||
trainer = RTDETRTrainer(overrides=args)
|
||||
trainer.train()
|
||||
```
|
||||
"""
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Return a YOLO detection model."""
|
||||
model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
return model
|
||||
|
||||
def build_dataset(self, img_path, mode='val', batch=None):
|
||||
"""Build RTDETR Dataset
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||
"""
|
||||
return RTDETRDataset(
|
||||
img_path=img_path,
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch,
|
||||
augment=mode == 'train', # no augmentation
|
||||
hyp=self.args,
|
||||
rect=False, # no rect
|
||||
cache=self.args.cache or None,
|
||||
prefix=colorstr(f'{mode}: '),
|
||||
data=self.data)
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns a DetectionValidator for RTDETR model validation."""
|
||||
self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
|
||||
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocesses a batch of images by scaling and converting to float."""
|
||||
batch = super().preprocess_batch(batch)
|
||||
bs = len(batch['img'])
|
||||
batch_idx = batch['batch_idx']
|
||||
gt_bbox, gt_class = [], []
|
||||
for i in range(bs):
|
||||
gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
|
||||
gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
|
||||
return batch
|
141
ytracking/ultralytics/models/rtdetr/val.py
Normal file
141
ytracking/ultralytics/models/rtdetr/val.py
Normal file
@ -0,0 +1,141 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data import YOLODataset
|
||||
from ultralytics.data.augment import Compose, Format, v8_transforms
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import colorstr, ops
|
||||
|
||||
__all__ = 'RTDETRValidator', # tuple or list
|
||||
|
||||
|
||||
# TODO: Temporarily RT-DETR does not need padding.
|
||||
class RTDETRDataset(YOLODataset):
|
||||
|
||||
def __init__(self, *args, data=None, **kwargs):
|
||||
super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs)
|
||||
|
||||
# NOTE: add stretch version load_image for rtdetr mosaic
|
||||
def load_image(self, i, rect_mode=False):
|
||||
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
|
||||
return super().load_image(i=i, rect_mode=rect_mode)
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
"""Temporary, only for evaluation."""
|
||||
if self.augment:
|
||||
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
|
||||
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
|
||||
transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
|
||||
else:
|
||||
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
|
||||
transforms = Compose([])
|
||||
transforms.append(
|
||||
Format(bbox_format='xywh',
|
||||
normalize=True,
|
||||
return_mask=self.use_segments,
|
||||
return_keypoint=self.use_keypoints,
|
||||
batch_idx=True,
|
||||
mask_ratio=hyp.mask_ratio,
|
||||
mask_overlap=hyp.overlap_mask))
|
||||
return transforms
|
||||
|
||||
|
||||
class RTDETRValidator(DetectionValidator):
|
||||
"""
|
||||
A class extending the DetectionValidator class for validation based on an RT-DETR detection model.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.models.rtdetr import RTDETRValidator
|
||||
|
||||
args = dict(model='rtdetr-l.pt', data='coco8.yaml')
|
||||
validator = RTDETRValidator(args=args)
|
||||
validator()
|
||||
```
|
||||
"""
|
||||
|
||||
def build_dataset(self, img_path, mode='val', batch=None):
|
||||
"""
|
||||
Build an RTDETR Dataset.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||
"""
|
||||
return RTDETRDataset(
|
||||
img_path=img_path,
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch,
|
||||
augment=False, # no augmentation
|
||||
hyp=self.args,
|
||||
rect=False, # no rect
|
||||
cache=self.args.cache or None,
|
||||
prefix=colorstr(f'{mode}: '),
|
||||
data=self.data)
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Apply Non-maximum suppression to prediction outputs."""
|
||||
bs, _, nd = preds[0].shape
|
||||
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
||||
bboxes *= self.args.imgsz
|
||||
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
|
||||
for i, bbox in enumerate(bboxes): # (300, 4)
|
||||
bbox = ops.xywh2xyxy(bbox)
|
||||
score, cls = scores[i].max(-1) # (300, )
|
||||
# Do not need threshold for evaluation as only got 300 boxes here.
|
||||
# idx = score > self.args.conf
|
||||
pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter
|
||||
# sort by confidence to correctly get internal metrics.
|
||||
pred = pred[score.argsort(descending=True)]
|
||||
outputs[i] = pred # [idx]
|
||||
|
||||
return outputs
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
"""Metrics."""
|
||||
for si, pred in enumerate(preds):
|
||||
idx = batch['batch_idx'] == si
|
||||
cls = batch['cls'][idx]
|
||||
bbox = batch['bboxes'][idx]
|
||||
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
|
||||
shape = batch['ori_shape'][si]
|
||||
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
self.seen += 1
|
||||
|
||||
if npr == 0:
|
||||
if nl:
|
||||
self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), cls.squeeze(-1)))
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
|
||||
continue
|
||||
|
||||
# Predictions
|
||||
if self.args.single_cls:
|
||||
pred[:, 5] = 0
|
||||
predn = pred.clone()
|
||||
predn[..., [0, 2]] *= shape[1] / self.args.imgsz # native-space pred
|
||||
predn[..., [1, 3]] *= shape[0] / self.args.imgsz # native-space pred
|
||||
|
||||
# Evaluate
|
||||
if nl:
|
||||
tbox = ops.xywh2xyxy(bbox) # target boxes
|
||||
tbox[..., [0, 2]] *= shape[1] # native-space pred
|
||||
tbox[..., [1, 3]] *= shape[0] # native-space pred
|
||||
labelsn = torch.cat((cls, tbox), 1) # native-space labels
|
||||
# NOTE: To get correct metrics, the inputs of `_process_batch` should always be float32 type.
|
||||
correct_bboxes = self._process_batch(predn.float(), labelsn)
|
||||
# TODO: maybe remove these `self.` arguments as they already are member variable
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(predn, labelsn)
|
||||
self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) # (conf, pcls, tcls)
|
||||
|
||||
# Save
|
||||
if self.args.save_json:
|
||||
self.pred_to_json(predn, batch['im_file'][si])
|
||||
if self.args.save_txt:
|
||||
file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt'
|
||||
self.save_one_txt(predn, self.args.save_conf, shape, file)
|
8
ytracking/ultralytics/models/sam/__init__.py
Normal file
8
ytracking/ultralytics/models/sam/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .model import SAM
|
||||
from .predict import Predictor
|
||||
|
||||
# from .build import build_sam
|
||||
|
||||
__all__ = 'SAM', 'Predictor' # tuple or list
|
Binary file not shown.
Binary file not shown.
BIN
ytracking/ultralytics/models/sam/__pycache__/amg.cpython-38.pyc
Normal file
BIN
ytracking/ultralytics/models/sam/__pycache__/amg.cpython-38.pyc
Normal file
Binary file not shown.
BIN
ytracking/ultralytics/models/sam/__pycache__/amg.cpython-39.pyc
Normal file
BIN
ytracking/ultralytics/models/sam/__pycache__/amg.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
180
ytracking/ultralytics/models/sam/amg.py
Normal file
180
ytracking/ultralytics/models/sam/amg.py
Normal file
@ -0,0 +1,180 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import math
|
||||
from itertools import product
|
||||
from typing import Any, Generator, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def is_box_near_crop_edge(boxes: torch.Tensor,
|
||||
crop_box: List[int],
|
||||
orig_box: List[int],
|
||||
atol: float = 20.0) -> torch.Tensor:
|
||||
"""Return a boolean tensor indicating if boxes are near the crop edge."""
|
||||
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
||||
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
||||
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
||||
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
||||
return torch.any(near_crop_edge, dim=1)
|
||||
|
||||
|
||||
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
||||
"""Yield batches of data from the input arguments."""
|
||||
assert args and all(len(a) == len(args[0]) for a in args), 'Batched iteration must have same-size inputs.'
|
||||
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
||||
for b in range(n_batches):
|
||||
yield [arg[b * batch_size:(b + 1) * batch_size] for arg in args]
|
||||
|
||||
|
||||
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
||||
"""
|
||||
Computes the stability score for a batch of masks. The stability
|
||||
score is the IoU between the binary masks obtained by thresholding
|
||||
the predicted mask logits at high and low values.
|
||||
"""
|
||||
# One mask is always contained inside the other.
|
||||
# Save memory by preventing unnecessary cast to torch.int64
|
||||
intersections = ((masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1,
|
||||
dtype=torch.int32))
|
||||
unions = ((masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32))
|
||||
return intersections / unions
|
||||
|
||||
|
||||
def build_point_grid(n_per_side: int) -> np.ndarray:
|
||||
"""Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1]."""
|
||||
offset = 1 / (2 * n_per_side)
|
||||
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
||||
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
||||
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
||||
return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
||||
|
||||
|
||||
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
|
||||
"""Generate point grids for all crop layers."""
|
||||
return [build_point_grid(int(n_per_side / (scale_per_layer ** i))) for i in range(n_layers + 1)]
|
||||
|
||||
|
||||
def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int,
|
||||
overlap_ratio: float) -> Tuple[List[List[int]], List[int]]:
|
||||
"""Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer."""
|
||||
crop_boxes, layer_idxs = [], []
|
||||
im_h, im_w = im_size
|
||||
short_side = min(im_h, im_w)
|
||||
|
||||
# Original image
|
||||
crop_boxes.append([0, 0, im_w, im_h])
|
||||
layer_idxs.append(0)
|
||||
|
||||
def crop_len(orig_len, n_crops, overlap):
|
||||
"""Crops bounding boxes to the size of the input image."""
|
||||
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
|
||||
|
||||
for i_layer in range(n_layers):
|
||||
n_crops_per_side = 2 ** (i_layer + 1)
|
||||
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
||||
|
||||
crop_w = crop_len(im_w, n_crops_per_side, overlap)
|
||||
crop_h = crop_len(im_h, n_crops_per_side, overlap)
|
||||
|
||||
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
|
||||
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
|
||||
|
||||
# Crops in XYWH format
|
||||
for x0, y0 in product(crop_box_x0, crop_box_y0):
|
||||
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
|
||||
crop_boxes.append(box)
|
||||
layer_idxs.append(i_layer + 1)
|
||||
|
||||
return crop_boxes, layer_idxs
|
||||
|
||||
|
||||
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
||||
"""Uncrop bounding boxes by adding the crop box offset."""
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
||||
# Check if boxes has a channel dimension
|
||||
if len(boxes.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
return boxes + offset
|
||||
|
||||
|
||||
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
||||
"""Uncrop points by adding the crop box offset."""
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = torch.tensor([[x0, y0]], device=points.device)
|
||||
# Check if points has a channel dimension
|
||||
if len(points.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
return points + offset
|
||||
|
||||
|
||||
def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
|
||||
"""Uncrop masks by padding them to the original image size."""
|
||||
x0, y0, x1, y1 = crop_box
|
||||
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
||||
return masks
|
||||
# Coordinate transform masks
|
||||
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
|
||||
pad = (x0, pad_x - x0, y0, pad_y - y0)
|
||||
return torch.nn.functional.pad(masks, pad, value=0)
|
||||
|
||||
|
||||
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
|
||||
"""Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in {'holes', 'islands'}
|
||||
correct_holes = mode == 'holes'
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if not small_regions:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
# If every region is below threshold, keep largest
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
|
||||
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
|
||||
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
||||
"""
|
||||
# torch.max below raises an error on empty inputs, just skip in this case
|
||||
if torch.numel(masks) == 0:
|
||||
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
||||
|
||||
# Normalize shape to CxHxW
|
||||
shape = masks.shape
|
||||
h, w = shape[-2:]
|
||||
masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0)
|
||||
# Get top and bottom edges
|
||||
in_height, _ = torch.max(masks, dim=-1)
|
||||
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
|
||||
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
||||
in_height_coords = in_height_coords + h * (~in_height)
|
||||
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
||||
|
||||
# Get left and right edges
|
||||
in_width, _ = torch.max(masks, dim=-2)
|
||||
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
|
||||
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
||||
in_width_coords = in_width_coords + w * (~in_width)
|
||||
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
||||
|
||||
# If the mask is empty the right edge will be to the left of the left edge.
|
||||
# Replace these boxes with [0, 0, 0, 0]
|
||||
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
||||
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
||||
out = out * (~empty_filter).unsqueeze(-1)
|
||||
|
||||
# Return to original shape
|
||||
return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0]
|
158
ytracking/ultralytics/models/sam/build.py
Normal file
158
ytracking/ultralytics/models/sam/build.py
Normal file
@ -0,0 +1,158 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.utils.downloads import attempt_download_asset
|
||||
|
||||
from .modules.decoders import MaskDecoder
|
||||
from .modules.encoders import ImageEncoderViT, PromptEncoder
|
||||
from .modules.sam import Sam
|
||||
from .modules.tiny_encoder import TinyViT
|
||||
from .modules.transformer import TwoWayTransformer
|
||||
|
||||
|
||||
def build_sam_vit_h(checkpoint=None):
|
||||
"""Build and return a Segment Anything Model (SAM) h-size model."""
|
||||
return _build_sam(
|
||||
encoder_embed_dim=1280,
|
||||
encoder_depth=32,
|
||||
encoder_num_heads=16,
|
||||
encoder_global_attn_indexes=[7, 15, 23, 31],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam_vit_l(checkpoint=None):
|
||||
"""Build and return a Segment Anything Model (SAM) l-size model."""
|
||||
return _build_sam(
|
||||
encoder_embed_dim=1024,
|
||||
encoder_depth=24,
|
||||
encoder_num_heads=16,
|
||||
encoder_global_attn_indexes=[5, 11, 17, 23],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam_vit_b(checkpoint=None):
|
||||
"""Build and return a Segment Anything Model (SAM) b-size model."""
|
||||
return _build_sam(
|
||||
encoder_embed_dim=768,
|
||||
encoder_depth=12,
|
||||
encoder_num_heads=12,
|
||||
encoder_global_attn_indexes=[2, 5, 8, 11],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_mobile_sam(checkpoint=None):
|
||||
"""Build and return Mobile Segment Anything Model (Mobile-SAM)."""
|
||||
return _build_sam(
|
||||
encoder_embed_dim=[64, 128, 160, 320],
|
||||
encoder_depth=[2, 2, 6, 2],
|
||||
encoder_num_heads=[2, 4, 5, 10],
|
||||
encoder_global_attn_indexes=None,
|
||||
mobile_sam=True,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def _build_sam(encoder_embed_dim,
|
||||
encoder_depth,
|
||||
encoder_num_heads,
|
||||
encoder_global_attn_indexes,
|
||||
checkpoint=None,
|
||||
mobile_sam=False):
|
||||
"""Builds the selected SAM model architecture."""
|
||||
prompt_embed_dim = 256
|
||||
image_size = 1024
|
||||
vit_patch_size = 16
|
||||
image_embedding_size = image_size // vit_patch_size
|
||||
image_encoder = (TinyViT(
|
||||
img_size=1024,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dims=encoder_embed_dim,
|
||||
depths=encoder_depth,
|
||||
num_heads=encoder_num_heads,
|
||||
window_sizes=[7, 7, 14, 7],
|
||||
mlp_ratio=4.0,
|
||||
drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
use_checkpoint=False,
|
||||
mbconv_expand_ratio=4.0,
|
||||
local_conv_size=3,
|
||||
layer_lr_decay=0.8,
|
||||
) if mobile_sam else ImageEncoderViT(
|
||||
depth=encoder_depth,
|
||||
embed_dim=encoder_embed_dim,
|
||||
img_size=image_size,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
||||
num_heads=encoder_num_heads,
|
||||
patch_size=vit_patch_size,
|
||||
qkv_bias=True,
|
||||
use_rel_pos=True,
|
||||
global_attn_indexes=encoder_global_attn_indexes,
|
||||
window_size=14,
|
||||
out_chans=prompt_embed_dim,
|
||||
))
|
||||
sam = Sam(
|
||||
image_encoder=image_encoder,
|
||||
prompt_encoder=PromptEncoder(
|
||||
embed_dim=prompt_embed_dim,
|
||||
image_embedding_size=(image_embedding_size, image_embedding_size),
|
||||
input_image_size=(image_size, image_size),
|
||||
mask_in_chans=16,
|
||||
),
|
||||
mask_decoder=MaskDecoder(
|
||||
num_multimask_outputs=3,
|
||||
transformer=TwoWayTransformer(
|
||||
depth=2,
|
||||
embedding_dim=prompt_embed_dim,
|
||||
mlp_dim=2048,
|
||||
num_heads=8,
|
||||
),
|
||||
transformer_dim=prompt_embed_dim,
|
||||
iou_head_depth=3,
|
||||
iou_head_hidden_dim=256,
|
||||
),
|
||||
pixel_mean=[123.675, 116.28, 103.53],
|
||||
pixel_std=[58.395, 57.12, 57.375],
|
||||
)
|
||||
if checkpoint is not None:
|
||||
checkpoint = attempt_download_asset(checkpoint)
|
||||
with open(checkpoint, 'rb') as f:
|
||||
state_dict = torch.load(f)
|
||||
sam.load_state_dict(state_dict)
|
||||
sam.eval()
|
||||
# sam.load_state_dict(torch.load(checkpoint), strict=True)
|
||||
# sam.eval()
|
||||
return sam
|
||||
|
||||
|
||||
sam_model_map = {
|
||||
'sam_h.pt': build_sam_vit_h,
|
||||
'sam_l.pt': build_sam_vit_l,
|
||||
'sam_b.pt': build_sam_vit_b,
|
||||
'mobile_sam.pt': build_mobile_sam, }
|
||||
|
||||
|
||||
def build_sam(ckpt='sam_b.pt'):
|
||||
"""Build a SAM model specified by ckpt."""
|
||||
model_builder = None
|
||||
for k in sam_model_map.keys():
|
||||
if ckpt.endswith(k):
|
||||
model_builder = sam_model_map.get(k)
|
||||
|
||||
if not model_builder:
|
||||
raise FileNotFoundError(f'{ckpt} is not a supported sam model. Available models are: \n {sam_model_map.keys()}')
|
||||
|
||||
return model_builder(ckpt)
|
51
ytracking/ultralytics/models/sam/model.py
Normal file
51
ytracking/ultralytics/models/sam/model.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
SAM model interface
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.engine.model import Model
|
||||
from ultralytics.utils.torch_utils import model_info
|
||||
|
||||
from .build import build_sam
|
||||
from .predict import Predictor
|
||||
|
||||
|
||||
class SAM(Model):
|
||||
"""
|
||||
SAM model interface.
|
||||
"""
|
||||
|
||||
def __init__(self, model='sam_b.pt') -> None:
|
||||
if model and Path(model).suffix not in ('.pt', '.pth'):
|
||||
raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
|
||||
super().__init__(model=model, task='segment')
|
||||
|
||||
def _load(self, weights: str, task=None):
|
||||
self.model = build_sam(weights)
|
||||
|
||||
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""Predicts and returns segmentation masks for given image or video source."""
|
||||
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
|
||||
kwargs.update(overrides)
|
||||
prompts = dict(bboxes=bboxes, points=points, labels=labels)
|
||||
return super().predict(source, stream, prompts=prompts, **kwargs)
|
||||
|
||||
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
||||
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
||||
|
||||
def info(self, detailed=False, verbose=True):
|
||||
"""
|
||||
Logs model info.
|
||||
|
||||
Args:
|
||||
detailed (bool): Show detailed information about model.
|
||||
verbose (bool): Controls verbosity.
|
||||
"""
|
||||
return model_info(self.model, detailed=detailed, verbose=verbose)
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
return {'segment': {'predictor': Predictor}}
|
1
ytracking/ultralytics/models/sam/modules/__init__.py
Normal file
1
ytracking/ultralytics/models/sam/modules/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
159
ytracking/ultralytics/models/sam/modules/decoders.py
Normal file
159
ytracking/ultralytics/models/sam/modules/decoders.py
Normal file
@ -0,0 +1,159 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ultralytics.nn.modules import LayerNorm2d
|
||||
|
||||
|
||||
class MaskDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
transformer_dim: int,
|
||||
transformer: nn.Module,
|
||||
num_multimask_outputs: int = 3,
|
||||
activation: Type[nn.Module] = nn.GELU,
|
||||
iou_head_depth: int = 3,
|
||||
iou_head_hidden_dim: int = 256,
|
||||
) -> None:
|
||||
"""
|
||||
Predicts masks given an image and prompt embeddings, using a transformer architecture.
|
||||
|
||||
Args:
|
||||
transformer_dim (int): the channel dimension of the transformer module
|
||||
transformer (nn.Module): the transformer used to predict masks
|
||||
num_multimask_outputs (int): the number of masks to predict when disambiguating masks
|
||||
activation (nn.Module): the type of activation to use when upscaling masks
|
||||
iou_head_depth (int): the depth of the MLP used to predict mask quality
|
||||
iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality
|
||||
"""
|
||||
super().__init__()
|
||||
self.transformer_dim = transformer_dim
|
||||
self.transformer = transformer
|
||||
|
||||
self.num_multimask_outputs = num_multimask_outputs
|
||||
|
||||
self.iou_token = nn.Embedding(1, transformer_dim)
|
||||
self.num_mask_tokens = num_multimask_outputs + 1
|
||||
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
||||
|
||||
self.output_upscaling = nn.Sequential(
|
||||
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(transformer_dim // 4),
|
||||
activation(),
|
||||
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||
activation(),
|
||||
)
|
||||
self.output_hypernetworks_mlps = nn.ModuleList([
|
||||
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)])
|
||||
|
||||
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
multimask_output: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks given image and prompt embeddings.
|
||||
|
||||
Args:
|
||||
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
||||
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
||||
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
||||
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
||||
multimask_output (bool): Whether to return multiple masks or a single mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: batched predicted masks
|
||||
torch.Tensor: batched predictions of mask quality
|
||||
"""
|
||||
masks, iou_pred = self.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=image_pe,
|
||||
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
||||
dense_prompt_embeddings=dense_prompt_embeddings,
|
||||
)
|
||||
|
||||
# Select the correct mask or masks for output
|
||||
mask_slice = slice(1, None) if multimask_output else slice(0, 1)
|
||||
masks = masks[:, mask_slice, :, :]
|
||||
iou_pred = iou_pred[:, mask_slice]
|
||||
|
||||
# Prepare output
|
||||
return masks, iou_pred
|
||||
|
||||
def predict_masks(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Predicts masks. See 'forward' for more details."""
|
||||
# Concatenate output tokens
|
||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
||||
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
||||
|
||||
# Expand per-image data in batch direction to be per-mask
|
||||
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
||||
src = src + dense_prompt_embeddings
|
||||
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
||||
b, c, h, w = src.shape
|
||||
|
||||
# Run the transformer
|
||||
hs, src = self.transformer(src, pos_src, tokens)
|
||||
iou_token_out = hs[:, 0, :]
|
||||
mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]
|
||||
|
||||
# Upscale mask embeddings and predict masks using the mask tokens
|
||||
src = src.transpose(1, 2).view(b, c, h, w)
|
||||
upscaled_embedding = self.output_upscaling(src)
|
||||
hyper_in_list: List[torch.Tensor] = [
|
||||
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)]
|
||||
hyper_in = torch.stack(hyper_in_list, dim=1)
|
||||
b, c, h, w = upscaled_embedding.shape
|
||||
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
||||
|
||||
# Generate mask quality predictions
|
||||
iou_pred = self.iou_prediction_head(iou_token_out)
|
||||
|
||||
return masks, iou_pred
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
Lightly adapted from
|
||||
https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
sigmoid_output: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
self.sigmoid_output = sigmoid_output
|
||||
|
||||
def forward(self, x):
|
||||
"""Executes feedforward within the neural network module and applies activation."""
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
if self.sigmoid_output:
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
568
ytracking/ultralytics/models/sam/modules/encoders.py
Normal file
568
ytracking/ultralytics/models/sam/modules/encoders.py
Normal file
@ -0,0 +1,568 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from typing import Any, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.nn.modules import LayerNorm2d, MLPBlock
|
||||
|
||||
|
||||
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
||||
class ImageEncoderViT(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int = 1024,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_chans: int = 256,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
use_abs_pos: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
global_attn_indexes: Tuple[int, ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
img_size (int): Input image size.
|
||||
patch_size (int): Patch size.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): Patch embedding dimension.
|
||||
depth (int): Depth of ViT.
|
||||
num_heads (int): Number of attention heads in each ViT block.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
act_layer (nn.Module): Activation layer.
|
||||
use_abs_pos (bool): If True, use absolute positional embeddings.
|
||||
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
window_size (int): Window size for window attention blocks.
|
||||
global_attn_indexes (list): Indexes for blocks using global attention.
|
||||
"""
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
kernel_size=(patch_size, patch_size),
|
||||
stride=(patch_size, patch_size),
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
|
||||
self.pos_embed: Optional[nn.Parameter] = None
|
||||
if use_abs_pos:
|
||||
# Initialize absolute positional embedding with pretrain image size.
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for i in range(depth):
|
||||
block = Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
use_rel_pos=use_rel_pos,
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
window_size=window_size if i not in global_attn_indexes else 0,
|
||||
input_size=(img_size // patch_size, img_size // patch_size),
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
self.neck = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
embed_dim,
|
||||
out_chans,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_chans),
|
||||
nn.Conv2d(
|
||||
out_chans,
|
||||
out_chans,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_chans),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
return self.neck(x.permute(0, 3, 1, 2))
|
||||
|
||||
|
||||
class PromptEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
image_embedding_size: Tuple[int, int],
|
||||
input_image_size: Tuple[int, int],
|
||||
mask_in_chans: int,
|
||||
activation: Type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
"""
|
||||
Encodes prompts for input to SAM's mask decoder.
|
||||
|
||||
Args:
|
||||
embed_dim (int): The prompts' embedding dimension
|
||||
image_embedding_size (tuple(int, int)): The spatial size of the
|
||||
image embedding, as (H, W).
|
||||
input_image_size (int): The padded size of the image as input
|
||||
to the image encoder, as (H, W).
|
||||
mask_in_chans (int): The number of hidden channels used for
|
||||
encoding input masks.
|
||||
activation (nn.Module): The activation to use when encoding
|
||||
input masks.
|
||||
"""
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.input_image_size = input_image_size
|
||||
self.image_embedding_size = image_embedding_size
|
||||
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
||||
|
||||
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
||||
point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)]
|
||||
self.point_embeddings = nn.ModuleList(point_embeddings)
|
||||
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
||||
|
||||
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
|
||||
self.mask_downscaling = nn.Sequential(
|
||||
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(mask_in_chans // 4),
|
||||
activation(),
|
||||
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
||||
LayerNorm2d(mask_in_chans),
|
||||
activation(),
|
||||
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
||||
)
|
||||
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
||||
|
||||
def get_dense_pe(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns the positional encoding used to encode point prompts,
|
||||
applied to a dense set of points the shape of the image encoding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
|
||||
"""
|
||||
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
||||
|
||||
def _embed_points(
|
||||
self,
|
||||
points: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
pad: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Embeds point prompts."""
|
||||
points = points + 0.5 # Shift to center of pixel
|
||||
if pad:
|
||||
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
||||
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
||||
points = torch.cat([points, padding_point], dim=1)
|
||||
labels = torch.cat([labels, padding_label], dim=1)
|
||||
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
|
||||
point_embedding[labels == -1] = 0.0
|
||||
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
||||
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
||||
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
||||
return point_embedding
|
||||
|
||||
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
||||
"""Embeds box prompts."""
|
||||
boxes = boxes + 0.5 # Shift to center of pixel
|
||||
coords = boxes.reshape(-1, 2, 2)
|
||||
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
|
||||
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
||||
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
||||
return corner_embedding
|
||||
|
||||
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
||||
"""Embeds mask inputs."""
|
||||
return self.mask_downscaling(masks)
|
||||
|
||||
def _get_batch_size(
|
||||
self,
|
||||
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
) -> int:
|
||||
"""
|
||||
Gets the batch size of the output given the batch size of the input prompts.
|
||||
"""
|
||||
if points is not None:
|
||||
return points[0].shape[0]
|
||||
elif boxes is not None:
|
||||
return boxes.shape[0]
|
||||
elif masks is not None:
|
||||
return masks.shape[0]
|
||||
else:
|
||||
return 1
|
||||
|
||||
def _get_device(self) -> torch.device:
|
||||
return self.point_embeddings[0].weight.device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Embeds different types of prompts, returning both sparse and dense embeddings.
|
||||
|
||||
Args:
|
||||
points (tuple(torch.Tensor, torch.Tensor), None): point coordinates and labels to embed.
|
||||
boxes (torch.Tensor, None): boxes to embed
|
||||
masks (torch.Tensor, None): masks to embed
|
||||
|
||||
Returns:
|
||||
torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined
|
||||
by the number of input points and boxes.
|
||||
torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W)
|
||||
"""
|
||||
bs = self._get_batch_size(points, boxes, masks)
|
||||
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
|
||||
if points is not None:
|
||||
coords, labels = points
|
||||
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
||||
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
||||
if boxes is not None:
|
||||
box_embeddings = self._embed_boxes(boxes)
|
||||
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
||||
|
||||
if masks is not None:
|
||||
dense_embeddings = self._embed_masks(masks)
|
||||
else:
|
||||
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1,
|
||||
1).expand(bs, -1, self.image_embedding_size[0],
|
||||
self.image_embedding_size[1])
|
||||
|
||||
return sparse_embeddings, dense_embeddings
|
||||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""
|
||||
Positional encoding using random spatial frequencies.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
||||
super().__init__()
|
||||
if scale is None or scale <= 0.0:
|
||||
scale = 1.0
|
||||
self.register_buffer('positional_encoding_gaussian_matrix', scale * torch.randn((2, num_pos_feats)))
|
||||
|
||||
# Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.deterministic = False
|
||||
|
||||
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
"""Positionally encode points that are normalized to [0,1]."""
|
||||
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
||||
coords = 2 * coords - 1
|
||||
coords = coords @ self.positional_encoding_gaussian_matrix
|
||||
coords = 2 * np.pi * coords
|
||||
# outputs d_1 x ... x d_n x C shape
|
||||
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
||||
|
||||
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Generate positional encoding for a grid of the specified size."""
|
||||
h, w = size
|
||||
device: Any = self.positional_encoding_gaussian_matrix.device
|
||||
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
||||
y_embed = grid.cumsum(dim=0) - 0.5
|
||||
x_embed = grid.cumsum(dim=1) - 0.5
|
||||
y_embed = y_embed / h
|
||||
x_embed = x_embed / w
|
||||
|
||||
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
||||
return pe.permute(2, 0, 1) # C x H x W
|
||||
|
||||
def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Positionally encode points that are not normalized to [0,1]."""
|
||||
coords = coords_input.clone()
|
||||
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
||||
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
||||
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads in each ViT block.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
act_layer (nn.Module): Activation layer.
|
||||
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
window_size (int): Window size for window attention blocks. If it equals 0, then
|
||||
use global attention.
|
||||
input_size (tuple(int, int), None): Input resolution for calculating the relative
|
||||
positional parameter size.
|
||||
"""
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rel_pos=use_rel_pos,
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
input_size=input_size if window_size == 0 else (window_size, window_size),
|
||||
)
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
|
||||
|
||||
self.window_size = window_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
# Window partition
|
||||
if self.window_size > 0:
|
||||
H, W = x.shape[1], x.shape[2]
|
||||
x, pad_hw = window_partition(x, self.window_size)
|
||||
|
||||
x = self.attn(x)
|
||||
# Reverse window partition
|
||||
if self.window_size > 0:
|
||||
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
||||
|
||||
x = shortcut + x
|
||||
return x + self.mlp(self.norm2(x))
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Multi-head Attention block with relative position embeddings."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
input_size (tuple(int, int), None): Input resolution for calculating the relative
|
||||
positional parameter size.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.use_rel_pos = use_rel_pos
|
||||
if self.use_rel_pos:
|
||||
assert (input_size is not None), 'Input size must be provided if using relative positional encoding.'
|
||||
# initialize relative positional embeddings
|
||||
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
||||
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (3, B, nHead, H * W, C)
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
# q, k, v with shape (B * nHead, H * W, C)
|
||||
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
||||
|
||||
attn = (q * self.scale) @ k.transpose(-2, -1)
|
||||
|
||||
if self.use_rel_pos:
|
||||
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
"""
|
||||
Partition into non-overlapping windows with padding if needed.
|
||||
Args:
|
||||
x (tensor): input tokens with [B, H, W, C].
|
||||
window_size (int): window size.
|
||||
|
||||
Returns:
|
||||
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
||||
(Hp, Wp): padded height and width before partition
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
|
||||
pad_h = (window_size - H % window_size) % window_size
|
||||
pad_w = (window_size - W % window_size) % window_size
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
||||
Hp, Wp = H + pad_h, W + pad_w
|
||||
|
||||
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows, (Hp, Wp)
|
||||
|
||||
|
||||
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int],
|
||||
hw: Tuple[int, int]) -> torch.Tensor:
|
||||
"""
|
||||
Window unpartition into original sequences and removing padding.
|
||||
Args:
|
||||
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
||||
window_size (int): window size.
|
||||
pad_hw (Tuple): padded height and width (Hp, Wp).
|
||||
hw (Tuple): original height and width (H, W) before padding.
|
||||
|
||||
Returns:
|
||||
x: unpartitioned sequences with [B, H, W, C].
|
||||
"""
|
||||
Hp, Wp = pad_hw
|
||||
H, W = hw
|
||||
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
||||
|
||||
if Hp > H or Wp > W:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get relative positional embeddings according to the relative positions of
|
||||
query and key sizes.
|
||||
Args:
|
||||
q_size (int): size of query q.
|
||||
k_size (int): size of key k.
|
||||
rel_pos (Tensor): relative position embeddings (L, C).
|
||||
|
||||
Returns:
|
||||
Extracted positional embeddings according to relative positions.
|
||||
"""
|
||||
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||
# Interpolate rel pos if needed.
|
||||
if rel_pos.shape[0] != max_rel_dist:
|
||||
# Interpolate rel pos.
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode='linear',
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||
else:
|
||||
rel_pos_resized = rel_pos
|
||||
|
||||
# Scale the coords with short length if shapes for q and k are different.
|
||||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
||||
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
||||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
||||
|
||||
return rel_pos_resized[relative_coords.long()]
|
||||
|
||||
|
||||
def add_decomposed_rel_pos(
|
||||
attn: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
rel_pos_h: torch.Tensor,
|
||||
rel_pos_w: torch.Tensor,
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
||||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
||||
Args:
|
||||
attn (Tensor): attention map.
|
||||
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
||||
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
||||
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
||||
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
||||
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
||||
|
||||
Returns:
|
||||
attn (Tensor): attention map with added relative positional embeddings.
|
||||
"""
|
||||
q_h, q_w = q_size
|
||||
k_h, k_w = k_size
|
||||
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
||||
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
||||
|
||||
B, _, dim = q.shape
|
||||
r_q = q.reshape(B, q_h, q_w, dim)
|
||||
rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh)
|
||||
rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw)
|
||||
|
||||
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
|
||||
B, q_h * q_w, k_h * k_w)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
Image to Patch Embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Tuple[int, int] = (16, 16),
|
||||
stride: Tuple[int, int] = (16, 16),
|
||||
padding: Tuple[int, int] = (0, 0),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
kernel_size (Tuple): kernel size of the projection layer.
|
||||
stride (Tuple): stride of the projection layer.
|
||||
padding (Tuple): padding size of the projection layer.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): Patch embedding dimension.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
|
49
ytracking/ultralytics/models/sam/modules/sam.py
Normal file
49
ytracking/ultralytics/models/sam/modules/sam.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .decoders import MaskDecoder
|
||||
from .encoders import ImageEncoderViT, PromptEncoder
|
||||
|
||||
|
||||
class Sam(nn.Module):
|
||||
mask_threshold: float = 0.0
|
||||
image_format: str = 'RGB'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_encoder: ImageEncoderViT,
|
||||
prompt_encoder: PromptEncoder,
|
||||
mask_decoder: MaskDecoder,
|
||||
pixel_mean: List[float] = (123.675, 116.28, 103.53),
|
||||
pixel_std: List[float] = (58.395, 57.12, 57.375)
|
||||
) -> None:
|
||||
"""
|
||||
SAM predicts object masks from an image and input prompts.
|
||||
|
||||
Note:
|
||||
All forward() operations moved to SAMPredictor.
|
||||
|
||||
Args:
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
|
||||
efficient mask prediction.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
|
||||
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
||||
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
||||
"""
|
||||
super().__init__()
|
||||
self.image_encoder = image_encoder
|
||||
self.prompt_encoder = prompt_encoder
|
||||
self.mask_decoder = mask_decoder
|
||||
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
||||
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
582
ytracking/ultralytics/models/sam/modules/tiny_encoder.py
Normal file
582
ytracking/ultralytics/models/sam/modules/tiny_encoder.py
Normal file
@ -0,0 +1,582 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
# --------------------------------------------------------
|
||||
# TinyViT Model Architecture
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Adapted from LeViT and Swin Transformer
|
||||
# LeViT: (https://github.com/facebookresearch/levit)
|
||||
# Swin: (https://github.com/microsoft/swin-transformer)
|
||||
# Build the TinyViT Model
|
||||
# --------------------------------------------------------
|
||||
|
||||
import itertools
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from ultralytics.utils.instance import to_2tuple
|
||||
|
||||
|
||||
class Conv2d_BN(torch.nn.Sequential):
|
||||
|
||||
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
|
||||
super().__init__()
|
||||
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
||||
bn = torch.nn.BatchNorm2d(b)
|
||||
torch.nn.init.constant_(bn.weight, bn_weight_init)
|
||||
torch.nn.init.constant_(bn.bias, 0)
|
||||
self.add_module('bn', bn)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
|
||||
def __init__(self, in_chans, embed_dim, resolution, activation):
|
||||
super().__init__()
|
||||
img_size: Tuple[int, int] = to_2tuple(resolution)
|
||||
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
||||
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
n = embed_dim
|
||||
self.seq = nn.Sequential(
|
||||
Conv2d_BN(in_chans, n // 2, 3, 2, 1),
|
||||
activation(),
|
||||
Conv2d_BN(n // 2, n, 3, 2, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.seq(x)
|
||||
|
||||
|
||||
class MBConv(nn.Module):
|
||||
|
||||
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.hidden_chans = int(in_chans * expand_ratio)
|
||||
self.out_chans = out_chans
|
||||
|
||||
self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
|
||||
self.act1 = activation()
|
||||
|
||||
self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans)
|
||||
self.act2 = activation()
|
||||
|
||||
self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
|
||||
self.act3 = activation()
|
||||
|
||||
# NOTE: `DropPath` is needed only for training.
|
||||
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.drop_path = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
x = self.conv1(x)
|
||||
x = self.act1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.act2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.drop_path(x)
|
||||
x += shortcut
|
||||
return self.act3(x)
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
|
||||
def __init__(self, input_resolution, dim, out_dim, activation):
|
||||
super().__init__()
|
||||
|
||||
self.input_resolution = input_resolution
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.act = activation()
|
||||
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
|
||||
stride_c = 1 if out_dim in [320, 448, 576] else 2
|
||||
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
|
||||
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
if x.ndim == 3:
|
||||
H, W = self.input_resolution
|
||||
B = len(x)
|
||||
# (B, C, H, W)
|
||||
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.act(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.act(x)
|
||||
x = self.conv3(x)
|
||||
return x.flatten(2).transpose(1, 2)
|
||||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
input_resolution,
|
||||
depth,
|
||||
activation,
|
||||
drop_path=0.,
|
||||
downsample=None,
|
||||
use_checkpoint=False,
|
||||
out_dim=None,
|
||||
conv_expand_ratio=4.,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
MBConv(
|
||||
dim,
|
||||
dim,
|
||||
conv_expand_ratio,
|
||||
activation,
|
||||
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
) for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
self.downsample = None if downsample is None else downsample(
|
||||
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
|
||||
def forward(self, x):
|
||||
for blk in self.blocks:
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.norm = nn.LayerNorm(in_features)
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.act = act_layer()
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
return self.drop(x)
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
key_dim,
|
||||
num_heads=8,
|
||||
attn_ratio=4,
|
||||
resolution=(14, 14),
|
||||
):
|
||||
super().__init__()
|
||||
# (h, w)
|
||||
assert isinstance(resolution, tuple) and len(resolution) == 2
|
||||
self.num_heads = num_heads
|
||||
self.scale = key_dim ** -0.5
|
||||
self.key_dim = key_dim
|
||||
self.nh_kd = nh_kd = key_dim * num_heads
|
||||
self.d = int(attn_ratio * key_dim)
|
||||
self.dh = int(attn_ratio * key_dim) * num_heads
|
||||
self.attn_ratio = attn_ratio
|
||||
h = self.dh + nh_kd * 2
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.qkv = nn.Linear(dim, h)
|
||||
self.proj = nn.Linear(self.dh, dim)
|
||||
|
||||
points = list(itertools.product(range(resolution[0]), range(resolution[1])))
|
||||
N = len(points)
|
||||
attention_offsets = {}
|
||||
idxs = []
|
||||
for p1 in points:
|
||||
for p2 in points:
|
||||
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
|
||||
if offset not in attention_offsets:
|
||||
attention_offsets[offset] = len(attention_offsets)
|
||||
idxs.append(attention_offsets[offset])
|
||||
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
||||
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
if mode and hasattr(self, 'ab'):
|
||||
del self.ab
|
||||
else:
|
||||
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
||||
|
||||
def forward(self, x): # x (B,N,C)
|
||||
B, N, _ = x.shape
|
||||
|
||||
# Normalization
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.qkv(x)
|
||||
# (B, N, num_heads, d)
|
||||
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
|
||||
# (B, num_heads, N, d)
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
self.ab = self.ab.to(self.attention_biases.device)
|
||||
|
||||
attn = ((q @ k.transpose(-2, -1)) * self.scale +
|
||||
(self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab))
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class TinyViTBlock(nn.Module):
|
||||
"""
|
||||
TinyViT Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int, int]): Input resolution.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
local_conv_size (int): the kernel size of the convolution between Attention and MLP. Default: 3
|
||||
activation (torch.nn): the activation function. Default: nn.GELU
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
local_conv_size=3,
|
||||
activation=nn.GELU,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.num_heads = num_heads
|
||||
assert window_size > 0, 'window_size must be greater than 0'
|
||||
self.window_size = window_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
# NOTE: `DropPath` is needed only for training.
|
||||
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.drop_path = nn.Identity()
|
||||
|
||||
assert dim % num_heads == 0, 'dim must be divisible by num_heads'
|
||||
head_dim = dim // num_heads
|
||||
|
||||
window_resolution = (window_size, window_size)
|
||||
self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution)
|
||||
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
mlp_activation = activation
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=mlp_activation, drop=drop)
|
||||
|
||||
pad = local_conv_size // 2
|
||||
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
|
||||
|
||||
def forward(self, x):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
res_x = x
|
||||
if H == self.window_size and W == self.window_size:
|
||||
x = self.attn(x)
|
||||
else:
|
||||
x = x.view(B, H, W, C)
|
||||
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||
padding = pad_b > 0 or pad_r > 0
|
||||
|
||||
if padding:
|
||||
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
|
||||
|
||||
pH, pW = H + pad_b, W + pad_r
|
||||
nH = pH // self.window_size
|
||||
nW = pW // self.window_size
|
||||
# window partition
|
||||
x = x.view(B, nH, self.window_size, nW, self.window_size,
|
||||
C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C)
|
||||
x = self.attn(x)
|
||||
# window reverse
|
||||
x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
|
||||
|
||||
if padding:
|
||||
x = x[:, :H, :W].contiguous()
|
||||
|
||||
x = x.view(B, L, C)
|
||||
|
||||
x = res_x + self.drop_path(x)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, C, H, W)
|
||||
x = self.local_conv(x)
|
||||
x = x.view(B, C, L).transpose(1, 2)
|
||||
|
||||
return x + self.drop_path(self.mlp(x))
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
|
||||
f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
"""
|
||||
A basic TinyViT layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
local_conv_size (int): the kernel size of the depthwise convolution between attention and MLP. Default: 3
|
||||
activation (torch.nn): the activation function. Default: nn.GELU
|
||||
out_dim (int | optional): the output dimension of the layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
input_resolution,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size,
|
||||
mlp_ratio=4.,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
downsample=None,
|
||||
use_checkpoint=False,
|
||||
local_conv_size=3,
|
||||
activation=nn.GELU,
|
||||
out_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
TinyViTBlock(
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
local_conv_size=local_conv_size,
|
||||
activation=activation,
|
||||
) for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
self.downsample = None if downsample is None else downsample(
|
||||
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
|
||||
def forward(self, x):
|
||||
for blk in self.blocks:
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(num_channels))
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
return self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
|
||||
|
||||
class TinyViT(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dims=[96, 192, 384, 768],
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_sizes=[7, 7, 14, 7],
|
||||
mlp_ratio=4.,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
use_checkpoint=False,
|
||||
mbconv_expand_ratio=4.0,
|
||||
local_conv_size=3,
|
||||
layer_lr_decay=1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
self.num_classes = num_classes
|
||||
self.depths = depths
|
||||
self.num_layers = len(depths)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
activation = nn.GELU
|
||||
|
||||
self.patch_embed = PatchEmbed(in_chans=in_chans,
|
||||
embed_dim=embed_dims[0],
|
||||
resolution=img_size,
|
||||
activation=activation)
|
||||
|
||||
patches_resolution = self.patch_embed.patches_resolution
|
||||
self.patches_resolution = patches_resolution
|
||||
|
||||
# stochastic depth
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
|
||||
# build layers
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
kwargs = dict(
|
||||
dim=embed_dims[i_layer],
|
||||
input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
||||
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
|
||||
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
||||
# patches_resolution[1] // (2 ** i_layer)),
|
||||
depth=depths[i_layer],
|
||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
use_checkpoint=use_checkpoint,
|
||||
out_dim=embed_dims[min(i_layer + 1,
|
||||
len(embed_dims) - 1)],
|
||||
activation=activation,
|
||||
)
|
||||
if i_layer == 0:
|
||||
layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
|
||||
else:
|
||||
layer = BasicLayer(num_heads=num_heads[i_layer],
|
||||
window_size=window_sizes[i_layer],
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
drop=drop_rate,
|
||||
local_conv_size=local_conv_size,
|
||||
**kwargs)
|
||||
self.layers.append(layer)
|
||||
|
||||
# Classifier head
|
||||
self.norm_head = nn.LayerNorm(embed_dims[-1])
|
||||
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
|
||||
|
||||
# init weights
|
||||
self.apply(self._init_weights)
|
||||
self.set_layer_lr_decay(layer_lr_decay)
|
||||
self.neck = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
embed_dims[-1],
|
||||
256,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(256),
|
||||
nn.Conv2d(
|
||||
256,
|
||||
256,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(256),
|
||||
)
|
||||
|
||||
def set_layer_lr_decay(self, layer_lr_decay):
|
||||
decay_rate = layer_lr_decay
|
||||
|
||||
# layers -> blocks (depth)
|
||||
depth = sum(self.depths)
|
||||
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
|
||||
|
||||
def _set_lr_scale(m, scale):
|
||||
for p in m.parameters():
|
||||
p.lr_scale = scale
|
||||
|
||||
self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
|
||||
i = 0
|
||||
for layer in self.layers:
|
||||
for block in layer.blocks:
|
||||
block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
|
||||
i += 1
|
||||
if layer.downsample is not None:
|
||||
layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))
|
||||
assert i == depth
|
||||
for m in [self.norm_head, self.head]:
|
||||
m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
|
||||
|
||||
for k, p in self.named_parameters():
|
||||
p.param_name = k
|
||||
|
||||
def _check_lr_scale(m):
|
||||
for p in m.parameters():
|
||||
assert hasattr(p, 'lr_scale'), p.param_name
|
||||
|
||||
self.apply(_check_lr_scale)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
# NOTE: This initialization is needed only for training.
|
||||
# trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay_keywords(self):
|
||||
return {'attention_biases'}
|
||||
|
||||
def forward_features(self, x):
|
||||
# x: (N, C, H, W)
|
||||
x = self.patch_embed(x)
|
||||
|
||||
x = self.layers[0](x)
|
||||
start_i = 1
|
||||
|
||||
for i in range(start_i, len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
x = layer(x)
|
||||
B, _, C = x.size()
|
||||
x = x.view(B, 64, 64, C)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
return self.neck(x)
|
||||
|
||||
def forward(self, x):
|
||||
return self.forward_features(x)
|
230
ytracking/ultralytics/models/sam/modules/transformer.py
Normal file
230
ytracking/ultralytics/models/sam/modules/transformer.py
Normal file
@ -0,0 +1,230 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import math
|
||||
from typing import Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ultralytics.nn.modules import MLPBlock
|
||||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth: int,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer decoder that attends to an input image using
|
||||
queries whose positional embedding is supplied.
|
||||
|
||||
Args:
|
||||
depth (int): number of layers in the transformer
|
||||
embedding_dim (int): the channel dimension for the input embeddings
|
||||
num_heads (int): the number of heads for multihead attention. Must
|
||||
divide embedding_dim
|
||||
mlp_dim (int): the channel dimension internal to the MLP block
|
||||
activation (nn.Module): the activation to use in the MLP block
|
||||
"""
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.mlp_dim = mlp_dim
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
for i in range(depth):
|
||||
self.layers.append(
|
||||
TwoWayAttentionBlock(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_dim=mlp_dim,
|
||||
activation=activation,
|
||||
attention_downsample_rate=attention_downsample_rate,
|
||||
skip_first_layer_pe=(i == 0),
|
||||
))
|
||||
|
||||
self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
||||
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embedding: Tensor,
|
||||
image_pe: Tensor,
|
||||
point_embedding: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w.
|
||||
image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding.
|
||||
point_embedding (torch.Tensor): the embedding to add to the query points.
|
||||
Must have shape B x N_points x embedding_dim for any N_points.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): the processed point_embedding
|
||||
(torch.Tensor): the processed image_embedding
|
||||
"""
|
||||
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
||||
bs, c, h, w = image_embedding.shape
|
||||
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
||||
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
||||
|
||||
# Prepare queries
|
||||
queries = point_embedding
|
||||
keys = image_embedding
|
||||
|
||||
# Apply transformer blocks and final layernorm
|
||||
for layer in self.layers:
|
||||
queries, keys = layer(
|
||||
queries=queries,
|
||||
keys=keys,
|
||||
query_pe=point_embedding,
|
||||
key_pe=image_pe,
|
||||
)
|
||||
|
||||
# Apply the final attention layer from the points to the image
|
||||
q = queries + point_embedding
|
||||
k = keys + image_pe
|
||||
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm_final_attn(queries)
|
||||
|
||||
return queries, keys
|
||||
|
||||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int = 2048,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
skip_first_layer_pe: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse
|
||||
inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
||||
inputs.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): the channel dimension of the embeddings
|
||||
num_heads (int): the number of heads in the attention layers
|
||||
mlp_dim (int): the hidden dimension of the mlp block
|
||||
activation (nn.Module): the activation of the mlp block
|
||||
skip_first_layer_pe (bool): skip the PE on the first layer
|
||||
"""
|
||||
super().__init__()
|
||||
self.self_attn = Attention(embedding_dim, num_heads)
|
||||
self.norm1 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
||||
self.norm2 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
||||
self.norm3 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.norm4 = nn.LayerNorm(embedding_dim)
|
||||
self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
||||
|
||||
self.skip_first_layer_pe = skip_first_layer_pe
|
||||
|
||||
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Apply self-attention and cross-attention to queries and keys and return the processed embeddings."""
|
||||
|
||||
# Self attention block
|
||||
if self.skip_first_layer_pe:
|
||||
queries = self.self_attn(q=queries, k=queries, v=queries)
|
||||
else:
|
||||
q = queries + query_pe
|
||||
attn_out = self.self_attn(q=q, k=q, v=queries)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm1(queries)
|
||||
|
||||
# Cross attention block, tokens attending to image embedding
|
||||
q = queries + query_pe
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm2(queries)
|
||||
|
||||
# MLP block
|
||||
mlp_out = self.mlp(queries)
|
||||
queries = queries + mlp_out
|
||||
queries = self.norm3(queries)
|
||||
|
||||
# Cross attention block, image embedding attending to tokens
|
||||
q = queries + query_pe
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
||||
keys = keys + attn_out
|
||||
keys = self.norm4(keys)
|
||||
|
||||
return queries, keys
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||
values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
downsample_rate: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.internal_dim = embedding_dim // downsample_rate
|
||||
self.num_heads = num_heads
|
||||
assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.'
|
||||
|
||||
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
||||
|
||||
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
||||
"""Separate the input tensor into the specified number of attention heads."""
|
||||
b, n, c = x.shape
|
||||
x = x.reshape(b, n, num_heads, c // num_heads)
|
||||
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
||||
|
||||
def _recombine_heads(self, x: Tensor) -> Tensor:
|
||||
"""Recombine the separated attention heads into a single tensor."""
|
||||
b, n_heads, n_tokens, c_per_head = x.shape
|
||||
x = x.transpose(1, 2)
|
||||
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
||||
"""Compute the attention output given the input query, key, and value tensors."""
|
||||
|
||||
# Input projections
|
||||
q = self.q_proj(q)
|
||||
k = self.k_proj(k)
|
||||
v = self.v_proj(v)
|
||||
|
||||
# Separate into heads
|
||||
q = self._separate_heads(q, self.num_heads)
|
||||
k = self._separate_heads(k, self.num_heads)
|
||||
v = self._separate_heads(v, self.num_heads)
|
||||
|
||||
# Attention
|
||||
_, _, _, c_per_head = q.shape
|
||||
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
||||
attn = attn / math.sqrt(c_per_head)
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
|
||||
# Get output
|
||||
out = attn @ v
|
||||
out = self._recombine_heads(out)
|
||||
return self.out_proj(out)
|
408
ytracking/ultralytics/models/sam/predict.py
Normal file
408
ytracking/ultralytics/models/sam/predict.py
Normal file
@ -0,0 +1,408 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from ultralytics.data.augment import LetterBox
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
from ultralytics.utils.torch_utils import select_device
|
||||
|
||||
from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score,
|
||||
generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks)
|
||||
from .build import build_sam
|
||||
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
# SAM needs retina_masks=True, or the results would be a mess.
|
||||
self.args.retina_masks = True
|
||||
# Args for set_image
|
||||
self.im = None
|
||||
self.features = None
|
||||
# Args for set_prompts
|
||||
self.prompts = {}
|
||||
# Args for segment everything
|
||||
self.segment_all = False
|
||||
|
||||
def preprocess(self, im):
|
||||
"""Prepares input image before inference.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
|
||||
"""
|
||||
if self.im is not None:
|
||||
return self.im
|
||||
not_tensor = not isinstance(im, torch.Tensor)
|
||||
if not_tensor:
|
||||
im = np.stack(self.pre_transform(im))
|
||||
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
im = torch.from_numpy(im)
|
||||
|
||||
im = im.to(self.device)
|
||||
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
|
||||
if not_tensor:
|
||||
im = (im - self.mean) / self.std
|
||||
return im
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
Pre-transform input image before inference.
|
||||
|
||||
Args:
|
||||
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
||||
|
||||
Returns:
|
||||
(list): A list of transformed images.
|
||||
"""
|
||||
assert len(im) == 1, 'SAM model does not currently support batched inference'
|
||||
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
|
||||
return [letterbox(image=x) for x in im]
|
||||
|
||||
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
|
||||
"""
|
||||
Predict masks for the given input prompts, using the currently set image.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
||||
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
|
||||
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
|
||||
1 indicates a foreground point and 0 indicates a background point.
|
||||
masks (np.ndarray, None): A low resolution mask input to the model, typically
|
||||
coming from a previous prediction iteration. Has form (N, H, W), where
|
||||
for SAM, H=W=256.
|
||||
multimask_output (bool): If true, the model will return three masks.
|
||||
For ambiguous input prompts (such as a single click), this will often
|
||||
produce better masks than a single prediction. If only a single
|
||||
mask is needed, the model's predicted quality score can be used
|
||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
||||
input prompts, multimask_output=False can give better results.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The output masks in CxHxW format, where C is the
|
||||
number of masks, and (H, W) is the original image size.
|
||||
(np.ndarray): An array of length C containing the model's
|
||||
predictions for the quality of each mask.
|
||||
(np.ndarray): An array of shape CxHxW, where C is the number
|
||||
of masks and H=W=256. These low resolution logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
"""
|
||||
# Get prompts from self.prompts first
|
||||
bboxes = self.prompts.pop('bboxes', bboxes)
|
||||
points = self.prompts.pop('points', points)
|
||||
masks = self.prompts.pop('masks', masks)
|
||||
if all(i is None for i in [bboxes, points, masks]):
|
||||
return self.generate(im, *args, **kwargs)
|
||||
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
||||
|
||||
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
|
||||
"""
|
||||
Predict masks for the given input prompts, using the currently set image.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
||||
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
|
||||
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
|
||||
1 indicates a foreground point and 0 indicates a background point.
|
||||
masks (np.ndarray, None): A low resolution mask input to the model, typically
|
||||
coming from a previous prediction iteration. Has form (N, H, W), where
|
||||
for SAM, H=W=256.
|
||||
multimask_output (bool): If true, the model will return three masks.
|
||||
For ambiguous input prompts (such as a single click), this will often
|
||||
produce better masks than a single prediction. If only a single
|
||||
mask is needed, the model's predicted quality score can be used
|
||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
||||
input prompts, multimask_output=False can give better results.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The output masks in CxHxW format, where C is the
|
||||
number of masks, and (H, W) is the original image size.
|
||||
(np.ndarray): An array of length C containing the model's
|
||||
predictions for the quality of each mask.
|
||||
(np.ndarray): An array of shape CxHxW, where C is the number
|
||||
of masks and H=W=256. These low resolution logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
"""
|
||||
features = self.model.image_encoder(im) if self.features is None else self.features
|
||||
|
||||
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
|
||||
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
|
||||
# Transform input prompts
|
||||
if points is not None:
|
||||
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
|
||||
points = points[None] if points.ndim == 1 else points
|
||||
# Assuming labels are all positive if users don't pass labels.
|
||||
if labels is None:
|
||||
labels = np.ones(points.shape[0])
|
||||
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
||||
points *= r
|
||||
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
||||
points, labels = points[:, None, :], labels[:, None]
|
||||
if bboxes is not None:
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
|
||||
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
||||
bboxes *= r
|
||||
if masks is not None:
|
||||
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
|
||||
points = (points, labels) if points is not None else None
|
||||
# Embed prompts
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
||||
points=points,
|
||||
boxes=bboxes,
|
||||
masks=masks,
|
||||
)
|
||||
|
||||
# Predict masks
|
||||
pred_masks, pred_scores = self.model.mask_decoder(
|
||||
image_embeddings=features,
|
||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
|
||||
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
|
||||
# `d` could be 1 or 3 depends on `multimask_output`.
|
||||
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
||||
|
||||
def generate(self,
|
||||
im,
|
||||
crop_n_layers=0,
|
||||
crop_overlap_ratio=512 / 1500,
|
||||
crop_downscale_factor=1,
|
||||
point_grids=None,
|
||||
points_stride=32,
|
||||
points_batch_size=64,
|
||||
conf_thres=0.88,
|
||||
stability_score_thresh=0.95,
|
||||
stability_score_offset=0.95,
|
||||
crop_nms_thresh=0.7):
|
||||
"""Segment the whole image.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
||||
crop_n_layers (int): If >0, mask prediction will be run again on
|
||||
crops of the image. Sets the number of layers to run, where each
|
||||
layer has 2**i_layer number of image crops.
|
||||
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
||||
In the first crop layer, crops will overlap by this fraction of
|
||||
the image length. Later layers with more crops scale down this overlap.
|
||||
crop_downscale_factor (int): The number of points-per-side
|
||||
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
point_grids (list(np.ndarray), None): A list over explicit grids
|
||||
of points used for sampling, normalized to [0,1]. The nth grid in the
|
||||
list is used in the nth crop layer. Exclusive with points_per_side.
|
||||
points_stride (int, None): The number of points to be sampled
|
||||
along one side of the image. The total number of points is
|
||||
points_per_side**2. If None, 'point_grids' must provide explicit
|
||||
point sampling.
|
||||
points_batch_size (int): Sets the number of points run simultaneously
|
||||
by the model. Higher numbers may be faster but use more GPU memory.
|
||||
conf_thres (float): A filtering threshold in [0,1], using the
|
||||
model's predicted mask quality.
|
||||
stability_score_thresh (float): A filtering threshold in [0,1], using
|
||||
the stability of the mask under changes to the cutoff used to binarize
|
||||
the model's mask predictions.
|
||||
stability_score_offset (float): The amount to shift the cutoff when
|
||||
calculated the stability score.
|
||||
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks between different crops.
|
||||
"""
|
||||
self.segment_all = True
|
||||
ih, iw = im.shape[2:]
|
||||
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
|
||||
if point_grids is None:
|
||||
point_grids = build_all_layer_point_grids(
|
||||
points_stride,
|
||||
crop_n_layers,
|
||||
crop_downscale_factor,
|
||||
)
|
||||
pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
|
||||
for crop_region, layer_idx in zip(crop_regions, layer_idxs):
|
||||
x1, y1, x2, y2 = crop_region
|
||||
w, h = x2 - x1, y2 - y1
|
||||
area = torch.tensor(w * h, device=im.device)
|
||||
points_scale = np.array([[w, h]]) # w, h
|
||||
# Crop image and interpolate to input size
|
||||
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False)
|
||||
# (num_points, 2)
|
||||
points_for_image = point_grids[layer_idx] * points_scale
|
||||
crop_masks, crop_scores, crop_bboxes = [], [], []
|
||||
for (points, ) in batch_iterator(points_batch_size, points_for_image):
|
||||
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
|
||||
# Interpolate predicted masks to input size
|
||||
pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0]
|
||||
idx = pred_score > conf_thres
|
||||
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
||||
|
||||
stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold,
|
||||
stability_score_offset)
|
||||
idx = stability_score > stability_score_thresh
|
||||
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
||||
# Bool type is much more memory-efficient.
|
||||
pred_mask = pred_mask > self.model.mask_threshold
|
||||
# (N, 4)
|
||||
pred_bbox = batched_mask_to_box(pred_mask).float()
|
||||
keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih])
|
||||
if not torch.all(keep_mask):
|
||||
pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask]
|
||||
|
||||
crop_masks.append(pred_mask)
|
||||
crop_bboxes.append(pred_bbox)
|
||||
crop_scores.append(pred_score)
|
||||
|
||||
# Do nms within this crop
|
||||
crop_masks = torch.cat(crop_masks)
|
||||
crop_bboxes = torch.cat(crop_bboxes)
|
||||
crop_scores = torch.cat(crop_scores)
|
||||
keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS
|
||||
crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region)
|
||||
crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw)
|
||||
crop_scores = crop_scores[keep]
|
||||
|
||||
pred_masks.append(crop_masks)
|
||||
pred_bboxes.append(crop_bboxes)
|
||||
pred_scores.append(crop_scores)
|
||||
region_areas.append(area.expand(len(crop_masks)))
|
||||
|
||||
pred_masks = torch.cat(pred_masks)
|
||||
pred_bboxes = torch.cat(pred_bboxes)
|
||||
pred_scores = torch.cat(pred_scores)
|
||||
region_areas = torch.cat(region_areas)
|
||||
|
||||
# Remove duplicate masks between crops
|
||||
if len(crop_regions) > 1:
|
||||
scores = 1 / region_areas
|
||||
keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh)
|
||||
pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep]
|
||||
|
||||
return pred_masks, pred_scores, pred_bboxes
|
||||
|
||||
def setup_model(self, model, verbose=True):
|
||||
"""Set up YOLO model with specified thresholds and device."""
|
||||
device = select_device(self.args.device, verbose=verbose)
|
||||
if model is None:
|
||||
model = build_sam(self.args.model)
|
||||
model.eval()
|
||||
self.model = model.to(device)
|
||||
self.device = device
|
||||
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
|
||||
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
|
||||
# TODO: Temporary settings for compatibility
|
||||
self.model.pt = False
|
||||
self.model.triton = False
|
||||
self.model.stride = 32
|
||||
self.model.fp16 = False
|
||||
self.done_warmup = True
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Post-processes inference output predictions to create detection masks for objects."""
|
||||
# (N, 1, H, W), (N, 1)
|
||||
pred_masks, pred_scores = preds[:2]
|
||||
pred_bboxes = preds[2] if self.segment_all else None
|
||||
names = dict(enumerate(str(i) for i in range(len(pred_masks))))
|
||||
|
||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
|
||||
results = []
|
||||
for i, masks in enumerate([pred_masks]):
|
||||
orig_img = orig_imgs[i]
|
||||
if pred_bboxes is not None:
|
||||
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
|
||||
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
|
||||
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
|
||||
|
||||
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
|
||||
masks = masks > self.model.mask_threshold # to bool
|
||||
img_path = self.batch[0][i]
|
||||
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
|
||||
# Reset segment-all mode.
|
||||
self.segment_all = False
|
||||
return results
|
||||
|
||||
def setup_source(self, source):
|
||||
"""Sets up source and inference mode."""
|
||||
if source is not None:
|
||||
super().setup_source(source)
|
||||
|
||||
def set_image(self, image):
|
||||
"""Set image in advance.
|
||||
Args:
|
||||
|
||||
image (str | np.ndarray): image file path or np.ndarray image by cv2.
|
||||
"""
|
||||
if self.model is None:
|
||||
model = build_sam(self.args.model)
|
||||
self.setup_model(model)
|
||||
self.setup_source(image)
|
||||
assert len(self.dataset) == 1, '`set_image` only supports setting one image!'
|
||||
for batch in self.dataset:
|
||||
im = self.preprocess(batch[1])
|
||||
self.features = self.model.image_encoder(im)
|
||||
self.im = im
|
||||
break
|
||||
|
||||
def set_prompts(self, prompts):
|
||||
"""Set prompts in advance."""
|
||||
self.prompts = prompts
|
||||
|
||||
def reset_image(self):
|
||||
self.im = None
|
||||
self.features = None
|
||||
|
||||
@staticmethod
|
||||
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
|
||||
"""
|
||||
Removes small disconnected regions and holes in masks, then reruns
|
||||
box NMS to remove any new duplicates. Requires open-cv as a dependency.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Masks, (N, H, W).
|
||||
min_area (int): Minimum area threshold.
|
||||
nms_thresh (float): NMS threshold.
|
||||
Returns:
|
||||
new_masks (torch.Tensor): New Masks, (N, H, W).
|
||||
keep (List[int]): The indices of the new masks, which can be used to filter
|
||||
the corresponding boxes.
|
||||
"""
|
||||
if len(masks) == 0:
|
||||
return masks
|
||||
|
||||
# Filter small disconnected regions and holes
|
||||
new_masks = []
|
||||
scores = []
|
||||
for mask in masks:
|
||||
mask = mask.cpu().numpy().astype(np.uint8)
|
||||
mask, changed = remove_small_regions(mask, min_area, mode='holes')
|
||||
unchanged = not changed
|
||||
mask, changed = remove_small_regions(mask, min_area, mode='islands')
|
||||
unchanged = unchanged and not changed
|
||||
|
||||
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
||||
# Give score=0 to changed masks and score=1 to unchanged masks
|
||||
# so NMS will prefer ones that didn't need postprocessing
|
||||
scores.append(float(unchanged))
|
||||
|
||||
# Recalculate boxes and remove any new duplicates
|
||||
new_masks = torch.cat(new_masks, dim=0)
|
||||
boxes = batched_mask_to_box(new_masks)
|
||||
keep = torchvision.ops.nms(
|
||||
boxes.float(),
|
||||
torch.as_tensor(scores),
|
||||
nms_thresh,
|
||||
)
|
||||
|
||||
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
|
1
ytracking/ultralytics/models/utils/__init__.py
Normal file
1
ytracking/ultralytics/models/utils/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
297
ytracking/ultralytics/models/utils/loss.py
Normal file
297
ytracking/ultralytics/models/utils/loss.py
Normal file
@ -0,0 +1,297 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.utils.loss import FocalLoss, VarifocalLoss
|
||||
from ultralytics.utils.metrics import bbox_iou
|
||||
|
||||
from .ops import HungarianMatcher
|
||||
|
||||
|
||||
class DETRLoss(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
nc=80,
|
||||
loss_gain=None,
|
||||
aux_loss=True,
|
||||
use_fl=True,
|
||||
use_vfl=False,
|
||||
use_uni_match=False,
|
||||
uni_match_ind=0):
|
||||
"""
|
||||
DETR loss function.
|
||||
|
||||
Args:
|
||||
nc (int): The number of classes.
|
||||
loss_gain (dict): The coefficient of loss.
|
||||
aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
|
||||
use_vfl (bool): Use VarifocalLoss or not.
|
||||
use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch.
|
||||
uni_match_ind (int): The fixed indices of a layer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if loss_gain is None:
|
||||
loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1}
|
||||
self.nc = nc
|
||||
self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2})
|
||||
self.loss_gain = loss_gain
|
||||
self.aux_loss = aux_loss
|
||||
self.fl = FocalLoss() if use_fl else None
|
||||
self.vfl = VarifocalLoss() if use_vfl else None
|
||||
|
||||
self.use_uni_match = use_uni_match
|
||||
self.uni_match_ind = uni_match_ind
|
||||
self.device = None
|
||||
|
||||
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
|
||||
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
|
||||
name_class = f'loss_class{postfix}'
|
||||
bs, nq = pred_scores.shape[:2]
|
||||
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
|
||||
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
|
||||
one_hot.scatter_(2, targets.unsqueeze(-1), 1)
|
||||
one_hot = one_hot[..., :-1]
|
||||
gt_scores = gt_scores.view(bs, nq, 1) * one_hot
|
||||
|
||||
if self.fl:
|
||||
if num_gts and self.vfl:
|
||||
loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
|
||||
else:
|
||||
loss_cls = self.fl(pred_scores, one_hot.float())
|
||||
loss_cls /= max(num_gts, 1) / nq
|
||||
else:
|
||||
loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
|
||||
|
||||
return {name_class: loss_cls.squeeze() * self.loss_gain['class']}
|
||||
|
||||
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''):
|
||||
# boxes: [b, query, 4], gt_bbox: list[[n, 4]]
|
||||
name_bbox = f'loss_bbox{postfix}'
|
||||
name_giou = f'loss_giou{postfix}'
|
||||
|
||||
loss = {}
|
||||
if len(gt_bboxes) == 0:
|
||||
loss[name_bbox] = torch.tensor(0., device=self.device)
|
||||
loss[name_giou] = torch.tensor(0., device=self.device)
|
||||
return loss
|
||||
|
||||
loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes)
|
||||
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
|
||||
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
|
||||
loss[name_giou] = self.loss_gain['giou'] * loss[name_giou]
|
||||
return {k: v.squeeze() for k, v in loss.items()}
|
||||
|
||||
# This function is for future RT-DETR Segment models
|
||||
# def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
|
||||
# # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
|
||||
# name_mask = f'loss_mask{postfix}'
|
||||
# name_dice = f'loss_dice{postfix}'
|
||||
#
|
||||
# loss = {}
|
||||
# if sum(len(a) for a in gt_mask) == 0:
|
||||
# loss[name_mask] = torch.tensor(0., device=self.device)
|
||||
# loss[name_dice] = torch.tensor(0., device=self.device)
|
||||
# return loss
|
||||
#
|
||||
# num_gts = len(gt_mask)
|
||||
# src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
|
||||
# src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
|
||||
# # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
|
||||
# loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
|
||||
# torch.tensor([num_gts], dtype=torch.float32))
|
||||
# loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
|
||||
# return loss
|
||||
|
||||
# This function is for future RT-DETR Segment models
|
||||
# @staticmethod
|
||||
# def _dice_loss(inputs, targets, num_gts):
|
||||
# inputs = F.sigmoid(inputs).flatten(1)
|
||||
# targets = targets.flatten(1)
|
||||
# numerator = 2 * (inputs * targets).sum(1)
|
||||
# denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
# loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
# return loss.sum() / num_gts
|
||||
|
||||
def _get_loss_aux(self,
|
||||
pred_bboxes,
|
||||
pred_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
match_indices=None,
|
||||
postfix='',
|
||||
masks=None,
|
||||
gt_mask=None):
|
||||
"""Get auxiliary losses"""
|
||||
# NOTE: loss class, bbox, giou, mask, dice
|
||||
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
|
||||
if match_indices is None and self.use_uni_match:
|
||||
match_indices = self.matcher(pred_bboxes[self.uni_match_ind],
|
||||
pred_scores[self.uni_match_ind],
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=masks[self.uni_match_ind] if masks is not None else None,
|
||||
gt_mask=gt_mask)
|
||||
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
|
||||
aux_masks = masks[i] if masks is not None else None
|
||||
loss_ = self._get_loss(aux_bboxes,
|
||||
aux_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=aux_masks,
|
||||
gt_mask=gt_mask,
|
||||
postfix=postfix,
|
||||
match_indices=match_indices)
|
||||
loss[0] += loss_[f'loss_class{postfix}']
|
||||
loss[1] += loss_[f'loss_bbox{postfix}']
|
||||
loss[2] += loss_[f'loss_giou{postfix}']
|
||||
# if masks is not None and gt_mask is not None:
|
||||
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
|
||||
# loss[3] += loss_[f'loss_mask{postfix}']
|
||||
# loss[4] += loss_[f'loss_dice{postfix}']
|
||||
|
||||
loss = {
|
||||
f'loss_class_aux{postfix}': loss[0],
|
||||
f'loss_bbox_aux{postfix}': loss[1],
|
||||
f'loss_giou_aux{postfix}': loss[2]}
|
||||
# if masks is not None and gt_mask is not None:
|
||||
# loss[f'loss_mask_aux{postfix}'] = loss[3]
|
||||
# loss[f'loss_dice_aux{postfix}'] = loss[4]
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def _get_index(match_indices):
|
||||
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
|
||||
src_idx = torch.cat([src for (src, _) in match_indices])
|
||||
dst_idx = torch.cat([dst for (_, dst) in match_indices])
|
||||
return (batch_idx, src_idx), dst_idx
|
||||
|
||||
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
|
||||
pred_assigned = torch.cat([
|
||||
t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||
for t, (I, _) in zip(pred_bboxes, match_indices)])
|
||||
gt_assigned = torch.cat([
|
||||
t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||
for t, (_, J) in zip(gt_bboxes, match_indices)])
|
||||
return pred_assigned, gt_assigned
|
||||
|
||||
def _get_loss(self,
|
||||
pred_bboxes,
|
||||
pred_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=None,
|
||||
gt_mask=None,
|
||||
postfix='',
|
||||
match_indices=None):
|
||||
"""Get losses"""
|
||||
if match_indices is None:
|
||||
match_indices = self.matcher(pred_bboxes,
|
||||
pred_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=masks,
|
||||
gt_mask=gt_mask)
|
||||
|
||||
idx, gt_idx = self._get_index(match_indices)
|
||||
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
|
||||
|
||||
bs, nq = pred_scores.shape[:2]
|
||||
targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
|
||||
targets[idx] = gt_cls[gt_idx]
|
||||
|
||||
gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
|
||||
if len(gt_bboxes):
|
||||
gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
|
||||
|
||||
loss = {}
|
||||
loss.update(self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix))
|
||||
loss.update(self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix))
|
||||
# if masks is not None and gt_mask is not None:
|
||||
# loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
|
||||
return loss
|
||||
|
||||
def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs):
|
||||
"""
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): [l, b, query, 4]
|
||||
pred_scores (torch.Tensor): [l, b, query, num_classes]
|
||||
batch (dict): A dict includes:
|
||||
gt_cls (torch.Tensor) with shape [num_gts, ],
|
||||
gt_bboxes (torch.Tensor): [num_gts, 4],
|
||||
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
|
||||
postfix (str): postfix of loss name.
|
||||
"""
|
||||
self.device = pred_bboxes.device
|
||||
match_indices = kwargs.get('match_indices', None)
|
||||
gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups']
|
||||
|
||||
total_loss = self._get_loss(pred_bboxes[-1],
|
||||
pred_scores[-1],
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
postfix=postfix,
|
||||
match_indices=match_indices)
|
||||
|
||||
if self.aux_loss:
|
||||
total_loss.update(
|
||||
self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices,
|
||||
postfix))
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
class RTDETRDetectionLoss(DETRLoss):
|
||||
|
||||
def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
|
||||
pred_bboxes, pred_scores = preds
|
||||
total_loss = super().forward(pred_bboxes, pred_scores, batch)
|
||||
|
||||
if dn_meta is not None:
|
||||
dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group']
|
||||
assert len(batch['gt_groups']) == len(dn_pos_idx)
|
||||
|
||||
# Denoising match indices
|
||||
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups'])
|
||||
|
||||
# Compute denoising training loss
|
||||
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices)
|
||||
total_loss.update(dn_loss)
|
||||
else:
|
||||
total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()})
|
||||
|
||||
return total_loss
|
||||
|
||||
@staticmethod
|
||||
def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
|
||||
"""
|
||||
Get the match indices for denoising.
|
||||
|
||||
Args:
|
||||
dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising.
|
||||
dn_num_group (int): The number of groups of denoising.
|
||||
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
|
||||
|
||||
Returns:
|
||||
dn_match_indices (List(tuple)): Matched indices.
|
||||
"""
|
||||
dn_match_indices = []
|
||||
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
||||
for i, num_gt in enumerate(gt_groups):
|
||||
if num_gt > 0:
|
||||
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
|
||||
gt_idx = gt_idx.repeat(dn_num_group)
|
||||
assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, '
|
||||
f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.'
|
||||
dn_match_indices.append((dn_pos_idx[i], gt_idx))
|
||||
else:
|
||||
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
|
||||
return dn_match_indices
|
261
ytracking/ultralytics/models/utils/ops.py
Normal file
261
ytracking/ultralytics/models/utils/ops.py
Normal file
@ -0,0 +1,261 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from ultralytics.utils.metrics import bbox_iou
|
||||
from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
|
||||
|
||||
|
||||
class HungarianMatcher(nn.Module):
|
||||
"""
|
||||
A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in
|
||||
an end-to-end fashion.
|
||||
|
||||
HungarianMatcher performs optimal assignment over the predicted and ground truth bounding boxes using a cost
|
||||
function that considers classification scores, bounding box coordinates, and optionally, mask predictions.
|
||||
|
||||
Attributes:
|
||||
cost_gain (dict): Dictionary of cost coefficients: 'class', 'bbox', 'giou', 'mask', and 'dice'.
|
||||
use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation.
|
||||
with_mask (bool): Indicates whether the model makes mask predictions.
|
||||
num_sample_points (int): The number of sample points used in mask cost calculation.
|
||||
alpha (float): The alpha factor in Focal Loss calculation.
|
||||
gamma (float): The gamma factor in Focal Loss calculation.
|
||||
|
||||
Methods:
|
||||
forward(pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None): Computes the
|
||||
assignment between predictions and ground truths for a batch.
|
||||
_cost_mask(bs, num_gts, masks=None, gt_mask=None): Computes the mask cost and dice cost if masks are predicted.
|
||||
"""
|
||||
|
||||
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
|
||||
super().__init__()
|
||||
if cost_gain is None:
|
||||
cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1}
|
||||
self.cost_gain = cost_gain
|
||||
self.use_fl = use_fl
|
||||
self.with_mask = with_mask
|
||||
self.num_sample_points = num_sample_points
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
|
||||
"""
|
||||
Forward pass for HungarianMatcher. This function computes costs based on prediction and ground truth
|
||||
(classification cost, L1 cost between boxes and GIoU cost between boxes) and finds the optimal matching
|
||||
between predictions and ground truth based on these costs.
|
||||
|
||||
Args:
|
||||
pred_bboxes (Tensor): Predicted bounding boxes with shape [batch_size, num_queries, 4].
|
||||
pred_scores (Tensor): Predicted scores with shape [batch_size, num_queries, num_classes].
|
||||
gt_cls (torch.Tensor): Ground truth classes with shape [num_gts, ].
|
||||
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape [num_gts, 4].
|
||||
gt_groups (List[int]): List of length equal to batch size, containing the number of ground truths for
|
||||
each image.
|
||||
masks (Tensor, optional): Predicted masks with shape [batch_size, num_queries, height, width].
|
||||
Defaults to None.
|
||||
gt_mask (List[Tensor], optional): List of ground truth masks, each with shape [num_masks, Height, Width].
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
(List[Tuple[Tensor, Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where:
|
||||
- index_i is the tensor of indices of the selected predictions (in order)
|
||||
- index_j is the tensor of indices of the corresponding selected ground truth targets (in order)
|
||||
For each batch element, it holds:
|
||||
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
|
||||
bs, nq, nc = pred_scores.shape
|
||||
|
||||
if sum(gt_groups) == 0:
|
||||
return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
# [batch_size * num_queries, num_classes]
|
||||
pred_scores = pred_scores.detach().view(-1, nc)
|
||||
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
|
||||
# [batch_size * num_queries, 4]
|
||||
pred_bboxes = pred_bboxes.detach().view(-1, 4)
|
||||
|
||||
# Compute the classification cost
|
||||
pred_scores = pred_scores[:, gt_cls]
|
||||
if self.use_fl:
|
||||
neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log())
|
||||
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
|
||||
cost_class = pos_cost_class - neg_cost_class
|
||||
else:
|
||||
cost_class = -pred_scores
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
|
||||
|
||||
# Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
|
||||
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
|
||||
|
||||
# Final cost matrix
|
||||
C = self.cost_gain['class'] * cost_class + \
|
||||
self.cost_gain['bbox'] * cost_bbox + \
|
||||
self.cost_gain['giou'] * cost_giou
|
||||
# Compute the mask cost and dice cost
|
||||
if self.with_mask:
|
||||
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
|
||||
|
||||
C = C.view(bs, nq, -1).cpu()
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
|
||||
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
||||
# (idx for queries, idx for gt)
|
||||
return [(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
|
||||
for k, (i, j) in enumerate(indices)]
|
||||
|
||||
# This function is for future RT-DETR Segment models
|
||||
# def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
|
||||
# assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
|
||||
# # all masks share the same set of points for efficient matching
|
||||
# sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
|
||||
# sample_points = 2.0 * sample_points - 1.0
|
||||
#
|
||||
# out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
|
||||
# out_mask = out_mask.flatten(0, 1)
|
||||
#
|
||||
# tgt_mask = torch.cat(gt_mask).unsqueeze(1)
|
||||
# sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
|
||||
# tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
|
||||
#
|
||||
# with torch.cuda.amp.autocast(False):
|
||||
# # binary cross entropy cost
|
||||
# pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
|
||||
# neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
|
||||
# cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
|
||||
# cost_mask /= self.num_sample_points
|
||||
#
|
||||
# # dice cost
|
||||
# out_mask = F.sigmoid(out_mask)
|
||||
# numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
|
||||
# denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
|
||||
# cost_dice = 1 - (numerator + 1) / (denominator + 1)
|
||||
#
|
||||
# C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
|
||||
# return C
|
||||
|
||||
|
||||
def get_cdn_group(batch,
|
||||
num_classes,
|
||||
num_queries,
|
||||
class_embed,
|
||||
num_dn=100,
|
||||
cls_noise_ratio=0.5,
|
||||
box_noise_scale=1.0,
|
||||
training=False):
|
||||
"""
|
||||
Get contrastive denoising training group. This function creates a contrastive denoising training group with
|
||||
positive and negative samples from the ground truths (gt). It applies noise to the class labels and bounding
|
||||
box coordinates, and returns the modified labels, bounding boxes, attention mask and meta information.
|
||||
|
||||
Args:
|
||||
batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape [num_gts, ]), 'gt_bboxes'
|
||||
(torch.Tensor with shape [num_gts, 4]), 'gt_groups' (List(int)) which is a list of batch size length
|
||||
indicating the number of gts of each image.
|
||||
num_classes (int): Number of classes.
|
||||
num_queries (int): Number of queries.
|
||||
class_embed (torch.Tensor): Embedding weights to map class labels to embedding space.
|
||||
num_dn (int, optional): Number of denoising. Defaults to 100.
|
||||
cls_noise_ratio (float, optional): Noise ratio for class labels. Defaults to 0.5.
|
||||
box_noise_scale (float, optional): Noise scale for bounding box coordinates. Defaults to 1.0.
|
||||
training (bool, optional): If it's in training mode. Defaults to False.
|
||||
|
||||
Returns:
|
||||
(Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Dict]]): The modified class embeddings,
|
||||
bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn'
|
||||
is less than or equal to 0, the function returns None for all elements in the tuple.
|
||||
"""
|
||||
|
||||
if (not training) or num_dn <= 0:
|
||||
return None, None, None, None
|
||||
gt_groups = batch['gt_groups']
|
||||
total_num = sum(gt_groups)
|
||||
max_nums = max(gt_groups)
|
||||
if max_nums == 0:
|
||||
return None, None, None, None
|
||||
|
||||
num_group = num_dn // max_nums
|
||||
num_group = 1 if num_group == 0 else num_group
|
||||
# pad gt to max_num of a batch
|
||||
bs = len(gt_groups)
|
||||
gt_cls = batch['cls'] # (bs*num, )
|
||||
gt_bbox = batch['bboxes'] # bs*num, 4
|
||||
b_idx = batch['batch_idx']
|
||||
|
||||
# each group has positive and negative queries.
|
||||
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
|
||||
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
|
||||
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
|
||||
|
||||
# positive and negative mask
|
||||
# (bs*num*num_group, ), the second total_num*num_group part as negative samples
|
||||
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
|
||||
|
||||
if cls_noise_ratio > 0:
|
||||
# half of bbox prob
|
||||
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
|
||||
idx = torch.nonzero(mask).squeeze(-1)
|
||||
# randomly put a new one here
|
||||
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
|
||||
dn_cls[idx] = new_label
|
||||
|
||||
if box_noise_scale > 0:
|
||||
known_bbox = xywh2xyxy(dn_bbox)
|
||||
|
||||
diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
|
||||
|
||||
rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
|
||||
rand_part = torch.rand_like(dn_bbox)
|
||||
rand_part[neg_idx] += 1.0
|
||||
rand_part *= rand_sign
|
||||
known_bbox += rand_part * diff
|
||||
known_bbox.clip_(min=0.0, max=1.0)
|
||||
dn_bbox = xyxy2xywh(known_bbox)
|
||||
dn_bbox = inverse_sigmoid(dn_bbox)
|
||||
|
||||
# total denoising queries
|
||||
num_dn = int(max_nums * 2 * num_group)
|
||||
# class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
|
||||
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
|
||||
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
|
||||
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
|
||||
|
||||
map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
|
||||
pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
|
||||
|
||||
map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
|
||||
padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
|
||||
padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
|
||||
|
||||
tgt_size = num_dn + num_queries
|
||||
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
|
||||
# match query cannot see the reconstruct
|
||||
attn_mask[num_dn:, :num_dn] = True
|
||||
# reconstruct cannot see each other
|
||||
for i in range(num_group):
|
||||
if i == 0:
|
||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
|
||||
if i == num_group - 1:
|
||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True
|
||||
else:
|
||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
|
||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True
|
||||
dn_meta = {
|
||||
'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
|
||||
'dn_num_group': num_group,
|
||||
'dn_num_split': [num_dn, num_queries]}
|
||||
|
||||
return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to(
|
||||
class_embed.device), dn_meta
|
||||
|
||||
|
||||
def inverse_sigmoid(x, eps=1e-6):
|
||||
"""Inverse sigmoid function."""
|
||||
x = x.clip(min=0., max=1.)
|
||||
return torch.log(x / (1 - x + eps) + eps)
|
7
ytracking/ultralytics/models/yolo/__init__.py
Normal file
7
ytracking/ultralytics/models/yolo/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.models.yolo import classify, detect, pose, segment
|
||||
|
||||
from .model import YOLO
|
||||
|
||||
__all__ = 'classify', 'segment', 'detect', 'pose', 'YOLO'
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
7
ytracking/ultralytics/models/yolo/classify/__init__.py
Normal file
7
ytracking/ultralytics/models/yolo/classify/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.models.yolo.classify.predict import ClassificationPredictor
|
||||
from ultralytics.models.yolo.classify.train import ClassificationTrainer
|
||||
from ultralytics.models.yolo.classify.val import ClassificationValidator
|
||||
|
||||
__all__ = 'ClassificationPredictor', 'ClassificationTrainer', 'ClassificationValidator'
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
49
ytracking/ultralytics/models/yolo/classify/predict.py
Normal file
49
ytracking/ultralytics/models/yolo/classify/predict.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
|
||||
|
||||
class ClassificationPredictor(BasePredictor):
|
||||
"""
|
||||
A class extending the BasePredictor class for prediction based on a classification model.
|
||||
|
||||
Notes:
|
||||
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils import ASSETS
|
||||
from ultralytics.models.yolo.classify import ClassificationPredictor
|
||||
|
||||
args = dict(model='yolov8n-cls.pt', source=ASSETS)
|
||||
predictor = ClassificationPredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = 'classify'
|
||||
|
||||
def preprocess(self, img):
|
||||
"""Converts input image to model-compatible data type."""
|
||||
if not isinstance(img, torch.Tensor):
|
||||
img = torch.stack([self.transforms(im) for im in img], dim=0)
|
||||
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
||||
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Post-processes predictions to return Results objects."""
|
||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
|
||||
results = []
|
||||
for i, pred in enumerate(preds):
|
||||
orig_img = orig_imgs[i]
|
||||
img_path = self.batch[0][i]
|
||||
results.append(Results(orig_img, path=img_path, names=self.model.names, probs=pred))
|
||||
return results
|
150
ytracking/ultralytics/models/yolo/classify/train.py
Normal file
150
ytracking/ultralytics/models/yolo/classify/train.py
Normal file
@ -0,0 +1,150 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.engine.trainer import BaseTrainer
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
|
||||
from ultralytics.utils.plotting import plot_images, plot_results
|
||||
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
|
||||
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
"""
|
||||
A class extending the BaseTrainer class for training based on a classification model.
|
||||
|
||||
Notes:
|
||||
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.models.yolo.classify import ClassificationTrainer
|
||||
|
||||
args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3)
|
||||
trainer = ClassificationTrainer(overrides=args)
|
||||
trainer.train()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides['task'] = 'classify'
|
||||
if overrides.get('imgsz') is None:
|
||||
overrides['imgsz'] = 224
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Set the YOLO model's class names from the loaded dataset."""
|
||||
self.model.names = self.data['names']
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Returns a modified PyTorch model configured for training YOLO."""
|
||||
model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
for m in model.modules():
|
||||
if not self.args.pretrained and hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
||||
m.p = self.args.dropout # set dropout
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
return model
|
||||
|
||||
def setup_model(self):
|
||||
"""Load, create or download model for any task."""
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
||||
model, ckpt = str(self.model), None
|
||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||
if model.endswith('.pt'):
|
||||
self.model, ckpt = attempt_load_one_weight(model, device='cpu')
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
elif model.split('.')[-1] in ('yaml', 'yml'):
|
||||
self.model = self.get_model(cfg=model)
|
||||
elif model in torchvision.models.__dict__:
|
||||
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None)
|
||||
else:
|
||||
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
||||
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
|
||||
|
||||
return ckpt
|
||||
|
||||
def build_dataset(self, img_path, mode='train', batch=None):
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train', prefix=mode)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = self.build_dataset(dataset_path, mode)
|
||||
|
||||
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
|
||||
# Attach inference transforms
|
||||
if mode != 'train':
|
||||
if is_parallel(self.model):
|
||||
self.model.module.transforms = loader.dataset.torch_transforms
|
||||
else:
|
||||
self.model.transforms = loader.dataset.torch_transforms
|
||||
return loader
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocesses a batch of images and classes."""
|
||||
batch['img'] = batch['img'].to(self.device)
|
||||
batch['cls'] = batch['cls'].to(self.device)
|
||||
return batch
|
||||
|
||||
def progress_string(self):
|
||||
"""Returns a formatted string showing training progress."""
|
||||
return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
|
||||
('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns an instance of ClassificationValidator for validation."""
|
||||
self.loss_names = ['loss']
|
||||
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir)
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for
|
||||
segmentation & detection
|
||||
"""
|
||||
keys = [f'{prefix}/{x}' for x in self.loss_names]
|
||||
if loss_items is None:
|
||||
return keys
|
||||
loss_items = [round(float(loss_items), 5)]
|
||||
return dict(zip(keys, loss_items))
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plots metrics from a CSV file."""
|
||||
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
|
||||
|
||||
def final_eval(self):
|
||||
"""Evaluate trained model and save validation results."""
|
||||
for f in self.last, self.best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
if f is self.best:
|
||||
LOGGER.info(f'\nValidating {f}...')
|
||||
self.validator.args.data = self.args.data
|
||||
self.validator.args.plots = self.args.plots
|
||||
self.metrics = self.validator(model=f)
|
||||
self.metrics.pop('fitness', None)
|
||||
self.run_callbacks('on_fit_epoch_end')
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plots training samples with their annotations."""
|
||||
plot_images(
|
||||
images=batch['img'],
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
||||
fname=self.save_dir / f'train_batch{ni}.jpg',
|
||||
on_plot=self.on_plot)
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user