add yolo v10 and modify pipeline
This commit is contained in:
24
ultralytics/models/yolov10/val.py
Normal file
24
ultralytics/models/yolov10/val.py
Normal file
@ -0,0 +1,24 @@
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import ops
|
||||
import torch
|
||||
|
||||
class YOLOv10DetectionValidator(DetectionValidator):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.args.save_json |= self.is_coco
|
||||
|
||||
def postprocess(self, preds):
|
||||
if isinstance(preds, dict):
|
||||
preds = preds["one2one"]
|
||||
|
||||
if isinstance(preds, (list, tuple)):
|
||||
preds = preds[0]
|
||||
|
||||
# Acknowledgement: Thanks to sanha9999 in #190 and #181!
|
||||
if preds.shape[-1] == 6:
|
||||
return preds
|
||||
else:
|
||||
preds = preds.transpose(-1, -2)
|
||||
boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc)
|
||||
bboxes = ops.xywh2xyxy(boxes)
|
||||
return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
Reference in New Issue
Block a user