add yolo v10 and modify pipeline
This commit is contained in:
@ -12,6 +12,7 @@ import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.metrics import batch_probiou
|
||||
|
||||
|
||||
class Profile(contextlib.ContextDecorator):
|
||||
@ -22,22 +23,24 @@ class Profile(contextlib.ContextDecorator):
|
||||
```python
|
||||
from ultralytics.utils.ops import Profile
|
||||
|
||||
with Profile() as dt:
|
||||
with Profile(device=device) as dt:
|
||||
pass # slow operation here
|
||||
|
||||
print(dt) # prints "Elapsed time is 9.5367431640625e-07 s"
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, t=0.0):
|
||||
def __init__(self, t=0.0, device: torch.device = None):
|
||||
"""
|
||||
Initialize the Profile class.
|
||||
|
||||
Args:
|
||||
t (float): Initial time. Defaults to 0.0.
|
||||
device (torch.device): Devices used for model inference. Defaults to None (cpu).
|
||||
"""
|
||||
self.t = t
|
||||
self.cuda = torch.cuda.is_available()
|
||||
self.device = device
|
||||
self.cuda = bool(device and str(device).startswith("cuda"))
|
||||
|
||||
def __enter__(self):
|
||||
"""Start timing."""
|
||||
@ -50,12 +53,13 @@ class Profile(contextlib.ContextDecorator):
|
||||
self.t += self.dt # accumulate dt
|
||||
|
||||
def __str__(self):
|
||||
return f'Elapsed time is {self.t} s'
|
||||
"""Returns a human-readable string representing the accumulated elapsed time in the profiler."""
|
||||
return f"Elapsed time is {self.t} s"
|
||||
|
||||
def time(self):
|
||||
"""Get current time."""
|
||||
if self.cuda:
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.synchronize(self.device)
|
||||
return time.time()
|
||||
|
||||
|
||||
@ -71,18 +75,21 @@ def segment2box(segment, width=640, height=640):
|
||||
Returns:
|
||||
(np.ndarray): the minimum and maximum x and y values of the segment.
|
||||
"""
|
||||
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
||||
x, y = segment.T # segment xy
|
||||
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
|
||||
x, y, = x[inside], y[inside]
|
||||
return np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) if any(x) else np.zeros(
|
||||
4, dtype=segment.dtype) # xyxy
|
||||
x = x[inside]
|
||||
y = y[inside]
|
||||
return (
|
||||
np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
|
||||
if any(x)
|
||||
else np.zeros(4, dtype=segment.dtype)
|
||||
) # xyxy
|
||||
|
||||
|
||||
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
|
||||
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
|
||||
"""
|
||||
Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
|
||||
(img1_shape) to the shape of a different image (img0_shape).
|
||||
Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
|
||||
specified in (img1_shape) to the shape of a different image (img0_shape).
|
||||
|
||||
Args:
|
||||
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
|
||||
@ -92,24 +99,29 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
|
||||
calculated based on the size difference between the two images.
|
||||
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
|
||||
rescaling.
|
||||
xywh (bool): The box format is xywh or not, default=False.
|
||||
|
||||
Returns:
|
||||
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
|
||||
"""
|
||||
if ratio_pad is None: # calculate from img0_shape
|
||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||
pad = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), round(
|
||||
(img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1) # wh padding
|
||||
pad = (
|
||||
round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
|
||||
round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),
|
||||
) # wh padding
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
|
||||
if padding:
|
||||
boxes[..., [0, 2]] -= pad[0] # x padding
|
||||
boxes[..., [1, 3]] -= pad[1] # y padding
|
||||
boxes[..., 0] -= pad[0] # x padding
|
||||
boxes[..., 1] -= pad[1] # y padding
|
||||
if not xywh:
|
||||
boxes[..., 2] -= pad[0] # x padding
|
||||
boxes[..., 3] -= pad[1] # y padding
|
||||
boxes[..., :4] /= gain
|
||||
clip_boxes(boxes, img0_shape)
|
||||
return boxes
|
||||
return clip_boxes(boxes, img0_shape)
|
||||
|
||||
|
||||
def make_divisible(x, divisor):
|
||||
@ -128,19 +140,41 @@ def make_divisible(x, divisor):
|
||||
return math.ceil(x / divisor) * divisor
|
||||
|
||||
|
||||
def nms_rotated(boxes, scores, threshold=0.45):
|
||||
"""
|
||||
NMS for obbs, powered by probiou and fast-nms.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): (N, 5), xywhr.
|
||||
scores (torch.Tensor): (N, ).
|
||||
threshold (float): IoU threshold.
|
||||
|
||||
Returns:
|
||||
"""
|
||||
if len(boxes) == 0:
|
||||
return np.empty((0,), dtype=np.int8)
|
||||
sorted_idx = torch.argsort(scores, descending=True)
|
||||
boxes = boxes[sorted_idx]
|
||||
ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
|
||||
pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
|
||||
return sorted_idx[pick]
|
||||
|
||||
|
||||
def non_max_suppression(
|
||||
prediction,
|
||||
conf_thres=0.25,
|
||||
iou_thres=0.45,
|
||||
classes=None,
|
||||
agnostic=False,
|
||||
multi_label=False,
|
||||
labels=(),
|
||||
max_det=300,
|
||||
nc=0, # number of classes (optional)
|
||||
max_time_img=0.05,
|
||||
max_nms=30000,
|
||||
max_wh=7680,
|
||||
prediction,
|
||||
conf_thres=0.25,
|
||||
iou_thres=0.45,
|
||||
classes=None,
|
||||
agnostic=False,
|
||||
multi_label=False,
|
||||
labels=(),
|
||||
max_det=300,
|
||||
nc=0, # number of classes (optional)
|
||||
max_time_img=0.05,
|
||||
max_nms=30000,
|
||||
max_wh=7680,
|
||||
in_place=True,
|
||||
rotated=False,
|
||||
):
|
||||
"""
|
||||
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
|
||||
@ -164,7 +198,8 @@ def non_max_suppression(
|
||||
nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
|
||||
max_time_img (float): The maximum time (seconds) for processing one image.
|
||||
max_nms (int): The maximum number of boxes into torchvision.ops.nms().
|
||||
max_wh (int): The maximum box width and height in pixels
|
||||
max_wh (int): The maximum box width and height in pixels.
|
||||
in_place (bool): If True, the input prediction tensor will be modified in place.
|
||||
|
||||
Returns:
|
||||
(List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
|
||||
@ -173,15 +208,11 @@ def non_max_suppression(
|
||||
"""
|
||||
|
||||
# Checks
|
||||
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
|
||||
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
|
||||
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
|
||||
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
|
||||
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
|
||||
prediction = prediction[0] # select only inference output
|
||||
|
||||
device = prediction.device
|
||||
mps = 'mps' in device.type # Apple MPS
|
||||
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
|
||||
prediction = prediction.cpu()
|
||||
bs = prediction.shape[0] # batch size
|
||||
nc = nc or (prediction.shape[1] - 4) # number of classes
|
||||
nm = prediction.shape[1] - nc - 4
|
||||
@ -190,11 +221,15 @@ def non_max_suppression(
|
||||
|
||||
# Settings
|
||||
# min_wh = 2 # (pixels) minimum box width and height
|
||||
time_limit = 0.5 + max_time_img * bs # seconds to quit after
|
||||
time_limit = 2.0 + max_time_img * bs # seconds to quit after
|
||||
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||
|
||||
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
||||
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
||||
if not rotated:
|
||||
if in_place:
|
||||
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
||||
else:
|
||||
prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy
|
||||
|
||||
t = time.time()
|
||||
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
|
||||
@ -204,7 +239,7 @@ def non_max_suppression(
|
||||
x = x[xc[xi]] # confidence
|
||||
|
||||
# Cat apriori labels if autolabelling
|
||||
if labels and len(labels[xi]):
|
||||
if labels and len(labels[xi]) and not rotated:
|
||||
lb = labels[xi]
|
||||
v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
|
||||
v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
|
||||
@ -238,8 +273,13 @@ def non_max_suppression(
|
||||
|
||||
# Batched NMS
|
||||
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
||||
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
||||
scores = x[:, 4] # scores
|
||||
if rotated:
|
||||
boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
|
||||
i = nms_rotated(boxes, scores, iou_thres)
|
||||
else:
|
||||
boxes = x[:, :4] + c # boxes (offset by class)
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
||||
i = i[:max_det] # limit detections
|
||||
|
||||
# # Experimental
|
||||
@ -247,7 +287,7 @@ def non_max_suppression(
|
||||
# if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
||||
# # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
||||
# from .metrics import box_iou
|
||||
# iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
||||
# iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix
|
||||
# weights = iou * scores[None] # box weights
|
||||
# x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
||||
# redundant = True # require redundant detections
|
||||
@ -255,10 +295,8 @@ def non_max_suppression(
|
||||
# i = i[iou.sum(1) > 1] # require redundancy
|
||||
|
||||
output[xi] = x[i]
|
||||
if mps:
|
||||
output[xi] = output[xi].to(device)
|
||||
if (time.time() - t) > time_limit:
|
||||
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
|
||||
LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
|
||||
break # time limit exceeded
|
||||
|
||||
return output
|
||||
@ -269,17 +307,21 @@ def clip_boxes(boxes, shape):
|
||||
Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): the bounding boxes to clip
|
||||
shape (tuple): the shape of the image
|
||||
boxes (torch.Tensor): the bounding boxes to clip
|
||||
shape (tuple): the shape of the image
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | numpy.ndarray): Clipped boxes
|
||||
"""
|
||||
if isinstance(boxes, torch.Tensor): # faster individually
|
||||
boxes[..., 0].clamp_(0, shape[1]) # x1
|
||||
boxes[..., 1].clamp_(0, shape[0]) # y1
|
||||
boxes[..., 2].clamp_(0, shape[1]) # x2
|
||||
boxes[..., 3].clamp_(0, shape[0]) # y2
|
||||
if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
|
||||
boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
|
||||
boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1
|
||||
boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2
|
||||
boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2
|
||||
else: # np.array (faster grouped)
|
||||
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
|
||||
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
|
||||
return boxes
|
||||
|
||||
|
||||
def clip_coords(coords, shape):
|
||||
@ -291,19 +333,20 @@ def clip_coords(coords, shape):
|
||||
shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
|
||||
|
||||
Returns:
|
||||
(None): The function modifies the input `coordinates` in place, by clipping each coordinate to the image boundaries.
|
||||
(torch.Tensor | numpy.ndarray): Clipped coordinates
|
||||
"""
|
||||
if isinstance(coords, torch.Tensor): # faster individually
|
||||
coords[..., 0].clamp_(0, shape[1]) # x
|
||||
coords[..., 1].clamp_(0, shape[0]) # y
|
||||
if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
|
||||
coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x
|
||||
coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y
|
||||
else: # np.array (faster grouped)
|
||||
coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x
|
||||
coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y
|
||||
return coords
|
||||
|
||||
|
||||
def scale_image(masks, im0_shape, ratio_pad=None):
|
||||
"""
|
||||
Takes a mask, and resizes it to the original image size
|
||||
Takes a mask, and resizes it to the original image size.
|
||||
|
||||
Args:
|
||||
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
|
||||
@ -321,7 +364,7 @@ def scale_image(masks, im0_shape, ratio_pad=None):
|
||||
gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
|
||||
pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
# gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
top, left = int(pad[1]), int(pad[0]) # y, x
|
||||
bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
|
||||
@ -347,7 +390,7 @@ def xyxy2xywh(x):
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
||||
"""
|
||||
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
|
||||
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
|
||||
@ -367,7 +410,7 @@ def xywh2xyxy(x):
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
"""
|
||||
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
dw = x[..., 2] / 2 # half-width
|
||||
dh = x[..., 3] / 2 # half-height
|
||||
@ -392,7 +435,7 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
||||
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
|
||||
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
|
||||
"""
|
||||
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
|
||||
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
|
||||
@ -403,8 +446,8 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
||||
|
||||
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
||||
"""
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format.
|
||||
x, y, width and height are normalized to image dimensions
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
|
||||
width and height are normalized to image dimensions.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
@ -417,8 +460,8 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
|
||||
"""
|
||||
if clip:
|
||||
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
|
||||
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||
x = clip_boxes(x, (h - eps, w - eps))
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
|
||||
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
|
||||
@ -445,7 +488,7 @@ def xywh2ltwh(x):
|
||||
|
||||
def xyxy2ltwh(x):
|
||||
"""
|
||||
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right
|
||||
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
|
||||
@ -461,7 +504,7 @@ def xyxy2ltwh(x):
|
||||
|
||||
def ltwh2xywh(x):
|
||||
"""
|
||||
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
|
||||
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): the input tensor
|
||||
@ -477,7 +520,8 @@ def ltwh2xywh(x):
|
||||
|
||||
def xyxyxyxy2xywhr(corners):
|
||||
"""
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation].
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are
|
||||
expected in degrees from 0 to 90.
|
||||
|
||||
Args:
|
||||
corners (numpy.ndarray | torch.Tensor): Input corners of shape (n, 8).
|
||||
@ -485,66 +529,53 @@ def xyxyxyxy2xywhr(corners):
|
||||
Returns:
|
||||
(numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
|
||||
"""
|
||||
is_numpy = isinstance(corners, np.ndarray)
|
||||
atan2, sqrt = (np.arctan2, np.sqrt) if is_numpy else (torch.atan2, torch.sqrt)
|
||||
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = corners.T
|
||||
cx = (x1 + x3) / 2
|
||||
cy = (y1 + y3) / 2
|
||||
dx21 = x2 - x1
|
||||
dy21 = y2 - y1
|
||||
|
||||
w = sqrt(dx21 ** 2 + dy21 ** 2)
|
||||
h = sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
|
||||
|
||||
rotation = atan2(-dy21, dx21)
|
||||
rotation *= 180.0 / math.pi # radians to degrees
|
||||
|
||||
return np.vstack((cx, cy, w, h, rotation)).T if is_numpy else torch.stack((cx, cy, w, h, rotation), dim=1)
|
||||
is_torch = isinstance(corners, torch.Tensor)
|
||||
points = corners.cpu().numpy() if is_torch else corners
|
||||
points = points.reshape(len(corners), -1, 2)
|
||||
rboxes = []
|
||||
for pts in points:
|
||||
# NOTE: Use cv2.minAreaRect to get accurate xywhr,
|
||||
# especially some objects are cut off by augmentations in dataloader.
|
||||
(x, y), (w, h), angle = cv2.minAreaRect(pts)
|
||||
rboxes.append([x, y, w, h, angle / 180 * np.pi])
|
||||
return (
|
||||
torch.tensor(rboxes, device=corners.device, dtype=corners.dtype)
|
||||
if is_torch
|
||||
else np.asarray(rboxes, dtype=points.dtype)
|
||||
) # rboxes
|
||||
|
||||
|
||||
def xywhr2xyxyxyxy(center):
|
||||
def xywhr2xyxyxyxy(rboxes):
|
||||
"""
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4].
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should
|
||||
be in degrees from 0 to 90.
|
||||
|
||||
Args:
|
||||
center (numpy.ndarray | torch.Tensor): Input data in [cx, cy, w, h, rotation] format of shape (n, 5).
|
||||
rboxes (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5).
|
||||
|
||||
Returns:
|
||||
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 8).
|
||||
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2).
|
||||
"""
|
||||
is_numpy = isinstance(center, np.ndarray)
|
||||
is_numpy = isinstance(rboxes, np.ndarray)
|
||||
cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin)
|
||||
|
||||
cx, cy, w, h, rotation = center.T
|
||||
rotation *= math.pi / 180.0 # degrees to radians
|
||||
|
||||
dx = w / 2
|
||||
dy = h / 2
|
||||
|
||||
cos_rot = cos(rotation)
|
||||
sin_rot = sin(rotation)
|
||||
dx_cos_rot = dx * cos_rot
|
||||
dx_sin_rot = dx * sin_rot
|
||||
dy_cos_rot = dy * cos_rot
|
||||
dy_sin_rot = dy * sin_rot
|
||||
|
||||
x1 = cx - dx_cos_rot - dy_sin_rot
|
||||
y1 = cy + dx_sin_rot - dy_cos_rot
|
||||
x2 = cx + dx_cos_rot - dy_sin_rot
|
||||
y2 = cy - dx_sin_rot - dy_cos_rot
|
||||
x3 = cx + dx_cos_rot + dy_sin_rot
|
||||
y3 = cy - dx_sin_rot + dy_cos_rot
|
||||
x4 = cx - dx_cos_rot + dy_sin_rot
|
||||
y4 = cy + dx_sin_rot + dy_cos_rot
|
||||
|
||||
return np.vstack((x1, y1, x2, y2, x3, y3, x4, y4)).T if is_numpy else torch.stack(
|
||||
(x1, y1, x2, y2, x3, y3, x4, y4), dim=1)
|
||||
ctr = rboxes[..., :2]
|
||||
w, h, angle = (rboxes[..., i : i + 1] for i in range(2, 5))
|
||||
cos_value, sin_value = cos(angle), sin(angle)
|
||||
vec1 = [w / 2 * cos_value, w / 2 * sin_value]
|
||||
vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
|
||||
vec1 = np.concatenate(vec1, axis=-1) if is_numpy else torch.cat(vec1, dim=-1)
|
||||
vec2 = np.concatenate(vec2, axis=-1) if is_numpy else torch.cat(vec2, dim=-1)
|
||||
pt1 = ctr + vec1 + vec2
|
||||
pt2 = ctr + vec1 - vec2
|
||||
pt3 = ctr - vec1 - vec2
|
||||
pt4 = ctr - vec1 + vec2
|
||||
return np.stack([pt1, pt2, pt3, pt4], axis=-2) if is_numpy else torch.stack([pt1, pt2, pt3, pt4], dim=-2)
|
||||
|
||||
|
||||
def ltwh2xyxy(x):
|
||||
"""
|
||||
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): the input image
|
||||
@ -590,8 +621,9 @@ def resample_segments(segments, n=1000):
|
||||
s = np.concatenate((s, s[0:1, :]), axis=0)
|
||||
x = np.linspace(0, len(s) - 1, n)
|
||||
xp = np.arange(len(s))
|
||||
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)],
|
||||
dtype=np.float32).reshape(2, -1).T # segment xy
|
||||
segments[i] = (
|
||||
np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
|
||||
) # segment xy
|
||||
return segments
|
||||
|
||||
|
||||
@ -606,7 +638,7 @@ def crop_mask(masks, boxes):
|
||||
Returns:
|
||||
(torch.Tensor): The masks are being cropped to the bounding box.
|
||||
"""
|
||||
n, h, w = masks.shape
|
||||
_, h, w = masks.shape
|
||||
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
|
||||
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
|
||||
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
|
||||
@ -616,8 +648,8 @@ def crop_mask(masks, boxes):
|
||||
|
||||
def process_mask_upsample(protos, masks_in, bboxes, shape):
|
||||
"""
|
||||
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
|
||||
quality but is slower.
|
||||
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher quality
|
||||
but is slower.
|
||||
|
||||
Args:
|
||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||
@ -630,7 +662,7 @@ def process_mask_upsample(protos, masks_in, bboxes, shape):
|
||||
"""
|
||||
c, mh, mw = protos.shape # CHW
|
||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
||||
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
|
||||
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
|
||||
masks = crop_mask(masks, bboxes) # CHW
|
||||
return masks.gt_(0.5)
|
||||
|
||||
@ -654,16 +686,18 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
|
||||
c, mh, mw = protos.shape # CHW
|
||||
ih, iw = shape
|
||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
|
||||
width_ratio = mw / iw
|
||||
height_ratio = mh / ih
|
||||
|
||||
downsampled_bboxes = bboxes.clone()
|
||||
downsampled_bboxes[:, 0] *= mw / iw
|
||||
downsampled_bboxes[:, 2] *= mw / iw
|
||||
downsampled_bboxes[:, 3] *= mh / ih
|
||||
downsampled_bboxes[:, 1] *= mh / ih
|
||||
downsampled_bboxes[:, 0] *= width_ratio
|
||||
downsampled_bboxes[:, 2] *= width_ratio
|
||||
downsampled_bboxes[:, 3] *= height_ratio
|
||||
downsampled_bboxes[:, 1] *= height_ratio
|
||||
|
||||
masks = crop_mask(masks, downsampled_bboxes) # CHW
|
||||
if upsample:
|
||||
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
|
||||
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
|
||||
return masks.gt_(0.5)
|
||||
|
||||
|
||||
@ -707,13 +741,13 @@ def scale_masks(masks, shape, padding=True):
|
||||
bottom, right = (int(mh - pad[1]), int(mw - pad[0]))
|
||||
masks = masks[..., top:bottom, left:right]
|
||||
|
||||
masks = F.interpolate(masks, shape, mode='bilinear', align_corners=False) # NCHW
|
||||
masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW
|
||||
return masks
|
||||
|
||||
|
||||
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
|
||||
"""
|
||||
Rescale segment coordinates (xy) from img1_shape to img0_shape
|
||||
Rescale segment coordinates (xy) from img1_shape to img0_shape.
|
||||
|
||||
Args:
|
||||
img1_shape (tuple): The shape of the image that the coords are from.
|
||||
@ -739,14 +773,32 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
|
||||
coords[..., 1] -= pad[1] # y padding
|
||||
coords[..., 0] /= gain
|
||||
coords[..., 1] /= gain
|
||||
clip_coords(coords, img0_shape)
|
||||
coords = clip_coords(coords, img0_shape)
|
||||
if normalize:
|
||||
coords[..., 0] /= img0_shape[1] # width
|
||||
coords[..., 1] /= img0_shape[0] # height
|
||||
return coords
|
||||
|
||||
|
||||
def masks2segments(masks, strategy='largest'):
|
||||
def regularize_rboxes(rboxes):
|
||||
"""
|
||||
Regularize rotated boxes in range [0, pi/2].
|
||||
|
||||
Args:
|
||||
rboxes (torch.Tensor): (N, 5), xywhr.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The regularized boxes.
|
||||
"""
|
||||
x, y, w, h, t = rboxes.unbind(dim=-1)
|
||||
# Swap edge and angle if h >= w
|
||||
w_ = torch.where(w > h, w, h)
|
||||
h_ = torch.where(w > h, h, w)
|
||||
t = torch.where(w > h, t, t + math.pi / 2) % math.pi
|
||||
return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
|
||||
|
||||
|
||||
def masks2segments(masks, strategy="largest"):
|
||||
"""
|
||||
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
|
||||
|
||||
@ -758,16 +810,16 @@ def masks2segments(masks, strategy='largest'):
|
||||
segments (List): list of segment masks
|
||||
"""
|
||||
segments = []
|
||||
for x in masks.int().cpu().numpy().astype('uint8'):
|
||||
for x in masks.int().cpu().numpy().astype("uint8"):
|
||||
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
|
||||
if c:
|
||||
if strategy == 'concat': # concatenate all segments
|
||||
if strategy == "concat": # concatenate all segments
|
||||
c = np.concatenate([x.reshape(-1, 2) for x in c])
|
||||
elif strategy == 'largest': # select largest segment
|
||||
elif strategy == "largest": # select largest segment
|
||||
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
|
||||
else:
|
||||
c = np.zeros((0, 2)) # no segments found
|
||||
segments.append(c.astype('float32'))
|
||||
segments.append(c.astype("float32"))
|
||||
return segments
|
||||
|
||||
|
||||
@ -794,4 +846,19 @@ def clean_str(s):
|
||||
Returns:
|
||||
(str): a string with special characters replaced by an underscore _
|
||||
"""
|
||||
return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
|
||||
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
||||
|
||||
def v10postprocess(preds, max_det, nc=80):
|
||||
assert(4 + nc == preds.shape[-1])
|
||||
boxes, scores = preds.split([4, nc], dim=-1)
|
||||
max_scores = scores.amax(dim=-1)
|
||||
max_scores, index = torch.topk(max_scores, max_det, dim=-1)
|
||||
index = index.unsqueeze(-1)
|
||||
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
|
||||
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
|
||||
|
||||
scores, index = torch.topk(scores.flatten(1), max_det, dim=-1)
|
||||
labels = index % nc
|
||||
index = index // nc
|
||||
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
|
||||
return boxes, scores, labels
|
Reference in New Issue
Block a user