add yolo v10 and modify pipeline
This commit is contained in:
@ -4,4 +4,4 @@ from .model import NAS
|
||||
from .predict import NASPredictor
|
||||
from .val import NASValidator
|
||||
|
||||
__all__ = 'NASPredictor', 'NASValidator', 'NAS'
|
||||
__all__ = "NASPredictor", "NASValidator", "NAS"
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -17,26 +17,47 @@ import torch
|
||||
|
||||
from ultralytics.engine.model import Model
|
||||
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
|
||||
|
||||
from .predict import NASPredictor
|
||||
from .val import NASValidator
|
||||
|
||||
|
||||
class NAS(Model):
|
||||
"""
|
||||
YOLO NAS model for object detection.
|
||||
|
||||
def __init__(self, model='yolo_nas_s.pt') -> None:
|
||||
assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.'
|
||||
super().__init__(model, task='detect')
|
||||
This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
|
||||
It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics import NAS
|
||||
|
||||
model = NAS('yolo_nas_s')
|
||||
results = model.predict('ultralytics/assets/bus.jpg')
|
||||
```
|
||||
|
||||
Attributes:
|
||||
model (str): Path to the pre-trained model or model name. Defaults to 'yolo_nas_s.pt'.
|
||||
|
||||
Note:
|
||||
YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
|
||||
"""
|
||||
|
||||
def __init__(self, model="yolo_nas_s.pt") -> None:
|
||||
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
|
||||
assert Path(model).suffix not in (".yaml", ".yml"), "YOLO-NAS models only support pre-trained models."
|
||||
super().__init__(model, task="detect")
|
||||
|
||||
@smart_inference_mode()
|
||||
def _load(self, weights: str, task: str):
|
||||
# Load or create new NAS model
|
||||
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
|
||||
import super_gradients
|
||||
|
||||
suffix = Path(weights).suffix
|
||||
if suffix == '.pt':
|
||||
if suffix == ".pt":
|
||||
self.model = torch.load(weights)
|
||||
elif suffix == '':
|
||||
self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
|
||||
elif suffix == "":
|
||||
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
|
||||
# Standardize model
|
||||
self.model.fuse = lambda verbose=True: self.model
|
||||
self.model.stride = torch.tensor([32])
|
||||
@ -44,7 +65,7 @@ class NAS(Model):
|
||||
self.model.is_fused = lambda: False # for info()
|
||||
self.model.yaml = {} # for info()
|
||||
self.model.pt_path = weights # for export()
|
||||
self.model.task = 'detect' # for export()
|
||||
self.model.task = "detect" # for export()
|
||||
|
||||
def info(self, detailed=False, verbose=True):
|
||||
"""
|
||||
@ -58,4 +79,5 @@ class NAS(Model):
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}
|
||||
"""Returns a dictionary mapping tasks to respective predictor and validator classes."""
|
||||
return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
|
||||
|
@ -8,6 +8,29 @@ from ultralytics.utils import ops
|
||||
|
||||
|
||||
class NASPredictor(BasePredictor):
|
||||
"""
|
||||
Ultralytics YOLO NAS Predictor for object detection.
|
||||
|
||||
This class extends the `BasePredictor` from Ultralytics engine and is responsible for post-processing the
|
||||
raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
|
||||
scaling the bounding boxes to fit the original image dimensions.
|
||||
|
||||
Attributes:
|
||||
args (Namespace): Namespace containing various configurations for post-processing.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics import NAS
|
||||
|
||||
model = NAS('yolo_nas_s')
|
||||
predictor = model.predictor
|
||||
# Assumes that raw_preds, img, orig_imgs are available
|
||||
results = predictor.postprocess(raw_preds, img, orig_imgs)
|
||||
```
|
||||
|
||||
Note:
|
||||
Typically, this class is not instantiated directly. It is used internally within the `NAS` class.
|
||||
"""
|
||||
|
||||
def postprocess(self, preds_in, img, orig_imgs):
|
||||
"""Postprocess predictions and returns a list of Results objects."""
|
||||
@ -16,12 +39,14 @@ class NASPredictor(BasePredictor):
|
||||
boxes = ops.xyxy2xywh(preds_in[0][0])
|
||||
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
||||
|
||||
preds = ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
classes=self.args.classes)
|
||||
preds = ops.non_max_suppression(
|
||||
preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
classes=self.args.classes,
|
||||
)
|
||||
|
||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
|
@ -5,20 +5,46 @@ import torch
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import ops
|
||||
|
||||
__all__ = ['NASValidator']
|
||||
__all__ = ["NASValidator"]
|
||||
|
||||
|
||||
class NASValidator(DetectionValidator):
|
||||
"""
|
||||
Ultralytics YOLO NAS Validator for object detection.
|
||||
|
||||
Extends `DetectionValidator` from the Ultralytics models package and is designed to post-process the raw predictions
|
||||
generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
|
||||
ultimately producing the final detections.
|
||||
|
||||
Attributes:
|
||||
args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU thresholds.
|
||||
lb (torch.Tensor): Optional tensor for multilabel NMS.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics import NAS
|
||||
|
||||
model = NAS('yolo_nas_s')
|
||||
validator = model.validator
|
||||
# Assumes that raw_preds are available
|
||||
final_preds = validator.postprocess(raw_preds)
|
||||
```
|
||||
|
||||
Note:
|
||||
This class is generally not instantiated directly but is used internally within the `NAS` class.
|
||||
"""
|
||||
|
||||
def postprocess(self, preds_in):
|
||||
"""Apply Non-maximum suppression to prediction outputs."""
|
||||
boxes = ops.xyxy2xywh(preds_in[0][0])
|
||||
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
||||
return ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
labels=self.lb,
|
||||
multi_label=False,
|
||||
agnostic=self.args.single_cls,
|
||||
max_det=self.args.max_det,
|
||||
max_time_img=0.5)
|
||||
return ops.non_max_suppression(
|
||||
preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
labels=self.lb,
|
||||
multi_label=False,
|
||||
agnostic=self.args.single_cls,
|
||||
max_det=self.args.max_det,
|
||||
max_time_img=0.5,
|
||||
)
|
||||
|
Reference in New Issue
Block a user