diff --git a/__pycache__/pipeline_01.cpython-39.pyc b/__pycache__/pipeline_01.cpython-39.pyc new file mode 100644 index 0000000..595214a Binary files /dev/null and b/__pycache__/pipeline_01.cpython-39.pyc differ diff --git a/__pycache__/track_reid.cpython-39.pyc b/__pycache__/track_reid.cpython-39.pyc index d9e0026..4b57508 100644 Binary files a/__pycache__/track_reid.cpython-39.pyc and b/__pycache__/track_reid.cpython-39.pyc differ diff --git a/contrast/__pycache__/__init__.cpython-39.pyc b/contrast/__pycache__/__init__.cpython-39.pyc index 5c77270..6c5ac4a 100644 Binary files a/contrast/__pycache__/__init__.cpython-39.pyc and b/contrast/__pycache__/__init__.cpython-39.pyc differ diff --git a/contrast/feat_extract/__pycache__/config.cpython-39.pyc b/contrast/feat_extract/__pycache__/config.cpython-39.pyc index f61f73a..b3c7da0 100644 Binary files a/contrast/feat_extract/__pycache__/config.cpython-39.pyc and b/contrast/feat_extract/__pycache__/config.cpython-39.pyc differ diff --git a/contrast/feat_extract/__pycache__/inference.cpython-39.pyc b/contrast/feat_extract/__pycache__/inference.cpython-39.pyc index dfc7104..c7eb02c 100644 Binary files a/contrast/feat_extract/__pycache__/inference.cpython-39.pyc and b/contrast/feat_extract/__pycache__/inference.cpython-39.pyc differ diff --git a/contrast/feat_extract/inference.py b/contrast/feat_extract/inference.py index ab5c7b9..755f049 100644 --- a/contrast/feat_extract/inference.py +++ b/contrast/feat_extract/inference.py @@ -48,7 +48,7 @@ class FeatsInterface: modpath = os.path.join(curpath, conf.test_model) self.model.load_state_dict(torch.load(modpath, map_location=conf.device)) self.model.eval() - print('load model {} '.format(conf.testbackbone)) + # print('load model {} '.format(conf.testbackbone)) def inference(self, images, detections=None): ''' diff --git a/contrast/feat_extract/model/__pycache__/CBAM.cpython-39.pyc b/contrast/feat_extract/model/__pycache__/CBAM.cpython-39.pyc index c47423e..9c425a6 100644 Binary files a/contrast/feat_extract/model/__pycache__/CBAM.cpython-39.pyc and b/contrast/feat_extract/model/__pycache__/CBAM.cpython-39.pyc differ diff --git a/contrast/feat_extract/model/__pycache__/Tool.cpython-39.pyc b/contrast/feat_extract/model/__pycache__/Tool.cpython-39.pyc index f009757..0e3a668 100644 Binary files a/contrast/feat_extract/model/__pycache__/Tool.cpython-39.pyc and b/contrast/feat_extract/model/__pycache__/Tool.cpython-39.pyc differ diff --git a/contrast/feat_extract/model/__pycache__/__init__.cpython-39.pyc b/contrast/feat_extract/model/__pycache__/__init__.cpython-39.pyc index 50f8f0c..b79395e 100644 Binary files a/contrast/feat_extract/model/__pycache__/__init__.cpython-39.pyc and b/contrast/feat_extract/model/__pycache__/__init__.cpython-39.pyc differ diff --git a/contrast/feat_extract/model/__pycache__/fmobilenet.cpython-39.pyc b/contrast/feat_extract/model/__pycache__/fmobilenet.cpython-39.pyc index c8a7647..890c562 100644 Binary files a/contrast/feat_extract/model/__pycache__/fmobilenet.cpython-39.pyc and b/contrast/feat_extract/model/__pycache__/fmobilenet.cpython-39.pyc differ diff --git a/contrast/feat_extract/model/__pycache__/lcnet.cpython-39.pyc b/contrast/feat_extract/model/__pycache__/lcnet.cpython-39.pyc index 6a41ee8..fc264b4 100644 Binary files a/contrast/feat_extract/model/__pycache__/lcnet.cpython-39.pyc and b/contrast/feat_extract/model/__pycache__/lcnet.cpython-39.pyc differ diff --git a/contrast/feat_extract/model/__pycache__/loss.cpython-39.pyc b/contrast/feat_extract/model/__pycache__/loss.cpython-39.pyc index c266cef..b5cb9bd 100644 Binary files a/contrast/feat_extract/model/__pycache__/loss.cpython-39.pyc and b/contrast/feat_extract/model/__pycache__/loss.cpython-39.pyc differ diff --git a/contrast/feat_extract/model/__pycache__/metric.cpython-39.pyc b/contrast/feat_extract/model/__pycache__/metric.cpython-39.pyc index d4a705e..ab1598e 100644 Binary files a/contrast/feat_extract/model/__pycache__/metric.cpython-39.pyc and b/contrast/feat_extract/model/__pycache__/metric.cpython-39.pyc differ diff --git a/contrast/feat_extract/model/__pycache__/mobilenet_v2.cpython-39.pyc b/contrast/feat_extract/model/__pycache__/mobilenet_v2.cpython-39.pyc index 1680304..9590561 100644 Binary files a/contrast/feat_extract/model/__pycache__/mobilenet_v2.cpython-39.pyc and b/contrast/feat_extract/model/__pycache__/mobilenet_v2.cpython-39.pyc differ diff --git a/contrast/feat_extract/model/__pycache__/mobilenet_v3.cpython-39.pyc b/contrast/feat_extract/model/__pycache__/mobilenet_v3.cpython-39.pyc index 30ed788..0cf0132 100644 Binary files a/contrast/feat_extract/model/__pycache__/mobilenet_v3.cpython-39.pyc and b/contrast/feat_extract/model/__pycache__/mobilenet_v3.cpython-39.pyc differ diff --git a/contrast/feat_extract/model/__pycache__/mobilevit.cpython-39.pyc b/contrast/feat_extract/model/__pycache__/mobilevit.cpython-39.pyc index 6dc8172..a8f32ba 100644 Binary files a/contrast/feat_extract/model/__pycache__/mobilevit.cpython-39.pyc and b/contrast/feat_extract/model/__pycache__/mobilevit.cpython-39.pyc differ diff --git a/contrast/feat_extract/model/__pycache__/resbam.cpython-39.pyc b/contrast/feat_extract/model/__pycache__/resbam.cpython-39.pyc index 55f08d7..053b1ab 100644 Binary files a/contrast/feat_extract/model/__pycache__/resbam.cpython-39.pyc and b/contrast/feat_extract/model/__pycache__/resbam.cpython-39.pyc differ diff --git a/contrast/feat_extract/model/__pycache__/resnet_face.cpython-39.pyc b/contrast/feat_extract/model/__pycache__/resnet_face.cpython-39.pyc index 57209a0..a45da2a 100644 Binary files a/contrast/feat_extract/model/__pycache__/resnet_face.cpython-39.pyc and b/contrast/feat_extract/model/__pycache__/resnet_face.cpython-39.pyc differ diff --git a/contrast/feat_extract/model/__pycache__/resnet_pre.cpython-39.pyc b/contrast/feat_extract/model/__pycache__/resnet_pre.cpython-39.pyc index e2b7242..9254dcc 100644 Binary files a/contrast/feat_extract/model/__pycache__/resnet_pre.cpython-39.pyc and b/contrast/feat_extract/model/__pycache__/resnet_pre.cpython-39.pyc differ diff --git a/contrast/feat_extract/model/__pycache__/utils.cpython-39.pyc b/contrast/feat_extract/model/__pycache__/utils.cpython-39.pyc index d1836e9..5e9e645 100644 Binary files a/contrast/feat_extract/model/__pycache__/utils.cpython-39.pyc and b/contrast/feat_extract/model/__pycache__/utils.cpython-39.pyc differ diff --git a/execute_pipeline.py b/execute_pipeline.py new file mode 100644 index 0000000..6e9fa4f --- /dev/null +++ b/execute_pipeline.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Mar 28 11:35:28 2025 + +@author: ym +""" + +from pipeline_01 import execute_pipeline + + +execute_pipeline(evtdir = r"D:\datasets\ym\后台数据\unzip", + DataType = "raw", # raw, pkl + kk=1, + source_type = "video", # video, image, + save_path = r"D:\work\result_pipeline_V5", + yolo_ver = "V5", # V10, V5 + weight_yolo_v5 = r'./ckpts/best_cls10_0906.pt' , + weight_yolo_v10 = r'./ckpts/best_v10s_width0375_1205.pt', + saveimages = False + ) + +execute_pipeline(evtdir = r"D:\datasets\ym\后台数据\unzip", + DataType = "raw", # raw, pkl + kk=1, + source_type = "video", # video, image, + save_path = r"D:\work\result_pipeline_V10", + yolo_ver = "V10", # V10, V5 + weight_yolo_v5 = r'./ckpts/best_cls10_0906.pt' , + weight_yolo_v10 = r'./ckpts/best_v10s_width0375_1205.pt', + saveimages = False + ) \ No newline at end of file diff --git a/imgs_to_video.py b/imgs_to_video.py new file mode 100644 index 0000000..db016d6 --- /dev/null +++ b/imgs_to_video.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Jan 30 19:15:05 2024 + +@author: ym +""" +import cv2 +import os +import glob +IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes +VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes + + +def for_test(): + save_path = video_path + img_path + + fps, w, h = 10, 1024, 1280 + cap = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + + pathx = path + img_path + imgfiles = [f for f in os.listdir(pathx) if not f.find("_cut") != -1] + + imgfiles.sort(key = lambda x: int(x[:-5])) + imgpaths = [] + for imgfile in imgfiles: + imgpaths.append(os.path.join(pathx, imgfile)) + + center = (1280/2, 1024/2) + rotate_matrix = cv2.getRotationMatrix2D(center=center, angle=-90, scale=1) + k = 0 + for ipath in imgpaths: + img = cv2.imread(ipath) + rotated_image = cv2.warpAffine(src=img, M=rotate_matrix, dsize=(w, h)) + cap.write(rotated_image) + print("Have imgs") + +def test_1(): + + # name = os.path.split(img_path)[-1] + # save_path = video_path + name + '.mp4' + + save_path = video_path + img_path + + + + fps, w, h = 10, 1024, 1280 + cap = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + + pathx = path + img_path + imgfiles = [f for f in os.listdir(pathx) if not f.find("_cut") != -1] + + imgfiles.sort(key = lambda x: int(x[:-5])) + imgpaths = [] + for imgfile in imgfiles: + imgpaths.append(os.path.join(pathx, imgfile)) + + + + + # ipaths = [os.path.join(pathx, f) for f in os.listdir(pathx) if not f.find("_cut") != -1] + # ipaths = [] + # for f in os.listdir(pathx): + # if not f.find('_cut'): + # ipaths.append(os.path.join(pathx, f)) + # ipaths.sort(key = lambda x: int(x.split('_')[-2])) + + + k = 0 + for ipath in imgpaths: + img = cv2.imread(ipath) + cap.write(img) + + + k += 1 + + cap.release() + + print(img_path + f" have imgs: {k}") + +def img2video(imgpath): + if not os.path.isdir(imgpath): + return + + files = [] + files.extend(sorted(glob.glob(os.path.join(imgpath, "*.*")))) + images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS] + + h, w = cv2.imread(images[0]).shape[:2] + fps = 25 + + vidpath = imgpath + '.mp4' + cap = cv2.VideoWriter(vidpath, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + for p in images: + img = cv2.imread(p) + cap.write(img) + cap.release() + + +def main(): + imgpath = r"D:\work\result\202503251112_v10s_result" + + img2video(imgpath) + + + +if __name__ == "__main__": + main() + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/models/__pycache__/experimental.cpython-39.pyc b/models/__pycache__/experimental.cpython-39.pyc index babf431..2423fd1 100644 Binary files a/models/__pycache__/experimental.cpython-39.pyc and b/models/__pycache__/experimental.cpython-39.pyc differ diff --git a/models/experimental.py b/models/experimental.py index 2795871..d0fc839 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -76,7 +76,11 @@ def attempt_load(weights, device=None, inplace=True, fuse=True): model = Ensemble() for w in weights if isinstance(weights, list) else [weights]: - ckpt = torch.load(attempt_download(w), map_location=device, weights_only=False) # load + if torch.__version__ >= '2.6': + ckpt = torch.load(attempt_download(w), map_location=device, weights_only=False) # load + else: + ckpt = torch.load(attempt_download(w), map_location=device) + ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model # Model compatibility updates diff --git a/pipeline.py b/pipeline.py index f13830c..8f63a4f 100644 --- a/pipeline.py +++ b/pipeline.py @@ -11,7 +11,7 @@ import pickle import numpy as np from pathlib import Path from scipy.spatial.distance import cdist -from track_reid import yolo_resnet_tracker +from track_reid import yolo_resnet_tracker, yolov10_resnet_tracker from tracking.dotrack.dotracks_back import doBackTracks from tracking.dotrack.dotracks_front import doFrontTracks @@ -65,22 +65,20 @@ def pipeline( eventpath, savepath, SourceType, - weights + weights, + YoloVersion="V5" ): ''' eventpath: 单个事件的存储路径 ''' + optdict = {} + optdict["weights"] = weights if SourceType == "video": vpaths = get_video_pairs(eventpath) elif SourceType == "image": vpaths = get_image_pairs(eventpath) - - optdict = {} - optdict["weights"] = weights - - event_tracks = [] ## 构造购物事件字典 @@ -101,9 +99,9 @@ def pipeline( savepath_spdict.mkdir(parents=True, exist_ok=True) pf_path = Path(savepath_spdict) / Path(str(evtname)+".pickle") - if pf_path.exists(): - print(f"Pickle file have saved: {evtname}.pickle") - return + # if pf_path.exists(): + # print(f"Pickle file have saved: {evtname}.pickle") + # return '''====================== 构造 ShoppingDict 模块 =======================''' ShoppingDict = {"eventPath": eventpath, @@ -160,12 +158,16 @@ def pipeline( '''================= 3. Yolo + Resnet + Tracker =================''' optdict["source"] = vpath optdict["save_dir"] = savepath_pipeline_imgs - optdict["is_save_img"] = False + optdict["is_save_img"] = True optdict["is_save_video"] = True - yrtOut = yolo_resnet_tracker(**optdict) + if YoloVersion == "V5": + yrtOut = yolo_resnet_tracker(**optdict) + elif YoloVersion == "V10": + yrtOut = yolov10_resnet_tracker(**optdict) + yrtOut_save = [] for frdict in yrtOut: fr_dict = {} @@ -285,21 +287,32 @@ def pipeline( trajpath = os.path.join(savepath_pipeline, "trajectory.png") cv2.imwrite(trajpath, img_cat) -def main(): +def execute_pipeline(evtdir = r"D:\datasets\ym\后台数据\unzip", + source_type = "video", # video, image, + save_path = r"D:\work\result_pipeline", + yolo_ver = "V10", # V10, V5 + + weight_yolo_v5 = r'./ckpts/best_cls10_0906.pt' , + weight_yolo_v10 = r'./ckpts/best_v10s_width0375_1205.pt', + k=0 + ): + ''' + 运行函数 pipeline(),遍历事件文件夹,每个文件夹是一个事件 ''' - 函数:pipeline(),遍历事件文件夹,选择类型 image 或 video, - ''' parmDict = {} - evtdir = r"../dataset/backend_20250310" - parmDict["SourceType"] = "video" # video, image - parmDict["savepath"] = r"../dataset/run" - parmDict["weights"] = r'./ckpts/best_cls10_0906.pt' + parmDict["SourceType"] = source_type + parmDict["savepath"] = save_path + parmDict["YoloVersion"] = yolo_ver + if parmDict["YoloVersion"] == "V5": + parmDict["weights"] = weight_yolo_v5 + elif parmDict["YoloVersion"] == "V10": + parmDict["weights"] = weight_yolo_v10 evtdir = Path(evtdir) - k, errEvents = 0, [] + errEvents = [] for item in evtdir.iterdir(): if item.is_dir(): - # item = evtdir/Path("20250303-103058-074_6914973604223_6914973604223") + item = evtdir/Path("20250310-175352-741") parmDict["eventpath"] = item pipeline(**parmDict) # try: @@ -307,19 +320,21 @@ def main(): # except Exception as e: # errEvents.append(str(item)) k+=1 - if k==5: + if k==1: break - errfile = os.path.join(parmDict["savepath"], f'error_events.txt') + errfile = os.path.join(parmDict["savepath"], 'error_events.txt') with open(errfile, 'w', encoding='utf-8') as f: for line in errEvents: f.write(line + '\n') - - if __name__ == "__main__": - main() + execute_pipeline() + # spath_v10 = r"D:\work\result_pipeline_v10" + # spath_v5 = r"D:\work\result_pipeline_v5" + # execute_pipeline(save_path=spath_v10, yolo_ver="V10") + # execute_pipeline(save_path=spath_v5, yolo_ver="V5") diff --git a/pipeline_01.py b/pipeline_01.py new file mode 100644 index 0000000..adeec4e --- /dev/null +++ b/pipeline_01.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +""" +Created on Sun Sep 29 08:59:21 2024 + +@author: ym +""" +import os +# import sys +import cv2 +import pickle +import numpy as np +from pathlib import Path +from scipy.spatial.distance import cdist +from track_reid import yolo_resnet_tracker, yolov10_resnet_tracker + +from tracking.dotrack.dotracks_back import doBackTracks +from tracking.dotrack.dotracks_front import doFrontTracks +from tracking.utils.drawtracks import plot_frameID_y2, draw_all_trajectories +from utils.getsource import get_image_pairs, get_video_pairs +from tracking.utils.read_data import read_similar + +class CameraEvent_: + def __init__(self): + self.cameraType = '', # "front", "back" + self.videoPath = '', + self.imagePaths = [], + self.yoloResnetTracker =[], + self.tracking = None, + +class ShoppingEvent_: + def __init__(self): + self.eventPath = '' + self.eventName = '' + self.barcode = '' + self.eventType = '', # "input", "output", "other" + self.frontCamera = None + self.backCamera = None + self.one2n = [] + + + +def save_subimgs(imgdict, boxes, spath, ctype, featdict = None): + ''' + 当前 box 特征和该轨迹前一个 box 特征的相似度,可用于和跟踪序列中的相似度进行比较 + ''' + boxes = boxes[np.argsort(boxes[:, 7])] + for i in range(len(boxes)): + simi = None + tid, fid, bid = int(boxes[i, 4]), int(boxes[i, 7]), int(boxes[i, 8]) + + if i>0: + _, fid0, bid0 = int(boxes[i-1, 4]), int(boxes[i-1, 7]), int(boxes[i-1, 8]) + if f"{fid0}_{bid0}" in featdict.keys() and f"{fid}_{bid}" in featdict.keys(): + feat0 = featdict[f"{fid0}_{bid0}"] + feat1 = featdict[f"{fid}_{bid}"] + simi = 1 - np.maximum(0.0, cdist(feat0[None, :], feat1[None, :], "cosine"))[0][0] + + img = imgdict[f"{fid}_{bid}"] + imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}.png" + if simi is not None: + imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}_sim{simi:.2f}.png" + + cv2.imwrite(imgpath, img) + + +def save_subimgs_1(imgdict, boxes, spath, ctype, simidict = None): + ''' + 当前 box 特征和该轨迹 smooth_feat 特征的相似度, yolo_resnet_tracker 函数中, + 采用该方式记录特征相似度 + ''' + for i in range(len(boxes)): + tid, fid, bid = int(boxes[i, 4]), int(boxes[i, 7]), int(boxes[i, 8]) + + key = f"{fid}_{bid}" + img = imgdict[key] + imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}.png" + if simidict is not None and key in simidict.keys(): + imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}_sim{simidict[key]:.2f}.png" + + cv2.imwrite(imgpath, img) + +def show_result(event_tracks, yrtDict, savepath_pipe): + '''保存 Tracking 输出的运动轨迹子图,并记录相似度''' + + savepath_pipe_subimgs = savepath_pipe / Path("subimgs") + if not savepath_pipe_subimgs.exists(): + savepath_pipe_subimgs.mkdir(parents=True, exist_ok=True) + + + + + for CamerType, vts in event_tracks: + if len(vts.tracks)==0: continue + if CamerType == 'front': + # yolos = ShoppingDict["frontCamera"]["yoloResnetTracker"] + + yolos = yrtDict["frontyrt"] + ctype = 1 + if CamerType == 'back': + # yolos = ShoppingDict["backCamera"]["yoloResnetTracker"] + + yolos = yrtDict["backyrt"] + ctype = 0 + + imgdict, featdict, simidict = {}, {}, {} + for y in yolos: + imgdict.update(y["imgs"]) + featdict.update(y["feats"]) + simidict.update(y["featsimi"]) + + for track in vts.Residual: + if isinstance(track, np.ndarray): + save_subimgs(imgdict, track, savepath_pipe_subimgs, ctype, featdict) + else: + save_subimgs(imgdict, track.slt_boxes, savepath_pipe_subimgs, ctype, featdict) + + '''(3) 轨迹显示与保存''' + illus = [None, None] + for CamerType, vts in event_tracks: + if len(vts.tracks)==0: continue + + if CamerType == 'front': + edgeline = cv2.imread("./tracking/shopcart/cart_tempt/board_ftmp_line.png") + + h, w = edgeline.shape[:2] + # nh, nw = h//2, w//2 + # edgeline = cv2.resize(edgeline, (nw, nh), interpolation=cv2.INTER_AREA) + + img_tracking = draw_all_trajectories(vts, edgeline, savepath_pipe, CamerType, draw5p=True) + illus[0] = img_tracking + + plt = plot_frameID_y2(vts) + plt.savefig(os.path.join(savepath_pipe, "front_y2.png")) + + if CamerType == 'back': + edgeline = cv2.imread("./tracking/shopcart/cart_tempt/edgeline.png") + + h, w = edgeline.shape[:2] + # nh, nw = h//2, w//2 + # edgeline = cv2.resize(edgeline, (nw, nh), interpolation=cv2.INTER_AREA) + + img_tracking = draw_all_trajectories(vts, edgeline, savepath_pipe, CamerType, draw5p=True) + illus[1] = img_tracking + + illus = [im for im in illus if im is not None] + if len(illus): + img_cat = np.concatenate(illus, axis = 1) + if len(illus)==2: + H, W = img_cat.shape[:2] + cv2.line(img_cat, (int(W/2), 0), (int(W/2), int(H)), (128, 128, 255), 3) + + trajpath = os.path.join(savepath_pipe, "trajectory.png") + cv2.imwrite(trajpath, img_cat) + + + + +def pipeline(eventpath, + SourceType, + weights, + DataType = "raw", #raw, pkl: images or videos, pkl, pickle file + YoloVersion="V5", + savepath = None, + saveimages = True + ): + + ## 构造购物事件字典 + evtname = Path(eventpath).stem + barcode = evtname.split('_')[-1] if len(evtname.split('_'))>=2 \ + and len(evtname.split('_')[-1])>=8 \ + and evtname.split('_')[-1].isdigit() else '' + + '''事件结果存储文件夹: savepath_pipe, savepath_pkl''' + if not savepath: + savepath = Path(__file__).resolve().parents[0] / "events_result" + savepath_pipe = Path(savepath) / Path("yolos_tracking") / evtname + + + savepath_pkl = Path(savepath) / "shopping_pkl" + if not savepath_pkl.exists(): + savepath_pkl.mkdir(parents=True, exist_ok=True) + pklpath = Path(savepath_pkl) / Path(str(evtname)+".pickle") + + yrtDict = {} + + yrt_out = [] + if DataType == "raw": + ### 不重复执行已经过yolo-resnet-tracker + # if pklpath.exists(): + # print(f"Pickle file have saved: {evtname}.pickle") + # return + + if SourceType == "video": + vpaths = get_video_pairs(eventpath) + elif SourceType == "image": + vpaths = get_image_pairs(eventpath) + + + + for vpath in vpaths: + '''================= 2. 事件结果存储文件夹 =================''' + + + if isinstance(vpath, list): + savepath_pipe_imgs = savepath_pipe / Path("images") + else: + savepath_pipe_imgs = savepath_pipe / Path(str(Path(vpath).stem)) + + if not savepath_pipe_imgs.exists(): + savepath_pipe_imgs.mkdir(parents=True, exist_ok=True) + + optdict = {} + optdict["weights"] = weights + optdict["source"] = vpath + optdict["save_dir"] = savepath_pipe_imgs + optdict["is_save_img"] = saveimages + optdict["is_save_video"] = True + + + if YoloVersion == "V5": + yrtOut = yolo_resnet_tracker(**optdict) + elif YoloVersion == "V10": + yrtOut = yolov10_resnet_tracker(**optdict) + + yrt_out.append((vpath, yrtOut)) + + elif DataType == "pkl": + pass + + else: + return + + + + '''====================== 构造 ShoppingDict 模块 =======================''' + ShoppingDict = {"eventPath": eventpath, + "eventName": evtname, + "barcode": barcode, + "eventType": '', # "input", "output", "other" + "frontCamera": {}, + "backCamera": {}, + "one2n": [] # + } + procpath = Path(eventpath).joinpath('process.data') + if procpath.is_file(): + SimiDict = read_similar(procpath) + ShoppingDict["one2n"] = SimiDict['one2n'] + + event_tracks = [] + for vpath, yrtOut in yrt_out: + '''================= 1. 构造相机事件字典 =================''' + CameraEvent = {"cameraType": '', # "front", "back" + "videoPath": '', + "imagePaths": [], + "yoloResnetTracker": [], + "tracking": [], + } + + if isinstance(vpath, list): + CameraEvent["imagePaths"] = vpath + bname = os.path.basename(vpath[0]) + if not isinstance(vpath, list): + CameraEvent["videoPath"] = vpath + bname = os.path.basename(vpath).split('.')[0] + if bname.split('_')[0] == "0" or bname.find('back')>=0: + CameraEvent["cameraType"] = "back" + if bname.split('_')[0] == "1" or bname.find('front')>=0: + CameraEvent["cameraType"] = "front" + + + '''2种保存方式: (1) save images, (2) no save images''' + ### (1) save images + yrtOut_save = [] + for frdict in yrtOut: + fr_dict = {} + for k, v in frdict.items(): + if k != "imgs": + fr_dict[k]=v + yrtOut_save.append(fr_dict) + CameraEvent["yoloResnetTracker"] = yrtOut_save + + ### (2) no save images + # CameraEvent["yoloResnetTracker"] = yrtOut + + '''================= 4. tracking =================''' + '''(1) 生成用于 tracking 模块的 boxes、feats''' + bboxes = np.empty((0, 6), dtype=np.float64) + trackerboxes = np.empty((0, 9), dtype=np.float64) + trackefeats = {} + for frameDict in yrtOut: + tboxes = frameDict["tboxes"] + ffeats = frameDict["feats"] + + boxes = frameDict["bboxes"] + bboxes = np.concatenate((bboxes, np.array(boxes)), axis=0) + trackerboxes = np.concatenate((trackerboxes, np.array(tboxes)), axis=0) + for i in range(len(tboxes)): + fid, bid = int(tboxes[i, 7]), int(tboxes[i, 8]) + trackefeats.update({f"{fid}_{bid}": ffeats[f"{fid}_{bid}"]}) + + + '''(2) tracking, 后摄''' + if CameraEvent["cameraType"] == "back": + vts = doBackTracks(trackerboxes, trackefeats) + vts.classify() + event_tracks.append(("back", vts)) + + CameraEvent["tracking"] = vts + ShoppingDict["backCamera"] = CameraEvent + + yrtDict["backyrt"] = yrtOut + + '''(2) tracking, 前摄''' + if CameraEvent["cameraType"] == "front": + vts = doFrontTracks(trackerboxes, trackefeats) + vts.classify() + event_tracks.append(("front", vts)) + + CameraEvent["tracking"] = vts + ShoppingDict["frontCamera"] = CameraEvent + + yrtDict["frontyrt"] = yrtOut + + '''========================== 保存模块 =================================''' + # 保存 ShoppingDict + with open(str(pklpath), 'wb') as f: + pickle.dump(ShoppingDict, f) + + # 绘制并保存轨迹图 + show_result(event_tracks, yrtDict, savepath_pipe) + + + +def execute_pipeline(evtdir = r"D:\datasets\ym\后台数据\unzip", + DataType = "raw", # raw, pkl + save_path = r"D:\work\result_pipeline", + kk=1, + source_type = "video", # video, image, + yolo_ver = "V10", # V10, V5 + weight_yolo_v5 = r'./ckpts/best_cls10_0906.pt' , + weight_yolo_v10 = r'./ckpts/best_v10s_width0375_1205.pt', + saveimages = True + ): + ''' + 运行函数 pipeline(),遍历事件文件夹,每个文件夹是一个事件 + ''' + parmDict = {} + parmDict["DataType"] = DataType + parmDict["savepath"] = save_path + parmDict["SourceType"] = source_type + + parmDict["YoloVersion"] = yolo_ver + if parmDict["YoloVersion"] == "V5": + parmDict["weights"] = weight_yolo_v5 + elif parmDict["YoloVersion"] == "V10": + parmDict["weights"] = weight_yolo_v10 + + parmDict["saveimages"] = saveimages + + + evtdir = Path(evtdir) + errEvents = [] + k = 0 + for item in evtdir.iterdir(): + if item.is_dir(): + item = evtdir/Path("20250310-175352-741") + parmDict["eventpath"] = item + + pipeline(**parmDict) + # try: + # pipeline(**parmDict) + # except Exception as e: + # errEvents.append(str(item)) + + k+=1 + if kk is not None and k==kk: + break + + errfile = os.path.join(parmDict["savepath"], 'error_events.txt') + with open(errfile, 'w', encoding='utf-8') as f: + for line in errEvents: + f.write(line + '\n') + +if __name__ == "__main__": + execute_pipeline() + + # spath_v10 = r"D:\work\result_pipeline_v10" + # spath_v5 = r"D:\work\result_pipeline_v5" + # execute_pipeline(save_path=spath_v10, yolo_ver="V10") + # execute_pipeline(save_path=spath_v5, yolo_ver="V5") + + + + + \ No newline at end of file diff --git a/track_reid.py b/track_reid.py index 3357650..06d158f 100644 --- a/track_reid.py +++ b/track_reid.py @@ -64,7 +64,10 @@ from hands.hand_inference import hand_pose from contrast.feat_extract.config import config as conf from contrast.feat_extract.inference import FeatsInterface +from ultralytics import YOLOv10 + ReIDEncoder = FeatsInterface(conf) +print(f'load model {conf.testbackbone} in {Path(__file__).stem}') IMG_FORMATS = '.bmp', '.dng', '.jpeg', '.jpg', '.mpo', '.png', '.tif', '.tiff', '.webp', '.pfm' # include image suffixes VID_FORMATS = '.asf', '.avi', '.gif', '.m4v', '.mkv', '.mov', '.mp4', '.mpeg', '.mpg', '.ts', '.wmv' # include video suffixes @@ -131,12 +134,158 @@ def init_trackers(tracker_yaml = None, bs=1): trackers = [] for _ in range(bs): tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30) + if cfg.with_reid: + tracker.encoder = ReIDEncoder + trackers.append(tracker) return trackers +'''=============== used in pipeline.py for Yolov10 ==================''' +def yolov10_resnet_tracker( + weights = ROOT / 'ckpts/best_v10s_width0375_1205.pt', # model path or triton URL + source = '', # file/dir/URL/glob/screen/0(webcam) + save_dir = '', + is_save_img = True, + is_save_video = True, + + tracker_yaml = "./tracking/trackers/cfg/botsort.yaml", + line_thickness=3, # bounding box thickness (pixels) + hide_labels=False, # hide labels + ): + + ## load a custom model + model = YOLOv10(weights) -'''=============== used in pipeline.py ==================''' + custom = {"conf": 0.25, "batch": 1, "save": False, "mode": "predict"} + kwargs = {"save": True, "imgsz": 640, "conf": 0.1} + args = {**model.overrides, **custom, **kwargs} + predictor = model.task_map[model.task]["predictor"](overrides=args, _callbacks=model.callbacks) + + vid_path, vid_writer = None, None + tracker = init_trackers(tracker_yaml)[0] + yoloResnetTracker = [] + for i, result in enumerate(predictor.stream_inference(source)): + datamode = predictor.dataset.mode + + det = result.boxes.data.cpu().numpy() + im0 = result.orig_img + names = result.names + path = result.path + im_array = result.plot() + + + ## to do tracker.update() + det_tracking = Boxes(det, im0.shape) + tracks, outfeats = tracker.update(det_tracking, im0) + + + + if datamode == "video": + frameId = predictor.dataset.frame + elif datamode == "image": + frameId = predictor.dataset.count + annotator = Annotator(im0.copy(), line_width=line_thickness, example=str(names)) + + simdict, simdict1 = {}, {} + for fid, bid, mfeat, cfeat, features in outfeats: + if mfeat is not None and cfeat is not None: + simi = 1 - np.maximum(0.0, cdist(mfeat[None, :], cfeat[None, :], "cosine"))[0][0] + simdict.update({f"{int(frameId)}_{int(bid)}":simi}) + + if cfeat is not None and len(features)>=2: + mfeat = features[-2] + simi = 1 - np.maximum(0.0, cdist(mfeat[None, :], cfeat[None, :], "cosine"))[0][0] + simdict1.update({f"{int(frameId)}_{int(bid)}":simi}) + + + if len(tracks) > 0: + tracks[:, 7] = frameId + # trackerBoxes = np.concatenate([trackerBoxes, tracks], axis=0) + '''================== 1. 存储 dets/subimgs/features Dict =============''' + imgs, features = ReIDEncoder.inference(im0, tracks) + imgdict, featdict = {}, {} + for ii, bid in enumerate(tracks[:, 8]): + featdict.update({f"{int(frameId)}_{int(bid)}": features[ii, :]}) # [f"feat_{int(bid)}"] = features[i, :] + imgdict.update({f"{int(frameId)}_{int(bid)}": imgs[ii]}) + + frameDict = {"path": path, + "fid": int(frameId), + "bboxes": det, + "tboxes": tracks, + "imgs": imgdict, + "feats": featdict, + "featsimi": simdict, # 当前 box 特征和该轨迹 smooth_feat 特征的相似度 + "featsimi1": simdict1 # 当前 box 特征和该轨迹前一个 box 特征的相似度 + } + yoloResnetTracker.append(frameDict) + + # imgs, features = inference_image(im0, tracks) + # TrackerFeats = np.concatenate([TrackerFeats, features], axis=0) + + '''================== 2. 提取手势位置 ===================''' + for *xyxy, id, conf, cls, fid, bid in reversed(tracks): + name = ('' if id==-1 else f'id:{int(id)} ') + names[int(cls)] + if f"{int(frameId)}_{int(bid)}" in simdict.keys(): + sim = simdict[f"{int(frameId)}_{int(bid)}"] + label = f"{name} {sim:.2f}" + else: + label = None if hide_labels else name + + + # label = None if hide_labels else (name if hide_conf else f'{name} {conf:.1f}') + + if id >=0 and cls==0: + color = colors(int(cls), True) + elif id >=0 and cls!=0: + color = colors(int(id), True) + else: + color = colors(19, True) # 19为调色板的最后一个元素 + annotator.box_label(xyxy, label, color=color) + + '''====== Save results (image and video) ======''' + # save_path = str(save_dir / Path(path).name) # 带有后缀名 + im0 = annotator.result() + if is_save_img: + save_path_img = str(save_dir / Path(path).stem) + if datamode == 'image': + imgpath = save_path_img + ".png" + if datamode == 'video' : + imgpath = save_path_img + f"_{frameId}.png" + cv2.imwrite(Path(imgpath), im0) + + # if dataset.mode == 'video' and is_save_video: + + if is_save_video: + if datamode == 'video': + video_path = str(save_dir / Path(path).stem) + '.mp4' # 带有后缀名 + else: + videoname = str(Path(path).stem).split('_')[0] + '.mp4' + video_path = str(save_dir / videoname) + + if vid_path != video_path: # new video + vid_path = video_path + vid_cap = predictor.dataset.cap + + if isinstance(vid_writer, cv2.VideoWriter): + vid_writer.release() # release previous video writer + if vid_cap: # video + fps = vid_cap.get(cv2.CAP_PROP_FPS) + w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + else: # stream + fps, w, h = 25, im0.shape[1], im0.shape[0] + ## for image rotating in dataloader.LoadImages.__next__() + w, h = im0.shape[1], im0.shape[0] + + video_path = str(Path(video_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos + vid_writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + vid_writer.write(im0) + + return yoloResnetTracker + + +'''=============== used in pipeline.py for Yolov5 ==================''' @smart_inference_mode() def yolo_resnet_tracker( weights=ROOT / 'yolov5s.pt', # model path or triton URL @@ -660,8 +809,6 @@ def run( def parse_opt(): modelpath = ROOT / 'ckpts/best_cls10_0906.pt' # 'ckpts/best_15000_0908.pt', 'ckpts/yolov5s.pt', 'ckpts/best_20000_cls30.pt, best_yolov5m_250000' - - '''datapath为视频文件目录或视频文件''' datapath = r"D:/datasets/ym/videos/标记视频/" # ROOT/'data/videos', ROOT/'data/images' images # datapath = r"D:\datasets\ym\highvalue\videos" @@ -714,7 +861,7 @@ def find_video_imgs(root_dir): -def main(): +def main_v5(): ''' run(): 单张图像或单个视频文件的推理,不支持图像序列, ''' @@ -733,10 +880,10 @@ def main(): # p = r"D:\exhibition\images\153112511_0_seek_105.mp4" # p = r"D:\exhibition\images\image" - p = r"D:\全实时\202502\tracker\1_1740891284792.mp4" - optdict["project"] = r"D:\全实时\202502\tracker" - - # optdict["project"] = r"D:\exhibition\result" + p = r"D:\datasets\ym\后台数据\unzip\20250310-175352-741" + optdict["project"] = r"D:\work\result" + + optdict["weights"] = ROOT / 'ckpts/best_cls10_0906.pt' if os.path.isdir(p): files = find_video_imgs(p) k = 0 @@ -745,17 +892,39 @@ def main(): run(**optdict) k += 1 - if k == 1: + if k == 2: break elif os.path.isfile(p): optdict["source"] = p run(**optdict) +def main_v10(): + datapath = r'D:\datasets\ym\后台数据\unzip\20250310-175352-741\0.mp4' + savepath = r'D:\work\result' + savepath = savepath / Path(str(Path(datapath).stem)) + if not savepath.exists(): + savepath.mkdir(parents=True, exist_ok=True) + + weightpath = ROOT / 'ckpts/best_v10s_width0375_1205.pt' + + optdict = {} + optdict["weights"] = weightpath + optdict["source"] = datapath + optdict["save_dir"] = savepath + optdict["is_save_img"] = True + optdict["is_save_video"] = True + + yrtOut = yolov10_resnet_tracker(**optdict) + + if __name__ == '__main__': - main() + # main_v5() + + + main_v10() diff --git a/tracking/dotrack/__pycache__/dotracks.cpython-39.pyc b/tracking/dotrack/__pycache__/dotracks.cpython-39.pyc index ce55367..31724ce 100644 Binary files a/tracking/dotrack/__pycache__/dotracks.cpython-39.pyc and b/tracking/dotrack/__pycache__/dotracks.cpython-39.pyc differ diff --git a/tracking/dotrack/__pycache__/dotracks_back.cpython-39.pyc b/tracking/dotrack/__pycache__/dotracks_back.cpython-39.pyc index 7a9a101..9483b06 100644 Binary files a/tracking/dotrack/__pycache__/dotracks_back.cpython-39.pyc and b/tracking/dotrack/__pycache__/dotracks_back.cpython-39.pyc differ diff --git a/tracking/dotrack/__pycache__/dotracks_front.cpython-39.pyc b/tracking/dotrack/__pycache__/dotracks_front.cpython-39.pyc index bacadfe..eaa7a57 100644 Binary files a/tracking/dotrack/__pycache__/dotracks_front.cpython-39.pyc and b/tracking/dotrack/__pycache__/dotracks_front.cpython-39.pyc differ diff --git a/tracking/dotrack/__pycache__/track_back.cpython-39.pyc b/tracking/dotrack/__pycache__/track_back.cpython-39.pyc index 4969ac8..c2923a2 100644 Binary files a/tracking/dotrack/__pycache__/track_back.cpython-39.pyc and b/tracking/dotrack/__pycache__/track_back.cpython-39.pyc differ diff --git a/tracking/dotrack/__pycache__/track_front.cpython-39.pyc b/tracking/dotrack/__pycache__/track_front.cpython-39.pyc index 8baa7a1..523c566 100644 Binary files a/tracking/dotrack/__pycache__/track_front.cpython-39.pyc and b/tracking/dotrack/__pycache__/track_front.cpython-39.pyc differ diff --git a/tracking/trackers/__pycache__/bot_sort.cpython-39.pyc b/tracking/trackers/__pycache__/bot_sort.cpython-39.pyc index b02abbd..7dd0e3c 100644 Binary files a/tracking/trackers/__pycache__/bot_sort.cpython-39.pyc and b/tracking/trackers/__pycache__/bot_sort.cpython-39.pyc differ diff --git a/tracking/trackers/__pycache__/byte_tracker.cpython-39.pyc b/tracking/trackers/__pycache__/byte_tracker.cpython-39.pyc index 18b2f8f..33d3ddd 100644 Binary files a/tracking/trackers/__pycache__/byte_tracker.cpython-39.pyc and b/tracking/trackers/__pycache__/byte_tracker.cpython-39.pyc differ diff --git a/tracking/trackers/bot_sort.py b/tracking/trackers/bot_sort.py index ed0c96d..fb9ddba 100644 --- a/tracking/trackers/bot_sort.py +++ b/tracking/trackers/bot_sort.py @@ -116,11 +116,13 @@ class BOTSORT(BYTETracker): self.proximity_thresh = args.proximity_thresh self.appearance_thresh = args.appearance_thresh - if args.with_reid: - # Haven't supported BoT-SORT(reid) yet - # self.encoder = ReIDInterface(config) + # if args.with_reid: + # # Haven't supported BoT-SORT(reid) yet + # # self.encoder = ReIDInterface(config) - self.encoder = FeatsInterface(conf) + # self.encoder = FeatsInterface(conf) + + # print('load model {} in BOTSORT'.format(conf.testbackbone)) # self.gmc = GMC(method=args.gmc_method) # commented by WQG diff --git a/tracking/tracking_pipeline.py b/tracking/tracking_pipeline.py new file mode 100644 index 0000000..4e2b9d7 --- /dev/null +++ b/tracking/tracking_pipeline.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Mar 27 16:09:07 2025 + +@author: ym +""" +import os +import sys +import cv2 +import pickle +import numpy as np +from pathlib import Path +from scipy.spatial.distance import cdist + +from .dotrack.dotracks_back import doBackTracks +from .dotrack.dotracks_front import doFrontTracks +from .utils.drawtracks import plot_frameID_y2, draw_all_trajectories +from .utils.read_data import read_similar + + + + +class CameraEvent_: + def __init__(self): + self.cameraType = '', # "front", "back" + self.videoPath = '', + self.imagePaths = [], + self.yoloResnetTracker =[], + self.tracking = None, + +class ShoppingEvent_: + def __init__(self): + self.eventPath = '' + self.eventName = '' + self.barcode = '' + self.eventType = '', # "input", "output", "other" + self.frontCamera = None + self.backCamera = None + self.one2n = [] + + + + + +def main(): + ''' + 将一个对象读取,修改其中一个属性 + + ''' + + + evt_pkfile = 'path.pickle' + with open(evt_pkfile, 'rb') as f: + ShoppingDict = pickle.load(f) + + savepath = "" + + back_camera = ShoppingDict["backCamera"]["cameraType"] + back_yrt = ShoppingDict["backCamera"]["yoloResnetTracker"] + front_camera = ShoppingDict["frontCamera"]["cameraType"] + front_yrt = ShoppingDict["frontCamera"]["yoloResnetTracker"] + yrts = [(back_camera, back_yrt), (front_camera, front_yrt)] + + + shopping_event = ShoppingEvent_() + shopping_event.eventPath = ShoppingDict["eventPath"] + shopping_event.eventName = ShoppingDict["eventName"] + shopping_event.barcode = ShoppingDict["barcode"] + + yrtDict = {} + event_tracks = [] + for camera_type, yrtOut in yrts: + ''' + inputs: + yrtOut + camera_type + outputs: + CameraEvent + ''' + + camera_event = CameraEvent_() + + + + '''================= 4. tracking =================''' + '''(1) 生成用于 tracking 模块的 boxes、feats''' + bboxes = np.empty((0, 6), dtype=np.float64) + trackerboxes = np.empty((0, 9), dtype=np.float64) + trackefeats = {} + for frameDict in yrtOut: + tboxes = frameDict["tboxes"] + ffeats = frameDict["feats"] + + boxes = frameDict["bboxes"] + bboxes = np.concatenate((bboxes, np.array(boxes)), axis=0) + trackerboxes = np.concatenate((trackerboxes, np.array(tboxes)), axis=0) + for i in range(len(tboxes)): + fid, bid = int(tboxes[i, 7]), int(tboxes[i, 8]) + trackefeats.update({f"{fid}_{bid}": ffeats[f"{fid}_{bid}"]}) + + + '''(2) tracking, 后摄''' + if CameraEvent["cameraType"] == "back": + vts = doBackTracks(trackerboxes, trackefeats) + vts.classify() + event_tracks.append(("back", vts)) + + + + camera_event.camera_type = camera_type + camera_event.yoloResnetTracker = yrtOut + camera_event.tracking = vts + camera_event.videoPath = ShoppingDict["backCamera"]["videoPath"] + camera_event.imagePaths = ShoppingDict["backCamera"]["imagePaths"] + shopping_event.backCamera = camera_event + + yrtDict["backyrt"] = yrtOut + + '''(2) tracking, 前摄''' + if CameraEvent["cameraType"] == "front": + vts = doFrontTracks(trackerboxes, trackefeats) + vts.classify() + event_tracks.append(("front", vts)) + + camera_event.camera_type = camera_type + camera_event.yoloResnetTracker = yrtOut + camera_event.tracking = vts + camera_event.videoPath = ShoppingDict["frontCamera"]["videoPath"] + camera_event.imagePaths = ShoppingDict["frontCamera"]["imagePaths"] + shopping_event.backCamera = camera_event + + yrtDict["frontyrt"] = yrtOut + + + name = Path(evt_pkfile).stem + pf_path = os.path.join(savepath, name+"_new.pickle") + with open(str(pf_path), 'wb') as f: + pickle.dump(shopping_event, f) + + + + illus = [None, None] + for CamerType, vts in event_tracks: + if len(vts.tracks)==0: continue + + if CamerType == 'front': + edgeline = cv2.imread("./tracking/shopcart/cart_tempt/board_ftmp_line.png") + + h, w = edgeline.shape[:2] + # nh, nw = h//2, w//2 + # edgeline = cv2.resize(edgeline, (nw, nh), interpolation=cv2.INTER_AREA) + + img_tracking = draw_all_trajectories(vts, edgeline, savepath_pipeline, CamerType, draw5p=True) + illus[0] = img_tracking + + plt = plot_frameID_y2(vts) + plt.savefig(os.path.join(savepath_pipeline, "front_y2.png")) + + if CamerType == 'back': + edgeline = cv2.imread("./tracking/shopcart/cart_tempt/edgeline.png") + + h, w = edgeline.shape[:2] + # nh, nw = h//2, w//2 + # edgeline = cv2.resize(edgeline, (nw, nh), interpolation=cv2.INTER_AREA) + + img_tracking = draw_all_trajectories(vts, edgeline, savepath_pipeline, CamerType, draw5p=True) + illus[1] = img_tracking + + + + + + + + + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tracking/utils/__pycache__/annotator.cpython-39.pyc b/tracking/utils/__pycache__/annotator.cpython-39.pyc index b5eb1c9..182245b 100644 Binary files a/tracking/utils/__pycache__/annotator.cpython-39.pyc and b/tracking/utils/__pycache__/annotator.cpython-39.pyc differ diff --git a/tracking/utils/__pycache__/drawtracks.cpython-39.pyc b/tracking/utils/__pycache__/drawtracks.cpython-39.pyc index bccf31b..90325be 100644 Binary files a/tracking/utils/__pycache__/drawtracks.cpython-39.pyc and b/tracking/utils/__pycache__/drawtracks.cpython-39.pyc differ diff --git a/tracking/utils/__pycache__/plotting.cpython-39.pyc b/tracking/utils/__pycache__/plotting.cpython-39.pyc index e0fd814..4de0e4e 100644 Binary files a/tracking/utils/__pycache__/plotting.cpython-39.pyc and b/tracking/utils/__pycache__/plotting.cpython-39.pyc differ diff --git a/tracking/utils/__pycache__/read_data.cpython-39.pyc b/tracking/utils/__pycache__/read_data.cpython-39.pyc index 0d7146f..9748378 100644 Binary files a/tracking/utils/__pycache__/read_data.cpython-39.pyc and b/tracking/utils/__pycache__/read_data.cpython-39.pyc differ diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index a0ae59f..8ff1b4f 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,12 +1,27 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = '8.0.173' +__version__ = "8.1.34" -from ultralytics.models import RTDETR, SAM, YOLO +from ultralytics.data.explorer.explorer import Explorer +from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld, YOLOv10 from ultralytics.models.fastsam import FastSAM from ultralytics.models.nas import NAS -from ultralytics.utils import SETTINGS as settings +from ultralytics.utils import ASSETS, SETTINGS as settings from ultralytics.utils.checks import check_yolo as checks from ultralytics.utils.downloads import download -__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'settings' +__all__ = ( + "__version__", + "ASSETS", + "YOLO", + "YOLOWorld", + "NAS", + "SAM", + "FastSAM", + "RTDETR", + "checks", + "download", + "settings", + "Explorer", + "YOLOv10" +) diff --git a/ultralytics/__pycache__/__init__.cpython-312.pyc b/ultralytics/__pycache__/__init__.cpython-312.pyc index cc440bf..aebdf61 100644 Binary files a/ultralytics/__pycache__/__init__.cpython-312.pyc and b/ultralytics/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/__pycache__/__init__.cpython-39.pyc b/ultralytics/__pycache__/__init__.cpython-39.pyc index 2dffc61..a34898b 100644 Binary files a/ultralytics/__pycache__/__init__.cpython-39.pyc and b/ultralytics/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 7bc48f2..c34cc17 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -1,34 +1,62 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license import contextlib -import re +import os import shutil +import subprocess import sys from pathlib import Path from types import SimpleNamespace from typing import Dict, List, Union +import re -from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, - SETTINGS_YAML, IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, - yaml_load, yaml_print) +from ultralytics.utils import ( + ASSETS, + DEFAULT_CFG, + DEFAULT_CFG_DICT, + DEFAULT_CFG_PATH, + LOGGER, + RANK, + ROOT, + RUNS_DIR, + SETTINGS, + SETTINGS_YAML, + TESTS_RUNNING, + IterableSimpleNamespace, + __version__, + checks, + colorstr, + deprecation_warn, + yaml_load, + yaml_print, +) # Define valid tasks and modes -MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark' -TASKS = 'detect', 'segment', 'classify', 'pose' -TASK2DATA = {'detect': 'coco8.yaml', 'segment': 'coco8-seg.yaml', 'classify': 'imagenet10', 'pose': 'coco8-pose.yaml'} +MODES = {"train", "val", "predict", "export", "track", "benchmark"} +TASKS = {"detect", "segment", "classify", "pose", "obb"} +TASK2DATA = { + "detect": "coco8.yaml", + "segment": "coco8-seg.yaml", + "classify": "imagenet10", + "pose": "coco8-pose.yaml", + "obb": "dota8.yaml", +} TASK2MODEL = { - 'detect': 'yolov8n.pt', - 'segment': 'yolov8n-seg.pt', - 'classify': 'yolov8n-cls.pt', - 'pose': 'yolov8n-pose.pt'} + "detect": "yolov8n.pt", + "segment": "yolov8n-seg.pt", + "classify": "yolov8n-cls.pt", + "pose": "yolov8n-pose.pt", + "obb": "yolov8n-obb.pt", +} TASK2METRIC = { - 'detect': 'metrics/mAP50-95(B)', - 'segment': 'metrics/mAP50-95(M)', - 'classify': 'metrics/accuracy_top1', - 'pose': 'metrics/mAP50-95(P)'} + "detect": "metrics/mAP50-95(B)", + "segment": "metrics/mAP50-95(M)", + "classify": "metrics/accuracy_top1", + "pose": "metrics/mAP50-95(P)", + "obb": "metrics/mAP50-95(B)", +} -CLI_HELP_MSG = \ - f""" +CLI_HELP_MSG = f""" Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax: yolo TASK MODE ARGS @@ -42,7 +70,7 @@ CLI_HELP_MSG = \ yolo train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01 2. Predict a YouTube video using a pretrained segmentation model at image size 320: - yolo predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320 + yolo predict model=yolov8n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320 3. Val a pretrained detection model at batch-size 1 and image size 640: yolo val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640 @@ -50,6 +78,9 @@ CLI_HELP_MSG = \ 4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required) yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128 + 6. Explore your datasets using semantic search and SQL with a simple GUI powered by Ultralytics Explorer API + yolo explorer + 5. Run special commands: yolo help yolo checks @@ -64,16 +95,84 @@ CLI_HELP_MSG = \ """ # Define keys for arg type checks -CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear' -CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', - 'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', - 'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction') # fraction floats 0.0 - 1.0 -CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride', - 'line_width', 'workspace', 'nbs', 'save_period') -CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val', - 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop', - 'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras', - 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile') +CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"} +CFG_FRACTION_KEYS = { + "dropout", + "iou", + "lr0", + "lrf", + "momentum", + "weight_decay", + "warmup_momentum", + "warmup_bias_lr", + "label_smoothing", + "hsv_h", + "hsv_s", + "hsv_v", + "translate", + "scale", + "perspective", + "flipud", + "fliplr", + "bgr", + "mosaic", + "mixup", + "copy_paste", + "conf", + "iou", + "fraction", +} # fraction floats 0.0 - 1.0 +CFG_INT_KEYS = { + "epochs", + "patience", + "batch", + "workers", + "seed", + "close_mosaic", + "mask_ratio", + "max_det", + "vid_stride", + "line_width", + "workspace", + "nbs", + "save_period", +} +CFG_BOOL_KEYS = { + "save", + "exist_ok", + "verbose", + "deterministic", + "single_cls", + "rect", + "cos_lr", + "overlap_mask", + "val", + "save_json", + "save_hybrid", + "half", + "dnn", + "plots", + "show", + "save_txt", + "save_conf", + "save_crop", + "save_frames", + "show_labels", + "show_conf", + "visualize", + "augment", + "agnostic_nms", + "retina_masks", + "show_boxes", + "keras", + "optimize", + "int8", + "dynamic", + "simplify", + "nms", + "profile", + "multi_scale", +} def cfg2dict(cfg): @@ -109,53 +208,72 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove # Merge overrides if overrides: overrides = cfg2dict(overrides) - if 'save_dir' not in cfg: - overrides.pop('save_dir', None) # special override keys to ignore + if "save_dir" not in cfg: + overrides.pop("save_dir", None) # special override keys to ignore check_dict_alignment(cfg, overrides) cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides) # Special handling for numeric project/name - for k in 'project', 'name': + for k in "project", "name": if k in cfg and isinstance(cfg[k], (int, float)): cfg[k] = str(cfg[k]) - if cfg.get('name') == 'model': # assign model to 'name' arg - cfg['name'] = cfg.get('model', '').split('.')[0] + if cfg.get("name") == "model": # assign model to 'name' arg + cfg["name"] = cfg.get("model", "").split(".")[0] LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.") # Type and Value checks - for k, v in cfg.items(): - if v is not None: # None values may be from optional args - if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)): - raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " - f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')") - elif k in CFG_FRACTION_KEYS: - if not isinstance(v, (int, float)): - raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " - f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')") - if not (0.0 <= v <= 1.0): - raise ValueError(f"'{k}={v}' is an invalid value. " - f"Valid '{k}' values are between 0.0 and 1.0.") - elif k in CFG_INT_KEYS and not isinstance(v, int): - raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " - f"'{k}' must be an int (i.e. '{k}=8')") - elif k in CFG_BOOL_KEYS and not isinstance(v, bool): - raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " - f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')") + check_cfg(cfg) # Return instance return IterableSimpleNamespace(**cfg) +def check_cfg(cfg, hard=True): + """Check Ultralytics configuration argument types and values.""" + for k, v in cfg.items(): + if v is not None: # None values may be from optional args + if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) + cfg[k] = float(v) + elif k in CFG_FRACTION_KEYS: + if not isinstance(v, (int, float)): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) + cfg[k] = v = float(v) + if not (0.0 <= v <= 1.0): + raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.") + elif k in CFG_INT_KEYS and not isinstance(v, int): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')" + ) + cfg[k] = int(v) + elif k in CFG_BOOL_KEYS and not isinstance(v, bool): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')" + ) + cfg[k] = bool(v) + + def get_save_dir(args, name=None): """Return save_dir as created from train/val/predict arguments.""" - if getattr(args, 'save_dir', None): + if getattr(args, "save_dir", None): save_dir = args.save_dir else: from ultralytics.utils.files import increment_path - project = args.project or Path(SETTINGS['runs_dir']) / args.task - name = name or args.name or f'{args.mode}' + project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task + name = name or args.name or f"{args.mode}" save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True) return Path(save_dir) @@ -165,23 +283,26 @@ def _handle_deprecation(custom): """Hardcoded function to handle deprecated config keys.""" for key in custom.copy().keys(): - if key == 'hide_labels': - deprecation_warn(key, 'show_labels') - custom['show_labels'] = custom.pop('hide_labels') == 'False' - if key == 'hide_conf': - deprecation_warn(key, 'show_conf') - custom['show_conf'] = custom.pop('hide_conf') == 'False' - if key == 'line_thickness': - deprecation_warn(key, 'line_width') - custom['line_width'] = custom.pop('line_thickness') + if key == "boxes": + deprecation_warn(key, "show_boxes") + custom["show_boxes"] = custom.pop("boxes") + if key == "hide_labels": + deprecation_warn(key, "show_labels") + custom["show_labels"] = custom.pop("hide_labels") == "False" + if key == "hide_conf": + deprecation_warn(key, "show_conf") + custom["show_conf"] = custom.pop("hide_conf") == "False" + if key == "line_thickness": + deprecation_warn(key, "line_width") + custom["line_width"] = custom.pop("line_thickness") return custom def check_dict_alignment(base: Dict, custom: Dict, e=None): """ - This function checks for any mismatched keys between a custom configuration list and a base configuration list. - If any mismatched keys are found, the function prints out similar keys from the base list and exits the program. + This function checks for any mismatched keys between a custom configuration list and a base configuration list. If + any mismatched keys are found, the function prints out similar keys from the base list and exits the program. Args: custom (dict): a dictionary of custom configuration options @@ -194,36 +315,35 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None): if mismatched: from difflib import get_close_matches - string = '' + string = "" for x in mismatched: matches = get_close_matches(x, base_keys) # key list - matches = [f'{k}={base[k]}' if base.get(k) is not None else k for k in matches] - match_str = f'Similar arguments are i.e. {matches}.' if matches else '' + matches = [f"{k}={base[k]}" if base.get(k) is not None else k for k in matches] + match_str = f"Similar arguments are i.e. {matches}." if matches else "" string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n" raise SyntaxError(string + CLI_HELP_MSG) from e def merge_equals_args(args: List[str]) -> List[str]: """ - Merges arguments around isolated '=' args in a list of strings. - The function considers cases where the first argument ends with '=' or the second starts with '=', - as well as when the middle one is an equals sign. + Merges arguments around isolated '=' args in a list of strings. The function considers cases where the first + argument ends with '=' or the second starts with '=', as well as when the middle one is an equals sign. Args: args (List[str]): A list of strings where each element is an argument. Returns: - List[str]: A list of strings where the arguments around isolated '=' are merged. + (List[str]): A list of strings where the arguments around isolated '=' are merged. """ new_args = [] for i, arg in enumerate(args): - if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val'] - new_args[-1] += f'={args[i + 1]}' + if arg == "=" and 0 < i < len(args) - 1: # merge ['arg', '=', 'val'] + new_args[-1] += f"={args[i + 1]}" del args[i + 1] - elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val'] - new_args.append(f'{arg}{args[i + 1]}') + elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]: # merge ['arg=', 'val'] + new_args.append(f"{arg}{args[i + 1]}") del args[i + 1] - elif arg.startswith('=') and i > 0: # merge ['arg', '=val'] + elif arg.startswith("=") and i > 0: # merge ['arg', '=val'] new_args[-1] += arg else: new_args.append(arg) @@ -247,11 +367,11 @@ def handle_yolo_hub(args: List[str]) -> None: """ from ultralytics import hub - if args[0] == 'login': - key = args[1] if len(args) > 1 else '' + if args[0] == "login": + key = args[1] if len(args) > 1 else "" # Log in to Ultralytics HUB using the provided API key hub.login(key) - elif args[0] == 'logout': + elif args[0] == "logout": # Log out from Ultralytics HUB hub.logout() @@ -271,39 +391,47 @@ def handle_yolo_settings(args: List[str]) -> None: python my_script.py yolo settings reset ``` """ - url = 'https://docs.ultralytics.com/quickstart/#ultralytics-settings' # help URL + url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings" # help URL try: if any(args): - if args[0] == 'reset': + if args[0] == "reset": SETTINGS_YAML.unlink() # delete the settings file SETTINGS.reset() # create new settings - LOGGER.info('Settings reset successfully') # inform the user that settings have been reset + LOGGER.info("Settings reset successfully") # inform the user that settings have been reset else: # save a new setting new = dict(parse_key_value_pair(a) for a in args) check_dict_alignment(SETTINGS, new) SETTINGS.update(new) - LOGGER.info(f'💡 Learn about settings at {url}') + LOGGER.info(f"💡 Learn about settings at {url}") yaml_print(SETTINGS_YAML) # print the current settings except Exception as e: LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.") +def handle_explorer(): + """Open the Ultralytics Explorer GUI.""" + checks.check_requirements("streamlit") + LOGGER.info("💡 Loading Explorer dashboard...") + subprocess.run(["streamlit", "run", ROOT / "data/explorer/gui/dash.py", "--server.maxMessageSize", "2048"]) + + def parse_key_value_pair(pair): """Parse one 'key=value' pair and return key and value.""" - re.sub(r' *= *', '=', pair) # remove spaces around equals sign - k, v = pair.split('=', 1) # split on first '=' sign + k, v = pair.split("=", 1) # split on first '=' sign + k, v = k.strip(), v.strip() # remove spaces assert v, f"missing '{k}' value" return k, smart_value(v) def smart_value(v): """Convert a string to an underlying type such as int, float, bool, etc.""" - if v.lower() == 'none': + v_lower = v.lower() + if v_lower == "none": return None - elif v.lower() == 'true': + elif v_lower == "true": return True - elif v.lower() == 'false': + elif v_lower == "false": return False else: with contextlib.suppress(Exception): @@ -311,7 +439,7 @@ def smart_value(v): return v -def entrypoint(debug=''): +def entrypoint(debug=""): """ This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed to the package. @@ -326,135 +454,160 @@ def entrypoint(debug=''): It uses the package's default cfg and initializes it using the passed overrides. Then it calls the CLI function with the composed cfg """ - args = (debug.split(' ') if debug else sys.argv)[1:] + args = (debug.split(" ") if debug else sys.argv)[1:] if not args: # no arguments passed LOGGER.info(CLI_HELP_MSG) return special = { - 'help': lambda: LOGGER.info(CLI_HELP_MSG), - 'checks': checks.check_yolo, - 'version': lambda: LOGGER.info(__version__), - 'settings': lambda: handle_yolo_settings(args[1:]), - 'cfg': lambda: yaml_print(DEFAULT_CFG_PATH), - 'hub': lambda: handle_yolo_hub(args[1:]), - 'login': lambda: handle_yolo_hub(args), - 'copy-cfg': copy_default_cfg} + "help": lambda: LOGGER.info(CLI_HELP_MSG), + "checks": checks.collect_system_info, + "version": lambda: LOGGER.info(__version__), + "settings": lambda: handle_yolo_settings(args[1:]), + "cfg": lambda: yaml_print(DEFAULT_CFG_PATH), + "hub": lambda: handle_yolo_hub(args[1:]), + "login": lambda: handle_yolo_hub(args), + "copy-cfg": copy_default_cfg, + "explorer": lambda: handle_explorer(), + } full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special} - # Define common mis-uses of special commands, i.e. -h, -help, --help + # Define common misuses of special commands, i.e. -h, -help, --help special.update({k[0]: v for k, v in special.items()}) # singular - special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith('s')}) # singular - special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}} + special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith("s")}) # singular + special = {**special, **{f"-{k}": v for k, v in special.items()}, **{f"--{k}": v for k, v in special.items()}} overrides = {} # basic overrides, i.e. imgsz=320 for a in merge_equals_args(args): # merge spaces around '=' sign - if a.startswith('--'): - LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.") + if a.startswith("--"): + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require leading dashes '--', updating to '{a[2:]}'.") a = a[2:] - if a.endswith(','): - LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.") + if a.endswith(","): + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.") a = a[:-1] - if '=' in a: + if "=" in a: try: k, v = parse_key_value_pair(a) - if k == 'cfg': # custom.yaml passed - LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}') - overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'} + if k == "cfg" and v is not None: # custom.yaml passed + LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}") + overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != "cfg"} else: overrides[k] = v except (NameError, SyntaxError, ValueError, AssertionError) as e: - check_dict_alignment(full_args_dict, {a: ''}, e) + check_dict_alignment(full_args_dict, {a: ""}, e) elif a in TASKS: - overrides['task'] = a + overrides["task"] = a elif a in MODES: - overrides['mode'] = a + overrides["mode"] = a elif a.lower() in special: special[a.lower()]() return elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool): overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True elif a in DEFAULT_CFG_DICT: - raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign " - f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}") + raise SyntaxError( + f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign " + f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}" + ) else: - check_dict_alignment(full_args_dict, {a: ''}) + check_dict_alignment(full_args_dict, {a: ""}) # Check keys check_dict_alignment(full_args_dict, overrides) # Mode - mode = overrides.get('mode') + mode = overrides.get("mode") if mode is None: - mode = DEFAULT_CFG.mode or 'predict' - LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.") + mode = DEFAULT_CFG.mode or "predict" + LOGGER.warning(f"WARNING ⚠️ 'mode' argument is missing. Valid modes are {MODES}. Using default 'mode={mode}'.") elif mode not in MODES: raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}") # Task - task = overrides.pop('task', None) + task = overrides.pop("task", None) if task: if task not in TASKS: raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}") - if 'model' not in overrides: - overrides['model'] = TASK2MODEL[task] + if "model" not in overrides: + overrides["model"] = TASK2MODEL[task] # Model - model = overrides.pop('model', DEFAULT_CFG.model) + model = overrides.pop("model", DEFAULT_CFG.model) if model is None: - model = 'yolov8n.pt' - LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.") - overrides['model'] = model - if 'rtdetr' in model.lower(): # guess architecture + model = "yolov8n.pt" + LOGGER.warning(f"WARNING ⚠️ 'model' argument is missing. Using default 'model={model}'.") + overrides["model"] = model + # stem = Path(model).stem.lower() + stem = model.lower() + if "rtdetr" in stem: # guess architecture from ultralytics import RTDETR + model = RTDETR(model) # no task argument - elif 'fastsam' in model.lower(): + elif "fastsam" in stem: from ultralytics import FastSAM + model = FastSAM(model) - elif 'sam' in model.lower(): + elif "sam" in stem: from ultralytics import SAM + model = SAM(model) - else: + elif re.search("v3|v5|v6|v8|v9", stem): from ultralytics import YOLO + model = YOLO(model, task=task) - if isinstance(overrides.get('pretrained'), str): - model.load(overrides['pretrained']) + else: + from ultralytics import YOLOv10 + + # Special case for the HuggingFace Hub + split_path = model.split('/') + if len(split_path) == 2 and (not os.path.exists(model)): + model = YOLOv10.from_pretrained(model) + else: + model = YOLOv10(model) + if isinstance(overrides.get("pretrained"), str): + model.load(overrides["pretrained"]) # Task Update if task != model.task: if task: - LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " - f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.") + LOGGER.warning( + f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " + f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model." + ) task = model.task # Mode - if mode in ('predict', 'track') and 'source' not in overrides: - overrides['source'] = DEFAULT_CFG.source or ASSETS - LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.") - elif mode in ('train', 'val'): - if 'data' not in overrides and 'resume' not in overrides: - overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) - LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.") - elif mode == 'export': - if 'format' not in overrides: - overrides['format'] = DEFAULT_CFG.format or 'torchscript' - LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.") + if mode in ("predict", "track") and "source" not in overrides: + overrides["source"] = DEFAULT_CFG.source or ASSETS + LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.") + elif mode in ("train", "val"): + if "data" not in overrides and "resume" not in overrides: + overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) + LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.") + elif mode == "export": + if "format" not in overrides: + overrides["format"] = DEFAULT_CFG.format or "torchscript" + LOGGER.warning(f"WARNING ⚠️ 'format' argument is missing. Using default 'format={overrides['format']}'.") # Run command in python - # getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml getattr(model, mode)(**overrides) # default args from model + # Show help + LOGGER.info(f"💡 Learn more at https://docs.ultralytics.com/modes/{mode}") + # Special modes -------------------------------------------------------------------------------------------------------- def copy_default_cfg(): """Copy and create a new default configuration file with '_copy' appended to its name.""" - new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml') + new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml") shutil.copy2(DEFAULT_CFG_PATH, new_file) - LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n' - f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8") + LOGGER.info( + f"{DEFAULT_CFG_PATH} copied to {new_file}\n" + f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8" + ) -if __name__ == '__main__': +if __name__ == "__main__": # Example: entrypoint(debug='yolo predict model=yolov8n.pt') - entrypoint(debug='') + entrypoint(debug="") diff --git a/ultralytics/cfg/__pycache__/__init__.cpython-312.pyc b/ultralytics/cfg/__pycache__/__init__.cpython-312.pyc index f668d66..bbf0100 100644 Binary files a/ultralytics/cfg/__pycache__/__init__.cpython-312.pyc and b/ultralytics/cfg/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/cfg/__pycache__/__init__.cpython-39.pyc b/ultralytics/cfg/__pycache__/__init__.cpython-39.pyc index 30ff710..14e2e94 100644 Binary files a/ultralytics/cfg/__pycache__/__init__.cpython-39.pyc and b/ultralytics/cfg/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/cfg/datasets/Argoverse.yaml b/ultralytics/cfg/datasets/Argoverse.yaml index 76255e4..43755f7 100644 --- a/ultralytics/cfg/datasets/Argoverse.yaml +++ b/ultralytics/cfg/datasets/Argoverse.yaml @@ -1,17 +1,17 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -# Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/ by Argo AI +# Argoverse-HD dataset (ring-front-center camera) https://www.cs.cmu.edu/~mengtial/proj/streaming/ by Argo AI +# Documentation: https://docs.ultralytics.com/datasets/detect/argoverse/ # Example usage: yolo train data=Argoverse.yaml # parent # ├── ultralytics # └── datasets # └── Argoverse ← downloads here (31.5 GB) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/Argoverse # dataset root dir -train: Argoverse-1.1/images/train/ # train images (relative to 'path') 39384 images -val: Argoverse-1.1/images/val/ # val images (relative to 'path') 15062 images -test: Argoverse-1.1/images/test/ # test images (optional) https://eval.ai/web/challenges/challenge-page/800/overview +path: ../datasets/Argoverse # dataset root dir +train: Argoverse-1.1/images/train/ # train images (relative to 'path') 39384 images +val: Argoverse-1.1/images/val/ # val images (relative to 'path') 15062 images +test: Argoverse-1.1/images/test/ # test images (optional) https://eval.ai/web/challenges/challenge-page/800/overview # Classes names: @@ -24,7 +24,6 @@ names: 6: traffic_light 7: stop_sign - # Download script/URL (optional) --------------------------------------------------------------------------------------- download: | import json @@ -64,7 +63,9 @@ download: | # Download 'https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip' (deprecated S3 link) dir = Path(yaml['path']) # dataset root dir urls = ['https://drive.google.com/file/d/1st9qW3BeIwQsnR0t8mRpvbsSWIo16ACi/view?usp=drive_link'] - download(urls, dir=dir) + print("\n\nWARNING: Argoverse dataset MUST be downloaded manually, autodownload will NOT work.") + print(f"WARNING: Manually download Argoverse dataset '{urls[0]}' to '{dir}' and re-run your command.\n\n") + # download(urls, dir=dir) # Convert annotations_dir = 'Argoverse-HD/annotations/' diff --git a/ultralytics/cfg/datasets/DOTAv2.yaml b/ultralytics/cfg/datasets/DOTAv1.5.yaml similarity index 56% rename from ultralytics/cfg/datasets/DOTAv2.yaml rename to ultralytics/cfg/datasets/DOTAv1.5.yaml index c663bdd..701535f 100644 --- a/ultralytics/cfg/datasets/DOTAv2.yaml +++ b/ultralytics/cfg/datasets/DOTAv1.5.yaml @@ -1,18 +1,19 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -# DOTA 2.0 dataset https://captain-whu.github.io/DOTA/index.html for object detection in aerial images by Wuhan University -# Example usage: yolo train model=yolov8n-obb.pt data=DOTAv2.yaml +# DOTA 1.5 dataset https://captain-whu.github.io/DOTA/index.html for object detection in aerial images by Wuhan University +# Documentation: https://docs.ultralytics.com/datasets/obb/dota-v2/ +# Example usage: yolo train model=yolov8n-obb.pt data=DOTAv1.5.yaml # parent # ├── ultralytics # └── datasets -# └── dota2 ← downloads here (2GB) +# └── dota1.5 ← downloads here (2GB) # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/DOTAv2 # dataset root dir -train: images/train # train images (relative to 'path') 1411 images -val: images/val # val images (relative to 'path') 458 images -test: images/test # test images (optional) 937 images +path: ../datasets/DOTAv1.5 # dataset root dir +train: images/train # train images (relative to 'path') 1411 images +val: images/val # val images (relative to 'path') 458 images +test: images/test # test images (optional) 937 images -# Classes for DOTA 2.0 +# Classes for DOTA 1.5 names: 0: plane 1: ship @@ -30,8 +31,6 @@ names: 13: soccer ball field 14: swimming pool 15: container crane - 16: airport - 17: helipad # Download script/URL (optional) -download: https://github.com/ultralytics/yolov5/releases/download/v1.0/DOTAv2.zip +download: https://github.com/ultralytics/yolov5/releases/download/v1.0/DOTAv1.5.zip diff --git a/ultralytics/cfg/datasets/DOTAv1.yaml b/ultralytics/cfg/datasets/DOTAv1.yaml new file mode 100644 index 0000000..f6364d3 --- /dev/null +++ b/ultralytics/cfg/datasets/DOTAv1.yaml @@ -0,0 +1,35 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# DOTA 1.0 dataset https://captain-whu.github.io/DOTA/index.html for object detection in aerial images by Wuhan University +# Documentation: https://docs.ultralytics.com/datasets/obb/dota-v2/ +# Example usage: yolo train model=yolov8n-obb.pt data=DOTAv1.yaml +# parent +# ├── ultralytics +# └── datasets +# └── dota1 ← downloads here (2GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/DOTAv1 # dataset root dir +train: images/train # train images (relative to 'path') 1411 images +val: images/val # val images (relative to 'path') 458 images +test: images/test # test images (optional) 937 images + +# Classes for DOTA 1.0 +names: + 0: plane + 1: ship + 2: storage tank + 3: baseball diamond + 4: tennis court + 5: basketball court + 6: ground track field + 7: harbor + 8: bridge + 9: large vehicle + 10: small vehicle + 11: helicopter + 12: roundabout + 13: soccer ball field + 14: swimming pool + +# Download script/URL (optional) +download: https://github.com/ultralytics/yolov5/releases/download/v1.0/DOTAv1.zip diff --git a/ultralytics/cfg/datasets/GlobalWheat2020.yaml b/ultralytics/cfg/datasets/GlobalWheat2020.yaml index 165004f..ae6bfa0 100644 --- a/ultralytics/cfg/datasets/GlobalWheat2020.yaml +++ b/ultralytics/cfg/datasets/GlobalWheat2020.yaml @@ -1,14 +1,14 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -# Global Wheat 2020 dataset http://www.global-wheat.com/ by University of Saskatchewan +# Global Wheat 2020 dataset https://www.global-wheat.com/ by University of Saskatchewan +# Documentation: https://docs.ultralytics.com/datasets/detect/globalwheat2020/ # Example usage: yolo train data=GlobalWheat2020.yaml # parent # ├── ultralytics # └── datasets # └── GlobalWheat2020 ← downloads here (7.0 GB) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/GlobalWheat2020 # dataset root dir +path: ../datasets/GlobalWheat2020 # dataset root dir train: # train images (relative to 'path') 3422 images - images/arvalis_1 - images/arvalis_2 @@ -29,7 +29,6 @@ test: # test images (optional) 1276 images names: 0: wheat_head - # Download script/URL (optional) --------------------------------------------------------------------------------------- download: | from ultralytics.utils.downloads import download diff --git a/ultralytics/cfg/datasets/ImageNet.yaml b/ultralytics/cfg/datasets/ImageNet.yaml index c1aa155..0dc344a 100644 --- a/ultralytics/cfg/datasets/ImageNet.yaml +++ b/ultralytics/cfg/datasets/ImageNet.yaml @@ -1,18 +1,18 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license # ImageNet-1k dataset https://www.image-net.org/index.php by Stanford University # Simplified class names from https://github.com/anishathalye/imagenet-simple-labels +# Documentation: https://docs.ultralytics.com/datasets/classify/imagenet/ # Example usage: yolo train task=classify data=imagenet # parent # ├── ultralytics # └── datasets # └── imagenet ← downloads here (144 GB) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/imagenet # dataset root dir -train: train # train images (relative to 'path') 1281167 images -val: val # val images (relative to 'path') 50000 images -test: # test images (optional) +path: ../datasets/imagenet # dataset root dir +train: train # train images (relative to 'path') 1281167 images +val: val # val images (relative to 'path') 50000 images +test: # test images (optional) # Classes names: @@ -2020,6 +2020,5 @@ map: n13133613: ear n15075141: toilet_tissue - # Download script/URL (optional) download: yolo/data/scripts/get_imagenet.sh diff --git a/ultralytics/cfg/datasets/Objects365.yaml b/ultralytics/cfg/datasets/Objects365.yaml index 415eff9..9b11720 100644 --- a/ultralytics/cfg/datasets/Objects365.yaml +++ b/ultralytics/cfg/datasets/Objects365.yaml @@ -1,17 +1,17 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license # Objects365 dataset https://www.objects365.org/ by Megvii +# Documentation: https://docs.ultralytics.com/datasets/detect/objects365/ # Example usage: yolo train data=Objects365.yaml # parent # ├── ultralytics # └── datasets # └── Objects365 ← downloads here (712 GB = 367G data + 345G zips) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/Objects365 # dataset root dir -train: images/train # train images (relative to 'path') 1742289 images +path: ../datasets/Objects365 # dataset root dir +train: images/train # train images (relative to 'path') 1742289 images val: images/val # val images (relative to 'path') 80000 images -test: # test images (optional) +test: # test images (optional) # Classes names: @@ -381,7 +381,6 @@ names: 363: Curling 364: Table Tennis - # Download script/URL (optional) --------------------------------------------------------------------------------------- download: | from tqdm import tqdm diff --git a/ultralytics/cfg/datasets/SKU-110K.yaml b/ultralytics/cfg/datasets/SKU-110K.yaml index e6deac2..fff1baa 100644 --- a/ultralytics/cfg/datasets/SKU-110K.yaml +++ b/ultralytics/cfg/datasets/SKU-110K.yaml @@ -1,23 +1,22 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license # SKU-110K retail items dataset https://github.com/eg4000/SKU110K_CVPR19 by Trax Retail +# Documentation: https://docs.ultralytics.com/datasets/detect/sku-110k/ # Example usage: yolo train data=SKU-110K.yaml # parent # ├── ultralytics # └── datasets # └── SKU-110K ← downloads here (13.6 GB) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/SKU-110K # dataset root dir -train: train.txt # train images (relative to 'path') 8219 images -val: val.txt # val images (relative to 'path') 588 images -test: test.txt # test images (optional) 2936 images +path: ../datasets/SKU-110K # dataset root dir +train: train.txt # train images (relative to 'path') 8219 images +val: val.txt # val images (relative to 'path') 588 images +test: test.txt # test images (optional) 2936 images # Classes names: 0: object - # Download script/URL (optional) --------------------------------------------------------------------------------------- download: | import shutil diff --git a/ultralytics/cfg/datasets/VOC.yaml b/ultralytics/cfg/datasets/VOC.yaml index 6bdcc4f..cd6d5ad 100644 --- a/ultralytics/cfg/datasets/VOC.yaml +++ b/ultralytics/cfg/datasets/VOC.yaml @@ -1,12 +1,12 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license # PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC by University of Oxford +# Documentation: # Documentation: https://docs.ultralytics.com/datasets/detect/voc/ # Example usage: yolo train data=VOC.yaml # parent # ├── ultralytics # └── datasets # └── VOC ← downloads here (2.8 GB) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] path: ../datasets/VOC train: # train images (relative to 'path') 16551 images @@ -42,7 +42,6 @@ names: 18: train 19: tvmonitor - # Download script/URL (optional) --------------------------------------------------------------------------------------- download: | import xml.etree.ElementTree as ET @@ -81,7 +80,7 @@ download: | urls = [f'{url}VOCtrainval_06-Nov-2007.zip', # 446MB, 5012 images f'{url}VOCtest_06-Nov-2007.zip', # 438MB, 4953 images f'{url}VOCtrainval_11-May-2012.zip'] # 1.95GB, 17126 images - download(urls, dir=dir / 'images', curl=True, threads=3) + download(urls, dir=dir / 'images', curl=True, threads=3, exist_ok=True) # download and unzip over existing paths (required) # Convert path = dir / 'images/VOCdevkit' diff --git a/ultralytics/cfg/datasets/VisDrone.yaml b/ultralytics/cfg/datasets/VisDrone.yaml index a1a4a46..773f0b0 100644 --- a/ultralytics/cfg/datasets/VisDrone.yaml +++ b/ultralytics/cfg/datasets/VisDrone.yaml @@ -1,17 +1,17 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license # VisDrone2019-DET dataset https://github.com/VisDrone/VisDrone-Dataset by Tianjin University +# Documentation: https://docs.ultralytics.com/datasets/detect/visdrone/ # Example usage: yolo train data=VisDrone.yaml # parent # ├── ultralytics # └── datasets # └── VisDrone ← downloads here (2.3 GB) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/VisDrone # dataset root dir -train: VisDrone2019-DET-train/images # train images (relative to 'path') 6471 images -val: VisDrone2019-DET-val/images # val images (relative to 'path') 548 images -test: VisDrone2019-DET-test-dev/images # test images (optional) 1610 images +path: ../datasets/VisDrone # dataset root dir +train: VisDrone2019-DET-train/images # train images (relative to 'path') 6471 images +val: VisDrone2019-DET-val/images # val images (relative to 'path') 548 images +test: VisDrone2019-DET-test-dev/images # test images (optional) 1610 images # Classes names: @@ -26,7 +26,6 @@ names: 8: bus 9: motor - # Download script/URL (optional) --------------------------------------------------------------------------------------- download: | import os diff --git a/ultralytics/cfg/datasets/african-wildlife.yaml b/ultralytics/cfg/datasets/african-wildlife.yaml new file mode 100644 index 0000000..af8af36 --- /dev/null +++ b/ultralytics/cfg/datasets/african-wildlife.yaml @@ -0,0 +1,24 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# African-wildlife dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/african-wildlife/ +# Example usage: yolo train data=african-wildlife.yaml +# parent +# ├── ultralytics +# └── datasets +# └── african-wildlife ← downloads here (100 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/african-wildlife # dataset root dir +train: train/images # train images (relative to 'path') 1052 images +val: valid/images # val images (relative to 'path') 225 images +test: test/images # test images (relative to 'path') 227 images + +# Classes +names: + 0: buffalo + 1: elephant + 2: rhino + 3: zebra + +# Download script/URL (optional) +download: https://ultralytics.com/assets/african-wildlife.zip diff --git a/ultralytics/cfg/datasets/brain-tumor.yaml b/ultralytics/cfg/datasets/brain-tumor.yaml new file mode 100644 index 0000000..be61098 --- /dev/null +++ b/ultralytics/cfg/datasets/brain-tumor.yaml @@ -0,0 +1,22 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Brain-tumor dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/brain-tumor/ +# Example usage: yolo train data=brain-tumor.yaml +# parent +# ├── ultralytics +# └── datasets +# └── brain-tumor ← downloads here (4.05 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/brain-tumor # dataset root dir +train: train/images # train images (relative to 'path') 893 images +val: valid/images # val images (relative to 'path') 223 images +test: # test images (relative to 'path') + +# Classes +names: + 0: negative + 1: positive + +# Download script/URL (optional) +download: https://ultralytics.com/assets/brain-tumor.zip diff --git a/ultralytics/cfg/datasets/carparts-seg.yaml b/ultralytics/cfg/datasets/carparts-seg.yaml new file mode 100644 index 0000000..a1c25ba --- /dev/null +++ b/ultralytics/cfg/datasets/carparts-seg.yaml @@ -0,0 +1,43 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Carparts-seg dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/segment/carparts-seg/ +# Example usage: yolo train data=carparts-seg.yaml +# parent +# ├── ultralytics +# └── datasets +# └── carparts-seg ← downloads here (132 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/carparts-seg # dataset root dir +train: train/images # train images (relative to 'path') 3516 images +val: valid/images # val images (relative to 'path') 276 images +test: test/images # test images (relative to 'path') 401 images + +# Classes +names: + 0: back_bumper + 1: back_door + 2: back_glass + 3: back_left_door + 4: back_left_light + 5: back_light + 6: back_right_door + 7: back_right_light + 8: front_bumper + 9: front_door + 10: front_glass + 11: front_left_door + 12: front_left_light + 13: front_light + 14: front_right_door + 15: front_right_light + 16: hood + 17: left_mirror + 18: object + 19: right_mirror + 20: tailgate + 21: trunk + 22: wheel + +# Download script/URL (optional) +download: https://ultralytics.com/assets/carparts-seg.zip diff --git a/ultralytics/cfg/datasets/coco-pose.yaml b/ultralytics/cfg/datasets/coco-pose.yaml index 670d55b..b50b7a5 100644 --- a/ultralytics/cfg/datasets/coco-pose.yaml +++ b/ultralytics/cfg/datasets/coco-pose.yaml @@ -1,20 +1,20 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -# COCO 2017 dataset http://cocodataset.org by Microsoft +# COCO 2017 dataset https://cocodataset.org by Microsoft +# Documentation: https://docs.ultralytics.com/datasets/pose/coco/ # Example usage: yolo train data=coco-pose.yaml # parent # ├── ultralytics # └── datasets # └── coco-pose ← downloads here (20.1 GB) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/coco-pose # dataset root dir -train: train2017.txt # train images (relative to 'path') 118287 images -val: val2017.txt # val images (relative to 'path') 5000 images -test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794 +path: ../datasets/coco-pose # dataset root dir +train: train2017.txt # train images (relative to 'path') 118287 images +val: val2017.txt # val images (relative to 'path') 5000 images +test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794 # Keypoints -kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] # Classes diff --git a/ultralytics/cfg/datasets/coco.yaml b/ultralytics/cfg/datasets/coco.yaml index 8a70a5b..d0297f7 100644 --- a/ultralytics/cfg/datasets/coco.yaml +++ b/ultralytics/cfg/datasets/coco.yaml @@ -1,17 +1,17 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -# COCO 2017 dataset http://cocodataset.org by Microsoft +# COCO 2017 dataset https://cocodataset.org by Microsoft +# Documentation: https://docs.ultralytics.com/datasets/detect/coco/ # Example usage: yolo train data=coco.yaml # parent # ├── ultralytics # └── datasets # └── coco ← downloads here (20.1 GB) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/coco # dataset root dir -train: train2017.txt # train images (relative to 'path') 118287 images -val: val2017.txt # val images (relative to 'path') 5000 images -test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794 +path: ../datasets/coco # dataset root dir +train: train2017.txt # train images (relative to 'path') 118287 images +val: val2017.txt # val images (relative to 'path') 5000 images +test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794 # Classes names: @@ -96,7 +96,6 @@ names: 78: hair drier 79: toothbrush - # Download script/URL (optional) download: | from ultralytics.utils.downloads import download diff --git a/ultralytics/cfg/datasets/coco128-seg.yaml b/ultralytics/cfg/datasets/coco128-seg.yaml index 8c2e3da..e898a40 100644 --- a/ultralytics/cfg/datasets/coco128-seg.yaml +++ b/ultralytics/cfg/datasets/coco128-seg.yaml @@ -1,17 +1,17 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license # COCO128-seg dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/segment/coco/ # Example usage: yolo train data=coco128.yaml # parent # ├── ultralytics # └── datasets # └── coco128-seg ← downloads here (7 MB) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/coco128-seg # dataset root dir -train: images/train2017 # train images (relative to 'path') 128 images -val: images/train2017 # val images (relative to 'path') 128 images -test: # test images (optional) +path: ../datasets/coco128-seg # dataset root dir +train: images/train2017 # train images (relative to 'path') 128 images +val: images/train2017 # val images (relative to 'path') 128 images +test: # test images (optional) # Classes names: @@ -96,6 +96,5 @@ names: 78: hair drier 79: toothbrush - # Download script/URL (optional) download: https://ultralytics.com/assets/coco128-seg.zip diff --git a/ultralytics/cfg/datasets/coco128.yaml b/ultralytics/cfg/datasets/coco128.yaml index 9749ab6..8d47ee0 100644 --- a/ultralytics/cfg/datasets/coco128.yaml +++ b/ultralytics/cfg/datasets/coco128.yaml @@ -1,17 +1,17 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license # COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/coco/ # Example usage: yolo train data=coco128.yaml # parent # ├── ultralytics # └── datasets # └── coco128 ← downloads here (7 MB) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/coco128 # dataset root dir -train: images/train2017 # train images (relative to 'path') 128 images -val: images/train2017 # val images (relative to 'path') 128 images -test: # test images (optional) +path: ../datasets/coco128 # dataset root dir +train: images/train2017 # train images (relative to 'path') 128 images +val: images/train2017 # val images (relative to 'path') 128 images +test: # test images (optional) # Classes names: @@ -96,6 +96,5 @@ names: 78: hair drier 79: toothbrush - # Download script/URL (optional) download: https://ultralytics.com/assets/coco128.zip diff --git a/ultralytics/cfg/datasets/coco8-pose.yaml b/ultralytics/cfg/datasets/coco8-pose.yaml index e6fab8b..4dee5be 100644 --- a/ultralytics/cfg/datasets/coco8-pose.yaml +++ b/ultralytics/cfg/datasets/coco8-pose.yaml @@ -1,20 +1,20 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license # COCO8-pose dataset (first 8 images from COCO train2017) by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/pose/coco8-pose/ # Example usage: yolo train data=coco8-pose.yaml # parent # ├── ultralytics # └── datasets # └── coco8-pose ← downloads here (1 MB) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/coco8-pose # dataset root dir -train: images/train # train images (relative to 'path') 4 images -val: images/val # val images (relative to 'path') 4 images -test: # test images (optional) +path: ../datasets/coco8-pose # dataset root dir +train: images/train # train images (relative to 'path') 4 images +val: images/val # val images (relative to 'path') 4 images +test: # test images (optional) # Keypoints -kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] # Classes diff --git a/ultralytics/cfg/datasets/coco8-seg.yaml b/ultralytics/cfg/datasets/coco8-seg.yaml index e6faca1..d8b6ed2 100644 --- a/ultralytics/cfg/datasets/coco8-seg.yaml +++ b/ultralytics/cfg/datasets/coco8-seg.yaml @@ -1,17 +1,17 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license # COCO8-seg dataset (first 8 images from COCO train2017) by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/segment/coco8-seg/ # Example usage: yolo train data=coco8-seg.yaml # parent # ├── ultralytics # └── datasets # └── coco8-seg ← downloads here (1 MB) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/coco8-seg # dataset root dir -train: images/train # train images (relative to 'path') 4 images -val: images/val # val images (relative to 'path') 4 images -test: # test images (optional) +path: ../datasets/coco8-seg # dataset root dir +train: images/train # train images (relative to 'path') 4 images +val: images/val # val images (relative to 'path') 4 images +test: # test images (optional) # Classes names: @@ -96,6 +96,5 @@ names: 78: hair drier 79: toothbrush - # Download script/URL (optional) download: https://ultralytics.com/assets/coco8-seg.zip diff --git a/ultralytics/cfg/datasets/coco8.yaml b/ultralytics/cfg/datasets/coco8.yaml index eeb5d9d..2925f81 100644 --- a/ultralytics/cfg/datasets/coco8.yaml +++ b/ultralytics/cfg/datasets/coco8.yaml @@ -1,17 +1,17 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license # COCO8 dataset (first 8 images from COCO train2017) by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/coco8/ # Example usage: yolo train data=coco8.yaml # parent # ├── ultralytics # └── datasets # └── coco8 ← downloads here (1 MB) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/coco8 # dataset root dir -train: images/train # train images (relative to 'path') 4 images -val: images/val # val images (relative to 'path') 4 images -test: # test images (optional) +path: ../datasets/coco8 # dataset root dir +train: images/train # train images (relative to 'path') 4 images +val: images/val # val images (relative to 'path') 4 images +test: # test images (optional) # Classes names: @@ -96,6 +96,5 @@ names: 78: hair drier 79: toothbrush - # Download script/URL (optional) download: https://ultralytics.com/assets/coco8.zip diff --git a/ultralytics/cfg/datasets/crack-seg.yaml b/ultralytics/cfg/datasets/crack-seg.yaml new file mode 100644 index 0000000..2054f62 --- /dev/null +++ b/ultralytics/cfg/datasets/crack-seg.yaml @@ -0,0 +1,21 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Crack-seg dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/segment/crack-seg/ +# Example usage: yolo train data=crack-seg.yaml +# parent +# ├── ultralytics +# └── datasets +# └── crack-seg ← downloads here (91.2 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/crack-seg # dataset root dir +train: train/images # train images (relative to 'path') 3717 images +val: valid/images # val images (relative to 'path') 112 images +test: test/images # test images (relative to 'path') 200 images + +# Classes +names: + 0: crack + +# Download script/URL (optional) +download: https://ultralytics.com/assets/crack-seg.zip diff --git a/ultralytics/cfg/datasets/dota8.yaml b/ultralytics/cfg/datasets/dota8.yaml new file mode 100644 index 0000000..f58b501 --- /dev/null +++ b/ultralytics/cfg/datasets/dota8.yaml @@ -0,0 +1,34 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# DOTA8 dataset 8 images from split DOTAv1 dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/obb/dota8/ +# Example usage: yolo train model=yolov8n-obb.pt data=dota8.yaml +# parent +# ├── ultralytics +# └── datasets +# └── dota8 ← downloads here (1MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/dota8 # dataset root dir +train: images/train # train images (relative to 'path') 4 images +val: images/val # val images (relative to 'path') 4 images + +# Classes for DOTA 1.0 +names: + 0: plane + 1: ship + 2: storage tank + 3: baseball diamond + 4: tennis court + 5: basketball court + 6: ground track field + 7: harbor + 8: bridge + 9: large vehicle + 10: small vehicle + 11: helicopter + 12: roundabout + 13: soccer ball field + 14: swimming pool + +# Download script/URL (optional) +download: https://github.com/ultralytics/yolov5/releases/download/v1.0/dota8.zip diff --git a/ultralytics/cfg/datasets/open-images-v7.yaml b/ultralytics/cfg/datasets/open-images-v7.yaml index bb1e3ff..d9cad9f 100644 --- a/ultralytics/cfg/datasets/open-images-v7.yaml +++ b/ultralytics/cfg/datasets/open-images-v7.yaml @@ -1,17 +1,17 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license # Open Images v7 dataset https://storage.googleapis.com/openimages/web/index.html by Google +# Documentation: https://docs.ultralytics.com/datasets/detect/open-images-v7/ # Example usage: yolo train data=open-images-v7.yaml # parent # ├── ultralytics # └── datasets # └── open-images-v7 ← downloads here (561 GB) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/open-images-v7 # dataset root dir -train: images/train # train images (relative to 'path') 1743042 images -val: images/val # val images (relative to 'path') 41620 images -test: # test images (optional) +path: ../datasets/open-images-v7 # dataset root dir +train: images/train # train images (relative to 'path') 1743042 images +val: images/val # val images (relative to 'path') 41620 images +test: # test images (optional) # Classes names: @@ -617,7 +617,6 @@ names: 599: Zebra 600: Zucchini - # Download script/URL (optional) --------------------------------------------------------------------------------------- download: | from ultralytics.utils import LOGGER, SETTINGS, Path, is_ubuntu, get_ubuntu_version diff --git a/ultralytics/cfg/datasets/package-seg.yaml b/ultralytics/cfg/datasets/package-seg.yaml new file mode 100644 index 0000000..44fe550 --- /dev/null +++ b/ultralytics/cfg/datasets/package-seg.yaml @@ -0,0 +1,21 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Package-seg dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/segment/package-seg/ +# Example usage: yolo train data=package-seg.yaml +# parent +# ├── ultralytics +# └── datasets +# └── package-seg ← downloads here (102 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/package-seg # dataset root dir +train: images/train # train images (relative to 'path') 1920 images +val: images/val # val images (relative to 'path') 89 images +test: test/images # test images (relative to 'path') 188 images + +# Classes +names: + 0: package + +# Download script/URL (optional) +download: https://ultralytics.com/assets/package-seg.zip diff --git a/ultralytics/cfg/datasets/tiger-pose.yaml b/ultralytics/cfg/datasets/tiger-pose.yaml new file mode 100644 index 0000000..d37df04 --- /dev/null +++ b/ultralytics/cfg/datasets/tiger-pose.yaml @@ -0,0 +1,24 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Tiger Pose dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/pose/tiger-pose/ +# Example usage: yolo train data=tiger-pose.yaml +# parent +# ├── ultralytics +# └── datasets +# └── tiger-pose ← downloads here (75.3 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/tiger-pose # dataset root dir +train: train # train images (relative to 'path') 210 images +val: val # val images (relative to 'path') 53 images + +# Keypoints +kpt_shape: [12, 2] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +flip_idx: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + +# Classes +names: + 0: tiger + +# Download script/URL (optional) +download: https://ultralytics.com/assets/tiger-pose.zip diff --git a/ultralytics/cfg/datasets/xView.yaml b/ultralytics/cfg/datasets/xView.yaml index bdc2d91..d2e957a 100644 --- a/ultralytics/cfg/datasets/xView.yaml +++ b/ultralytics/cfg/datasets/xView.yaml @@ -1,17 +1,17 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license # DIUx xView 2018 Challenge https://challenge.xviewdataset.org by U.S. National Geospatial-Intelligence Agency (NGA) # -------- DOWNLOAD DATA MANUALLY and jar xf val_images.zip to 'datasets/xView' before running train command! -------- +# Documentation: https://docs.ultralytics.com/datasets/detect/xview/ # Example usage: yolo train data=xView.yaml # parent # ├── ultralytics # └── datasets # └── xView ← downloads here (20.7 GB) - # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/xView # dataset root dir -train: images/autosplit_train.txt # train images (relative to 'path') 90% of 847 train images -val: images/autosplit_val.txt # train images (relative to 'path') 10% of 847 train images +path: ../datasets/xView # dataset root dir +train: images/autosplit_train.txt # train images (relative to 'path') 90% of 847 train images +val: images/autosplit_val.txt # train images (relative to 'path') 10% of 847 train images # Classes names: @@ -76,7 +76,6 @@ names: 58: Pylon 59: Tower - # Download script/URL (optional) --------------------------------------------------------------------------------------- download: | import json diff --git a/ultralytics/cfg/default.yaml b/ultralytics/cfg/default.yaml index 12e3708..bd074b1 100644 --- a/ultralytics/cfg/default.yaml +++ b/ultralytics/cfg/default.yaml @@ -1,116 +1,127 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license # Default training settings and hyperparameters for medium-augmentation COCO training -task: detect # (str) YOLO task, i.e. detect, segment, classify, pose -mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark +task: detect # (str) YOLO task, i.e. detect, segment, classify, pose +mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark # Train settings ------------------------------------------------------------------------------------------------------- -model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml -data: # (str, optional) path to data file, i.e. coco128.yaml -epochs: 100 # (int) number of epochs to train for -patience: 50 # (int) epochs to wait for no observable improvement for early stopping of training -batch: 16 # (int) number of images per batch (-1 for AutoBatch) -imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes -save: True # (bool) save train checkpoints and predict results +model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml +data: # (str, optional) path to data file, i.e. coco128.yaml +epochs: 100 # (int) number of epochs to train for +time: # (float, optional) number of hours to train for, overrides epochs if supplied +patience: 100 # (int) epochs to wait for no observable improvement for early stopping of training +batch: 16 # (int) number of images per batch (-1 for AutoBatch) +imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes +save: True # (bool) save train checkpoints and predict results save_period: -1 # (int) Save checkpoint every x epochs (disabled if < 1) -cache: False # (bool) True/ram, disk or False. Use cache for data loading -device: # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu -workers: 8 # (int) number of worker threads for data loading (per RANK if DDP) -project: # (str, optional) project name -name: # (str, optional) experiment name, results saved to 'project/name' directory -exist_ok: False # (bool) whether to overwrite existing experiment -pretrained: True # (bool | str) whether to use a pretrained model (bool) or a model to load weights from (str) -optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto] -verbose: True # (bool) whether to print verbose output -seed: 0 # (int) random seed for reproducibility -deterministic: True # (bool) whether to enable deterministic mode -single_cls: False # (bool) train multi-class data as single-class -rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val' -cos_lr: False # (bool) use cosine learning rate scheduler -close_mosaic: 10 # (int) disable mosaic augmentation for final epochs (0 to disable) -resume: False # (bool) resume training from last checkpoint -amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check -fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set) -profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers -freeze: None # (int | list, optional) freeze first n layers, or freeze list of layer indices during training +val_period: 1 # (int) Validation every x epochs +cache: False # (bool) True/ram, disk or False. Use cache for data loading +device: # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu +workers: 8 # (int) number of worker threads for data loading (per RANK if DDP) +project: # (str, optional) project name +name: # (str, optional) experiment name, results saved to 'project/name' directory +exist_ok: False # (bool) whether to overwrite existing experiment +pretrained: True # (bool | str) whether to use a pretrained model (bool) or a model to load weights from (str) +optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto] +verbose: True # (bool) whether to print verbose output +seed: 0 # (int) random seed for reproducibility +deterministic: True # (bool) whether to enable deterministic mode +single_cls: False # (bool) train multi-class data as single-class +rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val' +cos_lr: False # (bool) use cosine learning rate scheduler +close_mosaic: 10 # (int) disable mosaic augmentation for final epochs (0 to disable) +resume: False # (bool) resume training from last checkpoint +amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check +fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set) +profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers +freeze: None # (int | list, optional) freeze first n layers, or freeze list of layer indices during training +multi_scale: False # (bool) Whether to use multiscale during training # Segmentation -overlap_mask: True # (bool) masks should overlap during training (segment train only) -mask_ratio: 4 # (int) mask downsample ratio (segment train only) +overlap_mask: True # (bool) masks should overlap during training (segment train only) +mask_ratio: 4 # (int) mask downsample ratio (segment train only) # Classification -dropout: 0.0 # (float) use dropout regularization (classify train only) +dropout: 0.0 # (float) use dropout regularization (classify train only) # Val/Test settings ---------------------------------------------------------------------------------------------------- -val: True # (bool) validate/test during training -split: val # (str) dataset split to use for validation, i.e. 'val', 'test' or 'train' -save_json: False # (bool) save results to JSON file -save_hybrid: False # (bool) save hybrid version of labels (labels + additional predictions) -conf: # (float, optional) object confidence threshold for detection (default 0.25 predict, 0.001 val) -iou: 0.7 # (float) intersection over union (IoU) threshold for NMS -max_det: 300 # (int) maximum number of detections per image -half: False # (bool) use half precision (FP16) -dnn: False # (bool) use OpenCV DNN for ONNX inference -plots: True # (bool) save plots during train/val +val: True # (bool) validate/test during training +split: val # (str) dataset split to use for validation, i.e. 'val', 'test' or 'train' +save_json: False # (bool) save results to JSON file +save_hybrid: False # (bool) save hybrid version of labels (labels + additional predictions) +conf: # (float, optional) object confidence threshold for detection (default 0.25 predict, 0.001 val) +iou: 0.7 # (float) intersection over union (IoU) threshold for NMS +max_det: 300 # (int) maximum number of detections per image +half: False # (bool) use half precision (FP16) +dnn: False # (bool) use OpenCV DNN for ONNX inference +plots: True # (bool) save plots and images during train/val -# Prediction settings -------------------------------------------------------------------------------------------------- -source: # (str, optional) source directory for images or videos -show: False # (bool) show results if possible -save_txt: False # (bool) save results as .txt file -save_conf: False # (bool) save results with confidence scores -save_crop: False # (bool) save cropped images with results -show_labels: True # (bool) show object labels in plots -show_conf: True # (bool) show object confidence scores in plots -vid_stride: 1 # (int) video frame-rate stride -stream_buffer: False # (bool) buffer all streaming frames (True) or return the most recent frame (False) -line_width: # (int, optional) line width of the bounding boxes, auto if missing -visualize: False # (bool) visualize model features -augment: False # (bool) apply image augmentation to prediction sources -agnostic_nms: False # (bool) class-agnostic NMS -classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3] -retina_masks: False # (bool) use high-resolution segmentation masks -boxes: True # (bool) Show boxes in segmentation predictions +# Predict settings ----------------------------------------------------------------------------------------------------- +source: # (str, optional) source directory for images or videos +vid_stride: 1 # (int) video frame-rate stride +stream_buffer: False # (bool) buffer all streaming frames (True) or return the most recent frame (False) +visualize: False # (bool) visualize model features +augment: False # (bool) apply image augmentation to prediction sources +agnostic_nms: False # (bool) class-agnostic NMS +classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3] +retina_masks: False # (bool) use high-resolution segmentation masks +embed: # (list[int], optional) return feature vectors/embeddings from given layers + +# Visualize settings --------------------------------------------------------------------------------------------------- +show: False # (bool) show predicted images and videos if environment allows +save_frames: False # (bool) save predicted individual video frames +save_txt: False # (bool) save results as .txt file +save_conf: False # (bool) save results with confidence scores +save_crop: False # (bool) save cropped images with results +show_labels: True # (bool) show prediction labels, i.e. 'person' +show_conf: True # (bool) show prediction confidence, i.e. '0.99' +show_boxes: True # (bool) show prediction boxes +line_width: # (int, optional) line width of the bounding boxes. Scaled to image size if None. # Export settings ------------------------------------------------------------------------------------------------------ -format: torchscript # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats -keras: False # (bool) use Kera=s -optimize: False # (bool) TorchScript: optimize for mobile -int8: False # (bool) CoreML/TF INT8 quantization -dynamic: False # (bool) ONNX/TF/TensorRT: dynamic axes -simplify: False # (bool) ONNX: simplify model -opset: # (int, optional) ONNX: opset version -workspace: 4 # (int) TensorRT: workspace size (GB) -nms: False # (bool) CoreML: add NMS +format: torchscript # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats +keras: False # (bool) use Kera=s +optimize: False # (bool) TorchScript: optimize for mobile +int8: False # (bool) CoreML/TF INT8 quantization +dynamic: False # (bool) ONNX/TF/TensorRT: dynamic axes +simplify: False # (bool) ONNX: simplify model using `onnxslim` +opset: # (int, optional) ONNX: opset version +workspace: 4 # (int) TensorRT: workspace size (GB) +nms: False # (bool) CoreML: add NMS # Hyperparameters ------------------------------------------------------------------------------------------------------ -lr0: 0.01 # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3) -lrf: 0.01 # (float) final learning rate (lr0 * lrf) -momentum: 0.937 # (float) SGD momentum/Adam beta1 -weight_decay: 0.0005 # (float) optimizer weight decay 5e-4 -warmup_epochs: 3.0 # (float) warmup epochs (fractions ok) -warmup_momentum: 0.8 # (float) warmup initial momentum -warmup_bias_lr: 0.1 # (float) warmup initial bias lr -box: 7.5 # (float) box loss gain -cls: 0.5 # (float) cls loss gain (scale with pixels) -dfl: 1.5 # (float) dfl loss gain -pose: 12.0 # (float) pose loss gain -kobj: 1.0 # (float) keypoint obj loss gain -label_smoothing: 0.0 # (float) label smoothing (fraction) -nbs: 64 # (int) nominal batch size -hsv_h: 0.015 # (float) image HSV-Hue augmentation (fraction) -hsv_s: 0.7 # (float) image HSV-Saturation augmentation (fraction) -hsv_v: 0.4 # (float) image HSV-Value augmentation (fraction) -degrees: 0.0 # (float) image rotation (+/- deg) -translate: 0.1 # (float) image translation (+/- fraction) -scale: 0.5 # (float) image scale (+/- gain) -shear: 0.0 # (float) image shear (+/- deg) -perspective: 0.0 # (float) image perspective (+/- fraction), range 0-0.001 -flipud: 0.0 # (float) image flip up-down (probability) -fliplr: 0.5 # (float) image flip left-right (probability) -mosaic: 1.0 # (float) image mosaic (probability) -mixup: 0.0 # (float) image mixup (probability) -copy_paste: 0.0 # (float) segment copy-paste (probability) +lr0: 0.01 # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3) +lrf: 0.01 # (float) final learning rate (lr0 * lrf) +momentum: 0.937 # (float) SGD momentum/Adam beta1 +weight_decay: 0.0005 # (float) optimizer weight decay 5e-4 +warmup_epochs: 3.0 # (float) warmup epochs (fractions ok) +warmup_momentum: 0.8 # (float) warmup initial momentum +warmup_bias_lr: 0.1 # (float) warmup initial bias lr +box: 7.5 # (float) box loss gain +cls: 0.5 # (float) cls loss gain (scale with pixels) +dfl: 1.5 # (float) dfl loss gain +pose: 12.0 # (float) pose loss gain +kobj: 1.0 # (float) keypoint obj loss gain +label_smoothing: 0.0 # (float) label smoothing (fraction) +nbs: 64 # (int) nominal batch size +hsv_h: 0.015 # (float) image HSV-Hue augmentation (fraction) +hsv_s: 0.7 # (float) image HSV-Saturation augmentation (fraction) +hsv_v: 0.4 # (float) image HSV-Value augmentation (fraction) +degrees: 0.0 # (float) image rotation (+/- deg) +translate: 0.1 # (float) image translation (+/- fraction) +scale: 0.5 # (float) image scale (+/- gain) +shear: 0.0 # (float) image shear (+/- deg) +perspective: 0.0 # (float) image perspective (+/- fraction), range 0-0.001 +flipud: 0.0 # (float) image flip up-down (probability) +fliplr: 0.5 # (float) image flip left-right (probability) +bgr: 0.0 # (float) image channel BGR (probability) +mosaic: 1.0 # (float) image mosaic (probability) +mixup: 0.0 # (float) image mixup (probability) +copy_paste: 0.0 # (float) segment copy-paste (probability) +auto_augment: randaugment # (str) auto augmentation policy for classification (randaugment, autoaugment, augmix) +erasing: 0.4 # (float) probability of random erasing during classification training (0-1) +crop_fraction: 1.0 # (float) image crop fraction for classification evaluation/inference (0-1) # Custom config.yaml --------------------------------------------------------------------------------------------------- -cfg: # (str, optional) for overriding defaults.yaml +cfg: # (str, optional) for overriding defaults.yaml # Tracker settings ------------------------------------------------------------------------------------------------------ -tracker: botsort.yaml # (str) tracker type, choices=[botsort.yaml, bytetrack.yaml] +tracker: botsort.yaml # (str) tracker type, choices=[botsort.yaml, bytetrack.yaml] diff --git a/ultralytics/cfg/models/README.md b/ultralytics/cfg/models/README.md index 4749441..c022fb5 100644 --- a/ultralytics/cfg/models/README.md +++ b/ultralytics/cfg/models/README.md @@ -14,8 +14,7 @@ Model `*.yaml` files may be used directly in the Command Line Interface (CLI) wi yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100 ``` -They may also be used directly in a Python environment, and accepts the same -[arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above: +They may also be used directly in a Python environment, and accepts the same [arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above: ```python from ultralytics import YOLO diff --git a/ultralytics/cfg/models/rt-detr/rtdetr-l.yaml b/ultralytics/cfg/models/rt-detr/rtdetr-l.yaml index bd20da1..c6eb0b3 100644 --- a/ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +++ b/ultralytics/cfg/models/rt-detr/rtdetr-l.yaml @@ -2,49 +2,49 @@ # RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr # Parameters -nc: 80 # number of classes +nc: 80 # number of classes scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' # [depth, width, max_channels] l: [1.00, 1.00, 1024] backbone: # [from, repeats, module, args] - - [-1, 1, HGStem, [32, 48]] # 0-P2/4 - - [-1, 6, HGBlock, [48, 128, 3]] # stage 1 + - [-1, 1, HGStem, [32, 48]] # 0-P2/4 + - [-1, 6, HGBlock, [48, 128, 3]] # stage 1 - - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8 - - [-1, 6, HGBlock, [96, 512, 3]] # stage 2 + - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8 + - [-1, 6, HGBlock, [96, 512, 3]] # stage 2 - - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16 - - [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut + - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16 + - [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut - [-1, 6, HGBlock, [192, 1024, 5, True, True]] - - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3 + - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3 - - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32 - - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4 + - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32 + - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4 head: - - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2 + - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2 - [-1, 1, AIFI, [1024, 8]] - - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0 + - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0 - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1 + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1 - [[-2, -1], 1, Concat, [1]] - - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0 - - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1 + - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0 + - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1 - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0 - - [[-2, -1], 1, Concat, [1]] # cat backbone P4 - - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1 + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0 + - [[-2, -1], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1 - - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0 - - [[-1, 17], 1, Concat, [1]] # cat Y4 - - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0 + - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0 + - [[-1, 17], 1, Concat, [1]] # cat Y4 + - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0 - - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1 - - [[-1, 12], 1, Concat, [1]] # cat Y5 - - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1 + - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1 + - [[-1, 12], 1, Concat, [1]] # cat Y5 + - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1 - - [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) + - [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml b/ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml new file mode 100644 index 0000000..a68bb5d --- /dev/null +++ b/ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml @@ -0,0 +1,42 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# RT-DETR-ResNet101 object detection model with P3-P5 outputs. + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + l: [1.00, 1.00, 1024] + +backbone: + # [from, repeats, module, args] + - [-1, 1, ResNetLayer, [3, 64, 1, True, 1]] # 0 + - [-1, 1, ResNetLayer, [64, 64, 1, False, 3]] # 1 + - [-1, 1, ResNetLayer, [256, 128, 2, False, 4]] # 2 + - [-1, 1, ResNetLayer, [512, 256, 2, False, 23]] # 3 + - [-1, 1, ResNetLayer, [1024, 512, 2, False, 3]] # 4 + +head: + - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 5 + - [-1, 1, AIFI, [1024, 8]] + - [-1, 1, Conv, [256, 1, 1]] # 7 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 9 + - [[-2, -1], 1, Concat, [1]] + - [-1, 3, RepC3, [256]] # 11 + - [-1, 1, Conv, [256, 1, 1]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [2, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 + - [[-2, -1], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, RepC3, [256]] # X3 (16), fpn_blocks.1 + + - [-1, 1, Conv, [256, 3, 2]] # 17, downsample_convs.0 + - [[-1, 12], 1, Concat, [1]] # cat Y4 + - [-1, 3, RepC3, [256]] # F4 (19), pan_blocks.0 + + - [-1, 1, Conv, [256, 3, 2]] # 20, downsample_convs.1 + - [[-1, 7], 1, Concat, [1]] # cat Y5 + - [-1, 3, RepC3, [256]] # F5 (22), pan_blocks.1 + + - [[16, 19, 22], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml b/ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml new file mode 100644 index 0000000..7145910 --- /dev/null +++ b/ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml @@ -0,0 +1,42 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# RT-DETR-ResNet50 object detection model with P3-P5 outputs. + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + l: [1.00, 1.00, 1024] + +backbone: + # [from, repeats, module, args] + - [-1, 1, ResNetLayer, [3, 64, 1, True, 1]] # 0 + - [-1, 1, ResNetLayer, [64, 64, 1, False, 3]] # 1 + - [-1, 1, ResNetLayer, [256, 128, 2, False, 4]] # 2 + - [-1, 1, ResNetLayer, [512, 256, 2, False, 6]] # 3 + - [-1, 1, ResNetLayer, [1024, 512, 2, False, 3]] # 4 + +head: + - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 5 + - [-1, 1, AIFI, [1024, 8]] + - [-1, 1, Conv, [256, 1, 1]] # 7 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 9 + - [[-2, -1], 1, Concat, [1]] + - [-1, 3, RepC3, [256]] # 11 + - [-1, 1, Conv, [256, 1, 1]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [2, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 + - [[-2, -1], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, RepC3, [256]] # X3 (16), fpn_blocks.1 + + - [-1, 1, Conv, [256, 3, 2]] # 17, downsample_convs.0 + - [[-1, 12], 1, Concat, [1]] # cat Y4 + - [-1, 3, RepC3, [256]] # F4 (19), pan_blocks.0 + + - [-1, 1, Conv, [256, 3, 2]] # 20, downsample_convs.1 + - [[-1, 7], 1, Concat, [1]] # cat Y5 + - [-1, 3, RepC3, [256]] # F5 (22), pan_blocks.1 + + - [[16, 19, 22], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/rt-detr/rtdetr-x.yaml b/ultralytics/cfg/models/rt-detr/rtdetr-x.yaml index 848cb52..0e819b0 100644 --- a/ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +++ b/ultralytics/cfg/models/rt-detr/rtdetr-x.yaml @@ -2,53 +2,53 @@ # RT-DETR-x object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr # Parameters -nc: 80 # number of classes +nc: 80 # number of classes scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' # [depth, width, max_channels] x: [1.00, 1.00, 2048] backbone: # [from, repeats, module, args] - - [-1, 1, HGStem, [32, 64]] # 0-P2/4 - - [-1, 6, HGBlock, [64, 128, 3]] # stage 1 + - [-1, 1, HGStem, [32, 64]] # 0-P2/4 + - [-1, 6, HGBlock, [64, 128, 3]] # stage 1 - - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8 + - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8 - [-1, 6, HGBlock, [128, 512, 3]] - - [-1, 6, HGBlock, [128, 512, 3, False, True]] # 4-stage 2 + - [-1, 6, HGBlock, [128, 512, 3, False, True]] # 4-stage 2 - - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 5-P3/16 - - [-1, 6, HGBlock, [256, 1024, 5, True, False]] # cm, c2, k, light, shortcut + - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 5-P3/16 + - [-1, 6, HGBlock, [256, 1024, 5, True, False]] # cm, c2, k, light, shortcut - [-1, 6, HGBlock, [256, 1024, 5, True, True]] - [-1, 6, HGBlock, [256, 1024, 5, True, True]] - [-1, 6, HGBlock, [256, 1024, 5, True, True]] - - [-1, 6, HGBlock, [256, 1024, 5, True, True]] # 10-stage 3 + - [-1, 6, HGBlock, [256, 1024, 5, True, True]] # 10-stage 3 - - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 11-P4/32 + - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 11-P4/32 - [-1, 6, HGBlock, [512, 2048, 5, True, False]] - - [-1, 6, HGBlock, [512, 2048, 5, True, True]] # 13-stage 4 + - [-1, 6, HGBlock, [512, 2048, 5, True, True]] # 13-stage 4 head: - - [-1, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 14 input_proj.2 + - [-1, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 14 input_proj.2 - [-1, 1, AIFI, [2048, 8]] - - [-1, 1, Conv, [384, 1, 1]] # 16, Y5, lateral_convs.0 + - [-1, 1, Conv, [384, 1, 1]] # 16, Y5, lateral_convs.0 - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [10, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 18 input_proj.1 + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [10, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 18 input_proj.1 - [[-2, -1], 1, Concat, [1]] - - [-1, 3, RepC3, [384]] # 20, fpn_blocks.0 - - [-1, 1, Conv, [384, 1, 1]] # 21, Y4, lateral_convs.1 + - [-1, 3, RepC3, [384]] # 20, fpn_blocks.0 + - [-1, 1, Conv, [384, 1, 1]] # 21, Y4, lateral_convs.1 - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [4, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 23 input_proj.0 - - [[-2, -1], 1, Concat, [1]] # cat backbone P4 - - [-1, 3, RepC3, [384]] # X3 (25), fpn_blocks.1 + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [4, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 23 input_proj.0 + - [[-2, -1], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, RepC3, [384]] # X3 (25), fpn_blocks.1 - - [-1, 1, Conv, [384, 3, 2]] # 26, downsample_convs.0 - - [[-1, 21], 1, Concat, [1]] # cat Y4 - - [-1, 3, RepC3, [384]] # F4 (28), pan_blocks.0 + - [-1, 1, Conv, [384, 3, 2]] # 26, downsample_convs.0 + - [[-1, 21], 1, Concat, [1]] # cat Y4 + - [-1, 3, RepC3, [384]] # F4 (28), pan_blocks.0 - - [-1, 1, Conv, [384, 3, 2]] # 29, downsample_convs.1 - - [[-1, 16], 1, Concat, [1]] # cat Y5 - - [-1, 3, RepC3, [384]] # F5 (31), pan_blocks.1 + - [-1, 1, Conv, [384, 3, 2]] # 29, downsample_convs.1 + - [[-1, 16], 1, Concat, [1]] # cat Y5 + - [-1, 3, RepC3, [384]] # F5 (31), pan_blocks.1 - - [[25, 28, 31], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) + - [[25, 28, 31], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10b.yaml b/ultralytics/cfg/models/v10/yolov10b.yaml new file mode 100644 index 0000000..a9dc721 --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10b.yaml @@ -0,0 +1,40 @@ +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + b: [0.67, 1.00, 512] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fCIB, [512, True]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10l.yaml b/ultralytics/cfg/models/v10/yolov10l.yaml new file mode 100644 index 0000000..047de26 --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10l.yaml @@ -0,0 +1,40 @@ +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fCIB, [512, True]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10m.yaml b/ultralytics/cfg/models/v10/yolov10m.yaml new file mode 100644 index 0000000..5bdb5bf --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10m.yaml @@ -0,0 +1,43 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10n.yaml b/ultralytics/cfg/models/v10/yolov10n.yaml new file mode 100644 index 0000000..1ee7437 --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10n.yaml @@ -0,0 +1,40 @@ +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10s.yaml b/ultralytics/cfg/models/v10/yolov10s.yaml new file mode 100644 index 0000000..c61e08c --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10s.yaml @@ -0,0 +1,39 @@ +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + s: [0.33, 0.50, 1024] + +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10x.yaml b/ultralytics/cfg/models/v10/yolov10x.yaml new file mode 100644 index 0000000..ab5fc8f --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10x.yaml @@ -0,0 +1,40 @@ +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + x: [1.00, 1.25, 512] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2fCIB, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fCIB, [512, True]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v3/yolov3-spp.yaml b/ultralytics/cfg/models/v3/yolov3-spp.yaml index 406e019..6724f4e 100644 --- a/ultralytics/cfg/models/v3/yolov3-spp.yaml +++ b/ultralytics/cfg/models/v3/yolov3-spp.yaml @@ -2,47 +2,45 @@ # YOLOv3-SPP object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3 # Parameters -nc: 80 # number of classes -depth_multiple: 1.0 # model depth multiple -width_multiple: 1.0 # layer channel multiple +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple # darknet53 backbone backbone: # [from, number, module, args] - [[-1, 1, Conv, [32, 3, 1]], # 0 - [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 - [-1, 1, Bottleneck, [64]], - [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 - [-1, 2, Bottleneck, [128]], - [-1, 1, Conv, [256, 3, 2]], # 5-P3/8 - [-1, 8, Bottleneck, [256]], - [-1, 1, Conv, [512, 3, 2]], # 7-P4/16 - [-1, 8, Bottleneck, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32 - [-1, 4, Bottleneck, [1024]], # 10 - ] + - [-1, 1, Conv, [32, 3, 1]] # 0 + - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2 + - [-1, 1, Bottleneck, [64]] + - [-1, 1, Conv, [128, 3, 2]] # 3-P2/4 + - [-1, 2, Bottleneck, [128]] + - [-1, 1, Conv, [256, 3, 2]] # 5-P3/8 + - [-1, 8, Bottleneck, [256]] + - [-1, 1, Conv, [512, 3, 2]] # 7-P4/16 + - [-1, 8, Bottleneck, [512]] + - [-1, 1, Conv, [1024, 3, 2]] # 9-P5/32 + - [-1, 4, Bottleneck, [1024]] # 10 # YOLOv3-SPP head head: - [[-1, 1, Bottleneck, [1024, False]], - [-1, 1, SPP, [512, [5, 9, 13]]], - [-1, 1, Conv, [1024, 3, 1]], - [-1, 1, Conv, [512, 1, 1]], - [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large) + - [-1, 1, Bottleneck, [1024, False]] + - [-1, 1, SPP, [512, [5, 9, 13]]] + - [-1, 1, Conv, [1024, 3, 1]] + - [-1, 1, Conv, [512, 1, 1]] + - [-1, 1, Conv, [1024, 3, 1]] # 15 (P5/32-large) - [-2, 1, Conv, [256, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 8], 1, Concat, [1]], # cat backbone P4 - [-1, 1, Bottleneck, [512, False]], - [-1, 1, Bottleneck, [512, False]], - [-1, 1, Conv, [256, 1, 1]], - [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium) + - [-2, 1, Conv, [256, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, Bottleneck, [512, False]] + - [-1, 1, Bottleneck, [512, False]] + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, Conv, [512, 3, 1]] # 22 (P4/16-medium) - [-2, 1, Conv, [128, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 6], 1, Concat, [1]], # cat backbone P3 - [-1, 1, Bottleneck, [256, False]], - [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small) + - [-2, 1, Conv, [128, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, Bottleneck, [256, False]] + - [-1, 2, Bottleneck, [256, False]] # 27 (P3/8-small) - [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5) - ] + - [[27, 22, 15], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v3/yolov3-tiny.yaml b/ultralytics/cfg/models/v3/yolov3-tiny.yaml index 69d8e42..f3fe257 100644 --- a/ultralytics/cfg/models/v3/yolov3-tiny.yaml +++ b/ultralytics/cfg/models/v3/yolov3-tiny.yaml @@ -2,38 +2,36 @@ # YOLOv3-tiny object detection model with P4-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3 # Parameters -nc: 80 # number of classes -depth_multiple: 1.0 # model depth multiple -width_multiple: 1.0 # layer channel multiple +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple # YOLOv3-tiny backbone backbone: # [from, number, module, args] - [[-1, 1, Conv, [16, 3, 1]], # 0 - [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2 - [-1, 1, Conv, [32, 3, 1]], - [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4 - [-1, 1, Conv, [64, 3, 1]], - [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8 - [-1, 1, Conv, [128, 3, 1]], - [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16 - [-1, 1, Conv, [256, 3, 1]], - [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32 - [-1, 1, Conv, [512, 3, 1]], - [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]], # 11 - [-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12 - ] + - [-1, 1, Conv, [16, 3, 1]] # 0 + - [-1, 1, nn.MaxPool2d, [2, 2, 0]] # 1-P1/2 + - [-1, 1, Conv, [32, 3, 1]] + - [-1, 1, nn.MaxPool2d, [2, 2, 0]] # 3-P2/4 + - [-1, 1, Conv, [64, 3, 1]] + - [-1, 1, nn.MaxPool2d, [2, 2, 0]] # 5-P3/8 + - [-1, 1, Conv, [128, 3, 1]] + - [-1, 1, nn.MaxPool2d, [2, 2, 0]] # 7-P4/16 + - [-1, 1, Conv, [256, 3, 1]] + - [-1, 1, nn.MaxPool2d, [2, 2, 0]] # 9-P5/32 + - [-1, 1, Conv, [512, 3, 1]] + - [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]] # 11 + - [-1, 1, nn.MaxPool2d, [2, 1, 0]] # 12 # YOLOv3-tiny head head: - [[-1, 1, Conv, [1024, 3, 1]], - [-1, 1, Conv, [256, 1, 1]], - [-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large) + - [-1, 1, Conv, [1024, 3, 1]] + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, Conv, [512, 3, 1]] # 15 (P5/32-large) - [-2, 1, Conv, [128, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 8], 1, Concat, [1]], # cat backbone P4 - [-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium) + - [-2, 1, Conv, [128, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, Conv, [256, 3, 1]] # 19 (P4/16-medium) - [[19, 15], 1, Detect, [nc]], # Detect(P4, P5) - ] + - [[19, 15], 1, Detect, [nc]] # Detect(P4, P5) diff --git a/ultralytics/cfg/models/v3/yolov3.yaml b/ultralytics/cfg/models/v3/yolov3.yaml index 7cc0afa..716866a 100644 --- a/ultralytics/cfg/models/v3/yolov3.yaml +++ b/ultralytics/cfg/models/v3/yolov3.yaml @@ -2,47 +2,45 @@ # YOLOv3 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3 # Parameters -nc: 80 # number of classes -depth_multiple: 1.0 # model depth multiple -width_multiple: 1.0 # layer channel multiple +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple # darknet53 backbone backbone: # [from, number, module, args] - [[-1, 1, Conv, [32, 3, 1]], # 0 - [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 - [-1, 1, Bottleneck, [64]], - [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 - [-1, 2, Bottleneck, [128]], - [-1, 1, Conv, [256, 3, 2]], # 5-P3/8 - [-1, 8, Bottleneck, [256]], - [-1, 1, Conv, [512, 3, 2]], # 7-P4/16 - [-1, 8, Bottleneck, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32 - [-1, 4, Bottleneck, [1024]], # 10 - ] + - [-1, 1, Conv, [32, 3, 1]] # 0 + - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2 + - [-1, 1, Bottleneck, [64]] + - [-1, 1, Conv, [128, 3, 2]] # 3-P2/4 + - [-1, 2, Bottleneck, [128]] + - [-1, 1, Conv, [256, 3, 2]] # 5-P3/8 + - [-1, 8, Bottleneck, [256]] + - [-1, 1, Conv, [512, 3, 2]] # 7-P4/16 + - [-1, 8, Bottleneck, [512]] + - [-1, 1, Conv, [1024, 3, 2]] # 9-P5/32 + - [-1, 4, Bottleneck, [1024]] # 10 # YOLOv3 head head: - [[-1, 1, Bottleneck, [1024, False]], - [-1, 1, Conv, [512, 1, 1]], - [-1, 1, Conv, [1024, 3, 1]], - [-1, 1, Conv, [512, 1, 1]], - [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large) + - [-1, 1, Bottleneck, [1024, False]] + - [-1, 1, Conv, [512, 1, 1]] + - [-1, 1, Conv, [1024, 3, 1]] + - [-1, 1, Conv, [512, 1, 1]] + - [-1, 1, Conv, [1024, 3, 1]] # 15 (P5/32-large) - [-2, 1, Conv, [256, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 8], 1, Concat, [1]], # cat backbone P4 - [-1, 1, Bottleneck, [512, False]], - [-1, 1, Bottleneck, [512, False]], - [-1, 1, Conv, [256, 1, 1]], - [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium) + - [-2, 1, Conv, [256, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, Bottleneck, [512, False]] + - [-1, 1, Bottleneck, [512, False]] + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, Conv, [512, 3, 1]] # 22 (P4/16-medium) - [-2, 1, Conv, [128, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 6], 1, Concat, [1]], # cat backbone P3 - [-1, 1, Bottleneck, [256, False]], - [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small) + - [-2, 1, Conv, [128, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, Bottleneck, [256, False]] + - [-1, 2, Bottleneck, [256, False]] # 27 (P3/8-small) - [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5) - ] + - [[27, 22, 15], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v5/yolov5-p6.yaml b/ultralytics/cfg/models/v5/yolov5-p6.yaml index d468377..2fd3ac7 100644 --- a/ultralytics/cfg/models/v5/yolov5-p6.yaml +++ b/ultralytics/cfg/models/v5/yolov5-p6.yaml @@ -2,7 +2,7 @@ # YOLOv5 object detection model with P3-P6 outputs. For details see https://docs.ultralytics.com/models/yolov5 # Parameters -nc: 80 # number of classes +nc: 80 # number of classes scales: # model compound scaling constants, i.e. 'model=yolov5n-p6.yaml' will call yolov5-p6.yaml with scale 'n' # [depth, width, max_channels] n: [0.33, 0.25, 1024] @@ -14,48 +14,46 @@ scales: # model compound scaling constants, i.e. 'model=yolov5n-p6.yaml' will ca # YOLOv5 v6.0 backbone backbone: # [from, number, module, args] - [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 - [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, C3, [128]], - [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 6, C3, [256]], - [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, C3, [512]], - [-1, 1, Conv, [768, 3, 2]], # 7-P5/32 - [-1, 3, C3, [768]], - [-1, 1, Conv, [1024, 3, 2]], # 9-P6/64 - [-1, 3, C3, [1024]], - [-1, 1, SPPF, [1024, 5]], # 11 - ] + - [-1, 1, Conv, [64, 6, 2, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C3, [128]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C3, [256]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 9, C3, [512]] + - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 + - [-1, 3, C3, [768]] + - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 + - [-1, 3, C3, [1024]] + - [-1, 1, SPPF, [1024, 5]] # 11 # YOLOv5 v6.0 head head: - [[-1, 1, Conv, [768, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 8], 1, Concat, [1]], # cat backbone P5 - [-1, 3, C3, [768, False]], # 15 + - [-1, 1, Conv, [768, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P5 + - [-1, 3, C3, [768, False]] # 15 - [-1, 1, Conv, [512, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, C3, [512, False]], # 19 + - [-1, 1, Conv, [512, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C3, [512, False]] # 19 - [-1, 1, Conv, [256, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, C3, [256, False]], # 23 (P3/8-small) + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C3, [256, False]] # 23 (P3/8-small) - [-1, 1, Conv, [256, 3, 2]], - [[-1, 20], 1, Concat, [1]], # cat head P4 - [-1, 3, C3, [512, False]], # 26 (P4/16-medium) + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 20], 1, Concat, [1]] # cat head P4 + - [-1, 3, C3, [512, False]] # 26 (P4/16-medium) - [-1, 1, Conv, [512, 3, 2]], - [[-1, 16], 1, Concat, [1]], # cat head P5 - [-1, 3, C3, [768, False]], # 29 (P5/32-large) + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 16], 1, Concat, [1]] # cat head P5 + - [-1, 3, C3, [768, False]] # 29 (P5/32-large) - [-1, 1, Conv, [768, 3, 2]], - [[-1, 12], 1, Concat, [1]], # cat head P6 - [-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge) + - [-1, 1, Conv, [768, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P6 + - [-1, 3, C3, [1024, False]] # 32 (P6/64-xlarge) - [[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6) - ] + - [[23, 26, 29, 32], 1, Detect, [nc]] # Detect(P3, P4, P5, P6) diff --git a/ultralytics/cfg/models/v5/yolov5.yaml b/ultralytics/cfg/models/v5/yolov5.yaml index 4a3fced..8fdc79e 100644 --- a/ultralytics/cfg/models/v5/yolov5.yaml +++ b/ultralytics/cfg/models/v5/yolov5.yaml @@ -2,7 +2,7 @@ # YOLOv5 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov5 # Parameters -nc: 80 # number of classes +nc: 80 # number of classes scales: # model compound scaling constants, i.e. 'model=yolov5n.yaml' will call yolov5.yaml with scale 'n' # [depth, width, max_channels] n: [0.33, 0.25, 1024] @@ -14,37 +14,35 @@ scales: # model compound scaling constants, i.e. 'model=yolov5n.yaml' will call # YOLOv5 v6.0 backbone backbone: # [from, number, module, args] - [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 - [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, C3, [128]], - [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 6, C3, [256]], - [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, C3, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 - [-1, 3, C3, [1024]], - [-1, 1, SPPF, [1024, 5]], # 9 - ] + - [-1, 1, Conv, [64, 6, 2, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C3, [128]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C3, [256]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 9, C3, [512]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C3, [1024]] + - [-1, 1, SPPF, [1024, 5]] # 9 # YOLOv5 v6.0 head head: - [[-1, 1, Conv, [512, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, C3, [512, False]], # 13 + - [-1, 1, Conv, [512, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C3, [512, False]] # 13 - [-1, 1, Conv, [256, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, C3, [256, False]], # 17 (P3/8-small) + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C3, [256, False]] # 17 (P3/8-small) - [-1, 1, Conv, [256, 3, 2]], - [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 14], 1, Concat, [1]] # cat head P4 + - [-1, 3, C3, [512, False]] # 20 (P4/16-medium) - [-1, 1, Conv, [512, 3, 2]], - [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C3, [1024, False]] # 23 (P5/32-large) - [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5) - ] + - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v6/yolov6.yaml b/ultralytics/cfg/models/v6/yolov6.yaml index cb5e32a..f39dfb4 100644 --- a/ultralytics/cfg/models/v6/yolov6.yaml +++ b/ultralytics/cfg/models/v6/yolov6.yaml @@ -2,8 +2,8 @@ # YOLOv6 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/models/yolov6 # Parameters -nc: 80 # number of classes -activation: nn.ReLU() # (optional) model default activation function +nc: 80 # number of classes +activation: nn.ReLU() # (optional) model default activation function scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n' # [depth, width, max_channels] n: [0.33, 0.25, 1024] @@ -15,39 +15,39 @@ scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call # YOLOv6-3.0s backbone backbone: # [from, repeats, module, args] - - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 - - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 - [-1, 6, Conv, [128, 3, 1]] - - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 - [-1, 12, Conv, [256, 3, 1]] - - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 - [-1, 18, Conv, [512, 3, 1]] - - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 - [-1, 6, Conv, [1024, 3, 1]] - - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, SPPF, [1024, 5]] # 9 # YOLOv6-3.0s head head: - [-1, 1, Conv, [256, 1, 1]] - [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]] - - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 - [-1, 1, Conv, [256, 3, 1]] - - [-1, 9, Conv, [256, 3, 1]] # 14 + - [-1, 9, Conv, [256, 3, 1]] # 14 - [-1, 1, Conv, [128, 1, 1]] - [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]] - - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 - [-1, 1, Conv, [128, 3, 1]] - - [-1, 9, Conv, [128, 3, 1]] # 19 + - [-1, 9, Conv, [128, 3, 1]] # 19 - [-1, 1, Conv, [128, 3, 2]] - - [[-1, 15], 1, Concat, [1]] # cat head P4 + - [[-1, 15], 1, Concat, [1]] # cat head P4 - [-1, 1, Conv, [256, 3, 1]] - - [-1, 9, Conv, [256, 3, 1]] # 23 + - [-1, 9, Conv, [256, 3, 1]] # 23 - [-1, 1, Conv, [256, 3, 2]] - - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [[-1, 10], 1, Concat, [1]] # cat head P5 - [-1, 1, Conv, [512, 3, 1]] - - [-1, 9, Conv, [512, 3, 1]] # 27 + - [-1, 9, Conv, [512, 3, 1]] # 27 - - [[19, 23, 27], 1, Detect, [nc]] # Detect(P3, P4, P5) + - [[19, 23, 27], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml b/ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml new file mode 100644 index 0000000..6867f88 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml @@ -0,0 +1,25 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# YOLOv8-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify + +# Parameters +nc: 1000 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 1024] + l: [1.00, 1.00, 1024] + x: [1.00, 1.25, 1024] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, ResNetLayer, [3, 64, 1, True, 1]] # 0-P1/2 + - [-1, 1, ResNetLayer, [64, 64, 1, False, 3]] # 1-P2/4 + - [-1, 1, ResNetLayer, [256, 128, 2, False, 4]] # 2-P3/8 + - [-1, 1, ResNetLayer, [512, 256, 2, False, 23]] # 3-P4/16 + - [-1, 1, ResNetLayer, [1024, 512, 2, False, 3]] # 4-P5/32 + +# YOLOv8.0n head +head: + - [-1, 1, Classify, [nc]] # Classify diff --git a/ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml b/ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml new file mode 100644 index 0000000..8ffd111 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml @@ -0,0 +1,25 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# YOLOv8-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify + +# Parameters +nc: 1000 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 1024] + l: [1.00, 1.00, 1024] + x: [1.00, 1.25, 1024] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, ResNetLayer, [3, 64, 1, True, 1]] # 0-P1/2 + - [-1, 1, ResNetLayer, [64, 64, 1, False, 3]] # 1-P2/4 + - [-1, 1, ResNetLayer, [256, 128, 2, False, 4]] # 2-P3/8 + - [-1, 1, ResNetLayer, [512, 256, 2, False, 6]] # 3-P4/16 + - [-1, 1, ResNetLayer, [1024, 512, 2, False, 3]] # 4-P5/32 + +# YOLOv8.0n head +head: + - [-1, 1, Classify, [nc]] # Classify diff --git a/ultralytics/cfg/models/v8/yolov8-cls.yaml b/ultralytics/cfg/models/v8/yolov8-cls.yaml index 5332f1d..180fc65 100644 --- a/ultralytics/cfg/models/v8/yolov8-cls.yaml +++ b/ultralytics/cfg/models/v8/yolov8-cls.yaml @@ -2,7 +2,7 @@ # YOLOv8-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify # Parameters -nc: 1000 # number of classes +nc: 1000 # number of classes scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' # [depth, width, max_channels] n: [0.33, 0.25, 1024] @@ -14,16 +14,16 @@ scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will c # YOLOv8.0n backbone backbone: # [from, repeats, module, args] - - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 - - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 - [-1, 3, C2f, [128, True]] - - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 - [-1, 6, C2f, [256, True]] - - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 - [-1, 6, C2f, [512, True]] - - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 - [-1, 3, C2f, [1024, True]] # YOLOv8.0n head head: - - [-1, 1, Classify, [nc]] # Classify + - [-1, 1, Classify, [nc]] # Classify diff --git a/ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml b/ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml new file mode 100644 index 0000000..aee2093 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml @@ -0,0 +1,54 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# YOLOv8 object detection model with P2-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n-ghost-p2 summary: 491 layers, 2033944 parameters, 2033928 gradients, 13.8 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s-ghost-p2 summary: 491 layers, 5562080 parameters, 5562064 gradients, 25.1 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m-ghost-p2 summary: 731 layers, 9031728 parameters, 9031712 gradients, 42.8 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l-ghost-p2 summary: 971 layers, 12214448 parameters, 12214432 gradients, 69.1 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x-ghost-p2 summary: 971 layers, 18664776 parameters, 18664760 gradients, 103.3 GFLOPs + +# YOLOv8.0-ghost backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, GhostConv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C3Ghost, [128, True]] + - [-1, 1, GhostConv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C3Ghost, [256, True]] + - [-1, 1, GhostConv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C3Ghost, [512, True]] + - [-1, 1, GhostConv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C3Ghost, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0-ghost-p2 head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C3Ghost, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C3Ghost, [256]] # 15 (P3/8-small) + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 2], 1, Concat, [1]] # cat backbone P2 + - [-1, 3, C3Ghost, [128]] # 18 (P2/4-xsmall) + + - [-1, 1, GhostConv, [128, 3, 2]] + - [[-1, 15], 1, Concat, [1]] # cat head P3 + - [-1, 3, C3Ghost, [256]] # 21 (P3/8-small) + + - [-1, 1, GhostConv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C3Ghost, [512]] # 24 (P4/16-medium) + + - [-1, 1, GhostConv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C3Ghost, [1024]] # 27 (P5/32-large) + + - [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml b/ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml new file mode 100644 index 0000000..b35f4cd --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml @@ -0,0 +1,56 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# YOLOv8 object detection model with P3-P6 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n-ghost-p6 summary: 529 layers, 2901100 parameters, 2901084 gradients, 5.8 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s-ghost-p6 summary: 529 layers, 9520008 parameters, 9519992 gradients, 16.4 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m-ghost-p6 summary: 789 layers, 18002904 parameters, 18002888 gradients, 34.4 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l-ghost-p6 summary: 1049 layers, 21227584 parameters, 21227568 gradients, 55.3 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x-ghost-p6 summary: 1049 layers, 33057852 parameters, 33057836 gradients, 85.7 GFLOPs + +# YOLOv8.0-ghost backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, GhostConv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C3Ghost, [128, True]] + - [-1, 1, GhostConv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C3Ghost, [256, True]] + - [-1, 1, GhostConv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C3Ghost, [512, True]] + - [-1, 1, GhostConv, [768, 3, 2]] # 7-P5/32 + - [-1, 3, C3Ghost, [768, True]] + - [-1, 1, GhostConv, [1024, 3, 2]] # 9-P6/64 + - [-1, 3, C3Ghost, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 11 + +# YOLOv8.0-ghost-p6 head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P5 + - [-1, 3, C3Ghost, [768]] # 14 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C3Ghost, [512]] # 17 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C3Ghost, [256]] # 20 (P3/8-small) + + - [-1, 1, GhostConv, [256, 3, 2]] + - [[-1, 17], 1, Concat, [1]] # cat head P4 + - [-1, 3, C3Ghost, [512]] # 23 (P4/16-medium) + + - [-1, 1, GhostConv, [512, 3, 2]] + - [[-1, 14], 1, Concat, [1]] # cat head P5 + - [-1, 3, C3Ghost, [768]] # 26 (P5/32-large) + + - [-1, 1, GhostConv, [768, 3, 2]] + - [[-1, 11], 1, Concat, [1]] # cat head P6 + - [-1, 3, C3Ghost, [1024]] # 29 (P6/64-xlarge) + + - [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6) diff --git a/ultralytics/cfg/models/v8/yolov8-ghost.yaml b/ultralytics/cfg/models/v8/yolov8-ghost.yaml new file mode 100644 index 0000000..adc1802 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-ghost.yaml @@ -0,0 +1,47 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect +# Employs Ghost convolutions and modules proposed in Huawei's GhostNet in https://arxiv.org/abs/1911.11907v2 + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n-ghost summary: 403 layers, 1865316 parameters, 1865300 gradients, 5.8 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s-ghost summary: 403 layers, 5960072 parameters, 5960056 gradients, 16.4 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m-ghost summary: 603 layers, 10336312 parameters, 10336296 gradients, 32.7 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l-ghost summary: 803 layers, 14277872 parameters, 14277856 gradients, 53.7 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x-ghost summary: 803 layers, 22229308 parameters, 22229292 gradients, 83.3 GFLOPs + +# YOLOv8.0n-ghost backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, GhostConv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C3Ghost, [128, True]] + - [-1, 1, GhostConv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C3Ghost, [256, True]] + - [-1, 1, GhostConv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C3Ghost, [512, True]] + - [-1, 1, GhostConv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C3Ghost, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C3Ghost, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C3Ghost, [256]] # 15 (P3/8-small) + + - [-1, 1, GhostConv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C3Ghost, [512]] # 18 (P4/16-medium) + + - [-1, 1, GhostConv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C3Ghost, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-obb.yaml b/ultralytics/cfg/models/v8/yolov8-obb.yaml new file mode 100644 index 0000000..7a7f60c --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-obb.yaml @@ -0,0 +1,46 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# YOLOv8 Oriented Bounding Boxes (OBB) model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, OBB, [nc, 1]] # OBB(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-p2.yaml b/ultralytics/cfg/models/v8/yolov8-p2.yaml index 3e286aa..5392774 100644 --- a/ultralytics/cfg/models/v8/yolov8-p2.yaml +++ b/ultralytics/cfg/models/v8/yolov8-p2.yaml @@ -2,7 +2,7 @@ # YOLOv8 object detection model with P2-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect # Parameters -nc: 80 # number of classes +nc: 80 # number of classes scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' # [depth, width, max_channels] n: [0.33, 0.25, 1024] @@ -14,41 +14,41 @@ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call # YOLOv8.0 backbone backbone: # [from, repeats, module, args] - - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 - - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 - [-1, 3, C2f, [128, True]] - - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 - [-1, 6, C2f, [256, True]] - - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 - [-1, 6, C2f, [512, True]] - - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 - [-1, 3, C2f, [1024, True]] - - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, SPPF, [1024, 5]] # 9 # YOLOv8.0-p2 head head: - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 6], 1, Concat, [1]] # cat backbone P4 - - [-1, 3, C2f, [512]] # 12 + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 4], 1, Concat, [1]] # cat backbone P3 - - [-1, 3, C2f, [256]] # 15 (P3/8-small) + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 2], 1, Concat, [1]] # cat backbone P2 - - [-1, 3, C2f, [128]] # 18 (P2/4-xsmall) + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 2], 1, Concat, [1]] # cat backbone P2 + - [-1, 3, C2f, [128]] # 18 (P2/4-xsmall) - [-1, 1, Conv, [128, 3, 2]] - - [[-1, 15], 1, Concat, [1]] # cat head P3 - - [-1, 3, C2f, [256]] # 21 (P3/8-small) + - [[-1, 15], 1, Concat, [1]] # cat head P3 + - [-1, 3, C2f, [256]] # 21 (P3/8-small) - [-1, 1, Conv, [256, 3, 2]] - - [[-1, 12], 1, Concat, [1]] # cat head P4 - - [-1, 3, C2f, [512]] # 24 (P4/16-medium) + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 24 (P4/16-medium) - [-1, 1, Conv, [512, 3, 2]] - - [[-1, 9], 1, Concat, [1]] # cat head P5 - - [-1, 3, C2f, [1024]] # 27 (P5/32-large) + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 27 (P5/32-large) - - [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5) + - [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-p6.yaml b/ultralytics/cfg/models/v8/yolov8-p6.yaml index 3635ed9..2d6d5f9 100644 --- a/ultralytics/cfg/models/v8/yolov8-p6.yaml +++ b/ultralytics/cfg/models/v8/yolov8-p6.yaml @@ -2,7 +2,7 @@ # YOLOv8 object detection model with P3-P6 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect # Parameters -nc: 80 # number of classes +nc: 80 # number of classes scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n' # [depth, width, max_channels] n: [0.33, 0.25, 1024] @@ -14,43 +14,43 @@ scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will ca # YOLOv8.0x6 backbone backbone: # [from, repeats, module, args] - - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 - - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 - [-1, 3, C2f, [128, True]] - - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 - [-1, 6, C2f, [256, True]] - - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 - [-1, 6, C2f, [512, True]] - - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 + - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 - [-1, 3, C2f, [768, True]] - - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 + - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 - [-1, 3, C2f, [1024, True]] - - [-1, 1, SPPF, [1024, 5]] # 11 + - [-1, 1, SPPF, [1024, 5]] # 11 # YOLOv8.0x6 head head: - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 8], 1, Concat, [1]] # cat backbone P5 - - [-1, 3, C2, [768, False]] # 14 + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P5 + - [-1, 3, C2, [768, False]] # 14 - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 6], 1, Concat, [1]] # cat backbone P4 - - [-1, 3, C2, [512, False]] # 17 + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2, [512, False]] # 17 - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 4], 1, Concat, [1]] # cat backbone P3 - - [-1, 3, C2, [256, False]] # 20 (P3/8-small) + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2, [256, False]] # 20 (P3/8-small) - [-1, 1, Conv, [256, 3, 2]] - - [[-1, 17], 1, Concat, [1]] # cat head P4 - - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) + - [[-1, 17], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) - [-1, 1, Conv, [512, 3, 2]] - - [[-1, 14], 1, Concat, [1]] # cat head P5 - - [-1, 3, C2, [768, False]] # 26 (P5/32-large) + - [[-1, 14], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2, [768, False]] # 26 (P5/32-large) - [-1, 1, Conv, [768, 3, 2]] - - [[-1, 11], 1, Concat, [1]] # cat head P6 - - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) + - [[-1, 11], 1, Concat, [1]] # cat head P6 + - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) - - [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6) + - [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6) diff --git a/ultralytics/cfg/models/v8/yolov8-pose-p6.yaml b/ultralytics/cfg/models/v8/yolov8-pose-p6.yaml index abf0cfc..60007ac 100644 --- a/ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +++ b/ultralytics/cfg/models/v8/yolov8-pose-p6.yaml @@ -2,8 +2,8 @@ # YOLOv8-pose-p6 keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose # Parameters -nc: 1 # number of classes -kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +nc: 1 # number of classes +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n' # [depth, width, max_channels] n: [0.33, 0.25, 1024] @@ -15,43 +15,43 @@ scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will ca # YOLOv8.0x6 backbone backbone: # [from, repeats, module, args] - - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 - - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 - [-1, 3, C2f, [128, True]] - - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 - [-1, 6, C2f, [256, True]] - - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 - [-1, 6, C2f, [512, True]] - - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 + - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 - [-1, 3, C2f, [768, True]] - - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 + - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 - [-1, 3, C2f, [1024, True]] - - [-1, 1, SPPF, [1024, 5]] # 11 + - [-1, 1, SPPF, [1024, 5]] # 11 # YOLOv8.0x6 head head: - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 8], 1, Concat, [1]] # cat backbone P5 - - [-1, 3, C2, [768, False]] # 14 + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P5 + - [-1, 3, C2, [768, False]] # 14 - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 6], 1, Concat, [1]] # cat backbone P4 - - [-1, 3, C2, [512, False]] # 17 + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2, [512, False]] # 17 - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 4], 1, Concat, [1]] # cat backbone P3 - - [-1, 3, C2, [256, False]] # 20 (P3/8-small) + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2, [256, False]] # 20 (P3/8-small) - [-1, 1, Conv, [256, 3, 2]] - - [[-1, 17], 1, Concat, [1]] # cat head P4 - - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) + - [[-1, 17], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) - [-1, 1, Conv, [512, 3, 2]] - - [[-1, 14], 1, Concat, [1]] # cat head P5 - - [-1, 3, C2, [768, False]] # 26 (P5/32-large) + - [[-1, 14], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2, [768, False]] # 26 (P5/32-large) - [-1, 1, Conv, [768, 3, 2]] - - [[-1, 11], 1, Concat, [1]] # cat head P6 - - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) + - [[-1, 11], 1, Concat, [1]] # cat head P6 + - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) - - [[20, 23, 26, 29], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5, P6) + - [[20, 23, 26, 29], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5, P6) diff --git a/ultralytics/cfg/models/v8/yolov8-pose.yaml b/ultralytics/cfg/models/v8/yolov8-pose.yaml index 9f48e1e..60388ef 100644 --- a/ultralytics/cfg/models/v8/yolov8-pose.yaml +++ b/ultralytics/cfg/models/v8/yolov8-pose.yaml @@ -2,8 +2,8 @@ # YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose # Parameters -nc: 1 # number of classes -kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +nc: 1 # number of classes +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will call yolov8-pose.yaml with scale 'n' # [depth, width, max_channels] n: [0.33, 0.25, 1024] @@ -15,33 +15,33 @@ scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will # YOLOv8.0n backbone backbone: # [from, repeats, module, args] - - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 - - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 - [-1, 3, C2f, [128, True]] - - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 - [-1, 6, C2f, [256, True]] - - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 - [-1, 6, C2f, [512, True]] - - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 - [-1, 3, C2f, [1024, True]] - - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, SPPF, [1024, 5]] # 9 # YOLOv8.0n head head: - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 6], 1, Concat, [1]] # cat backbone P4 - - [-1, 3, C2f, [512]] # 12 + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 4], 1, Concat, [1]] # cat backbone P3 - - [-1, 3, C2f, [256]] # 15 (P3/8-small) + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) - [-1, 1, Conv, [256, 3, 2]] - - [[-1, 12], 1, Concat, [1]] # cat head P4 - - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) - [-1, 1, Conv, [512, 3, 2]] - - [[-1, 9], 1, Concat, [1]] # cat head P5 - - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) - - [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5) + - [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-rtdetr.yaml b/ultralytics/cfg/models/v8/yolov8-rtdetr.yaml index a058106..27b790b 100644 --- a/ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +++ b/ultralytics/cfg/models/v8/yolov8-rtdetr.yaml @@ -2,45 +2,45 @@ # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect # Parameters -nc: 80 # number of classes +nc: 80 # number of classes scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' # [depth, width, max_channels] - n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs - s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs - m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs - l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs - x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs + n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs # YOLOv8.0n backbone backbone: # [from, repeats, module, args] - - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 - - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 - [-1, 3, C2f, [128, True]] - - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 - [-1, 6, C2f, [256, True]] - - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 - [-1, 6, C2f, [512, True]] - - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 - [-1, 3, C2f, [1024, True]] - - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, SPPF, [1024, 5]] # 9 # YOLOv8.0n head head: - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 6], 1, Concat, [1]] # cat backbone P4 - - [-1, 3, C2f, [512]] # 12 + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 4], 1, Concat, [1]] # cat backbone P3 - - [-1, 3, C2f, [256]] # 15 (P3/8-small) + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) - [-1, 1, Conv, [256, 3, 2]] - - [[-1, 12], 1, Concat, [1]] # cat head P4 - - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) - [-1, 1, Conv, [512, 3, 2]] - - [[-1, 9], 1, Concat, [1]] # cat head P5 - - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) - - [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) + - [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-seg-p6.yaml b/ultralytics/cfg/models/v8/yolov8-seg-p6.yaml index 5ac0936..78c0444 100644 --- a/ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +++ b/ultralytics/cfg/models/v8/yolov8-seg-p6.yaml @@ -2,7 +2,7 @@ # YOLOv8-seg-p6 instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment # Parameters -nc: 80 # number of classes +nc: 80 # number of classes scales: # model compound scaling constants, i.e. 'model=yolov8n-seg-p6.yaml' will call yolov8-seg-p6.yaml with scale 'n' # [depth, width, max_channels] n: [0.33, 0.25, 1024] @@ -14,43 +14,43 @@ scales: # model compound scaling constants, i.e. 'model=yolov8n-seg-p6.yaml' wil # YOLOv8.0x6 backbone backbone: # [from, repeats, module, args] - - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 - - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 - [-1, 3, C2f, [128, True]] - - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 - [-1, 6, C2f, [256, True]] - - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 - [-1, 6, C2f, [512, True]] - - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 + - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 - [-1, 3, C2f, [768, True]] - - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 + - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 - [-1, 3, C2f, [1024, True]] - - [-1, 1, SPPF, [1024, 5]] # 11 + - [-1, 1, SPPF, [1024, 5]] # 11 # YOLOv8.0x6 head head: - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 8], 1, Concat, [1]] # cat backbone P5 - - [-1, 3, C2, [768, False]] # 14 + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P5 + - [-1, 3, C2, [768, False]] # 14 - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 6], 1, Concat, [1]] # cat backbone P4 - - [-1, 3, C2, [512, False]] # 17 + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2, [512, False]] # 17 - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 4], 1, Concat, [1]] # cat backbone P3 - - [-1, 3, C2, [256, False]] # 20 (P3/8-small) + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2, [256, False]] # 20 (P3/8-small) - [-1, 1, Conv, [256, 3, 2]] - - [[-1, 17], 1, Concat, [1]] # cat head P4 - - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) + - [[-1, 17], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) - [-1, 1, Conv, [512, 3, 2]] - - [[-1, 14], 1, Concat, [1]] # cat head P5 - - [-1, 3, C2, [768, False]] # 26 (P5/32-large) + - [[-1, 14], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2, [768, False]] # 26 (P5/32-large) - [-1, 1, Conv, [768, 3, 2]] - - [[-1, 11], 1, Concat, [1]] # cat head P6 - - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) + - [[-1, 11], 1, Concat, [1]] # cat head P6 + - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) - - [[20, 23, 26, 29], 1, Segment, [nc, 32, 256]] # Pose(P3, P4, P5, P6) + - [[20, 23, 26, 29], 1, Segment, [nc, 32, 256]] # Pose(P3, P4, P5, P6) diff --git a/ultralytics/cfg/models/v8/yolov8-seg.yaml b/ultralytics/cfg/models/v8/yolov8-seg.yaml index fbb08fc..700b795 100644 --- a/ultralytics/cfg/models/v8/yolov8-seg.yaml +++ b/ultralytics/cfg/models/v8/yolov8-seg.yaml @@ -2,7 +2,7 @@ # YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment # Parameters -nc: 80 # number of classes +nc: 80 # number of classes scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n' # [depth, width, max_channels] n: [0.33, 0.25, 1024] @@ -14,33 +14,33 @@ scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will c # YOLOv8.0n backbone backbone: # [from, repeats, module, args] - - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 - - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 - [-1, 3, C2f, [128, True]] - - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 - [-1, 6, C2f, [256, True]] - - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 - [-1, 6, C2f, [512, True]] - - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 - [-1, 3, C2f, [1024, True]] - - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, SPPF, [1024, 5]] # 9 # YOLOv8.0n head head: - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 6], 1, Concat, [1]] # cat backbone P4 - - [-1, 3, C2f, [512]] # 12 + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 4], 1, Concat, [1]] # cat backbone P3 - - [-1, 3, C2f, [256]] # 15 (P3/8-small) + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) - [-1, 1, Conv, [256, 3, 2]] - - [[-1, 12], 1, Concat, [1]] # cat head P4 - - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) - [-1, 1, Conv, [512, 3, 2]] - - [[-1, 9], 1, Concat, [1]] # cat head P5 - - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) - - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5) + - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-world.yaml b/ultralytics/cfg/models/v8/yolov8-world.yaml new file mode 100644 index 0000000..c21a7f0 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-world.yaml @@ -0,0 +1,48 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# YOLOv8-World object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fAttn, [512, 256, 8]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2fAttn, [256, 128, 4]] # 15 (P3/8-small) + + - [[15, 12, 9], 1, ImagePoolingAttn, [256]] # 16 (P3/8-small) + + - [15, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fAttn, [512, 256, 8]] # 19 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fAttn, [1024, 512, 16]] # 22 (P5/32-large) + + - [[15, 19, 22], 1, WorldDetect, [nc, 512, False]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-worldv2.yaml b/ultralytics/cfg/models/v8/yolov8-worldv2.yaml new file mode 100644 index 0000000..322b97d --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-worldv2.yaml @@ -0,0 +1,46 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# YOLOv8-World-v2 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fAttn, [512, 256, 8]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2fAttn, [256, 128, 4]] # 15 (P3/8-small) + + - [15, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fAttn, [512, 256, 8]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fAttn, [1024, 512, 16]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, WorldDetect, [nc, 512, True]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8.yaml b/ultralytics/cfg/models/v8/yolov8.yaml index 2255450..b328e98 100644 --- a/ultralytics/cfg/models/v8/yolov8.yaml +++ b/ultralytics/cfg/models/v8/yolov8.yaml @@ -2,45 +2,45 @@ # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect # Parameters -nc: 80 # number of classes +nc: 80 # number of classes scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' # [depth, width, max_channels] - n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs - s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs - m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs - l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs - x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs + n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs # YOLOv8.0n backbone backbone: # [from, repeats, module, args] - - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 - - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 - [-1, 3, C2f, [128, True]] - - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 - [-1, 6, C2f, [256, True]] - - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 - [-1, 6, C2f, [512, True]] - - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 - [-1, 3, C2f, [1024, True]] - - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, SPPF, [1024, 5]] # 9 # YOLOv8.0n head head: - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 6], 1, Concat, [1]] # cat backbone P4 - - [-1, 3, C2f, [512]] # 12 + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 - - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - - [[-1, 4], 1, Concat, [1]] # cat backbone P3 - - [-1, 3, C2f, [256]] # 15 (P3/8-small) + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) - [-1, 1, Conv, [256, 3, 2]] - - [[-1, 12], 1, Concat, [1]] # cat head P4 - - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) - [-1, 1, Conv, [512, 3, 2]] - - [[-1, 9], 1, Concat, [1]] # cat head P5 - - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) - - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) + - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v9/yolov9c.yaml b/ultralytics/cfg/models/v9/yolov9c.yaml new file mode 100644 index 0000000..66c02d6 --- /dev/null +++ b/ultralytics/cfg/models/v9/yolov9c.yaml @@ -0,0 +1,36 @@ +# YOLOv9 + +# parameters +nc: 80 # number of classes + +# gelan backbone +backbone: + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]] # 2 + - [-1, 1, ADown, [256]] # 3-P3/8 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]] # 4 + - [-1, 1, ADown, [512]] # 5-P4/16 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 6 + - [-1, 1, ADown, [512]] # 7-P5/32 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 8 + - [-1, 1, SPPELAN, [512, 256]] # 9 + +head: + - [-1, 1, nn.Upsample, [None, 2, 'nearest']] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, 'nearest']] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]] # 15 (P3/8-small) + + - [-1, 1, ADown, [256]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 18 (P4/16-medium) + + - [-1, 1, ADown, [512]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Detect, [nc]] # DDetect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v9/yolov9e.yaml b/ultralytics/cfg/models/v9/yolov9e.yaml new file mode 100644 index 0000000..8e15a42 --- /dev/null +++ b/ultralytics/cfg/models/v9/yolov9e.yaml @@ -0,0 +1,60 @@ +# YOLOv9 + +# parameters +nc: 80 # number of classes + +# gelan backbone +backbone: + - [-1, 1, Silence, []] + - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 2-P2/4 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]] # 3 + - [-1, 1, ADown, [256]] # 4-P3/8 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]] # 5 + - [-1, 1, ADown, [512]] # 6-P4/16 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 7 + - [-1, 1, ADown, [1024]] # 8-P5/32 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 9 + + - [1, 1, CBLinear, [[64]]] # 10 + - [3, 1, CBLinear, [[64, 128]]] # 11 + - [5, 1, CBLinear, [[64, 128, 256]]] # 12 + - [7, 1, CBLinear, [[64, 128, 256, 512]]] # 13 + - [9, 1, CBLinear, [[64, 128, 256, 512, 1024]]] # 14 + + - [0, 1, Conv, [64, 3, 2]] # 15-P1/2 + - [[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]] # 16 + - [-1, 1, Conv, [128, 3, 2]] # 17-P2/4 + - [[11, 12, 13, 14, -1], 1, CBFuse, [[1, 1, 1, 1]]] # 18 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]] # 19 + - [-1, 1, ADown, [256]] # 20-P3/8 + - [[12, 13, 14, -1], 1, CBFuse, [[2, 2, 2]]] # 21 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]] # 22 + - [-1, 1, ADown, [512]] # 23-P4/16 + - [[13, 14, -1], 1, CBFuse, [[3, 3]]] # 24 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 25 + - [-1, 1, ADown, [1024]] # 26-P5/32 + - [[14, -1], 1, CBFuse, [[4]]] # 27 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 28 + - [-1, 1, SPPELAN, [512, 256]] # 29 + +# gelan head +head: + - [-1, 1, nn.Upsample, [None, 2, 'nearest']] + - [[-1, 25], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]] # 32 + + - [-1, 1, nn.Upsample, [None, 2, 'nearest']] + - [[-1, 22], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 2]] # 35 (P3/8-small) + + - [-1, 1, ADown, [256]] + - [[-1, 32], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]] # 38 (P4/16-medium) + + - [-1, 1, ADown, [512]] + - [[-1, 29], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]] # 41 (P5/32-large) + + # detect + - [[35, 38, 41], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/trackers/botsort.yaml b/ultralytics/cfg/trackers/botsort.yaml index cbbf348..0c66dc6 100644 --- a/ultralytics/cfg/trackers/botsort.yaml +++ b/ultralytics/cfg/trackers/botsort.yaml @@ -1,17 +1,17 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license # Default YOLO tracker settings for BoT-SORT tracker https://github.com/NirAharon/BoT-SORT -tracker_type: botsort # tracker type, ['botsort', 'bytetrack'] -track_high_thresh: 0.5 # threshold for the first association -track_low_thresh: 0.1 # threshold for the second association -new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks -track_buffer: 30 # buffer to calculate the time when to remove tracks -match_thresh: 0.8 # threshold for matching tracks +tracker_type: botsort # tracker type, ['botsort', 'bytetrack'] +track_high_thresh: 0.5 # threshold for the first association +track_low_thresh: 0.1 # threshold for the second association +new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks +track_buffer: 30 # buffer to calculate the time when to remove tracks +match_thresh: 0.8 # threshold for matching tracks # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) # mot20: False # for tracker evaluation(not used for now) # BoT-SORT settings -gmc_method: sparseOptFlow # method of global motion compensation +gmc_method: sparseOptFlow # method of global motion compensation # ReID model related thresh (not supported yet) proximity_thresh: 0.5 appearance_thresh: 0.25 diff --git a/ultralytics/cfg/trackers/bytetrack.yaml b/ultralytics/cfg/trackers/bytetrack.yaml index 5060f92..29d352c 100644 --- a/ultralytics/cfg/trackers/bytetrack.yaml +++ b/ultralytics/cfg/trackers/bytetrack.yaml @@ -1,11 +1,11 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license # Default YOLO tracker settings for ByteTrack tracker https://github.com/ifzhang/ByteTrack -tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack'] -track_high_thresh: 0.5 # threshold for the first association -track_low_thresh: 0.1 # threshold for the second association -new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks -track_buffer: 30 # buffer to calculate the time when to remove tracks -match_thresh: 0.8 # threshold for matching tracks +tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack'] +track_high_thresh: 0.5 # threshold for the first association +track_low_thresh: 0.1 # threshold for the second association +new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks +track_buffer: 30 # buffer to calculate the time when to remove tracks +match_thresh: 0.8 # threshold for matching tracks # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) # mot20: False # for tracker evaluation(not used for now) diff --git a/ultralytics/data/__init__.py b/ultralytics/data/__init__.py index 6fa7e84..9f91ce9 100644 --- a/ultralytics/data/__init__.py +++ b/ultralytics/data/__init__.py @@ -4,5 +4,12 @@ from .base import BaseDataset from .build import build_dataloader, build_yolo_dataset, load_inference_source from .dataset import ClassificationDataset, SemanticDataset, YOLODataset -__all__ = ('BaseDataset', 'ClassificationDataset', 'SemanticDataset', 'YOLODataset', 'build_yolo_dataset', - 'build_dataloader', 'load_inference_source') +__all__ = ( + "BaseDataset", + "ClassificationDataset", + "SemanticDataset", + "YOLODataset", + "build_yolo_dataset", + "build_dataloader", + "load_inference_source", +) diff --git a/ultralytics/data/__pycache__/__init__.cpython-312.pyc b/ultralytics/data/__pycache__/__init__.cpython-312.pyc index a315137..79aaa2a 100644 Binary files a/ultralytics/data/__pycache__/__init__.cpython-312.pyc and b/ultralytics/data/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/data/__pycache__/__init__.cpython-39.pyc b/ultralytics/data/__pycache__/__init__.cpython-39.pyc index 6b73442..f451229 100644 Binary files a/ultralytics/data/__pycache__/__init__.cpython-39.pyc and b/ultralytics/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/data/__pycache__/augment.cpython-312.pyc b/ultralytics/data/__pycache__/augment.cpython-312.pyc index dd25a89..0f50d86 100644 Binary files a/ultralytics/data/__pycache__/augment.cpython-312.pyc and b/ultralytics/data/__pycache__/augment.cpython-312.pyc differ diff --git a/ultralytics/data/__pycache__/augment.cpython-39.pyc b/ultralytics/data/__pycache__/augment.cpython-39.pyc index e00500d..bebc23e 100644 Binary files a/ultralytics/data/__pycache__/augment.cpython-39.pyc and b/ultralytics/data/__pycache__/augment.cpython-39.pyc differ diff --git a/ultralytics/data/__pycache__/base.cpython-312.pyc b/ultralytics/data/__pycache__/base.cpython-312.pyc index 977602f..fcacfb4 100644 Binary files a/ultralytics/data/__pycache__/base.cpython-312.pyc and b/ultralytics/data/__pycache__/base.cpython-312.pyc differ diff --git a/ultralytics/data/__pycache__/base.cpython-39.pyc b/ultralytics/data/__pycache__/base.cpython-39.pyc index 0fe3f01..f403862 100644 Binary files a/ultralytics/data/__pycache__/base.cpython-39.pyc and b/ultralytics/data/__pycache__/base.cpython-39.pyc differ diff --git a/ultralytics/data/__pycache__/build.cpython-312.pyc b/ultralytics/data/__pycache__/build.cpython-312.pyc index 88587dd..030a11d 100644 Binary files a/ultralytics/data/__pycache__/build.cpython-312.pyc and b/ultralytics/data/__pycache__/build.cpython-312.pyc differ diff --git a/ultralytics/data/__pycache__/build.cpython-39.pyc b/ultralytics/data/__pycache__/build.cpython-39.pyc index 188bdf0..a437e16 100644 Binary files a/ultralytics/data/__pycache__/build.cpython-39.pyc and b/ultralytics/data/__pycache__/build.cpython-39.pyc differ diff --git a/ultralytics/data/__pycache__/converter.cpython-312.pyc b/ultralytics/data/__pycache__/converter.cpython-312.pyc index 29f7ca5..1ac161e 100644 Binary files a/ultralytics/data/__pycache__/converter.cpython-312.pyc and b/ultralytics/data/__pycache__/converter.cpython-312.pyc differ diff --git a/ultralytics/data/__pycache__/converter.cpython-39.pyc b/ultralytics/data/__pycache__/converter.cpython-39.pyc index e975df1..ff6cf6d 100644 Binary files a/ultralytics/data/__pycache__/converter.cpython-39.pyc and b/ultralytics/data/__pycache__/converter.cpython-39.pyc differ diff --git a/ultralytics/data/__pycache__/dataset.cpython-312.pyc b/ultralytics/data/__pycache__/dataset.cpython-312.pyc index 74ed7db..c3334d9 100644 Binary files a/ultralytics/data/__pycache__/dataset.cpython-312.pyc and b/ultralytics/data/__pycache__/dataset.cpython-312.pyc differ diff --git a/ultralytics/data/__pycache__/dataset.cpython-39.pyc b/ultralytics/data/__pycache__/dataset.cpython-39.pyc index 4e0dfbd..91bfcba 100644 Binary files a/ultralytics/data/__pycache__/dataset.cpython-39.pyc and b/ultralytics/data/__pycache__/dataset.cpython-39.pyc differ diff --git a/ultralytics/data/__pycache__/loaders.cpython-312.pyc b/ultralytics/data/__pycache__/loaders.cpython-312.pyc index a799acd..9766f1f 100644 Binary files a/ultralytics/data/__pycache__/loaders.cpython-312.pyc and b/ultralytics/data/__pycache__/loaders.cpython-312.pyc differ diff --git a/ultralytics/data/__pycache__/loaders.cpython-39.pyc b/ultralytics/data/__pycache__/loaders.cpython-39.pyc index 1f0e4e7..cabbb34 100644 Binary files a/ultralytics/data/__pycache__/loaders.cpython-39.pyc and b/ultralytics/data/__pycache__/loaders.cpython-39.pyc differ diff --git a/ultralytics/data/__pycache__/utils.cpython-312.pyc b/ultralytics/data/__pycache__/utils.cpython-312.pyc index 184140b..d6e9800 100644 Binary files a/ultralytics/data/__pycache__/utils.cpython-312.pyc and b/ultralytics/data/__pycache__/utils.cpython-312.pyc differ diff --git a/ultralytics/data/__pycache__/utils.cpython-39.pyc b/ultralytics/data/__pycache__/utils.cpython-39.pyc index ba1df55..1347543 100644 Binary files a/ultralytics/data/__pycache__/utils.cpython-39.pyc and b/ultralytics/data/__pycache__/utils.cpython-39.pyc differ diff --git a/ultralytics/data/annotator.py b/ultralytics/data/annotator.py index b4e08c7..b5b899c 100644 --- a/ultralytics/data/annotator.py +++ b/ultralytics/data/annotator.py @@ -5,7 +5,7 @@ from pathlib import Path from ultralytics import SAM, YOLO -def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None): +def auto_annotate(data, det_model="yolov8x.pt", sam_model="sam_b.pt", device="", output_dir=None): """ Automatically annotates images using a YOLO object detection model and a SAM segmentation model. @@ -29,7 +29,7 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', data = Path(data) if not output_dir: - output_dir = data.parent / f'{data.stem}_auto_annotate_labels' + output_dir = data.parent / f"{data.stem}_auto_annotate_labels" Path(output_dir).mkdir(exist_ok=True, parents=True) det_results = det_model(data, stream=True, device=device) @@ -41,10 +41,10 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device) segments = sam_results[0].masks.xyn # noqa - with open(f'{str(Path(output_dir) / Path(result.path).stem)}.txt', 'w') as f: + with open(f"{Path(output_dir) / Path(result.path).stem}.txt", "w") as f: for i in range(len(segments)): s = segments[i] if len(s) == 0: continue segment = map(str, segments[i].reshape(-1).tolist()) - f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n') + f.write(f"{class_ids[i]} " + " ".join(segment) + "\n") diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py index f14b82b..aab3e62 100644 --- a/ultralytics/data/augment.py +++ b/ultralytics/data/augment.py @@ -13,23 +13,41 @@ from ultralytics.utils import LOGGER, colorstr from ultralytics.utils.checks import check_version from ultralytics.utils.instance import Instances from ultralytics.utils.metrics import bbox_ioa -from ultralytics.utils.ops import segment2box - +from ultralytics.utils.ops import segment2box, xyxyxyxy2xywhr +from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13 from .utils import polygons2masks, polygons2masks_overlap +DEFAULT_MEAN = (0.0, 0.0, 0.0) +DEFAULT_STD = (1.0, 1.0, 1.0) +DEFAULT_CROP_FTACTION = 1.0 + # TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic class BaseTransform: + """ + Base class for image transformations. + + This is a generic transformation class that can be extended for specific image processing needs. + The class is designed to be compatible with both classification and semantic segmentation tasks. + + Methods: + __init__: Initializes the BaseTransform object. + apply_image: Applies image transformation to labels. + apply_instances: Applies transformations to object instances in labels. + apply_semantic: Applies semantic segmentation to an image. + __call__: Applies all label transformations to an image, instances, and semantic masks. + """ def __init__(self) -> None: + """Initializes the BaseTransform object.""" pass def apply_image(self, labels): - """Applies image transformation to labels.""" + """Applies image transformations to labels.""" pass def apply_instances(self, labels): - """Applies transformations to input 'labels' and returns object instances.""" + """Applies transformations to object instances in labels.""" pass def apply_semantic(self, labels): @@ -37,13 +55,14 @@ class BaseTransform: pass def __call__(self, labels): - """Applies label transformations to an image, instances and semantic masks.""" + """Applies all label transformations to an image, instances, and semantic masks.""" self.apply_image(labels) self.apply_instances(labels) self.apply_semantic(labels) class Compose: + """Class for composing multiple image transformations.""" def __init__(self, transforms): """Initializes the Compose object with a list of transforms.""" @@ -60,18 +79,23 @@ class Compose: self.transforms.append(transform) def tolist(self): - """Converts list of transforms to a standard Python list.""" + """Converts the list of transforms to a standard Python list.""" return self.transforms def __repr__(self): - """Return string representation of object.""" + """Returns a string representation of the object.""" return f"{self.__class__.__name__}({', '.join([f'{t}' for t in self.transforms])})" class BaseMixTransform: - """This implementation is from mmyolo.""" + """ + Class for base mix (MixUp/Mosaic) transformations. + + This implementation is from mmyolo. + """ def __init__(self, dataset, pre_transform=None, p=0.0) -> None: + """Initializes the BaseMixTransform object with dataset, pre_transform, and probability.""" self.dataset = dataset self.pre_transform = pre_transform self.p = p @@ -92,11 +116,11 @@ class BaseMixTransform: if self.pre_transform is not None: for i, data in enumerate(mix_labels): mix_labels[i] = self.pre_transform(data) - labels['mix_labels'] = mix_labels + labels["mix_labels"] = mix_labels # Mosaic or MixUp labels = self._mix_transform(labels) - labels.pop('mix_labels', None) + labels.pop("mix_labels", None) return labels def _mix_transform(self, labels): @@ -124,8 +148,8 @@ class Mosaic(BaseMixTransform): def __init__(self, dataset, imgsz=640, p=1.0, n=4): """Initializes the object with a dataset, image size, probability, and border.""" - assert 0 <= p <= 1.0, f'The probability should be in range [0, 1], but got {p}.' - assert n in (4, 9), 'grid must be equal to 4 or 9.' + assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}." + assert n in (4, 9), "grid must be equal to 4 or 9." super().__init__(dataset=dataset, p=p) self.dataset = dataset self.imgsz = imgsz @@ -141,9 +165,45 @@ class Mosaic(BaseMixTransform): def _mix_transform(self, labels): """Apply mixup transformation to the input image and labels.""" - assert labels.get('rect_shape', None) is None, 'rect and mosaic are mutually exclusive.' - assert len(labels.get('mix_labels', [])), 'There are no other images for mosaic augment.' - return self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels) + assert labels.get("rect_shape", None) is None, "rect and mosaic are mutually exclusive." + assert len(labels.get("mix_labels", [])), "There are no other images for mosaic augment." + return ( + self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels) + ) # This code is modified for mosaic3 method. + + def _mosaic3(self, labels): + """Create a 1x3 image mosaic.""" + mosaic_labels = [] + s = self.imgsz + for i in range(3): + labels_patch = labels if i == 0 else labels["mix_labels"][i - 1] + # Load image + img = labels_patch["img"] + h, w = labels_patch.pop("resized_shape") + + # Place img in img3 + if i == 0: # center + img3 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 3 tiles + h0, w0 = h, w + c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates + elif i == 1: # right + c = s + w0, s, s + w0 + w, s + h + elif i == 2: # left + c = s - w, s + h0 - h, s, s + h0 + + padw, padh = c[:2] + x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords + + img3[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img3[ymin:ymax, xmin:xmax] + # hp, wp = h, w # height, width previous for next iteration + + # Labels assuming imgsz*2 mosaic size + labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1]) + mosaic_labels.append(labels_patch) + final_labels = self._cat_labels(mosaic_labels) + + final_labels["img"] = img3[-self.border[0] : self.border[0], -self.border[1] : self.border[1]] + return final_labels def _mosaic4(self, labels): """Create a 2x2 image mosaic.""" @@ -151,10 +211,10 @@ class Mosaic(BaseMixTransform): s = self.imgsz yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y for i in range(4): - labels_patch = labels if i == 0 else labels['mix_labels'][i - 1] + labels_patch = labels if i == 0 else labels["mix_labels"][i - 1] # Load image - img = labels_patch['img'] - h, w = labels_patch.pop('resized_shape') + img = labels_patch["img"] + h, w = labels_patch.pop("resized_shape") # Place img in img4 if i == 0: # top left @@ -178,7 +238,7 @@ class Mosaic(BaseMixTransform): labels_patch = self._update_labels(labels_patch, padw, padh) mosaic_labels.append(labels_patch) final_labels = self._cat_labels(mosaic_labels) - final_labels['img'] = img4 + final_labels["img"] = img4 return final_labels def _mosaic9(self, labels): @@ -187,10 +247,10 @@ class Mosaic(BaseMixTransform): s = self.imgsz hp, wp = -1, -1 # height, width previous for i in range(9): - labels_patch = labels if i == 0 else labels['mix_labels'][i - 1] + labels_patch = labels if i == 0 else labels["mix_labels"][i - 1] # Load image - img = labels_patch['img'] - h, w = labels_patch.pop('resized_shape') + img = labels_patch["img"] + h, w = labels_patch.pop("resized_shape") # Place img in img9 if i == 0: # center @@ -218,7 +278,7 @@ class Mosaic(BaseMixTransform): x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords # Image - img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax] + img9[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img9[ymin:ymax, xmin:xmax] hp, wp = h, w # height, width previous for next iteration # Labels assuming imgsz*2 mosaic size @@ -226,16 +286,16 @@ class Mosaic(BaseMixTransform): mosaic_labels.append(labels_patch) final_labels = self._cat_labels(mosaic_labels) - final_labels['img'] = img9[-self.border[0]:self.border[0], -self.border[1]:self.border[1]] + final_labels["img"] = img9[-self.border[0] : self.border[0], -self.border[1] : self.border[1]] return final_labels @staticmethod def _update_labels(labels, padw, padh): """Update labels.""" - nh, nw = labels['img'].shape[:2] - labels['instances'].convert_bbox(format='xyxy') - labels['instances'].denormalize(nw, nh) - labels['instances'].add_padding(padw, padh) + nh, nw = labels["img"].shape[:2] + labels["instances"].convert_bbox(format="xyxy") + labels["instances"].denormalize(nw, nh) + labels["instances"].add_padding(padw, padh) return labels def _cat_labels(self, mosaic_labels): @@ -246,24 +306,28 @@ class Mosaic(BaseMixTransform): instances = [] imgsz = self.imgsz * 2 # mosaic imgsz for labels in mosaic_labels: - cls.append(labels['cls']) - instances.append(labels['instances']) + cls.append(labels["cls"]) + instances.append(labels["instances"]) + # Final labels final_labels = { - 'im_file': mosaic_labels[0]['im_file'], - 'ori_shape': mosaic_labels[0]['ori_shape'], - 'resized_shape': (imgsz, imgsz), - 'cls': np.concatenate(cls, 0), - 'instances': Instances.concatenate(instances, axis=0), - 'mosaic_border': self.border} # final_labels - final_labels['instances'].clip(imgsz, imgsz) - good = final_labels['instances'].remove_zero_area_boxes() - final_labels['cls'] = final_labels['cls'][good] + "im_file": mosaic_labels[0]["im_file"], + "ori_shape": mosaic_labels[0]["ori_shape"], + "resized_shape": (imgsz, imgsz), + "cls": np.concatenate(cls, 0), + "instances": Instances.concatenate(instances, axis=0), + "mosaic_border": self.border, + } + final_labels["instances"].clip(imgsz, imgsz) + good = final_labels["instances"].remove_zero_area_boxes() + final_labels["cls"] = final_labels["cls"][good] return final_labels class MixUp(BaseMixTransform): + """Class for applying MixUp augmentation to the dataset.""" def __init__(self, dataset, pre_transform=None, p=0.0) -> None: + """Initializes MixUp object with dataset, pre_transform, and probability of applying MixUp.""" super().__init__(dataset=dataset, pre_transform=pre_transform, p=p) def get_indexes(self): @@ -271,36 +335,67 @@ class MixUp(BaseMixTransform): return random.randint(0, len(self.dataset) - 1) def _mix_transform(self, labels): - """Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf.""" + """Applies MixUp augmentation as per https://arxiv.org/pdf/1710.09412.pdf.""" r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0 - labels2 = labels['mix_labels'][0] - labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8) - labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0) - labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0) + labels2 = labels["mix_labels"][0] + labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8) + labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0) + labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0) return labels class RandomPerspective: + """ + Implements random perspective and affine transformations on images and corresponding bounding boxes, segments, and + keypoints. These transformations include rotation, translation, scaling, and shearing. The class also offers the + option to apply these transformations conditionally with a specified probability. + + Attributes: + degrees (float): Degree range for random rotations. + translate (float): Fraction of total width and height for random translation. + scale (float): Scaling factor interval, e.g., a scale factor of 0.1 allows a resize between 90%-110%. + shear (float): Shear intensity (angle in degrees). + perspective (float): Perspective distortion factor. + border (tuple): Tuple specifying mosaic border. + pre_transform (callable): A function/transform to apply to the image before starting the random transformation. + + Methods: + affine_transform(img, border): Applies a series of affine transformations to the image. + apply_bboxes(bboxes, M): Transforms bounding boxes using the calculated affine matrix. + apply_segments(segments, M): Transforms segments and generates new bounding boxes. + apply_keypoints(keypoints, M): Transforms keypoints. + __call__(labels): Main method to apply transformations to both images and their corresponding annotations. + box_candidates(box1, box2): Filters out bounding boxes that don't meet certain criteria post-transformation. + """ + + def __init__( + self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0), pre_transform=None + ): + """Initializes RandomPerspective object with transformation parameters.""" - def __init__(self, - degrees=0.0, - translate=0.1, - scale=0.5, - shear=0.0, - perspective=0.0, - border=(0, 0), - pre_transform=None): self.degrees = degrees self.translate = translate self.scale = scale self.shear = shear self.perspective = perspective - # Mosaic border - self.border = border + self.border = border # mosaic border self.pre_transform = pre_transform def affine_transform(self, img, border): - """Center.""" + """ + Applies a sequence of affine transformations centered around the image center. + + Args: + img (ndarray): Input image. + border (tuple): Border dimensions. + + Returns: + img (ndarray): Transformed image. + M (ndarray): Transformation matrix. + s (float): Scale factor. + """ + + # Center C = np.eye(3, dtype=np.float32) C[0, 2] = -img.shape[1] / 2 # x translation (pixels) @@ -387,6 +482,8 @@ class RandomPerspective: xy = xy[:, :2] / xy[:, 2:3] segments = xy.reshape(n, -1, 2) bboxes = np.stack([segment2box(xy, self.size[0], self.size[1]) for xy in segments], 0) + segments[..., 0] = segments[..., 0].clip(bboxes[:, 0:1], bboxes[:, 2:3]) + segments[..., 1] = segments[..., 1].clip(bboxes[:, 1:2], bboxes[:, 3:4]) return bboxes, segments def apply_keypoints(self, keypoints, M): @@ -419,21 +516,21 @@ class RandomPerspective: Args: labels (dict): a dict of `bboxes`, `segments`, `keypoints`. """ - if self.pre_transform and 'mosaic_border' not in labels: + if self.pre_transform and "mosaic_border" not in labels: labels = self.pre_transform(labels) - labels.pop('ratio_pad', None) # do not need ratio pad + labels.pop("ratio_pad", None) # do not need ratio pad - img = labels['img'] - cls = labels['cls'] - instances = labels.pop('instances') + img = labels["img"] + cls = labels["cls"] + instances = labels.pop("instances") # Make sure the coord formats are right - instances.convert_bbox(format='xyxy') + instances.convert_bbox(format="xyxy") instances.denormalize(*img.shape[:2][::-1]) - border = labels.pop('mosaic_border', self.border) + border = labels.pop("mosaic_border", self.border) self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h # M is affine matrix - # scale for func:`box_candidates` + # Scale for func:`box_candidates` img, M, scale = self.affine_transform(img, border) bboxes = self.apply_bboxes(instances.bboxes, M) @@ -446,24 +543,38 @@ class RandomPerspective: if keypoints is not None: keypoints = self.apply_keypoints(keypoints, M) - new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False) + new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False) # Clip new_instances.clip(*self.size) # Filter instances instances.scale(scale_w=scale, scale_h=scale, bbox_only=True) # Make the bboxes have the same scale with new_bboxes - i = self.box_candidates(box1=instances.bboxes.T, - box2=new_instances.bboxes.T, - area_thr=0.01 if len(segments) else 0.10) - labels['instances'] = new_instances[i] - labels['cls'] = cls[i] - labels['img'] = img - labels['resized_shape'] = img.shape[:2] + i = self.box_candidates( + box1=instances.bboxes.T, box2=new_instances.bboxes.T, area_thr=0.01 if len(segments) else 0.10 + ) + labels["instances"] = new_instances[i] + labels["cls"] = cls[i] + labels["img"] = img + labels["resized_shape"] = img.shape[:2] return labels - def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) - # Compute box candidates: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio + def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): + """ + Compute box candidates based on a set of thresholds. This method compares the characteristics of the boxes + before and after augmentation to decide whether a box is a candidate for further processing. + + Args: + box1 (numpy.ndarray): The 4,n bounding box before augmentation, represented as [x1, y1, x2, y2]. + box2 (numpy.ndarray): The 4,n bounding box after augmentation, represented as [x1, y1, x2, y2]. + wh_thr (float, optional): The width and height threshold in pixels. Default is 2. + ar_thr (float, optional): The aspect ratio threshold. Default is 100. + area_thr (float, optional): The area ratio threshold. Default is 0.1. + eps (float, optional): A small epsilon value to prevent division by zero. Default is 1e-16. + + Returns: + (numpy.ndarray): A boolean array indicating which boxes are candidates based on the given thresholds. + """ w1, h1 = box1[2] - box1[0], box1[3] - box1[1] w2, h2 = box2[2] - box2[0], box2[3] - box2[1] ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio @@ -471,15 +582,33 @@ class RandomPerspective: class RandomHSV: + """ + This class is responsible for performing random adjustments to the Hue, Saturation, and Value (HSV) channels of an + image. + + The adjustments are random but within limits set by hgain, sgain, and vgain. + """ def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5) -> None: + """ + Initialize RandomHSV class with gains for each HSV channel. + + Args: + hgain (float, optional): Maximum variation for hue. Default is 0.5. + sgain (float, optional): Maximum variation for saturation. Default is 0.5. + vgain (float, optional): Maximum variation for value. Default is 0.5. + """ self.hgain = hgain self.sgain = sgain self.vgain = vgain def __call__(self, labels): - """Applies image HSV augmentation""" - img = labels['img'] + """ + Applies random HSV augmentation to an image within the predefined limits. + + The modified image replaces the original image in the input 'labels' dict. + """ + img = labels["img"] if self.hgain or self.sgain or self.vgain: r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) @@ -496,10 +625,23 @@ class RandomHSV: class RandomFlip: - """Applies random horizontal or vertical flip to an image with a given probability.""" + """ + Applies a random horizontal or vertical flip to an image with a given probability. - def __init__(self, p=0.5, direction='horizontal', flip_idx=None) -> None: - assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}' + Also updates any instances (bounding boxes, keypoints, etc.) accordingly. + """ + + def __init__(self, p=0.5, direction="horizontal", flip_idx=None) -> None: + """ + Initializes the RandomFlip class with probability and direction. + + Args: + p (float, optional): The probability of applying the flip. Must be between 0 and 1. Default is 0.5. + direction (str, optional): The direction to apply the flip. Must be 'horizontal' or 'vertical'. + Default is 'horizontal'. + flip_idx (array-like, optional): Index mapping for flipping keypoints, if any. + """ + assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}" assert 0 <= p <= 1.0 self.p = p @@ -507,26 +649,35 @@ class RandomFlip: self.flip_idx = flip_idx def __call__(self, labels): - """Resize image and padding for detection, instance segmentation, pose.""" - img = labels['img'] - instances = labels.pop('instances') - instances.convert_bbox(format='xywh') + """ + Applies random flip to an image and updates any instances like bounding boxes or keypoints accordingly. + + Args: + labels (dict): A dictionary containing the keys 'img' and 'instances'. 'img' is the image to be flipped. + 'instances' is an object containing bounding boxes and optionally keypoints. + + Returns: + (dict): The same dict with the flipped image and updated instances under the 'img' and 'instances' keys. + """ + img = labels["img"] + instances = labels.pop("instances") + instances.convert_bbox(format="xywh") h, w = img.shape[:2] h = 1 if instances.normalized else h w = 1 if instances.normalized else w # Flip up-down - if self.direction == 'vertical' and random.random() < self.p: + if self.direction == "vertical" and random.random() < self.p: img = np.flipud(img) instances.flipud(h) - if self.direction == 'horizontal' and random.random() < self.p: + if self.direction == "horizontal" and random.random() < self.p: img = np.fliplr(img) instances.fliplr(w) # For keypoints if self.flip_idx is not None and instances.keypoints is not None: instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :]) - labels['img'] = np.ascontiguousarray(img) - labels['instances'] = instances + labels["img"] = np.ascontiguousarray(img) + labels["instances"] = instances return labels @@ -546,9 +697,9 @@ class LetterBox: """Return updated labels and image with added border.""" if labels is None: labels = {} - img = labels.get('img') if image is None else image + img = labels.get("img") if image is None else image shape = img.shape[:2] # current shape [height, width] - new_shape = labels.pop('rect_shape', self.new_shape) + new_shape = labels.pop("rect_shape", self.new_shape) if isinstance(new_shape, int): new_shape = (new_shape, new_shape) @@ -571,45 +722,72 @@ class LetterBox: if self.center: dw /= 2 # divide padding into 2 sides dh /= 2 - if labels.get('ratio_pad'): - labels['ratio_pad'] = (labels['ratio_pad'], (dw, dh)) # for evaluation if shape[::-1] != new_unpad: # resize img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1)) left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1)) - img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, - value=(114, 114, 114)) # add border + img = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) + ) # add border + if labels.get("ratio_pad"): + labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation if len(labels): labels = self._update_labels(labels, ratio, dw, dh) - labels['img'] = img - labels['resized_shape'] = new_shape + labels["img"] = img + labels["resized_shape"] = new_shape return labels else: return img def _update_labels(self, labels, ratio, padw, padh): """Update labels.""" - labels['instances'].convert_bbox(format='xyxy') - labels['instances'].denormalize(*labels['img'].shape[:2][::-1]) - labels['instances'].scale(*ratio) - labels['instances'].add_padding(padw, padh) + labels["instances"].convert_bbox(format="xyxy") + labels["instances"].denormalize(*labels["img"].shape[:2][::-1]) + labels["instances"].scale(*ratio) + labels["instances"].add_padding(padw, padh) return labels class CopyPaste: + """ + Implements the Copy-Paste augmentation as described in the paper https://arxiv.org/abs/2012.07177. This class is + responsible for applying the Copy-Paste augmentation on images and their corresponding instances. + """ def __init__(self, p=0.5) -> None: + """ + Initializes the CopyPaste class with a given probability. + + Args: + p (float, optional): The probability of applying the Copy-Paste augmentation. Must be between 0 and 1. + Default is 0.5. + """ self.p = p def __call__(self, labels): - """Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy).""" - im = labels['img'] - cls = labels['cls'] + """ + Applies the Copy-Paste augmentation to the given image and instances. + + Args: + labels (dict): A dictionary containing: + - 'img': The image to augment. + - 'cls': Class labels associated with the instances. + - 'instances': Object containing bounding boxes, and optionally, keypoints and segments. + + Returns: + (dict): Dict with augmented image and updated instances under the 'img', 'cls', and 'instances' keys. + + Notes: + 1. Instances are expected to have 'segments' as one of their attributes for this augmentation to work. + 2. This method modifies the input dictionary 'labels' in place. + """ + im = labels["img"] + cls = labels["cls"] h, w = im.shape[:2] - instances = labels.pop('instances') - instances.convert_bbox(format='xyxy') + instances = labels.pop("instances") + instances.convert_bbox(format="xyxy") instances.denormalize(w, h) if self.p and len(instances.segments): n = len(instances) @@ -632,27 +810,32 @@ class CopyPaste: i = cv2.flip(im_new, 1).astype(bool) im[i] = result[i] - labels['img'] = im - labels['cls'] = cls - labels['instances'] = instances + labels["img"] = im + labels["cls"] = cls + labels["instances"] = instances return labels class Albumentations: - """Albumentations transformations. Optional, uninstall package to disable. - Applies Blur, Median Blur, convert to grayscale, Contrast Limited Adaptive Histogram Equalization, - random change of brightness and contrast, RandomGamma and lowering of image quality by compression.""" + """ + Albumentations transformations. + + Optional, uninstall package to disable. Applies Blur, Median Blur, convert to grayscale, Contrast Limited Adaptive + Histogram Equalization, random change of brightness and contrast, RandomGamma and lowering of image quality by + compression. + """ def __init__(self, p=1.0): """Initialize the transform object for YOLO bbox formatted params.""" self.p = p self.transform = None - prefix = colorstr('albumentations: ') + prefix = colorstr("albumentations: ") try: import albumentations as A - check_version(A.__version__, '1.0.3', hard=True) # version requirement + check_version(A.__version__, "1.0.3", hard=True) # version requirement + # Transforms T = [ A.Blur(p=0.01), A.MedianBlur(p=0.01), @@ -660,59 +843,81 @@ class Albumentations: A.CLAHE(p=0.01), A.RandomBrightnessContrast(p=0.0), A.RandomGamma(p=0.0), - A.ImageCompression(quality_lower=75, p=0.0)] # transforms - self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'])) + A.ImageCompression(quality_lower=75, p=0.0), + ] + self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"])) - LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p)) + LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p)) except ImportError: # package not installed, skip pass except Exception as e: - LOGGER.info(f'{prefix}{e}') + LOGGER.info(f"{prefix}{e}") def __call__(self, labels): """Generates object detections and returns a dictionary with detection results.""" - im = labels['img'] - cls = labels['cls'] + im = labels["img"] + cls = labels["cls"] if len(cls): - labels['instances'].convert_bbox('xywh') - labels['instances'].normalize(*im.shape[:2][::-1]) - bboxes = labels['instances'].bboxes + labels["instances"].convert_bbox("xywh") + labels["instances"].normalize(*im.shape[:2][::-1]) + bboxes = labels["instances"].bboxes # TODO: add supports of segments and keypoints if self.transform and random.random() < self.p: new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed - if len(new['class_labels']) > 0: # skip update if no bbox in new im - labels['img'] = new['image'] - labels['cls'] = np.array(new['class_labels']) - bboxes = np.array(new['bboxes'], dtype=np.float32) - labels['instances'].update(bboxes=bboxes) + if len(new["class_labels"]) > 0: # skip update if no bbox in new im + labels["img"] = new["image"] + labels["cls"] = np.array(new["class_labels"]) + bboxes = np.array(new["bboxes"], dtype=np.float32) + labels["instances"].update(bboxes=bboxes) return labels # TODO: technically this is not an augmentation, maybe we should put this to another files class Format: + """ + Formats image annotations for object detection, instance segmentation, and pose estimation tasks. The class + standardizes the image and instance annotations to be used by the `collate_fn` in PyTorch DataLoader. - def __init__(self, - bbox_format='xywh', - normalize=True, - return_mask=False, - return_keypoint=False, - mask_ratio=4, - mask_overlap=True, - batch_idx=True): + Attributes: + bbox_format (str): Format for bounding boxes. Default is 'xywh'. + normalize (bool): Whether to normalize bounding boxes. Default is True. + return_mask (bool): Return instance masks for segmentation. Default is False. + return_keypoint (bool): Return keypoints for pose estimation. Default is False. + mask_ratio (int): Downsample ratio for masks. Default is 4. + mask_overlap (bool): Whether to overlap masks. Default is True. + batch_idx (bool): Keep batch indexes. Default is True. + bgr (float): The probability to return BGR images. Default is 0.0. + """ + + def __init__( + self, + bbox_format="xywh", + normalize=True, + return_mask=False, + return_keypoint=False, + return_obb=False, + mask_ratio=4, + mask_overlap=True, + batch_idx=True, + bgr=0.0, + ): + """Initializes the Format class with given parameters.""" self.bbox_format = bbox_format self.normalize = normalize self.return_mask = return_mask # set False when training detection only self.return_keypoint = return_keypoint + self.return_obb = return_obb self.mask_ratio = mask_ratio self.mask_overlap = mask_overlap self.batch_idx = batch_idx # keep the batch indexes + self.bgr = bgr def __call__(self, labels): """Return formatted image, classes, bounding boxes & keypoints to be used by 'collate_fn'.""" - img = labels.pop('img') + img = labels.pop("img") h, w = img.shape[:2] - cls = labels.pop('cls') - instances = labels.pop('instances') + cls = labels.pop("cls") + instances = labels.pop("instances") instances.convert_bbox(format=self.bbox_format) instances.denormalize(w, h) nl = len(instances) @@ -722,31 +927,37 @@ class Format: masks, instances, cls = self._format_segments(instances, cls, w, h) masks = torch.from_numpy(masks) else: - masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, - img.shape[1] // self.mask_ratio) - labels['masks'] = masks + masks = torch.zeros( + 1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio + ) + labels["masks"] = masks if self.normalize: instances.normalize(w, h) - labels['img'] = self._format_img(img) - labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl) - labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4)) + labels["img"] = self._format_img(img) + labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl) + labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4)) if self.return_keypoint: - labels['keypoints'] = torch.from_numpy(instances.keypoints) + labels["keypoints"] = torch.from_numpy(instances.keypoints) + if self.return_obb: + labels["bboxes"] = ( + xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5)) + ) # Then we can use collate_fn if self.batch_idx: - labels['batch_idx'] = torch.zeros(nl) + labels["batch_idx"] = torch.zeros(nl) return labels def _format_img(self, img): - """Format the image for YOLOv5 from Numpy array to PyTorch tensor.""" + """Format the image for YOLO from Numpy array to PyTorch tensor.""" if len(img.shape) < 3: img = np.expand_dims(img, -1) - img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1]) + img = img.transpose(2, 0, 1) + img = np.ascontiguousarray(img[::-1] if random.uniform(0, 1) > self.bgr else img) img = torch.from_numpy(img) return img def _format_segments(self, instances, cls, w, h): - """convert polygon points to bitmap.""" + """Convert polygon points to bitmap.""" segments = instances.segments if self.mask_overlap: masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio) @@ -761,140 +972,281 @@ class Format: def v8_transforms(dataset, imgsz, hyp, stretch=False): """Convert images to a size suitable for YOLOv8 training.""" - pre_transform = Compose([ - Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic), - CopyPaste(p=hyp.copy_paste), - RandomPerspective( - degrees=hyp.degrees, - translate=hyp.translate, - scale=hyp.scale, - shear=hyp.shear, - perspective=hyp.perspective, - pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)), - )]) - flip_idx = dataset.data.get('flip_idx', []) # for keypoints augmentation + pre_transform = Compose( + [ + Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic), + CopyPaste(p=hyp.copy_paste), + RandomPerspective( + degrees=hyp.degrees, + translate=hyp.translate, + scale=hyp.scale, + shear=hyp.shear, + perspective=hyp.perspective, + pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)), + ), + ] + ) + flip_idx = dataset.data.get("flip_idx", []) # for keypoints augmentation if dataset.use_keypoints: - kpt_shape = dataset.data.get('kpt_shape', None) + kpt_shape = dataset.data.get("kpt_shape", None) if len(flip_idx) == 0 and hyp.fliplr > 0.0: hyp.fliplr = 0.0 LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'") elif flip_idx and (len(flip_idx) != kpt_shape[0]): - raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}') + raise ValueError(f"data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}") - return Compose([ - pre_transform, - MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup), - Albumentations(p=1.0), - RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), - RandomFlip(direction='vertical', p=hyp.flipud), - RandomFlip(direction='horizontal', p=hyp.fliplr, flip_idx=flip_idx)]) # transforms + return Compose( + [ + pre_transform, + MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup), + Albumentations(p=1.0), + RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), + RandomFlip(direction="vertical", p=hyp.flipud), + RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx), + ] + ) # transforms # Classification augmentations ----------------------------------------------------------------------------------------- -def classify_transforms(size=224, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)): # IMAGENET_MEAN, IMAGENET_STD +def classify_transforms( + size=224, + mean=DEFAULT_MEAN, + std=DEFAULT_STD, + interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, + crop_fraction: float = DEFAULT_CROP_FTACTION, +): + """ + Classification transforms for evaluation/inference. Inspired by timm/data/transforms_factory.py. + + Args: + size (int): image size + mean (tuple): mean values of RGB channels + std (tuple): std values of RGB channels + interpolation (T.InterpolationMode): interpolation mode. default is T.InterpolationMode.BILINEAR. + crop_fraction (float): fraction of image to crop. default is 1.0. + + Returns: + (T.Compose): torchvision transforms + """ + + if isinstance(size, (tuple, list)): + assert len(size) == 2 + scale_size = tuple(math.floor(x / crop_fraction) for x in size) + else: + scale_size = math.floor(size / crop_fraction) + scale_size = (scale_size, scale_size) + + # aspect ratio is preserved, crops center within image, no borders are added, image is lost + if scale_size[0] == scale_size[1]: + # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) + tfl = [T.Resize(scale_size[0], interpolation=interpolation)] + else: + # resize shortest edge to matching target dim for non-square target + tfl = [T.Resize(scale_size)] + tfl += [T.CenterCrop(size)] + + tfl += [ + T.ToTensor(), + T.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std), + ), + ] + + return T.Compose(tfl) + + +# Classification augmentations train --------------------------------------------------------------------------------------- +def classify_augmentations( + size=224, + mean=DEFAULT_MEAN, + std=DEFAULT_STD, + scale=None, + ratio=None, + hflip=0.5, + vflip=0.0, + auto_augment=None, + hsv_h=0.015, # image HSV-Hue augmentation (fraction) + hsv_s=0.4, # image HSV-Saturation augmentation (fraction) + hsv_v=0.4, # image HSV-Value augmentation (fraction) + force_color_jitter=False, + erasing=0.0, + interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, +): + """ + Classification transforms with augmentation for training. Inspired by timm/data/transforms_factory.py. + + Args: + size (int): image size + scale (tuple): scale range of the image. default is (0.08, 1.0) + ratio (tuple): aspect ratio range of the image. default is (3./4., 4./3.) + mean (tuple): mean values of RGB channels + std (tuple): std values of RGB channels + hflip (float): probability of horizontal flip + vflip (float): probability of vertical flip + auto_augment (str): auto augmentation policy. can be 'randaugment', 'augmix', 'autoaugment' or None. + hsv_h (float): image HSV-Hue augmentation (fraction) + hsv_s (float): image HSV-Saturation augmentation (fraction) + hsv_v (float): image HSV-Value augmentation (fraction) + force_color_jitter (bool): force to apply color jitter even if auto augment is enabled + erasing (float): probability of random erasing + interpolation (T.InterpolationMode): interpolation mode. default is T.InterpolationMode.BILINEAR. + + Returns: + (T.Compose): torchvision transforms + """ # Transforms to apply if albumentations not installed if not isinstance(size, int): - raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)') - if any(mean) or any(std): - return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(mean, std, inplace=True)]) - else: - return T.Compose([CenterCrop(size), ToTensor()]) + raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)") + scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range + ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range + primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)] + if hflip > 0.0: + primary_tfl += [T.RandomHorizontalFlip(p=hflip)] + if vflip > 0.0: + primary_tfl += [T.RandomVerticalFlip(p=vflip)] + secondary_tfl = [] + disable_color_jitter = False + if auto_augment: + assert isinstance(auto_augment, str) + # color jitter is typically disabled if AA/RA on, + # this allows override without breaking old hparm cfgs + disable_color_jitter = not force_color_jitter -def hsv2colorjitter(h, s, v): - """Map HSV (hue, saturation, value) jitter into ColorJitter values (brightness, contrast, saturation, hue)""" - return v, v, s, h - - -def classify_albumentations( - augment=True, - size=224, - scale=(0.08, 1.0), - hflip=0.5, - vflip=0.0, - hsv_h=0.015, # image HSV-Hue augmentation (fraction) - hsv_s=0.7, # image HSV-Saturation augmentation (fraction) - hsv_v=0.4, # image HSV-Value augmentation (fraction) - mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN - std=(1.0, 1.0, 1.0), # IMAGENET_STD - auto_aug=False, -): - """YOLOv8 classification Albumentations (optional, only used if package is installed).""" - prefix = colorstr('albumentations: ') - try: - import albumentations as A - from albumentations.pytorch import ToTensorV2 - - check_version(A.__version__, '1.0.3', hard=True) # version requirement - if augment: # Resize and crop - T = [A.RandomResizedCrop(height=size, width=size, scale=scale)] - if auto_aug: - # TODO: implement AugMix, AutoAug & RandAug in albumentations - LOGGER.info(f'{prefix}auto augmentations are currently not supported') + if auto_augment == "randaugment": + if TORCHVISION_0_11: + secondary_tfl += [T.RandAugment(interpolation=interpolation)] else: - if hflip > 0: - T += [A.HorizontalFlip(p=hflip)] - if vflip > 0: - T += [A.VerticalFlip(p=vflip)] - if any((hsv_h, hsv_s, hsv_v)): - T += [A.ColorJitter(*hsv2colorjitter(hsv_h, hsv_s, hsv_v))] # brightness, contrast, saturation, hue - else: # Use fixed crop for eval set (reproducibility) - T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)] - T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor - LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p)) - return A.Compose(T) + LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.') - except ImportError: # package not installed, skip - pass - except Exception as e: - LOGGER.info(f'{prefix}{e}') + elif auto_augment == "augmix": + if TORCHVISION_0_13: + secondary_tfl += [T.AugMix(interpolation=interpolation)] + else: + LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.') + + elif auto_augment == "autoaugment": + if TORCHVISION_0_10: + secondary_tfl += [T.AutoAugment(interpolation=interpolation)] + else: + LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.') + + else: + raise ValueError( + f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", ' + f'"augmix", "autoaugment" or None' + ) + + if not disable_color_jitter: + secondary_tfl += [T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)] + + final_tfl = [ + T.ToTensor(), + T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), + T.RandomErasing(p=erasing, inplace=True), + ] + + return T.Compose(primary_tfl + secondary_tfl + final_tfl) +# NOTE: keep this class for backward compatibility class ClassifyLetterBox: - """YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])""" + """ + YOLOv8 LetterBox class for image preprocessing, designed to be part of a transformation pipeline, e.g., + T.Compose([LetterBox(size), ToTensor()]). + + Attributes: + h (int): Target height of the image. + w (int): Target width of the image. + auto (bool): If True, automatically solves for short side using stride. + stride (int): The stride value, used when 'auto' is True. + """ def __init__(self, size=(640, 640), auto=False, stride=32): - """Resizes image and crops it to center with max dimensions 'h' and 'w'.""" + """ + Initializes the ClassifyLetterBox class with a target size, auto-flag, and stride. + + Args: + size (Union[int, Tuple[int, int]]): The target dimensions (height, width) for the letterbox. + auto (bool): If True, automatically calculates the short side based on stride. + stride (int): The stride value, used when 'auto' is True. + """ super().__init__() self.h, self.w = (size, size) if isinstance(size, int) else size self.auto = auto # pass max size integer, automatically solve for short side using stride self.stride = stride # used with auto - def __call__(self, im): # im = np.array HWC + def __call__(self, im): + """ + Resizes the image and pads it with a letterbox method. + + Args: + im (numpy.ndarray): The input image as a numpy array of shape HWC. + + Returns: + (numpy.ndarray): The letterboxed and resized image as a numpy array. + """ imh, imw = im.shape[:2] - r = min(self.h / imh, self.w / imw) # ratio of new/old - h, w = round(imh * r), round(imw * r) # resized image - hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w + r = min(self.h / imh, self.w / imw) # ratio of new/old dimensions + h, w = round(imh * r), round(imw * r) # resized image dimensions + + # Calculate padding dimensions + hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else (self.h, self.w) top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1) - im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype) - im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) + + # Create padded image + im_out = np.full((hs, ws, 3), 114, dtype=im.dtype) + im_out[top : top + h, left : left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) return im_out +# NOTE: keep this class for backward compatibility class CenterCrop: - """YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])""" + """YOLOv8 CenterCrop class for image preprocessing, designed to be part of a transformation pipeline, e.g., + T.Compose([CenterCrop(size), ToTensor()]). + """ def __init__(self, size=640): """Converts an image from numpy array to PyTorch tensor.""" super().__init__() self.h, self.w = (size, size) if isinstance(size, int) else size - def __call__(self, im): # im = np.array HWC + def __call__(self, im): + """ + Resizes and crops the center of the image using a letterbox method. + + Args: + im (numpy.ndarray): The input image as a numpy array of shape HWC. + + Returns: + (numpy.ndarray): The center-cropped and resized image as a numpy array. + """ imh, imw = im.shape[:2] m = min(imh, imw) # min dimension top, left = (imh - m) // 2, (imw - m) // 2 - return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR) + return cv2.resize(im[top : top + m, left : left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR) +# NOTE: keep this class for backward compatibility class ToTensor: - """YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()]).""" + """YOLOv8 ToTensor class for image preprocessing, i.e., T.Compose([LetterBox(size), ToTensor()]).""" def __init__(self, half=False): """Initialize YOLOv8 ToTensor object with optional half-precision support.""" super().__init__() self.half = half - def __call__(self, im): # im = np.array HWC in BGR order + def __call__(self, im): + """ + Transforms an image from a numpy array to a PyTorch tensor, applying optional half-precision and normalization. + + Args: + im (numpy.ndarray): Input image as a numpy array with shape (H, W, C) in BGR order. + + Returns: + (torch.Tensor): The transformed image as a PyTorch tensor in float32 or float16, normalized to [0, 1]. + """ im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous im = torch.from_numpy(im) # to torch im = im.half() if self.half else im.float() # uint8 to fp16/32 diff --git a/ultralytics/data/base.py b/ultralytics/data/base.py index 429533d..6af8d3c 100644 --- a/ultralytics/data/base.py +++ b/ultralytics/data/base.py @@ -15,7 +15,6 @@ import psutil from torch.utils.data import Dataset from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM - from .utils import HELP_URL, IMG_FORMATS @@ -47,20 +46,23 @@ class BaseDataset(Dataset): transforms (callable): Image transformation function. """ - def __init__(self, - img_path, - imgsz=640, - cache=False, - augment=True, - hyp=DEFAULT_CFG, - prefix='', - rect=False, - batch_size=16, - stride=32, - pad=0.5, - single_cls=False, - classes=None, - fraction=1.0): + def __init__( + self, + img_path, + imgsz=640, + cache=False, + augment=True, + hyp=DEFAULT_CFG, + prefix="", + rect=False, + batch_size=16, + stride=32, + pad=0.5, + single_cls=False, + classes=None, + fraction=1.0, + ): + """Initialize BaseDataset with given configuration and options.""" super().__init__() self.img_path = img_path self.imgsz = imgsz @@ -84,11 +86,11 @@ class BaseDataset(Dataset): self.buffer = [] # buffer size = batch size self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0 - # Cache stuff - if cache == 'ram' and not self.check_cache_ram(): + # Cache images + if cache == "ram" and not self.check_cache_ram(): cache = False self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni - self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files] + self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files] if cache: self.cache_images(cache) @@ -102,54 +104,62 @@ class BaseDataset(Dataset): for p in img_path if isinstance(img_path, list) else [img_path]: p = Path(p) # os-agnostic if p.is_dir(): # dir - f += glob.glob(str(p / '**' / '*.*'), recursive=True) + f += glob.glob(str(p / "**" / "*.*"), recursive=True) # F = list(p.rglob('*.*')) # pathlib elif p.is_file(): # file with open(p) as t: t = t.read().strip().splitlines() parent = str(p.parent) + os.sep - f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path + f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path # F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib) else: - raise FileNotFoundError(f'{self.prefix}{p} does not exist') - im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS) + raise FileNotFoundError(f"{self.prefix}{p} does not exist") + im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS) # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib - assert im_files, f'{self.prefix}No images found in {img_path}' + assert im_files, f"{self.prefix}No images found in {img_path}" except Exception as e: - raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e + raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e if self.fraction < 1: - im_files = im_files[:round(len(im_files) * self.fraction)] + # im_files = im_files[: round(len(im_files) * self.fraction)] + num_elements_to_select = round(len(im_files) * self.fraction) + im_files = random.sample(im_files, num_elements_to_select) return im_files def update_labels(self, include_class: Optional[list]): - """include_class, filter labels to include only these classes (optional).""" + """Update labels to include only these classes (optional).""" include_class_array = np.array(include_class).reshape(1, -1) for i in range(len(self.labels)): if include_class is not None: - cls = self.labels[i]['cls'] - bboxes = self.labels[i]['bboxes'] - segments = self.labels[i]['segments'] - keypoints = self.labels[i]['keypoints'] + cls = self.labels[i]["cls"] + bboxes = self.labels[i]["bboxes"] + segments = self.labels[i]["segments"] + keypoints = self.labels[i]["keypoints"] j = (cls == include_class_array).any(1) - self.labels[i]['cls'] = cls[j] - self.labels[i]['bboxes'] = bboxes[j] + self.labels[i]["cls"] = cls[j] + self.labels[i]["bboxes"] = bboxes[j] if segments: - self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx] + self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx] if keypoints is not None: - self.labels[i]['keypoints'] = keypoints[j] + self.labels[i]["keypoints"] = keypoints[j] if self.single_cls: - self.labels[i]['cls'][:, 0] = 0 + self.labels[i]["cls"][:, 0] = 0 def load_image(self, i, rect_mode=True): """Loads 1 image from dataset index 'i', returns (im, resized hw).""" im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] if im is None: # not cached in RAM if fn.exists(): # load npy - im = np.load(fn) + try: + im = np.load(fn) + except Exception as e: + LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}") + Path(fn).unlink(missing_ok=True) + im = cv2.imread(f) # BGR else: # read image im = cv2.imread(f) # BGR - if im is None: - raise FileNotFoundError(f'Image Not Found {f}') + if im is None: + raise FileNotFoundError(f"Image Not Found {f}") + h0, w0 = im.shape[:2] # orig hw if rect_mode: # resize long side to imgsz while maintaining aspect ratio r = self.imgsz / max(h0, w0) # ratio @@ -174,17 +184,17 @@ class BaseDataset(Dataset): def cache_images(self, cache): """Cache images to memory or disk.""" b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes - fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image + fcn = self.cache_images_to_disk if cache == "disk" else self.load_image with ThreadPool(NUM_THREADS) as pool: results = pool.imap(fcn, range(self.ni)) pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0) for i, x in pbar: - if cache == 'disk': + if cache == "disk": b += self.npy_files[i].stat().st_size else: # 'ram' self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) b += self.ims[i].nbytes - pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})' + pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {cache})" pbar.close() def cache_images_to_disk(self, i): @@ -200,15 +210,17 @@ class BaseDataset(Dataset): for _ in range(n): im = cv2.imread(random.choice(self.im_files)) # sample image ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio - b += im.nbytes * ratio ** 2 + b += im.nbytes * ratio**2 mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM mem = psutil.virtual_memory() cache = mem_required < mem.available # to cache or not to cache, that is the question if not cache: - LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images ' - f'with {int(safety_margin * 100)}% safety margin but only ' - f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, ' - f"{'caching images ✅' if cache else 'not caching images ⚠️'}") + LOGGER.info( + f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images ' + f'with {int(safety_margin * 100)}% safety margin but only ' + f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, ' + f"{'caching images ✅' if cache else 'not caching images ⚠️'}" + ) return cache def set_rectangle(self): @@ -216,7 +228,7 @@ class BaseDataset(Dataset): bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index nb = bi[-1] + 1 # number of batches - s = np.array([x.pop('shape') for x in self.labels]) # hw + s = np.array([x.pop("shape") for x in self.labels]) # hw ar = s[:, 0] / s[:, 1] # aspect ratio irect = ar.argsort() self.im_files = [self.im_files[i] for i in irect] @@ -243,12 +255,14 @@ class BaseDataset(Dataset): def get_image_and_label(self, index): """Get and return label information from the dataset.""" label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948 - label.pop('shape', None) # shape is for rect, remove it - label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index) - label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0], - label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation + label.pop("shape", None) # shape is for rect, remove it + label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index) + label["ratio_pad"] = ( + label["resized_shape"][0] / label["ori_shape"][0], + label["resized_shape"][1] / label["ori_shape"][1], + ) # for evaluation if self.rect: - label['rect_shape'] = self.batch_shapes[self.batch[index]] + label["rect_shape"] = self.batch_shapes[self.batch[index]] return self.update_labels_info(label) def __len__(self): @@ -256,24 +270,32 @@ class BaseDataset(Dataset): return len(self.labels) def update_labels_info(self, label): - """custom your label format here.""" + """Custom your label format here.""" return label def build_transforms(self, hyp=None): - """Users can custom augmentations here - like: + """ + Users can customize augmentations here. + + Example: + ```python if self.augment: # Training transforms return Compose([]) else: # Val transforms return Compose([]) + ``` """ raise NotImplementedError def get_labels(self): - """Users can custom their own format here. - Make sure your output is a list with each element like below: + """ + Users can customize their own format here. + + Note: + Ensure output is a dictionary with the following keys: + ```python dict( im_file=im_file, shape=shape, # format: (height, width) @@ -284,5 +306,6 @@ class BaseDataset(Dataset): normalized=True, # or False bbox_format="xyxy", # or xywh, ltwh ) + ``` """ raise NotImplementedError diff --git a/ultralytics/data/build.py b/ultralytics/data/build.py index 9d40e5a..6bfb48f 100644 --- a/ultralytics/data/build.py +++ b/ultralytics/data/build.py @@ -9,23 +9,34 @@ import torch from PIL import Image from torch.utils.data import dataloader, distributed -from ultralytics.data.loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams, LoadTensor, - SourceTypes, autocast_list) +from ultralytics.data.loaders import ( + LOADERS, + LoadImagesAndVideos, + LoadPilAndNumpy, + LoadScreenshots, + LoadStreams, + LoadTensor, + SourceTypes, + autocast_list, +) from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS from ultralytics.utils import RANK, colorstr from ultralytics.utils.checks import check_file - from .dataset import YOLODataset from .utils import PIN_MEMORY class InfiniteDataLoader(dataloader.DataLoader): - """Dataloader that reuses workers. Uses same syntax as vanilla DataLoader.""" + """ + Dataloader that reuses workers. + + Uses same syntax as vanilla DataLoader. + """ def __init__(self, *args, **kwargs): """Dataloader that infinitely recycles workers, inherits from DataLoader.""" super().__init__(*args, **kwargs) - object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) + object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler)) self.iterator = super().__iter__() def __len__(self): @@ -38,7 +49,9 @@ class InfiniteDataLoader(dataloader.DataLoader): yield next(self.iterator) def reset(self): - """Reset iterator. + """ + Reset iterator. + This is useful when we want to modify settings of dataset while training. """ self.iterator = self._get_iterator() @@ -64,49 +77,51 @@ class _RepeatSampler: def seed_worker(worker_id): # noqa """Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader.""" - worker_seed = torch.initial_seed() % 2 ** 32 + worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) -def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32): - """Build YOLO Dataset""" +def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32): + """Build YOLO Dataset.""" return YOLODataset( img_path=img_path, imgsz=cfg.imgsz, batch_size=batch, - augment=mode == 'train', # augmentation + augment=mode == "train", # augmentation hyp=cfg, # TODO: probably add a get_hyps_from_cfg function rect=cfg.rect or rect, # rectangular batches cache=cfg.cache or None, single_cls=cfg.single_cls or False, stride=int(stride), - pad=0.0 if mode == 'train' else 0.5, - prefix=colorstr(f'{mode}: '), - use_segments=cfg.task == 'segment', - use_keypoints=cfg.task == 'pose', + pad=0.0 if mode == "train" else 0.5, + prefix=colorstr(f"{mode}: "), + task=cfg.task, classes=cfg.classes, data=data, - fraction=cfg.fraction if mode == 'train' else 1.0) + fraction=cfg.fraction if mode == "train" else 1.0, + ) def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1): """Return an InfiniteDataLoader or DataLoader for training or validation set.""" batch = min(batch, len(dataset)) nd = torch.cuda.device_count() # number of CUDA devices - nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers + nw = min([os.cpu_count() // max(nd, 1), workers]) # number of workers sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) generator = torch.Generator() generator.manual_seed(6148914691236517205 + RANK) - return InfiniteDataLoader(dataset=dataset, - batch_size=batch, - shuffle=shuffle and sampler is None, - num_workers=nw, - sampler=sampler, - pin_memory=PIN_MEMORY, - collate_fn=getattr(dataset, 'collate_fn', None), - worker_init_fn=seed_worker, - generator=generator) + return InfiniteDataLoader( + dataset=dataset, + batch_size=batch, + shuffle=shuffle and sampler is None, + num_workers=nw, + sampler=sampler, + pin_memory=PIN_MEMORY, + collate_fn=getattr(dataset, "collate_fn", None), + worker_init_fn=seed_worker, + generator=generator, + ) def check_source(source): @@ -114,10 +129,10 @@ def check_source(source): webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False if isinstance(source, (str, int, Path)): # int for local usb camera source = str(source) - is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) - is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')) - webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file) - screenshot = source.lower() == 'screen' + is_file = Path(source).suffix[1:] in (IMG_FORMATS | VID_FORMATS) + is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")) + webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file) + screenshot = source.lower() == "screen" if is_url and is_file: source = check_file(source) # download elif isinstance(source, LOADERS): @@ -130,42 +145,42 @@ def check_source(source): elif isinstance(source, torch.Tensor): tensor = True else: - raise TypeError('Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict') + raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict") return source, webcam, screenshot, from_img, in_memory, tensor -def load_inference_source(source=None, imgsz=640, vid_stride=1, stream_buffer=False): +def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False): """ Loads an inference source for object detection and applies necessary transformations. Args: source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference. - imgsz (int, optional): The size of the image for inference. Default is 640. + batch (int, optional): Batch size for dataloaders. Default is 1. vid_stride (int, optional): The frame interval for video sources. Default is 1. - stream_buffer (bool, optional): Determined whether stream frames will be buffered. Default is False. + buffer (bool, optional): Determined whether stream frames will be buffered. Default is False. Returns: dataset (Dataset): A dataset object for the specified input source. """ - source, webcam, screenshot, from_img, in_memory, tensor = check_source(source) - source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor) + source, stream, screenshot, from_img, in_memory, tensor = check_source(source) + source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor) # Dataloader if tensor: dataset = LoadTensor(source) elif in_memory: dataset = source - elif webcam: - dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride, stream_buffer=stream_buffer) + elif stream: + dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer) elif screenshot: - dataset = LoadScreenshots(source, imgsz=imgsz) + dataset = LoadScreenshots(source) elif from_img: - dataset = LoadPilAndNumpy(source, imgsz=imgsz) + dataset = LoadPilAndNumpy(source) else: - dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride) + dataset = LoadImagesAndVideos(source, batch=batch, vid_stride=vid_stride) # Attach source types to the dataset - setattr(dataset, 'source_type', source_type) + setattr(dataset, "source_type", source_type) return dataset diff --git a/ultralytics/data/converter.py b/ultralytics/data/converter.py index 1e3b429..eff4dac 100644 --- a/ultralytics/data/converter.py +++ b/ultralytics/data/converter.py @@ -1,31 +1,120 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license import json -import shutil from collections import defaultdict from pathlib import Path import cv2 import numpy as np -from ultralytics.utils import TQDM +from ultralytics.utils import LOGGER, TQDM +from ultralytics.utils.files import increment_path def coco91_to_coco80_class(): - """Converts 91-index COCO class IDs to 80-index COCO class IDs. + """ + Converts 91-index COCO class IDs to 80-index COCO class IDs. Returns: (list): A list of 91 class IDs where the index represents the 80-index class ID and the value is the corresponding 91-index class ID. """ return [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, None, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, None, 24, 25, None, - None, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, None, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, - 51, 52, 53, 54, 55, 56, 57, 58, 59, None, 60, None, None, 61, None, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, - None, 73, 74, 75, 76, 77, 78, 79, None] + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + None, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + None, + 24, + 25, + None, + None, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + None, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + None, + 60, + None, + None, + 61, + None, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + None, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + None, + ] -def coco80_to_coco91_class(): # +def coco80_to_coco91_class(): """ Converts 80-index (val2014) to 91-index (paper). For details see https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/. @@ -41,16 +130,102 @@ def coco80_to_coco91_class(): # ``` """ return [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, - 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, - 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 27, + 28, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 67, + 70, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + ] -def convert_coco(labels_dir='../coco/annotations/', use_segments=False, use_keypoints=False, cls91to80=True): - """Converts COCO dataset annotations to a format suitable for training YOLOv5 models. +def convert_coco( + labels_dir="../coco/annotations/", + save_dir="coco_converted/", + use_segments=False, + use_keypoints=False, + cls91to80=True, +): + """ + Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models. Args: labels_dir (str, optional): Path to directory containing COCO dataset annotation files. + save_dir (str, optional): Path to directory to save results to. use_segments (bool, optional): Whether to include segmentation masks in the output. use_keypoints (bool, optional): Whether to include keypoint annotations in the output. cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs. @@ -67,78 +242,79 @@ def convert_coco(labels_dir='../coco/annotations/', use_segments=False, use_keyp """ # Create dataset directory - save_dir = Path('yolo_labels') - if save_dir.exists(): - shutil.rmtree(save_dir) # delete dir - for p in save_dir / 'labels', save_dir / 'images': + save_dir = increment_path(save_dir) # increment if save directory already exists + for p in save_dir / "labels", save_dir / "images": p.mkdir(parents=True, exist_ok=True) # make dir # Convert classes coco80 = coco91_to_coco80_class() # Import json - for json_file in sorted(Path(labels_dir).resolve().glob('*.json')): - fn = Path(save_dir) / 'labels' / json_file.stem.replace('instances_', '') # folder name + for json_file in sorted(Path(labels_dir).resolve().glob("*.json")): + fn = Path(save_dir) / "labels" / json_file.stem.replace("instances_", "") # folder name fn.mkdir(parents=True, exist_ok=True) with open(json_file) as f: data = json.load(f) # Create image dict - images = {f'{x["id"]:d}': x for x in data['images']} + images = {f'{x["id"]:d}': x for x in data["images"]} # Create image-annotations dict imgToAnns = defaultdict(list) - for ann in data['annotations']: - imgToAnns[ann['image_id']].append(ann) + for ann in data["annotations"]: + imgToAnns[ann["image_id"]].append(ann) # Write labels file - for img_id, anns in TQDM(imgToAnns.items(), desc=f'Annotations {json_file}'): - img = images[f'{img_id:d}'] - h, w, f = img['height'], img['width'], img['file_name'] + for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"): + img = images[f"{img_id:d}"] + h, w, f = img["height"], img["width"], img["file_name"] bboxes = [] segments = [] keypoints = [] for ann in anns: - if ann['iscrowd']: + if ann["iscrowd"]: continue # The COCO box format is [top left x, top left y, width, height] - box = np.array(ann['bbox'], dtype=np.float64) + box = np.array(ann["bbox"], dtype=np.float64) box[:2] += box[2:] / 2 # xy top-left corner to center box[[0, 2]] /= w # normalize x box[[1, 3]] /= h # normalize y if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0 continue - cls = coco80[ann['category_id'] - 1] if cls91to80 else ann['category_id'] - 1 # class + cls = coco80[ann["category_id"] - 1] if cls91to80 else ann["category_id"] - 1 # class box = [cls] + box.tolist() if box not in bboxes: bboxes.append(box) - if use_segments and ann.get('segmentation') is not None: - if len(ann['segmentation']) == 0: - segments.append([]) - continue - elif len(ann['segmentation']) > 1: - s = merge_multi_segment(ann['segmentation']) - s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist() - else: - s = [j for i in ann['segmentation'] for j in i] # all segments concatenated - s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist() - s = [cls] + s - if s not in segments: + if use_segments and ann.get("segmentation") is not None: + if len(ann["segmentation"]) == 0: + segments.append([]) + continue + elif len(ann["segmentation"]) > 1: + s = merge_multi_segment(ann["segmentation"]) + s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist() + else: + s = [j for i in ann["segmentation"] for j in i] # all segments concatenated + s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist() + s = [cls] + s segments.append(s) - if use_keypoints and ann.get('keypoints') is not None: - keypoints.append(box + (np.array(ann['keypoints']).reshape(-1, 3) / - np.array([w, h, 1])).reshape(-1).tolist()) + if use_keypoints and ann.get("keypoints") is not None: + keypoints.append( + box + (np.array(ann["keypoints"]).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist() + ) # Write - with open((fn / f).with_suffix('.txt'), 'a') as file: + with open((fn / f).with_suffix(".txt"), "a") as file: for i in range(len(bboxes)): if use_keypoints: - line = *(keypoints[i]), # cls, box, keypoints + line = (*(keypoints[i]),) # cls, box, keypoints else: - line = *(segments[i] - if use_segments and len(segments[i]) > 0 else bboxes[i]), # cls, box or segments - file.write(('%g ' * len(line)).rstrip() % line + '\n') + line = ( + *(segments[i] if use_segments and len(segments[i]) > 0 else bboxes[i]), + ) # cls, box or segments + file.write(("%g " * len(line)).rstrip() % line + "\n") + + LOGGER.info(f"COCO data converted successfully.\nResults saved to {save_dir.resolve()}") def convert_dota_to_yolo_obb(dota_root_path: str): @@ -160,48 +336,52 @@ def convert_dota_to_yolo_obb(dota_root_path: str): Notes: The directory structure assumed for the DOTA dataset: - - DOTA - - images - - train - - val - - labels - - train_original - - val_original - After the function execution, the new labels will be saved in: - DOTA - - labels - - train - - val + ├─ images + │ ├─ train + │ └─ val + └─ labels + ├─ train_original + └─ val_original + + After execution, the function will organize the labels into: + + - DOTA + └─ labels + ├─ train + └─ val """ dota_root_path = Path(dota_root_path) # Class names to indices mapping class_mapping = { - 'plane': 0, - 'ship': 1, - 'storage-tank': 2, - 'baseball-diamond': 3, - 'tennis-court': 4, - 'basketball-court': 5, - 'ground-track-field': 6, - 'harbor': 7, - 'bridge': 8, - 'large-vehicle': 9, - 'small-vehicle': 10, - 'helicopter': 11, - 'roundabout': 12, - 'soccer ball-field': 13, - 'swimming-pool': 14, - 'container-crane': 15, - 'airport': 16, - 'helipad': 17} + "plane": 0, + "ship": 1, + "storage-tank": 2, + "baseball-diamond": 3, + "tennis-court": 4, + "basketball-court": 5, + "ground-track-field": 6, + "harbor": 7, + "bridge": 8, + "large-vehicle": 9, + "small-vehicle": 10, + "helicopter": 11, + "roundabout": 12, + "soccer-ball-field": 13, + "swimming-pool": 14, + "container-crane": 15, + "airport": 16, + "helipad": 17, + } def convert_label(image_name, image_width, image_height, orig_label_dir, save_dir): - orig_label_path = orig_label_dir / f'{image_name}.txt' - save_path = save_dir / f'{image_name}.txt' + """Converts a single image's DOTA annotation to YOLO OBB format and saves it to a specified directory.""" + orig_label_path = orig_label_dir / f"{image_name}.txt" + save_path = save_dir / f"{image_name}.txt" - with orig_label_path.open('r') as f, save_path.open('w') as g: + with orig_label_path.open("r") as f, save_path.open("w") as g: lines = f.readlines() for line in lines: parts = line.strip().split() @@ -211,20 +391,21 @@ def convert_dota_to_yolo_obb(dota_root_path: str): class_idx = class_mapping[class_name] coords = [float(p) for p in parts[:8]] normalized_coords = [ - coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)] - formatted_coords = ['{:.6g}'.format(coord) for coord in normalized_coords] + coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8) + ] + formatted_coords = ["{:.6g}".format(coord) for coord in normalized_coords] g.write(f"{class_idx} {' '.join(formatted_coords)}\n") - for phase in ['train', 'val']: - image_dir = dota_root_path / 'images' / phase - orig_label_dir = dota_root_path / 'labels' / f'{phase}_original' - save_dir = dota_root_path / 'labels' / phase + for phase in ["train", "val"]: + image_dir = dota_root_path / "images" / phase + orig_label_dir = dota_root_path / "labels" / f"{phase}_original" + save_dir = dota_root_path / "labels" / phase save_dir.mkdir(parents=True, exist_ok=True) image_paths = list(image_dir.iterdir()) - for image_path in TQDM(image_paths, desc=f'Processing {phase} images'): - if image_path.suffix != '.png': + for image_path in TQDM(image_paths, desc=f"Processing {phase} images"): + if image_path.suffix != ".png": continue image_name_without_ext = image_path.stem img = cv2.imread(str(image_path)) @@ -237,8 +418,8 @@ def min_index(arr1, arr2): Find a pair of indexes with the shortest distance between two arrays of 2D points. Args: - arr1 (np.array): A NumPy array of shape (N, 2) representing N 2D points. - arr2 (np.array): A NumPy array of shape (M, 2) representing M 2D points. + arr1 (np.ndarray): A NumPy array of shape (N, 2) representing N 2D points. + arr2 (np.ndarray): A NumPy array of shape (M, 2) representing M 2D points. Returns: (tuple): A tuple containing the indexes of the points with the shortest distance in arr1 and arr2 respectively. @@ -263,31 +444,30 @@ def merge_multi_segment(segments): segments = [np.array(i).reshape(-1, 2) for i in segments] idx_list = [[] for _ in range(len(segments))] - # record the indexes with min distance between each segment + # Record the indexes with min distance between each segment for i in range(1, len(segments)): idx1, idx2 = min_index(segments[i - 1], segments[i]) idx_list[i - 1].append(idx1) idx_list[i].append(idx2) - # use two round to connect all the segments + # Use two round to connect all the segments for k in range(2): - # forward connection + # Forward connection if k == 0: for i, idx in enumerate(idx_list): - # middle segments have two indexes - # reverse the index of middle segments + # Middle segments have two indexes, reverse the index of middle segments if len(idx) == 2 and idx[0] > idx[1]: idx = idx[::-1] segments[i] = segments[i][::-1, :] segments[i] = np.roll(segments[i], -idx[0], axis=0) segments[i] = np.concatenate([segments[i], segments[i][:1]]) - # deal with the first segment and the last one + # Deal with the first segment and the last one if i in [0, len(idx_list) - 1]: s.append(segments[i]) else: idx = [0, idx[1] - idx[0]] - s.append(segments[i][idx[0]:idx[1] + 1]) + s.append(segments[i][idx[0] : idx[1] + 1]) else: for i in range(len(idx_list) - 1, -1, -1): @@ -296,3 +476,67 @@ def merge_multi_segment(segments): nidx = abs(idx[1] - idx[0]) s.append(segments[i][nidx:]) return s + + +def yolo_bbox2segment(im_dir, save_dir=None, sam_model="sam_b.pt"): + """ + Converts existing object detection dataset (bounding boxes) to segmentation dataset or oriented bounding box (OBB) + in YOLO format. Generates segmentation data using SAM auto-annotator as needed. + + Args: + im_dir (str | Path): Path to image directory to convert. + save_dir (str | Path): Path to save the generated labels, labels will be saved + into `labels-segment` in the same directory level of `im_dir` if save_dir is None. Default: None. + sam_model (str): Segmentation model to use for intermediate segmentation data; optional. + + Notes: + The input directory structure assumed for dataset: + + - im_dir + ├─ 001.jpg + ├─ .. + └─ NNN.jpg + - labels + ├─ 001.txt + ├─ .. + └─ NNN.txt + """ + from ultralytics.data import YOLODataset + from ultralytics.utils.ops import xywh2xyxy + from ultralytics.utils import LOGGER + from ultralytics import SAM + from tqdm import tqdm + + # NOTE: add placeholder to pass class index check + dataset = YOLODataset(im_dir, data=dict(names=list(range(1000)))) + if len(dataset.labels[0]["segments"]) > 0: # if it's segment data + LOGGER.info("Segmentation labels detected, no need to generate new ones!") + return + + LOGGER.info("Detection labels detected, generating segment labels by SAM model!") + sam_model = SAM(sam_model) + for l in tqdm(dataset.labels, total=len(dataset.labels), desc="Generating segment labels"): + h, w = l["shape"] + boxes = l["bboxes"] + if len(boxes) == 0: # skip empty labels + continue + boxes[:, [0, 2]] *= w + boxes[:, [1, 3]] *= h + im = cv2.imread(l["im_file"]) + sam_results = sam_model(im, bboxes=xywh2xyxy(boxes), verbose=False, save=False) + l["segments"] = sam_results[0].masks.xyn + + save_dir = Path(save_dir) if save_dir else Path(im_dir).parent / "labels-segment" + save_dir.mkdir(parents=True, exist_ok=True) + for l in dataset.labels: + texts = [] + lb_name = Path(l["im_file"]).with_suffix(".txt").name + txt_file = save_dir / lb_name + cls = l["cls"] + for i, s in enumerate(l["segments"]): + line = (int(cls[i]), *s.reshape(-1)) + texts.append(("%g " * len(line)).rstrip() % line) + if texts: + with open(txt_file, "a") as f: + f.writelines(text + "\n" for text in texts) + LOGGER.info(f"Generated segment labels saved in {save_dir}") diff --git a/ultralytics/data/dataset.py b/ultralytics/data/dataset.py index 65fe141..42b7cc1 100644 --- a/ultralytics/data/dataset.py +++ b/ultralytics/data/dataset.py @@ -8,15 +8,16 @@ import cv2 import numpy as np import torch import torchvision +from PIL import Image from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable - -from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms +from ultralytics.utils.ops import resample_segments +from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms from .base import BaseDataset from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label # Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8 -DATASET_CACHE_VERSION = '1.0.3' +DATASET_CACHE_VERSION = "1.0.3" class YOLODataset(BaseDataset): @@ -25,40 +26,54 @@ class YOLODataset(BaseDataset): Args: data (dict, optional): A dataset YAML dictionary. Defaults to None. - use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False. - use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False. + task (str): An explicit arg to point current task, Defaults to 'detect'. Returns: (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model. """ - def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs): - self.use_segments = use_segments - self.use_keypoints = use_keypoints + def __init__(self, *args, data=None, task="detect", **kwargs): + """Initializes the YOLODataset with optional configurations for segments and keypoints.""" + self.use_segments = task == "segment" + self.use_keypoints = task == "pose" + self.use_obb = task == "obb" self.data = data - assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.' + assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints." super().__init__(*args, **kwargs) - def cache_labels(self, path=Path('./labels.cache')): - """Cache dataset labels, check images and read shapes. + def cache_labels(self, path=Path("./labels.cache")): + """ + Cache dataset labels, check images and read shapes. + Args: - path (Path): path where to save the cache file (default: Path('./labels.cache')). + path (Path): Path where to save the cache file. Default is Path('./labels.cache'). + Returns: (dict): labels. """ - x = {'labels': []} + x = {"labels": []} nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages - desc = f'{self.prefix}Scanning {path.parent / path.stem}...' + desc = f"{self.prefix}Scanning {path.parent / path.stem}..." total = len(self.im_files) - nkpt, ndim = self.data.get('kpt_shape', (0, 0)) + nkpt, ndim = self.data.get("kpt_shape", (0, 0)) if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)): - raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " - "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'") + raise ValueError( + "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " + "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'" + ) with ThreadPool(NUM_THREADS) as pool: - results = pool.imap(func=verify_image_label, - iterable=zip(self.im_files, self.label_files, repeat(self.prefix), - repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt), - repeat(ndim))) + results = pool.imap( + func=verify_image_label, + iterable=zip( + self.im_files, + self.label_files, + repeat(self.prefix), + repeat(self.use_keypoints), + repeat(len(self.data["names"])), + repeat(nkpt), + repeat(ndim), + ), + ) pbar = TQDM(results, desc=desc, total=total) for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar: nm += nm_f @@ -66,7 +81,7 @@ class YOLODataset(BaseDataset): ne += ne_f nc += nc_f if im_file: - x['labels'].append( + x["labels"].append( dict( im_file=im_file, shape=shape, @@ -75,60 +90,63 @@ class YOLODataset(BaseDataset): segments=segments, keypoints=keypoint, normalized=True, - bbox_format='xywh')) + bbox_format="xywh", + ) + ) if msg: msgs.append(msg) - pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt' + pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" pbar.close() if msgs: - LOGGER.info('\n'.join(msgs)) + LOGGER.info("\n".join(msgs)) if nf == 0: - LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}') - x['hash'] = get_hash(self.label_files + self.im_files) - x['results'] = nf, nm, ne, nc, len(self.im_files) - x['msgs'] = msgs # warnings + LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}") + x["hash"] = get_hash(self.label_files + self.im_files) + x["results"] = nf, nm, ne, nc, len(self.im_files) + x["msgs"] = msgs # warnings save_dataset_cache_file(self.prefix, path, x) return x def get_labels(self): """Returns dictionary of labels for YOLO training.""" self.label_files = img2label_paths(self.im_files) - cache_path = Path(self.label_files[0]).parent.with_suffix('.cache') + cache_path = Path(self.label_files[0]).parent.with_suffix(".cache") try: cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file - assert cache['version'] == DATASET_CACHE_VERSION # matches current version - assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash + assert cache["version"] == DATASET_CACHE_VERSION # matches current version + assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash except (FileNotFoundError, AssertionError, AttributeError): cache, exists = self.cache_labels(cache_path), False # run cache ops # Display cache - nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total + nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total if exists and LOCAL_RANK in (-1, 0): - d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt' + d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results - if cache['msgs']: - LOGGER.info('\n'.join(cache['msgs'])) # display warnings + if cache["msgs"]: + LOGGER.info("\n".join(cache["msgs"])) # display warnings # Read cache - [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items - labels = cache['labels'] + [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items + labels = cache["labels"] if not labels: - LOGGER.warning(f'WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}') - self.im_files = [lb['im_file'] for lb in labels] # update im_files + LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}") + self.im_files = [lb["im_file"] for lb in labels] # update im_files # Check if the dataset is all boxes or all segments - lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels) + lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels) len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths)) if len_segments and len_boxes != len_segments: LOGGER.warning( - f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, ' - f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. ' - 'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.') + f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, " + f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. " + "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset." + ) for lb in labels: - lb['segments'] = [] + lb["segments"] = [] if len_cls == 0: - LOGGER.warning(f'WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}') + LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}") return labels def build_transforms(self, hyp=None): @@ -140,13 +158,18 @@ class YOLODataset(BaseDataset): else: transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)]) transforms.append( - Format(bbox_format='xywh', - normalize=True, - return_mask=self.use_segments, - return_keypoint=self.use_keypoints, - batch_idx=True, - mask_ratio=hyp.mask_ratio, - mask_overlap=hyp.overlap_mask)) + Format( + bbox_format="xywh", + normalize=True, + return_mask=self.use_segments, + return_keypoint=self.use_keypoints, + return_obb=self.use_obb, + batch_idx=True, + mask_ratio=hyp.mask_ratio, + mask_overlap=hyp.overlap_mask, + bgr=hyp.bgr if self.augment else 0.0, # only affect training. + ) + ) return transforms def close_mosaic(self, hyp): @@ -157,15 +180,28 @@ class YOLODataset(BaseDataset): self.transforms = self.build_transforms(hyp) def update_labels_info(self, label): - """custom your label format here.""" - # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label - # we can make it also support classification and semantic segmentation by add or remove some dict keys there. - bboxes = label.pop('bboxes') - segments = label.pop('segments') - keypoints = label.pop('keypoints', None) - bbox_format = label.pop('bbox_format') - normalized = label.pop('normalized') - label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) + """ + Custom your label format here. + + Note: + cls is not with bboxes now, classification and semantic segmentation need an independent cls label + Can also support classification and semantic segmentation by adding or removing dict keys there. + """ + bboxes = label.pop("bboxes") + segments = label.pop("segments", []) + keypoints = label.pop("keypoints", None) + bbox_format = label.pop("bbox_format") + normalized = label.pop("normalized") + + # NOTE: do NOT resample oriented boxes + segment_resamples = 100 if self.use_obb else 1000 + if len(segments) > 0: + # list[np.array(1000, 2)] * num_samples + # (N, 1000, 2) + segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0) + else: + segments = np.zeros((0, segment_resamples, 2), dtype=np.float32) + label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) return label @staticmethod @@ -176,65 +212,75 @@ class YOLODataset(BaseDataset): values = list(zip(*[list(b.values()) for b in batch])) for i, k in enumerate(keys): value = values[i] - if k == 'img': + if k == "img": value = torch.stack(value, 0) - if k in ['masks', 'keypoints', 'bboxes', 'cls']: + if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]: value = torch.cat(value, 0) new_batch[k] = value - new_batch['batch_idx'] = list(new_batch['batch_idx']) - for i in range(len(new_batch['batch_idx'])): - new_batch['batch_idx'][i] += i # add target image index for build_targets() - new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0) + new_batch["batch_idx"] = list(new_batch["batch_idx"]) + for i in range(len(new_batch["batch_idx"])): + new_batch["batch_idx"][i] += i # add target image index for build_targets() + new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0) return new_batch # Classification dataloaders ------------------------------------------------------------------------------------------- class ClassificationDataset(torchvision.datasets.ImageFolder): """ - YOLO Classification Dataset. + Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image + augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep + learning models, with optional image transformations and caching mechanisms to speed up training. - Args: - root (str): Dataset path. + This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images + in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process + to ensure data integrity and consistency. Attributes: - cache_ram (bool): True if images should be cached in RAM, False otherwise. - cache_disk (bool): True if images should be cached on disk, False otherwise. - samples (list): List of samples containing file, index, npy, and im. - torch_transforms (callable): torchvision transforms applied to the dataset. - album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True. + cache_ram (bool): Indicates if caching in RAM is enabled. + cache_disk (bool): Indicates if caching on disk is enabled. + samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache + file (if caching on disk), and optionally the loaded image array (if caching in RAM). + torch_transforms (callable): PyTorch transforms to be applied to the images. """ - def __init__(self, root, args, augment=False, cache=False, prefix=''): + def __init__(self, root, args, augment=False, prefix=""): """ Initialize YOLO object with root, image size, augmentations, and cache settings. Args: - root (str): Dataset path. - args (Namespace): Argument parser containing dataset related settings. - augment (bool, optional): True if dataset should be augmented, False otherwise. Defaults to False. - cache (bool | str | optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False. + root (str): Path to the dataset directory where images are stored in a class-specific folder structure. + args (Namespace): Configuration containing dataset-related settings such as image size, augmentation + parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction + of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training), + `auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`. + augment (bool, optional): Whether to apply augmentations to the dataset. Default is False. + prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and + debugging. Default is an empty string. """ super().__init__(root=root) if augment and args.fraction < 1.0: # reduce training fraction - self.samples = self.samples[:round(len(self.samples) * args.fraction)] - self.prefix = colorstr(f'{prefix}: ') if prefix else '' - self.cache_ram = cache is True or cache == 'ram' - self.cache_disk = cache == 'disk' + self.samples = self.samples[: round(len(self.samples) * args.fraction)] + self.prefix = colorstr(f"{prefix}: ") if prefix else "" + self.cache_ram = args.cache is True or args.cache == "ram" # cache images into RAM + self.cache_disk = args.cache == "disk" # cache images on hard drive as uncompressed *.npy files self.samples = self.verify_images() # filter out bad images - self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im - self.torch_transforms = classify_transforms(args.imgsz) - self.album_transforms = classify_albumentations( - augment=augment, - size=args.imgsz, - scale=(1.0 - args.scale, 1.0), # (0.08, 1.0) - hflip=args.fliplr, - vflip=args.flipud, - hsv_h=args.hsv_h, # HSV-Hue augmentation (fraction) - hsv_s=args.hsv_s, # HSV-Saturation augmentation (fraction) - hsv_v=args.hsv_v, # HSV-Value augmentation (fraction) - mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN - std=(1.0, 1.0, 1.0), # IMAGENET_STD - auto_aug=False) if augment else None + self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im + scale = (1.0 - args.scale, 1.0) # (0.08, 1.0) + self.torch_transforms = ( + classify_augmentations( + size=args.imgsz, + scale=scale, + hflip=args.fliplr, + vflip=args.flipud, + erasing=args.erasing, + auto_augment=args.auto_augment, + hsv_h=args.hsv_h, + hsv_s=args.hsv_s, + hsv_v=args.hsv_v, + ) + if augment + else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction) + ) def __getitem__(self, i): """Returns subset of data and targets corresponding to given indices.""" @@ -247,30 +293,30 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): im = np.load(fn) else: # read image im = cv2.imread(f) # BGR - if self.album_transforms: - sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image'] - else: - sample = self.torch_transforms(im) - return {'img': sample, 'cls': j} + # Convert NumPy array to PIL image + im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB)) + sample = self.torch_transforms(im) + return {"img": sample, "cls": j} def __len__(self) -> int: + """Return the total number of samples in the dataset.""" return len(self.samples) def verify_images(self): """Verify all images in dataset.""" - desc = f'{self.prefix}Scanning {self.root}...' - path = Path(self.root).with_suffix('.cache') # *.cache file path + desc = f"{self.prefix}Scanning {self.root}..." + path = Path(self.root).with_suffix(".cache") # *.cache file path with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError): cache = load_dataset_cache_file(path) # attempt to load a *.cache file - assert cache['version'] == DATASET_CACHE_VERSION # matches current version - assert cache['hash'] == get_hash([x[0] for x in self.samples]) # identical hash - nf, nc, n, samples = cache.pop('results') # found, missing, empty, corrupt, total + assert cache["version"] == DATASET_CACHE_VERSION # matches current version + assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash + nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total if LOCAL_RANK in (-1, 0): - d = f'{desc} {nf} images, {nc} corrupt' + d = f"{desc} {nf} images, {nc} corrupt" TQDM(None, desc=d, total=n, initial=n) - if cache['msgs']: - LOGGER.info('\n'.join(cache['msgs'])) # display warnings + if cache["msgs"]: + LOGGER.info("\n".join(cache["msgs"])) # display warnings return samples # Run scan if *.cache retrieval failed @@ -285,13 +331,13 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): msgs.append(msg) nf += nf_f nc += nc_f - pbar.desc = f'{desc} {nf} images, {nc} corrupt' + pbar.desc = f"{desc} {nf} images, {nc} corrupt" pbar.close() if msgs: - LOGGER.info('\n'.join(msgs)) - x['hash'] = get_hash([x[0] for x in self.samples]) - x['results'] = nf, nc, len(samples), samples - x['msgs'] = msgs # warnings + LOGGER.info("\n".join(msgs)) + x["hash"] = get_hash([x[0] for x in self.samples]) + x["results"] = nf, nc, len(samples), samples + x["msgs"] = msgs # warnings save_dataset_cache_file(self.prefix, path, x) return samples @@ -299,6 +345,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): def load_dataset_cache_file(path): """Load an Ultralytics *.cache dictionary from path.""" import gc + gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585 cache = np.load(str(path), allow_pickle=True).item() # load dict gc.enable() @@ -307,19 +354,29 @@ def load_dataset_cache_file(path): def save_dataset_cache_file(prefix, path, x): """Save an Ultralytics dataset *.cache dictionary x to path.""" - x['version'] = DATASET_CACHE_VERSION # add cache version + x["version"] = DATASET_CACHE_VERSION # add cache version if is_dir_writeable(path.parent): if path.exists(): path.unlink() # remove *.cache file if exists np.save(str(path), x) # save cache for next time - path.with_suffix('.cache.npy').rename(path) # remove .npy suffix - LOGGER.info(f'{prefix}New cache created: {path}') + path.with_suffix(".cache.npy").rename(path) # remove .npy suffix + LOGGER.info(f"{prefix}New cache created: {path}") else: - LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.') + LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.") # TODO: support semantic segmentation class SemanticDataset(BaseDataset): + """ + Semantic Segmentation Dataset. + + This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities + from the BaseDataset class. + + Note: + This class is currently a placeholder and needs to be populated with methods and attributes for supporting + semantic segmentation tasks. + """ def __init__(self): """Initialize a SemanticDataset object.""" diff --git a/ultralytics/data/explorer/__init__.py b/ultralytics/data/explorer/__init__.py new file mode 100644 index 0000000..ce594dc --- /dev/null +++ b/ultralytics/data/explorer/__init__.py @@ -0,0 +1,5 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from .utils import plot_query_result + +__all__ = ["plot_query_result"] diff --git a/ultralytics/data/explorer/__pycache__/__init__.cpython-312.pyc b/ultralytics/data/explorer/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..31a7d18 Binary files /dev/null and b/ultralytics/data/explorer/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/data/explorer/__pycache__/__init__.cpython-39.pyc b/ultralytics/data/explorer/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..93c0a20 Binary files /dev/null and b/ultralytics/data/explorer/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/data/explorer/__pycache__/explorer.cpython-312.pyc b/ultralytics/data/explorer/__pycache__/explorer.cpython-312.pyc new file mode 100644 index 0000000..00463e0 Binary files /dev/null and b/ultralytics/data/explorer/__pycache__/explorer.cpython-312.pyc differ diff --git a/ultralytics/data/explorer/__pycache__/explorer.cpython-39.pyc b/ultralytics/data/explorer/__pycache__/explorer.cpython-39.pyc new file mode 100644 index 0000000..c0d300d Binary files /dev/null and b/ultralytics/data/explorer/__pycache__/explorer.cpython-39.pyc differ diff --git a/ultralytics/data/explorer/__pycache__/utils.cpython-312.pyc b/ultralytics/data/explorer/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000..1303eee Binary files /dev/null and b/ultralytics/data/explorer/__pycache__/utils.cpython-312.pyc differ diff --git a/ultralytics/data/explorer/__pycache__/utils.cpython-39.pyc b/ultralytics/data/explorer/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000..c17a0e4 Binary files /dev/null and b/ultralytics/data/explorer/__pycache__/utils.cpython-39.pyc differ diff --git a/ultralytics/data/explorer/explorer.py b/ultralytics/data/explorer/explorer.py new file mode 100644 index 0000000..d21a5c2 --- /dev/null +++ b/ultralytics/data/explorer/explorer.py @@ -0,0 +1,472 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from io import BytesIO +from pathlib import Path +from typing import Any, List, Tuple, Union + +import cv2 +import numpy as np +import torch +from PIL import Image +from matplotlib import pyplot as plt +from pandas import DataFrame +from tqdm import tqdm + +from ultralytics.data.augment import Format +from ultralytics.data.dataset import YOLODataset +from ultralytics.data.utils import check_det_dataset +from ultralytics.models.yolo.model import YOLO +from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks, USER_CONFIG_DIR +from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch + + +class ExplorerDataset(YOLODataset): + def __init__(self, *args, data: dict = None, **kwargs) -> None: + super().__init__(*args, data=data, **kwargs) + + def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]: + """Loads 1 image from dataset index 'i' without any resize ops.""" + im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] + if im is None: # not cached in RAM + if fn.exists(): # load npy + im = np.load(fn) + else: # read image + im = cv2.imread(f) # BGR + if im is None: + raise FileNotFoundError(f"Image Not Found {f}") + h0, w0 = im.shape[:2] # orig hw + return im, (h0, w0), im.shape[:2] + + return self.ims[i], self.im_hw0[i], self.im_hw[i] + + def build_transforms(self, hyp: IterableSimpleNamespace = None): + """Creates transforms for dataset images without resizing.""" + return Format( + bbox_format="xyxy", + normalize=False, + return_mask=self.use_segments, + return_keypoint=self.use_keypoints, + batch_idx=True, + mask_ratio=hyp.mask_ratio, + mask_overlap=hyp.overlap_mask, + ) + + +class Explorer: + def __init__( + self, + data: Union[str, Path] = "coco128.yaml", + model: str = "yolov8n.pt", + uri: str = USER_CONFIG_DIR / "explorer", + ) -> None: + # Note duckdb==0.10.0 bug https://github.com/ultralytics/ultralytics/pull/8181 + checks.check_requirements(["lancedb>=0.4.3", "duckdb<=0.9.2"]) + import lancedb + + self.connection = lancedb.connect(uri) + self.table_name = Path(data).name.lower() + "_" + model.lower() + self.sim_idx_base_name = ( + f"{self.table_name}_sim_idx".lower() + ) # Use this name and append thres and top_k to reuse the table + self.model = YOLO(model) + self.data = data # None + self.choice_set = None + + self.table = None + self.progress = 0 + + def create_embeddings_table(self, force: bool = False, split: str = "train") -> None: + """ + Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it + already exists. Pass force=True to overwrite the existing table. + + Args: + force (bool): Whether to overwrite the existing table or not. Defaults to False. + split (str): Split of the dataset to use. Defaults to 'train'. + + Example: + ```python + exp = Explorer() + exp.create_embeddings_table() + ``` + """ + if self.table is not None and not force: + LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.") + return + if self.table_name in self.connection.table_names() and not force: + LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.") + self.table = self.connection.open_table(self.table_name) + self.progress = 1 + return + if self.data is None: + raise ValueError("Data must be provided to create embeddings table") + + data_info = check_det_dataset(self.data) + if split not in data_info: + raise ValueError( + f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}" + ) + + choice_set = data_info[split] + choice_set = choice_set if isinstance(choice_set, list) else [choice_set] + self.choice_set = choice_set + dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task) + + # Create the table schema + batch = dataset[0] + vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0] + table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite") + table.add( + self._yield_batches( + dataset, + data_info, + self.model, + exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"], + ) + ) + + self.table = table + + def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]): + """Generates batches of data for embedding, excluding specified keys.""" + for i in tqdm(range(len(dataset))): + self.progress = float(i + 1) / len(dataset) + batch = dataset[i] + for k in exclude_keys: + batch.pop(k, None) + batch = sanitize_batch(batch, data_info) + batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist() + yield [batch] + + def query( + self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25 + ) -> Any: # pyarrow.Table + """ + Query the table for similar images. Accepts a single image or a list of images. + + Args: + imgs (str or list): Path to the image or a list of paths to the images. + limit (int): Number of results to return. + + Returns: + (pyarrow.Table): An arrow table containing the results. Supports converting to: + - pandas dataframe: `result.to_pandas()` + - dict of lists: `result.to_pydict()` + + Example: + ```python + exp = Explorer() + exp.create_embeddings_table() + similar = exp.query(img='https://ultralytics.com/images/zidane.jpg') + ``` + """ + if self.table is None: + raise ValueError("Table is not created. Please create the table first.") + if isinstance(imgs, str): + imgs = [imgs] + assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}" + embeds = self.model.embed(imgs) + # Get avg if multiple images are passed (len > 1) + embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy() + return self.table.search(embeds).limit(limit).to_arrow() + + def sql_query( + self, query: str, return_type: str = "pandas" + ) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table + """ + Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown. + + Args: + query (str): SQL query to run. + return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'. + + Returns: + (pyarrow.Table): An arrow table containing the results. + + Example: + ```python + exp = Explorer() + exp.create_embeddings_table() + query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'" + result = exp.sql_query(query) + ``` + """ + assert return_type in { + "pandas", + "arrow", + }, f"Return type should be either `pandas` or `arrow`, but got {return_type}" + import duckdb + + if self.table is None: + raise ValueError("Table is not created. Please create the table first.") + + # Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this. + table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB + if not query.startswith("SELECT") and not query.startswith("WHERE"): + raise ValueError( + f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}" + ) + if query.startswith("WHERE"): + query = f"SELECT * FROM 'table' {query}" + LOGGER.info(f"Running query: {query}") + + rs = duckdb.sql(query) + if return_type == "arrow": + return rs.arrow() + elif return_type == "pandas": + return rs.df() + + def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image: + """ + Plot the results of a SQL-Like query on the table. + Args: + query (str): SQL query to run. + labels (bool): Whether to plot the labels or not. + + Returns: + (PIL.Image): Image containing the plot. + + Example: + ```python + exp = Explorer() + exp.create_embeddings_table() + query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'" + result = exp.plot_sql_query(query) + ``` + """ + result = self.sql_query(query, return_type="arrow") + if len(result) == 0: + LOGGER.info("No results found.") + return None + img = plot_query_result(result, plot_labels=labels) + return Image.fromarray(img) + + def get_similar( + self, + img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, + idx: Union[int, List[int]] = None, + limit: int = 25, + return_type: str = "pandas", + ) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table + """ + Query the table for similar images. Accepts a single image or a list of images. + + Args: + img (str or list): Path to the image or a list of paths to the images. + idx (int or list): Index of the image in the table or a list of indexes. + limit (int): Number of results to return. Defaults to 25. + return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'. + + Returns: + (pandas.DataFrame): A dataframe containing the results. + + Example: + ```python + exp = Explorer() + exp.create_embeddings_table() + similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg') + ``` + """ + assert return_type in { + "pandas", + "arrow", + }, f"Return type should be either `pandas` or `arrow`, but got {return_type}" + img = self._check_imgs_or_idxs(img, idx) + similar = self.query(img, limit=limit) + + if return_type == "arrow": + return similar + elif return_type == "pandas": + return similar.to_pandas() + + def plot_similar( + self, + img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, + idx: Union[int, List[int]] = None, + limit: int = 25, + labels: bool = True, + ) -> Image.Image: + """ + Plot the similar images. Accepts images or indexes. + + Args: + img (str or list): Path to the image or a list of paths to the images. + idx (int or list): Index of the image in the table or a list of indexes. + labels (bool): Whether to plot the labels or not. + limit (int): Number of results to return. Defaults to 25. + + Returns: + (PIL.Image): Image containing the plot. + + Example: + ```python + exp = Explorer() + exp.create_embeddings_table() + similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg') + ``` + """ + similar = self.get_similar(img, idx, limit, return_type="arrow") + if len(similar) == 0: + LOGGER.info("No results found.") + return None + img = plot_query_result(similar, plot_labels=labels) + return Image.fromarray(img) + + def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame: + """ + Calculate the similarity index of all the images in the table. Here, the index will contain the data points that + are max_dist or closer to the image in the embedding space at a given index. + + Args: + max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2. + top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running + vector search. Defaults: None. + force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. + + Returns: + (pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, and columns + include indices of similar images and their respective distances. + + Example: + ```python + exp = Explorer() + exp.create_embeddings_table() + sim_idx = exp.similarity_index() + ``` + """ + if self.table is None: + raise ValueError("Table is not created. Please create the table first.") + sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower() + if sim_idx_table_name in self.connection.table_names() and not force: + LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.") + return self.connection.open_table(sim_idx_table_name).to_pandas() + + if top_k and not (1.0 >= top_k >= 0.0): + raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}") + if max_dist < 0.0: + raise ValueError(f"max_dist must be greater than 0. Got {max_dist}") + + top_k = int(top_k * len(self.table)) if top_k else len(self.table) + top_k = max(top_k, 1) + features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict() + im_files = features["im_file"] + embeddings = features["vector"] + + sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite") + + def _yield_sim_idx(): + """Generates a dataframe with similarity indices and distances for images.""" + for i in tqdm(range(len(embeddings))): + sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}") + yield [ + { + "idx": i, + "im_file": im_files[i], + "count": len(sim_idx), + "sim_im_files": sim_idx["im_file"].tolist(), + } + ] + + sim_table.add(_yield_sim_idx()) + self.sim_index = sim_table + return sim_table.to_pandas() + + def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image: + """ + Plot the similarity index of all the images in the table. Here, the index will contain the data points that are + max_dist or closer to the image in the embedding space at a given index. + + Args: + max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2. + top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when + running vector search. Defaults to 0.01. + force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. + + Returns: + (PIL.Image): Image containing the plot. + + Example: + ```python + exp = Explorer() + exp.create_embeddings_table() + + similarity_idx_plot = exp.plot_similarity_index() + similarity_idx_plot.show() # view image preview + similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file + ``` + """ + sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force) + sim_count = sim_idx["count"].tolist() + sim_count = np.array(sim_count) + + indices = np.arange(len(sim_count)) + + # Create the bar plot + plt.bar(indices, sim_count) + + # Customize the plot (optional) + plt.xlabel("data idx") + plt.ylabel("Count") + plt.title("Similarity Count") + buffer = BytesIO() + plt.savefig(buffer, format="png") + buffer.seek(0) + + # Use Pillow to open the image from the buffer + return Image.fromarray(np.array(Image.open(buffer))) + + def _check_imgs_or_idxs( + self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]] + ) -> List[np.ndarray]: + if img is None and idx is None: + raise ValueError("Either img or idx must be provided.") + if img is not None and idx is not None: + raise ValueError("Only one of img or idx must be provided.") + if idx is not None: + idx = idx if isinstance(idx, list) else [idx] + img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"] + + return img if isinstance(img, list) else [img] + + def ask_ai(self, query): + """ + Ask AI a question. + + Args: + query (str): Question to ask. + + Returns: + (pandas.DataFrame): A dataframe containing filtered results to the SQL query. + + Example: + ```python + exp = Explorer() + exp.create_embeddings_table() + answer = exp.ask_ai('Show images with 1 person and 2 dogs') + ``` + """ + result = prompt_sql_query(query) + try: + df = self.sql_query(result) + except Exception as e: + LOGGER.error("AI generated query is not valid. Please try again with a different prompt") + LOGGER.error(e) + return None + return df + + def visualize(self, result): + """ + Visualize the results of a query. TODO. + + Args: + result (pyarrow.Table): Table containing the results of a query. + """ + pass + + def generate_report(self, result): + """ + Generate a report of the dataset. + + TODO + """ + pass diff --git a/ultralytics/data/explorer/gui/__init__.py b/ultralytics/data/explorer/gui/__init__.py new file mode 100644 index 0000000..9e68dc1 --- /dev/null +++ b/ultralytics/data/explorer/gui/__init__.py @@ -0,0 +1 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license diff --git a/ultralytics/data/explorer/gui/dash.py b/ultralytics/data/explorer/gui/dash.py new file mode 100644 index 0000000..b082d49 --- /dev/null +++ b/ultralytics/data/explorer/gui/dash.py @@ -0,0 +1,268 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import time +from threading import Thread + +import pandas as pd + +from ultralytics import Explorer +from ultralytics.utils import ROOT, SETTINGS +from ultralytics.utils.checks import check_requirements + +check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.3")) + +import streamlit as st +from streamlit_select import image_select + + +def _get_explorer(): + """Initializes and returns an instance of the Explorer class.""" + exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model")) + thread = Thread( + target=exp.create_embeddings_table, kwargs={"force": st.session_state.get("force_recreate_embeddings")} + ) + thread.start() + progress_bar = st.progress(0, text="Creating embeddings table...") + while exp.progress < 1: + time.sleep(0.1) + progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%") + thread.join() + st.session_state["explorer"] = exp + progress_bar.empty() + + +def init_explorer_form(): + """Initializes an Explorer instance and creates embeddings table with progress tracking.""" + datasets = ROOT / "cfg" / "datasets" + ds = [d.name for d in datasets.glob("*.yaml")] + models = [ + "yolov8n.pt", + "yolov8s.pt", + "yolov8m.pt", + "yolov8l.pt", + "yolov8x.pt", + "yolov8n-seg.pt", + "yolov8s-seg.pt", + "yolov8m-seg.pt", + "yolov8l-seg.pt", + "yolov8x-seg.pt", + "yolov8n-pose.pt", + "yolov8s-pose.pt", + "yolov8m-pose.pt", + "yolov8l-pose.pt", + "yolov8x-pose.pt", + ] + with st.form(key="explorer_init_form"): + col1, col2 = st.columns(2) + with col1: + st.selectbox("Select dataset", ds, key="dataset", index=ds.index("coco128.yaml")) + with col2: + st.selectbox("Select model", models, key="model") + st.checkbox("Force recreate embeddings", key="force_recreate_embeddings") + + st.form_submit_button("Explore", on_click=_get_explorer) + + +def query_form(): + """Sets up a form in Streamlit to initialize Explorer with dataset and model selection.""" + with st.form("query_form"): + col1, col2 = st.columns([0.8, 0.2]) + with col1: + st.text_input( + "Query", + "WHERE labels LIKE '%person%' AND labels LIKE '%dog%'", + label_visibility="collapsed", + key="query", + ) + with col2: + st.form_submit_button("Query", on_click=run_sql_query) + + +def ai_query_form(): + """Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection.""" + with st.form("ai_query_form"): + col1, col2 = st.columns([0.8, 0.2]) + with col1: + st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query") + with col2: + st.form_submit_button("Ask AI", on_click=run_ai_query) + + +def find_similar_imgs(imgs): + """Initializes a Streamlit form for AI-based image querying with custom input.""" + exp = st.session_state["explorer"] + similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow") + paths = similar.to_pydict()["im_file"] + st.session_state["imgs"] = paths + st.session_state["res"] = similar + + +def similarity_form(selected_imgs): + """Initializes a form for AI-based image querying with custom input in Streamlit.""" + st.write("Similarity Search") + with st.form("similarity_form"): + subcol1, subcol2 = st.columns([1, 1]) + with subcol1: + st.number_input( + "limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit" + ) + + with subcol2: + disabled = not len(selected_imgs) + st.write("Selected: ", len(selected_imgs)) + st.form_submit_button( + "Search", + disabled=disabled, + on_click=find_similar_imgs, + args=(selected_imgs,), + ) + if disabled: + st.error("Select at least one image to search.") + + +# def persist_reset_form(): +# with st.form("persist_reset"): +# col1, col2 = st.columns([1, 1]) +# with col1: +# st.form_submit_button("Reset", on_click=reset) +# +# with col2: +# st.form_submit_button("Persist", on_click=update_state, args=("PERSISTING", True)) + + +def run_sql_query(): + """Executes an SQL query and returns the results.""" + st.session_state["error"] = None + query = st.session_state.get("query") + if query.rstrip().lstrip(): + exp = st.session_state["explorer"] + res = exp.sql_query(query, return_type="arrow") + st.session_state["imgs"] = res.to_pydict()["im_file"] + st.session_state["res"] = res + + +def run_ai_query(): + """Execute SQL query and update session state with query results.""" + if not SETTINGS["openai_api_key"]: + st.session_state["error"] = ( + 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."' + ) + return + st.session_state["error"] = None + query = st.session_state.get("ai_query") + if query.rstrip().lstrip(): + exp = st.session_state["explorer"] + res = exp.ask_ai(query) + if not isinstance(res, pd.DataFrame) or res.empty: + st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it." + return + st.session_state["imgs"] = res["im_file"].to_list() + st.session_state["res"] = res + + +def reset_explorer(): + """Resets the explorer to its initial state by clearing session variables.""" + st.session_state["explorer"] = None + st.session_state["imgs"] = None + st.session_state["error"] = None + + +def utralytics_explorer_docs_callback(): + """Resets the explorer to its initial state by clearing session variables.""" + with st.container(border=True): + st.image( + "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg", + width=100, + ) + st.markdown( + "

This demo is built using Ultralytics Explorer API. Visit API docs to try examples & learn more

", + unsafe_allow_html=True, + help=None, + ) + st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/") + + +def layout(): + """Resets explorer session variables and provides documentation with a link to API docs.""" + st.set_page_config(layout="wide", initial_sidebar_state="collapsed") + st.markdown("

Ultralytics Explorer Demo

", unsafe_allow_html=True) + + if st.session_state.get("explorer") is None: + init_explorer_form() + return + + st.button(":arrow_backward: Select Dataset", on_click=reset_explorer) + exp = st.session_state.get("explorer") + col1, col2 = st.columns([0.75, 0.25], gap="small") + imgs = [] + if st.session_state.get("error"): + st.error(st.session_state["error"]) + else: + if st.session_state.get("imgs"): + imgs = st.session_state.get("imgs") + else: + imgs = exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"] + st.session_state["res"] = exp.table.to_arrow() + total_imgs, selected_imgs = len(imgs), [] + with col1: + subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5) + with subcol1: + st.write("Max Images Displayed:") + with subcol2: + num = st.number_input( + "Max Images Displayed", + min_value=0, + max_value=total_imgs, + value=min(500, total_imgs), + key="num_imgs_displayed", + label_visibility="collapsed", + ) + with subcol3: + st.write("Start Index:") + with subcol4: + start_idx = st.number_input( + "Start Index", + min_value=0, + max_value=total_imgs, + value=0, + key="start_index", + label_visibility="collapsed", + ) + with subcol5: + reset = st.button("Reset", use_container_width=False, key="reset") + if reset: + st.session_state["imgs"] = None + st.experimental_rerun() + + query_form() + ai_query_form() + if total_imgs: + labels, boxes, masks, kpts, classes = None, None, None, None, None + task = exp.model.task + if st.session_state.get("display_labels"): + labels = st.session_state.get("res").to_pydict()["labels"][start_idx : start_idx + num] + boxes = st.session_state.get("res").to_pydict()["bboxes"][start_idx : start_idx + num] + masks = st.session_state.get("res").to_pydict()["masks"][start_idx : start_idx + num] + kpts = st.session_state.get("res").to_pydict()["keypoints"][start_idx : start_idx + num] + classes = st.session_state.get("res").to_pydict()["cls"][start_idx : start_idx + num] + imgs_displayed = imgs[start_idx : start_idx + num] + selected_imgs = image_select( + f"Total samples: {total_imgs}", + images=imgs_displayed, + use_container_width=False, + # indices=[i for i in range(num)] if select_all else None, + labels=labels, + classes=classes, + bboxes=boxes, + masks=masks if task == "segment" else None, + kpts=kpts if task == "pose" else None, + ) + + with col2: + similarity_form(selected_imgs) + display_labels = st.checkbox("Labels", value=False, key="display_labels") + utralytics_explorer_docs_callback() + + +if __name__ == "__main__": + layout() diff --git a/ultralytics/data/explorer/utils.py b/ultralytics/data/explorer/utils.py new file mode 100644 index 0000000..d1c4b9b --- /dev/null +++ b/ultralytics/data/explorer/utils.py @@ -0,0 +1,166 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import getpass +from typing import List + +import cv2 +import numpy as np +import pandas as pd + +from ultralytics.data.augment import LetterBox +from ultralytics.utils import LOGGER as logger +from ultralytics.utils import SETTINGS +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.ops import xyxy2xywh +from ultralytics.utils.plotting import plot_images + + +def get_table_schema(vector_size): + """Extracts and returns the schema of a database table.""" + from lancedb.pydantic import LanceModel, Vector + + class Schema(LanceModel): + im_file: str + labels: List[str] + cls: List[int] + bboxes: List[List[float]] + masks: List[List[List[int]]] + keypoints: List[List[List[float]]] + vector: Vector(vector_size) + + return Schema + + +def get_sim_index_schema(): + """Returns a LanceModel schema for a database table with specified vector size.""" + from lancedb.pydantic import LanceModel + + class Schema(LanceModel): + idx: int + im_file: str + count: int + sim_im_files: List[str] + + return Schema + + +def sanitize_batch(batch, dataset_info): + """Sanitizes input batch for inference, ensuring correct format and dimensions.""" + batch["cls"] = batch["cls"].flatten().int().tolist() + box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1]) + batch["bboxes"] = [box for box, _ in box_cls_pair] + batch["cls"] = [cls for _, cls in box_cls_pair] + batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]] + batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]] + batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]] + return batch + + +def plot_query_result(similar_set, plot_labels=True): + """ + Plot images from the similar set. + + Args: + similar_set (list): Pyarrow or pandas object containing the similar data points + plot_labels (bool): Whether to plot labels or not + """ + similar_set = ( + similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict() + ) + empty_masks = [[[]]] + empty_boxes = [[]] + images = similar_set.get("im_file", []) + bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else [] + masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else [] + kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else [] + cls = similar_set.get("cls", []) + + plot_size = 640 + imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], [] + for i, imf in enumerate(images): + im = cv2.imread(imf) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + h, w = im.shape[:2] + r = min(plot_size / h, plot_size / w) + imgs.append(LetterBox(plot_size, center=False)(image=im).transpose(2, 0, 1)) + if plot_labels: + if len(bboxes) > i and len(bboxes[i]) > 0: + box = np.array(bboxes[i], dtype=np.float32) + box[:, [0, 2]] *= r + box[:, [1, 3]] *= r + plot_boxes.append(box) + if len(masks) > i and len(masks[i]) > 0: + mask = np.array(masks[i], dtype=np.uint8)[0] + plot_masks.append(LetterBox(plot_size, center=False)(image=mask)) + if len(kpts) > i and kpts[i] is not None: + kpt = np.array(kpts[i], dtype=np.float32) + kpt[:, :, :2] *= r + plot_kpts.append(kpt) + batch_idx.append(np.ones(len(np.array(bboxes[i], dtype=np.float32))) * i) + imgs = np.stack(imgs, axis=0) + masks = np.stack(plot_masks, axis=0) if plot_masks else np.zeros(0, dtype=np.uint8) + kpts = np.concatenate(plot_kpts, axis=0) if plot_kpts else np.zeros((0, 51), dtype=np.float32) + boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if plot_boxes else np.zeros(0, dtype=np.float32) + batch_idx = np.concatenate(batch_idx, axis=0) + cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0) + + return plot_images( + imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False + ) + + +def prompt_sql_query(query): + """Plots images with optional labels from a similar data set.""" + check_requirements("openai>=1.6.1") + from openai import OpenAI + + if not SETTINGS["openai_api_key"]: + logger.warning("OpenAI API key not found in settings. Please enter your API key below.") + openai_api_key = getpass.getpass("OpenAI API key: ") + SETTINGS.update({"openai_api_key": openai_api_key}) + openai = OpenAI(api_key=SETTINGS["openai_api_key"]) + + messages = [ + { + "role": "system", + "content": """ + You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on + the following schema and a user request. You only need to output the format with fixed selection + statement that selects everything from "'table'", like `SELECT * from 'table'` + + Schema: + im_file: string not null + labels: list not null + child 0, item: string + cls: list not null + child 0, item: int64 + bboxes: list> not null + child 0, item: list + child 0, item: double + masks: list>> not null + child 0, item: list> + child 0, item: list + child 0, item: int64 + keypoints: list>> not null + child 0, item: list> + child 0, item: list + child 0, item: double + vector: fixed_size_list[256] not null + child 0, item: float + + Some details about the schema: + - the "labels" column contains the string values like 'person' and 'dog' for the respective objects + in each image + - the "cls" column contains the integer values on these classes that map them the labels + + Example of a correct query: + request - Get all data points that contain 2 or more people and at least one dog + correct query- + SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1; + """, + }, + {"role": "user", "content": f"{query}"}, + ] + + response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages) + return response.choices[0].message.content diff --git a/ultralytics/data/loaders.py b/ultralytics/data/loaders.py index 6656596..4b89770 100644 --- a/ultralytics/data/loaders.py +++ b/ultralytics/data/loaders.py @@ -22,76 +22,114 @@ from ultralytics.utils.checks import check_requirements @dataclass class SourceTypes: - webcam: bool = False + """Class to represent various types of input sources for predictions.""" + + stream: bool = False screenshot: bool = False from_img: bool = False tensor: bool = False class LoadStreams: - """YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`.""" + """ + Stream Loader for various types of video streams, Supports RTSP, RTMP, HTTP, and TCP streams. - def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, stream_buffer=False): + Attributes: + sources (str): The source input paths or URLs for the video streams. + vid_stride (int): Video frame-rate stride, defaults to 1. + buffer (bool): Whether to buffer input streams, defaults to False. + running (bool): Flag to indicate if the streaming thread is running. + mode (str): Set to 'stream' indicating real-time capture. + imgs (list): List of image frames for each stream. + fps (list): List of FPS for each stream. + frames (list): List of total frames for each stream. + threads (list): List of threads for each stream. + shape (list): List of shapes for each stream. + caps (list): List of cv2.VideoCapture objects for each stream. + bs (int): Batch size for processing. + + Methods: + __init__: Initialize the stream loader. + update: Read stream frames in daemon thread. + close: Close stream loader and release resources. + __iter__: Returns an iterator object for the class. + __next__: Returns source paths, transformed, and original images for processing. + __len__: Return the length of the sources object. + + Example: + ```bash + yolo predict source='rtsp://example.com/media.mp4' + ``` + """ + + def __init__(self, sources="file.streams", vid_stride=1, buffer=False): """Initialize instance variables and check for consistent input stream shapes.""" torch.backends.cudnn.benchmark = True # faster for fixed-size inference - self.stream_buffer = stream_buffer # buffer input streams + self.buffer = buffer # buffer input streams self.running = True # running flag for Thread - self.mode = 'stream' - self.imgsz = imgsz + self.mode = "stream" self.vid_stride = vid_stride # video frame-rate stride + sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources] n = len(sources) - self.sources = [ops.clean_str(x) for x in sources] # clean source names for later - self.imgs, self.fps, self.frames, self.threads, self.shape = [[]] * n, [0] * n, [0] * n, [None] * n, [None] * n + self.bs = n + self.fps = [0] * n # frames per second + self.frames = [0] * n + self.threads = [None] * n self.caps = [None] * n # video capture objects + self.imgs = [[] for _ in range(n)] # images + self.shape = [[] for _ in range(n)] # image shapes + self.sources = [ops.clean_str(x) for x in sources] # clean source names for later for i, s in enumerate(sources): # index, source # Start thread to read frames from video stream - st = f'{i + 1}/{n}: {s}... ' - if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video - # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc' + st = f"{i + 1}/{n}: {s}... " + if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video + # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4' s = get_best_youtube_url(s) s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam if s == 0 and (is_colab() or is_kaggle()): - raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. " - "Try running 'source=0' in a local environment.") + raise NotImplementedError( + "'source=0' webcam not supported in Colab and Kaggle notebooks. " + "Try running 'source=0' in a local environment." + ) self.caps[i] = cv2.VideoCapture(s) # store video capture object if not self.caps[i].isOpened(): - raise ConnectionError(f'{st}Failed to open {s}') + raise ConnectionError(f"{st}Failed to open {s}") w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = self.caps[i].get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float( - 'inf') # infinite stream fallback + "inf" + ) # infinite stream fallback self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback success, im = self.caps[i].read() # guarantee first frame if not success or im is None: - raise ConnectionError(f'{st}Failed to read images from {s}') + raise ConnectionError(f"{st}Failed to read images from {s}") self.imgs[i].append(im) self.shape[i] = im.shape self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True) - LOGGER.info(f'{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)') + LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)") self.threads[i].start() - LOGGER.info('') # newline - - # Check for common shapes - self.bs = self.__len__() + LOGGER.info("") # newline def update(self, i, cap, stream): """Read stream `i` frames in daemon thread.""" n, f = 0, self.frames[i] # frame number, frame array while self.running and cap.isOpened() and n < (f - 1): - # Only read a new frame if the buffer is empty - if not self.imgs[i] or not self.stream_buffer: + if len(self.imgs[i]) < 30: # keep a <=30-image buffer n += 1 cap.grab() # .read() = .grab() followed by .retrieve() if n % self.vid_stride == 0: success, im = cap.retrieve() if not success: im = np.zeros(self.shape[i], dtype=np.uint8) - LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.') + LOGGER.warning("WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.") cap.open(stream) # re-open stream if signal was lost - self.imgs[i].append(im) # add image to buffer + if self.buffer: + self.imgs[i].append(im) + else: + self.imgs[i] = [im] else: time.sleep(0.01) # wait until the buffer is empty @@ -105,7 +143,7 @@ class LoadStreams: try: cap.release() # release video capture except Exception as e: - LOGGER.warning(f'WARNING ⚠️ Could not release VideoCapture object: {e}') + LOGGER.warning(f"WARNING ⚠️ Could not release VideoCapture object: {e}") cv2.destroyAllWindows() def __iter__(self): @@ -117,36 +155,62 @@ class LoadStreams: """Returns source paths, transformed and original images for processing.""" self.count += 1 - # Wait until a frame is available in each buffer - while not all(self.imgs): - if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit - self.close() - raise StopIteration - time.sleep(1 / min(self.fps)) + images = [] + for i, x in enumerate(self.imgs): + # Wait until a frame is available in each buffer + while not x: + if not self.threads[i].is_alive() or cv2.waitKey(1) == ord("q"): # q to quit + self.close() + raise StopIteration + time.sleep(1 / min(self.fps)) + x = self.imgs[i] + if not x: + LOGGER.warning(f"WARNING ⚠️ Waiting for stream {i}") - # Get and remove the next frame from imgs buffer - if self.stream_buffer: - images = [x.pop(0) for x in self.imgs] - else: - # Get the latest frame, and clear the rest from the imgs buffer - images = [] - for x in self.imgs: - images.append(x.pop(-1) if x else None) + # Get and remove the first frame from imgs buffer + if self.buffer: + images.append(x.pop(0)) + + # Get the last frame, and clear the rest from the imgs buffer + else: + images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8)) x.clear() - return self.sources, images, None, '' + return self.sources, images, [""] * self.bs def __len__(self): """Return the length of the sources object.""" - return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years + return self.bs # 1E12 frames = 32 streams at 30 FPS for 30 years class LoadScreenshots: - """YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`.""" + """ + YOLOv8 screenshot dataloader. - def __init__(self, source, imgsz=640): - """source = [screen_number left top width height] (pixels).""" - check_requirements('mss') + This class manages the loading of screenshot images for processing with YOLOv8. + Suitable for use with `yolo predict source=screen`. + + Attributes: + source (str): The source input indicating which screen to capture. + screen (int): The screen number to capture. + left (int): The left coordinate for screen capture area. + top (int): The top coordinate for screen capture area. + width (int): The width of the screen capture area. + height (int): The height of the screen capture area. + mode (str): Set to 'stream' indicating real-time capture. + frame (int): Counter for captured frames. + sct (mss.mss): Screen capture object from `mss` library. + bs (int): Batch size, set to 1. + monitor (dict): Monitor configuration details. + + Methods: + __iter__: Returns an iterator object. + __next__: Captures the next screenshot and returns it. + """ + + def __init__(self, source): + """Source = [screen_number left top width height] (pixels).""" + check_requirements("mss") import mss # noqa source, *params = source.split() @@ -157,19 +221,19 @@ class LoadScreenshots: left, top, width, height = (int(x) for x in params) elif len(params) == 5: self.screen, left, top, width, height = (int(x) for x in params) - self.imgsz = imgsz - self.mode = 'stream' + self.mode = "stream" self.frame = 0 self.sct = mss.mss() self.bs = 1 + self.fps = 30 # Parse monitor shape monitor = self.sct.monitors[self.screen] - self.top = monitor['top'] if top is None else (monitor['top'] + top) - self.left = monitor['left'] if left is None else (monitor['left'] + left) - self.width = width or monitor['width'] - self.height = height or monitor['height'] - self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height} + self.top = monitor["top"] if top is None else (monitor["top"] + top) + self.left = monitor["left"] if left is None else (monitor["left"] + left) + self.width = width or monitor["width"] + self.height = height or monitor["height"] + self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height} def __iter__(self): """Returns an iterator of the object.""" @@ -178,53 +242,75 @@ class LoadScreenshots: def __next__(self): """mss screen capture: get raw pixels from the screen as np array.""" im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR - s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: ' + s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: " self.frame += 1 - return [str(self.screen)], [im0], None, s # screen, img, vid_cap, string + return [str(self.screen)], [im0], [s] # screen, img, string -class LoadImages: - """YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`.""" +class LoadImagesAndVideos: + """ + YOLOv8 image/video dataloader. - def __init__(self, path, imgsz=640, vid_stride=1): + This class manages the loading and pre-processing of image and video data for YOLOv8. It supports loading from + various formats, including single image files, video files, and lists of image and video paths. + + Attributes: + files (list): List of image and video file paths. + nf (int): Total number of files (images and videos). + video_flag (list): Flags indicating whether a file is a video (True) or an image (False). + mode (str): Current mode, 'image' or 'video'. + vid_stride (int): Stride for video frame-rate, defaults to 1. + bs (int): Batch size, set to 1 for this class. + cap (cv2.VideoCapture): Video capture object for OpenCV. + frame (int): Frame counter for video. + frames (int): Total number of frames in the video. + count (int): Counter for iteration, initialized at 0 during `__iter__()`. + + Methods: + _new_video(path): Create a new cv2.VideoCapture object for a given video path. + """ + + def __init__(self, path, batch=1, vid_stride=1): """Initialize the Dataloader and raise FileNotFoundError if file not found.""" parent = None - if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line + if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line parent = Path(path).parent path = Path(path).read_text().splitlines() # list of sources files = [] for p in sorted(path) if isinstance(path, (list, tuple)) else [path]: a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912 - if '*' in a: + if "*" in a: files.extend(sorted(glob.glob(a, recursive=True))) # glob elif os.path.isdir(a): - files.extend(sorted(glob.glob(os.path.join(a, '*.*')))) # dir + files.extend(sorted(glob.glob(os.path.join(a, "*.*")))) # dir elif os.path.isfile(a): files.append(a) # files (absolute or relative to CWD) elif parent and (parent / p).is_file(): files.append(str((parent / p).absolute())) # files (relative to *.txt file parent) else: - raise FileNotFoundError(f'{p} does not exist') + raise FileNotFoundError(f"{p} does not exist") - images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS] - videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS] + images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS] + videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS] ni, nv = len(images), len(videos) - self.imgsz = imgsz self.files = images + videos self.nf = ni + nv # number of files + self.ni = ni # number of images self.video_flag = [False] * ni + [True] * nv - self.mode = 'image' + self.mode = "image" self.vid_stride = vid_stride # video frame-rate stride - self.bs = 1 + self.bs = batch if any(videos): self._new_video(videos[0]) # new video else: self.cap = None if self.nf == 0: - raise FileNotFoundError(f'No images or videos found in {p}. ' - f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}') + raise FileNotFoundError( + f"No images or videos found in {p}. " + f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}" + ) def __iter__(self): """Returns an iterator object for VideoStream or ImageFolder.""" @@ -232,71 +318,105 @@ class LoadImages: return self def __next__(self): - """Return next image, path and metadata from dataset.""" - if self.count == self.nf: - raise StopIteration - path = self.files[self.count] - - if self.video_flag[self.count]: - # Read video - self.mode = 'video' - for _ in range(self.vid_stride): - self.cap.grab() - success, im0 = self.cap.retrieve() - while not success: - self.count += 1 - self.cap.release() - if self.count == self.nf: # last video + """Returns the next batch of images or video frames along with their paths and metadata.""" + paths, imgs, info = [], [], [] + while len(imgs) < self.bs: + if self.count >= self.nf: # end of file list + if len(imgs) > 0: + return paths, imgs, info # return last partial batch + else: raise StopIteration - path = self.files[self.count] - self._new_video(path) - success, im0 = self.cap.read() - self.frame += 1 - # im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False - s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ' + path = self.files[self.count] + if self.video_flag[self.count]: + self.mode = "video" + if not self.cap or not self.cap.isOpened(): + self._new_video(path) - else: - # Read image - self.count += 1 - im0 = cv2.imread(path) # BGR - if im0 is None: - raise FileNotFoundError(f'Image Not Found {path}') - s = f'image {self.count}/{self.nf} {path}: ' + for _ in range(self.vid_stride): + success = self.cap.grab() + if not success: + break # end of video or failure - return [path], [im0], self.cap, s + if success: + success, im0 = self.cap.retrieve() + if success: + self.frame += 1 + paths.append(path) + imgs.append(im0) + info.append(f"video {self.count + 1}/{self.nf} (frame {self.frame}/{self.frames}) {path}: ") + if self.frame == self.frames: # end of video + self.count += 1 + self.cap.release() + else: + # Move to the next file if the current video ended or failed to open + self.count += 1 + if self.cap: + self.cap.release() + if self.count < self.nf: + self._new_video(self.files[self.count]) + else: + self.mode = "image" + im0 = cv2.imread(path) # BGR + if im0 is None: + raise FileNotFoundError(f"Image Not Found {path}") + paths.append(path) + imgs.append(im0) + info.append(f"image {self.count + 1}/{self.nf} {path}: ") + self.count += 1 # move to the next file + if self.count >= self.ni: # end of image list + break + + return paths, imgs, info def _new_video(self, path): - """Create a new video capture object.""" + """Creates a new video capture object for the given path.""" self.frame = 0 self.cap = cv2.VideoCapture(path) + self.fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + if not self.cap.isOpened(): + raise FileNotFoundError(f"Failed to open video {path}") self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride) def __len__(self): - """Returns the number of files in the object.""" - return self.nf # number of files + """Returns the number of batches in the object.""" + return math.ceil(self.nf / self.bs) # number of files class LoadPilAndNumpy: + """ + Load images from PIL and Numpy arrays for batch processing. - def __init__(self, im0, imgsz=640): + This class is designed to manage loading and pre-processing of image data from both PIL and Numpy formats. + It performs basic validation and format conversion to ensure that the images are in the required format for + downstream processing. + + Attributes: + paths (list): List of image paths or autogenerated filenames. + im0 (list): List of images stored as Numpy arrays. + mode (str): Type of data being processed, defaults to 'image'. + bs (int): Batch size, equivalent to the length of `im0`. + + Methods: + _single_check(im): Validate and format a single image to a Numpy array. + """ + + def __init__(self, im0): """Initialize PIL and Numpy Dataloader.""" if not isinstance(im0, list): im0 = [im0] - self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)] + self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)] self.im0 = [self._single_check(im) for im in im0] - self.imgsz = imgsz - self.mode = 'image' - # Generate fake paths + self.mode = "image" self.bs = len(self.im0) @staticmethod def _single_check(im): """Validate and format an image to numpy array.""" - assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}' + assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}" if isinstance(im, Image.Image): - if im.mode != 'RGB': - im = im.convert('RGB') + if im.mode != "RGB": + im = im.convert("RGB") im = np.asarray(im)[:, :, ::-1] im = np.ascontiguousarray(im) # contiguous return im @@ -310,7 +430,7 @@ class LoadPilAndNumpy: if self.count == 1: # loop only once as it's batch inference raise StopIteration self.count += 1 - return self.paths, self.im0, None, '' + return self.paths, self.im0, [""] * self.bs def __iter__(self): """Enables iteration for class LoadPilAndNumpy.""" @@ -319,18 +439,36 @@ class LoadPilAndNumpy: class LoadTensor: + """ + Load images from torch.Tensor data. + + This class manages the loading and pre-processing of image data from PyTorch tensors for further processing. + + Attributes: + im0 (torch.Tensor): The input tensor containing the image(s). + bs (int): Batch size, inferred from the shape of `im0`. + mode (str): Current mode, set to 'image'. + paths (list): List of image paths or filenames. + count (int): Counter for iteration, initialized at 0 during `__iter__()`. + + Methods: + _single_check(im, stride): Validate and possibly modify the input tensor. + """ def __init__(self, im0) -> None: + """Initialize Tensor Dataloader.""" self.im0 = self._single_check(im0) self.bs = self.im0.shape[0] - self.mode = 'image' - self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)] + self.mode = "image" + self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)] @staticmethod def _single_check(im, stride=32): """Validate and format an image to torch.Tensor.""" - s = f'WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) ' \ - f'divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.' + s = ( + f"WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) " + f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible." + ) if len(im.shape) != 4: if len(im.shape) != 3: raise ValueError(s) @@ -338,9 +476,11 @@ class LoadTensor: im = im.unsqueeze(0) if im.shape[2] % stride or im.shape[3] % stride: raise ValueError(s) - if im.max() > 1.0: - LOGGER.warning(f'WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. ' - f'Dividing input by 255.') + if im.max() > 1.0 + torch.finfo(im.dtype).eps: # torch.float32 eps is 1.2e-07 + LOGGER.warning( + f"WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. " + f"Dividing input by 255." + ) im = im.float() / 255.0 return im @@ -355,7 +495,7 @@ class LoadTensor: if self.count == 1: raise StopIteration self.count += 1 - return self.paths, self.im0, None, '' + return self.paths, self.im0, [""] * self.bs def __len__(self): """Returns the batch size.""" @@ -363,26 +503,23 @@ class LoadTensor: def autocast_list(source): - """ - Merges a list of source of different types into a list of numpy arrays or PIL images - """ + """Merges a list of source of different types into a list of numpy arrays or PIL images.""" files = [] for im in source: if isinstance(im, (str, Path)): # filename or uri - files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im)) + files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im)) elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image files.append(im) else: - raise TypeError(f'type {type(im).__name__} is not a supported Ultralytics prediction source type. \n' - f'See https://docs.ultralytics.com/modes/predict for supported source types.') + raise TypeError( + f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n" + f"See https://docs.ultralytics.com/modes/predict for supported source types." + ) return files -LOADERS = LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots # tuple - - -def get_best_youtube_url(url, use_pafy=False): +def get_best_youtube_url(url, use_pafy=True): """ Retrieves the URL of the best quality MP4 video stream from a given YouTube video. @@ -397,16 +534,22 @@ def get_best_youtube_url(url, use_pafy=False): (str): The URL of the best quality MP4 video stream, or None if no suitable stream is found. """ if use_pafy: - check_requirements(('pafy', 'youtube_dl==2020.12.2')) + check_requirements(("pafy", "youtube_dl==2020.12.2")) import pafy # noqa - return pafy.new(url).getbestvideo(preftype='mp4').url + + return pafy.new(url).getbestvideo(preftype="mp4").url else: - check_requirements('yt-dlp') + check_requirements("yt-dlp") import yt_dlp - with yt_dlp.YoutubeDL({'quiet': True}) as ydl: + + with yt_dlp.YoutubeDL({"quiet": True}) as ydl: info_dict = ydl.extract_info(url, download=False) # extract info - for f in reversed(info_dict.get('formats', [])): # reversed because best is usually last + for f in reversed(info_dict.get("formats", [])): # reversed because best is usually last # Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size - good_size = (f.get('width') or 0) >= 1920 or (f.get('height') or 0) >= 1080 - if good_size and f['vcodec'] != 'none' and f['acodec'] == 'none' and f['ext'] == 'mp4': - return f.get('url') + good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080 + if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4": + return f.get("url") + + +# Define constants +LOADERS = (LoadStreams, LoadPilAndNumpy, LoadImagesAndVideos, LoadScreenshots) diff --git a/ultralytics/data/scripts/get_coco.sh b/ultralytics/data/scripts/get_coco.sh index 126e7f0..764e280 100644 --- a/ultralytics/data/scripts/get_coco.sh +++ b/ultralytics/data/scripts/get_coco.sh @@ -1,6 +1,6 @@ #!/bin/bash # Ultralytics YOLO 🚀, AGPL-3.0 license -# Download COCO 2017 dataset http://cocodataset.org +# Download COCO 2017 dataset https://cocodataset.org # Example usage: bash data/scripts/get_coco.sh # parent # ├── ultralytics diff --git a/ultralytics/data/split_dota.py b/ultralytics/data/split_dota.py new file mode 100644 index 0000000..8a5469b --- /dev/null +++ b/ultralytics/data/split_dota.py @@ -0,0 +1,288 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import itertools +from glob import glob +from math import ceil +from pathlib import Path + +import cv2 +import numpy as np +from PIL import Image +from tqdm import tqdm + +from ultralytics.data.utils import exif_size, img2label_paths +from ultralytics.utils.checks import check_requirements + +check_requirements("shapely") +from shapely.geometry import Polygon + + +def bbox_iof(polygon1, bbox2, eps=1e-6): + """ + Calculate iofs between bbox1 and bbox2. + + Args: + polygon1 (np.ndarray): Polygon coordinates, (n, 8). + bbox2 (np.ndarray): Bounding boxes, (n ,4). + """ + polygon1 = polygon1.reshape(-1, 4, 2) + lt_point = np.min(polygon1, axis=-2) + rb_point = np.max(polygon1, axis=-2) + bbox1 = np.concatenate([lt_point, rb_point], axis=-1) + + lt = np.maximum(bbox1[:, None, :2], bbox2[..., :2]) + rb = np.minimum(bbox1[:, None, 2:], bbox2[..., 2:]) + wh = np.clip(rb - lt, 0, np.inf) + h_overlaps = wh[..., 0] * wh[..., 1] + + l, t, r, b = (bbox2[..., i] for i in range(4)) + polygon2 = np.stack([l, t, r, t, r, b, l, b], axis=-1).reshape(-1, 4, 2) + + sg_polys1 = [Polygon(p) for p in polygon1] + sg_polys2 = [Polygon(p) for p in polygon2] + overlaps = np.zeros(h_overlaps.shape) + for p in zip(*np.nonzero(h_overlaps)): + overlaps[p] = sg_polys1[p[0]].intersection(sg_polys2[p[-1]]).area + unions = np.array([p.area for p in sg_polys1], dtype=np.float32) + unions = unions[..., None] + + unions = np.clip(unions, eps, np.inf) + outputs = overlaps / unions + if outputs.ndim == 1: + outputs = outputs[..., None] + return outputs + + +def load_yolo_dota(data_root, split="train"): + """ + Load DOTA dataset. + + Args: + data_root (str): Data root. + split (str): The split data set, could be train or val. + + Notes: + The directory structure assumed for the DOTA dataset: + - data_root + - images + - train + - val + - labels + - train + - val + """ + assert split in ["train", "val"] + im_dir = Path(data_root) / "images" / split + assert im_dir.exists(), f"Can't find {im_dir}, please check your data root." + im_files = glob(str(Path(data_root) / "images" / split / "*")) + lb_files = img2label_paths(im_files) + annos = [] + for im_file, lb_file in zip(im_files, lb_files): + w, h = exif_size(Image.open(im_file)) + with open(lb_file) as f: + lb = [x.split() for x in f.read().strip().splitlines() if len(x)] + lb = np.array(lb, dtype=np.float32) + annos.append(dict(ori_size=(h, w), label=lb, filepath=im_file)) + return annos + + +def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.01): + """ + Get the coordinates of windows. + + Args: + im_size (tuple): Original image size, (h, w). + crop_sizes (List(int)): Crop size of windows. + gaps (List(int)): Gap between crops. + im_rate_thr (float): Threshold of windows areas divided by image ares. + """ + h, w = im_size + windows = [] + for crop_size, gap in zip(crop_sizes, gaps): + assert crop_size > gap, f"invalid crop_size gap pair [{crop_size} {gap}]" + step = crop_size - gap + + xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1) + xs = [step * i for i in range(xn)] + if len(xs) > 1 and xs[-1] + crop_size > w: + xs[-1] = w - crop_size + + yn = 1 if h <= crop_size else ceil((h - crop_size) / step + 1) + ys = [step * i for i in range(yn)] + if len(ys) > 1 and ys[-1] + crop_size > h: + ys[-1] = h - crop_size + + start = np.array(list(itertools.product(xs, ys)), dtype=np.int64) + stop = start + crop_size + windows.append(np.concatenate([start, stop], axis=1)) + windows = np.concatenate(windows, axis=0) + + im_in_wins = windows.copy() + im_in_wins[:, 0::2] = np.clip(im_in_wins[:, 0::2], 0, w) + im_in_wins[:, 1::2] = np.clip(im_in_wins[:, 1::2], 0, h) + im_areas = (im_in_wins[:, 2] - im_in_wins[:, 0]) * (im_in_wins[:, 3] - im_in_wins[:, 1]) + win_areas = (windows[:, 2] - windows[:, 0]) * (windows[:, 3] - windows[:, 1]) + im_rates = im_areas / win_areas + if not (im_rates > im_rate_thr).any(): + max_rate = im_rates.max() + im_rates[abs(im_rates - max_rate) < eps] = 1 + return windows[im_rates > im_rate_thr] + + +def get_window_obj(anno, windows, iof_thr=0.7): + """Get objects for each window.""" + h, w = anno["ori_size"] + label = anno["label"] + if len(label): + label[:, 1::2] *= w + label[:, 2::2] *= h + iofs = bbox_iof(label[:, 1:], windows) + # Unnormalized and misaligned coordinates + return [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))] # window_anns + else: + return [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))] # window_anns + + +def crop_and_save(anno, windows, window_objs, im_dir, lb_dir): + """ + Crop images and save new labels. + + Args: + anno (dict): Annotation dict, including `filepath`, `label`, `ori_size` as its keys. + windows (list): A list of windows coordinates. + window_objs (list): A list of labels inside each window. + im_dir (str): The output directory path of images. + lb_dir (str): The output directory path of labels. + + Notes: + The directory structure assumed for the DOTA dataset: + - data_root + - images + - train + - val + - labels + - train + - val + """ + im = cv2.imread(anno["filepath"]) + name = Path(anno["filepath"]).stem + for i, window in enumerate(windows): + x_start, y_start, x_stop, y_stop = window.tolist() + new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}" + patch_im = im[y_start:y_stop, x_start:x_stop] + ph, pw = patch_im.shape[:2] + + cv2.imwrite(str(Path(im_dir) / f"{new_name}.jpg"), patch_im) + label = window_objs[i] + if len(label) == 0: + continue + label[:, 1::2] -= x_start + label[:, 2::2] -= y_start + label[:, 1::2] /= pw + label[:, 2::2] /= ph + + with open(Path(lb_dir) / f"{new_name}.txt", "w") as f: + for lb in label: + formatted_coords = ["{:.6g}".format(coord) for coord in lb[1:]] + f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n") + + +def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=[1024], gaps=[200]): + """ + Split both images and labels. + + Notes: + The directory structure assumed for the DOTA dataset: + - data_root + - images + - split + - labels + - split + and the output directory structure is: + - save_dir + - images + - split + - labels + - split + """ + im_dir = Path(save_dir) / "images" / split + im_dir.mkdir(parents=True, exist_ok=True) + lb_dir = Path(save_dir) / "labels" / split + lb_dir.mkdir(parents=True, exist_ok=True) + + annos = load_yolo_dota(data_root, split=split) + for anno in tqdm(annos, total=len(annos), desc=split): + windows = get_windows(anno["ori_size"], crop_sizes, gaps) + window_objs = get_window_obj(anno, windows) + crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir)) + + +def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]): + """ + Split train and val set of DOTA. + + Notes: + The directory structure assumed for the DOTA dataset: + - data_root + - images + - train + - val + - labels + - train + - val + and the output directory structure is: + - save_dir + - images + - train + - val + - labels + - train + - val + """ + crop_sizes, gaps = [], [] + for r in rates: + crop_sizes.append(int(crop_size / r)) + gaps.append(int(gap / r)) + for split in ["train", "val"]: + split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps) + + +def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]): + """ + Split test set of DOTA, labels are not included within this set. + + Notes: + The directory structure assumed for the DOTA dataset: + - data_root + - images + - test + and the output directory structure is: + - save_dir + - images + - test + """ + crop_sizes, gaps = [], [] + for r in rates: + crop_sizes.append(int(crop_size / r)) + gaps.append(int(gap / r)) + save_dir = Path(save_dir) / "images" / "test" + save_dir.mkdir(parents=True, exist_ok=True) + + im_dir = Path(data_root) / "images" / "test" + assert im_dir.exists(), f"Can't find {im_dir}, please check your data root." + im_files = glob(str(im_dir / "*")) + for im_file in tqdm(im_files, total=len(im_files), desc="test"): + w, h = exif_size(Image.open(im_file)) + windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps) + im = cv2.imread(im_file) + name = Path(im_file).stem + for window in windows: + x_start, y_start, x_stop, y_stop = window.tolist() + new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}" + patch_im = im[y_start:y_stop, x_start:x_stop] + cv2.imwrite(str(save_dir / f"{new_name}.jpg"), patch_im) + + +if __name__ == "__main__": + split_trainval(data_root="DOTAv2", save_dir="DOTAv2-split") + split_test(data_root="DOTAv2", save_dir="DOTAv2-split") diff --git a/ultralytics/data/utils.py b/ultralytics/data/utils.py index 3b780f2..c0a0773 100644 --- a/ultralytics/data/utils.py +++ b/ultralytics/data/utils.py @@ -17,36 +17,47 @@ import numpy as np from PIL import Image, ImageOps from ultralytics.nn.autobackend import check_class_names -from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, TQDM, clean_url, colorstr, - emojis, yaml_load) +from ultralytics.utils import ( + DATASETS_DIR, + LOGGER, + NUM_THREADS, + ROOT, + SETTINGS_YAML, + TQDM, + clean_url, + colorstr, + emojis, + yaml_load, + yaml_save, +) from ultralytics.utils.checks import check_file, check_font, is_ascii from ultralytics.utils.downloads import download, safe_download, unzip_file from ultralytics.utils.ops import segments2boxes -HELP_URL = 'See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance.' -IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes -VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes -PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders +HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance." +IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes +VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes +PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders def img2label_paths(img_paths): """Define label paths as a function of image paths.""" - sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings - return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths] + sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings + return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths] def get_hash(paths): """Returns a single hash value of a list of paths (files or dirs).""" size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes h = hashlib.sha256(str(size).encode()) # hash sizes - h.update(''.join(paths).encode()) # hash paths + h.update("".join(paths).encode()) # hash paths return h.hexdigest() # return hash def exif_size(img: Image.Image): """Returns exif-corrected PIL size.""" s = img.size # (width, height) - if img.format == 'JPEG': # only support JPEG images + if img.format == "JPEG": # only support JPEG images with contextlib.suppress(Exception): exif = img.getexif() if exif: @@ -60,24 +71,24 @@ def verify_image(args): """Verify one image.""" (im_file, cls), prefix = args # Number (found, corrupt), message - nf, nc, msg = 0, 0, '' + nf, nc, msg = 0, 0, "" try: im = Image.open(im_file) im.verify() # PIL verify shape = exif_size(im) # image size shape = (shape[1], shape[0]) # hw - assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' - assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}' - if im.format.lower() in ('jpg', 'jpeg'): - with open(im_file, 'rb') as f: + assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" + assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}" + if im.format.lower() in ("jpg", "jpeg"): + with open(im_file, "rb") as f: f.seek(-2, 2) - if f.read() != b'\xff\xd9': # corrupt JPEG - ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100) - msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved' + if f.read() != b"\xff\xd9": # corrupt JPEG + ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100) + msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved" nf = 1 except Exception as e: nc = 1 - msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}' + msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" return (im_file, cls), nf, nc, msg @@ -85,21 +96,21 @@ def verify_image_label(args): """Verify one image-label pair.""" im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args # Number (missing, found, empty, corrupt), message, segments, keypoints - nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None + nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None try: # Verify images im = Image.open(im_file) im.verify() # PIL verify shape = exif_size(im) # image size shape = (shape[1], shape[0]) # hw - assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' - assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}' - if im.format.lower() in ('jpg', 'jpeg'): - with open(im_file, 'rb') as f: + assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" + assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}" + if im.format.lower() in ("jpg", "jpeg"): + with open(im_file, "rb") as f: f.seek(-2, 2) - if f.read() != b'\xff\xd9': # corrupt JPEG - ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100) - msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved' + if f.read() != b"\xff\xd9": # corrupt JPEG + ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100) + msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved" # Verify labels if os.path.isfile(lb_file): @@ -114,32 +125,32 @@ def verify_image_label(args): nl = len(lb) if nl: if keypoint: - assert lb.shape[1] == (5 + nkpt * ndim), f'labels require {(5 + nkpt * ndim)} columns each' - assert (lb[:, 5::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels' - assert (lb[:, 6::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels' + assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each" + points = lb[:, 5:].reshape(-1, ndim)[:, :2] else: - assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected' - assert (lb[:, 1:] <= 1).all(), \ - f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}' - assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}' + assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" + points = lb[:, 1:] + assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}" + assert lb.min() >= 0, f"negative label values {lb[lb < 0]}" + # All labels - max_cls = int(lb[:, 0].max()) # max label count - assert max_cls <= num_cls, \ - f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \ - f'Possible class labels are 0-{num_cls - 1}' + max_cls = lb[:, 0].max() # max label count + assert max_cls <= num_cls, ( + f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. " + f"Possible class labels are 0-{num_cls - 1}" + ) _, i = np.unique(lb, axis=0, return_index=True) if len(i) < nl: # duplicate row check lb = lb[i] # remove duplicates if segments: segments = [segments[x] for x in i] - msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed' + msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed" else: ne = 1 # label empty - lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros( - (0, 5), dtype=np.float32) + lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32) else: nm = 1 # label missing - lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32) + lb = np.zeros((0, (5 + nkpt * ndim) if keypoints else 5), dtype=np.float32) if keypoint: keypoints = lb[:, 5:].reshape(-1, nkpt, ndim) if ndim == 2: @@ -149,42 +160,56 @@ def verify_image_label(args): return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg except Exception as e: nc = 1 - msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}' + msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" return [None, None, None, None, None, nm, nf, ne, nc, msg] def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1): """ + Convert a list of polygons to a binary mask of the specified image size. + Args: - imgsz (tuple): The image size. - polygons (list[np.ndarray]): [N, M], N is the number of polygons, M is the number of points(Be divided by 2). - color (int): color - downsample_ratio (int): downsample ratio + imgsz (tuple): The size of the image as (height, width). + polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where + N is the number of polygons, and M is the number of points such that M % 2 = 0. + color (int, optional): The color value to fill in the polygons on the mask. Defaults to 1. + downsample_ratio (int, optional): Factor by which to downsample the mask. Defaults to 1. + + Returns: + (np.ndarray): A binary mask of the specified image size with the polygons filled in. """ mask = np.zeros(imgsz, dtype=np.uint8) polygons = np.asarray(polygons, dtype=np.int32) polygons = polygons.reshape((polygons.shape[0], -1, 2)) cv2.fillPoly(mask, polygons, color=color) nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio) - # NOTE: fillPoly first then resize is trying to keep the same way of loss calculation when mask-ratio=1. + # Note: fillPoly first then resize is trying to keep the same loss calculation method when mask-ratio=1 return cv2.resize(mask, (nw, nh)) def polygons2masks(imgsz, polygons, color, downsample_ratio=1): """ + Convert a list of polygons to a set of binary masks of the specified image size. + Args: - imgsz (tuple): The image size. - polygons (list[np.ndarray]): each polygon is [N, M], N is number of polygons, M is number of points (M % 2 = 0) - color (int): color - downsample_ratio (int): downsample ratio + imgsz (tuple): The size of the image as (height, width). + polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where + N is the number of polygons, and M is the number of points such that M % 2 = 0. + color (int): The color value to fill in the polygons on the masks. + downsample_ratio (int, optional): Factor by which to downsample each mask. Defaults to 1. + + Returns: + (np.ndarray): A set of binary masks of the specified image size with the polygons filled in. """ return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons]) def polygons2masks_overlap(imgsz, segments, downsample_ratio=1): """Return a (640, 640) overlap mask.""" - masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio), - dtype=np.int32 if len(segments) > 255 else np.uint8) + masks = np.zeros( + (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio), + dtype=np.int32 if len(segments) > 255 else np.uint8, + ) areas = [] ms = [] for si in range(len(segments)): @@ -206,7 +231,7 @@ def find_dataset_yaml(path: Path) -> Path: Find and return the YAML file associated with a Detect, Segment or Pose dataset. This function searches for a YAML file at the root level of the provided directory first, and if not found, it - performs a recursive search. It prefers YAML files that have the samestem as the provided path. An AssertionError + performs a recursive search. It prefers YAML files that have the same stem as the provided path. An AssertionError is raised if no YAML file is found or if multiple YAML files are found. Args: @@ -215,7 +240,7 @@ def find_dataset_yaml(path: Path) -> Path: Returns: (Path): The path of the found YAML file. """ - files = list(path.glob('*.yaml')) or list(path.rglob('*.yaml')) # try root level first and then recursive + files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive assert files, f"No YAML file found in '{path.resolve()}'" if len(files) > 1: files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match @@ -239,57 +264,57 @@ def check_det_dataset(dataset, autodownload=True): (dict): Parsed dataset information and paths. """ - data = check_file(dataset) + file = check_file(dataset) # Download (optional) - extract_dir = '' - if isinstance(data, (str, Path)) and (zipfile.is_zipfile(data) or is_tarfile(data)): - new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False) - data = find_dataset_yaml(DATASETS_DIR / new_dir) - extract_dir, autodownload = data.parent, False + extract_dir = "" + if zipfile.is_zipfile(file) or is_tarfile(file): + new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False) + file = find_dataset_yaml(DATASETS_DIR / new_dir) + extract_dir, autodownload = file.parent, False - # Read YAML (optional) - if isinstance(data, (str, Path)): - data = yaml_load(data, append_filename=True) # dictionary + # Read YAML + data = yaml_load(file, append_filename=True) # dictionary # Checks - for k in 'train', 'val': + for k in "train", "val": if k not in data: - if k == 'val' and 'validation' in data: - LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.") - data['val'] = data.pop('validation') # replace 'validation' key with 'val' key - else: + if k != "val" or "validation" not in data: raise SyntaxError( - emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")) - if 'names' not in data and 'nc' not in data: + emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.") + ) + LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.") + data["val"] = data.pop("validation") # replace 'validation' key with 'val' key + if "names" not in data and "nc" not in data: raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs.")) - if 'names' in data and 'nc' in data and len(data['names']) != data['nc']: + if "names" in data and "nc" in data and len(data["names"]) != data["nc"]: raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match.")) - if 'names' not in data: - data['names'] = [f'class_{i}' for i in range(data['nc'])] + if "names" not in data: + data["names"] = [f"class_{i}" for i in range(data["nc"])] else: - data['nc'] = len(data['names']) + data["nc"] = len(data["names"]) - data['names'] = check_class_names(data['names']) + data["names"] = check_class_names(data["names"]) # Resolve paths - path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root - + path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root if not path.is_absolute(): path = (DATASETS_DIR / path).resolve() - data['path'] = path # download scripts - for k in 'train', 'val', 'test': + + # Set paths + data["path"] = path # download scripts + for k in "train", "val", "test": if data.get(k): # prepend path if isinstance(data[k], str): x = (path / data[k]).resolve() - if not x.exists() and data[k].startswith('../'): + if not x.exists() and data[k].startswith("../"): x = (path / data[k][3:]).resolve() data[k] = str(x) else: data[k] = [str((path / x).resolve()) for x in data[k]] # Parse YAML - train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download')) + val, s = (data.get(x) for x in ("val", "download")) if val: val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path if not all(x.exists() for x in val): @@ -302,22 +327,22 @@ def check_det_dataset(dataset, autodownload=True): raise FileNotFoundError(m) t = time.time() r = None # success - if s.startswith('http') and s.endswith('.zip'): # URL + if s.startswith("http") and s.endswith(".zip"): # URL safe_download(url=s, dir=DATASETS_DIR, delete=True) - elif s.startswith('bash '): # bash script - LOGGER.info(f'Running {s} ...') + elif s.startswith("bash "): # bash script + LOGGER.info(f"Running {s} ...") r = os.system(s) else: # python script - exec(s, {'yaml': data}) - dt = f'({round(time.time() - t, 1)}s)' - s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌' - LOGGER.info(f'Dataset download {s}\n') - check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts + exec(s, {"yaml": data}) + dt = f"({round(time.time() - t, 1)}s)" + s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌" + LOGGER.info(f"Dataset download {s}\n") + check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts return data # dictionary -def check_cls_dataset(dataset, split=''): +def check_cls_dataset(dataset, split=""): """ Checks a classification dataset such as Imagenet. @@ -338,54 +363,62 @@ def check_cls_dataset(dataset, split=''): """ # Download (optional if dataset=https://file.zip is passed directly) - if str(dataset).startswith(('http:/', 'https:/')): + if str(dataset).startswith(("http:/", "https:/")): dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False) + elif Path(dataset).suffix in (".zip", ".tar", ".gz"): + file = check_file(dataset) + dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False) dataset = Path(dataset) data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve() if not data_dir.is_dir(): - LOGGER.warning(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...') + LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...") t = time.time() - if str(dataset) == 'imagenet': + if str(dataset) == "imagenet": subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) else: - url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' + url = f"https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip" download(url, dir=data_dir.parent) s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" LOGGER.info(s) - train_set = data_dir / 'train' - val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if \ - (data_dir / 'validation').exists() else None # data/test or data/val - test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test - if split == 'val' and not val_set: + train_set = data_dir / "train" + val_set = ( + data_dir / "val" + if (data_dir / "val").exists() + else data_dir / "validation" + if (data_dir / "validation").exists() + else None + ) # data/test or data/val + test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test + if split == "val" and not val_set: LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.") - elif split == 'test' and not test_set: + elif split == "test" and not test_set: LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.") - nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes - names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list + nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes + names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list names = dict(enumerate(sorted(names))) # Print to console - for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items(): + for k, v in {"train": train_set, "val": val_set, "test": test_set}.items(): prefix = f'{colorstr(f"{k}:")} {v}...' if v is None: LOGGER.info(prefix) else: - files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS] + files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS] nf = len(files) # number of files nd = len({file.parent for file in files}) # number of directories if nf == 0: - if k == 'train': + if k == "train": raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ ")) else: - LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found') + LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found") elif nd != nc: - LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}') + LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}") else: - LOGGER.info(f'{prefix} found {nf} images in {nd} classes ✅ ') + LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ") - return {'train': train_set, 'val': val_set, 'test': test_set, 'nc': nc, 'names': names} + return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names} class HUBDatasetStats: @@ -393,7 +426,7 @@ class HUBDatasetStats: A class for generating HUB dataset JSON and `-hub` dataset directory. Args: - path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco128.yaml'. + path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco8.yaml'. task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'. autodownload (bool): Attempt to download dataset if not found locally. Default is False. @@ -413,39 +446,42 @@ class HUBDatasetStats: ``` """ - def __init__(self, path='coco128.yaml', task='detect', autodownload=False): + def __init__(self, path="coco8.yaml", task="detect", autodownload=False): """Initialize class.""" path = Path(path).resolve() - LOGGER.info(f'Starting HUB dataset checks for {path}....') + LOGGER.info(f"Starting HUB dataset checks for {path}....") self.task = task # detect, segment, pose, classify - if self.task == 'classify': + if self.task == "classify": unzip_dir = unzip_file(path) data = check_cls_dataset(unzip_dir) - data['path'] = unzip_dir + data["path"] = unzip_dir else: # detect, segment, pose - zipped, data_dir, yaml_path = self._unzip(Path(path)) + _, data_dir, yaml_path = self._unzip(Path(path)) try: - # data = yaml_load(check_yaml(yaml_path)) # data dict - data = check_det_dataset(yaml_path, autodownload) # data dict - if zipped: - data['path'] = data_dir + # Load YAML with checks + data = yaml_load(yaml_path) + data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets + yaml_save(yaml_path, data) + data = check_det_dataset(yaml_path, autodownload) # dict + data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute) except Exception as e: - raise Exception('error/HUB/dataset_stats/init') from e + raise Exception("error/HUB/dataset_stats/init") from e self.hub_dir = Path(f'{data["path"]}-hub') - self.im_dir = self.hub_dir / 'images' - self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images - self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary + self.im_dir = self.hub_dir / "images" + self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary self.data = data - def _unzip(self, path): + @staticmethod + def _unzip(path): """Unzip data.zip.""" - if not str(path).endswith('.zip'): # path is data.yaml + if not str(path).endswith(".zip"): # path is data.yaml return False, None, path unzip_dir = unzip_file(path, path=path.parent) - assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \ - f'path/to/abc.zip MUST unzip to path/to/abc/' + assert unzip_dir.is_dir(), ( + f"Error unzipping {path}, {unzip_dir} not found. " f"path/to/abc.zip MUST unzip to path/to/abc/" + ) return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path def _hub_ops(self, f): @@ -457,31 +493,31 @@ class HUBDatasetStats: def _round(labels): """Update labels to integer class and 4 decimal place floats.""" - if self.task == 'detect': - coordinates = labels['bboxes'] - elif self.task == 'segment': - coordinates = [x.flatten() for x in labels['segments']] - elif self.task == 'pose': - n = labels['keypoints'].shape[0] - coordinates = np.concatenate((labels['bboxes'], labels['keypoints'].reshape(n, -1)), 1) + if self.task == "detect": + coordinates = labels["bboxes"] + elif self.task == "segment": + coordinates = [x.flatten() for x in labels["segments"]] + elif self.task == "pose": + n = labels["keypoints"].shape[0] + coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, -1)), 1) else: - raise ValueError('Undefined dataset task.') - zipped = zip(labels['cls'], coordinates) + raise ValueError("Undefined dataset task.") + zipped = zip(labels["cls"], coordinates) return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped] - for split in 'train', 'val', 'test': + for split in "train", "val", "test": self.stats[split] = None # predefine path = self.data.get(split) # Check split if path is None: # no split continue - files = [f for f in Path(path).rglob('*.*') if f.suffix[1:].lower() in IMG_FORMATS] # image files in split + files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split if not files: # no images continue # Get dataset statistics - if self.task == 'classify': + if self.task == "classify": from torchvision.datasets import ImageFolder dataset = ImageFolder(self.data[split]) @@ -491,41 +527,36 @@ class HUBDatasetStats: x[im[1]] += 1 self.stats[split] = { - 'instance_stats': { - 'total': len(dataset), - 'per_class': x.tolist()}, - 'image_stats': { - 'total': len(dataset), - 'unlabelled': 0, - 'per_class': x.tolist()}, - 'labels': [{ - Path(k).name: v} for k, v in dataset.imgs]} + "instance_stats": {"total": len(dataset), "per_class": x.tolist()}, + "image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()}, + "labels": [{Path(k).name: v} for k, v in dataset.imgs], + } else: from ultralytics.data import YOLODataset - dataset = YOLODataset(img_path=self.data[split], - data=self.data, - use_segments=self.task == 'segment', - use_keypoints=self.task == 'pose') - x = np.array([ - np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc']) - for label in TQDM(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80) + dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task) + x = np.array( + [ + np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"]) + for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics") + ] + ) # shape(128x80) self.stats[split] = { - 'instance_stats': { - 'total': int(x.sum()), - 'per_class': x.sum(0).tolist()}, - 'image_stats': { - 'total': len(dataset), - 'unlabelled': int(np.all(x == 0, 1).sum()), - 'per_class': (x > 0).sum(0).tolist()}, - 'labels': [{ - Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)]} + "instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()}, + "image_stats": { + "total": len(dataset), + "unlabelled": int(np.all(x == 0, 1).sum()), + "per_class": (x > 0).sum(0).tolist(), + }, + "labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)], + } # Save, print and return if save: - stats_path = self.hub_dir / 'stats.json' - LOGGER.info(f'Saving {stats_path.resolve()}...') - with open(stats_path, 'w') as f: + self.hub_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/ + stats_path = self.hub_dir / "stats.json" + LOGGER.info(f"Saving {stats_path.resolve()}...") + with open(stats_path, "w") as f: json.dump(self.stats, f) # save stats.json if verbose: LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False)) @@ -535,22 +566,23 @@ class HUBDatasetStats: """Compress images for Ultralytics HUB.""" from ultralytics.data import YOLODataset # ClassificationDataset - for split in 'train', 'val', 'test': + self.im_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/images/ + for split in "train", "val", "test": if self.data.get(split) is None: continue dataset = YOLODataset(img_path=self.data[split], data=self.data) with ThreadPool(NUM_THREADS) as pool: - for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'): + for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"): pass - LOGGER.info(f'Done. All images saved to {self.im_dir}') + LOGGER.info(f"Done. All images saved to {self.im_dir}") return self.im_dir def compress_one_image(f, f_new=None, max_dim=1920, quality=50): """ - Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the - Python Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will - not be resized. + Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the Python + Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be + resized. Args: f (str): The path to the input image file. @@ -573,9 +605,9 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50): r = max_dim / max(im.height, im.width) # ratio if r < 1.0: # image too large im = im.resize((int(im.width * r), int(im.height * r))) - im.save(f_new or f, 'JPEG', quality=quality, optimize=True) # save + im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save except Exception as e: # use OpenCV - LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}') + LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}") im = cv2.imread(f) im_height, im_width = im.shape[:2] r = max_dim / max(im_height, im_width) # ratio @@ -584,7 +616,7 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50): cv2.imwrite(str(f_new or f), im) -def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annotated_only=False): +def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False): """ Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files. @@ -602,18 +634,18 @@ def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annot """ path = Path(path) # images dir - files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only + files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only n = len(files) # number of files random.seed(0) # for reproducibility indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split - txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files + txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files for x in txt: if (path.parent / x).exists(): (path.parent / x).unlink() # remove existing - LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only) + LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only) for i, img in TQDM(zip(indices, files), total=n): if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label - with open(path.parent / txt[i], 'a') as f: - f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file + with open(path.parent / txt[i], "a") as f: + f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file diff --git a/ultralytics/engine/__init__.py b/ultralytics/engine/__init__.py index e69de29..9e68dc1 100644 --- a/ultralytics/engine/__init__.py +++ b/ultralytics/engine/__init__.py @@ -0,0 +1 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license diff --git a/ultralytics/engine/__pycache__/__init__.cpython-312.pyc b/ultralytics/engine/__pycache__/__init__.cpython-312.pyc index 87fdfda..fdaad91 100644 Binary files a/ultralytics/engine/__pycache__/__init__.cpython-312.pyc and b/ultralytics/engine/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/engine/__pycache__/__init__.cpython-39.pyc b/ultralytics/engine/__pycache__/__init__.cpython-39.pyc index 1c9e29a..d1d5e84 100644 Binary files a/ultralytics/engine/__pycache__/__init__.cpython-39.pyc and b/ultralytics/engine/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/engine/__pycache__/exporter.cpython-39.pyc b/ultralytics/engine/__pycache__/exporter.cpython-39.pyc new file mode 100644 index 0000000..dc597f0 Binary files /dev/null and b/ultralytics/engine/__pycache__/exporter.cpython-39.pyc differ diff --git a/ultralytics/engine/__pycache__/model.cpython-312.pyc b/ultralytics/engine/__pycache__/model.cpython-312.pyc index 4ac711e..06be6f0 100644 Binary files a/ultralytics/engine/__pycache__/model.cpython-312.pyc and b/ultralytics/engine/__pycache__/model.cpython-312.pyc differ diff --git a/ultralytics/engine/__pycache__/model.cpython-39.pyc b/ultralytics/engine/__pycache__/model.cpython-39.pyc index 01b2025..b1d5c97 100644 Binary files a/ultralytics/engine/__pycache__/model.cpython-39.pyc and b/ultralytics/engine/__pycache__/model.cpython-39.pyc differ diff --git a/ultralytics/engine/__pycache__/predictor.cpython-312.pyc b/ultralytics/engine/__pycache__/predictor.cpython-312.pyc index 5eb87c7..2cb4f14 100644 Binary files a/ultralytics/engine/__pycache__/predictor.cpython-312.pyc and b/ultralytics/engine/__pycache__/predictor.cpython-312.pyc differ diff --git a/ultralytics/engine/__pycache__/predictor.cpython-39.pyc b/ultralytics/engine/__pycache__/predictor.cpython-39.pyc index ec432e9..958802e 100644 Binary files a/ultralytics/engine/__pycache__/predictor.cpython-39.pyc and b/ultralytics/engine/__pycache__/predictor.cpython-39.pyc differ diff --git a/ultralytics/engine/__pycache__/results.cpython-312.pyc b/ultralytics/engine/__pycache__/results.cpython-312.pyc index af95d4f..b892a86 100644 Binary files a/ultralytics/engine/__pycache__/results.cpython-312.pyc and b/ultralytics/engine/__pycache__/results.cpython-312.pyc differ diff --git a/ultralytics/engine/__pycache__/results.cpython-39.pyc b/ultralytics/engine/__pycache__/results.cpython-39.pyc index 933f4e1..a8f289a 100644 Binary files a/ultralytics/engine/__pycache__/results.cpython-39.pyc and b/ultralytics/engine/__pycache__/results.cpython-39.pyc differ diff --git a/ultralytics/engine/__pycache__/trainer.cpython-312.pyc b/ultralytics/engine/__pycache__/trainer.cpython-312.pyc index a3675cb..69d0a55 100644 Binary files a/ultralytics/engine/__pycache__/trainer.cpython-312.pyc and b/ultralytics/engine/__pycache__/trainer.cpython-312.pyc differ diff --git a/ultralytics/engine/__pycache__/trainer.cpython-39.pyc b/ultralytics/engine/__pycache__/trainer.cpython-39.pyc index 7edd344..6ab5ba2 100644 Binary files a/ultralytics/engine/__pycache__/trainer.cpython-39.pyc and b/ultralytics/engine/__pycache__/trainer.cpython-39.pyc differ diff --git a/ultralytics/engine/__pycache__/validator.cpython-312.pyc b/ultralytics/engine/__pycache__/validator.cpython-312.pyc index 583bbd8..8d29bff 100644 Binary files a/ultralytics/engine/__pycache__/validator.cpython-312.pyc and b/ultralytics/engine/__pycache__/validator.cpython-312.pyc differ diff --git a/ultralytics/engine/__pycache__/validator.cpython-39.pyc b/ultralytics/engine/__pycache__/validator.cpython-39.pyc index 7af53b6..469bb61 100644 Binary files a/ultralytics/engine/__pycache__/validator.cpython-39.pyc and b/ultralytics/engine/__pycache__/validator.cpython-39.pyc differ diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index 5c43edc..6ac170c 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -16,7 +16,7 @@ TensorFlow Lite | `tflite` | yolov8n.tflite TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite TensorFlow.js | `tfjs` | yolov8n_web_model/ PaddlePaddle | `paddle` | yolov8n_paddle_model/ -ncnn | `ncnn` | yolov8n_ncnn_model/ +NCNN | `ncnn` | yolov8n_ncnn_model/ Requirements: $ pip install "ultralytics[export]" @@ -41,6 +41,7 @@ Inference: yolov8n.tflite # TensorFlow Lite yolov8n_edgetpu.tflite # TensorFlow Edge TPU yolov8n_paddle_model # PaddlePaddle + yolov8n_ncnn_model # NCNN TensorFlow.js: $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example @@ -48,6 +49,7 @@ TensorFlow.js: $ ln -s ../../yolov5/yolov8n_web_model public/yolov8n_web_model $ npm start """ + import json import os import shutil @@ -64,36 +66,50 @@ import torch from ultralytics.cfg import get_cfg from ultralytics.data.dataset import YOLODataset from ultralytics.data.utils import check_det_dataset -from ultralytics.nn.autobackend import check_class_names -from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder -from ultralytics.nn.tasks import DetectionModel, SegmentationModel -from ultralytics.utils import (ARM64, DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, WINDOWS, __version__, callbacks, - colorstr, get_default_args, yaml_save) -from ultralytics.utils.checks import check_imgsz, check_requirements, check_version +from ultralytics.nn.autobackend import check_class_names, default_class_names +from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder, v10Detect +from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel +from ultralytics.utils import ( + ARM64, + DEFAULT_CFG, + LINUX, + LOGGER, + MACOS, + ROOT, + WINDOWS, + __version__, + callbacks, + colorstr, + get_default_args, + yaml_save, +) +from ultralytics.utils.checks import PYTHON_VERSION, check_imgsz, check_is_path_safe, check_requirements, check_version from ultralytics.utils.downloads import attempt_download_asset, get_github_assets from ultralytics.utils.files import file_size, spaces_in_path from ultralytics.utils.ops import Profile -from ultralytics.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode +from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device, smart_inference_mode def export_formats(): """YOLOv8 export formats.""" import pandas + x = [ - ['PyTorch', '-', '.pt', True, True], - ['TorchScript', 'torchscript', '.torchscript', True, True], - ['ONNX', 'onnx', '.onnx', True, True], - ['OpenVINO', 'openvino', '_openvino_model', True, False], - ['TensorRT', 'engine', '.engine', False, True], - ['CoreML', 'coreml', '.mlpackage', True, False], - ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True], - ['TensorFlow GraphDef', 'pb', '.pb', True, True], - ['TensorFlow Lite', 'tflite', '.tflite', True, False], - ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False], - ['TensorFlow.js', 'tfjs', '_web_model', True, False], - ['PaddlePaddle', 'paddle', '_paddle_model', True, True], - ['ncnn', 'ncnn', '_ncnn_model', True, True], ] - return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU']) + ["PyTorch", "-", ".pt", True, True], + ["TorchScript", "torchscript", ".torchscript", True, True], + ["ONNX", "onnx", ".onnx", True, True], + ["OpenVINO", "openvino", "_openvino_model", True, False], + ["TensorRT", "engine", ".engine", False, True], + ["CoreML", "coreml", ".mlpackage", True, False], + ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True], + ["TensorFlow GraphDef", "pb", ".pb", True, True], + ["TensorFlow Lite", "tflite", ".tflite", True, False], + ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False], + ["TensorFlow.js", "tfjs", "_web_model", True, False], + ["PaddlePaddle", "paddle", "_paddle_model", True, True], + ["NCNN", "ncnn", "_ncnn_model", True, True], + ] + return pandas.DataFrame(x, columns=["Format", "Argument", "Suffix", "CPU", "GPU"]) def gd_outputs(gd): @@ -102,7 +118,7 @@ def gd_outputs(gd): for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef name_list.append(node.name) input_list.extend(node.input) - return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp')) + return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp")) def try_export(inner_func): @@ -111,14 +127,14 @@ def try_export(inner_func): def outer_func(*args, **kwargs): """Export a model.""" - prefix = inner_args['prefix'] + prefix = inner_args["prefix"] try: with Profile() as dt: f, model = inner_func(*args, **kwargs) LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)") return f, model except Exception as e: - LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}') + LOGGER.info(f"{prefix} export failure ❌ {dt.t:.1f}s: {e}") raise e return outer_func @@ -140,53 +156,65 @@ class Exporter: Args: cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. overrides (dict, optional): Configuration overrides. Defaults to None. - _callbacks (list, optional): List of callback functions. Defaults to None. + _callbacks (dict, optional): Dictionary of callback functions. Defaults to None. """ self.args = get_cfg(cfg, overrides) + if self.args.format.lower() in ("coreml", "mlmodel"): # fix attempt for protobuf<3.20.x errors + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback + self.callbacks = _callbacks or callbacks.get_default_callbacks() callbacks.add_integration_callbacks(self) @smart_inference_mode() def __call__(self, model=None): """Returns list of exported files/dirs after running callbacks.""" - self.run_callbacks('on_export_start') + self.run_callbacks("on_export_start") t = time.time() - format = self.args.format.lower() # to lowercase - if format in ('tensorrt', 'trt'): # 'engine' aliases - format = 'engine' - if format in ('mlmodel', 'mlpackage', 'mlprogram', 'apple', 'ios', 'coreml'): # 'coreml' aliases - os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' # fix attempt for protobuf<3.20.x errors - format = 'coreml' - fmts = tuple(export_formats()['Argument'][1:]) # available export formats - flags = [x == format for x in fmts] + fmt = self.args.format.lower() # to lowercase + if fmt in ("tensorrt", "trt"): # 'engine' aliases + fmt = "engine" + if fmt in ("mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"): # 'coreml' aliases + fmt = "coreml" + fmts = tuple(export_formats()["Argument"][1:]) # available export formats + flags = [x == fmt for x in fmts] if sum(flags) != 1: - raise ValueError(f"Invalid export format='{format}'. Valid formats are {fmts}") + raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}") jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans # Device - if format == 'engine' and self.args.device is None: - LOGGER.warning('WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0') - self.args.device = '0' - self.device = select_device('cpu' if self.args.device is None else self.args.device) + if fmt == "engine" and self.args.device is None: + LOGGER.warning("WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0") + self.args.device = "0" + self.device = select_device("cpu" if self.args.device is None else self.args.device) # Checks + if not hasattr(model, "names"): + model.names = default_class_names() model.names = check_class_names(model.names) - if self.args.half and onnx and self.device.type == 'cpu': - LOGGER.warning('WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0') + if self.args.half and onnx and self.device.type == "cpu": + LOGGER.warning("WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0") self.args.half = False - assert not self.args.dynamic, 'half=True not compatible with dynamic=True, i.e. use only one.' + assert not self.args.dynamic, "half=True not compatible with dynamic=True, i.e. use only one." self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size if self.args.optimize: assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False" - assert self.device.type == 'cpu', "optimize=True not compatible with cuda devices, i.e. use device='cpu'" + assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'" if edgetpu and not LINUX: - raise SystemError('Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/') + raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/") + if isinstance(model, WorldModel): + LOGGER.warning( + "WARNING ⚠️ YOLOWorld (original version) export is not supported to any format.\n" + "WARNING ⚠️ YOLOWorldv2 models (i.e. 'yolov8s-worldv2.pt') only support export to " + "(torchscript, onnx, openvino, engine, coreml) formats. " + "See https://docs.ultralytics.com/models/yolo-world for details." + ) # Input im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device) file = Path( - getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml.get('yaml_file', '')) - if file.suffix in ('.yaml', '.yml'): + getattr(model, "pt_path", None) or getattr(model, "yaml_file", None) or model.yaml.get("yaml_file", "") + ) + if file.suffix in {".yaml", ".yml"}: file = Path(file.name) # Update model @@ -197,10 +225,13 @@ class Exporter: model.float() model = model.fuse() for m in model.modules(): - if isinstance(m, (Detect, RTDETRDecoder)): # Segment and Pose use Detect base class + if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB m.dynamic = self.args.dynamic m.export = True m.format = self.args.format + if isinstance(m, v10Detect): + m.max_det = self.args.max_det + elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)): # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph m.forward = m.forward_split @@ -208,47 +239,54 @@ class Exporter: y = None for _ in range(2): y = model(im) # dry runs - if self.args.half and (engine or onnx) and self.device.type != 'cpu': + if self.args.half and onnx and self.device.type != "cpu": im, model = im.half(), model.half() # to FP16 # Filter warnings - warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning - warnings.filterwarnings('ignore', category=UserWarning) # suppress shape prim::Constant missing ONNX warning - warnings.filterwarnings('ignore', category=DeprecationWarning) # suppress CoreML np.bool deprecation warning + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) # suppress TracerWarning + warnings.filterwarnings("ignore", category=UserWarning) # suppress shape prim::Constant missing ONNX warning + warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress CoreML np.bool deprecation warning # Assign self.im = im self.model = model self.file = file - self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else \ - tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y) - self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO') - data = model.args['data'] if hasattr(model, 'args') and isinstance(model.args, dict) else '' + self.output_shape = ( + tuple(y.shape) + if isinstance(y, torch.Tensor) + else tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y) + ) + self.pretty_name = Path(self.model.yaml.get("yaml_file", self.file)).stem.replace("yolo", "YOLO") + data = model.args["data"] if hasattr(model, "args") and isinstance(model.args, dict) else "" description = f'Ultralytics {self.pretty_name} model {f"trained on {data}" if data else ""}' self.metadata = { - 'description': description, - 'author': 'Ultralytics', - 'license': 'AGPL-3.0 https://ultralytics.com/license', - 'date': datetime.now().isoformat(), - 'version': __version__, - 'stride': int(max(model.stride)), - 'task': model.task, - 'batch': self.args.batch, - 'imgsz': self.imgsz, - 'names': model.names} # model metadata - if model.task == 'pose': - self.metadata['kpt_shape'] = model.model[-1].kpt_shape + "description": description, + "author": "Ultralytics", + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + "stride": int(max(model.stride)), + "task": model.task, + "batch": self.args.batch, + "imgsz": self.imgsz, + "names": model.names, + } # model metadata + if model.task == "pose": + self.metadata["kpt_shape"] = model.model[-1].kpt_shape - LOGGER.info(f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and " - f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)') + LOGGER.info( + f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and " + f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)' + ) # Exports - f = [''] * len(fmts) # exported filenames + f = [""] * len(fmts) # exported filenames if jit or ncnn: # TorchScript f[0], _ = self.export_torchscript() if engine: # TensorRT required before ONNX f[1], _ = self.export_engine() - if onnx or xml: # OpenVINO requires ONNX + if onnx: # ONNX f[2], _ = self.export_onnx() if xml: # OpenVINO f[3], _ = self.export_openvino() @@ -262,12 +300,12 @@ class Exporter: if tflite: f[7], _ = self.export_tflite(keras_model=keras_model, nms=False, agnostic_nms=self.args.agnostic_nms) if edgetpu: - f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f'{self.file.stem}_full_integer_quant.tflite') + f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f"{self.file.stem}_full_integer_quant.tflite") if tfjs: f[9], _ = self.export_tfjs() if paddle: # PaddlePaddle f[10], _ = self.export_paddle() - if ncnn: # ncnn + if ncnn: # NCNN f[11], _ = self.export_ncnn() # Finish @@ -275,58 +313,65 @@ class Exporter: if any(f): f = str(Path(f[-1])) square = self.imgsz[0] == self.imgsz[1] - s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \ - f"work. Use export 'imgsz={max(self.imgsz)}' if val is required." - imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '') - predict_data = f'data={data}' if model.task == 'segment' and format == 'pb' else '' - q = 'int8' if self.args.int8 else 'half' if self.args.half else '' # quantization - LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)' - f"\nResults saved to {colorstr('bold', file.parent.resolve())}" - f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}' - f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}' - f'\nVisualize: https://netron.app') + s = ( + "" + if square + else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " + f"work. Use export 'imgsz={max(self.imgsz)}' if val is required." + ) + imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(" ", "") + predict_data = f"data={data}" if model.task == "segment" and fmt == "pb" else "" + q = "int8" if self.args.int8 else "half" if self.args.half else "" # quantization + LOGGER.info( + f'\nExport complete ({time.time() - t:.1f}s)' + f"\nResults saved to {colorstr('bold', file.parent.resolve())}" + f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}' + f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}' + f'\nVisualize: https://netron.app' + ) - self.run_callbacks('on_export_end') + self.run_callbacks("on_export_end") return f # return list of exported files/dirs @try_export - def export_torchscript(self, prefix=colorstr('TorchScript:')): + def export_torchscript(self, prefix=colorstr("TorchScript:")): """YOLOv8 TorchScript model export.""" - LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...') - f = self.file.with_suffix('.torchscript') + LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...") + f = self.file.with_suffix(".torchscript") ts = torch.jit.trace(self.model, self.im, strict=False) - extra_files = {'config.txt': json.dumps(self.metadata)} # torch._C.ExtraFilesMap() + extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap() if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html - LOGGER.info(f'{prefix} optimizing for mobile...') + LOGGER.info(f"{prefix} optimizing for mobile...") from torch.utils.mobile_optimizer import optimize_for_mobile + optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files) else: ts.save(str(f), _extra_files=extra_files) return f, None @try_export - def export_onnx(self, prefix=colorstr('ONNX:')): + def export_onnx(self, prefix=colorstr("ONNX:")): """YOLOv8 ONNX export.""" - requirements = ['onnx>=1.12.0'] + requirements = ["onnx>=1.12.0"] if self.args.simplify: - requirements += ['onnxsim>=0.4.33', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'] + requirements += ["onnxslim==0.1.31", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")] check_requirements(requirements) import onnx # noqa opset_version = self.args.opset or get_latest_opset() - LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...') - f = str(self.file.with_suffix('.onnx')) + LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...") + f = str(self.file.with_suffix(".onnx")) - output_names = ['output0', 'output1'] if isinstance(self.model, SegmentationModel) else ['output0'] + output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"] dynamic = self.args.dynamic if dynamic: - dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640) + dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640) if isinstance(self.model, SegmentationModel): - dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 116, 8400) - dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160) + dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 116, 8400) + dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160) elif isinstance(self.model, DetectionModel): - dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 84, 8400) + dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400) torch.onnx.export( self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu @@ -335,9 +380,10 @@ class Exporter: verbose=False, opset_version=opset_version, do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False - input_names=['images'], + input_names=["images"], output_names=output_names, - dynamic_axes=dynamic or None) + dynamic_axes=dynamic or None, + ) # Checks model_onnx = onnx.load(f) # load onnx model @@ -346,14 +392,17 @@ class Exporter: # Simplify if self.args.simplify: try: - import onnxsim + import onnxslim - LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...') - # subprocess.run(f'onnxsim "{f}" "{f}"', shell=True) - model_onnx, check = onnxsim.simplify(model_onnx) - assert check, 'Simplified ONNX model could not be validated' + LOGGER.info(f"{prefix} simplifying with onnxslim {onnxslim.__version__}...") + model_onnx = onnxslim.slim(model_onnx) + + # ONNX Simplifier (deprecated as must be compiled with 'cmake' in aarch64 and Conda CI environments) + # import onnxsim + # model_onnx, check = onnxsim.simplify(model_onnx) + # assert check, "Simplified ONNX model could not be validated" except Exception as e: - LOGGER.info(f'{prefix} simplifier failure: {e}') + LOGGER.warning(f"{prefix} simplifier failure: {e}") # Metadata for k, v in self.metadata.items(): @@ -364,162 +413,193 @@ class Exporter: return f, model_onnx @try_export - def export_openvino(self, prefix=colorstr('OpenVINO:')): + def export_openvino(self, prefix=colorstr("OpenVINO:")): """YOLOv8 OpenVINO export.""" - check_requirements('openvino-dev>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/ - import openvino.runtime as ov # noqa - from openvino.tools import mo # noqa + check_requirements("openvino>=2024.0.0") # requires openvino: https://pypi.org/project/openvino/ + import openvino as ov - LOGGER.info(f'\n{prefix} starting export with openvino {ov.__version__}...') - f = str(self.file).replace(self.file.suffix, f'_openvino_model{os.sep}') - fq = str(self.file).replace(self.file.suffix, f'_int8_openvino_model{os.sep}') - f_onnx = self.file.with_suffix('.onnx') - f_ov = str(Path(f) / self.file.with_suffix('.xml').name) - fq_ov = str(Path(fq) / self.file.with_suffix('.xml').name) + LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...") + assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed" + ov_model = ov.convert_model( + self.model.cpu(), + input=None if self.args.dynamic else [self.im.shape], + example_input=self.im, + ) def serialize(ov_model, file): """Set RT info, serialize and save metadata YAML.""" - ov_model.set_rt_info('YOLOv8', ['model_info', 'model_type']) - ov_model.set_rt_info(True, ['model_info', 'reverse_input_channels']) - ov_model.set_rt_info(114, ['model_info', 'pad_value']) - ov_model.set_rt_info([255.0], ['model_info', 'scale_values']) - ov_model.set_rt_info(self.args.iou, ['model_info', 'iou_threshold']) - ov_model.set_rt_info([v.replace(' ', '_') for v in self.model.names.values()], ['model_info', 'labels']) - if self.model.task != 'classify': - ov_model.set_rt_info('fit_to_window_letterbox', ['model_info', 'resize_type']) + ov_model.set_rt_info("YOLOv8", ["model_info", "model_type"]) + ov_model.set_rt_info(True, ["model_info", "reverse_input_channels"]) + ov_model.set_rt_info(114, ["model_info", "pad_value"]) + ov_model.set_rt_info([255.0], ["model_info", "scale_values"]) + ov_model.set_rt_info(self.args.iou, ["model_info", "iou_threshold"]) + ov_model.set_rt_info([v.replace(" ", "_") for v in self.model.names.values()], ["model_info", "labels"]) + if self.model.task != "classify": + ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"]) - ov.serialize(ov_model, file) # save - yaml_save(Path(file).parent / 'metadata.yaml', self.metadata) # add metadata.yaml - - ov_model = mo.convert_model(f_onnx, - model_name=self.pretty_name, - framework='onnx', - compress_to_fp16=self.args.half) # export + ov.runtime.save_model(ov_model, file, compress_to_fp16=self.args.half) + yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml if self.args.int8: - assert self.args.data, "INT8 export requires a data argument for calibration, i.e. 'data=coco8.yaml'" - check_requirements('nncf>=2.5.0') + fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}") + fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name) + if not self.args.data: + self.args.data = DEFAULT_CFG.data or "coco128.yaml" + LOGGER.warning( + f"{prefix} WARNING ⚠️ INT8 export requires a missing 'data' arg for calibration. " + f"Using default 'data={self.args.data}'." + ) + check_requirements("nncf>=2.8.0") import nncf def transform_fn(data_item): """Quantization transform function.""" - im = data_item['img'].numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0 + assert ( + data_item["img"].dtype == torch.uint8 + ), "Input image must be uint8 for the quantization preprocessing" + im = data_item["img"].numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0 return np.expand_dims(im, 0) if im.ndim == 3 else im # Generate calibration data for integer quantization LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'") data = check_det_dataset(self.args.data) - dataset = YOLODataset(data['val'], data=data, imgsz=self.imgsz[0], augment=False) + dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False) + n = len(dataset) + if n < 300: + LOGGER.warning(f"{prefix} WARNING ⚠️ >300 images recommended for INT8 calibration, found {n} images.") quantization_dataset = nncf.Dataset(dataset, transform_fn) - ignored_scope = nncf.IgnoredScope(types=['Multiply', 'Subtract', 'Sigmoid']) # ignore operation - quantized_ov_model = nncf.quantize(ov_model, - quantization_dataset, - preset=nncf.QuantizationPreset.MIXED, - ignored_scope=ignored_scope) + + ignored_scope = None + if isinstance(self.model.model[-1], Detect): + # Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect + head_module_name = ".".join(list(self.model.named_modules())[-1][0].split(".")[:2]) + + ignored_scope = nncf.IgnoredScope( # ignore operations + patterns=[ + f".*{head_module_name}/.*/Add", + f".*{head_module_name}/.*/Sub*", + f".*{head_module_name}/.*/Mul*", + f".*{head_module_name}/.*/Div*", + f".*{head_module_name}\\.dfl.*", + ], + types=["Sigmoid"], + ) + + quantized_ov_model = nncf.quantize( + ov_model, quantization_dataset, preset=nncf.QuantizationPreset.MIXED, ignored_scope=ignored_scope + ) serialize(quantized_ov_model, fq_ov) return fq, None + f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}") + f_ov = str(Path(f) / self.file.with_suffix(".xml").name) + serialize(ov_model, f_ov) return f, None @try_export - def export_paddle(self, prefix=colorstr('PaddlePaddle:')): + def export_paddle(self, prefix=colorstr("PaddlePaddle:")): """YOLOv8 Paddle export.""" - check_requirements(('paddlepaddle', 'x2paddle')) + check_requirements(("paddlepaddle", "x2paddle")) import x2paddle # noqa from x2paddle.convert import pytorch2paddle # noqa - LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...') - f = str(self.file).replace(self.file.suffix, f'_paddle_model{os.sep}') + LOGGER.info(f"\n{prefix} starting export with X2Paddle {x2paddle.__version__}...") + f = str(self.file).replace(self.file.suffix, f"_paddle_model{os.sep}") - pytorch2paddle(module=self.model, save_dir=f, jit_type='trace', input_examples=[self.im]) # export - yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml + pytorch2paddle(module=self.model, save_dir=f, jit_type="trace", input_examples=[self.im]) # export + yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml return f, None @try_export - def export_ncnn(self, prefix=colorstr('ncnn:')): + def export_ncnn(self, prefix=colorstr("NCNN:")): """ - YOLOv8 ncnn export using PNNX https://github.com/pnnx/pnnx. + YOLOv8 NCNN export using PNNX https://github.com/pnnx/pnnx. """ - check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires ncnn + check_requirements("ncnn") import ncnn # noqa - LOGGER.info(f'\n{prefix} starting export with ncnn {ncnn.__version__}...') - f = Path(str(self.file).replace(self.file.suffix, f'_ncnn_model{os.sep}')) - f_ts = self.file.with_suffix('.torchscript') + LOGGER.info(f"\n{prefix} starting export with NCNN {ncnn.__version__}...") + f = Path(str(self.file).replace(self.file.suffix, f"_ncnn_model{os.sep}")) + f_ts = self.file.with_suffix(".torchscript") - pnnx_filename = 'pnnx.exe' if WINDOWS else 'pnnx' - if Path(pnnx_filename).is_file(): - pnnx = pnnx_filename - elif (ROOT / pnnx_filename).is_file(): - pnnx = ROOT / pnnx_filename - else: + name = Path("pnnx.exe" if WINDOWS else "pnnx") # PNNX filename + pnnx = name if name.is_file() else ROOT / name + if not pnnx.is_file(): LOGGER.warning( - f'{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from ' - 'https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory ' - f'or in {ROOT}. See PNNX repo for full installation instructions.') - _, assets = get_github_assets(repo='pnnx/pnnx', retry=True) - system = 'macos' if MACOS else 'ubuntu' if LINUX else 'windows' # operating system - asset = [x for x in assets if system in x][0] if assets else \ - f'https://github.com/pnnx/pnnx/releases/download/20230816/pnnx-20230816-{system}.zip' # fallback - asset = attempt_download_asset(asset, repo='pnnx/pnnx', release='latest') - unzip_dir = Path(asset).with_suffix('') - pnnx = ROOT / pnnx_filename # new location - (unzip_dir / pnnx_filename).rename(pnnx) # move binary to ROOT - shutil.rmtree(unzip_dir) # delete unzip dir - Path(asset).unlink() # delete zip - pnnx.chmod(0o777) # set read, write, and execute permissions for everyone + f"{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from " + "https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory " + f"or in {ROOT}. See PNNX repo for full installation instructions." + ) + system = "macos" if MACOS else "windows" if WINDOWS else "linux-aarch64" if ARM64 else "linux" + _, assets = get_github_assets(repo="pnnx/pnnx", retry=True) + if assets: + url = [x for x in assets if f"{system}.zip" in x][0] + else: + url = f"https://github.com/pnnx/pnnx/releases/download/20240226/pnnx-20240226-{system}.zip" + LOGGER.warning(f"{prefix} WARNING ⚠️ PNNX GitHub assets not found, using default {url}") + asset = attempt_download_asset(url, repo="pnnx/pnnx", release="latest") + if check_is_path_safe(Path.cwd(), asset): # avoid path traversal security vulnerability + unzip_dir = Path(asset).with_suffix("") + (unzip_dir / name).rename(pnnx) # move binary to ROOT + shutil.rmtree(unzip_dir) # delete unzip dir + Path(asset).unlink() # delete zip + pnnx.chmod(0o777) # set read, write, and execute permissions for everyone ncnn_args = [ f'ncnnparam={f / "model.ncnn.param"}', f'ncnnbin={f / "model.ncnn.bin"}', - f'ncnnpy={f / "model_ncnn.py"}', ] + f'ncnnpy={f / "model_ncnn.py"}', + ] pnnx_args = [ f'pnnxparam={f / "model.pnnx.param"}', f'pnnxbin={f / "model.pnnx.bin"}', f'pnnxpy={f / "model_pnnx.py"}', - f'pnnxonnx={f / "model.pnnx.onnx"}', ] + f'pnnxonnx={f / "model.pnnx.onnx"}', + ] cmd = [ str(pnnx), str(f_ts), *ncnn_args, *pnnx_args, - f'fp16={int(self.args.half)}', - f'device={self.device.type}', - f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', ] + f"fp16={int(self.args.half)}", + f"device={self.device.type}", + f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', + ] f.mkdir(exist_ok=True) # make ncnn_model directory LOGGER.info(f"{prefix} running '{' '.join(cmd)}'") subprocess.run(cmd, check=True) # Remove debug files - pnnx_files = [x.split('=')[-1] for x in pnnx_args] - for f_debug in ('debug.bin', 'debug.param', 'debug2.bin', 'debug2.param', *pnnx_files): + pnnx_files = [x.split("=")[-1] for x in pnnx_args] + for f_debug in ("debug.bin", "debug.param", "debug2.bin", "debug2.param", *pnnx_files): Path(f_debug).unlink(missing_ok=True) - yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml + yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml return str(f), None @try_export - def export_coreml(self, prefix=colorstr('CoreML:')): + def export_coreml(self, prefix=colorstr("CoreML:")): """YOLOv8 CoreML export.""" - mlmodel = self.args.format.lower() == 'mlmodel' # legacy *.mlmodel export format requested - check_requirements('coremltools>=6.0,<=6.2' if mlmodel else 'coremltools>=7.0.b1') + mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested + check_requirements("coremltools>=6.0,<=6.2" if mlmodel else "coremltools>=7.0") import coremltools as ct # noqa - LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') - f = self.file.with_suffix('.mlmodel' if mlmodel else '.mlpackage') + LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...") + assert not WINDOWS, "CoreML export is not supported on Windows, please run on macOS or Linux." + f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage") if f.is_dir(): shutil.rmtree(f) bias = [0.0, 0.0, 0.0] scale = 1 / 255 classifier_config = None - if self.model.task == 'classify': + if self.model.task == "classify": classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None model = self.model - elif self.model.task == 'detect': + elif self.model.task == "detect": model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model else: if self.args.nms: @@ -528,67 +608,71 @@ class Exporter: model = self.model ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model - ct_model = ct.convert(ts, - inputs=[ct.ImageType('image', shape=self.im.shape, scale=scale, bias=bias)], - classifier_config=classifier_config, - convert_to='neuralnetwork' if mlmodel else 'mlprogram') - bits, mode = (8, 'kmeans') if self.args.int8 else (16, 'linear') if self.args.half else (32, None) + ct_model = ct.convert( + ts, + inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)], + classifier_config=classifier_config, + convert_to="neuralnetwork" if mlmodel else "mlprogram", + ) + bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None) if bits < 32: - if 'kmeans' in mode: - check_requirements('scikit-learn') # scikit-learn package required for k-means quantization + if "kmeans" in mode: + check_requirements("scikit-learn") # scikit-learn package required for k-means quantization if mlmodel: ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) elif bits == 8: # mlprogram already quantized to FP16 import coremltools.optimize.coreml as cto - op_config = cto.OpPalettizerConfig(mode='kmeans', nbits=bits, weight_threshold=512) + + op_config = cto.OpPalettizerConfig(mode="kmeans", nbits=bits, weight_threshold=512) config = cto.OptimizationConfig(global_config=op_config) ct_model = cto.palettize_weights(ct_model, config=config) - if self.args.nms and self.model.task == 'detect': + if self.args.nms and self.model.task == "detect": if mlmodel: - import platform - # coremltools<=6.2 NMS export requires Python<3.11 - check_version(platform.python_version(), '<3.11', name='Python ', hard=True) + check_version(PYTHON_VERSION, "<3.11", name="Python ", hard=True) weights_dir = None else: ct_model.save(str(f)) # save otherwise weights_dir does not exist - weights_dir = str(f / 'Data/com.apple.CoreML/weights') + weights_dir = str(f / "Data/com.apple.CoreML/weights") ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir) m = self.metadata # metadata dict - ct_model.short_description = m.pop('description') - ct_model.author = m.pop('author') - ct_model.license = m.pop('license') - ct_model.version = m.pop('version') + ct_model.short_description = m.pop("description") + ct_model.author = m.pop("author") + ct_model.license = m.pop("license") + ct_model.version = m.pop("version") ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()}) try: ct_model.save(str(f)) # save *.mlpackage except Exception as e: LOGGER.warning( - f'{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. ' - f'Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928.') - f = f.with_suffix('.mlmodel') + f"{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. " + f"Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928." + ) + f = f.with_suffix(".mlmodel") ct_model.save(str(f)) return f, ct_model @try_export - def export_engine(self, prefix=colorstr('TensorRT:')): + def export_engine(self, prefix=colorstr("TensorRT:")): """YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt.""" - assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'" + assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'" + f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016 + try: import tensorrt as trt # noqa except ImportError: if LINUX: - check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com') + check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com") import tensorrt as trt # noqa - check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0 - self.args.simplify = True - f_onnx, _ = self.export_onnx() + check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0 - LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') - assert Path(f_onnx).exists(), f'failed to export ONNX file: {f_onnx}' - f = self.file.with_suffix('.engine') # TensorRT engine file + self.args.simplify = True + + LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...") + assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}" + f = self.file.with_suffix(".engine") # TensorRT engine file logger = trt.Logger(trt.Logger.INFO) if self.args.verbose: logger.min_severity = trt.Logger.Severity.VERBOSE @@ -598,11 +682,11 @@ class Exporter: config.max_workspace_size = self.args.workspace * 1 << 30 # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice - flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) network = builder.create_network(flag) parser = trt.OnnxParser(network, logger) if not parser.parse_from_file(f_onnx): - raise RuntimeError(f'failed to load ONNX file: {f_onnx}') + raise RuntimeError(f"failed to load ONNX file: {f_onnx}") inputs = [network.get_input(i) for i in range(network.num_inputs)] outputs = [network.get_output(i) for i in range(network.num_outputs)] @@ -621,15 +705,19 @@ class Exporter: config.add_optimization_profile(profile) LOGGER.info( - f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}') + f"{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}" + ) if builder.platform_has_fast_fp16 and self.args.half: config.set_flag(trt.BuilderFlag.FP16) + del self.model + torch.cuda.empty_cache() + # Write file - with builder.build_engine(network, config) as engine, open(f, 'wb') as t: + with builder.build_engine(network, config) as engine, open(f, "wb") as t: # Metadata meta = json.dumps(self.metadata) - t.write(len(meta).to_bytes(4, byteorder='little', signed=True)) + t.write(len(meta).to_bytes(4, byteorder="little", signed=True)) t.write(meta.encode()) # Model t.write(engine.serialize()) @@ -637,83 +725,114 @@ class Exporter: return f, None @try_export - def export_saved_model(self, prefix=colorstr('TensorFlow SavedModel:')): + def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")): """YOLOv8 TensorFlow SavedModel export.""" cuda = torch.cuda.is_available() try: import tensorflow as tf # noqa except ImportError: - check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}") + suffix = "-macos" if MACOS else "-aarch64" if ARM64 else "" if cuda else "-cpu" + version = "" if ARM64 else "<=2.13.1" + check_requirements(f"tensorflow{suffix}{version}") import tensorflow as tf # noqa + if ARM64: + check_requirements("cmake") # 'cmake' is needed to build onnxsim on aarch64 check_requirements( - ('onnx', 'onnx2tf>=1.15.4', 'sng4onnx>=1.0.1', 'onnxsim>=0.4.33', 'onnx_graphsurgeon>=0.3.26', - 'tflite_support', 'onnxruntime-gpu' if cuda else 'onnxruntime'), - cmds='--extra-index-url https://pypi.ngc.nvidia.com') # onnx_graphsurgeon only on NVIDIA + ( + "onnx>=1.12.0", + "onnx2tf>=1.15.4,<=1.17.5", + "sng4onnx>=1.0.1", + "onnxslim==0.1.31", + "onnx_graphsurgeon>=0.3.26", + "tflite_support", + "flatbuffers>=23.5.26,<100", # update old 'flatbuffers' included inside tensorflow package + "onnxruntime-gpu" if cuda else "onnxruntime", + ), + cmds="--extra-index-url https://pypi.ngc.nvidia.com", + ) # onnx_graphsurgeon only on NVIDIA - LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') - f = Path(str(self.file).replace(self.file.suffix, '_saved_model')) + LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...") + check_version( + tf.__version__, + "<=2.13.1", + name="tensorflow", + verbose=True, + msg="https://github.com/ultralytics/ultralytics/issues/5161", + ) + import onnx2tf + + f = Path(str(self.file).replace(self.file.suffix, "_saved_model")) if f.is_dir(): - import shutil shutil.rmtree(f) # delete output folder + # Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545 + onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy") + if not onnx2tf_file.exists(): + attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True) + # Export to ONNX self.args.simplify = True f_onnx, _ = self.export_onnx() # Export to TF - tmp_file = f / 'tmp_tflite_int8_calibration_images.npy' # int8 calibration images file + tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file + np_data = None if self.args.int8: - verbosity = '--verbosity info' + verbosity = "info" if self.args.data: # Generate calibration data for integer quantization LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'") data = check_det_dataset(self.args.data) - dataset = YOLODataset(data['val'], data=data, imgsz=self.imgsz[0], augment=False) + dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False) images = [] for i, batch in enumerate(dataset): if i >= 100: # maximum number of calibration images break - im = batch['img'].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC + im = batch["img"].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC images.append(im) f.mkdir() images = torch.cat(images, 0).float() # mean = images.view(-1, 3).mean(0) # imagenet mean [123.675, 116.28, 103.53] # std = images.view(-1, 3).std(0) # imagenet std [58.395, 57.12, 57.375] np.save(str(tmp_file), images.numpy()) # BHWC - int8 = f'-oiqt -qt per-tensor -cind images "{tmp_file}" "[[[[0, 0, 0]]]]" "[[[[255, 255, 255]]]]"' - else: - int8 = '-oiqt -qt per-tensor' + np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]] else: - verbosity = '--non_verbose' - int8 = '' + verbosity = "error" - cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo {verbosity} {int8}'.strip() - LOGGER.info(f"{prefix} running '{cmd}'") - subprocess.run(cmd, shell=True) - yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml + LOGGER.info(f"{prefix} starting TFLite export with onnx2tf {onnx2tf.__version__}...") + onnx2tf.convert( + input_onnx_file_path=f_onnx, + output_folder_path=str(f), + not_use_onnxsim=True, + verbosity=verbosity, + output_integer_quantized_tflite=self.args.int8, + quant_type="per-tensor", # "per-tensor" (faster) or "per-channel" (slower but more accurate) + custom_input_op_name_np_data_path=np_data, + ) + yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml # Remove/rename TFLite models if self.args.int8: tmp_file.unlink(missing_ok=True) - for file in f.rglob('*_dynamic_range_quant.tflite'): - file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix)) - for file in f.rglob('*_integer_quant_with_int16_act.tflite'): + for file in f.rglob("*_dynamic_range_quant.tflite"): + file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix)) + for file in f.rglob("*_integer_quant_with_int16_act.tflite"): file.unlink() # delete extra fp16 activation TFLite files # Add TFLite metadata - for file in f.rglob('*.tflite'): - f.unlink() if 'quant_with_int16_act.tflite' in str(f) else self._add_tflite_metadata(file) + for file in f.rglob("*.tflite"): + f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file) return str(f), tf.saved_model.load(f, tags=None, options=None) # load saved_model as Keras model @try_export - def export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')): + def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")): """YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow.""" import tensorflow as tf # noqa from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa - LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') - f = self.file.with_suffix('.pb') + LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...") + f = self.file.with_suffix(".pb") m = tf.function(lambda x: keras_model(x)) # full model m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)) @@ -723,40 +842,43 @@ class Exporter: return f, None @try_export - def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')): + def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr("TensorFlow Lite:")): """YOLOv8 TensorFlow Lite export.""" import tensorflow as tf # noqa - LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') - saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model')) + LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...") + saved_model = Path(str(self.file).replace(self.file.suffix, "_saved_model")) if self.args.int8: - f = saved_model / f'{self.file.stem}_int8.tflite' # fp32 in/out + f = saved_model / f"{self.file.stem}_int8.tflite" # fp32 in/out elif self.args.half: - f = saved_model / f'{self.file.stem}_float16.tflite' # fp32 in/out + f = saved_model / f"{self.file.stem}_float16.tflite" # fp32 in/out else: - f = saved_model / f'{self.file.stem}_float32.tflite' + f = saved_model / f"{self.file.stem}_float32.tflite" return str(f), None @try_export - def export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')): + def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")): """YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/.""" - LOGGER.warning(f'{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185') + LOGGER.warning(f"{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185") - cmd = 'edgetpu_compiler --version' - help_url = 'https://coral.ai/docs/edgetpu/compiler/' - assert LINUX, f'export only supported on Linux. See {help_url}' + cmd = "edgetpu_compiler --version" + help_url = "https://coral.ai/docs/edgetpu/compiler/" + assert LINUX, f"export only supported on Linux. See {help_url}" if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0: - LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}') - sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system + LOGGER.info(f"\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}") + sudo = subprocess.run("sudo --version >/dev/null", shell=True).returncode == 0 # sudo installed on system for c in ( - 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -', - 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list', - 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'): - subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True) + "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -", + 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | ' + "sudo tee /etc/apt/sources.list.d/coral-edgetpu.list", + "sudo apt-get update", + "sudo apt-get install edgetpu-compiler", + ): + subprocess.run(c if sudo else c.replace("sudo ", ""), shell=True, check=True) ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1] - LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...') - f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model + LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...") + f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model cmd = f'edgetpu_compiler -s -d -k 10 --out_dir "{Path(f).parent}" "{tflite_model}"' LOGGER.info(f"{prefix} running '{cmd}'") @@ -765,28 +887,35 @@ class Exporter: return f, None @try_export - def export_tfjs(self, prefix=colorstr('TensorFlow.js:')): + def export_tfjs(self, prefix=colorstr("TensorFlow.js:")): """YOLOv8 TensorFlow.js export.""" - check_requirements('tensorflowjs') + check_requirements("tensorflowjs") + if ARM64: + # Fix error: `np.object` was a deprecated alias for the builtin `object` when exporting to TF.js on ARM64 + check_requirements("numpy==1.23.5") import tensorflow as tf import tensorflowjs as tfjs # noqa - LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...') - f = str(self.file).replace(self.file.suffix, '_web_model') # js dir - f_pb = str(self.file.with_suffix('.pb')) # *.pb path + LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...") + f = str(self.file).replace(self.file.suffix, "_web_model") # js dir + f_pb = str(self.file.with_suffix(".pb")) # *.pb path gd = tf.Graph().as_graph_def() # TF GraphDef - with open(f_pb, 'rb') as file: + with open(f_pb, "rb") as file: gd.ParseFromString(file.read()) - outputs = ','.join(gd_outputs(gd)) - LOGGER.info(f'\n{prefix} output node names: {outputs}') + outputs = ",".join(gd_outputs(gd)) + LOGGER.info(f"\n{prefix} output node names: {outputs}") + quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else "" with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path - cmd = f'tensorflowjs_converter --input_format=tf_frozen_model --output_node_names={outputs} "{fpb_}" "{f_}"' + cmd = ( + "tensorflowjs_converter " + f'--input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"' + ) LOGGER.info(f"{prefix} running '{cmd}'") subprocess.run(cmd, shell=True) - if ' ' in str(f): + if " " in f: LOGGER.warning(f"{prefix} WARNING ⚠️ your model may not work correctly with spaces in path '{f}'.") # f_json = Path(f) / 'model.json' # *.json path @@ -803,7 +932,7 @@ class Exporter: # f_json.read_text(), # ) # j.write(subst) - yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml + yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml return f, None def _add_tflite_metadata(self, file): @@ -814,14 +943,14 @@ class Exporter: # Create model info model_meta = _metadata_fb.ModelMetadataT() - model_meta.name = self.metadata['description'] - model_meta.version = self.metadata['version'] - model_meta.author = self.metadata['author'] - model_meta.license = self.metadata['license'] + model_meta.name = self.metadata["description"] + model_meta.version = self.metadata["version"] + model_meta.author = self.metadata["author"] + model_meta.license = self.metadata["license"] # Label file - tmp_file = Path(file).parent / 'temp_meta.txt' - with open(tmp_file, 'w') as f: + tmp_file = Path(file).parent / "temp_meta.txt" + with open(tmp_file, "w") as f: f.write(str(self.metadata)) label_file = _metadata_fb.AssociatedFileT() @@ -830,8 +959,8 @@ class Exporter: # Create input info input_meta = _metadata_fb.TensorMetadataT() - input_meta.name = 'image' - input_meta.description = 'Input image to be detected.' + input_meta.name = "image" + input_meta.description = "Input image to be detected." input_meta.content = _metadata_fb.ContentT() input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT() input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB @@ -839,19 +968,19 @@ class Exporter: # Create output info output1 = _metadata_fb.TensorMetadataT() - output1.name = 'output' - output1.description = 'Coordinates of detected objects, class labels, and confidence score' + output1.name = "output" + output1.description = "Coordinates of detected objects, class labels, and confidence score" output1.associatedFiles = [label_file] - if self.model.task == 'segment': + if self.model.task == "segment": output2 = _metadata_fb.TensorMetadataT() - output2.name = 'output' - output2.description = 'Mask protos' + output2.name = "output" + output2.description = "Mask protos" output2.associatedFiles = [label_file] # Create subgraph info subgraph = _metadata_fb.SubGraphMetadataT() subgraph.inputTensorMetadata = [input_meta] - subgraph.outputTensorMetadata = [output1, output2] if self.model.task == 'segment' else [output1] + subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1] model_meta.subgraphMetadata = [subgraph] b = flatbuffers.Builder(0) @@ -864,11 +993,11 @@ class Exporter: populator.populate() tmp_file.unlink() - def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr('CoreML Pipeline:')): + def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")): """YOLOv8 CoreML pipeline.""" import coremltools as ct # noqa - LOGGER.info(f'{prefix} starting pipeline with coremltools {ct.__version__}...') + LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...") _, _, h, w = list(self.im.shape) # BCHW # Output shapes @@ -876,8 +1005,9 @@ class Exporter: out0, out1 = iter(spec.description.output) if MACOS: from PIL import Image - img = Image.new('RGB', (w, h)) # w=192, h=320 - out = model.predict({'image': img}) + + img = Image.new("RGB", (w, h)) # w=192, h=320 + out = model.predict({"image": img}) out0_shape = out[out0.name].shape # (3780, 80) out1_shape = out[out1.name].shape # (3780, 4) else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y @@ -885,11 +1015,11 @@ class Exporter: out1_shape = self.output_shape[2], 4 # (3780, 4) # Checks - names = self.metadata['names'] + names = self.metadata["names"] nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height _, nc = out0_shape # number of anchors, number of classes # _, nc = out0.type.multiArrayType.shape - assert len(names) == nc, f'{len(names)} names found for nc={nc}' # check + assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check # Define output shapes (missing) out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80) @@ -923,8 +1053,8 @@ class Exporter: nms_spec.description.output.add() nms_spec.description.output[i].ParseFromString(decoder_output) - nms_spec.description.output[0].name = 'confidence' - nms_spec.description.output[1].name = 'coordinates' + nms_spec.description.output[0].name = "confidence" + nms_spec.description.output[1].name = "coordinates" output_sizes = [nc, 4] for i in range(2): @@ -940,10 +1070,10 @@ class Exporter: nms = nms_spec.nonMaximumSuppression nms.confidenceInputFeatureName = out0.name # 1x507x80 nms.coordinatesInputFeatureName = out1.name # 1x507x4 - nms.confidenceOutputFeatureName = 'confidence' - nms.coordinatesOutputFeatureName = 'coordinates' - nms.iouThresholdInputFeatureName = 'iouThreshold' - nms.confidenceThresholdInputFeatureName = 'confidenceThreshold' + nms.confidenceOutputFeatureName = "confidence" + nms.coordinatesOutputFeatureName = "coordinates" + nms.iouThresholdInputFeatureName = "iouThreshold" + nms.confidenceThresholdInputFeatureName = "confidenceThreshold" nms.iouThreshold = 0.45 nms.confidenceThreshold = 0.25 nms.pickTop.perClass = True @@ -951,10 +1081,14 @@ class Exporter: nms_model = ct.models.MLModel(nms_spec) # 4. Pipeline models together - pipeline = ct.models.pipeline.Pipeline(input_features=[('image', ct.models.datatypes.Array(3, ny, nx)), - ('iouThreshold', ct.models.datatypes.Double()), - ('confidenceThreshold', ct.models.datatypes.Double())], - output_features=['confidence', 'coordinates']) + pipeline = ct.models.pipeline.Pipeline( + input_features=[ + ("image", ct.models.datatypes.Array(3, ny, nx)), + ("iouThreshold", ct.models.datatypes.Double()), + ("confidenceThreshold", ct.models.datatypes.Double()), + ], + output_features=["confidence", "coordinates"], + ) pipeline.add_model(model) pipeline.add_model(nms_model) @@ -965,25 +1099,24 @@ class Exporter: # Update metadata pipeline.spec.specificationVersion = 5 - pipeline.spec.description.metadata.userDefined.update({ - 'IoU threshold': str(nms.iouThreshold), - 'Confidence threshold': str(nms.confidenceThreshold)}) + pipeline.spec.description.metadata.userDefined.update( + {"IoU threshold": str(nms.iouThreshold), "Confidence threshold": str(nms.confidenceThreshold)} + ) # Save the model model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir) - model.input_description['image'] = 'Input image' - model.input_description['iouThreshold'] = f'(optional) IOU threshold override (default: {nms.iouThreshold})' - model.input_description['confidenceThreshold'] = \ - f'(optional) Confidence threshold override (default: {nms.confidenceThreshold})' - model.output_description['confidence'] = 'Boxes × Class confidence (see user-defined metadata "classes")' - model.output_description['coordinates'] = 'Boxes × [x, y, width, height] (relative to image size)' - LOGGER.info(f'{prefix} pipeline success') + model.input_description["image"] = "Input image" + model.input_description["iouThreshold"] = f"(optional) IoU threshold override (default: {nms.iouThreshold})" + model.input_description["confidenceThreshold"] = ( + f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})" + ) + model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")' + model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)" + LOGGER.info(f"{prefix} pipeline success") return model def add_callback(self, event: str, callback): - """ - Appends the given callback. - """ + """Appends the given callback.""" self.callbacks[event].append(callback) def run_callbacks(self, event: str): diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index ac57323..ef5c93c 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -5,64 +5,109 @@ import sys from pathlib import Path from typing import Union +import numpy as np +import torch + from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir from ultralytics.hub.utils import HUB_WEB_ROOT from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load -from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, emojis, yaml_load -from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml -from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS -from ultralytics.utils.torch_utils import smart_inference_mode +from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, callbacks, checks, emojis, yaml_load -class Model: +class Model(nn.Module): """ - A base model class to unify apis for all the models. + A base class for implementing YOLO models, unifying APIs across different model types. + + This class provides a common interface for various operations related to YOLO models, such as training, + validation, prediction, exporting, and benchmarking. It handles different types of models, including those + loaded from local files, Ultralytics HUB, or Triton Server. The class is designed to be flexible and + extendable for different tasks and model configurations. Args: - model (str, Path): Path to the model file to load or create. - task (Any, optional): Task type for the YOLO model. Defaults to None. + model (Union[str, Path], optional): Path or name of the model to load or create. This can be a local file + path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'. + task (Any, optional): The task type associated with the YOLO model. This can be used to specify the model's + application domain, such as object detection, segmentation, etc. Defaults to None. + verbose (bool, optional): If True, enables verbose output during the model's operations. Defaults to False. Attributes: - predictor (Any): The predictor object. - model (Any): The model object. - trainer (Any): The trainer object. - task (str): The type of model task. - ckpt (Any): The checkpoint object if the model loaded from *.pt file. - cfg (str): The model configuration if loaded from *.yaml file. - ckpt_path (str): The checkpoint file path. - overrides (dict): Overrides for the trainer object. - metrics (Any): The data for metrics. + callbacks (dict): A dictionary of callback functions for various events during model operations. + predictor (BasePredictor): The predictor object used for making predictions. + model (nn.Module): The underlying PyTorch model. + trainer (BaseTrainer): The trainer object used for training the model. + ckpt (dict): The checkpoint data if the model is loaded from a *.pt file. + cfg (str): The configuration of the model if loaded from a *.yaml file. + ckpt_path (str): The path to the checkpoint file. + overrides (dict): A dictionary of overrides for model configuration. + metrics (dict): The latest training/validation metrics. + session (HUBTrainingSession): The Ultralytics HUB session, if applicable. + task (str): The type of task the model is intended for. + model_name (str): The name of the model. Methods: - __call__(source=None, stream=False, **kwargs): - Alias for the predict method. - _new(cfg:str, verbose:bool=True) -> None: - Initializes a new model and infers the task type from the model definitions. - _load(weights:str, task:str='') -> None: - Initializes a new model and infers the task type from the model head. - _check_is_pytorch_model() -> None: - Raises TypeError if the model is not a PyTorch model. - reset() -> None: - Resets the model modules. - info(verbose:bool=False) -> None: - Logs the model info. - fuse() -> None: - Fuses the model for faster inference. - predict(source=None, stream=False, **kwargs) -> List[ultralytics.engine.results.Results]: - Performs prediction using the YOLO model. + __call__: Alias for the predict method, enabling the model instance to be callable. + _new: Initializes a new model based on a configuration file. + _load: Loads a model from a checkpoint file. + _check_is_pytorch_model: Ensures that the model is a PyTorch model. + reset_weights: Resets the model's weights to their initial state. + load: Loads model weights from a specified file. + save: Saves the current state of the model to a file. + info: Logs or returns information about the model. + fuse: Fuses Conv2d and BatchNorm2d layers for optimized inference. + predict: Performs object detection predictions. + track: Performs object tracking. + val: Validates the model on a dataset. + benchmark: Benchmarks the model on various export formats. + export: Exports the model to different formats. + train: Trains the model on a dataset. + tune: Performs hyperparameter tuning. + _apply: Applies a function to the model's tensors. + add_callback: Adds a callback function for an event. + clear_callback: Clears all callbacks for an event. + reset_callbacks: Resets all callbacks to their default functions. + _get_hub_session: Retrieves or creates an Ultralytics HUB session. + is_triton_model: Checks if a model is a Triton Server model. + is_hub_model: Checks if a model is an Ultralytics HUB model. + _reset_ckpt_args: Resets checkpoint arguments when loading a PyTorch model. + _smart_load: Loads the appropriate module based on the model task. + task_map: Provides a mapping from model tasks to corresponding classes. - Returns: - list(ultralytics.engine.results.Results): The prediction results. + Raises: + FileNotFoundError: If the specified model file does not exist or is inaccessible. + ValueError: If the model file or configuration is invalid or unsupported. + ImportError: If required dependencies for specific model types (like HUB SDK) are not installed. + TypeError: If the model is not a PyTorch model when required. + AttributeError: If required attributes or methods are not implemented or available. + NotImplementedError: If a specific model task or mode is not supported. """ - def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None: + def __init__( + self, + model: Union[str, Path] = "yolov8n.pt", + task: str = None, + verbose: bool = False, + ) -> None: """ - Initializes the YOLO model. + Initializes a new instance of the YOLO model class. + + This constructor sets up the model based on the provided model path or name. It handles various types of model + sources, including local files, Ultralytics HUB models, and Triton Server models. The method initializes several + important attributes of the model and prepares it for operations like training, prediction, or export. Args: - model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'. - task (Any, optional): Task type for the YOLO model. Defaults to None. + model (Union[str, Path], optional): The path or model file to load or create. This can be a local + file path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'. + task (Any, optional): The task type associated with the YOLO model, specifying its application domain. + Defaults to None. + verbose (bool, optional): If True, enables verbose output during the model's initialization and subsequent + operations. Defaults to False. + + Raises: + FileNotFoundError: If the specified model file does not exist or is inaccessible. + ValueError: If the model file or configuration is invalid or unsupported. + ImportError: If required dependencies for specific model types (like HUB SDK) are not installed. """ + super().__init__() self.callbacks = callbacks.get_default_callbacks() self.predictor = None # reuse predictor self.model = None # model object @@ -74,36 +119,80 @@ class Model: self.metrics = None # validation/training metrics self.session = None # HUB session self.task = task # task type - model = str(model).strip() # strip spaces + model = str(model).strip() # Check if Ultralytics HUB model from https://hub.ultralytics.com if self.is_hub_model(model): - from ultralytics.hub.session import HUBTrainingSession - self.session = HUBTrainingSession(model) + # Fetch model from HUB + checks.check_requirements("hub-sdk>=0.0.6") + self.session = self._get_hub_session(model) model = self.session.model_file - # Load or create new YOLO model - suffix = Path(model).suffix - if not suffix and Path(model).stem in GITHUB_ASSETS_STEMS: - model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt - if suffix in ('.yaml', '.yml'): - self._new(model, task) - else: - self._load(model, task) + # Check if Triton Server model + elif self.is_triton_model(model): + self.model_name = self.model = model + self.task = task + return - def __call__(self, source=None, stream=False, **kwargs): - """Calls the 'predict' function with given arguments to perform object detection.""" + # Load or create new YOLO model + if Path(model).suffix in (".yaml", ".yml"): + self._new(model, task=task, verbose=verbose) + else: + self._load(model, task=task) + + def __call__( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs, + ) -> list: + """ + An alias for the predict method, enabling the model instance to be callable. + + This method simplifies the process of making predictions by allowing the model instance to be called directly + with the required arguments for prediction. + + Args: + source (str | Path | int | PIL.Image | np.ndarray, optional): The source of the image for making + predictions. Accepts various types, including file paths, URLs, PIL images, and numpy arrays. + Defaults to None. + stream (bool, optional): If True, treats the input source as a continuous stream for predictions. + Defaults to False. + **kwargs (any): Additional keyword arguments for configuring the prediction process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in the Results class. + """ return self.predict(source, stream, **kwargs) @staticmethod - def is_hub_model(model): - """Check if the provided model is a HUB model.""" - return any(( - model.startswith(f'{HUB_WEB_ROOT}/models/'), # i.e. https://hub.ultralytics.com/models/MODEL_ID - [len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID - len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID + def _get_hub_session(model: str): + """Creates a session for Hub Training.""" + from ultralytics.hub.session import HUBTrainingSession - def _new(self, cfg: str, task=None, model=None, verbose=True): + session = HUBTrainingSession(model) + return session if session.client.authenticated else None + + @staticmethod + def is_triton_model(model: str) -> bool: + """Is model a Triton Server URL string, i.e. :////""" + from urllib.parse import urlsplit + + url = urlsplit(model) + return url.netloc and url.path and url.scheme in {"http", "grpc"} + + @staticmethod + def is_hub_model(model: str) -> bool: + """Check if the provided model is a HUB model.""" + return any( + ( + model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID + [len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODEL + len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODEL + ) + ) + + def _new(self, cfg: str, task=None, model=None, verbose=False) -> None: """ Initializes a new model and infers the task type from the model definitions. @@ -116,16 +205,16 @@ class Model: cfg_dict = yaml_model_load(cfg) self.cfg = cfg self.task = task or guess_model_task(cfg_dict) - self.model = (model or self.smart_load('model'))(cfg_dict, verbose=verbose and RANK == -1) # build model - self.overrides['model'] = self.cfg - self.overrides['task'] = self.task + self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model + self.overrides["model"] = self.cfg + self.overrides["task"] = self.task # Below added to allow export from YAMLs - args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args - self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model + self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args) self.model.task = self.task + self.model_name = cfg - def _load(self, weights: str, task=None): + def _load(self, weights: str, task=None) -> None: """ Initializes a new model and infers the task type from the model head. @@ -133,49 +222,74 @@ class Model: weights (str): model checkpoint to be loaded task (str | None): model task """ - suffix = Path(weights).suffix - if suffix == '.pt': + if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): + weights = checks.check_file(weights) # automatically download and return local filename + weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt + + if Path(weights).suffix == ".pt": self.model, self.ckpt = attempt_load_one_weight(weights) - self.task = self.model.args['task'] + self.task = self.model.args["task"] self.overrides = self.model.args = self._reset_ckpt_args(self.model.args) self.ckpt_path = self.model.pt_path else: - weights = check_file(weights) + weights = checks.check_file(weights) # runs in all cases, not redundant with above call self.model, self.ckpt = weights, None self.task = task or guess_model_task(weights) self.ckpt_path = weights - self.overrides['model'] = weights - self.overrides['task'] = self.task + self.overrides["model"] = weights + self.overrides["task"] = self.task + self.model_name = weights - def _check_is_pytorch_model(self): - """ - Raises TypeError is model is not a PyTorch model - """ - pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt' + def _check_is_pytorch_model(self) -> None: + """Raises TypeError is model is not a PyTorch model.""" + pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt" pt_module = isinstance(self.model, nn.Module) if not (pt_module or pt_str): - raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. " - f'PyTorch models can be used to train, val, predict and export, i.e. ' - f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only " - f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.") + raise TypeError( + f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. " + f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported " + f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, " + f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device " + f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'" + ) - @smart_inference_mode() - def reset_weights(self): + def reset_weights(self) -> "Model": """ - Resets the model modules parameters to randomly initialized values, losing all training information. + Resets the model parameters to randomly initialized values, effectively discarding all training information. + + This method iterates through all modules in the model and resets their parameters if they have a + 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, enabling them + to be updated during training. + + Returns: + self (ultralytics.engine.model.Model): The instance of the class with reset weights. + + Raises: + AssertionError: If the model is not a PyTorch model. """ self._check_is_pytorch_model() for m in self.model.modules(): - if hasattr(m, 'reset_parameters'): + if hasattr(m, "reset_parameters"): m.reset_parameters() for p in self.model.parameters(): p.requires_grad = True return self - @smart_inference_mode() - def load(self, weights='yolov8n.pt'): + def load(self, weights: Union[str, Path] = "yolov8n.pt") -> "Model": """ - Transfers parameters with matching names and shapes from 'weights' to model. + Loads parameters from the specified weights file into the model. + + This method supports loading weights from a file or directly from a weights object. It matches parameters by + name and shape and transfers them to the model. + + Args: + weights (str | Path): Path to the weights file or a weights object. Defaults to 'yolov8n.pt'. + + Returns: + self (ultralytics.engine.model.Model): The instance of the class with loaded weights. + + Raises: + AssertionError: If the model is not a PyTorch model. """ self._check_is_pytorch_model() if isinstance(weights, (str, Path)): @@ -183,160 +297,362 @@ class Model: self.model.load(weights) return self - def info(self, detailed=False, verbose=True): + def save(self, filename: Union[str, Path] = "saved_model.pt", use_dill=True) -> None: """ - Logs model info. + Saves the current model state to a file. + + This method exports the model's checkpoint (ckpt) to the specified filename. Args: - detailed (bool): Show detailed information about model. - verbose (bool): Controls verbosity. + filename (str | Path): The name of the file to save the model to. Defaults to 'saved_model.pt'. + use_dill (bool): Whether to try using dill for serialization if available. Defaults to True. + + Raises: + AssertionError: If the model is not a PyTorch model. + """ + self._check_is_pytorch_model() + from ultralytics import __version__ + from datetime import datetime + + updates = { + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + } + torch.save({**self.ckpt, **updates}, filename, use_dill=use_dill) + + def info(self, detailed: bool = False, verbose: bool = True): + """ + Logs or returns model information. + + This method provides an overview or detailed information about the model, depending on the arguments passed. + It can control the verbosity of the output. + + Args: + detailed (bool): If True, shows detailed information about the model. Defaults to False. + verbose (bool): If True, prints the information. If False, returns the information. Defaults to True. + + Returns: + (list): Various types of information about the model, depending on the 'detailed' and 'verbose' parameters. + + Raises: + AssertionError: If the model is not a PyTorch model. """ self._check_is_pytorch_model() return self.model.info(detailed=detailed, verbose=verbose) def fuse(self): - """Fuse PyTorch Conv2d and BatchNorm2d layers.""" + """ + Fuses Conv2d and BatchNorm2d layers in the model. + + This method optimizes the model by fusing Conv2d and BatchNorm2d layers, which can improve inference speed. + + Raises: + AssertionError: If the model is not a PyTorch model. + """ self._check_is_pytorch_model() self.model.fuse() - @smart_inference_mode() - def predict(self, source=None, stream=False, predictor=None, **kwargs): + def embed( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs, + ) -> list: """ - Perform prediction using the YOLO model. + Generates image embeddings based on the provided source. + + This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image source. + It allows customization of the embedding process through various keyword arguments. Args: - source (str | int | PIL | np.ndarray): The source of the image to make predictions on. - Accepts all source types accepted by the YOLO model. - stream (bool): Whether to stream the predictions or not. Defaults to False. - predictor (BasePredictor): Customized predictor. - **kwargs : Additional keyword arguments passed to the predictor. - Check the 'configuration' section in the documentation for all available options. + source (str | int | PIL.Image | np.ndarray): The source of the image for generating embeddings. + The source can be a file path, URL, PIL image, numpy array, etc. Defaults to None. + stream (bool): If True, predictions are streamed. Defaults to False. + **kwargs (any): Additional keyword arguments for configuring the embedding process. Returns: - (List[ultralytics.engine.results.Results]): The prediction results. + (List[torch.Tensor]): A list containing the image embeddings. + + Raises: + AssertionError: If the model is not a PyTorch model. + """ + if not kwargs.get("embed"): + kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed + return self.predict(source, stream, **kwargs) + + def predict( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + predictor=None, + **kwargs, + ) -> list: + """ + Performs predictions on the given image source using the YOLO model. + + This method facilitates the prediction process, allowing various configurations through keyword arguments. + It supports predictions with custom predictors or the default predictor method. The method handles different + types of image sources and can operate in a streaming mode. It also provides support for SAM-type models + through 'prompts'. + + The method sets up a new predictor if not already present and updates its arguments with each call. + It also issues a warning and uses default assets if the 'source' is not provided. The method determines if it + is being called from the command line interface and adjusts its behavior accordingly, including setting defaults + for confidence threshold and saving behavior. + + Args: + source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions. + Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to ASSETS. + stream (bool, optional): Treats the input source as a continuous stream for predictions. Defaults to False. + predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions. + If None, the method uses a default predictor. Defaults to None. + **kwargs (any): Additional keyword arguments for configuring the prediction process. These arguments allow + for further customization of the prediction behavior. + + Returns: + (List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in the Results class. + + Raises: + AttributeError: If the predictor is not properly set up. """ if source is None: source = ASSETS LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") - is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any( - x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track')) + is_cli = (sys.argv[0].endswith("yolo") or sys.argv[0].endswith("ultralytics")) and any( + x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track") + ) - custom = {'conf': 0.25, 'save': is_cli} # method defaults - args = {**self.overrides, **custom, **kwargs, 'mode': 'predict'} # highest priority args on the right - prompts = args.pop('prompts', None) # for SAM-type models + custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults + args = {**self.overrides, **custom, **kwargs} # highest priority args on the right + prompts = args.pop("prompts", None) # for SAM-type models if not self.predictor: - self.predictor = (predictor or self.smart_load('predictor'))(overrides=args, _callbacks=self.callbacks) + self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks) self.predictor.setup_model(model=self.model, verbose=is_cli) else: # only update args if predictor is already setup self.predictor.args = get_cfg(self.predictor.args, args) - if 'project' in args or 'name' in args: + if "project" in args or "name" in args: self.predictor.save_dir = get_save_dir(self.predictor.args) - if prompts and hasattr(self.predictor, 'set_prompts'): # for SAM-type models + if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models self.predictor.set_prompts(prompts) return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) - def track(self, source=None, stream=False, persist=False, **kwargs): + def track( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + persist: bool = False, + **kwargs, + ) -> list: """ - Perform object tracking on the input source using the registered trackers. + Conducts object tracking on the specified input source using the registered trackers. + + This method performs object tracking using the model's predictors and optionally registered trackers. It is + capable of handling different types of input sources such as file paths or video streams. The method supports + customization of the tracking process through various keyword arguments. It registers trackers if they are not + already present and optionally persists them based on the 'persist' flag. + + The method sets a default confidence threshold specifically for ByteTrack-based tracking, which requires low + confidence predictions as input. The tracking mode is explicitly set in the keyword arguments. Args: - source (str, optional): The input source for object tracking. Can be a file path or a video stream. - stream (bool, optional): Whether the input source is a video stream. Defaults to False. - persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. - **kwargs (optional): Additional keyword arguments for the tracking process. + source (str, optional): The input source for object tracking. It can be a file path, URL, or video stream. + stream (bool, optional): Treats the input source as a continuous video stream. Defaults to False. + persist (bool, optional): Persists the trackers between different calls to this method. Defaults to False. + **kwargs (any): Additional keyword arguments for configuring the tracking process. These arguments allow + for further customization of the tracking behavior. Returns: - (List[ultralytics.engine.results.Results]): The tracking results. + (List[ultralytics.engine.results.Results]): A list of tracking results, encapsulated in the Results class. + + Raises: + AttributeError: If the predictor does not have registered trackers. """ - if not hasattr(self.predictor, 'trackers'): + if not hasattr(self.predictor, "trackers"): from ultralytics.trackers import register_tracker + register_tracker(self, persist) - # ByteTrack-based method needs low confidence predictions as input - kwargs['conf'] = kwargs.get('conf') or 0.1 - kwargs['mode'] = 'track' + kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input + kwargs["batch"] = kwargs.get("batch") or 1 # batch-size 1 for tracking in videos + kwargs["mode"] = "track" return self.predict(source=source, stream=stream, **kwargs) - @smart_inference_mode() - def val(self, validator=None, **kwargs): + def val( + self, + validator=None, + **kwargs, + ): """ - Validate a model on a given dataset. + Validates the model using a specified dataset and validation configuration. + + This method facilitates the model validation process, allowing for a range of customization through various + settings and configurations. It supports validation with a custom validator or the default validation approach. + The method combines default configurations, method-specific defaults, and user-provided arguments to configure + the validation process. After validation, it updates the model's metrics with the results obtained from the + validator. + + The method supports various arguments that allow customization of the validation process. For a comprehensive + list of all configurable options, users should refer to the 'configuration' section in the documentation. Args: - validator (BaseValidator): Customized validator. - **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs - """ - custom = {'rect': True} # method defaults - args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right - args['imgsz'] = check_imgsz(args['imgsz'], max_dim=1) + validator (BaseValidator, optional): An instance of a custom validator class for validating the model. If + None, the method uses a default validator. Defaults to None. + **kwargs (any): Arbitrary keyword arguments representing the validation configuration. These arguments are + used to customize various aspects of the validation process. - validator = (validator or self.smart_load('validator'))(args=args, _callbacks=self.callbacks) + Returns: + (dict): Validation metrics obtained from the validation process. + + Raises: + AssertionError: If the model is not a PyTorch model. + """ + custom = {"rect": True} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right + + validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks) validator(model=self.model) self.metrics = validator.metrics return validator.metrics - @smart_inference_mode() - def benchmark(self, **kwargs): + def benchmark( + self, + **kwargs, + ): """ - Benchmark a model on all export formats. + Benchmarks the model across various export formats to evaluate performance. + + This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc. + It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is configured + using a combination of default configuration values, model-specific arguments, method-specific defaults, and + any additional user-provided keyword arguments. + + The method supports various arguments that allow customization of the benchmarking process, such as dataset + choice, image size, precision modes, device selection, and verbosity. For a comprehensive list of all + configurable options, users should refer to the 'configuration' section in the documentation. Args: - **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs + **kwargs (any): Arbitrary keyword arguments to customize the benchmarking process. These are combined with + default configurations, model-specific arguments, and method defaults. + + Returns: + (dict): A dictionary containing the results of the benchmarking process. + + Raises: + AssertionError: If the model is not a PyTorch model. """ self._check_is_pytorch_model() from ultralytics.utils.benchmarks import benchmark - custom = {'verbose': False} # method defaults - args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, 'mode': 'benchmark'} + custom = {"verbose": False} # method defaults + args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"} return benchmark( model=self, - data=kwargs.get('data'), # if no 'data' argument passed set data=None for default datasets - imgsz=args['imgsz'], - half=args['half'], - int8=args['int8'], - device=args['device'], - verbose=kwargs.get('verbose')) + data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets + imgsz=args["imgsz"], + half=args["half"], + int8=args["int8"], + device=args["device"], + verbose=kwargs.get("verbose"), + ) - def export(self, **kwargs): + def export( + self, + **kwargs, + ): """ - Export model. + Exports the model to a different format suitable for deployment. + + This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment + purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method + defaults, and any additional arguments provided. The combined arguments are used to configure export settings. + + The method supports a wide range of arguments to customize the export process. For a comprehensive list of all + possible arguments, refer to the 'configuration' section in the documentation. Args: - **kwargs : Any other args accepted by the Exporter. To see all args check 'configuration' section in docs. + **kwargs (any): Arbitrary keyword arguments to customize the export process. These are combined with the + model's overrides and method defaults. + + Returns: + (object): The exported model in the specified format, or an object related to the export process. + + Raises: + AssertionError: If the model is not a PyTorch model. """ self._check_is_pytorch_model() from .exporter import Exporter - custom = {'imgsz': self.model.args['imgsz'], 'batch': 1, 'data': None, 'verbose': False} # method defaults - args = {**self.overrides, **custom, **kwargs, 'mode': 'export'} # highest priority args on the right + custom = {"imgsz": self.model.args["imgsz"], "batch": 1, "data": None, "verbose": False} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model) - def train(self, trainer=None, **kwargs): + def train( + self, + trainer=None, + **kwargs, + ): """ - Trains the model on a given dataset. + Trains the model using the specified dataset and training configuration. + + This method facilitates model training with a range of customizable settings and configurations. It supports + training with a custom trainer or the default training approach defined in the method. The method handles + different scenarios, such as resuming training from a checkpoint, integrating with Ultralytics HUB, and + updating model and configuration after training. + + When using Ultralytics HUB, if the session already has a loaded model, the method prioritizes HUB training + arguments and issues a warning if local arguments are provided. It checks for pip updates and combines default + configurations, method-specific defaults, and user-provided arguments to configure the training process. After + training, it updates the model and its configurations, and optionally attaches metrics. Args: - trainer (BaseTrainer, optional): Customized trainer. - **kwargs (Any): Any number of arguments representing the training configuration. + trainer (BaseTrainer, optional): An instance of a custom trainer class for training the model. If None, the + method uses a default trainer. Defaults to None. + **kwargs (any): Arbitrary keyword arguments representing the training configuration. These arguments are + used to customize various aspects of the training process. + + Returns: + (dict | None): Training metrics if available and training is successful; otherwise, None. + + Raises: + AssertionError: If the model is not a PyTorch model. + PermissionError: If there is a permission issue with the HUB session. + ModuleNotFoundError: If the HUB SDK is not installed. """ self._check_is_pytorch_model() - if self.session: # Ultralytics HUB session + if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model if any(kwargs): - LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.') - kwargs = self.session.train_args - check_pip_update_available() + LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.") + kwargs = self.session.train_args # overwrite kwargs - overrides = yaml_load(check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides - custom = {'data': TASK2DATA[self.task]} # method defaults - args = {**overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right - if args.get('resume'): - args['resume'] = self.ckpt_path + checks.check_pip_update_available() - self.trainer = (trainer or self.smart_load('trainer'))(overrides=args, _callbacks=self.callbacks) - if not args.get('resume'): # manually set model only if not resuming + overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides + custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]} # method defaults + args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right + if args.get("resume"): + args["resume"] = self.ckpt_path + + self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks) + if not args.get("resume"): # manually set model only if not resuming self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) self.model = self.trainer.model + + if SETTINGS["hub"] is True and not self.session: + # Create a model in HUB + try: + self.session = self._get_hub_session(self.model_name) + if self.session: + self.session.create_model(args) + # Check model was created + if not getattr(self.session.model, "id", None): + self.session = None + except (PermissionError, ModuleNotFoundError): + # Ignore PermissionError and ModuleNotFoundError which indicates hub-sdk not installed + pass + self.trainer.hub_session = self.session # attach optional HUB session self.trainer.train() # Update model and cfg after training @@ -344,78 +660,148 @@ class Model: ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last self.model, _ = attempt_load_one_weight(ckpt) self.overrides = self.model.args - self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP + self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP return self.metrics - def tune(self, use_ray=False, iterations=10, *args, **kwargs): + def tune( + self, + use_ray=False, + iterations=10, + *args, + **kwargs, + ): """ - Runs hyperparameter tuning, optionally using Ray Tune. See ultralytics.utils.tuner.run_ray_tune for Args. + Conducts hyperparameter tuning for the model, with an option to use Ray Tune. + + This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method. + When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module. + Otherwise, it uses the internal 'Tuner' class for tuning. The method combines default, overridden, and + custom arguments to configure the tuning process. + + Args: + use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False. + iterations (int): The number of tuning iterations to perform. Defaults to 10. + *args (list): Variable length argument list for additional arguments. + **kwargs (any): Arbitrary keyword arguments. These are combined with the model's overrides and defaults. Returns: (dict): A dictionary containing the results of the hyperparameter search. + + Raises: + AssertionError: If the model is not a PyTorch model. """ self._check_is_pytorch_model() if use_ray: from ultralytics.utils.tuner import run_ray_tune + return run_ray_tune(self, max_samples=iterations, *args, **kwargs) else: from .tuner import Tuner - custom = {'plots': False, 'save': False} # method defaults - args = {**self.overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right + custom = {} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations) - def to(self, device): - """ - Sends the model to the given device. - - Args: - device (str): device - """ + def _apply(self, fn) -> "Model": + """Apply to(), cpu(), cuda(), half(), float() to model tensors that are not parameters or registered buffers.""" self._check_is_pytorch_model() - self.model.to(device) + self = super()._apply(fn) # noqa + self.predictor = None # reset predictor as device may have changed + self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0' return self @property - def names(self): - """Returns class names of the loaded model.""" - return self.model.names if hasattr(self.model, 'names') else None + def names(self) -> list: + """ + Retrieves the class names associated with the loaded model. + + This property returns the class names if they are defined in the model. It checks the class names for validity + using the 'check_class_names' function from the ultralytics.nn.autobackend module. + + Returns: + (list | None): The class names of the model if available, otherwise None. + """ + from ultralytics.nn.autobackend import check_class_names + + return check_class_names(self.model.names) if hasattr(self.model, "names") else None @property - def device(self): - """Returns device if PyTorch model.""" + def device(self) -> torch.device: + """ + Retrieves the device on which the model's parameters are allocated. + + This property is used to determine whether the model's parameters are on CPU or GPU. It only applies to models + that are instances of nn.Module. + + Returns: + (torch.device | None): The device (CPU/GPU) of the model if it is a PyTorch model, otherwise None. + """ return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None @property def transforms(self): - """Returns transform of the loaded model.""" - return self.model.transforms if hasattr(self.model, 'transforms') else None + """ + Retrieves the transformations applied to the input data of the loaded model. - def add_callback(self, event: str, func): - """Add a callback.""" + This property returns the transformations if they are defined in the model. + + Returns: + (object | None): The transform object of the model if available, otherwise None. + """ + return self.model.transforms if hasattr(self.model, "transforms") else None + + def add_callback(self, event: str, func) -> None: + """ + Adds a callback function for a specified event. + + This method allows the user to register a custom callback function that is triggered on a specific event during + model training or inference. + + Args: + event (str): The name of the event to attach the callback to. + func (callable): The callback function to be registered. + + Raises: + ValueError: If the event name is not recognized. + """ self.callbacks[event].append(func) - def clear_callback(self, event: str): - """Clear all event callbacks.""" + def clear_callback(self, event: str) -> None: + """ + Clears all callback functions registered for a specified event. + + This method removes all custom and default callback functions associated with the given event. + + Args: + event (str): The name of the event for which to clear the callbacks. + + Raises: + ValueError: If the event name is not recognized. + """ self.callbacks[event] = [] - @staticmethod - def _reset_ckpt_args(args): - """Reset arguments when loading a PyTorch model.""" - include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model - return {k: v for k, v in args.items() if k in include} + def reset_callbacks(self) -> None: + """ + Resets all callbacks to their default functions. - def _reset_callbacks(self): - """Reset all registered callbacks.""" + This method reinstates the default callback functions for all events, removing any custom callbacks that were + added previously. + """ for event in callbacks.default_callbacks.keys(): self.callbacks[event] = [callbacks.default_callbacks[event][0]] - def __getattr__(self, attr): - """Raises error if object has no requested attribute.""" - name = self.__class__.__name__ - raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + @staticmethod + def _reset_ckpt_args(args: dict) -> dict: + """Reset arguments when loading a PyTorch model.""" + include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model + return {k: v for k, v in args.items() if k in include} - def smart_load(self, key): + # def __getattr__(self, attr): + # """Raises error if object has no requested attribute.""" + # name = self.__class__.__name__ + # raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + + def _smart_load(self, key: str): """Load model/trainer/validator/predictor.""" try: return self.task_map[self.task][key] @@ -423,14 +809,15 @@ class Model: name = self.__class__.__name__ mode = inspect.stack()[1][3] # get the function name. raise NotImplementedError( - emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")) from e + emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.") + ) from e @property - def task_map(self): + def task_map(self) -> dict: """ Map head to model, trainer, validator, and predictor classes. Returns: task_map (dict): The map of model task to mode classes. """ - raise NotImplementedError('Please provide task map for your model!') + raise NotImplementedError("Please provide task map for your model!") diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py index c649090..9ec803a 100644 --- a/ultralytics/engine/predictor.py +++ b/ultralytics/engine/predictor.py @@ -11,8 +11,8 @@ Usage - sources: list.txt # list of images list.streams # list of streams 'path/*.jpg' # glob - 'https://youtu.be/Zgi9g1ksQHc' # YouTube - 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream + 'https://youtu.be/LNwODJXcvt4' # YouTube + 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream Usage - formats: $ yolo mode=predict model=yolov8n.pt # PyTorch @@ -26,8 +26,12 @@ Usage - formats: yolov8n.tflite # TensorFlow Lite yolov8n_edgetpu.tflite # TensorFlow Edge TPU yolov8n_paddle_model # PaddlePaddle + yolov8n_ncnn_model # NCNN """ + import platform +import re +import threading from pathlib import Path import cv2 @@ -58,7 +62,7 @@ Example: class BasePredictor: """ - BasePredictor + BasePredictor. A base class for creating predictors. @@ -70,9 +74,7 @@ class BasePredictor: data (dict): Data configuration. device (torch.device): Device used for prediction. dataset (Dataset): Dataset used for prediction. - vid_path (str): Path to video file. - vid_writer (cv2.VideoWriter): Video writer for saving video output. - data_path (str): Path to data. + vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output. """ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): @@ -97,19 +99,22 @@ class BasePredictor: self.imgsz = None self.device = None self.dataset = None - self.vid_path, self.vid_writer = None, None + self.vid_writer = {} # dict of {save_path: video_writer, ...} self.plotted_img = None - self.data_path = None self.source_type = None + self.seen = 0 + self.windows = [] self.batch = None self.results = None self.transforms = None self.callbacks = _callbacks or callbacks.get_default_callbacks() self.txt_path = None + self._lock = threading.Lock() # for automatic thread-safe inference callbacks.add_integration_callbacks(self) def preprocess(self, im): - """Prepares input image before inference. + """ + Prepares input image before inference. Args: im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list. @@ -128,9 +133,13 @@ class BasePredictor: return im def inference(self, im, *args, **kwargs): - visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem, - mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False - return self.model(im, augment=self.args.augment, visualize=visualize) + """Runs inference on a given image using the specified model and arguments.""" + visualize = ( + increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True) + if self.args.visualize and (not self.source_type.tensor) + else False + ) + return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs) def pre_transform(self, im): """ @@ -142,45 +151,10 @@ class BasePredictor: Returns: (list): A list of transformed images. """ - same_shapes = all(x.shape == im[0].shape for x in im) + same_shapes = len({x.shape for x in im}) == 1 letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride) return [letterbox(image=x) for x in im] - def write_results(self, idx, results, batch): - """Write inference results to a file or directory.""" - p, im, _ = batch - log_string = '' - if len(im.shape) == 3: - im = im[None] # expand for batch dim - if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1 - log_string += f'{idx}: ' - frame = self.dataset.count - else: - frame = getattr(self.dataset, 'frame', 0) - self.data_path = p - self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') - log_string += '%gx%g ' % im.shape[2:] # print string - result = results[idx] - log_string += result.verbose() - - if self.args.save or self.args.show: # Add bbox to image - plot_args = { - 'line_width': self.args.line_width, - 'boxes': self.args.boxes, - 'conf': self.args.show_conf, - 'labels': self.args.show_labels} - if not self.args.retina_masks: - plot_args['im_gpu'] = im[idx] - self.plotted_img = result.plot(**plot_args) - # Write - if self.args.save_txt: - result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf) - if self.args.save_crop: - result.save_crop(save_dir=self.save_dir / 'crops', - file_name=self.data_path.stem + ('' if self.dataset.mode == 'image' else f'_{frame}')) - - return log_string - def postprocess(self, preds, img, orig_imgs): """Post-processes predictions for an image and returns them.""" return preds @@ -194,157 +168,224 @@ class BasePredictor: return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one def predict_cli(self, source=None, model=None): - """Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode.""" + """ + Method used for CLI prediction. + + It uses always generator as outputs as not required by CLI mode. + """ gen = self.stream_inference(source, model) - for _ in gen: # running CLI inference without accumulating any outputs (do not modify) + for _ in gen: # noqa, running CLI inference without accumulating any outputs (do not modify) pass def setup_source(self, source): """Sets up source and inference mode.""" self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size - self.transforms = getattr(self.model.model, 'transforms', classify_transforms( - self.imgsz[0])) if self.args.task == 'classify' else None - self.dataset = load_inference_source(source=source, - imgsz=self.imgsz, - vid_stride=self.args.vid_stride, - stream_buffer=self.args.stream_buffer) + self.transforms = ( + getattr( + self.model.model, + "transforms", + classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction), + ) + if self.args.task == "classify" + else None + ) + self.dataset = load_inference_source( + source=source, + batch=self.args.batch, + vid_stride=self.args.vid_stride, + buffer=self.args.stream_buffer, + ) self.source_type = self.dataset.source_type - if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams - len(self.dataset) > 1000 or # images - any(getattr(self.dataset, 'video_flag', [False]))): # videos + if not getattr(self, "stream", True) and ( + self.source_type.stream + or self.source_type.screenshot + or len(self.dataset) > 1000 # many images + or any(getattr(self.dataset, "video_flag", [False])) + ): # videos LOGGER.warning(STREAM_WARNING) - self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs + self.vid_writer = {} @smart_inference_mode() def stream_inference(self, source=None, model=None, *args, **kwargs): """Streams real-time inference on camera feed and saves results to file.""" if self.args.verbose: - LOGGER.info('') + LOGGER.info("") # Setup model if not self.model: self.setup_model(model) - # Setup source every time predict is called - self.setup_source(source if source is not None else self.args.source) + with self._lock: # for thread-safe inference + # Setup source every time predict is called + self.setup_source(source if source is not None else self.args.source) - # Check if save_dir/ label file exists - if self.args.save or self.args.save_txt: - (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) + # Check if save_dir/ label file exists + if self.args.save or self.args.save_txt: + (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) - # Warmup model - if not self.done_warmup: - self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz)) - self.done_warmup = True + # Warmup model + if not self.done_warmup: + self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz)) + self.done_warmup = True - self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile()) - self.run_callbacks('on_predict_start') - for batch in self.dataset: - self.run_callbacks('on_predict_batch_start') - self.batch = batch - path, im0s, vid_cap, s = batch + self.seen, self.windows, self.batch = 0, [], None + profilers = ( + ops.Profile(device=self.device), + ops.Profile(device=self.device), + ops.Profile(device=self.device), + ) + self.run_callbacks("on_predict_start") + for self.batch in self.dataset: + self.run_callbacks("on_predict_batch_start") + paths, im0s, s = self.batch - # Preprocess - with profilers[0]: - im = self.preprocess(im0s) + # Preprocess + with profilers[0]: + im = self.preprocess(im0s) - # Inference - with profilers[1]: - preds = self.inference(im, *args, **kwargs) + # Inference + with profilers[1]: + preds = self.inference(im, *args, **kwargs) + if self.args.embed: + yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors + continue - # Postprocess - with profilers[2]: - self.results = self.postprocess(preds, im, im0s) - self.run_callbacks('on_predict_postprocess_end') + # Postprocess + with profilers[2]: + self.results = self.postprocess(preds, im, im0s) + self.run_callbacks("on_predict_postprocess_end") - # Visualize, save, write results - n = len(im0s) - for i in range(n): - self.seen += 1 - self.results[i].speed = { - 'preprocess': profilers[0].dt * 1E3 / n, - 'inference': profilers[1].dt * 1E3 / n, - 'postprocess': profilers[2].dt * 1E3 / n} - p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy() - p = Path(p) + # Visualize, save, write results + n = len(im0s) + for i in range(n): + self.seen += 1 + self.results[i].speed = { + "preprocess": profilers[0].dt * 1e3 / n, + "inference": profilers[1].dt * 1e3 / n, + "postprocess": profilers[2].dt * 1e3 / n, + } + if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: + s[i] += self.write_results(i, Path(paths[i]), im, s) - if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: - s += self.write_results(i, self.results, (p, im, im0)) - if self.args.save or self.args.save_txt: - self.results[i].save_dir = self.save_dir.__str__() - if self.args.show and self.plotted_img is not None: - self.show(p) - if self.args.save and self.plotted_img is not None: - self.save_preds(vid_cap, i, str(self.save_dir / p.name)) + # Print batch results + if self.args.verbose: + LOGGER.info("\n".join(s)) - self.run_callbacks('on_predict_batch_end') - yield from self.results - - # Print time (inference-only) - if self.args.verbose: - LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms') + self.run_callbacks("on_predict_batch_end") + yield from self.results # Release assets - if isinstance(self.vid_writer[-1], cv2.VideoWriter): - self.vid_writer[-1].release() # release final video writer + for v in self.vid_writer.values(): + if isinstance(v, cv2.VideoWriter): + v.release() - # Print results + # Print final results if self.args.verbose and self.seen: - t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image - LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape ' - f'{(1, 3, *im.shape[2:])}' % t) + t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image + LOGGER.info( + f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape " + f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t + ) if self.args.save or self.args.save_txt or self.args.save_crop: - nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels - s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else '' + nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels + s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else "" LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") - - self.run_callbacks('on_predict_end') + self.run_callbacks("on_predict_end") def setup_model(self, model, verbose=True): """Initialize YOLO model with given parameters and set it to evaluation mode.""" - self.model = AutoBackend(model or self.args.model, - device=select_device(self.args.device, verbose=verbose), - dnn=self.args.dnn, - data=self.args.data, - fp16=self.args.half, - fuse=True, - verbose=verbose) + self.model = AutoBackend( + weights=model or self.args.model, + device=select_device(self.args.device, verbose=verbose), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, + batch=self.args.batch, + fuse=True, + verbose=verbose, + ) self.device = self.model.device # update device self.args.half = self.model.fp16 # update half self.model.eval() - def show(self, p): - """Display an image in a window using OpenCV imshow().""" - im0 = self.plotted_img - if platform.system() == 'Linux' and p not in self.windows: - self.windows.append(p) - cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) - cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0]) - cv2.imshow(str(p), im0) - cv2.waitKey(500 if self.batch[3].startswith('image') else 1) # 1 millisecond + def write_results(self, i, p, im, s): + """Write inference results to a file or directory.""" + string = "" # print string + if len(im.shape) == 3: + im = im[None] # expand for batch dim + if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1 + string += f"{i}: " + frame = self.dataset.count + else: + match = re.search(r"frame (\d+)/", s[i]) + frame = int(match.group(1)) if match else None # 0 if frame undetermined - def save_preds(self, vid_cap, idx, save_path): + self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}")) + string += "%gx%g " % im.shape[2:] + result = self.results[i] + result.save_dir = self.save_dir.__str__() # used in other locations + string += result.verbose() + f"{result.speed['inference']:.1f}ms" + + # Add predictions to image + if self.args.save or self.args.show: + self.plotted_img = result.plot( + line_width=self.args.line_width, + boxes=self.args.show_boxes, + conf=self.args.show_conf, + labels=self.args.show_labels, + im_gpu=None if self.args.retina_masks else im[i], + ) + + # Save results + if self.args.save_txt: + result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf) + if self.args.save_crop: + result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem) + if self.args.show: + self.show(str(p)) + if self.args.save: + self.save_predicted_images(str(self.save_dir / (p.name or "tmp.jpg")), frame) + + return string + + def save_predicted_images(self, save_path="", frame=0): """Save video predictions as mp4 at specified path.""" - im0 = self.plotted_img - # Save imgs - if self.dataset.mode == 'image': - cv2.imwrite(save_path, im0) - else: # 'video' or 'stream' - if self.vid_path[idx] != save_path: # new video - self.vid_path[idx] = save_path - if isinstance(self.vid_writer[idx], cv2.VideoWriter): - self.vid_writer[idx].release() # release previous video writer - if vid_cap: # video - fps = int(vid_cap.get(cv2.CAP_PROP_FPS)) # integer required, floats produce error in MP4 codec - w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - else: # stream - fps, w, h = 30, im0.shape[1], im0.shape[0] - suffix, fourcc = ('.mp4', 'avc1') if MACOS else ('.avi', 'WMV2') if WINDOWS else ('.avi', 'MJPG') - save_path = str(Path(save_path).with_suffix(suffix)) - self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)) - self.vid_writer[idx].write(im0) + im = self.plotted_img + + # Save videos and streams + if self.dataset.mode in {"stream", "video"}: + fps = self.dataset.fps if self.dataset.mode == "video" else 30 + frames_path = f'{save_path.split(".", 1)[0]}_frames/' + if save_path not in self.vid_writer: # new video + if self.args.save_frames: + Path(frames_path).mkdir(parents=True, exist_ok=True) + suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG") + self.vid_writer[save_path] = cv2.VideoWriter( + filename=str(Path(save_path).with_suffix(suffix)), + fourcc=cv2.VideoWriter_fourcc(*fourcc), + fps=fps, # integer required, floats produce error in MP4 codec + frameSize=(im.shape[1], im.shape[0]), # (width, height) + ) + + # Save video + self.vid_writer[save_path].write(im) + if self.args.save_frames: + cv2.imwrite(f"{frames_path}{frame}.jpg", im) + + # Save images + else: + cv2.imwrite(save_path, im) + + def show(self, p=""): + """Display an image in a window using OpenCV imshow().""" + im = self.plotted_img + if platform.system() == "Linux" and p not in self.windows: + self.windows.append(p) + cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) + cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height) + cv2.imshow(p, im) + cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond def run_callbacks(self, event: str): """Runs all registered callbacks for a specific event.""" @@ -352,7 +393,5 @@ class BasePredictor: callback(self) def add_callback(self, event: str, func): - """ - Add callback - """ + """Add callback.""" self.callbacks[event].append(func) diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py index d6763ff..85849c3 100644 --- a/ultralytics/engine/results.py +++ b/ultralytics/engine/results.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license """ -Ultralytics Results, Boxes and Masks classes for handling inference results +Ultralytics Results, Boxes and Masks classes for handling inference results. Usage: See https://docs.ultralytics.com/modes/predict/ """ @@ -13,17 +13,17 @@ import numpy as np import torch from ultralytics.data.augment import LetterBox -from ultralytics.utils import LOGGER, SimpleClass, deprecation_warn, ops +from ultralytics.utils import LOGGER, SimpleClass, ops from ultralytics.utils.plotting import Annotator, colors, save_one_box +from ultralytics.utils.torch_utils import smart_inference_mode class BaseTensor(SimpleClass): - """ - Base tensor class with additional methods for easy manipulation and device handling. - """ + """Base tensor class with additional methods for easy manipulation and device handling.""" def __init__(self, data, orig_shape) -> None: - """Initialize BaseTensor with data and original shape. + """ + Initialize BaseTensor with data and original shape. Args: data (torch.Tensor | np.ndarray): Predictions, such as bboxes, masks and keypoints. @@ -67,45 +67,63 @@ class Results(SimpleClass): """ A class for storing and manipulating inference results. - Args: - orig_img (numpy.ndarray): The original image as a numpy array. - path (str): The path to the image file. - names (dict): A dictionary of class names. - boxes (torch.tensor, optional): A 2D tensor of bounding box coordinates for each detection. - masks (torch.tensor, optional): A 3D tensor of detection masks, where each mask is a binary image. - probs (torch.tensor, optional): A 1D tensor of probabilities of each class for classification task. - keypoints (List[List[float]], optional): A list of detected keypoints for each object. - Attributes: - orig_img (numpy.ndarray): The original image as a numpy array. - orig_shape (tuple): The original image shape in (height, width) format. - boxes (Boxes, optional): A Boxes object containing the detection bounding boxes. - masks (Masks, optional): A Masks object containing the detection masks. - probs (Probs, optional): A Probs object containing probabilities of each class for classification task. - keypoints (Keypoints, optional): A Keypoints object containing detected keypoints for each object. - speed (dict): A dictionary of preprocess, inference, and postprocess speeds in milliseconds per image. - names (dict): A dictionary of class names. - path (str): The path to the image file. - _keys (tuple): A tuple of attribute names for non-empty attributes. + orig_img (numpy.ndarray): Original image as a numpy array. + orig_shape (tuple): Original image shape in (height, width) format. + boxes (Boxes, optional): Object containing detection bounding boxes. + masks (Masks, optional): Object containing detection masks. + probs (Probs, optional): Object containing class probabilities for classification tasks. + keypoints (Keypoints, optional): Object containing detected keypoints for each object. + speed (dict): Dictionary of preprocess, inference, and postprocess speeds (ms/image). + names (dict): Dictionary of class names. + path (str): Path to the image file. + + Methods: + update(boxes=None, masks=None, probs=None, obb=None): Updates object attributes with new detection results. + cpu(): Returns a copy of the Results object with all tensors on CPU memory. + numpy(): Returns a copy of the Results object with all tensors as numpy arrays. + cuda(): Returns a copy of the Results object with all tensors on GPU memory. + to(*args, **kwargs): Returns a copy of the Results object with tensors on a specified device and dtype. + new(): Returns a new Results object with the same image, path, and names. + plot(...): Plots detection results on an input image, returning an annotated image. + show(): Show annotated results to screen. + save(filename): Save annotated results to file. + verbose(): Returns a log string for each task, detailing detections and classifications. + save_txt(txt_file, save_conf=False): Saves detection results to a text file. + save_crop(save_dir, file_name=Path("im.jpg")): Saves cropped detection images. + tojson(normalize=False): Converts detection results to JSON format. """ - def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None) -> None: - """Initialize the Results class.""" + def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None, obb=None) -> None: + """ + Initialize the Results class. + + Args: + orig_img (numpy.ndarray): The original image as a numpy array. + path (str): The path to the image file. + names (dict): A dictionary of class names. + boxes (torch.tensor, optional): A 2D tensor of bounding box coordinates for each detection. + masks (torch.tensor, optional): A 3D tensor of detection masks, where each mask is a binary image. + probs (torch.tensor, optional): A 1D tensor of probabilities of each class for classification task. + keypoints (torch.tensor, optional): A 2D tensor of keypoint coordinates for each detection. + obb (torch.tensor, optional): A 2D tensor of oriented bounding box coordinates for each detection. + """ self.orig_img = orig_img self.orig_shape = orig_img.shape[:2] self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks self.probs = Probs(probs) if probs is not None else None self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None - self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image + self.obb = OBB(obb, self.orig_shape) if obb is not None else None + self.speed = {"preprocess": None, "inference": None, "postprocess": None} # milliseconds per image self.names = names self.path = path self.save_dir = None - self._keys = 'boxes', 'masks', 'probs', 'keypoints' + self._keys = "boxes", "masks", "probs", "keypoints", "obb" def __getitem__(self, idx): """Return a Results object for the specified index.""" - return self._apply('__getitem__', idx) + return self._apply("__getitem__", idx) def __len__(self): """Return the number of detections in the Results object.""" @@ -114,17 +132,30 @@ class Results(SimpleClass): if v is not None: return len(v) - def update(self, boxes=None, masks=None, probs=None): + def update(self, boxes=None, masks=None, probs=None, obb=None): """Update the boxes, masks, and probs attributes of the Results object.""" if boxes is not None: - ops.clip_boxes(boxes, self.orig_shape) # clip boxes - self.boxes = Boxes(boxes, self.orig_shape) + self.boxes = Boxes(ops.clip_boxes(boxes, self.orig_shape), self.orig_shape) if masks is not None: self.masks = Masks(masks, self.orig_shape) if probs is not None: self.probs = probs + if obb is not None: + self.obb = OBB(obb, self.orig_shape) def _apply(self, fn, *args, **kwargs): + """ + Applies a function to all non-empty attributes and returns a new Results object with modified attributes. This + function is internally called by methods like .to(), .cuda(), .cpu(), etc. + + Args: + fn (str): The name of the function to apply. + *args: Variable length argument list to pass to the function. + **kwargs: Arbitrary keyword arguments to pass to the function. + + Returns: + Results: A new Results object with attributes modified by the applied function. + """ r = self.new() for k in self._keys: v = getattr(self, k) @@ -134,40 +165,42 @@ class Results(SimpleClass): def cpu(self): """Return a copy of the Results object with all tensors on CPU memory.""" - return self._apply('cpu') + return self._apply("cpu") def numpy(self): """Return a copy of the Results object with all tensors as numpy arrays.""" - return self._apply('numpy') + return self._apply("numpy") def cuda(self): """Return a copy of the Results object with all tensors on GPU memory.""" - return self._apply('cuda') + return self._apply("cuda") def to(self, *args, **kwargs): """Return a copy of the Results object with tensors on the specified device and dtype.""" - return self._apply('to', *args, **kwargs) + return self._apply("to", *args, **kwargs) def new(self): """Return a new Results object with the same image, path, and names.""" return Results(orig_img=self.orig_img, path=self.path, names=self.names) def plot( - self, - conf=True, - line_width=None, - font_size=None, - font='Arial.ttf', - pil=False, - img=None, - im_gpu=None, - kpt_radius=5, - kpt_line=True, - labels=True, - boxes=True, - masks=True, - probs=True, - **kwargs # deprecated args TODO: remove support in 8.2 + self, + conf=True, + line_width=None, + font_size=None, + font="Arial.ttf", + pil=False, + img=None, + im_gpu=None, + kpt_radius=5, + kpt_line=True, + labels=True, + boxes=True, + masks=True, + probs=True, + show=False, + save=False, + filename=None, ): """ Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image. @@ -186,6 +219,9 @@ class Results(SimpleClass): boxes (bool): Whether to plot the bounding boxes. masks (bool): Whether to plot the masks. probs (bool): Whether to plot classification probability + show (bool): Whether to display the annotated image directly. + save (bool): Whether to save the annotated image to `filename`. + filename (str): Filename to save image to if save is True. Returns: (numpy.ndarray): A numpy array of the annotated image. @@ -207,19 +243,9 @@ class Results(SimpleClass): if img is None and isinstance(self.orig_img, torch.Tensor): img = (self.orig_img[0].detach().permute(1, 2, 0).contiguous() * 255).to(torch.uint8).cpu().numpy() - # Deprecation warn TODO: remove in 8.2 - if 'show_conf' in kwargs: - deprecation_warn('show_conf', 'conf') - conf = kwargs['show_conf'] - assert isinstance(conf, bool), '`show_conf` should be of boolean type, i.e, show_conf=True/False' - - if 'line_thickness' in kwargs: - deprecation_warn('line_thickness', 'line_width') - line_width = kwargs['line_thickness'] - assert isinstance(line_width, int), '`line_width` should be of int type, i.e, line_width=3' - names = self.names - pred_boxes, show_boxes = self.boxes, boxes + is_obb = self.obb is not None + pred_boxes, show_boxes = self.obb if is_obb else self.boxes, boxes pred_masks, show_masks = self.masks, masks pred_probs, show_probs = self.probs, probs annotator = Annotator( @@ -228,28 +254,35 @@ class Results(SimpleClass): font_size, font, pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True - example=names) + example=names, + ) # Plot Segment results if pred_masks and show_masks: if im_gpu is None: img = LetterBox(pred_masks.shape[1:])(image=annotator.result()) - im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute( - 2, 0, 1).flip(0).contiguous() / 255 + im_gpu = ( + torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device) + .permute(2, 0, 1) + .flip(0) + .contiguous() + / 255 + ) idx = pred_boxes.cls if pred_boxes else range(len(pred_masks)) annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu) # Plot Detect results - if pred_boxes and show_boxes: + if pred_boxes is not None and show_boxes: for d in reversed(pred_boxes): c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item()) - name = ('' if id is None else f'id:{id} ') + names[c] - label = (f'{name} {conf:.2f}' if conf else name) if labels else None - annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) + name = ("" if id is None else f"id:{id} ") + names[c] + label = (f"{name} {conf:.2f}" if conf else name) if labels else None + box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze() + annotator.box_label(box, label, color=colors(c, True), rotated=is_obb) # Plot Classify results if pred_probs is not None and show_probs: - text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5) + text = ",\n".join(f"{names[j] if names else j} {pred_probs.data[j]:.2f}" for j in pred_probs.top5) x = round(self.orig_shape[0] * 0.03) annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors @@ -258,17 +291,34 @@ class Results(SimpleClass): for k in reversed(self.keypoints.data): annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line) + # Show results + if show: + annotator.show(self.path) + + # Save results + if save: + annotator.save(filename) + return annotator.result() + def show(self, *args, **kwargs): + """Show annotated results image.""" + self.plot(show=True, *args, **kwargs) + + def save(self, filename=None, *args, **kwargs): + """Save annotated results image.""" + if not filename: + filename = f"results_{Path(self.path).name}" + self.plot(save=True, filename=filename, *args, **kwargs) + return filename + def verbose(self): - """ - Return log string for each task. - """ - log_string = '' + """Return log string for each task.""" + log_string = "" probs = self.probs boxes = self.boxes if len(self) == 0: - return log_string if probs is not None else f'{log_string}(no detections), ' + return log_string if probs is not None else f"{log_string}(no detections), " if probs is not None: log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, " if boxes: @@ -285,34 +335,35 @@ class Results(SimpleClass): txt_file (str): txt file path. save_conf (bool): save confidence score or not. """ - boxes = self.boxes + is_obb = self.obb is not None + boxes = self.obb if is_obb else self.boxes masks = self.masks probs = self.probs kpts = self.keypoints texts = [] if probs is not None: # Classify - [texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5] + [texts.append(f"{probs.data[j]:.2f} {self.names[j]}") for j in probs.top5] elif boxes: # Detect/segment/pose for j, d in enumerate(boxes): c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) - line = (c, *d.xywhn.view(-1)) + line = (c, *(d.xyxyxyxyn.view(-1) if is_obb else d.xywhn.view(-1))) if masks: seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2) line = (c, *seg) if kpts is not None: kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn - line += (*kpt.reshape(-1).tolist(), ) - line += (conf, ) * save_conf + (() if id is None else (id, )) - texts.append(('%g ' * len(line)).rstrip() % line) + line += (*kpt.reshape(-1).tolist(),) + line += (conf,) * save_conf + (() if id is None else (id,)) + texts.append(("%g " * len(line)).rstrip() % line) if texts: Path(txt_file).parent.mkdir(parents=True, exist_ok=True) # make directory - with open(txt_file, 'a') as f: - f.writelines(text + '\n' for text in texts) + with open(txt_file, "a") as f: + f.writelines(text + "\n" for text in texts) - def save_crop(self, save_dir, file_name=Path('im.jpg')): + def save_crop(self, save_dir, file_name=Path("im.jpg")): """ Save cropped predictions to `save_dir/cls/file_name.jpg`. @@ -321,79 +372,105 @@ class Results(SimpleClass): file_name (str | pathlib.Path): File name. """ if self.probs is not None: - LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.') + LOGGER.warning("WARNING ⚠️ Classify task do not support `save_crop`.") + return + if self.obb is not None: + LOGGER.warning("WARNING ⚠️ OBB task do not support `save_crop`.") return for d in self.boxes: - save_one_box(d.xyxy, - self.orig_img.copy(), - file=Path(save_dir) / self.names[int(d.cls)] / f'{Path(file_name).stem}.jpg', - BGR=True) + save_one_box( + d.xyxy, + self.orig_img.copy(), + file=Path(save_dir) / self.names[int(d.cls)] / f"{Path(file_name)}.jpg", + BGR=True, + ) - def tojson(self, normalize=False): - """Convert the object to JSON format.""" + def summary(self, normalize=False, decimals=5): + """Convert the results to a summarized format.""" if self.probs is not None: - LOGGER.warning('Warning: Classify task do not support `tojson` yet.') + LOGGER.warning("Warning: Classify results do not support the `summary()` method yet.") return - import json - # Create list of detection dictionaries results = [] data = self.boxes.data.cpu().tolist() h, w = self.orig_shape if normalize else (1, 1) for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id - box = {'x1': row[0] / w, 'y1': row[1] / h, 'x2': row[2] / w, 'y2': row[3] / h} - conf = row[-2] + box = { + "x1": round(row[0] / w, decimals), + "y1": round(row[1] / h, decimals), + "x2": round(row[2] / w, decimals), + "y2": round(row[3] / h, decimals), + } + conf = round(row[-2], decimals) class_id = int(row[-1]) - name = self.names[class_id] - result = {'name': name, 'class': class_id, 'confidence': conf, 'box': box} + result = {"name": self.names[class_id], "class": class_id, "confidence": conf, "box": box} if self.boxes.is_track: - result['track_id'] = int(row[-3]) # track ID + result["track_id"] = int(row[-3]) # track ID if self.masks: - x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array - result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()} + result["segments"] = { + "x": (self.masks.xy[i][:, 0] / w).round(decimals).tolist(), + "y": (self.masks.xy[i][:, 1] / h).round(decimals).tolist(), + } if self.keypoints is not None: x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor - result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()} + result["keypoints"] = { + "x": (x / w).numpy().round(decimals).tolist(), # decimals named argument required + "y": (y / h).numpy().round(decimals).tolist(), + "visible": visible.numpy().round(decimals).tolist(), + } results.append(result) - # Convert detections to JSON - return json.dumps(results, indent=2) + return results + + def tojson(self, normalize=False, decimals=5): + """Convert the results to JSON format.""" + import json + + return json.dumps(self.summary(normalize=normalize, decimals=decimals), indent=2) class Boxes(BaseTensor): """ - A class for storing and manipulating detection boxes. - - Args: - boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes, - with shape (num_boxes, 6) or (num_boxes, 7). The last two columns contain confidence and class values. - If present, the third last column contains track IDs. - orig_shape (tuple): Original image size, in the format (height, width). + Manages detection boxes, providing easy access and manipulation of box coordinates, confidence scores, class + identifiers, and optional tracking IDs. Supports multiple formats for box coordinates, including both absolute and + normalized forms. Attributes: - xyxy (torch.Tensor | numpy.ndarray): The boxes in xyxy format. - conf (torch.Tensor | numpy.ndarray): The confidence values of the boxes. - cls (torch.Tensor | numpy.ndarray): The class values of the boxes. - id (torch.Tensor | numpy.ndarray): The track IDs of the boxes (if available). - xywh (torch.Tensor | numpy.ndarray): The boxes in xywh format. - xyxyn (torch.Tensor | numpy.ndarray): The boxes in xyxy format normalized by original image size. - xywhn (torch.Tensor | numpy.ndarray): The boxes in xywh format normalized by original image size. - data (torch.Tensor): The raw bboxes tensor (alias for `boxes`). + data (torch.Tensor): The raw tensor containing detection boxes and their associated data. + orig_shape (tuple): The original image size as a tuple (height, width), used for normalization. + is_track (bool): Indicates whether tracking IDs are included in the box data. + + Properties: + xyxy (torch.Tensor | numpy.ndarray): Boxes in [x1, y1, x2, y2] format. + conf (torch.Tensor | numpy.ndarray): Confidence scores for each box. + cls (torch.Tensor | numpy.ndarray): Class labels for each box. + id (torch.Tensor | numpy.ndarray, optional): Tracking IDs for each box, if available. + xywh (torch.Tensor | numpy.ndarray): Boxes in [x, y, width, height] format, calculated on demand. + xyxyn (torch.Tensor | numpy.ndarray): Normalized [x1, y1, x2, y2] boxes, relative to `orig_shape`. + xywhn (torch.Tensor | numpy.ndarray): Normalized [x, y, width, height] boxes, relative to `orig_shape`. Methods: - cpu(): Move the object to CPU memory. - numpy(): Convert the object to a numpy array. - cuda(): Move the object to CUDA memory. - to(*args, **kwargs): Move the object to the specified device. + cpu(): Moves the boxes to CPU memory. + numpy(): Converts the boxes to a numpy array format. + cuda(): Moves the boxes to CUDA (GPU) memory. + to(device, dtype=None): Moves the boxes to the specified device. """ def __init__(self, boxes, orig_shape) -> None: - """Initialize the Boxes class.""" + """ + Initialize the Boxes class. + + Args: + boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes, with + shape (num_boxes, 6) or (num_boxes, 7). The last two columns contain confidence and class values. + If present, the third last column contains track IDs. + orig_shape (tuple): Original image size, in the format (height, width). + """ if boxes.ndim == 1: boxes = boxes[None, :] n = boxes.shape[-1] - assert n in (6, 7), f'expected `n` in [6, 7], but got {n}' # xyxy, track_id, conf, cls + assert n in (6, 7), f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls super().__init__(boxes, orig_shape) self.is_track = n == 7 self.orig_shape = orig_shape @@ -442,19 +519,12 @@ class Boxes(BaseTensor): xywh[..., [1, 3]] /= self.orig_shape[0] return xywh - @property - def boxes(self): - """Return the raw bboxes tensor (deprecated).""" - LOGGER.warning("WARNING ⚠️ 'Boxes.boxes' is deprecated. Use 'Boxes.data' instead.") - return self.data - class Masks(BaseTensor): """ A class for storing and manipulating detection masks. Attributes: - segments (list): Deprecated property for segments (normalized). xy (list): A list of segments in pixel coordinates. xyn (list): A list of normalized segments. @@ -471,22 +541,14 @@ class Masks(BaseTensor): masks = masks[None, :] super().__init__(masks, orig_shape) - @property - @lru_cache(maxsize=1) - def segments(self): - """Return segments (normalized). Deprecated; use xyn property instead.""" - LOGGER.warning( - "WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and 'Masks.xy' for segments (pixels) instead." - ) - return self.xyn - @property @lru_cache(maxsize=1) def xyn(self): """Return normalized segments.""" return [ ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True) - for x in ops.masks2segments(self.data)] + for x in ops.masks2segments(self.data) + ] @property @lru_cache(maxsize=1) @@ -494,13 +556,8 @@ class Masks(BaseTensor): """Return segments in pixel coordinates.""" return [ ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False) - for x in ops.masks2segments(self.data)] - - @property - def masks(self): - """Return the raw masks tensor. Deprecated; use data attribute instead.""" - LOGGER.warning("WARNING ⚠️ 'Masks.masks' is deprecated. Use 'Masks.data' instead.") - return self.data + for x in ops.masks2segments(self.data) + ] class Keypoints(BaseTensor): @@ -519,10 +576,14 @@ class Keypoints(BaseTensor): to(device, dtype): Returns a copy of the keypoints tensor with the specified device and dtype. """ + @smart_inference_mode() # avoid keypoints < conf in-place error def __init__(self, keypoints, orig_shape) -> None: """Initializes the Keypoints object with detection keypoints and original image size.""" if keypoints.ndim == 2: keypoints = keypoints[None, :] + if keypoints.shape[2] == 3: # x, y, conf + mask = keypoints[..., 2] < 0.5 # points with conf < 0.5 (not visible) + keypoints[..., :2][mask] = 0 super().__init__(keypoints, orig_shape) self.has_visible = self.data.shape[-1] == 3 @@ -566,6 +627,7 @@ class Probs(BaseTensor): """ def __init__(self, probs, orig_shape=None) -> None: + """Initialize the Probs class with classification probabilities and optional original shape of the image.""" super().__init__(probs, orig_shape) @property @@ -591,3 +653,91 @@ class Probs(BaseTensor): def top5conf(self): """Return the confidences of top 5.""" return self.data[self.top5] + + +class OBB(BaseTensor): + """ + A class for storing and manipulating Oriented Bounding Boxes (OBB). + + Args: + boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes, + with shape (num_boxes, 7) or (num_boxes, 8). The last two columns contain confidence and class values. + If present, the third last column contains track IDs, and the fifth column from the left contains rotation. + orig_shape (tuple): Original image size, in the format (height, width). + + Attributes: + xywhr (torch.Tensor | numpy.ndarray): The boxes in [x_center, y_center, width, height, rotation] format. + conf (torch.Tensor | numpy.ndarray): The confidence values of the boxes. + cls (torch.Tensor | numpy.ndarray): The class values of the boxes. + id (torch.Tensor | numpy.ndarray): The track IDs of the boxes (if available). + xyxyxyxyn (torch.Tensor | numpy.ndarray): The rotated boxes in xyxyxyxy format normalized by orig image size. + xyxyxyxy (torch.Tensor | numpy.ndarray): The rotated boxes in xyxyxyxy format. + xyxy (torch.Tensor | numpy.ndarray): The horizontal boxes in xyxyxyxy format. + data (torch.Tensor): The raw OBB tensor (alias for `boxes`). + + Methods: + cpu(): Move the object to CPU memory. + numpy(): Convert the object to a numpy array. + cuda(): Move the object to CUDA memory. + to(*args, **kwargs): Move the object to the specified device. + """ + + def __init__(self, boxes, orig_shape) -> None: + """Initialize the Boxes class.""" + if boxes.ndim == 1: + boxes = boxes[None, :] + n = boxes.shape[-1] + assert n in (7, 8), f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls + super().__init__(boxes, orig_shape) + self.is_track = n == 8 + self.orig_shape = orig_shape + + @property + def xywhr(self): + """Return the rotated boxes in xywhr format.""" + return self.data[:, :5] + + @property + def conf(self): + """Return the confidence values of the boxes.""" + return self.data[:, -2] + + @property + def cls(self): + """Return the class values of the boxes.""" + return self.data[:, -1] + + @property + def id(self): + """Return the track IDs of the boxes (if available).""" + return self.data[:, -3] if self.is_track else None + + @property + @lru_cache(maxsize=2) + def xyxyxyxy(self): + """Return the boxes in xyxyxyxy format, (N, 4, 2).""" + return ops.xywhr2xyxyxyxy(self.xywhr) + + @property + @lru_cache(maxsize=2) + def xyxyxyxyn(self): + """Return the boxes in xyxyxyxy format, (N, 4, 2).""" + xyxyxyxyn = self.xyxyxyxy.clone() if isinstance(self.xyxyxyxy, torch.Tensor) else np.copy(self.xyxyxyxy) + xyxyxyxyn[..., 0] /= self.orig_shape[1] + xyxyxyxyn[..., 1] /= self.orig_shape[0] + return xyxyxyxyn + + @property + @lru_cache(maxsize=2) + def xyxy(self): + """ + Return the horizontal boxes in xyxy format, (N, 4). + + Accepts both torch and numpy boxes. + """ + x1 = self.xyxyxyxy[..., 0].min(1).values + x2 = self.xyxyxyxy[..., 0].max(1).values + y1 = self.xyxyxyxy[..., 1].min(1).values + y2 = self.xyxyxyxy[..., 1].max(1).values + xyxy = [x1, y1, x2, y2] + return np.stack(xyxy, axis=-1) if isinstance(self.data, np.ndarray) else torch.stack(xyxy, dim=-1) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 4ff4229..2e7a7db 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license """ -Train a model on a dataset +Train a model on a dataset. Usage: $ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16 @@ -19,31 +19,45 @@ import numpy as np import torch from torch import distributed as dist from torch import nn, optim -from torch.cuda import amp -from torch.nn.parallel import DistributedDataParallel as DDP from ultralytics.cfg import get_cfg, get_save_dir from ultralytics.data.utils import check_cls_dataset, check_det_dataset from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights -from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM, __version__, callbacks, clean_url, colorstr, emojis, - yaml_save) +from ultralytics.utils import ( + DEFAULT_CFG, + LOGGER, + RANK, + TQDM, + __version__, + callbacks, + clean_url, + colorstr, + emojis, + yaml_save, +) from ultralytics.utils.autobatch import check_train_batch_size -from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args +from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.utils.files import get_latest_run -from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device, - strip_optimizer) +from ultralytics.utils.torch_utils import ( + EarlyStopping, + ModelEMA, + de_parallel, + init_seeds, + one_cycle, + select_device, + strip_optimizer, +) class BaseTrainer: """ - BaseTrainer + BaseTrainer. A base class for creating trainers. Attributes: args (SimpleNamespace): Configuration for the trainer. - check_resume (method): Method to check if training should be resumed from a saved checkpoint. validator (BaseValidator): Validator instance. model (nn.Module): Model instance. callbacks (defaultdict): Dictionary of callbacks. @@ -62,6 +76,7 @@ class BaseTrainer: trainset (torch.utils.data.Dataset): Training dataset. testset (torch.utils.data.Dataset): Testing dataset. ema (nn.Module): EMA (Exponential Moving Average) of the model. + resume (bool): Resume training from a checkpoint. lf (nn.Module): Loss function. scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler. best_fitness (float): The best fitness value achieved. @@ -84,19 +99,19 @@ class BaseTrainer: self.check_resume(overrides) self.device = select_device(self.args.device, self.args.batch) self.validator = None - self.model = None self.metrics = None self.plots = {} init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) # Dirs self.save_dir = get_save_dir(self.args) - self.wdir = self.save_dir / 'weights' # weights dir + self.args.name = self.save_dir.name # update name for loggers + self.wdir = self.save_dir / "weights" # weights dir if RANK in (-1, 0): self.wdir.mkdir(parents=True, exist_ok=True) # make dir self.args.save_dir = str(self.save_dir) - yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args - self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths + yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args + self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths self.save_period = self.args.save_period self.batch_size = self.args.batch @@ -106,18 +121,23 @@ class BaseTrainer: print_args(vars(self.args)) # Device - if self.device.type in ('cpu', 'mps'): + if self.device.type in ("cpu", "mps"): self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading # Model and Dataset - self.model = self.args.model + self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt try: - if self.args.task == 'classify': + if self.args.task == "classify": self.data = check_cls_dataset(self.args.data) - elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment', 'pose'): + elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ( + "detect", + "segment", + "pose", + "obb", + ): self.data = check_det_dataset(self.args.data) - if 'yaml_file' in self.data: - self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage + if "yaml_file" in self.data: + self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage except Exception as e: raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e @@ -133,8 +153,8 @@ class BaseTrainer: self.fitness = None self.loss = None self.tloss = None - self.loss_names = ['Loss'] - self.csv = self.save_dir / 'results.csv' + self.loss_names = ["Loss"] + self.csv = self.save_dir / "results.csv" self.plot_idx = [0, 1, 2] # Callbacks @@ -143,15 +163,11 @@ class BaseTrainer: callbacks.add_integration_callbacks(self) def add_callback(self, event: str, callback): - """ - Appends the given callback. - """ + """Appends the given callback.""" self.callbacks[event].append(callback) def set_callback(self, event: str, callback): - """ - Overrides the existing callbacks with the given callback. - """ + """Overrides the existing callbacks with the given callback.""" self.callbacks[event] = [callback] def run_callbacks(self, event: str): @@ -162,7 +178,7 @@ class BaseTrainer: def train(self): """Allow device='', device=None on Multi-GPU systems to default to device=0.""" if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3' - world_size = len(self.args.device.split(',')) + world_size = len(self.args.device.split(",")) elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list) world_size = len(self.args.device) elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number @@ -171,14 +187,16 @@ class BaseTrainer: world_size = 0 # Run subprocess if DDP training, else train normally - if world_size > 1 and 'LOCAL_RANK' not in os.environ: + if world_size > 1 and "LOCAL_RANK" not in os.environ: # Argument checks if self.args.rect: LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'") self.args.rect = False if self.args.batch == -1: - LOGGER.warning("WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting " - "default 'batch=16'") + LOGGER.warning( + "WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting " + "default 'batch=16'" + ) self.args.batch = 16 # Command @@ -194,42 +212,56 @@ class BaseTrainer: else: self._do_train(world_size) + def _setup_scheduler(self): + """Initialize training learning rate scheduler.""" + if self.args.cos_lr: + self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf'] + else: + self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear + self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) + def _setup_ddp(self, world_size): """Initializes and sets the DistributedDataParallel parameters for training.""" torch.cuda.set_device(RANK) - self.device = torch.device('cuda', RANK) + self.device = torch.device("cuda", RANK) # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}') - os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout + os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout dist.init_process_group( - 'nccl' if dist.is_nccl_available() else 'gloo', + backend="nccl" if dist.is_nccl_available() else "gloo", timeout=timedelta(seconds=10800), # 3 hours rank=RANK, - world_size=world_size) + world_size=world_size, + ) def _setup_train(self, world_size): - """ - Builds dataloaders and optimizer on correct rank process. - """ + """Builds dataloaders and optimizer on correct rank process.""" # Model - self.run_callbacks('on_pretrain_routine_start') + self.run_callbacks("on_pretrain_routine_start") ckpt = self.setup_model() self.model = self.model.to(self.device) self.set_model_attributes() # Freeze layers - freeze_list = self.args.freeze if isinstance( - self.args.freeze, list) else range(self.args.freeze) if isinstance(self.args.freeze, int) else [] - always_freeze_names = ['.dfl'] # always freeze these layers - freeze_layer_names = [f'model.{x}.' for x in freeze_list] + always_freeze_names + freeze_list = ( + self.args.freeze + if isinstance(self.args.freeze, list) + else range(self.args.freeze) + if isinstance(self.args.freeze, int) + else [] + ) + always_freeze_names = [".dfl"] # always freeze these layers + freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names for k, v in self.model.named_parameters(): # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results) if any(x in k for x in freeze_layer_names): LOGGER.info(f"Freezing layer '{k}'") v.requires_grad = False - elif not v.requires_grad: - LOGGER.info(f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. " - 'See ultralytics.engine.trainer for customization of frozen layers.') + elif not v.requires_grad and v.dtype.is_floating_point: # only floating point Tensor can require gradients + LOGGER.info( + f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. " + "See ultralytics.engine.trainer for customization of frozen layers." + ) v.requires_grad = True # Check AMP @@ -241,13 +273,14 @@ class BaseTrainer: if RANK > -1 and world_size > 1: # DDP dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None) self.amp = bool(self.amp) # as boolean - self.scaler = amp.GradScaler(enabled=self.amp) + self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp) if world_size > 1: - self.model = DDP(self.model, device_ids=[RANK]) + self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK]) # Check imgsz - gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride) + gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride) self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1) + self.stride = gs # for multiscale training # Batch size if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size @@ -255,11 +288,14 @@ class BaseTrainer: # Dataloaders batch_size = self.batch_size // max(world_size, 1) - self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train') + self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train") if RANK in (-1, 0): - self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val') + # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects. + self.test_loader = self.get_dataloader( + self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val" + ) self.validator = self.get_validator() - metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val') + metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val") self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) self.ema = ModelEMA(self.model) if self.args.plots: @@ -269,22 +305,20 @@ class BaseTrainer: self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs - self.optimizer = self.build_optimizer(model=self.model, - name=self.args.optimizer, - lr=self.args.lr0, - momentum=self.args.momentum, - decay=weight_decay, - iterations=iterations) + self.optimizer = self.build_optimizer( + model=self.model, + name=self.args.optimizer, + lr=self.args.lr0, + momentum=self.args.momentum, + decay=weight_decay, + iterations=iterations, + ) # Scheduler - if self.args.cos_lr: - self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf'] - else: - self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear - self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) + self._setup_scheduler() self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False self.resume_training(ckpt) self.scheduler.last_epoch = self.start_epoch - 1 # do not move - self.run_callbacks('on_pretrain_routine_end') + self.run_callbacks("on_pretrain_routine_end") def _do_train(self, world_size=1): """Train completed, evaluate and plot if specified by arguments.""" @@ -292,35 +326,33 @@ class BaseTrainer: self._setup_ddp(world_size) self._setup_train(world_size) - self.epoch_time = None - self.epoch_time_start = time.time() - self.train_time_start = time.time() nb = len(self.train_loader) # number of batches nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations last_opt_step = -1 - self.run_callbacks('on_train_start') - LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n' - f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' - f"Logging results to {colorstr('bold', self.save_dir)}\n" - f'Starting training for {self.epochs} epochs...') + self.epoch_time = None + self.epoch_time_start = time.time() + self.train_time_start = time.time() + self.run_callbacks("on_train_start") + LOGGER.info( + f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n' + f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' + f"Logging results to {colorstr('bold', self.save_dir)}\n" + f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...") + ) if self.args.close_mosaic: base_idx = (self.epochs - self.args.close_mosaic) * nb self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2]) - epoch = self.epochs # predefine for resume fully trained model edge cases - for epoch in range(self.start_epoch, self.epochs): + epoch = self.start_epoch + while True: self.epoch = epoch - self.run_callbacks('on_train_epoch_start') + self.run_callbacks("on_train_epoch_start") self.model.train() if RANK != -1: self.train_loader.sampler.set_epoch(epoch) pbar = enumerate(self.train_loader) # Update dataloader attributes (optional) if epoch == (self.epochs - self.args.close_mosaic): - LOGGER.info('Closing dataloader mosaic') - if hasattr(self.train_loader.dataset, 'mosaic'): - self.train_loader.dataset.mosaic = False - if hasattr(self.train_loader.dataset, 'close_mosaic'): - self.train_loader.dataset.close_mosaic(hyp=self.args) + self._close_dataloader_mosaic() self.train_loader.reset() if RANK in (-1, 0): @@ -329,18 +361,19 @@ class BaseTrainer: self.tloss = None self.optimizer.zero_grad() for i, batch in pbar: - self.run_callbacks('on_train_batch_start') + self.run_callbacks("on_train_batch_start") # Warmup ni = i + nb * epoch if ni <= nw: xi = [0, nw] # x interp - self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()) + self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())) for j, x in enumerate(self.optimizer.param_groups): # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 - x['lr'] = np.interp( - ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)]) - if 'momentum' in x: - x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) + x["lr"] = np.interp( + ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)] + ) + if "momentum" in x: + x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) # Forward with torch.cuda.amp.autocast(self.amp): @@ -348,8 +381,9 @@ class BaseTrainer: self.loss, self.loss_items = self.model(batch) if RANK != -1: self.loss *= world_size - self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \ - else self.loss_items + self.tloss = ( + (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items + ) # Backward self.scaler.scale(self.loss).backward() @@ -359,115 +393,137 @@ class BaseTrainer: self.optimizer_step() last_opt_step = ni + # Timed stopping + if self.args.time: + self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600) + if RANK != -1: # if DDP training + broadcast_list = [self.stop if RANK == 0 else None] + dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks + self.stop = broadcast_list[0] + if self.stop: # training time exceeded + break + # Log - mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) - loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1 + mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB) + loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1 losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0) if RANK in (-1, 0): pbar.set_description( - ('%11s' * 2 + '%11.4g' * (2 + loss_len)) % - (f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1])) - self.run_callbacks('on_batch_end') + ("%11s" * 2 + "%11.4g" * (2 + loss_len)) + % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]) + ) + self.run_callbacks("on_batch_end") if self.args.plots and ni in self.plot_idx: self.plot_training_samples(batch, ni) - self.run_callbacks('on_train_batch_end') - - self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers - - with warnings.catch_warnings(): - warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()' - self.scheduler.step() - self.run_callbacks('on_train_epoch_end') + self.run_callbacks("on_train_batch_end") + self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers + self.run_callbacks("on_train_epoch_end") if RANK in (-1, 0): + final_epoch = epoch + 1 == self.epochs + self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"]) # Validation - self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights']) - final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop - - if self.args.val or final_epoch: + if (self.args.val and (((epoch+1) % self.args.val_period == 0) or (self.epochs - epoch) <= 10)) \ + or final_epoch or self.stopper.possible_stop or self.stop: self.metrics, self.fitness = self.validate() self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr}) - self.stop = self.stopper(epoch + 1, self.fitness) + self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch + if self.args.time: + self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600) # Save model - if self.args.save or (epoch + 1 == self.epochs): + if self.args.save or final_epoch: self.save_model() - self.run_callbacks('on_model_save') + self.run_callbacks("on_model_save") - tnow = time.time() - self.epoch_time = tnow - self.epoch_time_start - self.epoch_time_start = tnow - self.run_callbacks('on_fit_epoch_end') - torch.cuda.empty_cache() # clears GPU vRAM at end of epoch, can help with out of memory errors + # Scheduler + t = time.time() + self.epoch_time = t - self.epoch_time_start + self.epoch_time_start = t + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()' + if self.args.time: + mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1) + self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time) + self._setup_scheduler() + self.scheduler.last_epoch = self.epoch # do not move + self.stop |= epoch >= self.epochs # stop if exceeded epochs + self.scheduler.step() + self.run_callbacks("on_fit_epoch_end") + torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors # Early Stopping if RANK != -1: # if DDP training broadcast_list = [self.stop if RANK == 0 else None] dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks - if RANK != 0: - self.stop = broadcast_list[0] + self.stop = broadcast_list[0] if self.stop: break # must break all DDP ranks + epoch += 1 if RANK in (-1, 0): # Do final val with best.pt - LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in ' - f'{(time.time() - self.train_time_start) / 3600:.3f} hours.') + LOGGER.info( + f"\n{epoch - self.start_epoch + 1} epochs completed in " + f"{(time.time() - self.train_time_start) / 3600:.3f} hours." + ) self.final_eval() if self.args.plots: self.plot_metrics() - self.run_callbacks('on_train_end') + self.run_callbacks("on_train_end") torch.cuda.empty_cache() - self.run_callbacks('teardown') + self.run_callbacks("teardown") def save_model(self): - """Save model checkpoints based on various conditions.""" + """Save model training checkpoints with additional metadata.""" + import pandas as pd # scope for faster startup + + metrics = {**self.metrics, **{"fitness": self.fitness}} + results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()} ckpt = { - 'epoch': self.epoch, - 'best_fitness': self.best_fitness, - 'model': deepcopy(de_parallel(self.model)).half(), - 'ema': deepcopy(self.ema.ema).half(), - 'updates': self.ema.updates, - 'optimizer': self.optimizer.state_dict(), - 'train_args': vars(self.args), # save as dict - 'date': datetime.now().isoformat(), - 'version': __version__} + "epoch": self.epoch, + "best_fitness": self.best_fitness, + "model": deepcopy(de_parallel(self.model)).half(), + "ema": deepcopy(self.ema.ema).half(), + "updates": self.ema.updates, + "optimizer": self.optimizer.state_dict(), + "train_args": vars(self.args), # save as dict + "train_metrics": metrics, + "train_results": results, + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + } - # Use dill (if exists) to serialize the lambda functions where pickle does not do this - try: - import dill as pickle - except ImportError: - import pickle - - # Save last, best and delete - torch.save(ckpt, self.last, pickle_module=pickle) + # Save last and best + torch.save(ckpt, self.last) if self.best_fitness == self.fitness: - torch.save(ckpt, self.best, pickle_module=pickle) - if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0): - torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle) - del ckpt + torch.save(ckpt, self.best) + if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0): + torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt") @staticmethod def get_dataset(data): """ - Get train, val path from data dict if it exists. Returns None if data format is not recognized. + Get train, val path from data dict if it exists. + + Returns None if data format is not recognized. """ - return data['train'], data.get('val') or data.get('test') + return data["train"], data.get("val") or data.get("test") def setup_model(self): - """ - load/create/download model for any task. - """ + """Load/create/download model for any task.""" if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed return model, weights = self.model, None ckpt = None - if str(model).endswith('.pt'): + if str(model).endswith(".pt"): weights, ckpt = attempt_load_one_weight(model) - cfg = ckpt['model'].yaml + cfg = ckpt["model"].yaml else: cfg = model self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights) @@ -484,17 +540,17 @@ class BaseTrainer: self.ema.update(self.model) def preprocess_batch(self, batch): - """ - Allows custom preprocessing model inputs and ground truths depending on task type. - """ + """Allows custom preprocessing model inputs and ground truths depending on task type.""" return batch def validate(self): """ - Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key. + Runs validation on test set using self.validator. + + The returned dict is expected to contain "fitness" key. """ metrics = self.validator(self) - fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found + fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found if not self.best_fitness or self.best_fitness < fitness: self.best_fitness = fitness return metrics, fitness @@ -505,30 +561,28 @@ class BaseTrainer: def get_validator(self): """Returns a NotImplementedError when the get_validator function is called.""" - raise NotImplementedError('get_validator function not implemented in trainer') + raise NotImplementedError("get_validator function not implemented in trainer") - def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): - """ - Returns dataloader derived from torch.data.Dataloader. - """ - raise NotImplementedError('get_dataloader function not implemented in trainer') + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): + """Returns dataloader derived from torch.data.Dataloader.""" + raise NotImplementedError("get_dataloader function not implemented in trainer") - def build_dataset(self, img_path, mode='train', batch=None): - """Build dataset""" - raise NotImplementedError('build_dataset function not implemented in trainer') + def build_dataset(self, img_path, mode="train", batch=None): + """Build dataset.""" + raise NotImplementedError("build_dataset function not implemented in trainer") - def label_loss_items(self, loss_items=None, prefix='train'): + def label_loss_items(self, loss_items=None, prefix="train"): """ - Returns a loss dict with labelled training loss items tensor + Returns a loss dict with labelled training loss items tensor. + + Note: + This is not needed for classification but necessary for segmentation & detection """ - # Not needed for classification but necessary for segmentation & detection - return {'loss': loss_items} if loss_items is not None else ['loss'] + return {"loss": loss_items} if loss_items is not None else ["loss"] def set_model_attributes(self): - """ - To set or update model parameters before training. - """ - self.model.names = self.data['names'] + """To set or update model parameters before training.""" + self.model.names = self.data["names"] def build_targets(self, preds, targets): """Builds target tensors for training YOLO model.""" @@ -536,11 +590,11 @@ class BaseTrainer: def progress_string(self): """Returns a string describing training progress.""" - return '' + return "" # TODO: may need to put these following functions into callback def plot_training_samples(self, batch, ni): - """Plots training samples during YOLOv5 training.""" + """Plots training samples during YOLO training.""" pass def plot_training_labels(self): @@ -551,9 +605,9 @@ class BaseTrainer: """Saves training metrics to a CSV file.""" keys, vals = list(metrics.keys()), list(metrics.values()) n = len(metrics) + 1 # number of cols - s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header - with open(self.csv, 'a') as f: - f.write(s + ('%23.5g,' * n % tuple([self.epoch + 1] + vals)).rstrip(',') + '\n') + s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header + with open(self.csv, "a") as f: + f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n") def plot_metrics(self): """Plot and display metrics visually.""" @@ -562,7 +616,7 @@ class BaseTrainer: def on_plot(self, name, data=None): """Registers plots (e.g. to be consumed in callbacks)""" path = Path(name) - self.plots[path] = {'data': data, 'timestamp': time.time()} + self.plots[path] = {"data": data, "timestamp": time.time()} def final_eval(self): """Performs final evaluation and validation for object detection YOLO model.""" @@ -570,11 +624,11 @@ class BaseTrainer: if f.exists(): strip_optimizer(f) # strip optimizers if f is self.best: - LOGGER.info(f'\nValidating {f}...') + LOGGER.info(f"\nValidating {f}...") self.validator.args.plots = self.args.plots self.metrics = self.validator(model=f) - self.metrics.pop('fitness', None) - self.run_callbacks('on_fit_epoch_end') + self.metrics.pop("fitness", None) + self.run_callbacks("on_fit_epoch_end") def check_resume(self, overrides): """Check if resume checkpoint exists and update arguments accordingly.""" @@ -586,56 +640,62 @@ class BaseTrainer: # Check that resume data YAML exists, otherwise strip to force re-download of dataset ckpt_args = attempt_load_weights(last).args - if not Path(ckpt_args['data']).exists(): - ckpt_args['data'] = self.args.data + if not Path(ckpt_args["data"]).exists(): + ckpt_args["data"] = self.args.data resume = True self.args = get_cfg(ckpt_args) - self.args.model = str(last) # reinstate model - for k in 'imgsz', 'batch': # allow arg updates to reduce memory on resume if crashed due to CUDA OOM + self.args.model = self.args.resume = str(last) # reinstate model + for k in "imgsz", "batch", "device": # allow arg updates to reduce memory or update device on resume if k in overrides: setattr(self.args, k, overrides[k]) except Exception as e: - raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, ' - "i.e. 'yolo train resume model=path/to/last.pt'") from e + raise FileNotFoundError( + "Resume checkpoint not found. Please pass a valid checkpoint to resume from, " + "i.e. 'yolo train resume model=path/to/last.pt'" + ) from e self.resume = resume def resume_training(self, ckpt): """Resume YOLO training from given epoch and best fitness.""" - if ckpt is None: + if ckpt is None or not self.resume: return best_fitness = 0.0 - start_epoch = ckpt['epoch'] + 1 - if ckpt['optimizer'] is not None: - self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer - best_fitness = ckpt['best_fitness'] - if self.ema and ckpt.get('ema'): - self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA - self.ema.updates = ckpt['updates'] - if self.resume: - assert start_epoch > 0, \ - f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \ - f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'" - LOGGER.info( - f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs') + start_epoch = ckpt["epoch"] + 1 + if ckpt["optimizer"] is not None: + self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer + best_fitness = ckpt["best_fitness"] + if self.ema and ckpt.get("ema"): + self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA + self.ema.updates = ckpt["updates"] + assert start_epoch > 0, ( + f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n" + f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'" + ) + LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs") if self.epochs < start_epoch: LOGGER.info( - f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.") - self.epochs += ckpt['epoch'] # finetune additional epochs + f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs." + ) + self.epochs += ckpt["epoch"] # finetune additional epochs self.best_fitness = best_fitness self.start_epoch = start_epoch if start_epoch > (self.epochs - self.args.close_mosaic): - LOGGER.info('Closing dataloader mosaic') - if hasattr(self.train_loader.dataset, 'mosaic'): - self.train_loader.dataset.mosaic = False - if hasattr(self.train_loader.dataset, 'close_mosaic'): - self.train_loader.dataset.close_mosaic(hyp=self.args) + self._close_dataloader_mosaic() - def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5): + def _close_dataloader_mosaic(self): + """Update dataloaders to stop using mosaic augmentation.""" + if hasattr(self.train_loader.dataset, "mosaic"): + self.train_loader.dataset.mosaic = False + if hasattr(self.train_loader.dataset, "close_mosaic"): + LOGGER.info("Closing dataloader mosaic") + self.train_loader.dataset.close_mosaic(hyp=self.args) + + def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5): """ - Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, - momentum, weight decay, and number of iterations. + Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum, + weight decay, and number of iterations. Args: model (torch.nn.Module): The model for which to build an optimizer. @@ -652,38 +712,45 @@ class BaseTrainer: """ g = [], [], [] # optimizer parameter groups - bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d() - if name == 'auto': - nc = getattr(model, 'nc', 10) # number of classes + bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d() + if name == "auto": + LOGGER.info( + f"{colorstr('optimizer:')} 'optimizer=auto' found, " + f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and " + f"determining best 'optimizer', 'lr0' and 'momentum' automatically... " + ) + nc = getattr(model, "nc", 10) # number of classes lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places - name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9) + name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9) self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam for module_name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): - fullname = f'{module_name}.{param_name}' if module_name else param_name - if 'bias' in fullname: # bias (no decay) + fullname = f"{module_name}.{param_name}" if module_name else param_name + if "bias" in fullname: # bias (no decay) g[2].append(param) elif isinstance(module, bn): # weight (no decay) g[1].append(param) else: # weight (with decay) g[0].append(param) - if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'): + if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"): optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0) - elif name == 'RMSProp': + elif name == "RMSProp": optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum) - elif name == 'SGD': + elif name == "SGD": optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) else: raise NotImplementedError( f"Optimizer '{name}' not found in list of available optimizers " - f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].' - 'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.') + f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]." + "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics." + ) - optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay - optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights) + optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay + optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights) LOGGER.info( f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups " - f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)') + f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)' + ) return optimizer diff --git a/ultralytics/engine/tuner.py b/ultralytics/engine/tuner.py index 0702690..f4fe57e 100644 --- a/ultralytics/engine/tuner.py +++ b/ultralytics/engine/tuner.py @@ -13,48 +13,59 @@ Example: from ultralytics import YOLO model = YOLO('yolov8n.pt') - model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10) + model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False) ``` """ + import random +import shutil +import subprocess import time -from copy import deepcopy import numpy as np +import torch -from ultralytics import YOLO from ultralytics.cfg import get_cfg, get_save_dir -from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, yaml_print, yaml_save +from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, remove_colorstr, yaml_print, yaml_save +from ultralytics.utils.plotting import plot_tune_results class Tuner: """ - Class responsible for hyperparameter tuning of YOLO models. + Class responsible for hyperparameter tuning of YOLO models. - The class evolves YOLO model hyperparameters over a given number of iterations - by mutating them according to the search space and retraining the model to evaluate their performance. + The class evolves YOLO model hyperparameters over a given number of iterations + by mutating them according to the search space and retraining the model to evaluate their performance. - Attributes: - space (dict): Hyperparameter search space containing bounds and scaling factors for mutation. - tune_dir (Path): Directory where evolution logs and results will be saved. - evolve_csv (Path): Path to the CSV file where evolution logs are saved. + Attributes: + space (dict): Hyperparameter search space containing bounds and scaling factors for mutation. + tune_dir (Path): Directory where evolution logs and results will be saved. + tune_csv (Path): Path to the CSV file where evolution logs are saved. - Methods: - _mutate(hyp: dict) -> dict: - Mutates the given hyperparameters within the bounds specified in `self.space`. + Methods: + _mutate(hyp: dict) -> dict: + Mutates the given hyperparameters within the bounds specified in `self.space`. - __call__(): - Executes the hyperparameter evolution across multiple iterations. + __call__(): + Executes the hyperparameter evolution across multiple iterations. - Example: - Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations. - ```python - from ultralytics import YOLO + Example: + Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations. + ```python + from ultralytics import YOLO - model = YOLO('yolov8n.pt') - model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10, val=False, cache=True) - ``` - """ + model = YOLO('yolov8n.pt') + model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False) + ``` + + Tune with custom search space. + ```python + from ultralytics import YOLO + + model = YOLO('yolov8n.pt') + model.tune(space={key1: val1, key2: val2}) # custom search space dictionary + ``` + """ def __init__(self, args=DEFAULT_CFG, _callbacks=None): """ @@ -63,37 +74,44 @@ class Tuner: Args: args (dict, optional): Configuration for hyperparameter evolution. """ - self.args = get_cfg(overrides=args) - self.space = { # key: (min, max, gain(optionaL)) + self.space = args.pop("space", None) or { # key: (min, max, gain(optional)) # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), - 'lr0': (1e-5, 1e-1), - 'lrf': (0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) - 'momentum': (0.6, 0.98, 0.3), # SGD momentum/Adam beta1 - 'weight_decay': (0.0, 0.001), # optimizer weight decay 5e-4 - 'warmup_epochs': (0.0, 5.0), # warmup epochs (fractions ok) - 'warmup_momentum': (0.0, 0.95), # warmup initial momentum - 'box': (0.02, 0.2), # box loss gain - 'cls': (0.2, 4.0), # cls loss gain (scale with pixels) - 'hsv_h': (0.0, 0.1), # image HSV-Hue augmentation (fraction) - 'hsv_s': (0.0, 0.9), # image HSV-Saturation augmentation (fraction) - 'hsv_v': (0.0, 0.9), # image HSV-Value augmentation (fraction) - 'degrees': (0.0, 45.0), # image rotation (+/- deg) - 'translate': (0.0, 0.9), # image translation (+/- fraction) - 'scale': (0.0, 0.9), # image scale (+/- gain) - 'shear': (0.0, 10.0), # image shear (+/- deg) - 'perspective': (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 - 'flipud': (0.0, 1.0), # image flip up-down (probability) - 'fliplr': (0.0, 1.0), # image flip left-right (probability) - 'mosaic': (0.0, 1.0), # image mixup (probability) - 'mixup': (0.0, 1.0), # image mixup (probability) - 'copy_paste': (0.0, 1.0)} # segment copy-paste (probability) - self.tune_dir = get_save_dir(self.args, name='_tune') - self.evolve_csv = self.tune_dir / 'evolve.csv' + "lr0": (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3) + "lrf": (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf) + "momentum": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1 + "weight_decay": (0.0, 0.001), # optimizer weight decay 5e-4 + "warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok) + "warmup_momentum": (0.0, 0.95), # warmup initial momentum + "box": (1.0, 20.0), # box loss gain + "cls": (0.2, 4.0), # cls loss gain (scale with pixels) + "dfl": (0.4, 6.0), # dfl loss gain + "hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction) + "hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction) + "hsv_v": (0.0, 0.9), # image HSV-Value augmentation (fraction) + "degrees": (0.0, 45.0), # image rotation (+/- deg) + "translate": (0.0, 0.9), # image translation (+/- fraction) + "scale": (0.0, 0.95), # image scale (+/- gain) + "shear": (0.0, 10.0), # image shear (+/- deg) + "perspective": (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 + "flipud": (0.0, 1.0), # image flip up-down (probability) + "fliplr": (0.0, 1.0), # image flip left-right (probability) + "bgr": (0.0, 1.0), # image channel bgr (probability) + "mosaic": (0.0, 1.0), # image mixup (probability) + "mixup": (0.0, 1.0), # image mixup (probability) + "copy_paste": (0.0, 1.0), # segment copy-paste (probability) + } + self.args = get_cfg(overrides=args) + self.tune_dir = get_save_dir(self.args, name="tune") + self.tune_csv = self.tune_dir / "tune_results.csv" self.callbacks = _callbacks or callbacks.get_default_callbacks() + self.prefix = colorstr("Tuner: ") callbacks.add_integration_callbacks(self) - LOGGER.info(f"Initialized Tuner instance with 'tune_dir={self.tune_dir}'.") + LOGGER.info( + f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n" + f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning" + ) - def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.2): + def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2): """ Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`. @@ -106,17 +124,17 @@ class Tuner: Returns: (dict): A dictionary containing mutated hyperparameters. """ - if self.evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate + if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate # Select parent(s) - x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1) + x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1) fitness = x[:, 0] # first column n = min(n, len(x)) # number of previous results to consider x = x[np.argsort(-fitness)][:n] # top n mutations - w = x[:, 0] - x[:, 0].min() + 1E-6 # weights (sum > 0) - if parent == 'single' or len(x) == 1: + w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0) + if parent == "single" or len(x) == 1: # x = x[random.randint(0, n - 1)] # random selection x = x[random.choices(range(n), weights=w)[0]] # weighted selection - elif parent == 'weighted': + elif parent == "weighted": x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination # Mutate @@ -139,7 +157,7 @@ class Tuner: return hyp - def __call__(self, model=None, iterations=10, prefix=colorstr('Tuner:')): + def __call__(self, model=None, iterations=10, cleanup=True): """ Executes the hyperparameter evolution process when the Tuner instance is called. @@ -152,54 +170,73 @@ class Tuner: Args: model (Model): A pre-initialized YOLO model to be used for training. iterations (int): The number of generations to run the evolution for. + cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning. Note: - The method utilizes the `self.evolve_csv` Path object to read and log hyperparameters and fitness scores. + The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores. Ensure this path is set correctly in the Tuner instance. """ t0 = time.time() best_save_dir, best_metrics = None, None - self.tune_dir.mkdir(parents=True, exist_ok=True) + (self.tune_dir / "weights").mkdir(parents=True, exist_ok=True) for i in range(iterations): # Mutate hyperparameters mutated_hyp = self._mutate() - LOGGER.info(f'{prefix} Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}') + LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}") + metrics = {} + train_args = {**vars(self.args), **mutated_hyp} + save_dir = get_save_dir(get_cfg(train_args)) + weights_dir = save_dir / "weights" try: - # Train YOLO model with mutated hyperparameters - train_args = {**vars(self.args), **mutated_hyp} - results = (deepcopy(model) or YOLO(self.args.model)).train(**train_args) - fitness = results.fitness + # Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang) + cmd = ["yolo", "train", *(f"{k}={v}" for k, v in train_args.items())] + return_code = subprocess.run(cmd, check=True).returncode + ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt") + metrics = torch.load(ckpt_file)["train_metrics"] + assert return_code == 0, "training failed" + except Exception as e: - LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i}\n{e}') - fitness = 0.0 + LOGGER.warning(f"WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}") - # Save results and mutated_hyp to evolve_csv + # Save results and mutated_hyp to CSV + fitness = metrics.get("fitness", 0.0) log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()] - headers = '' if self.evolve_csv.exists() else (','.join(['fitness_score'] + list(self.space.keys())) + '\n') - with open(self.evolve_csv, 'a') as f: - f.write(headers + ','.join(map(str, log_row)) + '\n') + headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n") + with open(self.tune_csv, "a") as f: + f.write(headers + ",".join(map(str, log_row)) + "\n") - # Print tuning results - x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1) + # Get best results + x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1) fitness = x[:, 0] # first column best_idx = fitness.argmax() best_is_current = best_idx == i if best_is_current: - best_save_dir = results.save_dir - best_metrics = {k: round(v, 5) for k, v in results.results_dict.items()} - header = (f'{prefix} {i + 1} iterations complete ✅ ({time.time() - t0:.2f}s)\n' - f'{prefix} Results saved to {colorstr("bold", self.tune_dir)}\n' - f'{prefix} Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n' - f'{prefix} Best fitness metrics are {best_metrics}\n' - f'{prefix} Best fitness model is {best_save_dir}\n' - f'{prefix} Best fitness hyperparameters are printed below.\n') + best_save_dir = save_dir + best_metrics = {k: round(v, 5) for k, v in metrics.items()} + for ckpt in weights_dir.glob("*.pt"): + shutil.copy2(ckpt, self.tune_dir / "weights") + elif cleanup: + shutil.rmtree(ckpt_file.parent) # remove iteration weights/ dir to reduce storage space - LOGGER.info('\n' + header) + # Plot tune results + plot_tune_results(self.tune_csv) - # Save turning results - data = {k: float(x[0, i + 1]) for i, k in enumerate(self.space.keys())} - header = header.replace(prefix, '#').replace('/', '').replace('', '') + '\n' - yaml_save(self.tune_dir / 'best.yaml', data=data, header=header) - yaml_print(self.tune_dir / 'best.yaml') + # Save and print tune results + header = ( + f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n' + f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n' + f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n' + f'{self.prefix}Best fitness metrics are {best_metrics}\n' + f'{self.prefix}Best fitness model is {best_save_dir}\n' + f'{self.prefix}Best fitness hyperparameters are printed below.\n' + ) + LOGGER.info("\n" + header) + data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())} + yaml_save( + self.tune_dir / "best_hyperparameters.yaml", + data=data, + header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n", + ) + yaml_print(self.tune_dir / "best_hyperparameters.yaml") diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index 9730c9b..e6d3f67 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -17,7 +17,9 @@ Usage - formats: yolov8n.tflite # TensorFlow Lite yolov8n_edgetpu.tflite # TensorFlow Edge TPU yolov8n_paddle_model # PaddlePaddle + yolov8n_ncnn_model # NCNN """ + import json import time from pathlib import Path @@ -36,7 +38,7 @@ from ultralytics.utils.torch_utils import de_parallel, select_device, smart_infe class BaseValidator: """ - BaseValidator + BaseValidator. A base class for creating validators. @@ -77,7 +79,7 @@ class BaseValidator: self.args = get_cfg(overrides=args) self.dataloader = dataloader self.pbar = pbar - self.model = None + self.stride = None self.data = None self.device = None self.batch_i = None @@ -89,20 +91,20 @@ class BaseValidator: self.nc = None self.iouv = None self.jdict = None - self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} self.save_dir = save_dir or get_save_dir(self.args) - (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) + (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) if self.args.conf is None: self.args.conf = 0.001 # default conf=0.001 + self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1) self.plots = {} self.callbacks = _callbacks or callbacks.get_default_callbacks() @smart_inference_mode() def __call__(self, trainer=None, model=None): - """ - Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer + """Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer gets priority). """ self.training = trainer is not None @@ -110,7 +112,7 @@ class BaseValidator: if self.training: self.device = trainer.device self.data = trainer.data - self.args.half = self.device.type != 'cpu' # force FP16 val during training + # self.args.half = self.device.type != "cpu" # force FP16 val during training model = trainer.ema.ema or trainer.model model = model.half() if self.args.half else model.float() # self.model = model @@ -119,12 +121,13 @@ class BaseValidator: model.eval() else: callbacks.add_integration_callbacks(self) - self.run_callbacks('on_val_start') - model = AutoBackend(model or self.args.model, - device=select_device(self.args.device, self.args.batch), - dnn=self.args.dnn, - data=self.args.data, - fp16=self.args.half) + model = AutoBackend( + weights=model or self.args.model, + device=select_device(self.args.device, self.args.batch), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, + ) # self.model = model self.device = model.device # update device self.args.half = model.fp16 # update half @@ -134,30 +137,37 @@ class BaseValidator: self.args.batch = model.batch_size elif not pt and not jit: self.args.batch = 1 # export.py models default to batch-size 1 - LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') + LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models") - if isinstance(self.args.data, str) and self.args.data.split('.')[-1] in ('yaml', 'yml'): + if str(self.args.data).split(".")[-1] in ("yaml", "yml"): self.data = check_det_dataset(self.args.data) - elif self.args.task == 'classify': + elif self.args.task == "classify": self.data = check_cls_dataset(self.args.data, split=self.args.split) else: raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌")) - if self.device.type in ('cpu', 'mps'): + if self.device.type in ("cpu", "mps"): self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading if not pt: self.args.rect = False + self.stride = model.stride # used in get_dataloader() for padding self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch) model.eval() model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup - dt = Profile(), Profile(), Profile(), Profile() + self.run_callbacks("on_val_start") + dt = ( + Profile(device=self.device), + Profile(device=self.device), + Profile(device=self.device), + Profile(device=self.device), + ) bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader)) self.init_metrics(de_parallel(model)) self.jdict = [] # empty before each val for batch_i, batch in enumerate(bar): - self.run_callbacks('on_val_batch_start') + self.run_callbacks("on_val_batch_start") self.batch_i = batch_i # Preprocess with dt[0]: @@ -165,7 +175,7 @@ class BaseValidator: # Inference with dt[1]: - preds = model(batch['img'], augment=augment) + preds = model(batch["img"], augment=augment) # Loss with dt[2]: @@ -181,23 +191,32 @@ class BaseValidator: self.plot_val_samples(batch, batch_i) self.plot_predictions(batch, preds, batch_i) - self.run_callbacks('on_val_batch_end') + self.run_callbacks("on_val_batch_end") stats = self.get_stats() self.check_stats(stats) - self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt))) + self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt))) self.finalize_metrics() - self.print_results() - self.run_callbacks('on_val_end') + if not (self.args.save_json and self.is_coco and len(self.jdict)): + self.print_results() + self.run_callbacks("on_val_end") if self.training: model.float() - results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')} + if self.args.save_json and self.jdict: + with open(str(self.save_dir / "predictions.json"), "w") as f: + LOGGER.info(f"Saving {f.name}...") + json.dump(self.jdict, f) # flatten and save + stats = self.eval_json(stats) # update stats + stats['fitness'] = stats['metrics/mAP50-95(B)'] + results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats else: - LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' % - tuple(self.speed.values())) + LOGGER.info( + "Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image" + % tuple(self.speed.values()) + ) if self.args.save_json and self.jdict: - with open(str(self.save_dir / 'predictions.json'), 'w') as f: - LOGGER.info(f'Saving {f.name}...') + with open(str(self.save_dir / "predictions.json"), "w") as f: + LOGGER.info(f"Saving {f.name}...") json.dump(self.jdict, f) # flatten and save stats = self.eval_json(stats) # update stats if self.args.plots or self.args.save_json: @@ -227,6 +246,7 @@ class BaseValidator: if use_scipy: # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708 import scipy # scope import to avoid importing for all commands + cost_matrix = iou * (iou >= threshold) if cost_matrix.any(): labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True) @@ -256,11 +276,11 @@ class BaseValidator: def get_dataloader(self, dataset_path, batch_size): """Get data loader from dataset path and batch size.""" - raise NotImplementedError('get_dataloader function not implemented for this validator') + raise NotImplementedError("get_dataloader function not implemented for this validator") def build_dataset(self, img_path): - """Build dataset""" - raise NotImplementedError('build_dataset function not implemented in validator') + """Build dataset.""" + raise NotImplementedError("build_dataset function not implemented in validator") def preprocess(self, batch): """Preprocesses an input batch.""" @@ -305,7 +325,7 @@ class BaseValidator: def on_plot(self, name, data=None): """Registers plots (e.g. to be consumed in callbacks)""" - self.plots[Path(name)] = {'data': data, 'timestamp': time.time()} + self.plots[Path(name)] = {"data": data, "timestamp": time.time()} # TODO: may need to put these following functions into callback def plot_val_samples(self, batch, ni): diff --git a/ultralytics/hub/__init__.py b/ultralytics/hub/__init__.py index daed439..4ea2fff 100644 --- a/ultralytics/hub/__init__.py +++ b/ultralytics/hub/__init__.py @@ -5,24 +5,51 @@ import requests from ultralytics.data.utils import HUBDatasetStats from ultralytics.hub.auth import Auth from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX -from ultralytics.utils import LOGGER, SETTINGS +from ultralytics.utils import LOGGER, SETTINGS, checks -def login(api_key=''): +def login(api_key: str = None, save=True) -> bool: """ Log in to the Ultralytics HUB API using the provided API key. + The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY + environment variable if successfully authenticated. + Args: - api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id + api_key (str, optional): API key to use for authentication. + If not provided, it will be retrieved from SETTINGS or HUB_API_KEY environment variable. + save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful. - Example: - ```python - from ultralytics import hub - - hub.login('API_KEY') - ``` + Returns: + (bool): True if authentication is successful, False otherwise. """ - Auth(api_key, verbose=True) + checks.check_requirements("hub-sdk>=0.0.6") + from hub_sdk import HUBClient + + api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL + saved_key = SETTINGS.get("api_key") + active_key = api_key or saved_key + credentials = {"api_key": active_key} if active_key and active_key != "" else None # set credentials + + client = HUBClient(credentials) # initialize HUBClient + + if client.authenticated: + # Successfully authenticated with HUB + + if save and client.api_key != saved_key: + SETTINGS.update({"api_key": client.api_key}) # update settings with valid API key + + # Set message based on whether key was provided or retrieved from settings + log_message = ( + "New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅" + ) + LOGGER.info(f"{PREFIX}{log_message}") + + return True + else: + # Failed to authenticate with HUB + LOGGER.info(f"{PREFIX}Get API key from {api_key_url} and then run 'yolo hub login API_KEY'") + return False def logout(): @@ -36,52 +63,53 @@ def logout(): hub.logout() ``` """ - SETTINGS['api_key'] = '' + SETTINGS["api_key"] = "" SETTINGS.save() LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.") -def reset_model(model_id=''): +def reset_model(model_id=""): """Reset a trained model to an untrained state.""" - r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id}) + r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key}) if r.status_code == 200: - LOGGER.info(f'{PREFIX}Model reset successfully') + LOGGER.info(f"{PREFIX}Model reset successfully") return - LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}') + LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}") def export_fmts_hub(): """Returns a list of HUB-supported export formats.""" from ultralytics.engine.exporter import export_formats - return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml'] + + return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"] -def export_model(model_id='', format='torchscript'): +def export_model(model_id="", format="torchscript"): """Export a model to all formats.""" assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}" - r = requests.post(f'{HUB_API_ROOT}/v1/models/{model_id}/export', - json={'format': format}, - headers={'x-api-key': Auth().api_key}) - assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}' - LOGGER.info(f'{PREFIX}{format} export started ✅') + r = requests.post( + f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key} + ) + assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}" + LOGGER.info(f"{PREFIX}{format} export started ✅") -def get_export(model_id='', format='torchscript'): +def get_export(model_id="", format="torchscript"): """Get an exported model dictionary with download URL.""" assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}" - r = requests.post(f'{HUB_API_ROOT}/get-export', - json={ - 'apiKey': Auth().api_key, - 'modelId': model_id, - 'format': format}) - assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}' + r = requests.post( + f"{HUB_API_ROOT}/get-export", + json={"apiKey": Auth().api_key, "modelId": model_id, "format": format}, + headers={"x-api-key": Auth().api_key}, + ) + assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}" return r.json() -def check_dataset(path='', task='detect'): +def check_dataset(path="", task="detect"): """ - Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is - uploaded to the HUB. Usage examples are given below. + Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded + to the HUB. Usage examples are given below. Args: path (str, optional): Path to data.zip (with data.yaml inside data.zip). Defaults to ''. @@ -97,4 +125,4 @@ def check_dataset(path='', task='detect'): ``` """ HUBDatasetStats(path=path, task=task).get_json() - LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.') + LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.") diff --git a/ultralytics/hub/__pycache__/__init__.cpython-312.pyc b/ultralytics/hub/__pycache__/__init__.cpython-312.pyc index 9984227..bef1603 100644 Binary files a/ultralytics/hub/__pycache__/__init__.cpython-312.pyc and b/ultralytics/hub/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/hub/__pycache__/__init__.cpython-39.pyc b/ultralytics/hub/__pycache__/__init__.cpython-39.pyc index 3fbf987..3bb4eff 100644 Binary files a/ultralytics/hub/__pycache__/__init__.cpython-39.pyc and b/ultralytics/hub/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/hub/__pycache__/auth.cpython-312.pyc b/ultralytics/hub/__pycache__/auth.cpython-312.pyc index c1304db..203c243 100644 Binary files a/ultralytics/hub/__pycache__/auth.cpython-312.pyc and b/ultralytics/hub/__pycache__/auth.cpython-312.pyc differ diff --git a/ultralytics/hub/__pycache__/auth.cpython-39.pyc b/ultralytics/hub/__pycache__/auth.cpython-39.pyc index 2a60626..1f909c7 100644 Binary files a/ultralytics/hub/__pycache__/auth.cpython-39.pyc and b/ultralytics/hub/__pycache__/auth.cpython-39.pyc differ diff --git a/ultralytics/hub/__pycache__/utils.cpython-312.pyc b/ultralytics/hub/__pycache__/utils.cpython-312.pyc index b082510..099777d 100644 Binary files a/ultralytics/hub/__pycache__/utils.cpython-312.pyc and b/ultralytics/hub/__pycache__/utils.cpython-312.pyc differ diff --git a/ultralytics/hub/__pycache__/utils.cpython-39.pyc b/ultralytics/hub/__pycache__/utils.cpython-39.pyc index 825ac3b..007109c 100644 Binary files a/ultralytics/hub/__pycache__/utils.cpython-39.pyc and b/ultralytics/hub/__pycache__/utils.cpython-39.pyc differ diff --git a/ultralytics/hub/auth.py b/ultralytics/hub/auth.py index 9963d79..6ede303 100644 --- a/ultralytics/hub/auth.py +++ b/ultralytics/hub/auth.py @@ -5,13 +5,27 @@ import requests from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab -API_KEY_URL = f'{HUB_WEB_ROOT}/settings?tab=api+keys' +API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys" class Auth: + """ + Manages authentication processes including API key handling, cookie-based authentication, and header generation. + + The class supports different methods of authentication: + 1. Directly using an API key. + 2. Authenticating using browser cookies (specifically in Google Colab). + 3. Prompting the user to enter an API key. + + Attributes: + id_token (str or bool): Token used for identity verification, initialized as False. + api_key (str or bool): API key for authentication, initialized as False. + model_key (bool): Placeholder for model key, initialized as False. + """ + id_token = api_key = model_key = False - def __init__(self, api_key='', verbose=False): + def __init__(self, api_key="", verbose=False): """ Initialize the Auth class with an optional API key. @@ -19,18 +33,18 @@ class Auth: api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id """ # Split the input API key in case it contains a combined key_model and keep only the API key part - api_key = api_key.split('_')[0] + api_key = api_key.split("_")[0] # Set API key attribute as value passed or SETTINGS API key if none passed - self.api_key = api_key or SETTINGS.get('api_key', '') + self.api_key = api_key or SETTINGS.get("api_key", "") # If an API key is provided if self.api_key: # If the provided API key matches the API key in the SETTINGS - if self.api_key == SETTINGS.get('api_key'): + if self.api_key == SETTINGS.get("api_key"): # Log that the user is already logged in if verbose: - LOGGER.info(f'{PREFIX}Authenticated ✅') + LOGGER.info(f"{PREFIX}Authenticated ✅") return else: # Attempt to authenticate with the provided API key @@ -45,62 +59,65 @@ class Auth: # Update SETTINGS with the new API key after successful authentication if success: - SETTINGS.update({'api_key': self.api_key}) + SETTINGS.update({"api_key": self.api_key}) # Log that the new login was successful if verbose: - LOGGER.info(f'{PREFIX}New authentication successful ✅') + LOGGER.info(f"{PREFIX}New authentication successful ✅") elif verbose: - LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}') + LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo hub login API_KEY'") def request_api_key(self, max_attempts=3): """ - Prompt the user to input their API key. Returns the model ID. + Prompt the user to input their API key. + + Returns the model ID. """ import getpass + for attempts in range(max_attempts): - LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}') - input_key = getpass.getpass(f'Enter API key from {API_KEY_URL} ') - self.api_key = input_key.split('_')[0] # remove model id if present + LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}") + input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ") + self.api_key = input_key.split("_")[0] # remove model id if present if self.authenticate(): return True - raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌')) + raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌")) def authenticate(self) -> bool: """ Attempt to authenticate with the server using either id_token or API key. Returns: - bool: True if authentication is successful, False otherwise. + (bool): True if authentication is successful, False otherwise. """ try: if header := self.get_auth_header(): - r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header) - if not r.json().get('success', False): - raise ConnectionError('Unable to authenticate.') + r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header) + if not r.json().get("success", False): + raise ConnectionError("Unable to authenticate.") return True - raise ConnectionError('User has not authenticated locally.') + raise ConnectionError("User has not authenticated locally.") except ConnectionError: self.id_token = self.api_key = False # reset invalid - LOGGER.warning(f'{PREFIX}Invalid API key ⚠️') + LOGGER.warning(f"{PREFIX}Invalid API key ⚠️") return False def auth_with_cookies(self) -> bool: """ - Attempt to fetch authentication via cookies and set id_token. - User must be logged in to HUB and running in a supported browser. + Attempt to fetch authentication via cookies and set id_token. User must be logged in to HUB and running in a + supported browser. Returns: - bool: True if authentication is successful, False otherwise. + (bool): True if authentication is successful, False otherwise. """ if not is_colab(): return False # Currently only works with Colab try: - authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto') - if authn.get('success', False): - self.id_token = authn.get('data', {}).get('idToken', None) + authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto") + if authn.get("success", False): + self.id_token = authn.get("data", {}).get("idToken", None) self.authenticate() return True - raise ConnectionError('Unable to fetch browser authentication details.') + raise ConnectionError("Unable to fetch browser authentication details.") except ConnectionError: self.id_token = False # reset invalid return False @@ -113,7 +130,7 @@ class Auth: (dict): The authentication header if id_token or API key is set, None otherwise. """ if self.id_token: - return {'authorization': f'Bearer {self.id_token}'} + return {"authorization": f"Bearer {self.id_token}"} elif self.api_key: - return {'x-api-key': self.api_key} + return {"x-api-key": self.api_key} # else returns None diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index 595de29..ebde7aa 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -1,29 +1,26 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -import signal -import sys +import threading +import time +from http import HTTPStatus from pathlib import Path -from time import sleep import requests -from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, smart_request -from ultralytics.utils import LOGGER, __version__, checks, emojis, is_colab, threaded +from ultralytics.hub.utils import HUB_WEB_ROOT, HELP_MSG, PREFIX, TQDM +from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab from ultralytics.utils.errors import HUBModelError -AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local' +AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local" class HUBTrainingSession: """ HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing. - Args: - url (str): Model identifier used to initialize the HUB training session. - Attributes: agent_id (str): Identifier for the instance communicating with the server. - model_id (str): Identifier for the YOLOv5 model being trained. + model_id (str): Identifier for the YOLO model being trained. model_url (str): URL for the model in Ultralytics HUB. api_url (str): API URL for the model in Ultralytics HUB. auth_header (dict): Authentication header for the Ultralytics HUB API requests. @@ -34,110 +31,287 @@ class HUBTrainingSession: alive (bool): Indicates if the heartbeat loop is active. """ - def __init__(self, url): + def __init__(self, identifier): """ Initialize the HUBTrainingSession with the provided model identifier. Args: - url (str): Model identifier used to initialize the HUB training session. - It can be a URL string or a model key with specific format. + identifier (str): Model identifier used to initialize the HUB training session. + It can be a URL string or a model key with specific format. Raises: ValueError: If the provided model identifier is invalid. ConnectionError: If connecting with global API key is not supported. + ModuleNotFoundError: If hub-sdk package is not installed. """ + from hub_sdk import HUBClient - from ultralytics.hub.auth import Auth + self.rate_limits = { + "metrics": 3.0, + "ckpt": 900.0, + "heartbeat": 300.0, + } # rate limits (seconds) + self.metrics_queue = {} # holds metrics for each epoch until upload + self.metrics_upload_failed_queue = {} # holds metrics for each epoch if upload failed + self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py # Parse input - if url.startswith(f'{HUB_WEB_ROOT}/models/'): - url = url.split(f'{HUB_WEB_ROOT}/models/')[-1] - if [len(x) for x in url.split('_')] == [42, 20]: - key, model_id = url.split('_') - elif len(url) == 20: - key, model_id = '', url + api_key, model_id, self.filename = self._parse_identifier(identifier) + + # Get credentials + active_key = api_key or SETTINGS.get("api_key") + credentials = {"api_key": active_key} if active_key else None # set credentials + + # Initialize client + self.client = HUBClient(credentials) + + if model_id: + self.load_model(model_id) # load existing model else: - raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. " - f"model='{HUB_WEB_ROOT}/models/MODEL_ID' and try again.") + self.model = self.client.model() # load empty model - # Authorize - auth = Auth(key) - self.agent_id = None # identifies which instance is communicating with server - self.model_id = model_id - self.model_url = f'{HUB_WEB_ROOT}/models/{model_id}' - self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}' - self.auth_header = auth.get_auth_header() - self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds) - self.timers = {} # rate limit timers (seconds) - self.metrics_queue = {} # metrics queue - self.model = self._get_model() - self.alive = True - self._start_heartbeat() # start heartbeats - self._register_signal_handlers() - LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀') + def load_model(self, model_id): + """Loads an existing model from Ultralytics HUB using the provided model identifier.""" + self.model = self.client.model(model_id) + if not self.model.data: # then model does not exist + raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling - def _register_signal_handlers(self): - """Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination.""" - signal.signal(signal.SIGTERM, self._handle_signal) - signal.signal(signal.SIGINT, self._handle_signal) + self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" - def _handle_signal(self, signum, frame): + self._set_train_args() + + # Start heartbeats for HUB to monitor agent + self.model.start_heartbeat(self.rate_limits["heartbeat"]) + LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") + + def create_model(self, model_args): + """Initializes a HUB training session with the specified model identifier.""" + payload = { + "config": { + "batchSize": model_args.get("batch", -1), + "epochs": model_args.get("epochs", 300), + "imageSize": model_args.get("imgsz", 640), + "patience": model_args.get("patience", 100), + "device": model_args.get("device", ""), + "cache": model_args.get("cache", "ram"), + }, + "dataset": {"name": model_args.get("data")}, + "lineage": { + "architecture": { + "name": self.filename.replace(".pt", "").replace(".yaml", ""), + }, + "parent": {}, + }, + "meta": {"name": self.filename}, + } + + if self.filename.endswith(".pt"): + payload["lineage"]["parent"]["name"] = self.filename + + self.model.create_model(payload) + + # Model could not be created + # TODO: improve error handling + if not self.model.id: + return + + self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" + + # Start heartbeats for HUB to monitor agent + self.model.start_heartbeat(self.rate_limits["heartbeat"]) + + LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") + + def _parse_identifier(self, identifier): """ - Handle kill signals and prevent heartbeats from being sent on Colab after termination. - This method does not use frame, it is included as it is passed by signal. - """ - if self.alive is True: - LOGGER.info(f'{PREFIX}Kill signal received! ❌') - self._stop_heartbeat() - sys.exit(signum) + Parses the given identifier to determine the type of identifier and extract relevant components. - def _stop_heartbeat(self): - """Terminate the heartbeat loop.""" - self.alive = False + The method supports different identifier formats: + - A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/' + - An identifier containing an API key and a model ID separated by an underscore + - An identifier that is solely a model ID of a fixed length + - A local filename that ends with '.pt' or '.yaml' + + Args: + identifier (str): The identifier string to be parsed. + + Returns: + (tuple): A tuple containing the API key, model ID, and filename as applicable. + + Raises: + HUBModelError: If the identifier format is not recognized. + """ + + # Initialize variables + api_key, model_id, filename = None, None, None + + # Check if identifier is a HUB URL + if identifier.startswith(f"{HUB_WEB_ROOT}/models/"): + # Extract the model_id after the HUB_WEB_ROOT URL + model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1] + else: + # Split the identifier based on underscores only if it's not a HUB URL + parts = identifier.split("_") + + # Check if identifier is in the format of API key and model ID + if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20: + api_key, model_id = parts + # Check if identifier is a single model ID + elif len(parts) == 1 and len(parts[0]) == 20: + model_id = parts[0] + # Check if identifier is a local filename + elif identifier.endswith(".pt") or identifier.endswith(".yaml"): + filename = identifier + else: + raise HUBModelError( + f"model='{identifier}' could not be parsed. Check format is correct. " + f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file." + ) + + return api_key, model_id, filename + + def _set_train_args(self): + """ + Initializes training arguments and creates a model entry on the Ultralytics HUB. + + This method sets up training arguments based on the model's state and updates them with any additional + arguments provided. It handles different states of the model, such as whether it's resumable, pretrained, + or requires specific file setup. + + Raises: + ValueError: If the model is already trained, if required dataset information is missing, or if there are + issues with the provided training arguments. + """ + if self.model.is_trained(): + raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀")) + + if self.model.is_resumable(): + # Model has saved weights + self.train_args = {"data": self.model.get_dataset_url(), "resume": True} + self.model_file = self.model.get_weights_url("last") + else: + # Model has no saved weights + self.train_args = self.model.data.get("train_args") # new response + + # Set the model file as either a *.pt or *.yaml file + self.model_file = ( + self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture() + ) + + if "data" not in self.train_args: + # RF bug - datasets are sometimes not exported + raise ValueError("Dataset may still be processing. Please wait a minute and try again.") + + self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u + self.model_id = self.model.id + + def request_queue( + self, + request_func, + retry=3, + timeout=30, + thread=True, + verbose=True, + progress_total=None, + *args, + **kwargs, + ): + def retry_request(): + """Attempts to call `request_func` with retries, timeout, and optional threading.""" + t0 = time.time() # Record the start time for the timeout + for i in range(retry + 1): + if (time.time() - t0) > timeout: + LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}") + break # Timeout reached, exit loop + + response = request_func(*args, **kwargs) + if response is None: + LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}") + time.sleep(2**i) # Exponential backoff before retrying + continue # Skip further processing and retry + + if progress_total: + self._show_upload_progress(progress_total, response) + + if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES: + # if request related to metrics upload + if kwargs.get("metrics"): + self.metrics_upload_failed_queue = {} + return response # Success, no need to retry + + if i == 0: + # Initial attempt, check status code and provide messages + message = self._get_failure_message(response, retry, timeout) + + if verbose: + LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})") + + if not self._should_retry(response.status_code): + LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}") + break # Not an error that should be retried, exit loop + + time.sleep(2**i) # Exponential backoff for retries + + # if request related to metrics upload and exceed retries + if response is None and kwargs.get("metrics"): + self.metrics_upload_failed_queue.update(kwargs.get("metrics", None)) + + return response + + if thread: + # Start a new thread to run the retry_request function + threading.Thread(target=retry_request, daemon=True).start() + else: + # If running in the main thread, call retry_request directly + return retry_request() + + def _should_retry(self, status_code): + """Determines if a request should be retried based on the HTTP status code.""" + retry_codes = { + HTTPStatus.REQUEST_TIMEOUT, + HTTPStatus.BAD_GATEWAY, + HTTPStatus.GATEWAY_TIMEOUT, + } + return status_code in retry_codes + + def _get_failure_message(self, response: requests.Response, retry: int, timeout: int): + """ + Generate a retry message based on the response status code. + + Args: + response: The HTTP response object. + retry: The number of retry attempts allowed. + timeout: The maximum timeout duration. + + Returns: + (str): The retry message. + """ + if self._should_retry(response.status_code): + return f"Retrying {retry}x for {timeout}s." if retry else "" + elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit + headers = response.headers + return ( + f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). " + f"Please retry after {headers['Retry-After']}s." + ) + else: + try: + return response.json().get("message", "No JSON message.") + except AttributeError: + return "Unable to read JSON." def upload_metrics(self): """Upload model metrics to Ultralytics HUB.""" - payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'} - smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2) + return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True) - def _get_model(self): - """Fetch and return model data from Ultralytics HUB.""" - api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}' - - try: - response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0) - data = response.json().get('data', None) - - if data.get('status', None) == 'trained': - raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀')) - - if not data.get('data', None): - raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix - self.model_id = data['id'] - - if data['status'] == 'new': # new model to start training - self.train_args = { - # TODO: deprecate 'batch_size' key for 'batch' in 3Q23 - 'batch': data['batch' if ('batch' in data) else 'batch_size'], - 'epochs': data['epochs'], - 'imgsz': data['imgsz'], - 'patience': data['patience'], - 'device': data['device'], - 'cache': data['cache'], - 'data': data['data']} - self.model_file = data.get('cfg') or data.get('weights') # cfg for pretrained=False - self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u - elif data['status'] == 'training': # existing model to resume training - self.train_args = {'data': data['data'], 'resume': True} - self.model_file = data['resume'] - - return data - except requests.exceptions.ConnectionError as e: - raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e - except Exception: - raise - - def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): + def upload_model( + self, + epoch: int, + weights: str, + is_best: bool = False, + map: float = 0.0, + final: bool = False, + ) -> None: """ Upload a model checkpoint to Ultralytics HUB. @@ -149,42 +323,33 @@ class HUBTrainingSession: final (bool): Indicates if the model is the final model after training. """ if Path(weights).is_file(): - with open(weights, 'rb') as f: - file = f.read() + progress_total = Path(weights).stat().st_size if final else None # Only show progress if final + self.request_queue( + self.model.upload_model, + epoch=epoch, + weights=weights, + is_best=is_best, + map=map, + final=final, + retry=10, + timeout=3600, + thread=not final, + progress_total=progress_total, + ) else: - LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.') - file = None - url = f'{self.api_url}/upload' - # url = 'http://httpbin.org/post' # for debug - data = {'epoch': epoch} - if final: - data.update({'type': 'final', 'map': map}) - smart_request('post', - url, - data=data, - files={'best.pt': file}, - headers=self.auth_header, - retry=10, - timeout=3600, - thread=False, - progress=True, - code=4) - else: - data.update({'type': 'epoch', 'isBest': bool(is_best)}) - smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3) + LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.") - @threaded - def _start_heartbeat(self): - """Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB.""" - while self.alive: - r = smart_request('post', - f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}', - json={ - 'agent': AGENT_NAME, - 'agentId': self.agent_id}, - headers=self.auth_header, - retry=0, - code=5, - thread=False) # already in a thread - self.agent_id = r.json().get('data', {}).get('agentId', None) - sleep(self.rate_limits['heartbeat']) + def _show_upload_progress(self, content_length: int, response: requests.Response) -> None: + """ + Display a progress bar to track the upload progress of a file download. + + Args: + content_length (int): The total size of the content to be downloaded in bytes. + response (requests.Response): The response object from the file download request. + + Returns: + None + """ + with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar: + for data in response.iter_content(chunk_size=1024): + pbar.update(len(data)) diff --git a/ultralytics/hub/utils.py b/ultralytics/hub/utils.py index 07da970..5c00076 100644 --- a/ultralytics/hub/utils.py +++ b/ultralytics/hub/utils.py @@ -10,14 +10,29 @@ from pathlib import Path import requests -from ultralytics.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM, TryExcept, __version__, - colorstr, get_git_origin_url, is_colab, is_git_dir, is_pip_package) +from ultralytics.utils import ( + ENVIRONMENT, + LOGGER, + ONLINE, + RANK, + SETTINGS, + TESTS_RUNNING, + TQDM, + TryExcept, + __version__, + colorstr, + get_git_origin_url, + is_colab, + is_git_dir, + is_pip_package, +) from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES -PREFIX = colorstr('Ultralytics HUB: ') -HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.' -HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com') -HUB_WEB_ROOT = os.environ.get('ULTRALYTICS_HUB_WEB', 'https://hub.ultralytics.com') +HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com") +HUB_WEB_ROOT = os.environ.get("ULTRALYTICS_HUB_WEB", "https://hub.ultralytics.com") + +PREFIX = colorstr("Ultralytics HUB: ") +HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance." def request_with_credentials(url: str) -> any: @@ -34,11 +49,13 @@ def request_with_credentials(url: str) -> any: OSError: If the function is not run in a Google Colab environment. """ if not is_colab(): - raise OSError('request_with_credentials() must run in a Colab environment') + raise OSError("request_with_credentials() must run in a Colab environment") from google.colab import output # noqa from IPython import display # noqa + display.display( - display.Javascript(""" + display.Javascript( + """ window._hub_tmp = new Promise((resolve, reject) => { const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000) fetch("%s", { @@ -53,8 +70,11 @@ def request_with_credentials(url: str) -> any: reject(err); }); }); - """ % url)) - return output.eval_js('_hub_tmp') + """ + % url + ) + ) + return output.eval_js("_hub_tmp") def requests_with_progress(method, url, **kwargs): @@ -64,22 +84,23 @@ def requests_with_progress(method, url, **kwargs): Args: method (str): The HTTP method to use (e.g. 'GET', 'POST'). url (str): The URL to send the request to. - **kwargs (dict): Additional keyword arguments to pass to the underlying `requests.request` function. + **kwargs (any): Additional keyword arguments to pass to the underlying `requests.request` function. Returns: (requests.Response): The response object from the HTTP request. Note: - If 'progress' is set to True, the progress bar will display the download progress - for responses with a known content length. + - If 'progress' is set to True, the progress bar will display the download progress for responses with a known + content length. + - If 'progress' is a number then progress bar will display assuming content length = progress. """ - progress = kwargs.pop('progress', False) + progress = kwargs.pop("progress", False) if not progress: return requests.request(method, url, **kwargs) response = requests.request(method, url, stream=True, **kwargs) - total = int(response.headers.get('content-length', 0)) # total size + total = int(response.headers.get("content-length", 0) if isinstance(progress, bool) else progress) # total size try: - pbar = TQDM(total=total, unit='B', unit_scale=True, unit_divisor=1024) + pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024) for data in response.iter_content(chunk_size=1024): pbar.update(len(data)) pbar.close() @@ -101,7 +122,7 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos code (int, optional): An identifier for the request, used for logging purposes. Default is -1. verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True. progress (bool, optional): Whether to show a progress bar during the request. Default is False. - **kwargs (dict): Keyword arguments to be passed to the requests function specified in method. + **kwargs (any): Keyword arguments to be passed to the requests function specified in method. Returns: (requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None. @@ -120,25 +141,27 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful" break try: - m = r.json().get('message', 'No JSON message.') + m = r.json().get("message", "No JSON message.") except AttributeError: - m = 'Unable to read JSON.' + m = "Unable to read JSON." if i == 0: if r.status_code in retry_codes: - m += f' Retrying {retry}x for {timeout}s.' if retry else '' + m += f" Retrying {retry}x for {timeout}s." if retry else "" elif r.status_code == 429: # rate limit h = r.headers # response headers - m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \ + m = ( + f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " f"Please retry after {h['Retry-After']}s." + ) if verbose: - LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})') + LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})") if r.status_code not in retry_codes: return r - time.sleep(2 ** i) # exponential standoff + time.sleep(2**i) # exponential standoff return r args = method, url - kwargs['progress'] = progress + kwargs["progress"] = progress if thread: threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start() else: @@ -157,29 +180,29 @@ class Events: enabled (bool): A flag to enable or disable Events based on certain conditions. """ - url = 'https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw' + url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw" def __init__(self): - """ - Initializes the Events object with default values for events, rate_limit, and metadata. - """ + """Initializes the Events object with default values for events, rate_limit, and metadata.""" self.events = [] # events list self.rate_limit = 60.0 # rate limit (seconds) self.t = 0.0 # rate limit timer (seconds) self.metadata = { - 'cli': Path(sys.argv[0]).name == 'yolo', - 'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other', - 'python': '.'.join(platform.python_version_tuple()[:2]), # i.e. 3.10 - 'version': __version__, - 'env': ENVIRONMENT, - 'session_id': round(random.random() * 1E15), - 'engagement_time_msec': 1000} - self.enabled = \ - SETTINGS['sync'] and \ - RANK in (-1, 0) and \ - not TESTS_RUNNING and \ - ONLINE and \ - (is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git') + "cli": Path(sys.argv[0]).name == "yolo", + "install": "git" if is_git_dir() else "pip" if is_pip_package() else "other", + "python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10 + "version": __version__, + "env": ENVIRONMENT, + "session_id": round(random.random() * 1e15), + "engagement_time_msec": 1000, + } + self.enabled = ( + SETTINGS["sync"] + and RANK in (-1, 0) + and not TESTS_RUNNING + and ONLINE + and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git") + ) def __call__(self, cfg): """ @@ -195,11 +218,13 @@ class Events: # Attempt to add to events if len(self.events) < 25: # Events list limited to 25 events (drop any events past this) params = { - **self.metadata, 'task': cfg.task, - 'model': cfg.model if cfg.model in GITHUB_ASSETS_NAMES else 'custom'} - if cfg.mode == 'export': - params['format'] = cfg.format - self.events.append({'name': cfg.mode, 'params': params}) + **self.metadata, + "task": cfg.task, + "model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom", + } + if cfg.mode == "export": + params["format"] = cfg.format + self.events.append({"name": cfg.mode, "params": params}) # Check rate limit t = time.time() @@ -208,10 +233,10 @@ class Events: return # Time is over rate limiter, send now - data = {'client_id': SETTINGS['uuid'], 'events': self.events} # SHA-256 anonymized UUID hash and events list + data = {"client_id": SETTINGS["uuid"], "events": self.events} # SHA-256 anonymized UUID hash and events list # POST equivalent to requests.post(self.url, json=data) - smart_request('post', self.url, json=data, retry=0, verbose=False) + smart_request("post", self.url, json=data, retry=0, verbose=False) # Reset events and rate limit timer self.events = [] diff --git a/ultralytics/models/__init__.py b/ultralytics/models/__init__.py index e96f893..42de3fb 100644 --- a/ultralytics/models/__init__.py +++ b/ultralytics/models/__init__.py @@ -2,6 +2,7 @@ from .rtdetr import RTDETR from .sam import SAM -from .yolo import YOLO +from .yolo import YOLO, YOLOWorld +from .yolov10 import YOLOv10 -__all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import +__all__ = "YOLO", "RTDETR", "SAM", "YOLOWorld", "YOLOv10" # allow simpler import diff --git a/ultralytics/models/__pycache__/__init__.cpython-312.pyc b/ultralytics/models/__pycache__/__init__.cpython-312.pyc index 0024398..d0c12fb 100644 Binary files a/ultralytics/models/__pycache__/__init__.cpython-312.pyc and b/ultralytics/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/models/__pycache__/__init__.cpython-39.pyc b/ultralytics/models/__pycache__/__init__.cpython-39.pyc index 213a9b8..4b37be5 100644 Binary files a/ultralytics/models/__pycache__/__init__.cpython-39.pyc and b/ultralytics/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/models/fastsam/__init__.py b/ultralytics/models/fastsam/__init__.py index 8f47772..eabf5b9 100644 --- a/ultralytics/models/fastsam/__init__.py +++ b/ultralytics/models/fastsam/__init__.py @@ -5,4 +5,4 @@ from .predict import FastSAMPredictor from .prompt import FastSAMPrompt from .val import FastSAMValidator -__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator' +__all__ = "FastSAMPredictor", "FastSAM", "FastSAMPrompt", "FastSAMValidator" diff --git a/ultralytics/models/fastsam/__pycache__/__init__.cpython-312.pyc b/ultralytics/models/fastsam/__pycache__/__init__.cpython-312.pyc index 12c40d4..2851e25 100644 Binary files a/ultralytics/models/fastsam/__pycache__/__init__.cpython-312.pyc and b/ultralytics/models/fastsam/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/models/fastsam/__pycache__/__init__.cpython-39.pyc b/ultralytics/models/fastsam/__pycache__/__init__.cpython-39.pyc index fa55f52..d2dff01 100644 Binary files a/ultralytics/models/fastsam/__pycache__/__init__.cpython-39.pyc and b/ultralytics/models/fastsam/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/models/fastsam/__pycache__/model.cpython-312.pyc b/ultralytics/models/fastsam/__pycache__/model.cpython-312.pyc index f95fff0..05a4e0d 100644 Binary files a/ultralytics/models/fastsam/__pycache__/model.cpython-312.pyc and b/ultralytics/models/fastsam/__pycache__/model.cpython-312.pyc differ diff --git a/ultralytics/models/fastsam/__pycache__/model.cpython-39.pyc b/ultralytics/models/fastsam/__pycache__/model.cpython-39.pyc index 79fbabd..05ea098 100644 Binary files a/ultralytics/models/fastsam/__pycache__/model.cpython-39.pyc and b/ultralytics/models/fastsam/__pycache__/model.cpython-39.pyc differ diff --git a/ultralytics/models/fastsam/__pycache__/predict.cpython-312.pyc b/ultralytics/models/fastsam/__pycache__/predict.cpython-312.pyc index 7c7958f..005e2cf 100644 Binary files a/ultralytics/models/fastsam/__pycache__/predict.cpython-312.pyc and b/ultralytics/models/fastsam/__pycache__/predict.cpython-312.pyc differ diff --git a/ultralytics/models/fastsam/__pycache__/predict.cpython-39.pyc b/ultralytics/models/fastsam/__pycache__/predict.cpython-39.pyc index dab61b7..d526f37 100644 Binary files a/ultralytics/models/fastsam/__pycache__/predict.cpython-39.pyc and b/ultralytics/models/fastsam/__pycache__/predict.cpython-39.pyc differ diff --git a/ultralytics/models/fastsam/__pycache__/prompt.cpython-312.pyc b/ultralytics/models/fastsam/__pycache__/prompt.cpython-312.pyc index 4706636..944eedb 100644 Binary files a/ultralytics/models/fastsam/__pycache__/prompt.cpython-312.pyc and b/ultralytics/models/fastsam/__pycache__/prompt.cpython-312.pyc differ diff --git a/ultralytics/models/fastsam/__pycache__/prompt.cpython-39.pyc b/ultralytics/models/fastsam/__pycache__/prompt.cpython-39.pyc index 59f365a..c4fd184 100644 Binary files a/ultralytics/models/fastsam/__pycache__/prompt.cpython-39.pyc and b/ultralytics/models/fastsam/__pycache__/prompt.cpython-39.pyc differ diff --git a/ultralytics/models/fastsam/__pycache__/utils.cpython-312.pyc b/ultralytics/models/fastsam/__pycache__/utils.cpython-312.pyc index 602a6d6..bf0b4e7 100644 Binary files a/ultralytics/models/fastsam/__pycache__/utils.cpython-312.pyc and b/ultralytics/models/fastsam/__pycache__/utils.cpython-312.pyc differ diff --git a/ultralytics/models/fastsam/__pycache__/utils.cpython-39.pyc b/ultralytics/models/fastsam/__pycache__/utils.cpython-39.pyc index 71db99a..2760a0e 100644 Binary files a/ultralytics/models/fastsam/__pycache__/utils.cpython-39.pyc and b/ultralytics/models/fastsam/__pycache__/utils.cpython-39.pyc differ diff --git a/ultralytics/models/fastsam/__pycache__/val.cpython-312.pyc b/ultralytics/models/fastsam/__pycache__/val.cpython-312.pyc index 7d57672..ea2a82c 100644 Binary files a/ultralytics/models/fastsam/__pycache__/val.cpython-312.pyc and b/ultralytics/models/fastsam/__pycache__/val.cpython-312.pyc differ diff --git a/ultralytics/models/fastsam/__pycache__/val.cpython-39.pyc b/ultralytics/models/fastsam/__pycache__/val.cpython-39.pyc index e69b299..5ca0872 100644 Binary files a/ultralytics/models/fastsam/__pycache__/val.cpython-39.pyc and b/ultralytics/models/fastsam/__pycache__/val.cpython-39.pyc differ diff --git a/ultralytics/models/fastsam/model.py b/ultralytics/models/fastsam/model.py index c1895fc..c01e66b 100644 --- a/ultralytics/models/fastsam/model.py +++ b/ultralytics/models/fastsam/model.py @@ -3,7 +3,6 @@ from pathlib import Path from ultralytics.engine.model import Model - from .predict import FastSAMPredictor from .val import FastSAMValidator @@ -21,13 +20,14 @@ class FastSAM(Model): ``` """ - def __init__(self, model='FastSAM-x.pt'): - """Call the __init__ method of the parent class (YOLO) with the updated default model""" - if str(model) == 'FastSAM.pt': - model = 'FastSAM-x.pt' - assert Path(model).suffix not in ('.yaml', '.yml'), 'FastSAM models only support pre-trained models.' - super().__init__(model=model, task='segment') + def __init__(self, model="FastSAM-x.pt"): + """Call the __init__ method of the parent class (YOLO) with the updated default model.""" + if str(model) == "FastSAM.pt": + model = "FastSAM-x.pt" + assert Path(model).suffix not in (".yaml", ".yml"), "FastSAM models only support pre-trained models." + super().__init__(model=model, task="segment") @property def task_map(self): - return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}} + """Returns a dictionary mapping segment task to corresponding predictor and validator classes.""" + return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}} diff --git a/ultralytics/models/fastsam/predict.py b/ultralytics/models/fastsam/predict.py index f94a173..0ef1803 100644 --- a/ultralytics/models/fastsam/predict.py +++ b/ultralytics/models/fastsam/predict.py @@ -9,19 +9,54 @@ from ultralytics.utils import DEFAULT_CFG, ops class FastSAMPredictor(DetectionPredictor): + """ + FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics + YOLO framework. + + This class extends the DetectionPredictor, customizing the prediction pipeline specifically for fast SAM. + It adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing + for single-class segmentation. + + Attributes: + cfg (dict): Configuration parameters for prediction. + overrides (dict, optional): Optional parameter overrides for custom behavior. + _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction. + """ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initializes the FastSAMPredictor class, inheriting from DetectionPredictor and setting the task to 'segment'. + + Args: + cfg (dict): Configuration parameters for prediction. + overrides (dict, optional): Optional parameter overrides for custom behavior. + _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction. + """ super().__init__(cfg, overrides, _callbacks) - self.args.task = 'segment' + self.args.task = "segment" def postprocess(self, preds, img, orig_imgs): - p = ops.non_max_suppression(preds[0], - self.args.conf, - self.args.iou, - agnostic=self.args.agnostic_nms, - max_det=self.args.max_det, - nc=len(self.model.names), - classes=self.args.classes) + """ + Perform post-processing steps on predictions, including non-max suppression and scaling boxes to original image + size, and returns the final results. + + Args: + preds (list): The raw output predictions from the model. + img (torch.Tensor): The processed image tensor. + orig_imgs (list | torch.Tensor): The original image or list of images. + + Returns: + (list): A list of Results objects, each containing processed boxes, masks, and other metadata. + """ + p = ops.non_max_suppression( + preds[0], + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + nc=1, # set to 1 class since SAM has no class predictions + classes=self.args.classes, + ) full_box = torch.zeros(p[0].shape[1], device=p[0].device) full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 full_box = full_box.view(1, -1) diff --git a/ultralytics/models/fastsam/prompt.py b/ultralytics/models/fastsam/prompt.py index 9d5ae25..f7bf5ad 100644 --- a/ultralytics/models/fastsam/prompt.py +++ b/ultralytics/models/fastsam/prompt.py @@ -13,54 +13,73 @@ from ultralytics.utils import TQDM class FastSAMPrompt: + """ + Fast Segment Anything Model class for image annotation and visualization. - def __init__(self, source, results, device='cuda') -> None: + Attributes: + device (str): Computing device ('cuda' or 'cpu'). + results: Object detection or segmentation results. + source: Source image or image path. + clip: CLIP model for linear assignment. + """ + + def __init__(self, source, results, device="cuda") -> None: + """Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment.""" self.device = device self.results = results self.source = source # Import and assign clip try: - import clip # for linear_assignment + import clip except ImportError: from ultralytics.utils.checks import check_requirements - check_requirements('git+https://github.com/openai/CLIP.git') + + check_requirements("git+https://github.com/openai/CLIP.git") import clip self.clip = clip @staticmethod def _segment_image(image, bbox): + """Segments the given image according to the provided bounding box coordinates.""" image_array = np.array(image) segmented_image_array = np.zeros_like(image_array) x1, y1, x2, y2 = bbox segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2] segmented_image = Image.fromarray(segmented_image_array) - black_image = Image.new('RGB', image.size, (255, 255, 255)) + black_image = Image.new("RGB", image.size, (255, 255, 255)) # transparency_mask = np.zeros_like((), dtype=np.uint8) transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8) transparency_mask[y1:y2, x1:x2] = 255 - transparency_mask_image = Image.fromarray(transparency_mask, mode='L') + transparency_mask_image = Image.fromarray(transparency_mask, mode="L") black_image.paste(segmented_image, mask=transparency_mask_image) return black_image @staticmethod def _format_results(result, filter=0): + """Formats detection results into list of annotations each containing ID, segmentation, bounding box, score and + area. + """ annotations = [] n = len(result.masks.data) if result.masks is not None else 0 for i in range(n): mask = result.masks.data[i] == 1.0 if torch.sum(mask) >= filter: annotation = { - 'id': i, - 'segmentation': mask.cpu().numpy(), - 'bbox': result.boxes.data[i], - 'score': result.boxes.conf[i]} - annotation['area'] = annotation['segmentation'].sum() + "id": i, + "segmentation": mask.cpu().numpy(), + "bbox": result.boxes.data[i], + "score": result.boxes.conf[i], + } + annotation["area"] = annotation["segmentation"].sum() annotations.append(annotation) return annotations @staticmethod def _get_bbox_from_mask(mask): + """Applies morphological transformations to the mask, displays it, and if with_contours is True, draws + contours. + """ mask = mask.astype(np.uint8) contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) x1, y1, w, h = cv2.boundingRect(contours[0]) @@ -74,22 +93,38 @@ class FastSAMPrompt: y2 = max(y2, y_t + h_t) return [x1, y1, x2, y2] - def plot(self, - annotations, - output, - bbox=None, - points=None, - point_label=None, - mask_random_color=True, - better_quality=True, - retina=False, - with_contours=True): + def plot( + self, + annotations, + output, + bbox=None, + points=None, + point_label=None, + mask_random_color=True, + better_quality=True, + retina=False, + with_contours=True, + ): + """ + Plots annotations, bounding boxes, and points on images and saves the output. + + Args: + annotations (list): Annotations to be plotted. + output (str or Path): Output directory for saving the plots. + bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None. + points (list, optional): Points to be plotted. Defaults to None. + point_label (list, optional): Labels for the points. Defaults to None. + mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True. + better_quality (bool, optional): Whether to apply morphological transformations for better mask quality. Defaults to True. + retina (bool, optional): Whether to use retina mask. Defaults to False. + with_contours (bool, optional): Whether to plot contours. Defaults to True. + """ pbar = TQDM(annotations, total=len(annotations)) for ann in pbar: result_name = os.path.basename(ann.path) - image = ann.orig_img + image = ann.orig_img[..., ::-1] # BGR to RGB original_h, original_w = ann.orig_shape - # for macOS only + # For macOS only # plt.switch_backend('TkAgg') plt.figure(figsize=(original_w / 100, original_h / 100)) # Add subplot with no margin. @@ -134,19 +169,13 @@ class FastSAMPrompt: contour_mask = temp / 255 * color.reshape(1, 1, -1) plt.imshow(contour_mask) - plt.axis('off') - fig = plt.gcf() - - # Check if the canvas has been drawn - if fig.canvas.get_renderer() is None: # macOS requires this or tests fail - fig.canvas.draw() - + # Save the figure save_path = Path(output) / result_name save_path.parent.mkdir(exist_ok=True, parents=True) - image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) - image.save(save_path) + plt.axis("off") + plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True) plt.close() - pbar.set_description(f'Saving {result_name} to {save_path}') + pbar.set_description(f"Saving {result_name} to {save_path}") @staticmethod def fast_show_mask( @@ -160,6 +189,20 @@ class FastSAMPrompt: target_height=960, target_width=960, ): + """ + Quickly shows the mask annotations on the given matplotlib axis. + + Args: + annotation (array-like): Mask annotation. + ax (matplotlib.axes.Axes): Matplotlib axis. + random_color (bool, optional): Whether to use random color for masks. Defaults to False. + bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None. + points (list, optional): Points to be plotted. Defaults to None. + pointlabel (list, optional): Labels for the points. Defaults to None. + retinamask (bool, optional): Whether to use retina mask. Defaults to True. + target_height (int, optional): Target height for resizing. Defaults to 960. + target_width (int, optional): Target width for resizing. Defaults to 960. + """ n, h, w = annotation.shape # batch, height, width areas = np.sum(annotation, axis=(1, 2)) @@ -175,26 +218,26 @@ class FastSAMPrompt: mask_image = np.expand_dims(annotation, -1) * visual show = np.zeros((h, w, 4)) - h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing='ij') + h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing="ij") indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) show[h_indices, w_indices, :] = mask_image[indices] if bbox is not None: x1, y1, x2, y2 = bbox - ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) + ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1)) # Draw point if points is not None: plt.scatter( [point[0] for i, point in enumerate(points) if pointlabel[i] == 1], [point[1] for i, point in enumerate(points) if pointlabel[i] == 1], s=20, - c='y', + c="y", ) plt.scatter( [point[0] for i, point in enumerate(points) if pointlabel[i] == 0], [point[1] for i, point in enumerate(points) if pointlabel[i] == 0], s=20, - c='m', + c="m", ) if not retinamask: @@ -203,6 +246,7 @@ class FastSAMPrompt: @torch.no_grad() def retrieve(self, model, preprocess, elements, search_text: str, device) -> int: + """Processes images and text with a model, calculates similarity, and returns softmax score.""" preprocessed_images = [preprocess(image).to(device) for image in elements] tokenized_text = self.clip.tokenize([search_text]).to(device) stacked_images = torch.stack(preprocessed_images) @@ -214,12 +258,13 @@ class FastSAMPrompt: return probs[:, 0].softmax(dim=0) def _crop_image(self, format_results): + """Crops an image based on provided annotation format and returns cropped images and related data.""" if os.path.isdir(self.source): raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.") image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB)) ori_w, ori_h = image.size annotations = format_results - mask_h, mask_w = annotations[0]['segmentation'].shape + mask_h, mask_w = annotations[0]["segmentation"].shape if ori_w != mask_w or ori_h != mask_h: image = image.resize((mask_w, mask_h)) cropped_boxes = [] @@ -227,18 +272,19 @@ class FastSAMPrompt: not_crop = [] filter_id = [] for _, mask in enumerate(annotations): - if np.sum(mask['segmentation']) <= 100: + if np.sum(mask["segmentation"]) <= 100: filter_id.append(_) continue - bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox - cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片 - cropped_images.append(bbox) # 保存裁剪的图片的bbox + bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask + cropped_boxes.append(self._segment_image(image, bbox)) # save cropped image + cropped_images.append(bbox) # save cropped image bbox return cropped_boxes, cropped_images, not_crop, filter_id, annotations def box_prompt(self, bbox): + """Modifies the bounding box properties and calculates IoU between masks and bounding box.""" if self.results[0].masks is not None: - assert (bbox[2] != 0 and bbox[3] != 0) + assert bbox[2] != 0 and bbox[3] != 0 if os.path.isdir(self.source): raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.") masks = self.results[0].masks.data @@ -250,7 +296,8 @@ class FastSAMPrompt: int(bbox[0] * w / target_width), int(bbox[1] * h / target_height), int(bbox[2] * w / target_width), - int(bbox[3] * h / target_height), ] + int(bbox[3] * h / target_height), + ] bbox[0] = max(round(bbox[0]), 0) bbox[1] = max(round(bbox[1]), 0) bbox[2] = min(round(bbox[2]), w) @@ -259,29 +306,30 @@ class FastSAMPrompt: # IoUs = torch.zeros(len(masks), dtype=torch.float32) bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) - masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2)) + masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2)) orig_masks_area = torch.sum(masks, dim=(1, 2)) union = bbox_area + orig_masks_area - masks_area - IoUs = masks_area / union - max_iou_index = torch.argmax(IoUs) + iou = masks_area / union + max_iou_index = torch.argmax(iou) self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()])) return self.results - def point_prompt(self, points, pointlabel): # numpy 处理 + def point_prompt(self, points, pointlabel): # numpy + """Adjusts points on detected masks based on user input and returns the modified results.""" if self.results[0].masks is not None: if os.path.isdir(self.source): raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.") masks = self._format_results(self.results[0], 0) target_height, target_width = self.results[0].orig_shape - h = masks[0]['segmentation'].shape[0] - w = masks[0]['segmentation'].shape[1] + h = masks[0]["segmentation"].shape[0] + w = masks[0]["segmentation"].shape[1] if h != target_height or w != target_width: points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] onemask = np.zeros((h, w)) for annotation in masks: - mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation + mask = annotation["segmentation"] if isinstance(annotation, dict) else annotation for i, point in enumerate(points): if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: onemask += mask @@ -292,16 +340,18 @@ class FastSAMPrompt: return self.results def text_prompt(self, text): + """Processes a text prompt, applies it to existing results and returns the updated results.""" if self.results[0].masks is not None: format_results = self._format_results(self.results[0], 0) cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) - clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device) + clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device) scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device) max_idx = scores.argsort() max_idx = max_idx[-1] max_idx += sum(np.array(filter_id) <= int(max_idx)) - self.results[0].masks.data = torch.tensor(np.array([ann['segmentation'] for ann in annotations])) + self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]])) return self.results def everything_prompt(self): + """Returns the processed results from the previous methods in the class.""" return self.results diff --git a/ultralytics/models/fastsam/utils.py b/ultralytics/models/fastsam/utils.py index e99fd62..480e903 100644 --- a/ultralytics/models/fastsam/utils.py +++ b/ultralytics/models/fastsam/utils.py @@ -42,23 +42,23 @@ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=Fals high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres """ boxes = adjust_bboxes_to_image_border(boxes, image_shape) - # obtain coordinates for intersections + # Obtain coordinates for intersections x1 = torch.max(box1[0], boxes[:, 0]) y1 = torch.max(box1[1], boxes[:, 1]) x2 = torch.min(box1[2], boxes[:, 2]) y2 = torch.min(box1[3], boxes[:, 3]) - # compute the area of intersection + # Compute the area of intersection intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) - # compute the area of both individual boxes + # Compute the area of both individual boxes box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) - # compute the area of union + # Compute the area of union union = box1_area + box2_area - intersection - # compute the IoU + # Compute the IoU iou = intersection / union # Should be shape (n, ) if raw_output: return 0 if iou.numel() == 0 else iou diff --git a/ultralytics/models/fastsam/val.py b/ultralytics/models/fastsam/val.py index fa25e49..9014b27 100644 --- a/ultralytics/models/fastsam/val.py +++ b/ultralytics/models/fastsam/val.py @@ -5,10 +5,36 @@ from ultralytics.utils.metrics import SegmentMetrics class FastSAMValidator(SegmentationValidator): + """ + Custom validation class for fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework. + + Extends the SegmentationValidator class, customizing the validation process specifically for fast SAM. This class + sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled + to avoid errors during validation. + + Attributes: + dataloader: The data loader object used for validation. + save_dir (str): The directory where validation results will be saved. + pbar: A progress bar object. + args: Additional arguments for customization. + _callbacks: List of callback functions to be invoked during validation. + """ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): - """Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.""" + """ + Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics. + + Args: + dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation. + save_dir (Path, optional): Directory to save results. + pbar (tqdm.tqdm): Progress bar for displaying progress. + args (SimpleNamespace): Configuration for the validator. + _callbacks (dict): Dictionary to store various callback functions. + + Notes: + Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors. + """ super().__init__(dataloader, save_dir, pbar, args, _callbacks) - self.args.task = 'segment' + self.args.task = "segment" self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot) diff --git a/ultralytics/models/nas/__init__.py b/ultralytics/models/nas/__init__.py index eec3837..b095a05 100644 --- a/ultralytics/models/nas/__init__.py +++ b/ultralytics/models/nas/__init__.py @@ -4,4 +4,4 @@ from .model import NAS from .predict import NASPredictor from .val import NASValidator -__all__ = 'NASPredictor', 'NASValidator', 'NAS' +__all__ = "NASPredictor", "NASValidator", "NAS" diff --git a/ultralytics/models/nas/__pycache__/__init__.cpython-312.pyc b/ultralytics/models/nas/__pycache__/__init__.cpython-312.pyc index 84e3932..5c155c1 100644 Binary files a/ultralytics/models/nas/__pycache__/__init__.cpython-312.pyc and b/ultralytics/models/nas/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/models/nas/__pycache__/__init__.cpython-39.pyc b/ultralytics/models/nas/__pycache__/__init__.cpython-39.pyc index 6dc818e..f9e576e 100644 Binary files a/ultralytics/models/nas/__pycache__/__init__.cpython-39.pyc and b/ultralytics/models/nas/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/models/nas/__pycache__/model.cpython-312.pyc b/ultralytics/models/nas/__pycache__/model.cpython-312.pyc index e33e052..b7b209c 100644 Binary files a/ultralytics/models/nas/__pycache__/model.cpython-312.pyc and b/ultralytics/models/nas/__pycache__/model.cpython-312.pyc differ diff --git a/ultralytics/models/nas/__pycache__/model.cpython-39.pyc b/ultralytics/models/nas/__pycache__/model.cpython-39.pyc index 71e30ac..5dcd9ff 100644 Binary files a/ultralytics/models/nas/__pycache__/model.cpython-39.pyc and b/ultralytics/models/nas/__pycache__/model.cpython-39.pyc differ diff --git a/ultralytics/models/nas/__pycache__/predict.cpython-312.pyc b/ultralytics/models/nas/__pycache__/predict.cpython-312.pyc index 475134d..243864e 100644 Binary files a/ultralytics/models/nas/__pycache__/predict.cpython-312.pyc and b/ultralytics/models/nas/__pycache__/predict.cpython-312.pyc differ diff --git a/ultralytics/models/nas/__pycache__/predict.cpython-39.pyc b/ultralytics/models/nas/__pycache__/predict.cpython-39.pyc index 7f5e883..96c5834 100644 Binary files a/ultralytics/models/nas/__pycache__/predict.cpython-39.pyc and b/ultralytics/models/nas/__pycache__/predict.cpython-39.pyc differ diff --git a/ultralytics/models/nas/__pycache__/val.cpython-312.pyc b/ultralytics/models/nas/__pycache__/val.cpython-312.pyc index 63b8a8e..ca819a2 100644 Binary files a/ultralytics/models/nas/__pycache__/val.cpython-312.pyc and b/ultralytics/models/nas/__pycache__/val.cpython-312.pyc differ diff --git a/ultralytics/models/nas/__pycache__/val.cpython-39.pyc b/ultralytics/models/nas/__pycache__/val.cpython-39.pyc index 0cfab75..c8aa1ce 100644 Binary files a/ultralytics/models/nas/__pycache__/val.cpython-39.pyc and b/ultralytics/models/nas/__pycache__/val.cpython-39.pyc differ diff --git a/ultralytics/models/nas/model.py b/ultralytics/models/nas/model.py index f848cc4..7997e96 100644 --- a/ultralytics/models/nas/model.py +++ b/ultralytics/models/nas/model.py @@ -17,26 +17,47 @@ import torch from ultralytics.engine.model import Model from ultralytics.utils.torch_utils import model_info, smart_inference_mode - from .predict import NASPredictor from .val import NASValidator class NAS(Model): + """ + YOLO NAS model for object detection. - def __init__(self, model='yolo_nas_s.pt') -> None: - assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.' - super().__init__(model, task='detect') + This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine. + It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models. + + Example: + ```python + from ultralytics import NAS + + model = NAS('yolo_nas_s') + results = model.predict('ultralytics/assets/bus.jpg') + ``` + + Attributes: + model (str): Path to the pre-trained model or model name. Defaults to 'yolo_nas_s.pt'. + + Note: + YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files. + """ + + def __init__(self, model="yolo_nas_s.pt") -> None: + """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model.""" + assert Path(model).suffix not in (".yaml", ".yml"), "YOLO-NAS models only support pre-trained models." + super().__init__(model, task="detect") @smart_inference_mode() def _load(self, weights: str, task: str): - # Load or create new NAS model + """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided.""" import super_gradients + suffix = Path(weights).suffix - if suffix == '.pt': + if suffix == ".pt": self.model = torch.load(weights) - elif suffix == '': - self.model = super_gradients.training.models.get(weights, pretrained_weights='coco') + elif suffix == "": + self.model = super_gradients.training.models.get(weights, pretrained_weights="coco") # Standardize model self.model.fuse = lambda verbose=True: self.model self.model.stride = torch.tensor([32]) @@ -44,7 +65,7 @@ class NAS(Model): self.model.is_fused = lambda: False # for info() self.model.yaml = {} # for info() self.model.pt_path = weights # for export() - self.model.task = 'detect' # for export() + self.model.task = "detect" # for export() def info(self, detailed=False, verbose=True): """ @@ -58,4 +79,5 @@ class NAS(Model): @property def task_map(self): - return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}} + """Returns a dictionary mapping tasks to respective predictor and validator classes.""" + return {"detect": {"predictor": NASPredictor, "validator": NASValidator}} diff --git a/ultralytics/models/nas/predict.py b/ultralytics/models/nas/predict.py index fe06c29..2e48546 100644 --- a/ultralytics/models/nas/predict.py +++ b/ultralytics/models/nas/predict.py @@ -8,6 +8,29 @@ from ultralytics.utils import ops class NASPredictor(BasePredictor): + """ + Ultralytics YOLO NAS Predictor for object detection. + + This class extends the `BasePredictor` from Ultralytics engine and is responsible for post-processing the + raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and + scaling the bounding boxes to fit the original image dimensions. + + Attributes: + args (Namespace): Namespace containing various configurations for post-processing. + + Example: + ```python + from ultralytics import NAS + + model = NAS('yolo_nas_s') + predictor = model.predictor + # Assumes that raw_preds, img, orig_imgs are available + results = predictor.postprocess(raw_preds, img, orig_imgs) + ``` + + Note: + Typically, this class is not instantiated directly. It is used internally within the `NAS` class. + """ def postprocess(self, preds_in, img, orig_imgs): """Postprocess predictions and returns a list of Results objects.""" @@ -16,12 +39,14 @@ class NASPredictor(BasePredictor): boxes = ops.xyxy2xywh(preds_in[0][0]) preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) - preds = ops.non_max_suppression(preds, - self.args.conf, - self.args.iou, - agnostic=self.args.agnostic_nms, - max_det=self.args.max_det, - classes=self.args.classes) + preds = ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + classes=self.args.classes, + ) if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) diff --git a/ultralytics/models/nas/val.py b/ultralytics/models/nas/val.py index 5c39171..a4a4f99 100644 --- a/ultralytics/models/nas/val.py +++ b/ultralytics/models/nas/val.py @@ -5,20 +5,46 @@ import torch from ultralytics.models.yolo.detect import DetectionValidator from ultralytics.utils import ops -__all__ = ['NASValidator'] +__all__ = ["NASValidator"] class NASValidator(DetectionValidator): + """ + Ultralytics YOLO NAS Validator for object detection. + + Extends `DetectionValidator` from the Ultralytics models package and is designed to post-process the raw predictions + generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes, + ultimately producing the final detections. + + Attributes: + args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU thresholds. + lb (torch.Tensor): Optional tensor for multilabel NMS. + + Example: + ```python + from ultralytics import NAS + + model = NAS('yolo_nas_s') + validator = model.validator + # Assumes that raw_preds are available + final_preds = validator.postprocess(raw_preds) + ``` + + Note: + This class is generally not instantiated directly but is used internally within the `NAS` class. + """ def postprocess(self, preds_in): """Apply Non-maximum suppression to prediction outputs.""" boxes = ops.xyxy2xywh(preds_in[0][0]) preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) - return ops.non_max_suppression(preds, - self.args.conf, - self.args.iou, - labels=self.lb, - multi_label=False, - agnostic=self.args.single_cls, - max_det=self.args.max_det, - max_time_img=0.5) + return ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + labels=self.lb, + multi_label=False, + agnostic=self.args.single_cls, + max_det=self.args.max_det, + max_time_img=0.5, + ) diff --git a/ultralytics/models/rtdetr/__init__.py b/ultralytics/models/rtdetr/__init__.py index 4d12115..172c74b 100644 --- a/ultralytics/models/rtdetr/__init__.py +++ b/ultralytics/models/rtdetr/__init__.py @@ -4,4 +4,4 @@ from .model import RTDETR from .predict import RTDETRPredictor from .val import RTDETRValidator -__all__ = 'RTDETRPredictor', 'RTDETRValidator', 'RTDETR' +__all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR" diff --git a/ultralytics/models/rtdetr/__pycache__/__init__.cpython-312.pyc b/ultralytics/models/rtdetr/__pycache__/__init__.cpython-312.pyc index e62c419..dc53474 100644 Binary files a/ultralytics/models/rtdetr/__pycache__/__init__.cpython-312.pyc and b/ultralytics/models/rtdetr/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/models/rtdetr/__pycache__/__init__.cpython-39.pyc b/ultralytics/models/rtdetr/__pycache__/__init__.cpython-39.pyc index 2522cd6..1224ec8 100644 Binary files a/ultralytics/models/rtdetr/__pycache__/__init__.cpython-39.pyc and b/ultralytics/models/rtdetr/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/models/rtdetr/__pycache__/model.cpython-312.pyc b/ultralytics/models/rtdetr/__pycache__/model.cpython-312.pyc index 78dfea7..7713b7b 100644 Binary files a/ultralytics/models/rtdetr/__pycache__/model.cpython-312.pyc and b/ultralytics/models/rtdetr/__pycache__/model.cpython-312.pyc differ diff --git a/ultralytics/models/rtdetr/__pycache__/model.cpython-39.pyc b/ultralytics/models/rtdetr/__pycache__/model.cpython-39.pyc index bccf6b9..ae5cd9d 100644 Binary files a/ultralytics/models/rtdetr/__pycache__/model.cpython-39.pyc and b/ultralytics/models/rtdetr/__pycache__/model.cpython-39.pyc differ diff --git a/ultralytics/models/rtdetr/__pycache__/predict.cpython-312.pyc b/ultralytics/models/rtdetr/__pycache__/predict.cpython-312.pyc index d4266dc..20225e1 100644 Binary files a/ultralytics/models/rtdetr/__pycache__/predict.cpython-312.pyc and b/ultralytics/models/rtdetr/__pycache__/predict.cpython-312.pyc differ diff --git a/ultralytics/models/rtdetr/__pycache__/predict.cpython-39.pyc b/ultralytics/models/rtdetr/__pycache__/predict.cpython-39.pyc index 53f0f0d..23bf393 100644 Binary files a/ultralytics/models/rtdetr/__pycache__/predict.cpython-39.pyc and b/ultralytics/models/rtdetr/__pycache__/predict.cpython-39.pyc differ diff --git a/ultralytics/models/rtdetr/__pycache__/train.cpython-312.pyc b/ultralytics/models/rtdetr/__pycache__/train.cpython-312.pyc index f0449d9..87bec8d 100644 Binary files a/ultralytics/models/rtdetr/__pycache__/train.cpython-312.pyc and b/ultralytics/models/rtdetr/__pycache__/train.cpython-312.pyc differ diff --git a/ultralytics/models/rtdetr/__pycache__/train.cpython-39.pyc b/ultralytics/models/rtdetr/__pycache__/train.cpython-39.pyc index a154802..5847ea7 100644 Binary files a/ultralytics/models/rtdetr/__pycache__/train.cpython-39.pyc and b/ultralytics/models/rtdetr/__pycache__/train.cpython-39.pyc differ diff --git a/ultralytics/models/rtdetr/__pycache__/val.cpython-312.pyc b/ultralytics/models/rtdetr/__pycache__/val.cpython-312.pyc index d995e6c..da53777 100644 Binary files a/ultralytics/models/rtdetr/__pycache__/val.cpython-312.pyc and b/ultralytics/models/rtdetr/__pycache__/val.cpython-312.pyc differ diff --git a/ultralytics/models/rtdetr/__pycache__/val.cpython-39.pyc b/ultralytics/models/rtdetr/__pycache__/val.cpython-39.pyc index 0d0c2db..230f46c 100644 Binary files a/ultralytics/models/rtdetr/__pycache__/val.cpython-39.pyc and b/ultralytics/models/rtdetr/__pycache__/val.cpython-39.pyc differ diff --git a/ultralytics/models/rtdetr/model.py b/ultralytics/models/rtdetr/model.py index c20d72f..440df17 100644 --- a/ultralytics/models/rtdetr/model.py +++ b/ultralytics/models/rtdetr/model.py @@ -1,7 +1,12 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license """ -RT-DETR model interface +Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector. RT-DETR offers real-time +performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT. It features an efficient +hybrid encoder and IoU-aware query selection for enhanced detection accuracy. + +For more information on RT-DETR, visit: https://arxiv.org/pdf/2304.08069.pdf """ + from ultralytics.engine.model import Model from ultralytics.nn.tasks import RTDETRDetectionModel @@ -12,19 +17,38 @@ from .val import RTDETRValidator class RTDETR(Model): """ - RTDETR model interface. + Interface for Baidu's RT-DETR model. This Vision Transformer-based object detector provides real-time performance + with high accuracy. It supports efficient hybrid encoding, IoU-aware query selection, and adaptable inference speed. + + Attributes: + model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'. """ - def __init__(self, model='rtdetr-l.pt') -> None: - if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'): - raise NotImplementedError('RT-DETR only supports creating from *.pt file or *.yaml file.') - super().__init__(model=model, task='detect') + def __init__(self, model="rtdetr-l.pt") -> None: + """ + Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats. + + Args: + model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'. + + Raises: + NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'. + """ + super().__init__(model=model, task="detect") @property - def task_map(self): + def task_map(self) -> dict: + """ + Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes. + + Returns: + dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model. + """ return { - 'detect': { - 'predictor': RTDETRPredictor, - 'validator': RTDETRValidator, - 'trainer': RTDETRTrainer, - 'model': RTDETRDetectionModel}} + "detect": { + "predictor": RTDETRPredictor, + "validator": RTDETRValidator, + "trainer": RTDETRTrainer, + "model": RTDETRDetectionModel, + } + } diff --git a/ultralytics/models/rtdetr/predict.py b/ultralytics/models/rtdetr/predict.py index 33d5d7a..7fc918b 100644 --- a/ultralytics/models/rtdetr/predict.py +++ b/ultralytics/models/rtdetr/predict.py @@ -10,7 +10,11 @@ from ultralytics.utils import ops class RTDETRPredictor(BasePredictor): """ - A class extending the BasePredictor class for prediction based on an RT-DETR detection model. + RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions using + Baidu's RT-DETR model. + + This class leverages the power of Vision Transformers to provide real-time object detection while maintaining + high accuracy. It supports key features like efficient hybrid encoding and IoU-aware query selection. Example: ```python @@ -21,10 +25,30 @@ class RTDETRPredictor(BasePredictor): predictor = RTDETRPredictor(overrides=args) predictor.predict_cli() ``` + + Attributes: + imgsz (int): Image size for inference (must be square and scale-filled). + args (dict): Argument overrides for the predictor. """ def postprocess(self, preds, img, orig_imgs): - """Postprocess predictions and returns a list of Results objects.""" + """ + Postprocess the raw predictions from the model to generate bounding boxes and confidence scores. + + The method filters detections based on confidence and class if specified in `self.args`. + + Args: + preds (list): List of [predictions, extra] from the model. + img (torch.Tensor): Processed input images. + orig_imgs (list or torch.Tensor): Original, unprocessed images. + + Returns: + (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores, + and class labels. + """ + if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference + preds = [preds, None] + nd = preds[0].shape[-1] bboxes, scores = preds[0].split((4, nd - 4), dim=-1) @@ -48,15 +72,15 @@ class RTDETRPredictor(BasePredictor): return results def pre_transform(self, im): - """Pre-transform input image before inference. + """ + Pre-transforms the input images before feeding them into the model for inference. The input images are + letterboxed to ensure a square aspect ratio and scale-filled. The size must be square(640) and scaleFilled. Args: - im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. - - Notes: The size must be square(640) and scaleFilled. + im (list[np.ndarray] |torch.Tensor): Input images of shape (N,3,h,w) for tensor, [(h,w,3) x N] for list. Returns: - (list): A list of transformed imgs. + (list): List of pre-transformed images ready for model inference. """ letterbox = LetterBox(self.imgsz, auto=False, scaleFill=True) return [letterbox(image=x) for x in im] diff --git a/ultralytics/models/rtdetr/train.py b/ultralytics/models/rtdetr/train.py index 1e58668..10a8f9b 100644 --- a/ultralytics/models/rtdetr/train.py +++ b/ultralytics/models/rtdetr/train.py @@ -7,16 +7,17 @@ import torch from ultralytics.models.yolo.detect import DetectionTrainer from ultralytics.nn.tasks import RTDETRDetectionModel from ultralytics.utils import RANK, colorstr - from .val import RTDETRDataset, RTDETRValidator class RTDETRTrainer(DetectionTrainer): """ - A class extending the DetectionTrainer class for training based on an RT-DETR detection model. + Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer + class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision + Transformers and has capabilities like IoU-aware query selection and adaptable inference speed. Notes: - - F.grid_sample used in rt-detr does not support the `deterministic=True` argument. + - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument. - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching. Example: @@ -30,43 +31,71 @@ class RTDETRTrainer(DetectionTrainer): """ def get_model(self, cfg=None, weights=None, verbose=True): - """Return a YOLO detection model.""" - model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) + """ + Initialize and return an RT-DETR model for object detection tasks. + + Args: + cfg (dict, optional): Model configuration. Defaults to None. + weights (str, optional): Path to pre-trained model weights. Defaults to None. + verbose (bool): Verbose logging if True. Defaults to True. + + Returns: + (RTDETRDetectionModel): Initialized model. + """ + model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) if weights: model.load(weights) return model - def build_dataset(self, img_path, mode='val', batch=None): - """Build RTDETR Dataset + def build_dataset(self, img_path, mode="val", batch=None): + """ + Build and return an RT-DETR dataset for training or validation. Args: img_path (str): Path to the folder containing images. - mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. - batch (int, optional): Size of batches, this is for `rect`. Defaults to None. + mode (str): Dataset mode, either 'train' or 'val'. + batch (int, optional): Batch size for rectangle training. Defaults to None. + + Returns: + (RTDETRDataset): Dataset object for the specific mode. """ return RTDETRDataset( img_path=img_path, imgsz=self.args.imgsz, batch_size=batch, - augment=mode == 'train', # no augmentation + augment=mode == "train", hyp=self.args, - rect=False, # no rect + rect=False, cache=self.args.cache or None, - prefix=colorstr(f'{mode}: '), - data=self.data) + prefix=colorstr(f"{mode}: "), + data=self.data, + ) def get_validator(self): - """Returns a DetectionValidator for RTDETR model validation.""" - self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss' + """ + Returns a DetectionValidator suitable for RT-DETR model validation. + + Returns: + (RTDETRValidator): Validator object for model validation. + """ + self.loss_names = "giou_loss", "cls_loss", "l1_loss" return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) def preprocess_batch(self, batch): - """Preprocesses a batch of images by scaling and converting to float.""" + """ + Preprocess a batch of images. Scales and converts the images to float format. + + Args: + batch (dict): Dictionary containing a batch of images, bboxes, and labels. + + Returns: + (dict): Preprocessed batch. + """ batch = super().preprocess_batch(batch) - bs = len(batch['img']) - batch_idx = batch['batch_idx'] + bs = len(batch["img"]) + batch_idx = batch["batch_idx"] gt_bbox, gt_class = [], [] for i in range(bs): - gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device)) - gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long)) + gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device)) + gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long)) return batch diff --git a/ultralytics/models/rtdetr/val.py b/ultralytics/models/rtdetr/val.py index 9b984be..88bb0ae 100644 --- a/ultralytics/models/rtdetr/val.py +++ b/ultralytics/models/rtdetr/val.py @@ -1,7 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from pathlib import Path - import torch from ultralytics.data import YOLODataset @@ -9,16 +7,22 @@ from ultralytics.data.augment import Compose, Format, v8_transforms from ultralytics.models.yolo.detect import DetectionValidator from ultralytics.utils import colorstr, ops -__all__ = 'RTDETRValidator', # tuple or list +__all__ = ("RTDETRValidator",) # tuple or list -# TODO: Temporarily RT-DETR does not need padding. class RTDETRDataset(YOLODataset): + """ + Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class. + + This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for + real-time detection and tracking tasks. + """ def __init__(self, *args, data=None, **kwargs): - super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs) + """Initialize the RTDETRDataset class by inheriting from the YOLODataset class.""" + super().__init__(*args, data=data, **kwargs) - # NOTE: add stretch version load_image for rtdetr mosaic + # NOTE: add stretch version load_image for RTDETR mosaic def load_image(self, i, rect_mode=False): """Loads 1 image from dataset index 'i', returns (im, resized hw).""" return super().load_image(i=i, rect_mode=rect_mode) @@ -33,19 +37,26 @@ class RTDETRDataset(YOLODataset): # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)]) transforms = Compose([]) transforms.append( - Format(bbox_format='xywh', - normalize=True, - return_mask=self.use_segments, - return_keypoint=self.use_keypoints, - batch_idx=True, - mask_ratio=hyp.mask_ratio, - mask_overlap=hyp.overlap_mask)) + Format( + bbox_format="xywh", + normalize=True, + return_mask=self.use_segments, + return_keypoint=self.use_keypoints, + batch_idx=True, + mask_ratio=hyp.mask_ratio, + mask_overlap=hyp.overlap_mask, + ) + ) return transforms class RTDETRValidator(DetectionValidator): """ - A class extending the DetectionValidator class for validation based on an RT-DETR detection model. + RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for + the RT-DETR (Real-Time DETR) object detection model. + + The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for + post-processing, and updates evaluation metrics accordingly. Example: ```python @@ -55,9 +66,12 @@ class RTDETRValidator(DetectionValidator): validator = RTDETRValidator(args=args) validator() ``` + + Note: + For further details on the attributes and methods, refer to the parent DetectionValidator class. """ - def build_dataset(self, img_path, mode='val', batch=None): + def build_dataset(self, img_path, mode="val", batch=None): """ Build an RTDETR Dataset. @@ -74,11 +88,15 @@ class RTDETRValidator(DetectionValidator): hyp=self.args, rect=False, # no rect cache=self.args.cache or None, - prefix=colorstr(f'{mode}: '), - data=self.data) + prefix=colorstr(f"{mode}: "), + data=self.data, + ) def postprocess(self, preds): """Apply Non-maximum suppression to prediction outputs.""" + if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference + preds = [preds, None] + bs, _, nd = preds[0].shape bboxes, scores = preds[0].split((4, nd - 4), dim=-1) bboxes *= self.args.imgsz @@ -86,56 +104,32 @@ class RTDETRValidator(DetectionValidator): for i, bbox in enumerate(bboxes): # (300, 4) bbox = ops.xywh2xyxy(bbox) score, cls = scores[i].max(-1) # (300, ) - # Do not need threshold for evaluation as only got 300 boxes here. + # Do not need threshold for evaluation as only got 300 boxes here # idx = score > self.args.conf pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter - # sort by confidence to correctly get internal metrics. + # Sort by confidence to correctly get internal metrics pred = pred[score.argsort(descending=True)] outputs[i] = pred # [idx] return outputs - def update_metrics(self, preds, batch): - """Metrics.""" - for si, pred in enumerate(preds): - idx = batch['batch_idx'] == si - cls = batch['cls'][idx] - bbox = batch['bboxes'][idx] - nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions - shape = batch['ori_shape'][si] - correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init - self.seen += 1 + def _prepare_batch(self, si, batch): + """Prepares a batch for training or inference by applying transformations.""" + idx = batch["batch_idx"] == si + cls = batch["cls"][idx].squeeze(-1) + bbox = batch["bboxes"][idx] + ori_shape = batch["ori_shape"][si] + imgsz = batch["img"].shape[2:] + ratio_pad = batch["ratio_pad"][si] + if len(cls): + bbox = ops.xywh2xyxy(bbox) # target boxes + bbox[..., [0, 2]] *= ori_shape[1] # native-space pred + bbox[..., [1, 3]] *= ori_shape[0] # native-space pred + return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad) - if npr == 0: - if nl: - self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), cls.squeeze(-1))) - if self.args.plots: - self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1)) - continue - - # Predictions - if self.args.single_cls: - pred[:, 5] = 0 - predn = pred.clone() - predn[..., [0, 2]] *= shape[1] / self.args.imgsz # native-space pred - predn[..., [1, 3]] *= shape[0] / self.args.imgsz # native-space pred - - # Evaluate - if nl: - tbox = ops.xywh2xyxy(bbox) # target boxes - tbox[..., [0, 2]] *= shape[1] # native-space pred - tbox[..., [1, 3]] *= shape[0] # native-space pred - labelsn = torch.cat((cls, tbox), 1) # native-space labels - # NOTE: To get correct metrics, the inputs of `_process_batch` should always be float32 type. - correct_bboxes = self._process_batch(predn.float(), labelsn) - # TODO: maybe remove these `self.` arguments as they already are member variable - if self.args.plots: - self.confusion_matrix.process_batch(predn, labelsn) - self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) # (conf, pcls, tcls) - - # Save - if self.args.save_json: - self.pred_to_json(predn, batch['im_file'][si]) - if self.args.save_txt: - file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt' - self.save_one_txt(predn, self.args.save_conf, shape, file) + def _prepare_pred(self, pred, pbatch): + """Prepares and returns a batch with transformed bounding boxes and class labels.""" + predn = pred.clone() + predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred + predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred + return predn.float() diff --git a/ultralytics/models/sam/__init__.py b/ultralytics/models/sam/__init__.py index 35f4efa..8701fcc 100644 --- a/ultralytics/models/sam/__init__.py +++ b/ultralytics/models/sam/__init__.py @@ -3,6 +3,4 @@ from .model import SAM from .predict import Predictor -# from .build import build_sam - -__all__ = 'SAM', 'Predictor' # tuple or list +__all__ = "SAM", "Predictor" # tuple or list diff --git a/ultralytics/models/sam/__pycache__/__init__.cpython-312.pyc b/ultralytics/models/sam/__pycache__/__init__.cpython-312.pyc index c91bc2d..8b41d71 100644 Binary files a/ultralytics/models/sam/__pycache__/__init__.cpython-312.pyc and b/ultralytics/models/sam/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/models/sam/__pycache__/__init__.cpython-39.pyc b/ultralytics/models/sam/__pycache__/__init__.cpython-39.pyc index f503cbb..81a0c71 100644 Binary files a/ultralytics/models/sam/__pycache__/__init__.cpython-39.pyc and b/ultralytics/models/sam/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/models/sam/__pycache__/amg.cpython-312.pyc b/ultralytics/models/sam/__pycache__/amg.cpython-312.pyc index ca37818..06b5fc4 100644 Binary files a/ultralytics/models/sam/__pycache__/amg.cpython-312.pyc and b/ultralytics/models/sam/__pycache__/amg.cpython-312.pyc differ diff --git a/ultralytics/models/sam/__pycache__/amg.cpython-39.pyc b/ultralytics/models/sam/__pycache__/amg.cpython-39.pyc index 18a5f58..6cf79c4 100644 Binary files a/ultralytics/models/sam/__pycache__/amg.cpython-39.pyc and b/ultralytics/models/sam/__pycache__/amg.cpython-39.pyc differ diff --git a/ultralytics/models/sam/__pycache__/build.cpython-312.pyc b/ultralytics/models/sam/__pycache__/build.cpython-312.pyc index 0ba57c1..ee3e4c8 100644 Binary files a/ultralytics/models/sam/__pycache__/build.cpython-312.pyc and b/ultralytics/models/sam/__pycache__/build.cpython-312.pyc differ diff --git a/ultralytics/models/sam/__pycache__/build.cpython-39.pyc b/ultralytics/models/sam/__pycache__/build.cpython-39.pyc index 8f95385..43e49b1 100644 Binary files a/ultralytics/models/sam/__pycache__/build.cpython-39.pyc and b/ultralytics/models/sam/__pycache__/build.cpython-39.pyc differ diff --git a/ultralytics/models/sam/__pycache__/model.cpython-312.pyc b/ultralytics/models/sam/__pycache__/model.cpython-312.pyc index 1da9eab..7c9e884 100644 Binary files a/ultralytics/models/sam/__pycache__/model.cpython-312.pyc and b/ultralytics/models/sam/__pycache__/model.cpython-312.pyc differ diff --git a/ultralytics/models/sam/__pycache__/model.cpython-39.pyc b/ultralytics/models/sam/__pycache__/model.cpython-39.pyc index eca6f8e..1826afc 100644 Binary files a/ultralytics/models/sam/__pycache__/model.cpython-39.pyc and b/ultralytics/models/sam/__pycache__/model.cpython-39.pyc differ diff --git a/ultralytics/models/sam/__pycache__/predict.cpython-312.pyc b/ultralytics/models/sam/__pycache__/predict.cpython-312.pyc index 43347e6..8253d47 100644 Binary files a/ultralytics/models/sam/__pycache__/predict.cpython-312.pyc and b/ultralytics/models/sam/__pycache__/predict.cpython-312.pyc differ diff --git a/ultralytics/models/sam/__pycache__/predict.cpython-39.pyc b/ultralytics/models/sam/__pycache__/predict.cpython-39.pyc index ad73446..ea68a38 100644 Binary files a/ultralytics/models/sam/__pycache__/predict.cpython-39.pyc and b/ultralytics/models/sam/__pycache__/predict.cpython-39.pyc differ diff --git a/ultralytics/models/sam/amg.py b/ultralytics/models/sam/amg.py index f251fe4..128108f 100644 --- a/ultralytics/models/sam/amg.py +++ b/ultralytics/models/sam/amg.py @@ -8,10 +8,9 @@ import numpy as np import torch -def is_box_near_crop_edge(boxes: torch.Tensor, - crop_box: List[int], - orig_box: List[int], - atol: float = 20.0) -> torch.Tensor: +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: """Return a boolean tensor indicating if boxes are near the crop edge.""" crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) @@ -24,23 +23,25 @@ def is_box_near_crop_edge(boxes: torch.Tensor, def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: """Yield batches of data from the input arguments.""" - assert args and all(len(a) == len(args[0]) for a in args), 'Batched iteration must have same-size inputs.' + assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs." n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) for b in range(n_batches): - yield [arg[b * batch_size:(b + 1) * batch_size] for arg in args] + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor: """ - Computes the stability score for a batch of masks. The stability - score is the IoU between the binary masks obtained by thresholding - the predicted mask logits at high and low values. + Computes the stability score for a batch of masks. + + The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high + and low values. + + Notes: + - One mask is always contained inside the other. + - Save memory by preventing unnecessary cast to torch.int64 """ - # One mask is always contained inside the other. - # Save memory by preventing unnecessary cast to torch.int64 - intersections = ((masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, - dtype=torch.int32)) - unions = ((masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)) + intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) return intersections / unions @@ -55,12 +56,17 @@ def build_point_grid(n_per_side: int) -> np.ndarray: def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]: """Generate point grids for all crop layers.""" - return [build_point_grid(int(n_per_side / (scale_per_layer ** i))) for i in range(n_layers + 1)] + return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)] -def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int, - overlap_ratio: float) -> Tuple[List[List[int]], List[int]]: - """Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.""" +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. + + Each layer has (2**i)**2 boxes for the ith layer. + """ crop_boxes, layer_idxs = [], [] im_h, im_w = im_size short_side = min(im_h, im_w) @@ -127,8 +133,8 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup """Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator.""" import cv2 # type: ignore - assert mode in {'holes', 'islands'} - correct_holes = mode == 'holes' + assert mode in {"holes", "islands"} + correct_holes = mode == "holes" working_mask = (correct_holes ^ mask).astype(np.uint8) n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) sizes = stats[:, -1][1:] # Row 0 is background label @@ -145,8 +151,9 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: """ - Calculates boxes in XYXY format around masks. Return [0,0,0,0] for - an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + Calculates boxes in XYXY format around masks. + + Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. """ # torch.max below raises an error on empty inputs, just skip in this case if torch.numel(masks) == 0: diff --git a/ultralytics/models/sam/build.py b/ultralytics/models/sam/build.py index 21da265..266587e 100644 --- a/ultralytics/models/sam/build.py +++ b/ultralytics/models/sam/build.py @@ -11,7 +11,6 @@ from functools import partial import torch from ultralytics.utils.downloads import attempt_download_asset - from .modules.decoders import MaskDecoder from .modules.encoders import ImageEncoderViT, PromptEncoder from .modules.sam import Sam @@ -64,46 +63,47 @@ def build_mobile_sam(checkpoint=None): ) -def _build_sam(encoder_embed_dim, - encoder_depth, - encoder_num_heads, - encoder_global_attn_indexes, - checkpoint=None, - mobile_sam=False): +def _build_sam( + encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False +): """Builds the selected SAM model architecture.""" prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size - image_encoder = (TinyViT( - img_size=1024, - in_chans=3, - num_classes=1000, - embed_dims=encoder_embed_dim, - depths=encoder_depth, - num_heads=encoder_num_heads, - window_sizes=[7, 7, 14, 7], - mlp_ratio=4.0, - drop_rate=0.0, - drop_path_rate=0.0, - use_checkpoint=False, - mbconv_expand_ratio=4.0, - local_conv_size=3, - layer_lr_decay=0.8, - ) if mobile_sam else ImageEncoderViT( - depth=encoder_depth, - embed_dim=encoder_embed_dim, - img_size=image_size, - mlp_ratio=4, - norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), - num_heads=encoder_num_heads, - patch_size=vit_patch_size, - qkv_bias=True, - use_rel_pos=True, - global_attn_indexes=encoder_global_attn_indexes, - window_size=14, - out_chans=prompt_embed_dim, - )) + image_encoder = ( + TinyViT( + img_size=1024, + in_chans=3, + num_classes=1000, + embed_dims=encoder_embed_dim, + depths=encoder_depth, + num_heads=encoder_num_heads, + window_sizes=[7, 7, 14, 7], + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8, + ) + if mobile_sam + else ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ) + ) sam = Sam( image_encoder=image_encoder, prompt_encoder=PromptEncoder( @@ -129,7 +129,7 @@ def _build_sam(encoder_embed_dim, ) if checkpoint is not None: checkpoint = attempt_download_asset(checkpoint) - with open(checkpoint, 'rb') as f: + with open(checkpoint, "rb") as f: state_dict = torch.load(f) sam.load_state_dict(state_dict) sam.eval() @@ -139,20 +139,22 @@ def _build_sam(encoder_embed_dim, sam_model_map = { - 'sam_h.pt': build_sam_vit_h, - 'sam_l.pt': build_sam_vit_l, - 'sam_b.pt': build_sam_vit_b, - 'mobile_sam.pt': build_mobile_sam, } + "sam_h.pt": build_sam_vit_h, + "sam_l.pt": build_sam_vit_l, + "sam_b.pt": build_sam_vit_b, + "mobile_sam.pt": build_mobile_sam, +} -def build_sam(ckpt='sam_b.pt'): +def build_sam(ckpt="sam_b.pt"): """Build a SAM model specified by ckpt.""" model_builder = None + ckpt = str(ckpt) # to allow Path ckpt types for k in sam_model_map.keys(): if ckpt.endswith(k): model_builder = sam_model_map.get(k) if not model_builder: - raise FileNotFoundError(f'{ckpt} is not a supported sam model. Available models are: \n {sam_model_map.keys()}') + raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}") return model_builder(ckpt) diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py index 2ca3501..cb12bc7 100644 --- a/ultralytics/models/sam/model.py +++ b/ultralytics/models/sam/model.py @@ -1,51 +1,114 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license """ -SAM model interface +SAM model interface. + +This module provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for real-time image +segmentation tasks. The SAM model allows for promptable segmentation with unparalleled versatility in image analysis, +and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new +image distributions and tasks without prior knowledge. + +Key Features: + - Promptable segmentation + - Real-time performance + - Zero-shot transfer capabilities + - Trained on SA-1B dataset """ from pathlib import Path from ultralytics.engine.model import Model from ultralytics.utils.torch_utils import model_info - from .build import build_sam from .predict import Predictor class SAM(Model): """ - SAM model interface. + SAM (Segment Anything Model) interface class. + + SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as + bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B + dataset. """ - def __init__(self, model='sam_b.pt') -> None: - if model and Path(model).suffix not in ('.pt', '.pth'): - raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.') - super().__init__(model=model, task='segment') + def __init__(self, model="sam_b.pt") -> None: + """ + Initializes the SAM model with a pre-trained model file. + + Args: + model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension. + + Raises: + NotImplementedError: If the model file extension is not .pt or .pth. + """ + if model and Path(model).suffix not in (".pt", ".pth"): + raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.") + super().__init__(model=model, task="segment") def _load(self, weights: str, task=None): + """ + Loads the specified weights into the SAM model. + + Args: + weights (str): Path to the weights file. + task (str, optional): Task name. Defaults to None. + """ self.model = build_sam(weights) def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs): - """Predicts and returns segmentation masks for given image or video source.""" - overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024) + """ + Performs segmentation prediction on the given image or video source. + + Args: + source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object. + stream (bool, optional): If True, enables real-time streaming. Defaults to False. + bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None. + points (list, optional): List of points for prompted segmentation. Defaults to None. + labels (list, optional): List of labels for prompted segmentation. Defaults to None. + + Returns: + (list): The model predictions. + """ + overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024) kwargs.update(overrides) prompts = dict(bboxes=bboxes, points=points, labels=labels) return super().predict(source, stream, prompts=prompts, **kwargs) def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs): - """Calls the 'predict' function with given arguments to perform object detection.""" + """ + Alias for the 'predict' method. + + Args: + source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object. + stream (bool, optional): If True, enables real-time streaming. Defaults to False. + bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None. + points (list, optional): List of points for prompted segmentation. Defaults to None. + labels (list, optional): List of labels for prompted segmentation. Defaults to None. + + Returns: + (list): The model predictions. + """ return self.predict(source, stream, bboxes, points, labels, **kwargs) def info(self, detailed=False, verbose=True): """ - Logs model info. + Logs information about the SAM model. Args: - detailed (bool): Show detailed information about model. - verbose (bool): Controls verbosity. + detailed (bool, optional): If True, displays detailed information about the model. Defaults to False. + verbose (bool, optional): If True, displays information on the console. Defaults to True. + + Returns: + (tuple): A tuple containing the model's information. """ return model_info(self.model, detailed=detailed, verbose=verbose) @property def task_map(self): - return {'segment': {'predictor': Predictor}} + """ + Provides a mapping from the 'segment' task to its corresponding 'Predictor'. + + Returns: + (dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'. + """ + return {"segment": {"predictor": Predictor}} diff --git a/ultralytics/models/sam/modules/__pycache__/__init__.cpython-312.pyc b/ultralytics/models/sam/modules/__pycache__/__init__.cpython-312.pyc index e19a05b..e266ea7 100644 Binary files a/ultralytics/models/sam/modules/__pycache__/__init__.cpython-312.pyc and b/ultralytics/models/sam/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/models/sam/modules/__pycache__/__init__.cpython-39.pyc b/ultralytics/models/sam/modules/__pycache__/__init__.cpython-39.pyc index 6efab26..f3e51cb 100644 Binary files a/ultralytics/models/sam/modules/__pycache__/__init__.cpython-39.pyc and b/ultralytics/models/sam/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/models/sam/modules/__pycache__/decoders.cpython-312.pyc b/ultralytics/models/sam/modules/__pycache__/decoders.cpython-312.pyc index bc57016..a4948fe 100644 Binary files a/ultralytics/models/sam/modules/__pycache__/decoders.cpython-312.pyc and b/ultralytics/models/sam/modules/__pycache__/decoders.cpython-312.pyc differ diff --git a/ultralytics/models/sam/modules/__pycache__/decoders.cpython-39.pyc b/ultralytics/models/sam/modules/__pycache__/decoders.cpython-39.pyc index f3d0706..bae68e2 100644 Binary files a/ultralytics/models/sam/modules/__pycache__/decoders.cpython-39.pyc and b/ultralytics/models/sam/modules/__pycache__/decoders.cpython-39.pyc differ diff --git a/ultralytics/models/sam/modules/__pycache__/encoders.cpython-312.pyc b/ultralytics/models/sam/modules/__pycache__/encoders.cpython-312.pyc index 16294e4..75a377c 100644 Binary files a/ultralytics/models/sam/modules/__pycache__/encoders.cpython-312.pyc and b/ultralytics/models/sam/modules/__pycache__/encoders.cpython-312.pyc differ diff --git a/ultralytics/models/sam/modules/__pycache__/encoders.cpython-39.pyc b/ultralytics/models/sam/modules/__pycache__/encoders.cpython-39.pyc index 67b2580..1a0e848 100644 Binary files a/ultralytics/models/sam/modules/__pycache__/encoders.cpython-39.pyc and b/ultralytics/models/sam/modules/__pycache__/encoders.cpython-39.pyc differ diff --git a/ultralytics/models/sam/modules/__pycache__/sam.cpython-312.pyc b/ultralytics/models/sam/modules/__pycache__/sam.cpython-312.pyc index 71f3fdc..e47bcff 100644 Binary files a/ultralytics/models/sam/modules/__pycache__/sam.cpython-312.pyc and b/ultralytics/models/sam/modules/__pycache__/sam.cpython-312.pyc differ diff --git a/ultralytics/models/sam/modules/__pycache__/sam.cpython-39.pyc b/ultralytics/models/sam/modules/__pycache__/sam.cpython-39.pyc index 3fb36d5..1b16234 100644 Binary files a/ultralytics/models/sam/modules/__pycache__/sam.cpython-39.pyc and b/ultralytics/models/sam/modules/__pycache__/sam.cpython-39.pyc differ diff --git a/ultralytics/models/sam/modules/__pycache__/tiny_encoder.cpython-312.pyc b/ultralytics/models/sam/modules/__pycache__/tiny_encoder.cpython-312.pyc index ac2f1bc..b5ca8ae 100644 Binary files a/ultralytics/models/sam/modules/__pycache__/tiny_encoder.cpython-312.pyc and b/ultralytics/models/sam/modules/__pycache__/tiny_encoder.cpython-312.pyc differ diff --git a/ultralytics/models/sam/modules/__pycache__/tiny_encoder.cpython-39.pyc b/ultralytics/models/sam/modules/__pycache__/tiny_encoder.cpython-39.pyc index 293a2c3..4fb47c1 100644 Binary files a/ultralytics/models/sam/modules/__pycache__/tiny_encoder.cpython-39.pyc and b/ultralytics/models/sam/modules/__pycache__/tiny_encoder.cpython-39.pyc differ diff --git a/ultralytics/models/sam/modules/__pycache__/transformer.cpython-312.pyc b/ultralytics/models/sam/modules/__pycache__/transformer.cpython-312.pyc index c7d0571..c22df5f 100644 Binary files a/ultralytics/models/sam/modules/__pycache__/transformer.cpython-312.pyc and b/ultralytics/models/sam/modules/__pycache__/transformer.cpython-312.pyc differ diff --git a/ultralytics/models/sam/modules/__pycache__/transformer.cpython-39.pyc b/ultralytics/models/sam/modules/__pycache__/transformer.cpython-39.pyc index a2a7a17..9a69d56 100644 Binary files a/ultralytics/models/sam/modules/__pycache__/transformer.cpython-39.pyc and b/ultralytics/models/sam/modules/__pycache__/transformer.cpython-39.pyc differ diff --git a/ultralytics/models/sam/modules/decoders.py b/ultralytics/models/sam/modules/decoders.py index 0c64a7e..073b1ad 100644 --- a/ultralytics/models/sam/modules/decoders.py +++ b/ultralytics/models/sam/modules/decoders.py @@ -10,6 +10,21 @@ from ultralytics.nn.modules import LayerNorm2d class MaskDecoder(nn.Module): + """ + Decoder module for generating masks and their associated quality scores, using a transformer architecture to predict + masks given image and prompt embeddings. + + Attributes: + transformer_dim (int): Channel dimension for the transformer module. + transformer (nn.Module): The transformer module used for mask prediction. + num_multimask_outputs (int): Number of masks to predict for disambiguating masks. + iou_token (nn.Embedding): Embedding for the IoU token. + num_mask_tokens (int): Number of mask tokens. + mask_tokens (nn.Embedding): Embedding for the mask tokens. + output_upscaling (nn.Sequential): Neural network sequence for upscaling the output. + output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks. + iou_prediction_head (nn.Module): MLP for predicting mask quality. + """ def __init__( self, @@ -49,8 +64,9 @@ class MaskDecoder(nn.Module): nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), activation(), ) - self.output_hypernetworks_mlps = nn.ModuleList([ - MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]) + self.output_hypernetworks_mlps = nn.ModuleList( + [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)] + ) self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth) @@ -98,10 +114,14 @@ class MaskDecoder(nn.Module): sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Predicts masks. See 'forward' for more details.""" + """ + Predicts masks. + + See 'forward' for more details. + """ # Concatenate output tokens output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) - output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1) tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # Expand per-image data in batch direction to be per-mask @@ -113,13 +133,14 @@ class MaskDecoder(nn.Module): # Run the transformer hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, 0, :] - mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] # Upscale mask embeddings and predict masks using the mask tokens src = src.transpose(1, 2).view(b, c, h, w) upscaled_embedding = self.output_upscaling(src) hyper_in_list: List[torch.Tensor] = [ - self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)] + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens) + ] hyper_in = torch.stack(hyper_in_list, dim=1) b, c, h, w = upscaled_embedding.shape masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) @@ -132,7 +153,7 @@ class MaskDecoder(nn.Module): class MLP(nn.Module): """ - Lightly adapted from + MLP (Multi-Layer Perceptron) model lightly adapted from https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py """ @@ -144,6 +165,16 @@ class MLP(nn.Module): num_layers: int, sigmoid_output: bool = False, ) -> None: + """ + Initializes the MLP (Multi-Layer Perceptron) model. + + Args: + input_dim (int): The dimensionality of the input features. + hidden_dim (int): The dimensionality of the hidden layers. + output_dim (int): The dimensionality of the output layer. + num_layers (int): The number of hidden layers. + sigmoid_output (bool, optional): Apply a sigmoid activation to the output layer. Defaults to False. + """ super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) diff --git a/ultralytics/models/sam/modules/encoders.py b/ultralytics/models/sam/modules/encoders.py index eb9352f..a51c347 100644 --- a/ultralytics/models/sam/modules/encoders.py +++ b/ultralytics/models/sam/modules/encoders.py @@ -10,27 +10,41 @@ import torch.nn.functional as F from ultralytics.nn.modules import LayerNorm2d, MLPBlock -# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa class ImageEncoderViT(nn.Module): + """ + An image encoder using Vision Transformer (ViT) architecture for encoding an image into a compact latent space. The + encoder takes an image, splits it into patches, and processes these patches through a series of transformer blocks. + The encoded patches are then processed through a neck to generate the final encoded representation. + + This class and its supporting functions below lightly adapted from the ViTDet backbone available at + https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py. + + Attributes: + img_size (int): Dimension of input images, assumed to be square. + patch_embed (PatchEmbed): Module for patch embedding. + pos_embed (nn.Parameter, optional): Absolute positional embedding for patches. + blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings. + neck (nn.Sequential): Neck module to further process the output. + """ def __init__( - self, - img_size: int = 1024, - patch_size: int = 16, - in_chans: int = 3, - embed_dim: int = 768, - depth: int = 12, - num_heads: int = 12, - mlp_ratio: float = 4.0, - out_chans: int = 256, - qkv_bias: bool = True, - norm_layer: Type[nn.Module] = nn.LayerNorm, - act_layer: Type[nn.Module] = nn.GELU, - use_abs_pos: bool = True, - use_rel_pos: bool = False, - rel_pos_zero_init: bool = True, - window_size: int = 0, - global_attn_indexes: Tuple[int, ...] = (), + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), ) -> None: """ Args: @@ -100,6 +114,9 @@ class ImageEncoderViT(nn.Module): ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Processes input through patch embedding, applies positional embedding if present, and passes through blocks + and neck. + """ x = self.patch_embed(x) if self.pos_embed is not None: x = x + self.pos_embed @@ -109,6 +126,22 @@ class ImageEncoderViT(nn.Module): class PromptEncoder(nn.Module): + """ + Encodes different types of prompts, including points, boxes, and masks, for input to SAM's mask decoder. The encoder + produces both sparse and dense embeddings for the input prompts. + + Attributes: + embed_dim (int): Dimension of the embeddings. + input_image_size (Tuple[int, int]): Size of the input image as (H, W). + image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W). + pe_layer (PositionEmbeddingRandom): Module for random position embedding. + num_point_embeddings (int): Number of point embeddings for different types of points. + point_embeddings (nn.ModuleList): List of point embeddings. + not_a_point_embed (nn.Embedding): Embedding for points that are not a part of any label. + mask_input_size (Tuple[int, int]): Size of the input mask. + mask_downscaling (nn.Sequential): Neural network for downscaling the mask. + no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided. + """ def __init__( self, @@ -157,20 +190,15 @@ class PromptEncoder(nn.Module): def get_dense_pe(self) -> torch.Tensor: """ - Returns the positional encoding used to encode point prompts, - applied to a dense set of points the shape of the image encoding. + Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the + image encoding. Returns: torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w) """ return self.pe_layer(self.image_embedding_size).unsqueeze(0) - def _embed_points( - self, - points: torch.Tensor, - labels: torch.Tensor, - pad: bool, - ) -> torch.Tensor: + def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor: """Embeds point prompts.""" points = points + 0.5 # Shift to center of pixel if pad: @@ -204,9 +232,7 @@ class PromptEncoder(nn.Module): boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor], ) -> int: - """ - Gets the batch size of the output given the batch size of the input prompts. - """ + """Gets the batch size of the output given the batch size of the input prompts.""" if points is not None: return points[0].shape[0] elif boxes is not None: @@ -217,6 +243,7 @@ class PromptEncoder(nn.Module): return 1 def _get_device(self) -> torch.device: + """Returns the device of the first point embedding's weight tensor.""" return self.point_embeddings[0].weight.device def forward( @@ -251,23 +278,22 @@ class PromptEncoder(nn.Module): if masks is not None: dense_embeddings = self._embed_masks(masks) else: - dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, - 1).expand(bs, -1, self.image_embedding_size[0], - self.image_embedding_size[1]) + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) return sparse_embeddings, dense_embeddings class PositionEmbeddingRandom(nn.Module): - """ - Positional encoding using random spatial frequencies. - """ + """Positional encoding using random spatial frequencies.""" def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + """Initializes a position embedding using random spatial frequencies.""" super().__init__() if scale is None or scale <= 0.0: scale = 1.0 - self.register_buffer('positional_encoding_gaussian_matrix', scale * torch.randn((2, num_pos_feats))) + self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats))) # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation' torch.use_deterministic_algorithms(False) @@ -275,11 +301,11 @@ class PositionEmbeddingRandom(nn.Module): def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: """Positionally encode points that are normalized to [0,1].""" - # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape coords = 2 * coords - 1 coords = coords @ self.positional_encoding_gaussian_matrix coords = 2 * np.pi * coords - # outputs d_1 x ... x d_n x C shape + # Outputs d_1 x ... x d_n x C shape return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) def forward(self, size: Tuple[int, int]) -> torch.Tensor: @@ -304,7 +330,7 @@ class PositionEmbeddingRandom(nn.Module): class Block(nn.Module): - """Transformer blocks with support of window attention and residual propagation blocks""" + """Transformer blocks with support of window attention and residual propagation blocks.""" def __init__( self, @@ -351,6 +377,7 @@ class Block(nn.Module): self.window_size = window_size def forward(self, x: torch.Tensor) -> torch.Tensor: + """Executes a forward pass through the transformer block with window attention and non-overlapping windows.""" shortcut = x x = self.norm1(x) # Window partition @@ -380,6 +407,8 @@ class Attention(nn.Module): input_size: Optional[Tuple[int, int]] = None, ) -> None: """ + Initialize Attention module. + Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. @@ -391,19 +420,20 @@ class Attention(nn.Module): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = head_dim ** -0.5 + self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.use_rel_pos = use_rel_pos if self.use_rel_pos: - assert (input_size is not None), 'Input size must be provided if using relative positional encoding.' - # initialize relative positional embeddings + assert input_size is not None, "Input size must be provided if using relative positional encoding." + # Initialize relative positional embeddings self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies the forward operation including attention, normalization, MLP, and indexing within window limits.""" B, H, W, _ = x.shape # qkv with shape (3, B, nHead, H * W, C) qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) @@ -444,10 +474,12 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T return windows, (Hp, Wp) -def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], - hw: Tuple[int, int]) -> torch.Tensor: +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: """ Window unpartition into original sequences and removing padding. + Args: windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. window_size (int): window size. @@ -470,8 +502,8 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[in def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: """ - Get relative positional embeddings according to the relative positions of - query and key sizes. + Get relative positional embeddings according to the relative positions of query and key sizes. + Args: q_size (int): size of query q. k_size (int): size of key k. @@ -487,7 +519,7 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor rel_pos_resized = F.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, - mode='linear', + mode="linear", ) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) else: @@ -510,8 +542,9 @@ def add_decomposed_rel_pos( k_size: Tuple[int, int], ) -> torch.Tensor: """ - Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. - https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Calculate decomposed Relative Positional Embeddings from mvitv2 paper at + https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py. + Args: attn (Tensor): attention map. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). @@ -530,29 +563,30 @@ def add_decomposed_rel_pos( B, _, dim = q.shape r_q = q.reshape(B, q_h, q_w, dim) - rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh) - rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( - B, q_h * q_w, k_h * k_w) + B, q_h * q_w, k_h * k_w + ) return attn class PatchEmbed(nn.Module): - """ - Image to Patch Embedding. - """ + """Image to Patch Embedding.""" def __init__( - self, - kernel_size: Tuple[int, int] = (16, 16), - stride: Tuple[int, int] = (16, 16), - padding: Tuple[int, int] = (0, 0), - in_chans: int = 3, - embed_dim: int = 768, + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, ) -> None: """ + Initialize PatchEmbed module. + Args: kernel_size (Tuple): kernel size of the projection layer. stride (Tuple): stride of the projection layer. @@ -565,4 +599,5 @@ class PatchEmbed(nn.Module): self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Computes patch embedding by applying convolution and transposing resulting tensor.""" return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py index 5649920..95d9bbe 100644 --- a/ultralytics/models/sam/modules/sam.py +++ b/ultralytics/models/sam/modules/sam.py @@ -16,8 +16,23 @@ from .encoders import ImageEncoderViT, PromptEncoder class Sam(nn.Module): + """ + Sam (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate image + embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by the mask + decoder to predict object masks. + + Attributes: + mask_threshold (float): Threshold value for mask prediction. + image_format (str): Format of the input image, default is 'RGB'. + image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings. + pixel_mean (List[float]): Mean pixel values for image normalization. + pixel_std (List[float]): Standard deviation values for image normalization. + """ + mask_threshold: float = 0.0 - image_format: str = 'RGB' + image_format: str = "RGB" def __init__( self, @@ -25,25 +40,26 @@ class Sam(nn.Module): prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder, pixel_mean: List[float] = (123.675, 116.28, 103.53), - pixel_std: List[float] = (58.395, 57.12, 57.375) + pixel_std: List[float] = (58.395, 57.12, 57.375), ) -> None: """ - SAM predicts object masks from an image and input prompts. + Initialize the Sam class to predict object masks from an image and input prompts. Note: All forward() operations moved to SAMPredictor. Args: - image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for - efficient mask prediction. - prompt_encoder (PromptEncoder): Encodes various types of input prompts. - mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. - pixel_mean (list(float)): Mean values for normalizing pixels in the input image. - pixel_std (list(float)): Std values for normalizing pixels in the input image. + image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. + pixel_mean (List[float], optional): Mean values for normalizing pixels in the input image. Defaults to + (123.675, 116.28, 103.53). + pixel_std (List[float], optional): Std values for normalizing pixels in the input image. Defaults to + (58.395, 57.12, 57.375). """ super().__init__() self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder self.mask_decoder = mask_decoder - self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False) - self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False) + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) diff --git a/ultralytics/models/sam/modules/tiny_encoder.py b/ultralytics/models/sam/modules/tiny_encoder.py index ca8de50..98f5ac0 100644 --- a/ultralytics/models/sam/modules/tiny_encoder.py +++ b/ultralytics/models/sam/modules/tiny_encoder.py @@ -21,19 +21,27 @@ from ultralytics.utils.instance import to_2tuple class Conv2d_BN(torch.nn.Sequential): + """A sequential container that performs 2D convolution followed by batch normalization.""" def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): + """Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and + drop path. + """ super().__init__() - self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) bn = torch.nn.BatchNorm2d(b) torch.nn.init.constant_(bn.weight, bn_weight_init) torch.nn.init.constant_(bn.bias, 0) - self.add_module('bn', bn) + self.add_module("bn", bn) class PatchEmbed(nn.Module): + """Embeds images into patches and projects them into a specified embedding dimension.""" def __init__(self, in_chans, embed_dim, resolution, activation): + """Initialize the PatchMerging class with specified input, output dimensions, resolution and activation + function. + """ super().__init__() img_size: Tuple[int, int] = to_2tuple(resolution) self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) @@ -48,12 +56,17 @@ class PatchEmbed(nn.Module): ) def forward(self, x): + """Runs input tensor 'x' through the PatchMerging model's sequence of operations.""" return self.seq(x) class MBConv(nn.Module): + """Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.""" def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path): + """Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation + function. + """ super().__init__() self.in_chans = in_chans self.hidden_chans = int(in_chans * expand_ratio) @@ -73,6 +86,7 @@ class MBConv(nn.Module): self.drop_path = nn.Identity() def forward(self, x): + """Implements the forward pass for the model architecture.""" shortcut = x x = self.conv1(x) x = self.act1(x) @@ -85,8 +99,12 @@ class MBConv(nn.Module): class PatchMerging(nn.Module): + """Merges neighboring patches in the feature map and projects to a new dimension.""" def __init__(self, input_resolution, dim, out_dim, activation): + """Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other + optional parameters. + """ super().__init__() self.input_resolution = input_resolution @@ -99,6 +117,7 @@ class PatchMerging(nn.Module): self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) def forward(self, x): + """Applies forward pass on the input utilizing convolution and activation layers, and returns the result.""" if x.ndim == 3: H, W = self.input_resolution B = len(x) @@ -115,6 +134,11 @@ class PatchMerging(nn.Module): class ConvLayer(nn.Module): + """ + Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv). + + Optionally applies downsample operations to the output, and provides support for gradient checkpointing. + """ def __init__( self, @@ -122,41 +146,69 @@ class ConvLayer(nn.Module): input_resolution, depth, activation, - drop_path=0., + drop_path=0.0, downsample=None, use_checkpoint=False, out_dim=None, - conv_expand_ratio=4., + conv_expand_ratio=4.0, ): + """ + Initializes the ConvLayer with the given dimensions and settings. + + Args: + dim (int): The dimensionality of the input and output. + input_resolution (Tuple[int, int]): The resolution of the input image. + depth (int): The number of MBConv layers in the block. + activation (Callable): Activation function applied after each convolution. + drop_path (Union[float, List[float]]): Drop path rate. Single float or a list of floats for each MBConv. + downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling. + use_checkpoint (bool): Whether to use gradient checkpointing to save memory. + out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`. + conv_expand_ratio (float): Expansion ratio for the MBConv layers. + """ super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint - # build blocks - self.blocks = nn.ModuleList([ - MBConv( - dim, - dim, - conv_expand_ratio, - activation, - drop_path[i] if isinstance(drop_path, list) else drop_path, - ) for i in range(depth)]) + # Build blocks + self.blocks = nn.ModuleList( + [ + MBConv( + dim, + dim, + conv_expand_ratio, + activation, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth) + ] + ) - # patch merging layer - self.downsample = None if downsample is None else downsample( - input_resolution, dim=dim, out_dim=out_dim, activation=activation) + # Patch merging layer + self.downsample = ( + None + if downsample is None + else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation) + ) def forward(self, x): + """Processes the input through a series of convolutional layers and returns the activated output.""" for blk in self.blocks: x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x) return x if self.downsample is None else self.downsample(x) class Mlp(nn.Module): + """ + Multi-layer Perceptron (MLP) for transformer architectures. - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + This layer takes an input with in_features, applies layer normalization and two fully-connected layers. + """ + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + """Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc.""" super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -167,6 +219,7 @@ class Mlp(nn.Module): self.drop = nn.Dropout(drop) def forward(self, x): + """Applies operations on input x and returns modified x, runs downsample if not None.""" x = self.norm(x) x = self.fc1(x) x = self.act(x) @@ -176,20 +229,41 @@ class Mlp(nn.Module): class Attention(torch.nn.Module): + """ + Multi-head attention module with support for spatial awareness, applying attention biases based on spatial + resolution. Implements trainable attention biases for each unique offset between spatial positions in the resolution + grid. + + Attributes: + ab (Tensor, optional): Cached attention biases for inference, deleted during training. + """ def __init__( - self, - dim, - key_dim, - num_heads=8, - attn_ratio=4, - resolution=(14, 14), + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=(14, 14), ): + """ + Initializes the Attention module. + + Args: + dim (int): The dimensionality of the input and output. + key_dim (int): The dimensionality of the keys and queries. + num_heads (int, optional): Number of attention heads. Default is 8. + attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors. Default is 4. + resolution (Tuple[int, int], optional): Spatial resolution of the input feature map. Default is (14, 14). + + Raises: + AssertionError: If `resolution` is not a tuple of length 2. + """ super().__init__() - # (h, w) + assert isinstance(resolution, tuple) and len(resolution) == 2 self.num_heads = num_heads - self.scale = key_dim ** -0.5 + self.scale = key_dim**-0.5 self.key_dim = key_dim self.nh_kd = nh_kd = key_dim * num_heads self.d = int(attn_ratio * key_dim) @@ -212,18 +286,20 @@ class Attention(torch.nn.Module): attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) - self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False) + self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False) @torch.no_grad() def train(self, mode=True): + """Sets the module in training mode and handles attribute 'ab' based on the mode.""" super().train(mode) - if mode and hasattr(self, 'ab'): + if mode and hasattr(self, "ab"): del self.ab else: self.ab = self.attention_biases[:, self.attention_bias_idxs] - def forward(self, x): # x (B,N,C) - B, N, _ = x.shape + def forward(self, x): # x + """Performs forward pass over the input tensor 'x' by applying normalization and querying keys/values.""" + B, N, _ = x.shape # B, N, C # Normalization x = self.norm(x) @@ -237,28 +313,16 @@ class Attention(torch.nn.Module): v = v.permute(0, 2, 1, 3) self.ab = self.ab.to(self.attention_biases.device) - attn = ((q @ k.transpose(-2, -1)) * self.scale + - (self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab)) + attn = (q @ k.transpose(-2, -1)) * self.scale + ( + self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + ) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) return self.proj(x) class TinyViTBlock(nn.Module): - """ - TinyViT Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int, int]): Input resolution. - num_heads (int): Number of attention heads. - window_size (int): Window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - drop (float, optional): Dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - local_conv_size (int): the kernel size of the convolution between Attention and MLP. Default: 3 - activation (torch.nn): the activation function. Default: nn.GELU - """ + """TinyViT Block that applies self-attention and a local convolution to the input.""" def __init__( self, @@ -266,17 +330,35 @@ class TinyViTBlock(nn.Module): input_resolution, num_heads, window_size=7, - mlp_ratio=4., - drop=0., - drop_path=0., + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, local_conv_size=3, activation=nn.GELU, ): + """ + Initializes the TinyViTBlock. + + Args: + dim (int): The dimensionality of the input and output. + input_resolution (Tuple[int, int]): Spatial resolution of the input feature map. + num_heads (int): Number of attention heads. + window_size (int, optional): Window size for attention. Default is 7. + mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4. + drop (float, optional): Dropout rate. Default is 0. + drop_path (float, optional): Stochastic depth rate. Default is 0. + local_conv_size (int, optional): The kernel size of the local convolution. Default is 3. + activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU. + + Raises: + AssertionError: If `window_size` is not greater than 0. + AssertionError: If `dim` is not divisible by `num_heads`. + """ super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads - assert window_size > 0, 'window_size must be greater than 0' + assert window_size > 0, "window_size must be greater than 0" self.window_size = window_size self.mlp_ratio = mlp_ratio @@ -284,7 +366,7 @@ class TinyViTBlock(nn.Module): # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = nn.Identity() - assert dim % num_heads == 0, 'dim must be divisible by num_heads' + assert dim % num_heads == 0, "dim must be divisible by num_heads" head_dim = dim // num_heads window_resolution = (window_size, window_size) @@ -298,9 +380,12 @@ class TinyViTBlock(nn.Module): self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) def forward(self, x): + """Applies attention-based transformation or padding to input 'x' before passing it through a local + convolution. + """ H, W = self.input_resolution B, L, C = x.shape - assert L == H * W, 'input feature has wrong size' + assert L == H * W, "input feature has wrong size" res_x = x if H == self.window_size and W == self.window_size: x = self.attn(x) @@ -316,11 +401,14 @@ class TinyViTBlock(nn.Module): pH, pW = H + pad_b, W + pad_r nH = pH // self.window_size nW = pW // self.window_size - # window partition - x = x.view(B, nH, self.window_size, nW, self.window_size, - C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C) + # Window partition + x = ( + x.view(B, nH, self.window_size, nW, self.window_size, C) + .transpose(2, 3) + .reshape(B * nH * nW, self.window_size * self.window_size, C) + ) x = self.attn(x) - # window reverse + # Window reverse x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C) if padding: @@ -337,29 +425,17 @@ class TinyViTBlock(nn.Module): return x + self.drop_path(self.mlp(x)) def extra_repr(self) -> str: - return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \ - f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}' + """Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of + attentions heads, window size, and MLP ratio. + """ + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + ) class BasicLayer(nn.Module): - """ - A basic TinyViT layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - drop (float, optional): Dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - local_conv_size (int): the kernel size of the depthwise convolution between attention and MLP. Default: 3 - activation (torch.nn): the activation function. Default: nn.GELU - out_dim (int | optional): the output dimension of the layer. Default: None - """ + """A basic TinyViT layer for one stage in a TinyViT architecture.""" def __init__( self, @@ -368,57 +444,90 @@ class BasicLayer(nn.Module): depth, num_heads, window_size, - mlp_ratio=4., - drop=0., - drop_path=0., + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, downsample=None, use_checkpoint=False, local_conv_size=3, activation=nn.GELU, out_dim=None, ): + """ + Initializes the BasicLayer. + + Args: + dim (int): The dimensionality of the input and output. + input_resolution (Tuple[int, int]): Spatial resolution of the input feature map. + depth (int): Number of TinyViT blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4. + drop (float, optional): Dropout rate. Default is 0. + drop_path (float | tuple[float], optional): Stochastic depth rate. Default is 0. + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default is None. + use_checkpoint (bool, optional): Whether to use checkpointing to save memory. Default is False. + local_conv_size (int, optional): Kernel size of the local convolution. Default is 3. + activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU. + out_dim (int | None, optional): The output dimension of the layer. Default is None. + + Raises: + ValueError: If `drop_path` is a list of float but its length doesn't match `depth`. + """ super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint - # build blocks - self.blocks = nn.ModuleList([ - TinyViTBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - drop=drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - local_conv_size=local_conv_size, - activation=activation, - ) for i in range(depth)]) + # Build blocks + self.blocks = nn.ModuleList( + [ + TinyViTBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + local_conv_size=local_conv_size, + activation=activation, + ) + for i in range(depth) + ] + ) - # patch merging layer - self.downsample = None if downsample is None else downsample( - input_resolution, dim=dim, out_dim=out_dim, activation=activation) + # Patch merging layer + self.downsample = ( + None + if downsample is None + else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation) + ) def forward(self, x): + """Performs forward propagation on the input tensor and returns a normalized tensor.""" for blk in self.blocks: x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x) return x if self.downsample is None else self.downsample(x) def extra_repr(self) -> str: - return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}' + """Returns a string representation of the extra_repr function with the layer's parameters.""" + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" class LayerNorm2d(nn.Module): + """A PyTorch implementation of Layer Normalization in 2D.""" def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + """Initialize LayerNorm2d with the number of channels and an optional epsilon.""" super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform a forward pass, normalizing the input tensor.""" u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) @@ -426,6 +535,30 @@ class LayerNorm2d(nn.Module): class TinyViT(nn.Module): + """ + The TinyViT architecture for vision tasks. + + Attributes: + img_size (int): Input image size. + in_chans (int): Number of input channels. + num_classes (int): Number of classification classes. + embed_dims (List[int]): List of embedding dimensions for each layer. + depths (List[int]): List of depths for each layer. + num_heads (List[int]): List of number of attention heads for each layer. + window_sizes (List[int]): List of window sizes for each layer. + mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. + drop_rate (float): Dropout rate for drop layers. + drop_path_rate (float): Drop path rate for stochastic depth. + use_checkpoint (bool): Use checkpointing for efficient memory usage. + mbconv_expand_ratio (float): Expansion ratio for MBConv layer. + local_conv_size (int): Local convolution kernel size. + layer_lr_decay (float): Layer-wise learning rate decay. + + Note: + This implementation is generalized to accept a list of depths, attention heads, + embedding dimensions and window sizes, which allows you to create a + "stack" of TinyViT models of varying configurations. + """ def __init__( self, @@ -436,14 +569,33 @@ class TinyViT(nn.Module): depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_sizes=[7, 7, 14, 7], - mlp_ratio=4., - drop_rate=0., + mlp_ratio=4.0, + drop_rate=0.0, drop_path_rate=0.1, use_checkpoint=False, mbconv_expand_ratio=4.0, local_conv_size=3, layer_lr_decay=1.0, ): + """ + Initializes the TinyViT model. + + Args: + img_size (int, optional): The input image size. Defaults to 224. + in_chans (int, optional): Number of input channels. Defaults to 3. + num_classes (int, optional): Number of classification classes. Defaults to 1000. + embed_dims (List[int], optional): List of embedding dimensions for each layer. Defaults to [96, 192, 384, 768]. + depths (List[int], optional): List of depths for each layer. Defaults to [2, 2, 6, 2]. + num_heads (List[int], optional): List of number of attention heads for each layer. Defaults to [3, 6, 12, 24]. + window_sizes (List[int], optional): List of window sizes for each layer. Defaults to [7, 7, 14, 7]. + mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension. Defaults to 4. + drop_rate (float, optional): Dropout rate. Defaults to 0. + drop_path_rate (float, optional): Drop path rate for stochastic depth. Defaults to 0.1. + use_checkpoint (bool, optional): Whether to use checkpointing for efficient memory usage. Defaults to False. + mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer. Defaults to 4.0. + local_conv_size (int, optional): Local convolution kernel size. Defaults to 3. + layer_lr_decay (float, optional): Layer-wise learning rate decay. Defaults to 1.0. + """ super().__init__() self.img_size = img_size self.num_classes = num_classes @@ -453,50 +605,52 @@ class TinyViT(nn.Module): activation = nn.GELU - self.patch_embed = PatchEmbed(in_chans=in_chans, - embed_dim=embed_dims[0], - resolution=img_size, - activation=activation) + self.patch_embed = PatchEmbed( + in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation + ) patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution - # stochastic depth + # Stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - # build layers + # Build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): kwargs = dict( dim=embed_dims[i_layer], - input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), - patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))), + input_resolution=( + patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), + patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), + ), # input_resolution=(patches_resolution[0] // (2 ** i_layer), # patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, - out_dim=embed_dims[min(i_layer + 1, - len(embed_dims) - 1)], + out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)], activation=activation, ) if i_layer == 0: layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs) else: - layer = BasicLayer(num_heads=num_heads[i_layer], - window_size=window_sizes[i_layer], - mlp_ratio=self.mlp_ratio, - drop=drop_rate, - local_conv_size=local_conv_size, - **kwargs) + layer = BasicLayer( + num_heads=num_heads[i_layer], + window_size=window_sizes[i_layer], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + **kwargs, + ) self.layers.append(layer) # Classifier head self.norm_head = nn.LayerNorm(embed_dims[-1]) self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity() - # init weights + # Init weights self.apply(self._init_weights) self.set_layer_lr_decay(layer_lr_decay) self.neck = nn.Sequential( @@ -518,13 +672,15 @@ class TinyViT(nn.Module): ) def set_layer_lr_decay(self, layer_lr_decay): + """Sets the learning rate decay for each layer in the TinyViT model.""" decay_rate = layer_lr_decay - # layers -> blocks (depth) + # Layers -> blocks (depth) depth = sum(self.depths) lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] def _set_lr_scale(m, scale): + """Sets the learning rate scale for each layer in the model based on the layer's depth.""" for p in m.parameters(): p.lr_scale = scale @@ -544,12 +700,14 @@ class TinyViT(nn.Module): p.param_name = k def _check_lr_scale(m): + """Checks if the learning rate scale attribute is present in module's parameters.""" for p in m.parameters(): - assert hasattr(p, 'lr_scale'), p.param_name + assert hasattr(p, "lr_scale"), p.param_name self.apply(_check_lr_scale) def _init_weights(self, m): + """Initializes weights for linear layers and layer normalization in the given module.""" if isinstance(m, nn.Linear): # NOTE: This initialization is needed only for training. # trunc_normal_(m.weight, std=.02) @@ -561,11 +719,12 @@ class TinyViT(nn.Module): @torch.jit.ignore def no_weight_decay_keywords(self): - return {'attention_biases'} + """Returns a dictionary of parameter names where weight decay should not be applied.""" + return {"attention_biases"} def forward_features(self, x): - # x: (N, C, H, W) - x = self.patch_embed(x) + """Runs the input through the model layers and returns the transformed output.""" + x = self.patch_embed(x) # x input is (N, C, H, W) x = self.layers[0](x) start_i = 1 @@ -573,10 +732,11 @@ class TinyViT(nn.Module): for i in range(start_i, len(self.layers)): layer = self.layers[i] x = layer(x) - B, _, C = x.size() + B, _, C = x.shape x = x.view(B, 64, 64, C) x = x.permute(0, 3, 1, 2) return self.neck(x) def forward(self, x): + """Executes a forward pass on the input tensor through the constructed model layers.""" return self.forward_features(x) diff --git a/ultralytics/models/sam/modules/transformer.py b/ultralytics/models/sam/modules/transformer.py index f925538..1ad0741 100644 --- a/ultralytics/models/sam/modules/transformer.py +++ b/ultralytics/models/sam/modules/transformer.py @@ -10,6 +10,21 @@ from ultralytics.nn.modules import MLPBlock class TwoWayTransformer(nn.Module): + """ + A Two-Way Transformer module that enables the simultaneous attention to both image and query points. This class + serves as a specialized transformer decoder that attends to an input image using queries whose positional embedding + is supplied. This is particularly useful for tasks like object detection, image segmentation, and point cloud + processing. + + Attributes: + depth (int): The number of layers in the transformer. + embedding_dim (int): The channel dimension for the input embeddings. + num_heads (int): The number of heads for multihead attention. + mlp_dim (int): The internal channel dimension for the MLP block. + layers (nn.ModuleList): The list of TwoWayAttentionBlock layers that make up the transformer. + final_attn_token_to_image (Attention): The final attention layer applied from the queries to the image. + norm_final_attn (nn.LayerNorm): The layer normalization applied to the final queries. + """ def __init__( self, @@ -21,8 +36,7 @@ class TwoWayTransformer(nn.Module): attention_downsample_rate: int = 2, ) -> None: """ - A transformer decoder that attends to an input image using - queries whose positional embedding is supplied. + A transformer decoder that attends to an input image using queries whose positional embedding is supplied. Args: depth (int): number of layers in the transformer @@ -48,7 +62,8 @@ class TwoWayTransformer(nn.Module): activation=activation, attention_downsample_rate=attention_downsample_rate, skip_first_layer_pe=(i == 0), - )) + ) + ) self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) self.norm_final_attn = nn.LayerNorm(embedding_dim) @@ -99,6 +114,23 @@ class TwoWayTransformer(nn.Module): class TwoWayAttentionBlock(nn.Module): + """ + An attention block that performs both self-attention and cross-attention in two directions: queries to keys and + keys to queries. This block consists of four main layers: (1) self-attention on sparse inputs, (2) cross-attention + of sparse inputs to dense inputs, (3) an MLP block on sparse inputs, and (4) cross-attention of dense inputs to + sparse inputs. + + Attributes: + self_attn (Attention): The self-attention layer for the queries. + norm1 (nn.LayerNorm): Layer normalization following the first attention block. + cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys. + norm2 (nn.LayerNorm): Layer normalization following the second attention block. + mlp (MLPBlock): MLP block that transforms the query embeddings. + norm3 (nn.LayerNorm): Layer normalization following the MLP block. + norm4 (nn.LayerNorm): Layer normalization following the third attention block. + cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries. + skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer. + """ def __init__( self, @@ -171,8 +203,7 @@ class TwoWayAttentionBlock(nn.Module): class Attention(nn.Module): - """ - An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + """An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and values. """ @@ -182,24 +213,37 @@ class Attention(nn.Module): num_heads: int, downsample_rate: int = 1, ) -> None: + """ + Initializes the Attention model with the given dimensions and settings. + + Args: + embedding_dim (int): The dimensionality of the input embeddings. + num_heads (int): The number of attention heads. + downsample_rate (int, optional): The factor by which the internal dimensions are downsampled. Defaults to 1. + + Raises: + AssertionError: If 'num_heads' does not evenly divide the internal dimension (embedding_dim / downsample_rate). + """ super().__init__() self.embedding_dim = embedding_dim self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads - assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.' + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." self.q_proj = nn.Linear(embedding_dim, self.internal_dim) self.k_proj = nn.Linear(embedding_dim, self.internal_dim) self.v_proj = nn.Linear(embedding_dim, self.internal_dim) self.out_proj = nn.Linear(self.internal_dim, embedding_dim) - def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + @staticmethod + def _separate_heads(x: Tensor, num_heads: int) -> Tensor: """Separate the input tensor into the specified number of attention heads.""" b, n, c = x.shape x = x.reshape(b, n, num_heads, c // num_heads) return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head - def _recombine_heads(self, x: Tensor) -> Tensor: + @staticmethod + def _recombine_heads(x: Tensor) -> Tensor: """Recombine the separated attention heads into a single tensor.""" b, n_heads, n_tokens, c_per_head = x.shape x = x.transpose(1, 2) diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py index e8a8197..63ca632 100644 --- a/ultralytics/models/sam/predict.py +++ b/ultralytics/models/sam/predict.py @@ -1,4 +1,12 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license +""" +Generate predictions using the Segment Anything Model (SAM). + +SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance. +This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation +using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image +segmentation tasks. +""" import numpy as np import torch @@ -10,129 +18,155 @@ from ultralytics.engine.predictor import BasePredictor from ultralytics.engine.results import Results from ultralytics.utils import DEFAULT_CFG, ops from ultralytics.utils.torch_utils import select_device - -from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score, - generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks) +from .amg import ( + batch_iterator, + batched_mask_to_box, + build_all_layer_point_grids, + calculate_stability_score, + generate_crop_boxes, + is_box_near_crop_edge, + remove_small_regions, + uncrop_boxes_xyxy, + uncrop_masks, +) from .build import build_sam class Predictor(BasePredictor): + """ + Predictor class for the Segment Anything Model (SAM), extending BasePredictor. + + The class provides an interface for model inference tailored to image segmentation tasks. + With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time + mask generation. The class is capable of working with various types of prompts such as bounding boxes, + points, and low-resolution masks. + + Attributes: + cfg (dict): Configuration dictionary specifying model and task-related parameters. + overrides (dict): Dictionary containing values that override the default configuration. + _callbacks (dict): Dictionary of user-defined callback functions to augment behavior. + args (namespace): Namespace to hold command-line arguments or other operational variables. + im (torch.Tensor): Preprocessed input image tensor. + features (torch.Tensor): Extracted image features used for inference. + prompts (dict): Collection of various prompt types, such as bounding boxes and points. + segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones. + """ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the Predictor with configuration, overrides, and callbacks. + + The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It + initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results. + + Args: + cfg (dict): Configuration dictionary. + overrides (dict, optional): Dictionary of values to override default configuration. + _callbacks (dict, optional): Dictionary of callback functions to customize behavior. + """ if overrides is None: overrides = {} - overrides.update(dict(task='segment', mode='predict', imgsz=1024)) + overrides.update(dict(task="segment", mode="predict", imgsz=1024)) super().__init__(cfg, overrides, _callbacks) - # SAM needs retina_masks=True, or the results would be a mess. self.args.retina_masks = True - # Args for set_image self.im = None self.features = None - # Args for set_prompts self.prompts = {} - # Args for segment everything self.segment_all = False def preprocess(self, im): - """Prepares input image before inference. + """ + Preprocess the input image for model inference. + + The method prepares the input image by applying transformations and normalization. + It supports both torch.Tensor and list of np.ndarray as input formats. Args: - im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list. + im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays. + + Returns: + (torch.Tensor): The preprocessed image tensor. """ if self.im is not None: return self.im not_tensor = not isinstance(im, torch.Tensor) if not_tensor: im = np.stack(self.pre_transform(im)) - im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w) - im = np.ascontiguousarray(im) # contiguous + im = im[..., ::-1].transpose((0, 3, 1, 2)) + im = np.ascontiguousarray(im) im = torch.from_numpy(im) im = im.to(self.device) - im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 + im = im.half() if self.model.fp16 else im.float() if not_tensor: im = (im - self.mean) / self.std return im def pre_transform(self, im): """ - Pre-transform input image before inference. + Perform initial transformations on the input image for preprocessing. + + The method applies transformations such as resizing to prepare the image for further preprocessing. + Currently, batched inference is not supported; hence the list length should be 1. Args: - im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. + im (List[np.ndarray]): List containing images in HWC numpy array format. Returns: - (list): A list of transformed images. + (List[np.ndarray]): List of transformed images. """ - assert len(im) == 1, 'SAM model does not currently support batched inference' + assert len(im) == 1, "SAM model does not currently support batched inference" letterbox = LetterBox(self.args.imgsz, auto=False, center=False) return [letterbox(image=x) for x in im] def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs): """ - Predict masks for the given input prompts, using the currently set image. + Perform image segmentation inference based on the given input cues, using the currently loaded image. This + method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and + mask decoder for real-time and promptable segmentation tasks. Args: - im (torch.Tensor): The preprocessed image, (N, C, H, W). - bboxes (np.ndarray | List, None): (N, 4), in XYXY format. - points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels. - labels (np.ndarray | List, None): (N, ), labels for the point prompts. - 1 indicates a foreground point and 0 indicates a background point. - masks (np.ndarray, None): A low resolution mask input to the model, typically - coming from a previous prediction iteration. Has form (N, H, W), where - for SAM, H=W=256. - multimask_output (bool): If true, the model will return three masks. - For ambiguous input prompts (such as a single click), this will often - produce better masks than a single prediction. If only a single - mask is needed, the model's predicted quality score can be used - to select the best mask. For non-ambiguous prompts, such as multiple - input prompts, multimask_output=False can give better results. + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates. + labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background. + masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256. + multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False. Returns: - (np.ndarray): The output masks in CxHxW format, where C is the - number of masks, and (H, W) is the original image size. - (np.ndarray): An array of length C containing the model's - predictions for the quality of each mask. - (np.ndarray): An array of shape CxHxW, where C is the number - of masks and H=W=256. These low resolution logits can be passed to - a subsequent iteration as mask input. + (tuple): Contains the following three elements. + - np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks. + - np.ndarray: An array of length C containing quality scores predicted by the model for each mask. + - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256. """ - # Get prompts from self.prompts first - bboxes = self.prompts.pop('bboxes', bboxes) - points = self.prompts.pop('points', points) - masks = self.prompts.pop('masks', masks) + # Override prompts if any stored in self.prompts + bboxes = self.prompts.pop("bboxes", bboxes) + points = self.prompts.pop("points", points) + masks = self.prompts.pop("masks", masks) + if all(i is None for i in [bboxes, points, masks]): return self.generate(im, *args, **kwargs) + return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False): """ - Predict masks for the given input prompts, using the currently set image. + Internal function for image segmentation inference based on cues like bounding boxes, points, and masks. + Leverages SAM's specialized architecture for prompt-based, real-time segmentation. Args: - im (torch.Tensor): The preprocessed image, (N, C, H, W). - bboxes (np.ndarray | List, None): (N, 4), in XYXY format. - points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels. - labels (np.ndarray | List, None): (N, ), labels for the point prompts. - 1 indicates a foreground point and 0 indicates a background point. - masks (np.ndarray, None): A low resolution mask input to the model, typically - coming from a previous prediction iteration. Has form (N, H, W), where - for SAM, H=W=256. - multimask_output (bool): If true, the model will return three masks. - For ambiguous input prompts (such as a single click), this will often - produce better masks than a single prediction. If only a single - mask is needed, the model's predicted quality score can be used - to select the best mask. For non-ambiguous prompts, such as multiple - input prompts, multimask_output=False can give better results. + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates. + labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background. + masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256. + multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False. Returns: - (np.ndarray): The output masks in CxHxW format, where C is the - number of masks, and (H, W) is the original image size. - (np.ndarray): An array of length C containing the model's - predictions for the quality of each mask. - (np.ndarray): An array of shape CxHxW, where C is the number - of masks and H=W=256. These low resolution logits can be passed to - a subsequent iteration as mask input. + (tuple): Contains the following three elements. + - np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks. + - np.ndarray: An array of length C containing quality scores predicted by the model for each mask. + - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256. """ features = self.model.image_encoder(im) if self.features is None else self.features @@ -158,11 +192,7 @@ class Predictor(BasePredictor): points = (points, labels) if points is not None else None # Embed prompts - sparse_embeddings, dense_embeddings = self.model.prompt_encoder( - points=points, - boxes=bboxes, - masks=masks, - ) + sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks) # Predict masks pred_masks, pred_scores = self.model.mask_decoder( @@ -177,58 +207,50 @@ class Predictor(BasePredictor): # `d` could be 1 or 3 depends on `multimask_output`. return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) - def generate(self, - im, - crop_n_layers=0, - crop_overlap_ratio=512 / 1500, - crop_downscale_factor=1, - point_grids=None, - points_stride=32, - points_batch_size=64, - conf_thres=0.88, - stability_score_thresh=0.95, - stability_score_offset=0.95, - crop_nms_thresh=0.7): - """Segment the whole image. + def generate( + self, + im, + crop_n_layers=0, + crop_overlap_ratio=512 / 1500, + crop_downscale_factor=1, + point_grids=None, + points_stride=32, + points_batch_size=64, + conf_thres=0.88, + stability_score_thresh=0.95, + stability_score_offset=0.95, + crop_nms_thresh=0.7, + ): + """ + Perform image segmentation using the Segment Anything Model (SAM). + + This function segments an entire image into constituent parts by leveraging SAM's advanced architecture + and real-time performance capabilities. It can optionally work on image crops for finer segmentation. Args: - im (torch.Tensor): The preprocessed image, (N, C, H, W). - crop_n_layers (int): If >0, mask prediction will be run again on - crops of the image. Sets the number of layers to run, where each - layer has 2**i_layer number of image crops. - crop_overlap_ratio (float): Sets the degree to which crops overlap. - In the first crop layer, crops will overlap by this fraction of - the image length. Later layers with more crops scale down this overlap. - crop_downscale_factor (int): The number of points-per-side - sampled in layer n is scaled down by crop_n_points_downscale_factor**n. - point_grids (list(np.ndarray), None): A list over explicit grids - of points used for sampling, normalized to [0,1]. The nth grid in the - list is used in the nth crop layer. Exclusive with points_per_side. - points_stride (int, None): The number of points to be sampled - along one side of the image. The total number of points is - points_per_side**2. If None, 'point_grids' must provide explicit - point sampling. - points_batch_size (int): Sets the number of points run simultaneously - by the model. Higher numbers may be faster but use more GPU memory. - conf_thres (float): A filtering threshold in [0,1], using the - model's predicted mask quality. - stability_score_thresh (float): A filtering threshold in [0,1], using - the stability of the mask under changes to the cutoff used to binarize - the model's mask predictions. - stability_score_offset (float): The amount to shift the cutoff when - calculated the stability score. - crop_nms_thresh (float): The box IoU cutoff used by non-maximal - suppression to filter duplicate masks between different crops. + im (torch.Tensor): Input tensor representing the preprocessed image with dimensions (N, C, H, W). + crop_n_layers (int): Specifies the number of layers for additional mask predictions on image crops. + Each layer produces 2**i_layer number of image crops. + crop_overlap_ratio (float): Determines the extent of overlap between crops. Scaled down in subsequent layers. + crop_downscale_factor (int): Scaling factor for the number of sampled points-per-side in each layer. + point_grids (list[np.ndarray], optional): Custom grids for point sampling normalized to [0,1]. + Used in the nth crop layer. + points_stride (int, optional): Number of points to sample along each side of the image. + Exclusive with 'point_grids'. + points_batch_size (int): Batch size for the number of points processed simultaneously. + conf_thres (float): Confidence threshold [0,1] for filtering based on the model's mask quality prediction. + stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on mask stability. + stability_score_offset (float): Offset value for calculating stability score. + crop_nms_thresh (float): IoU cutoff for Non-Maximum Suppression (NMS) to remove duplicate masks between crops. + + Returns: + (tuple): A tuple containing segmented masks, confidence scores, and bounding boxes. """ self.segment_all = True ih, iw = im.shape[2:] crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio) if point_grids is None: - point_grids = build_all_layer_point_grids( - points_stride, - crop_n_layers, - crop_downscale_factor, - ) + point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor) pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], [] for crop_region, layer_idx in zip(crop_regions, layer_idxs): x1, y1, x2, y2 = crop_region @@ -236,19 +258,20 @@ class Predictor(BasePredictor): area = torch.tensor(w * h, device=im.device) points_scale = np.array([[w, h]]) # w, h # Crop image and interpolate to input size - crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False) + crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False) # (num_points, 2) points_for_image = point_grids[layer_idx] * points_scale crop_masks, crop_scores, crop_bboxes = [], [], [] - for (points, ) in batch_iterator(points_batch_size, points_for_image): + for (points,) in batch_iterator(points_batch_size, points_for_image): pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True) # Interpolate predicted masks to input size - pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0] + pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0] idx = pred_score > conf_thres pred_mask, pred_score = pred_mask[idx], pred_score[idx] - stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold, - stability_score_offset) + stability_score = calculate_stability_score( + pred_mask, self.model.mask_threshold, stability_score_offset + ) idx = stability_score > stability_score_thresh pred_mask, pred_score = pred_mask[idx], pred_score[idx] # Bool type is much more memory-efficient. @@ -291,7 +314,22 @@ class Predictor(BasePredictor): return pred_masks, pred_scores, pred_bboxes def setup_model(self, model, verbose=True): - """Set up YOLO model with specified thresholds and device.""" + """ + Initializes the Segment Anything Model (SAM) for inference. + + This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary + parameters for image normalization and other Ultralytics compatibility settings. + + Args: + model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration. + verbose (bool): If True, prints selected device information. + + Attributes: + model (torch.nn.Module): The SAM model allocated to the chosen device for inference. + device (torch.device): The device to which the model and tensors are allocated. + mean (torch.Tensor): The mean values for image normalization. + std (torch.Tensor): The standard deviation values for image normalization. + """ device = select_device(self.args.device, verbose=verbose) if model is None: model = build_sam(self.args.model) @@ -300,7 +338,8 @@ class Predictor(BasePredictor): self.device = device self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device) self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device) - # TODO: Temporary settings for compatibility + + # Ultralytics compatibility settings self.model.pt = False self.model.triton = False self.model.stride = 32 @@ -308,7 +347,20 @@ class Predictor(BasePredictor): self.done_warmup = True def postprocess(self, preds, img, orig_imgs): - """Post-processes inference output predictions to create detection masks for objects.""" + """ + Post-processes SAM's inference outputs to generate object detection masks and bounding boxes. + + The method scales masks and boxes to the original image size and applies a threshold to the mask predictions. The + SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance. + + Args: + preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes. + img (torch.Tensor): The processed input image tensor. + orig_imgs (list | torch.Tensor): The original, unprocessed images. + + Returns: + (list): List of Results objects containing detection masks, bounding boxes, and other metadata. + """ # (N, 1, H, W), (N, 1) pred_masks, pred_scores = preds[:2] pred_bboxes = preds[2] if self.segment_all else None @@ -334,21 +386,36 @@ class Predictor(BasePredictor): return results def setup_source(self, source): - """Sets up source and inference mode.""" + """ + Sets up the data source for inference. + + This method configures the data source from which images will be fetched for inference. The source could be a + directory, a video file, or other types of image data sources. + + Args: + source (str | Path): The path to the image data source for inference. + """ if source is not None: super().setup_source(source) def set_image(self, image): - """Set image in advance. - Args: + """ + Preprocesses and sets a single image for inference. - image (str | np.ndarray): image file path or np.ndarray image by cv2. + This function sets up the model if not already initialized, configures the data source to the specified image, + and preprocesses the image for feature extraction. Only one image can be set at a time. + + Args: + image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2. + + Raises: + AssertionError: If more than one image is set. """ if self.model is None: model = build_sam(self.args.model) self.setup_model(model) self.setup_source(image) - assert len(self.dataset) == 1, '`set_image` only supports setting one image!' + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" for batch in self.dataset: im = self.preprocess(batch[1]) self.features = self.model.image_encoder(im) @@ -360,23 +427,27 @@ class Predictor(BasePredictor): self.prompts = prompts def reset_image(self): + """Resets the image and its features to None.""" self.im = None self.features = None @staticmethod def remove_small_regions(masks, min_area=0, nms_thresh=0.7): """ - Removes small disconnected regions and holes in masks, then reruns - box NMS to remove any new duplicates. Requires open-cv as a dependency. + Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this + function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum + Suppression (NMS) to eliminate any newly created duplicate boxes. Args: - masks (torch.Tensor): Masks, (N, H, W). - min_area (int): Minimum area threshold. - nms_thresh (float): NMS threshold. + masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is + the number of masks, H is height, and W is width. + min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0. + nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7. + Returns: - new_masks (torch.Tensor): New Masks, (N, H, W). - keep (List[int]): The indices of the new masks, which can be used to filter - the corresponding boxes. + (tuple([torch.Tensor, List[int]])): + - new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W). + - keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes. """ if len(masks) == 0: return masks @@ -386,23 +457,18 @@ class Predictor(BasePredictor): scores = [] for mask in masks: mask = mask.cpu().numpy().astype(np.uint8) - mask, changed = remove_small_regions(mask, min_area, mode='holes') + mask, changed = remove_small_regions(mask, min_area, mode="holes") unchanged = not changed - mask, changed = remove_small_regions(mask, min_area, mode='islands') + mask, changed = remove_small_regions(mask, min_area, mode="islands") unchanged = unchanged and not changed new_masks.append(torch.as_tensor(mask).unsqueeze(0)) - # Give score=0 to changed masks and score=1 to unchanged masks - # so NMS will prefer ones that didn't need postprocessing + # Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing scores.append(float(unchanged)) # Recalculate boxes and remove any new duplicates new_masks = torch.cat(new_masks, dim=0) boxes = batched_mask_to_box(new_masks) - keep = torchvision.ops.nms( - boxes.float(), - torch.as_tensor(scores), - nms_thresh, - ) + keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh) return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep diff --git a/ultralytics/models/utils/loss.py b/ultralytics/models/utils/loss.py index 95406e1..ac48775 100644 --- a/ultralytics/models/utils/loss.py +++ b/ultralytics/models/utils/loss.py @@ -6,20 +6,32 @@ import torch.nn.functional as F from ultralytics.utils.loss import FocalLoss, VarifocalLoss from ultralytics.utils.metrics import bbox_iou - from .ops import HungarianMatcher class DETRLoss(nn.Module): + """ + DETR (DEtection TRansformer) Loss class. This class calculates and returns the different loss components for the + DETR object detection model. It computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary + losses. - def __init__(self, - nc=80, - loss_gain=None, - aux_loss=True, - use_fl=True, - use_vfl=False, - use_uni_match=False, - uni_match_ind=0): + Attributes: + nc (int): The number of classes. + loss_gain (dict): Coefficients for different loss components. + aux_loss (bool): Whether to compute auxiliary losses. + use_fl (bool): Use FocalLoss or not. + use_vfl (bool): Use VarifocalLoss or not. + use_uni_match (bool): Whether to use a fixed layer to assign labels for the auxiliary branch. + uni_match_ind (int): The fixed indices of a layer to use if `use_uni_match` is True. + matcher (HungarianMatcher): Object to compute matching cost and indices. + fl (FocalLoss or None): Focal Loss object if `use_fl` is True, otherwise None. + vfl (VarifocalLoss or None): Varifocal Loss object if `use_vfl` is True, otherwise None. + device (torch.device): Device on which tensors are stored. + """ + + def __init__( + self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0 + ): """ DETR loss function. @@ -34,9 +46,9 @@ class DETRLoss(nn.Module): super().__init__() if loss_gain is None: - loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1} + loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1} self.nc = nc - self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2}) + self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2}) self.loss_gain = loss_gain self.aux_loss = aux_loss self.fl = FocalLoss() if use_fl else None @@ -46,9 +58,10 @@ class DETRLoss(nn.Module): self.uni_match_ind = uni_match_ind self.device = None - def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''): - # logits: [b, query, num_classes], gt_class: list[[n, 1]] - name_class = f'loss_class{postfix}' + def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""): + """Computes the classification loss based on predictions, target values, and ground truth scores.""" + # Logits: [b, query, num_classes], gt_class: list[[n, 1]] + name_class = f"loss_class{postfix}" bs, nq = pred_scores.shape[:2] # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes) one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device) @@ -63,25 +76,28 @@ class DETRLoss(nn.Module): loss_cls = self.fl(pred_scores, one_hot.float()) loss_cls /= max(num_gts, 1) / nq else: - loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss + loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss - return {name_class: loss_cls.squeeze() * self.loss_gain['class']} + return {name_class: loss_cls.squeeze() * self.loss_gain["class"]} - def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''): - # boxes: [b, query, 4], gt_bbox: list[[n, 4]] - name_bbox = f'loss_bbox{postfix}' - name_giou = f'loss_giou{postfix}' + def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""): + """Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding + boxes. + """ + # Boxes: [b, query, 4], gt_bbox: list[[n, 4]] + name_bbox = f"loss_bbox{postfix}" + name_giou = f"loss_giou{postfix}" loss = {} if len(gt_bboxes) == 0: - loss[name_bbox] = torch.tensor(0., device=self.device) - loss[name_giou] = torch.tensor(0., device=self.device) + loss[name_bbox] = torch.tensor(0.0, device=self.device) + loss[name_giou] = torch.tensor(0.0, device=self.device) return loss - loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes) + loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes) loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True) loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes) - loss[name_giou] = self.loss_gain['giou'] * loss[name_giou] + loss[name_giou] = self.loss_gain["giou"] * loss[name_giou] return {k: v.squeeze() for k, v in loss.items()} # This function is for future RT-DETR Segment models @@ -115,50 +131,57 @@ class DETRLoss(nn.Module): # loss = 1 - (numerator + 1) / (denominator + 1) # return loss.sum() / num_gts - def _get_loss_aux(self, - pred_bboxes, - pred_scores, - gt_bboxes, - gt_cls, - gt_groups, - match_indices=None, - postfix='', - masks=None, - gt_mask=None): - """Get auxiliary losses""" + def _get_loss_aux( + self, + pred_bboxes, + pred_scores, + gt_bboxes, + gt_cls, + gt_groups, + match_indices=None, + postfix="", + masks=None, + gt_mask=None, + ): + """Get auxiliary losses.""" # NOTE: loss class, bbox, giou, mask, dice loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device) if match_indices is None and self.use_uni_match: - match_indices = self.matcher(pred_bboxes[self.uni_match_ind], - pred_scores[self.uni_match_ind], - gt_bboxes, - gt_cls, - gt_groups, - masks=masks[self.uni_match_ind] if masks is not None else None, - gt_mask=gt_mask) + match_indices = self.matcher( + pred_bboxes[self.uni_match_ind], + pred_scores[self.uni_match_ind], + gt_bboxes, + gt_cls, + gt_groups, + masks=masks[self.uni_match_ind] if masks is not None else None, + gt_mask=gt_mask, + ) for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)): aux_masks = masks[i] if masks is not None else None - loss_ = self._get_loss(aux_bboxes, - aux_scores, - gt_bboxes, - gt_cls, - gt_groups, - masks=aux_masks, - gt_mask=gt_mask, - postfix=postfix, - match_indices=match_indices) - loss[0] += loss_[f'loss_class{postfix}'] - loss[1] += loss_[f'loss_bbox{postfix}'] - loss[2] += loss_[f'loss_giou{postfix}'] + loss_ = self._get_loss( + aux_bboxes, + aux_scores, + gt_bboxes, + gt_cls, + gt_groups, + masks=aux_masks, + gt_mask=gt_mask, + postfix=postfix, + match_indices=match_indices, + ) + loss[0] += loss_[f"loss_class{postfix}"] + loss[1] += loss_[f"loss_bbox{postfix}"] + loss[2] += loss_[f"loss_giou{postfix}"] # if masks is not None and gt_mask is not None: # loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix) # loss[3] += loss_[f'loss_mask{postfix}'] # loss[4] += loss_[f'loss_dice{postfix}'] loss = { - f'loss_class_aux{postfix}': loss[0], - f'loss_bbox_aux{postfix}': loss[1], - f'loss_giou_aux{postfix}': loss[2]} + f"loss_class_aux{postfix}": loss[0], + f"loss_bbox_aux{postfix}": loss[1], + f"loss_giou_aux{postfix}": loss[2], + } # if masks is not None and gt_mask is not None: # loss[f'loss_mask_aux{postfix}'] = loss[3] # loss[f'loss_dice_aux{postfix}'] = loss[4] @@ -166,39 +189,45 @@ class DETRLoss(nn.Module): @staticmethod def _get_index(match_indices): + """Returns batch indices, source indices, and destination indices from provided match indices.""" batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)]) src_idx = torch.cat([src for (src, _) in match_indices]) dst_idx = torch.cat([dst for (_, dst) in match_indices]) return (batch_idx, src_idx), dst_idx def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices): - pred_assigned = torch.cat([ - t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device) - for t, (I, _) in zip(pred_bboxes, match_indices)]) - gt_assigned = torch.cat([ - t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device) - for t, (_, J) in zip(gt_bboxes, match_indices)]) + """Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices.""" + pred_assigned = torch.cat( + [ + t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device) + for t, (i, _) in zip(pred_bboxes, match_indices) + ] + ) + gt_assigned = torch.cat( + [ + t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device) + for t, (_, j) in zip(gt_bboxes, match_indices) + ] + ) return pred_assigned, gt_assigned - def _get_loss(self, - pred_bboxes, - pred_scores, - gt_bboxes, - gt_cls, - gt_groups, - masks=None, - gt_mask=None, - postfix='', - match_indices=None): - """Get losses""" + def _get_loss( + self, + pred_bboxes, + pred_scores, + gt_bboxes, + gt_cls, + gt_groups, + masks=None, + gt_mask=None, + postfix="", + match_indices=None, + ): + """Get losses.""" if match_indices is None: - match_indices = self.matcher(pred_bboxes, - pred_scores, - gt_bboxes, - gt_cls, - gt_groups, - masks=masks, - gt_mask=gt_mask) + match_indices = self.matcher( + pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask + ) idx, gt_idx = self._get_index(match_indices) pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx] @@ -218,7 +247,7 @@ class DETRLoss(nn.Module): # loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix)) return loss - def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs): + def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs): """ Args: pred_bboxes (torch.Tensor): [l, b, query, 4] @@ -230,43 +259,62 @@ class DETRLoss(nn.Module): postfix (str): postfix of loss name. """ self.device = pred_bboxes.device - match_indices = kwargs.get('match_indices', None) - gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups'] + match_indices = kwargs.get("match_indices", None) + gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"] - total_loss = self._get_loss(pred_bboxes[-1], - pred_scores[-1], - gt_bboxes, - gt_cls, - gt_groups, - postfix=postfix, - match_indices=match_indices) + total_loss = self._get_loss( + pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices + ) if self.aux_loss: total_loss.update( - self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, - postfix)) + self._get_loss_aux( + pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix + ) + ) return total_loss class RTDETRDetectionLoss(DETRLoss): + """ + Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss. + + This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as + an additional denoising training loss when provided with denoising metadata. + """ def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None): + """ + Forward pass to compute the detection loss. + + Args: + preds (tuple): Predicted bounding boxes and scores. + batch (dict): Batch data containing ground truth information. + dn_bboxes (torch.Tensor, optional): Denoising bounding boxes. Default is None. + dn_scores (torch.Tensor, optional): Denoising scores. Default is None. + dn_meta (dict, optional): Metadata for denoising. Default is None. + + Returns: + (dict): Dictionary containing the total loss and, if applicable, the denoising loss. + """ pred_bboxes, pred_scores = preds total_loss = super().forward(pred_bboxes, pred_scores, batch) + # Check for denoising metadata to compute denoising training loss if dn_meta is not None: - dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group'] - assert len(batch['gt_groups']) == len(dn_pos_idx) + dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"] + assert len(batch["gt_groups"]) == len(dn_pos_idx) - # Denoising match indices - match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups']) + # Get the match indices for denoising + match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"]) - # Compute denoising training loss - dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices) + # Compute the denoising training loss + dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices) total_loss.update(dn_loss) else: - total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()}) + # If no denoising metadata is provided, set denoising loss to zero + total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()}) return total_loss @@ -276,12 +324,12 @@ class RTDETRDetectionLoss(DETRLoss): Get the match indices for denoising. Args: - dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising. - dn_num_group (int): The number of groups of denoising. - gt_groups (List(int)): a list of batch size length includes the number of gts of each image. + dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising. + dn_num_group (int): Number of denoising groups. + gt_groups (List[int]): List of integers representing the number of ground truths for each image. Returns: - dn_match_indices (List(tuple)): Matched indices. + (List[tuple]): List of tuples containing matched indices for denoising. """ dn_match_indices = [] idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) @@ -289,8 +337,8 @@ class RTDETRDetectionLoss(DETRLoss): if num_gt > 0: gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i] gt_idx = gt_idx.repeat(dn_num_group) - assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, ' - f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.' + assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, " + f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively." dn_match_indices.append((dn_pos_idx[i], gt_idx)) else: dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long))) diff --git a/ultralytics/models/utils/ops.py b/ultralytics/models/utils/ops.py index eb1ebfb..4f66fee 100644 --- a/ultralytics/models/utils/ops.py +++ b/ultralytics/models/utils/ops.py @@ -11,8 +11,8 @@ from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh class HungarianMatcher(nn.Module): """ - A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in - an end-to-end fashion. + A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in an + end-to-end fashion. HungarianMatcher performs optimal assignment over the predicted and ground truth bounding boxes using a cost function that considers classification scores, bounding box coordinates, and optionally, mask predictions. @@ -32,9 +32,12 @@ class HungarianMatcher(nn.Module): """ def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0): + """Initializes HungarianMatcher with cost coefficients, Focal Loss, mask prediction, sample points, and alpha + gamma factors. + """ super().__init__() if cost_gain is None: - cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1} + cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1} self.cost_gain = cost_gain self.use_fl = use_fl self.with_mask = with_mask @@ -45,8 +48,8 @@ class HungarianMatcher(nn.Module): def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None): """ Forward pass for HungarianMatcher. This function computes costs based on prediction and ground truth - (classification cost, L1 cost between boxes and GIoU cost between boxes) and finds the optimal matching - between predictions and ground truth based on these costs. + (classification cost, L1 cost between boxes and GIoU cost between boxes) and finds the optimal matching between + predictions and ground truth based on these costs. Args: pred_bboxes (Tensor): Predicted bounding boxes with shape [batch_size, num_queries, 4]. @@ -83,7 +86,7 @@ class HungarianMatcher(nn.Module): # Compute the classification cost pred_scores = pred_scores[:, gt_cls] if self.use_fl: - neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log()) + neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log()) pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log()) cost_class = pos_cost_class - neg_cost_class else: @@ -96,19 +99,25 @@ class HungarianMatcher(nn.Module): cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1) # Final cost matrix - C = self.cost_gain['class'] * cost_class + \ - self.cost_gain['bbox'] * cost_bbox + \ - self.cost_gain['giou'] * cost_giou + C = ( + self.cost_gain["class"] * cost_class + + self.cost_gain["bbox"] * cost_bbox + + self.cost_gain["giou"] * cost_giou + ) # Compute the mask cost and dice cost if self.with_mask: C += self._cost_mask(bs, gt_groups, masks, gt_mask) + # Set invalid values (NaNs and infinities) to 0 (fixes ValueError: matrix contains invalid numeric entries) + C[C.isnan() | C.isinf()] = 0.0 + C = C.view(bs, nq, -1).cpu() indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))] - gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) - # (idx for queries, idx for gt) - return [(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k]) - for k, (i, j) in enumerate(indices)] + gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt) + return [ + (torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k]) + for k, (i, j) in enumerate(indices) + ] # This function is for future RT-DETR Segment models # def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None): @@ -141,18 +150,13 @@ class HungarianMatcher(nn.Module): # return C -def get_cdn_group(batch, - num_classes, - num_queries, - class_embed, - num_dn=100, - cls_noise_ratio=0.5, - box_noise_scale=1.0, - training=False): +def get_cdn_group( + batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False +): """ - Get contrastive denoising training group. This function creates a contrastive denoising training group with - positive and negative samples from the ground truths (gt). It applies noise to the class labels and bounding - box coordinates, and returns the modified labels, bounding boxes, attention mask and meta information. + Get contrastive denoising training group. This function creates a contrastive denoising training group with positive + and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates, + and returns the modified labels, bounding boxes, attention mask and meta information. Args: batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape [num_gts, ]), 'gt_bboxes' @@ -174,7 +178,7 @@ def get_cdn_group(batch, if (not training) or num_dn <= 0: return None, None, None, None - gt_groups = batch['gt_groups'] + gt_groups = batch["gt_groups"] total_num = sum(gt_groups) max_nums = max(gt_groups) if max_nums == 0: @@ -182,26 +186,26 @@ def get_cdn_group(batch, num_group = num_dn // max_nums num_group = 1 if num_group == 0 else num_group - # pad gt to max_num of a batch + # Pad gt to max_num of a batch bs = len(gt_groups) - gt_cls = batch['cls'] # (bs*num, ) - gt_bbox = batch['bboxes'] # bs*num, 4 - b_idx = batch['batch_idx'] + gt_cls = batch["cls"] # (bs*num, ) + gt_bbox = batch["bboxes"] # bs*num, 4 + b_idx = batch["batch_idx"] - # each group has positive and negative queries. + # Each group has positive and negative queries. dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, ) dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4 dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, ) - # positive and negative mask + # Positive and negative mask # (bs*num*num_group, ), the second total_num*num_group part as negative samples neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num if cls_noise_ratio > 0: - # half of bbox prob + # Half of bbox prob mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5) idx = torch.nonzero(mask).squeeze(-1) - # randomly put a new one here + # Randomly put a new one here new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device) dn_cls[idx] = new_label @@ -217,10 +221,9 @@ def get_cdn_group(batch, known_bbox += rand_part * diff known_bbox.clip_(min=0.0, max=1.0) dn_bbox = xyxy2xywh(known_bbox) - dn_bbox = inverse_sigmoid(dn_bbox) + dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid - # total denoising queries - num_dn = int(max_nums * 2 * num_group) + num_dn = int(max_nums * 2 * num_group) # total denoising queries # class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)]) dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256 padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device) @@ -235,27 +238,26 @@ def get_cdn_group(batch, tgt_size = num_dn + num_queries attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool) - # match query cannot see the reconstruct + # Match query cannot see the reconstruct attn_mask[num_dn:, :num_dn] = True - # reconstruct cannot see each other + # Reconstruct cannot see each other for i in range(num_group): if i == 0: - attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True if i == num_group - 1: - attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True else: - attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True - attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True dn_meta = { - 'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)], - 'dn_num_group': num_group, - 'dn_num_split': [num_dn, num_queries]} + "dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)], + "dn_num_group": num_group, + "dn_num_split": [num_dn, num_queries], + } - return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to( - class_embed.device), dn_meta - - -def inverse_sigmoid(x, eps=1e-6): - """Inverse sigmoid function.""" - x = x.clip(min=0., max=1.) - return torch.log(x / (1 - x + eps) + eps) + return ( + padding_cls.to(class_embed.device), + padding_bbox.to(class_embed.device), + attn_mask.to(class_embed.device), + dn_meta, + ) diff --git a/ultralytics/models/yolo/__init__.py b/ultralytics/models/yolo/__init__.py index c66e376..7b1a597 100644 --- a/ultralytics/models/yolo/__init__.py +++ b/ultralytics/models/yolo/__init__.py @@ -1,7 +1,7 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from ultralytics.models.yolo import classify, detect, pose, segment +from ultralytics.models.yolo import classify, detect, obb, pose, segment -from .model import YOLO +from .model import YOLO, YOLOWorld -__all__ = 'classify', 'segment', 'detect', 'pose', 'YOLO' +__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO", "YOLOWorld" diff --git a/ultralytics/models/yolo/__pycache__/__init__.cpython-312.pyc b/ultralytics/models/yolo/__pycache__/__init__.cpython-312.pyc index 6546cfc..fe407d2 100644 Binary files a/ultralytics/models/yolo/__pycache__/__init__.cpython-312.pyc and b/ultralytics/models/yolo/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/__pycache__/__init__.cpython-39.pyc b/ultralytics/models/yolo/__pycache__/__init__.cpython-39.pyc index 58c2936..f8d9866 100644 Binary files a/ultralytics/models/yolo/__pycache__/__init__.cpython-39.pyc and b/ultralytics/models/yolo/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/__pycache__/model.cpython-312.pyc b/ultralytics/models/yolo/__pycache__/model.cpython-312.pyc index afe0c46..f371ca9 100644 Binary files a/ultralytics/models/yolo/__pycache__/model.cpython-312.pyc and b/ultralytics/models/yolo/__pycache__/model.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/__pycache__/model.cpython-39.pyc b/ultralytics/models/yolo/__pycache__/model.cpython-39.pyc index ffd6446..f489232 100644 Binary files a/ultralytics/models/yolo/__pycache__/model.cpython-39.pyc and b/ultralytics/models/yolo/__pycache__/model.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/classify/__init__.py b/ultralytics/models/yolo/classify/__init__.py index 33d72e6..ca92f89 100644 --- a/ultralytics/models/yolo/classify/__init__.py +++ b/ultralytics/models/yolo/classify/__init__.py @@ -4,4 +4,4 @@ from ultralytics.models.yolo.classify.predict import ClassificationPredictor from ultralytics.models.yolo.classify.train import ClassificationTrainer from ultralytics.models.yolo.classify.val import ClassificationValidator -__all__ = 'ClassificationPredictor', 'ClassificationTrainer', 'ClassificationValidator' +__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator" diff --git a/ultralytics/models/yolo/classify/__pycache__/__init__.cpython-312.pyc b/ultralytics/models/yolo/classify/__pycache__/__init__.cpython-312.pyc index 74695b1..d05f0b5 100644 Binary files a/ultralytics/models/yolo/classify/__pycache__/__init__.cpython-312.pyc and b/ultralytics/models/yolo/classify/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/classify/__pycache__/__init__.cpython-39.pyc b/ultralytics/models/yolo/classify/__pycache__/__init__.cpython-39.pyc index 0de9dbe..419dfac 100644 Binary files a/ultralytics/models/yolo/classify/__pycache__/__init__.cpython-39.pyc and b/ultralytics/models/yolo/classify/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/classify/__pycache__/predict.cpython-312.pyc b/ultralytics/models/yolo/classify/__pycache__/predict.cpython-312.pyc index d0af994..8fbe323 100644 Binary files a/ultralytics/models/yolo/classify/__pycache__/predict.cpython-312.pyc and b/ultralytics/models/yolo/classify/__pycache__/predict.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/classify/__pycache__/predict.cpython-39.pyc b/ultralytics/models/yolo/classify/__pycache__/predict.cpython-39.pyc index d1deddc..2d90a86 100644 Binary files a/ultralytics/models/yolo/classify/__pycache__/predict.cpython-39.pyc and b/ultralytics/models/yolo/classify/__pycache__/predict.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/classify/__pycache__/train.cpython-312.pyc b/ultralytics/models/yolo/classify/__pycache__/train.cpython-312.pyc index 44eda6f..07671e6 100644 Binary files a/ultralytics/models/yolo/classify/__pycache__/train.cpython-312.pyc and b/ultralytics/models/yolo/classify/__pycache__/train.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/classify/__pycache__/train.cpython-39.pyc b/ultralytics/models/yolo/classify/__pycache__/train.cpython-39.pyc index 2c526ca..8078fc1 100644 Binary files a/ultralytics/models/yolo/classify/__pycache__/train.cpython-39.pyc and b/ultralytics/models/yolo/classify/__pycache__/train.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/classify/__pycache__/val.cpython-312.pyc b/ultralytics/models/yolo/classify/__pycache__/val.cpython-312.pyc index 615b694..b8e1706 100644 Binary files a/ultralytics/models/yolo/classify/__pycache__/val.cpython-312.pyc and b/ultralytics/models/yolo/classify/__pycache__/val.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/classify/__pycache__/val.cpython-39.pyc b/ultralytics/models/yolo/classify/__pycache__/val.cpython-39.pyc index c45dd1c..63b50a9 100644 Binary files a/ultralytics/models/yolo/classify/__pycache__/val.cpython-39.pyc and b/ultralytics/models/yolo/classify/__pycache__/val.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/classify/predict.py b/ultralytics/models/yolo/classify/predict.py index a22616e..853ef04 100644 --- a/ultralytics/models/yolo/classify/predict.py +++ b/ultralytics/models/yolo/classify/predict.py @@ -1,6 +1,8 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license +import cv2 import torch +from PIL import Image from ultralytics.engine.predictor import BasePredictor from ultralytics.engine.results import Results @@ -26,13 +28,23 @@ class ClassificationPredictor(BasePredictor): """ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initializes ClassificationPredictor setting the task to 'classify'.""" super().__init__(cfg, overrides, _callbacks) - self.args.task = 'classify' + self.args.task = "classify" + self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor" def preprocess(self, img): """Converts input image to model-compatible data type.""" if not isinstance(img, torch.Tensor): - img = torch.stack([self.transforms(im) for im in img], dim=0) + is_legacy_transform = any( + self._legacy_transform_name in str(transform) for transform in self.transforms.transforms + ) + if is_legacy_transform: # to handle legacy transforms + img = torch.stack([self.transforms(im) for im in img], dim=0) + else: + img = torch.stack( + [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0 + ) img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 diff --git a/ultralytics/models/yolo/classify/train.py b/ultralytics/models/yolo/classify/train.py index 0829f05..42c6554 100644 --- a/ultralytics/models/yolo/classify/train.py +++ b/ultralytics/models/yolo/classify/train.py @@ -33,23 +33,23 @@ class ClassificationTrainer(BaseTrainer): """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks.""" if overrides is None: overrides = {} - overrides['task'] = 'classify' - if overrides.get('imgsz') is None: - overrides['imgsz'] = 224 + overrides["task"] = "classify" + if overrides.get("imgsz") is None: + overrides["imgsz"] = 224 super().__init__(cfg, overrides, _callbacks) def set_model_attributes(self): """Set the YOLO model's class names from the loaded dataset.""" - self.model.names = self.data['names'] + self.model.names = self.data["names"] def get_model(self, cfg=None, weights=None, verbose=True): """Returns a modified PyTorch model configured for training YOLO.""" - model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) + model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) if weights: model.load(weights) for m in model.modules(): - if not self.args.pretrained and hasattr(m, 'reset_parameters'): + if not self.args.pretrained and hasattr(m, "reset_parameters"): m.reset_parameters() if isinstance(m, torch.nn.Dropout) and self.args.dropout: m.p = self.args.dropout # set dropout @@ -64,31 +64,32 @@ class ClassificationTrainer(BaseTrainer): model, ckpt = str(self.model), None # Load a YOLO model locally, from torchvision, or from Ultralytics assets - if model.endswith('.pt'): - self.model, ckpt = attempt_load_one_weight(model, device='cpu') + if model.endswith(".pt"): + self.model, ckpt = attempt_load_one_weight(model, device="cpu") for p in self.model.parameters(): p.requires_grad = True # for training - elif model.split('.')[-1] in ('yaml', 'yml'): + elif model.split(".")[-1] in ("yaml", "yml"): self.model = self.get_model(cfg=model) elif model in torchvision.models.__dict__: - self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None) + self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None) else: - FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.') - ClassificationModel.reshape_outputs(self.model, self.data['nc']) + raise FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.") + ClassificationModel.reshape_outputs(self.model, self.data["nc"]) return ckpt - def build_dataset(self, img_path, mode='train', batch=None): - return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train', prefix=mode) + def build_dataset(self, img_path, mode="train", batch=None): + """Creates a ClassificationDataset instance given an image path, and mode (train/test etc.).""" + return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode) - def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): """Returns PyTorch DataLoader with transforms to preprocess images for inference.""" with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP dataset = self.build_dataset(dataset_path, mode) loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank) # Attach inference transforms - if mode != 'train': + if mode != "train": if is_parallel(self.model): self.model.module.transforms = loader.dataset.torch_transforms else: @@ -97,26 +98,32 @@ class ClassificationTrainer(BaseTrainer): def preprocess_batch(self, batch): """Preprocesses a batch of images and classes.""" - batch['img'] = batch['img'].to(self.device) - batch['cls'] = batch['cls'].to(self.device) + batch["img"] = batch["img"].to(self.device) + batch["cls"] = batch["cls"].to(self.device) return batch def progress_string(self): """Returns a formatted string showing training progress.""" - return ('\n' + '%11s' * (4 + len(self.loss_names))) % \ - ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size') + return ("\n" + "%11s" * (4 + len(self.loss_names))) % ( + "Epoch", + "GPU_mem", + *self.loss_names, + "Instances", + "Size", + ) def get_validator(self): """Returns an instance of ClassificationValidator for validation.""" - self.loss_names = ['loss'] - return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir) + self.loss_names = ["loss"] + return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks) - def label_loss_items(self, loss_items=None, prefix='train'): + def label_loss_items(self, loss_items=None, prefix="train"): """ - Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for - segmentation & detection + Returns a loss dict with labelled training loss items tensor. + + Not needed for classification but necessary for segmentation & detection """ - keys = [f'{prefix}/{x}' for x in self.loss_names] + keys = [f"{prefix}/{x}" for x in self.loss_names] if loss_items is None: return keys loss_items = [round(float(loss_items), 5)] @@ -132,19 +139,20 @@ class ClassificationTrainer(BaseTrainer): if f.exists(): strip_optimizer(f) # strip optimizers if f is self.best: - LOGGER.info(f'\nValidating {f}...') + LOGGER.info(f"\nValidating {f}...") self.validator.args.data = self.args.data self.validator.args.plots = self.args.plots self.metrics = self.validator(model=f) - self.metrics.pop('fitness', None) - self.run_callbacks('on_fit_epoch_end') + self.metrics.pop("fitness", None) + self.run_callbacks("on_fit_epoch_end") LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") def plot_training_samples(self, batch, ni): """Plots training samples with their annotations.""" plot_images( - images=batch['img'], - batch_idx=torch.arange(len(batch['img'])), - cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models - fname=self.save_dir / f'train_batch{ni}.jpg', - on_plot=self.on_plot) + images=batch["img"], + batch_idx=torch.arange(len(batch["img"])), + cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) diff --git a/ultralytics/models/yolo/classify/val.py b/ultralytics/models/yolo/classify/val.py index 0748e27..de3cff2 100644 --- a/ultralytics/models/yolo/classify/val.py +++ b/ultralytics/models/yolo/classify/val.py @@ -31,43 +31,42 @@ class ClassificationValidator(BaseValidator): super().__init__(dataloader, save_dir, pbar, args, _callbacks) self.targets = None self.pred = None - self.args.task = 'classify' + self.args.task = "classify" self.metrics = ClassifyMetrics() def get_desc(self): """Returns a formatted string summarizing classification metrics.""" - return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc') + return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc") def init_metrics(self, model): """Initialize confusion matrix, class names, and top-1 and top-5 accuracy.""" self.names = model.names self.nc = len(model.names) - self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task='classify') + self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify") self.pred = [] self.targets = [] def preprocess(self, batch): """Preprocesses input batch and returns it.""" - batch['img'] = batch['img'].to(self.device, non_blocking=True) - batch['img'] = batch['img'].half() if self.args.half else batch['img'].float() - batch['cls'] = batch['cls'].to(self.device) + batch["img"] = batch["img"].to(self.device, non_blocking=True) + batch["img"] = batch["img"].half() if self.args.half else batch["img"].float() + batch["cls"] = batch["cls"].to(self.device) return batch def update_metrics(self, preds, batch): """Updates running metrics with model predictions and batch targets.""" n5 = min(len(self.names), 5) self.pred.append(preds.argsort(1, descending=True)[:, :n5]) - self.targets.append(batch['cls']) + self.targets.append(batch["cls"]) def finalize_metrics(self, *args, **kwargs): """Finalizes metrics of the model such as confusion_matrix and speed.""" self.confusion_matrix.process_cls_preds(self.pred, self.targets) if self.args.plots: for normalize in True, False: - self.confusion_matrix.plot(save_dir=self.save_dir, - names=self.names.values(), - normalize=normalize, - on_plot=self.on_plot) + self.confusion_matrix.plot( + save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot + ) self.metrics.speed = self.speed self.metrics.confusion_matrix = self.confusion_matrix self.metrics.save_dir = self.save_dir @@ -78,6 +77,7 @@ class ClassificationValidator(BaseValidator): return self.metrics.results_dict def build_dataset(self, img_path): + """Creates and returns a ClassificationDataset instance using given image path and preprocessing parameters.""" return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split) def get_dataloader(self, dataset_path, batch_size): @@ -87,24 +87,27 @@ class ClassificationValidator(BaseValidator): def print_results(self): """Prints evaluation metrics for YOLO object detection model.""" - pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format - LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5)) + pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format + LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5)) def plot_val_samples(self, batch, ni): """Plot validation image samples.""" plot_images( - images=batch['img'], - batch_idx=torch.arange(len(batch['img'])), - cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models - fname=self.save_dir / f'val_batch{ni}_labels.jpg', + images=batch["img"], + batch_idx=torch.arange(len(batch["img"])), + cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models + fname=self.save_dir / f"val_batch{ni}_labels.jpg", names=self.names, - on_plot=self.on_plot) + on_plot=self.on_plot, + ) def plot_predictions(self, batch, preds, ni): """Plots predicted bounding boxes on input images and saves the result.""" - plot_images(batch['img'], - batch_idx=torch.arange(len(batch['img'])), - cls=torch.argmax(preds, dim=1), - fname=self.save_dir / f'val_batch{ni}_pred.jpg', - names=self.names, - on_plot=self.on_plot) # pred + plot_images( + batch["img"], + batch_idx=torch.arange(len(batch["img"])), + cls=torch.argmax(preds, dim=1), + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred diff --git a/ultralytics/models/yolo/detect/__init__.py b/ultralytics/models/yolo/detect/__init__.py index 20fc0c4..5f3e62c 100644 --- a/ultralytics/models/yolo/detect/__init__.py +++ b/ultralytics/models/yolo/detect/__init__.py @@ -4,4 +4,4 @@ from .predict import DetectionPredictor from .train import DetectionTrainer from .val import DetectionValidator -__all__ = 'DetectionPredictor', 'DetectionTrainer', 'DetectionValidator' +__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator" diff --git a/ultralytics/models/yolo/detect/__pycache__/__init__.cpython-312.pyc b/ultralytics/models/yolo/detect/__pycache__/__init__.cpython-312.pyc index 87cc14b..6374723 100644 Binary files a/ultralytics/models/yolo/detect/__pycache__/__init__.cpython-312.pyc and b/ultralytics/models/yolo/detect/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/detect/__pycache__/__init__.cpython-39.pyc b/ultralytics/models/yolo/detect/__pycache__/__init__.cpython-39.pyc index 773c37d..aa695ec 100644 Binary files a/ultralytics/models/yolo/detect/__pycache__/__init__.cpython-39.pyc and b/ultralytics/models/yolo/detect/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/detect/__pycache__/predict.cpython-312.pyc b/ultralytics/models/yolo/detect/__pycache__/predict.cpython-312.pyc index e2fea2f..7ea333e 100644 Binary files a/ultralytics/models/yolo/detect/__pycache__/predict.cpython-312.pyc and b/ultralytics/models/yolo/detect/__pycache__/predict.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/detect/__pycache__/predict.cpython-39.pyc b/ultralytics/models/yolo/detect/__pycache__/predict.cpython-39.pyc index fe70b67..25c3247 100644 Binary files a/ultralytics/models/yolo/detect/__pycache__/predict.cpython-39.pyc and b/ultralytics/models/yolo/detect/__pycache__/predict.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/detect/__pycache__/train.cpython-312.pyc b/ultralytics/models/yolo/detect/__pycache__/train.cpython-312.pyc index 7b5b461..786811e 100644 Binary files a/ultralytics/models/yolo/detect/__pycache__/train.cpython-312.pyc and b/ultralytics/models/yolo/detect/__pycache__/train.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/detect/__pycache__/train.cpython-39.pyc b/ultralytics/models/yolo/detect/__pycache__/train.cpython-39.pyc index 226df4e..c3a0b6f 100644 Binary files a/ultralytics/models/yolo/detect/__pycache__/train.cpython-39.pyc and b/ultralytics/models/yolo/detect/__pycache__/train.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/detect/__pycache__/val.cpython-312.pyc b/ultralytics/models/yolo/detect/__pycache__/val.cpython-312.pyc index 53190ee..a6aa07b 100644 Binary files a/ultralytics/models/yolo/detect/__pycache__/val.cpython-312.pyc and b/ultralytics/models/yolo/detect/__pycache__/val.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/detect/__pycache__/val.cpython-39.pyc b/ultralytics/models/yolo/detect/__pycache__/val.cpython-39.pyc index f6e55e9..fa187f3 100644 Binary files a/ultralytics/models/yolo/detect/__pycache__/val.cpython-39.pyc and b/ultralytics/models/yolo/detect/__pycache__/val.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/detect/predict.py b/ultralytics/models/yolo/detect/predict.py index 28cbd7c..3a0c628 100644 --- a/ultralytics/models/yolo/detect/predict.py +++ b/ultralytics/models/yolo/detect/predict.py @@ -22,12 +22,14 @@ class DetectionPredictor(BasePredictor): def postprocess(self, preds, img, orig_imgs): """Post-processes predictions and returns a list of Results objects.""" - preds = ops.non_max_suppression(preds, - self.args.conf, - self.args.iou, - agnostic=self.args.agnostic_nms, - max_det=self.args.max_det, - classes=self.args.classes) + preds = ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + classes=self.args.classes, + ) if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) diff --git a/ultralytics/models/yolo/detect/train.py b/ultralytics/models/yolo/detect/train.py index 56d9243..3326512 100644 --- a/ultralytics/models/yolo/detect/train.py +++ b/ultralytics/models/yolo/detect/train.py @@ -1,8 +1,11 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license +import math +import random from copy import copy import numpy as np +import torch.nn as nn from ultralytics.data import build_dataloader, build_yolo_dataset from ultralytics.engine.trainer import BaseTrainer @@ -27,7 +30,7 @@ class DetectionTrainer(BaseTrainer): ``` """ - def build_dataset(self, img_path, mode='train', batch=None): + def build_dataset(self, img_path, mode="train", batch=None): """ Build YOLO Dataset. @@ -37,53 +40,70 @@ class DetectionTrainer(BaseTrainer): batch (int, optional): Size of batches, this is for `rect`. Defaults to None. """ gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) - return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs) + return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs) - def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): """Construct and return dataloader.""" - assert mode in ['train', 'val'] + assert mode in ["train", "val"] with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP dataset = self.build_dataset(dataset_path, mode, batch_size) - shuffle = mode == 'train' - if getattr(dataset, 'rect', False) and shuffle: + shuffle = mode == "train" + if getattr(dataset, "rect", False) and shuffle: LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False") shuffle = False - workers = self.args.workers if mode == 'train' else self.args.workers * 2 + workers = self.args.workers if mode == "train" else self.args.workers * 2 return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader def preprocess_batch(self, batch): """Preprocesses a batch of images by scaling and converting to float.""" - batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255 + batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 + if self.args.multi_scale: + imgs = batch["img"] + sz = ( + random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride) + // self.stride + * self.stride + ) # size + sf = sz / max(imgs.shape[2:]) # scale factor + if sf != 1: + ns = [ + math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:] + ] # new shape (stretched to gs-multiple) + imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) + batch["img"] = imgs return batch def set_model_attributes(self): - """nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps).""" + """Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps).""" # self.args.box *= 3 / nl # scale to layers # self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers - self.model.nc = self.data['nc'] # attach number of classes to model - self.model.names = self.data['names'] # attach class names to model + self.model.nc = self.data["nc"] # attach number of classes to model + self.model.names = self.data["names"] # attach class names to model self.model.args = self.args # attach hyperparameters to model # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc def get_model(self, cfg=None, weights=None, verbose=True): """Return a YOLO detection model.""" - model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) + model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) if weights: model.load(weights) return model def get_validator(self): """Returns a DetectionValidator for YOLO model validation.""" - self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss' - return yolo.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) + self.loss_names = "box_loss", "cls_loss", "dfl_loss" + return yolo.detect.DetectionValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) - def label_loss_items(self, loss_items=None, prefix='train'): + def label_loss_items(self, loss_items=None, prefix="train"): """ - Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for - segmentation & detection + Returns a loss dict with labelled training loss items tensor. + + Not needed for classification but necessary for segmentation & detection """ - keys = [f'{prefix}/{x}' for x in self.loss_names] + keys = [f"{prefix}/{x}" for x in self.loss_names] if loss_items is not None: loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats return dict(zip(keys, loss_items)) @@ -92,18 +112,25 @@ class DetectionTrainer(BaseTrainer): def progress_string(self): """Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size.""" - return ('\n' + '%11s' * - (4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size') + return ("\n" + "%11s" * (4 + len(self.loss_names))) % ( + "Epoch", + "GPU_mem", + *self.loss_names, + "Instances", + "Size", + ) def plot_training_samples(self, batch, ni): """Plots training samples with their annotations.""" - plot_images(images=batch['img'], - batch_idx=batch['batch_idx'], - cls=batch['cls'].squeeze(-1), - bboxes=batch['bboxes'], - paths=batch['im_file'], - fname=self.save_dir / f'train_batch{ni}.jpg', - on_plot=self.on_plot) + plot_images( + images=batch["img"], + batch_idx=batch["batch_idx"], + cls=batch["cls"].squeeze(-1), + bboxes=batch["bboxes"], + paths=batch["im_file"], + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) def plot_metrics(self): """Plots metrics from a CSV file.""" @@ -111,6 +138,6 @@ class DetectionTrainer(BaseTrainer): def plot_training_labels(self): """Create a labeled training plot of the YOLO model.""" - boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0) - cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0) - plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot) + boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0) + cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0) + plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot) diff --git a/ultralytics/models/yolo/detect/val.py b/ultralytics/models/yolo/detect/val.py index 6fca481..5550ec3 100644 --- a/ultralytics/models/yolo/detect/val.py +++ b/ultralytics/models/yolo/detect/val.py @@ -12,7 +12,6 @@ from ultralytics.utils import LOGGER, ops from ultralytics.utils.checks import check_requirements from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou from ultralytics.utils.plotting import output_to_target, plot_images -from ultralytics.utils.torch_utils import de_parallel class DetectionValidator(BaseValidator): @@ -35,35 +34,40 @@ class DetectionValidator(BaseValidator): self.nt_per_class = None self.is_coco = False self.class_map = None - self.args.task = 'detect' + self.args.task = "detect" self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot) - self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95 + self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95 self.niou = self.iouv.numel() self.lb = [] # for autolabelling def preprocess(self, batch): """Preprocesses batch of images for YOLO training.""" - batch['img'] = batch['img'].to(self.device, non_blocking=True) - batch['img'] = (batch['img'].half() if self.args.half else batch['img'].float()) / 255 - for k in ['batch_idx', 'cls', 'bboxes']: + batch["img"] = batch["img"].to(self.device, non_blocking=True) + batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255 + for k in ["batch_idx", "cls", "bboxes"]: batch[k] = batch[k].to(self.device) if self.args.save_hybrid: - height, width = batch['img'].shape[2:] - nb = len(batch['img']) - bboxes = batch['bboxes'] * torch.tensor((width, height, width, height), device=self.device) - self.lb = [ - torch.cat([batch['cls'][batch['batch_idx'] == i], bboxes[batch['batch_idx'] == i]], dim=-1) - for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling + height, width = batch["img"].shape[2:] + nb = len(batch["img"]) + bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device) + self.lb = ( + [ + torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1) + for i in range(nb) + ] + if self.args.save_hybrid + else [] + ) # for autolabelling return batch def init_metrics(self, model): """Initialize evaluation metrics for YOLO.""" - val = self.data.get(self.args.split, '') # validation path - self.is_coco = isinstance(val, str) and 'coco' in val and val.endswith(f'{os.sep}val2017.txt') # is COCO + val = self.data.get(self.args.split, "") # validation path + self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000)) - self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO + self.args.save_json |= self.is_coco # run on final val if training COCO self.names = model.names self.nc = len(model.names) self.metrics.names = self.names @@ -71,67 +75,88 @@ class DetectionValidator(BaseValidator): self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf) self.seen = 0 self.jdict = [] - self.stats = [] + self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[]) def get_desc(self): """Return a formatted string summarizing class metrics of YOLO model.""" - return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)') + return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)") def postprocess(self, preds): """Apply Non-maximum suppression to prediction outputs.""" - return ops.non_max_suppression(preds, - self.args.conf, - self.args.iou, - labels=self.lb, - multi_label=True, - agnostic=self.args.single_cls, - max_det=self.args.max_det) + return ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + labels=self.lb, + multi_label=True, + agnostic=self.args.single_cls, + max_det=self.args.max_det, + ) + + def _prepare_batch(self, si, batch): + """Prepares a batch of images and annotations for validation.""" + idx = batch["batch_idx"] == si + cls = batch["cls"][idx].squeeze(-1) + bbox = batch["bboxes"][idx] + ori_shape = batch["ori_shape"][si] + imgsz = batch["img"].shape[2:] + ratio_pad = batch["ratio_pad"][si] + if len(cls): + bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes + ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels + return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad) + + def _prepare_pred(self, pred, pbatch): + """Prepares a batch of images and annotations for validation.""" + predn = pred.clone() + ops.scale_boxes( + pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"] + ) # native-space pred + return predn def update_metrics(self, preds, batch): """Metrics.""" for si, pred in enumerate(preds): - idx = batch['batch_idx'] == si - cls = batch['cls'][idx] - bbox = batch['bboxes'][idx] - nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions - shape = batch['ori_shape'][si] - correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init self.seen += 1 - + npr = len(pred) + stat = dict( + conf=torch.zeros(0, device=self.device), + pred_cls=torch.zeros(0, device=self.device), + tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + ) + pbatch = self._prepare_batch(si, batch) + cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") + nl = len(cls) + stat["target_cls"] = cls if npr == 0: if nl: - self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), cls.squeeze(-1))) + for k in self.stats.keys(): + self.stats[k].append(stat[k]) if self.args.plots: - self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1)) + self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls) continue # Predictions if self.args.single_cls: pred[:, 5] = 0 - predn = pred.clone() - ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape, - ratio_pad=batch['ratio_pad'][si]) # native-space pred + predn = self._prepare_pred(pred, pbatch) + stat["conf"] = predn[:, 4] + stat["pred_cls"] = predn[:, 5] # Evaluate if nl: - height, width = batch['img'].shape[2:] - tbox = ops.xywh2xyxy(bbox) * torch.tensor( - (width, height, width, height), device=self.device) # target boxes - ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape, - ratio_pad=batch['ratio_pad'][si]) # native-space labels - labelsn = torch.cat((cls, tbox), 1) # native-space labels - correct_bboxes = self._process_batch(predn, labelsn) - # TODO: maybe remove these `self.` arguments as they already are member variable + stat["tp"] = self._process_batch(predn, bbox, cls) if self.args.plots: - self.confusion_matrix.process_batch(predn, labelsn) - self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) # (conf, pcls, tcls) + self.confusion_matrix.process_batch(predn, bbox, cls) + for k in self.stats.keys(): + self.stats[k].append(stat[k]) # Save if self.args.save_json: - self.pred_to_json(predn, batch['im_file'][si]) + self.pred_to_json(predn, batch["im_file"][si]) if self.args.save_txt: - file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt' - self.save_one_txt(predn, self.args.save_conf, shape, file) + file = self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt' + self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file) def finalize_metrics(self, *args, **kwargs): """Set final values for metrics speed and confusion matrix.""" @@ -140,19 +165,20 @@ class DetectionValidator(BaseValidator): def get_stats(self): """Returns metrics statistics and results dictionary.""" - stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy - if len(stats) and stats[0].any(): - self.metrics.process(*stats) - self.nt_per_class = np.bincount(stats[-1].astype(int), minlength=self.nc) # number of targets per class + stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy + if len(stats) and stats["tp"].any(): + self.metrics.process(**stats) + self.nt_per_class = np.bincount( + stats["target_cls"].astype(int), minlength=self.nc + ) # number of targets per class return self.metrics.results_dict def print_results(self): """Prints training/validation set metrics per class.""" - pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format - LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) + pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format + LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) if self.nt_per_class.sum() == 0: - LOGGER.warning( - f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels') + LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels") # Print results per class if self.args.verbose and not self.training and self.nc > 1 and len(self.stats): @@ -161,12 +187,11 @@ class DetectionValidator(BaseValidator): if self.args.plots: for normalize in True, False: - self.confusion_matrix.plot(save_dir=self.save_dir, - names=self.names.values(), - normalize=normalize, - on_plot=self.on_plot) + self.confusion_matrix.plot( + save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot + ) - def _process_batch(self, detections, labels): + def _process_batch(self, detections, gt_bboxes, gt_cls): """ Return correct prediction matrix. @@ -179,10 +204,10 @@ class DetectionValidator(BaseValidator): Returns: (torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels. """ - iou = box_iou(labels[:, 1:], detections[:, :4]) - return self.match_predictions(detections[:, 5], labels[:, 0], iou) + iou = box_iou(gt_bboxes, detections[:, :4]) + return self.match_predictions(detections[:, 5], gt_cls, iou) - def build_dataset(self, img_path, mode='val', batch=None): + def build_dataset(self, img_path, mode="val", batch=None): """ Build YOLO Dataset. @@ -191,33 +216,36 @@ class DetectionValidator(BaseValidator): mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. batch (int, optional): Size of batches, this is for `rect`. Defaults to None. """ - gs = max(int(de_parallel(self.model).stride if self.model else 0), 32) - return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=gs) + return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride) def get_dataloader(self, dataset_path, batch_size): """Construct and return dataloader.""" - dataset = self.build_dataset(dataset_path, batch=batch_size, mode='val') + dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val") return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader def plot_val_samples(self, batch, ni): """Plot validation image samples.""" - plot_images(batch['img'], - batch['batch_idx'], - batch['cls'].squeeze(-1), - batch['bboxes'], - paths=batch['im_file'], - fname=self.save_dir / f'val_batch{ni}_labels.jpg', - names=self.names, - on_plot=self.on_plot) + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) def plot_predictions(self, batch, preds, ni): """Plots predicted bounding boxes on input images and saves the result.""" - plot_images(batch['img'], - *output_to_target(preds, max_det=self.args.max_det), - paths=batch['im_file'], - fname=self.save_dir / f'val_batch{ni}_pred.jpg', - names=self.names, - on_plot=self.on_plot) # pred + plot_images( + batch["img"], + *output_to_target(preds, max_det=self.args.max_det), + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred def save_one_txt(self, predn, save_conf, shape, file): """Save YOLO detections to a txt file in normalized coordinates in a specific format.""" @@ -225,8 +253,8 @@ class DetectionValidator(BaseValidator): for *xyxy, conf, cls in predn.tolist(): xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format - with open(file, 'a') as f: - f.write(('%g ' * len(line)).rstrip() % line + '\n') + with open(file, "a") as f: + f.write(("%g " * len(line)).rstrip() % line + "\n") def pred_to_json(self, predn, filename): """Serialize YOLO predictions to COCO json format.""" @@ -235,28 +263,31 @@ class DetectionValidator(BaseValidator): box = ops.xyxy2xywh(predn[:, :4]) # xywh box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner for p, b in zip(predn.tolist(), box.tolist()): - self.jdict.append({ - 'image_id': image_id, - 'category_id': self.class_map[int(p[5])], - 'bbox': [round(x, 3) for x in b], - 'score': round(p[4], 5)}) + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(p[5])], + "bbox": [round(x, 3) for x in b], + "score": round(p[4], 5), + } + ) def eval_json(self, stats): """Evaluates YOLO output in JSON format and returns performance statistics.""" if self.args.save_json and self.is_coco and len(self.jdict): - anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations - pred_json = self.save_dir / 'predictions.json' # predictions - LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') + anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations + pred_json = self.save_dir / "predictions.json" # predictions + LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb - check_requirements('pycocotools>=2.0.6') + check_requirements("pycocotools>=2.0.6") from pycocotools.coco import COCO # noqa from pycocotools.cocoeval import COCOeval # noqa for x in anno_json, pred_json: - assert x.is_file(), f'{x} file not found' + assert x.is_file(), f"{x} file not found" anno = COCO(str(anno_json)) # init annotations api pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) - eval = COCOeval(anno, pred, 'bbox') + eval = COCOeval(anno, pred, "bbox") if self.is_coco: eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval eval.evaluate() @@ -264,5 +295,5 @@ class DetectionValidator(BaseValidator): eval.summarize() stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50 except Exception as e: - LOGGER.warning(f'pycocotools unable to run: {e}') + LOGGER.warning(f"pycocotools unable to run: {e}") return stats diff --git a/ultralytics/models/yolo/model.py b/ultralytics/models/yolo/model.py index b85d46b..f10dc97 100644 --- a/ultralytics/models/yolo/model.py +++ b/ultralytics/models/yolo/model.py @@ -1,36 +1,111 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license +from pathlib import Path + from ultralytics.engine.model import Model -from ultralytics.models import yolo # noqa -from ultralytics.nn.tasks import ClassificationModel, DetectionModel, PoseModel, SegmentationModel +from ultralytics.models import yolo +from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel +from ultralytics.utils import yaml_load, ROOT class YOLO(Model): - """ - YOLO (You Only Look Once) object detection model. - """ + """YOLO (You Only Look Once) object detection model.""" + + def __init__(self, model="yolov8n.pt", task=None, verbose=False): + """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'.""" + path = Path(model) + if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model + new_instance = YOLOWorld(path) + self.__class__ = type(new_instance) + self.__dict__ = new_instance.__dict__ + elif "yolov10" in path.stem: + from ultralytics import YOLOv10 + new_instance = YOLOv10(path) + self.__class__ = type(new_instance) + self.__dict__ = new_instance.__dict__ + else: + # Continue with default YOLO initialization + super().__init__(model=model, task=task, verbose=verbose) @property def task_map(self): - """Map head to model, trainer, validator, and predictor classes""" + """Map head to model, trainer, validator, and predictor classes.""" return { - 'classify': { - 'model': ClassificationModel, - 'trainer': yolo.classify.ClassificationTrainer, - 'validator': yolo.classify.ClassificationValidator, - 'predictor': yolo.classify.ClassificationPredictor, }, - 'detect': { - 'model': DetectionModel, - 'trainer': yolo.detect.DetectionTrainer, - 'validator': yolo.detect.DetectionValidator, - 'predictor': yolo.detect.DetectionPredictor, }, - 'segment': { - 'model': SegmentationModel, - 'trainer': yolo.segment.SegmentationTrainer, - 'validator': yolo.segment.SegmentationValidator, - 'predictor': yolo.segment.SegmentationPredictor, }, - 'pose': { - 'model': PoseModel, - 'trainer': yolo.pose.PoseTrainer, - 'validator': yolo.pose.PoseValidator, - 'predictor': yolo.pose.PosePredictor, }, } + "classify": { + "model": ClassificationModel, + "trainer": yolo.classify.ClassificationTrainer, + "validator": yolo.classify.ClassificationValidator, + "predictor": yolo.classify.ClassificationPredictor, + }, + "detect": { + "model": DetectionModel, + "trainer": yolo.detect.DetectionTrainer, + "validator": yolo.detect.DetectionValidator, + "predictor": yolo.detect.DetectionPredictor, + }, + "segment": { + "model": SegmentationModel, + "trainer": yolo.segment.SegmentationTrainer, + "validator": yolo.segment.SegmentationValidator, + "predictor": yolo.segment.SegmentationPredictor, + }, + "pose": { + "model": PoseModel, + "trainer": yolo.pose.PoseTrainer, + "validator": yolo.pose.PoseValidator, + "predictor": yolo.pose.PosePredictor, + }, + "obb": { + "model": OBBModel, + "trainer": yolo.obb.OBBTrainer, + "validator": yolo.obb.OBBValidator, + "predictor": yolo.obb.OBBPredictor, + }, + } + + +class YOLOWorld(Model): + """YOLO-World object detection model.""" + + def __init__(self, model="yolov8s-world.pt") -> None: + """ + Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats. + + Args: + model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'. + """ + super().__init__(model=model, task="detect") + + # Assign default COCO class names when there are no custom names + if not hasattr(self.model, "names"): + self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names") + + @property + def task_map(self): + """Map head to model, validator, and predictor classes.""" + return { + "detect": { + "model": WorldModel, + "validator": yolo.detect.DetectionValidator, + "predictor": yolo.detect.DetectionPredictor, + } + } + + def set_classes(self, classes): + """ + Set classes. + + Args: + classes (List(str)): A list of categories i.e ["person"]. + """ + self.model.set_classes(classes) + # Remove background if it's given + background = " " + if background in classes: + classes.remove(background) + self.model.names = classes + + # Reset method class names + # self.predictor = None # reset predictor otherwise old names remain + if self.predictor: + self.predictor.model.names = classes diff --git a/ultralytics/models/yolo/obb/__init__.py b/ultralytics/models/yolo/obb/__init__.py new file mode 100644 index 0000000..f60349a --- /dev/null +++ b/ultralytics/models/yolo/obb/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from .predict import OBBPredictor +from .train import OBBTrainer +from .val import OBBValidator + +__all__ = "OBBPredictor", "OBBTrainer", "OBBValidator" diff --git a/ultralytics/models/yolo/obb/__pycache__/__init__.cpython-312.pyc b/ultralytics/models/yolo/obb/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..a4c4cf2 Binary files /dev/null and b/ultralytics/models/yolo/obb/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/obb/__pycache__/__init__.cpython-39.pyc b/ultralytics/models/yolo/obb/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..56eb79a Binary files /dev/null and b/ultralytics/models/yolo/obb/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/obb/__pycache__/predict.cpython-312.pyc b/ultralytics/models/yolo/obb/__pycache__/predict.cpython-312.pyc new file mode 100644 index 0000000..fb46f2f Binary files /dev/null and b/ultralytics/models/yolo/obb/__pycache__/predict.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/obb/__pycache__/predict.cpython-39.pyc b/ultralytics/models/yolo/obb/__pycache__/predict.cpython-39.pyc new file mode 100644 index 0000000..51d5ac8 Binary files /dev/null and b/ultralytics/models/yolo/obb/__pycache__/predict.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/obb/__pycache__/train.cpython-312.pyc b/ultralytics/models/yolo/obb/__pycache__/train.cpython-312.pyc new file mode 100644 index 0000000..4e86ac5 Binary files /dev/null and b/ultralytics/models/yolo/obb/__pycache__/train.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/obb/__pycache__/train.cpython-39.pyc b/ultralytics/models/yolo/obb/__pycache__/train.cpython-39.pyc new file mode 100644 index 0000000..9c2cb08 Binary files /dev/null and b/ultralytics/models/yolo/obb/__pycache__/train.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/obb/__pycache__/val.cpython-312.pyc b/ultralytics/models/yolo/obb/__pycache__/val.cpython-312.pyc new file mode 100644 index 0000000..ba785e6 Binary files /dev/null and b/ultralytics/models/yolo/obb/__pycache__/val.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/obb/__pycache__/val.cpython-39.pyc b/ultralytics/models/yolo/obb/__pycache__/val.cpython-39.pyc new file mode 100644 index 0000000..c9ca611 Binary files /dev/null and b/ultralytics/models/yolo/obb/__pycache__/val.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/obb/predict.py b/ultralytics/models/yolo/obb/predict.py new file mode 100644 index 0000000..bb8d4d3 --- /dev/null +++ b/ultralytics/models/yolo/obb/predict.py @@ -0,0 +1,53 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import torch + +from ultralytics.engine.results import Results +from ultralytics.models.yolo.detect.predict import DetectionPredictor +from ultralytics.utils import DEFAULT_CFG, ops + + +class OBBPredictor(DetectionPredictor): + """ + A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model. + + Example: + ```python + from ultralytics.utils import ASSETS + from ultralytics.models.yolo.obb import OBBPredictor + + args = dict(model='yolov8n-obb.pt', source=ASSETS) + predictor = OBBPredictor(overrides=args) + predictor.predict_cli() + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initializes OBBPredictor with optional model and data configuration overrides.""" + super().__init__(cfg, overrides, _callbacks) + self.args.task = "obb" + + def postprocess(self, preds, img, orig_imgs): + """Post-processes predictions and returns a list of Results objects.""" + preds = ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + nc=len(self.model.names), + classes=self.args.classes, + rotated=True, + ) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]): + rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1)) + rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True) + # xywh, r, conf, cls + obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1) + results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb)) + return results diff --git a/ultralytics/models/yolo/obb/train.py b/ultralytics/models/yolo/obb/train.py new file mode 100644 index 0000000..40a35a9 --- /dev/null +++ b/ultralytics/models/yolo/obb/train.py @@ -0,0 +1,42 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from copy import copy + +from ultralytics.models import yolo +from ultralytics.nn.tasks import OBBModel +from ultralytics.utils import DEFAULT_CFG, RANK + + +class OBBTrainer(yolo.detect.DetectionTrainer): + """ + A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model. + + Example: + ```python + from ultralytics.models.yolo.obb import OBBTrainer + + args = dict(model='yolov8n-obb.pt', data='dota8.yaml', epochs=3) + trainer = OBBTrainer(overrides=args) + trainer.train() + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a OBBTrainer object with given arguments.""" + if overrides is None: + overrides = {} + overrides["task"] = "obb" + super().__init__(cfg, overrides, _callbacks) + + def get_model(self, cfg=None, weights=None, verbose=True): + """Return OBBModel initialized with specified config and weights.""" + model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + + return model + + def get_validator(self): + """Return an instance of OBBValidator for validation of YOLO model.""" + self.loss_names = "box_loss", "cls_loss", "dfl_loss" + return yolo.obb.OBBValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) diff --git a/ultralytics/models/yolo/obb/val.py b/ultralytics/models/yolo/obb/val.py new file mode 100644 index 0000000..c440fe2 --- /dev/null +++ b/ultralytics/models/yolo/obb/val.py @@ -0,0 +1,185 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from pathlib import Path + +import torch + +from ultralytics.models.yolo.detect import DetectionValidator +from ultralytics.utils import LOGGER, ops +from ultralytics.utils.metrics import OBBMetrics, batch_probiou +from ultralytics.utils.plotting import output_to_rotated_target, plot_images + + +class OBBValidator(DetectionValidator): + """ + A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model. + + Example: + ```python + from ultralytics.models.yolo.obb import OBBValidator + + args = dict(model='yolov8n-obb.pt', data='dota8.yaml') + validator = OBBValidator(args=args) + validator(model=args['model']) + ``` + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.""" + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.args.task = "obb" + self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot) + + def init_metrics(self, model): + """Initialize evaluation metrics for YOLO.""" + super().init_metrics(model) + val = self.data.get(self.args.split, "") # validation path + self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO + + def postprocess(self, preds): + """Apply Non-maximum suppression to prediction outputs.""" + return ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + labels=self.lb, + nc=self.nc, + multi_label=True, + agnostic=self.args.single_cls, + max_det=self.args.max_det, + rotated=True, + ) + + def _process_batch(self, detections, gt_bboxes, gt_cls): + """ + Return correct prediction matrix. + + Args: + detections (torch.Tensor): Tensor of shape [N, 7] representing detections. + Each detection is of the format: x1, y1, x2, y2, conf, class, angle. + gt_bboxes (torch.Tensor): Tensor of shape [M, 5] representing rotated boxes. + Each box is of the format: x1, y1, x2, y2, angle. + labels (torch.Tensor): Tensor of shape [M] representing labels. + + Returns: + (torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels. + """ + iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1)) + return self.match_predictions(detections[:, 5], gt_cls, iou) + + def _prepare_batch(self, si, batch): + """Prepares and returns a batch for OBB validation.""" + idx = batch["batch_idx"] == si + cls = batch["cls"][idx].squeeze(-1) + bbox = batch["bboxes"][idx] + ori_shape = batch["ori_shape"][si] + imgsz = batch["img"].shape[2:] + ratio_pad = batch["ratio_pad"][si] + if len(cls): + bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes + ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels + return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad) + + def _prepare_pred(self, pred, pbatch): + """Prepares and returns a batch for OBB validation with scaled and padded bounding boxes.""" + predn = pred.clone() + ops.scale_boxes( + pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True + ) # native-space pred + return predn + + def plot_predictions(self, batch, preds, ni): + """Plots predicted bounding boxes on input images and saves the result.""" + plot_images( + batch["img"], + *output_to_rotated_target(preds, max_det=self.args.max_det), + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred + + def pred_to_json(self, predn, filename): + """Serialize YOLO predictions to COCO json format.""" + stem = Path(filename).stem + image_id = int(stem) if stem.isnumeric() else stem + rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1) + poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8) + for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())): + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(predn[i, 5].item())], + "score": round(predn[i, 4].item(), 5), + "rbox": [round(x, 3) for x in r], + "poly": [round(x, 3) for x in b], + } + ) + + def save_one_txt(self, predn, save_conf, shape, file): + """Save YOLO detections to a txt file in normalized coordinates in a specific format.""" + gn = torch.tensor(shape)[[1, 0]] # normalization gain whwh + for *xywh, conf, cls, angle in predn.tolist(): + xywha = torch.tensor([*xywh, angle]).view(1, 5) + xyxyxyxy = (ops.xywhr2xyxyxyxy(xywha) / gn).view(-1).tolist() # normalized xywh + line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy) # label format + with open(file, "a") as f: + f.write(("%g " * len(line)).rstrip() % line + "\n") + + def eval_json(self, stats): + """Evaluates YOLO output in JSON format and returns performance statistics.""" + if self.args.save_json and self.is_dota and len(self.jdict): + import json + import re + from collections import defaultdict + + pred_json = self.save_dir / "predictions.json" # predictions + pred_txt = self.save_dir / "predictions_txt" # predictions + pred_txt.mkdir(parents=True, exist_ok=True) + data = json.load(open(pred_json)) + # Save split results + LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...") + for d in data: + image_id = d["image_id"] + score = d["score"] + classname = self.names[d["category_id"]].replace(" ", "-") + p = d["poly"] + + with open(f'{pred_txt / f"Task1_{classname}"}.txt', "a") as f: + f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n") + # Save merged results, this could result slightly lower map than using official merging script, + # because of the probiou calculation. + pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions + pred_merged_txt.mkdir(parents=True, exist_ok=True) + merged_results = defaultdict(list) + LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...") + for d in data: + image_id = d["image_id"].split("__")[0] + pattern = re.compile(r"\d+___\d+") + x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___")) + bbox, score, cls = d["rbox"], d["score"], d["category_id"] + bbox[0] += x + bbox[1] += y + bbox.extend([score, cls]) + merged_results[image_id].append(bbox) + for image_id, bbox in merged_results.items(): + bbox = torch.tensor(bbox) + max_wh = torch.max(bbox[:, :2]).item() * 2 + c = bbox[:, 6:7] * max_wh # classes + scores = bbox[:, 5] # scores + b = bbox[:, :5].clone() + b[:, :2] += c + # 0.3 could get results close to the ones from official merging script, even slightly better. + i = ops.nms_rotated(b, scores, 0.3) + bbox = bbox[i] + + b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8) + for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist(): + classname = self.names[int(x[-1])].replace(" ", "-") + p = [round(i, 3) for i in x[:-2]] # poly + score = round(x[-2], 3) + + with open(f'{pred_merged_txt / f"Task1_{classname}"}.txt', "a") as f: + f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n") + + return stats diff --git a/ultralytics/models/yolo/pose/__init__.py b/ultralytics/models/yolo/pose/__init__.py index 2a79f0f..d566943 100644 --- a/ultralytics/models/yolo/pose/__init__.py +++ b/ultralytics/models/yolo/pose/__init__.py @@ -4,4 +4,4 @@ from .predict import PosePredictor from .train import PoseTrainer from .val import PoseValidator -__all__ = 'PoseTrainer', 'PoseValidator', 'PosePredictor' +__all__ = "PoseTrainer", "PoseValidator", "PosePredictor" diff --git a/ultralytics/models/yolo/pose/__pycache__/__init__.cpython-312.pyc b/ultralytics/models/yolo/pose/__pycache__/__init__.cpython-312.pyc index ea6cab1..3cc2a22 100644 Binary files a/ultralytics/models/yolo/pose/__pycache__/__init__.cpython-312.pyc and b/ultralytics/models/yolo/pose/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/pose/__pycache__/__init__.cpython-39.pyc b/ultralytics/models/yolo/pose/__pycache__/__init__.cpython-39.pyc index 35cd12f..1db6261 100644 Binary files a/ultralytics/models/yolo/pose/__pycache__/__init__.cpython-39.pyc and b/ultralytics/models/yolo/pose/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/pose/__pycache__/predict.cpython-312.pyc b/ultralytics/models/yolo/pose/__pycache__/predict.cpython-312.pyc index 5142060..d4f16cc 100644 Binary files a/ultralytics/models/yolo/pose/__pycache__/predict.cpython-312.pyc and b/ultralytics/models/yolo/pose/__pycache__/predict.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/pose/__pycache__/predict.cpython-39.pyc b/ultralytics/models/yolo/pose/__pycache__/predict.cpython-39.pyc index 6d92524..de16763 100644 Binary files a/ultralytics/models/yolo/pose/__pycache__/predict.cpython-39.pyc and b/ultralytics/models/yolo/pose/__pycache__/predict.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/pose/__pycache__/train.cpython-312.pyc b/ultralytics/models/yolo/pose/__pycache__/train.cpython-312.pyc index f008c3d..38c4493 100644 Binary files a/ultralytics/models/yolo/pose/__pycache__/train.cpython-312.pyc and b/ultralytics/models/yolo/pose/__pycache__/train.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/pose/__pycache__/train.cpython-39.pyc b/ultralytics/models/yolo/pose/__pycache__/train.cpython-39.pyc index a27ca4d..e7268de 100644 Binary files a/ultralytics/models/yolo/pose/__pycache__/train.cpython-39.pyc and b/ultralytics/models/yolo/pose/__pycache__/train.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/pose/__pycache__/val.cpython-312.pyc b/ultralytics/models/yolo/pose/__pycache__/val.cpython-312.pyc index 748a7a5..5276eff 100644 Binary files a/ultralytics/models/yolo/pose/__pycache__/val.cpython-312.pyc and b/ultralytics/models/yolo/pose/__pycache__/val.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/pose/__pycache__/val.cpython-39.pyc b/ultralytics/models/yolo/pose/__pycache__/val.cpython-39.pyc index 50fdd65..a331001 100644 Binary files a/ultralytics/models/yolo/pose/__pycache__/val.cpython-39.pyc and b/ultralytics/models/yolo/pose/__pycache__/val.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/pose/predict.py b/ultralytics/models/yolo/pose/predict.py index 14ae40b..7c55709 100644 --- a/ultralytics/models/yolo/pose/predict.py +++ b/ultralytics/models/yolo/pose/predict.py @@ -21,21 +21,26 @@ class PosePredictor(DetectionPredictor): """ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device.""" super().__init__(cfg, overrides, _callbacks) - self.args.task = 'pose' - if isinstance(self.args.device, str) and self.args.device.lower() == 'mps': - LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " - 'See https://github.com/ultralytics/ultralytics/issues/4031.') + self.args.task = "pose" + if isinstance(self.args.device, str) and self.args.device.lower() == "mps": + LOGGER.warning( + "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " + "See https://github.com/ultralytics/ultralytics/issues/4031." + ) def postprocess(self, preds, img, orig_imgs): """Return detection results for a given input image or list of images.""" - preds = ops.non_max_suppression(preds, - self.args.conf, - self.args.iou, - agnostic=self.args.agnostic_nms, - max_det=self.args.max_det, - classes=self.args.classes, - nc=len(self.model.names)) + preds = ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + classes=self.args.classes, + nc=len(self.model.names), + ) if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) @@ -48,5 +53,6 @@ class PosePredictor(DetectionPredictor): pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape) img_path = self.batch[0][i] results.append( - Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)) + Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts) + ) return results diff --git a/ultralytics/models/yolo/pose/train.py b/ultralytics/models/yolo/pose/train.py index 2d4f4e0..f5229e5 100644 --- a/ultralytics/models/yolo/pose/train.py +++ b/ultralytics/models/yolo/pose/train.py @@ -26,16 +26,18 @@ class PoseTrainer(yolo.detect.DetectionTrainer): """Initialize a PoseTrainer object with specified configurations and overrides.""" if overrides is None: overrides = {} - overrides['task'] = 'pose' + overrides["task"] = "pose" super().__init__(cfg, overrides, _callbacks) - if isinstance(self.args.device, str) and self.args.device.lower() == 'mps': - LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " - 'See https://github.com/ultralytics/ultralytics/issues/4031.') + if isinstance(self.args.device, str) and self.args.device.lower() == "mps": + LOGGER.warning( + "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " + "See https://github.com/ultralytics/ultralytics/issues/4031." + ) def get_model(self, cfg=None, weights=None, verbose=True): """Get pose estimation model with specified configuration and weights.""" - model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose) + model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose) if weights: model.load(weights) @@ -44,29 +46,33 @@ class PoseTrainer(yolo.detect.DetectionTrainer): def set_model_attributes(self): """Sets keypoints shape attribute of PoseModel.""" super().set_model_attributes() - self.model.kpt_shape = self.data['kpt_shape'] + self.model.kpt_shape = self.data["kpt_shape"] def get_validator(self): """Returns an instance of the PoseValidator class for validation.""" - self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss' - return yolo.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) + self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss" + return yolo.pose.PoseValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) def plot_training_samples(self, batch, ni): """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.""" - images = batch['img'] - kpts = batch['keypoints'] - cls = batch['cls'].squeeze(-1) - bboxes = batch['bboxes'] - paths = batch['im_file'] - batch_idx = batch['batch_idx'] - plot_images(images, - batch_idx, - cls, - bboxes, - kpts=kpts, - paths=paths, - fname=self.save_dir / f'train_batch{ni}.jpg', - on_plot=self.on_plot) + images = batch["img"] + kpts = batch["keypoints"] + cls = batch["cls"].squeeze(-1) + bboxes = batch["bboxes"] + paths = batch["im_file"] + batch_idx = batch["batch_idx"] + plot_images( + images, + batch_idx, + cls, + bboxes, + kpts=kpts, + paths=paths, + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) def plot_metrics(self): """Plots training/val metrics.""" diff --git a/ultralytics/models/yolo/pose/val.py b/ultralytics/models/yolo/pose/val.py index b8ebf57..8405686 100644 --- a/ultralytics/models/yolo/pose/val.py +++ b/ultralytics/models/yolo/pose/val.py @@ -31,100 +31,125 @@ class PoseValidator(DetectionValidator): super().__init__(dataloader, save_dir, pbar, args, _callbacks) self.sigma = None self.kpt_shape = None - self.args.task = 'pose' + self.args.task = "pose" self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot) - if isinstance(self.args.device, str) and self.args.device.lower() == 'mps': - LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " - 'See https://github.com/ultralytics/ultralytics/issues/4031.') + if isinstance(self.args.device, str) and self.args.device.lower() == "mps": + LOGGER.warning( + "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " + "See https://github.com/ultralytics/ultralytics/issues/4031." + ) def preprocess(self, batch): """Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device.""" batch = super().preprocess(batch) - batch['keypoints'] = batch['keypoints'].to(self.device).float() + batch["keypoints"] = batch["keypoints"].to(self.device).float() return batch def get_desc(self): """Returns description of evaluation metrics in string format.""" - return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P', - 'R', 'mAP50', 'mAP50-95)') + return ("%22s" + "%11s" * 10) % ( + "Class", + "Images", + "Instances", + "Box(P", + "R", + "mAP50", + "mAP50-95)", + "Pose(P", + "R", + "mAP50", + "mAP50-95)", + ) def postprocess(self, preds): """Apply non-maximum suppression and return detections with high confidence scores.""" - return ops.non_max_suppression(preds, - self.args.conf, - self.args.iou, - labels=self.lb, - multi_label=True, - agnostic=self.args.single_cls, - max_det=self.args.max_det, - nc=self.nc) + return ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + labels=self.lb, + multi_label=True, + agnostic=self.args.single_cls, + max_det=self.args.max_det, + nc=self.nc, + ) def init_metrics(self, model): """Initiate pose estimation metrics for YOLO model.""" super().init_metrics(model) - self.kpt_shape = self.data['kpt_shape'] + self.kpt_shape = self.data["kpt_shape"] is_pose = self.kpt_shape == [17, 3] nkpt = self.kpt_shape[0] self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt + self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[]) + + def _prepare_batch(self, si, batch): + """Prepares a batch for processing by converting keypoints to float and moving to device.""" + pbatch = super()._prepare_batch(si, batch) + kpts = batch["keypoints"][batch["batch_idx"] == si] + h, w = pbatch["imgsz"] + kpts = kpts.clone() + kpts[..., 0] *= w + kpts[..., 1] *= h + kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]) + pbatch["kpts"] = kpts + return pbatch + + def _prepare_pred(self, pred, pbatch): + """Prepares and scales keypoints in a batch for pose processing.""" + predn = super()._prepare_pred(pred, pbatch) + nk = pbatch["kpts"].shape[1] + pred_kpts = predn[:, 6:].view(len(predn), nk, -1) + ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]) + return predn, pred_kpts def update_metrics(self, preds, batch): """Metrics.""" for si, pred in enumerate(preds): - idx = batch['batch_idx'] == si - cls = batch['cls'][idx] - bbox = batch['bboxes'][idx] - kpts = batch['keypoints'][idx] - nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions - nk = kpts.shape[1] # number of keypoints - shape = batch['ori_shape'][si] - correct_kpts = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init - correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init self.seen += 1 - + npr = len(pred) + stat = dict( + conf=torch.zeros(0, device=self.device), + pred_cls=torch.zeros(0, device=self.device), + tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + ) + pbatch = self._prepare_batch(si, batch) + cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") + nl = len(cls) + stat["target_cls"] = cls if npr == 0: if nl: - self.stats.append((correct_bboxes, correct_kpts, *torch.zeros( - (2, 0), device=self.device), cls.squeeze(-1))) + for k in self.stats.keys(): + self.stats[k].append(stat[k]) if self.args.plots: - self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1)) + self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls) continue # Predictions if self.args.single_cls: pred[:, 5] = 0 - predn = pred.clone() - ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape, - ratio_pad=batch['ratio_pad'][si]) # native-space pred - pred_kpts = predn[:, 6:].view(npr, nk, -1) - ops.scale_coords(batch['img'][si].shape[1:], pred_kpts, shape, ratio_pad=batch['ratio_pad'][si]) + predn, pred_kpts = self._prepare_pred(pred, pbatch) + stat["conf"] = predn[:, 4] + stat["pred_cls"] = predn[:, 5] # Evaluate if nl: - height, width = batch['img'].shape[2:] - tbox = ops.xywh2xyxy(bbox) * torch.tensor( - (width, height, width, height), device=self.device) # target boxes - ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape, - ratio_pad=batch['ratio_pad'][si]) # native-space labels - tkpts = kpts.clone() - tkpts[..., 0] *= width - tkpts[..., 1] *= height - tkpts = ops.scale_coords(batch['img'][si].shape[1:], tkpts, shape, ratio_pad=batch['ratio_pad'][si]) - labelsn = torch.cat((cls, tbox), 1) # native-space labels - correct_bboxes = self._process_batch(predn[:, :6], labelsn) - correct_kpts = self._process_batch(predn[:, :6], labelsn, pred_kpts, tkpts) + stat["tp"] = self._process_batch(predn, bbox, cls) + stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"]) if self.args.plots: - self.confusion_matrix.process_batch(predn, labelsn) + self.confusion_matrix.process_batch(predn, bbox, cls) - # Append correct_masks, correct_boxes, pconf, pcls, tcls - self.stats.append((correct_bboxes, correct_kpts, pred[:, 4], pred[:, 5], cls.squeeze(-1))) + for k in self.stats.keys(): + self.stats[k].append(stat[k]) # Save if self.args.save_json: - self.pred_to_json(predn, batch['im_file'][si]) + self.pred_to_json(predn, batch["im_file"][si]) # if self.args.save_txt: # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') - def _process_batch(self, detections, labels, pred_kpts=None, gt_kpts=None): + def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None): """ Return correct prediction matrix. @@ -142,35 +167,39 @@ class PoseValidator(DetectionValidator): """ if pred_kpts is not None and gt_kpts is not None: # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384 - area = ops.xyxy2xywh(labels[:, 1:])[:, 2:].prod(1) * 0.53 + area = ops.xyxy2xywh(gt_bboxes)[:, 2:].prod(1) * 0.53 iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area) else: # boxes - iou = box_iou(labels[:, 1:], detections[:, :4]) + iou = box_iou(gt_bboxes, detections[:, :4]) - return self.match_predictions(detections[:, 5], labels[:, 0], iou) + return self.match_predictions(detections[:, 5], gt_cls, iou) def plot_val_samples(self, batch, ni): """Plots and saves validation set samples with predicted bounding boxes and keypoints.""" - plot_images(batch['img'], - batch['batch_idx'], - batch['cls'].squeeze(-1), - batch['bboxes'], - kpts=batch['keypoints'], - paths=batch['im_file'], - fname=self.save_dir / f'val_batch{ni}_labels.jpg', - names=self.names, - on_plot=self.on_plot) + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + kpts=batch["keypoints"], + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) def plot_predictions(self, batch, preds, ni): """Plots predictions for YOLO model.""" pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0) - plot_images(batch['img'], - *output_to_target(preds, max_det=self.args.max_det), - kpts=pred_kpts, - paths=batch['im_file'], - fname=self.save_dir / f'val_batch{ni}_pred.jpg', - names=self.names, - on_plot=self.on_plot) # pred + plot_images( + batch["img"], + *output_to_target(preds, max_det=self.args.max_det), + kpts=pred_kpts, + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred def pred_to_json(self, predn, filename): """Converts YOLO predictions to COCO JSON format.""" @@ -179,37 +208,41 @@ class PoseValidator(DetectionValidator): box = ops.xyxy2xywh(predn[:, :4]) # xywh box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner for p, b in zip(predn.tolist(), box.tolist()): - self.jdict.append({ - 'image_id': image_id, - 'category_id': self.class_map[int(p[5])], - 'bbox': [round(x, 3) for x in b], - 'keypoints': p[6:], - 'score': round(p[4], 5)}) + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(p[5])], + "bbox": [round(x, 3) for x in b], + "keypoints": p[6:], + "score": round(p[4], 5), + } + ) def eval_json(self, stats): """Evaluates object detection model using COCO JSON format.""" if self.args.save_json and self.is_coco and len(self.jdict): - anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json' # annotations - pred_json = self.save_dir / 'predictions.json' # predictions - LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') + anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations + pred_json = self.save_dir / "predictions.json" # predictions + LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb - check_requirements('pycocotools>=2.0.6') + check_requirements("pycocotools>=2.0.6") from pycocotools.coco import COCO # noqa from pycocotools.cocoeval import COCOeval # noqa for x in anno_json, pred_json: - assert x.is_file(), f'{x} file not found' + assert x.is_file(), f"{x} file not found" anno = COCO(str(anno_json)) # init annotations api pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) - for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'keypoints')]): + for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]): if self.is_coco: eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval eval.evaluate() eval.accumulate() eval.summarize() idx = i * 4 + 2 - stats[self.metrics.keys[idx + 1]], stats[ - self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50 + stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[ + :2 + ] # update mAP50-95 and mAP50 except Exception as e: - LOGGER.warning(f'pycocotools unable to run: {e}') + LOGGER.warning(f"pycocotools unable to run: {e}") return stats diff --git a/ultralytics/models/yolo/segment/__init__.py b/ultralytics/models/yolo/segment/__init__.py index c84a570..ec1ac79 100644 --- a/ultralytics/models/yolo/segment/__init__.py +++ b/ultralytics/models/yolo/segment/__init__.py @@ -4,4 +4,4 @@ from .predict import SegmentationPredictor from .train import SegmentationTrainer from .val import SegmentationValidator -__all__ = 'SegmentationPredictor', 'SegmentationTrainer', 'SegmentationValidator' +__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator" diff --git a/ultralytics/models/yolo/segment/__pycache__/__init__.cpython-312.pyc b/ultralytics/models/yolo/segment/__pycache__/__init__.cpython-312.pyc index 60035cc..a9bf274 100644 Binary files a/ultralytics/models/yolo/segment/__pycache__/__init__.cpython-312.pyc and b/ultralytics/models/yolo/segment/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/segment/__pycache__/__init__.cpython-39.pyc b/ultralytics/models/yolo/segment/__pycache__/__init__.cpython-39.pyc index 670a340..05ab95f 100644 Binary files a/ultralytics/models/yolo/segment/__pycache__/__init__.cpython-39.pyc and b/ultralytics/models/yolo/segment/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/segment/__pycache__/predict.cpython-312.pyc b/ultralytics/models/yolo/segment/__pycache__/predict.cpython-312.pyc index b860998..64e90a9 100644 Binary files a/ultralytics/models/yolo/segment/__pycache__/predict.cpython-312.pyc and b/ultralytics/models/yolo/segment/__pycache__/predict.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/segment/__pycache__/predict.cpython-39.pyc b/ultralytics/models/yolo/segment/__pycache__/predict.cpython-39.pyc index ebc040d..135cf37 100644 Binary files a/ultralytics/models/yolo/segment/__pycache__/predict.cpython-39.pyc and b/ultralytics/models/yolo/segment/__pycache__/predict.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/segment/__pycache__/train.cpython-312.pyc b/ultralytics/models/yolo/segment/__pycache__/train.cpython-312.pyc index f43270b..acedd40 100644 Binary files a/ultralytics/models/yolo/segment/__pycache__/train.cpython-312.pyc and b/ultralytics/models/yolo/segment/__pycache__/train.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/segment/__pycache__/train.cpython-39.pyc b/ultralytics/models/yolo/segment/__pycache__/train.cpython-39.pyc index 25a7406..44c61db 100644 Binary files a/ultralytics/models/yolo/segment/__pycache__/train.cpython-39.pyc and b/ultralytics/models/yolo/segment/__pycache__/train.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/segment/__pycache__/val.cpython-312.pyc b/ultralytics/models/yolo/segment/__pycache__/val.cpython-312.pyc index e8d7ea7..7cf7781 100644 Binary files a/ultralytics/models/yolo/segment/__pycache__/val.cpython-312.pyc and b/ultralytics/models/yolo/segment/__pycache__/val.cpython-312.pyc differ diff --git a/ultralytics/models/yolo/segment/__pycache__/val.cpython-39.pyc b/ultralytics/models/yolo/segment/__pycache__/val.cpython-39.pyc index 274331b..b5e60d5 100644 Binary files a/ultralytics/models/yolo/segment/__pycache__/val.cpython-39.pyc and b/ultralytics/models/yolo/segment/__pycache__/val.cpython-39.pyc differ diff --git a/ultralytics/models/yolo/segment/predict.py b/ultralytics/models/yolo/segment/predict.py index 7d51f7d..9d7015f 100644 --- a/ultralytics/models/yolo/segment/predict.py +++ b/ultralytics/models/yolo/segment/predict.py @@ -21,23 +21,27 @@ class SegmentationPredictor(DetectionPredictor): """ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initializes the SegmentationPredictor with the provided configuration, overrides, and callbacks.""" super().__init__(cfg, overrides, _callbacks) - self.args.task = 'segment' + self.args.task = "segment" def postprocess(self, preds, img, orig_imgs): - p = ops.non_max_suppression(preds[0], - self.args.conf, - self.args.iou, - agnostic=self.args.agnostic_nms, - max_det=self.args.max_det, - nc=len(self.model.names), - classes=self.args.classes) + """Applies non-max suppression and processes detections for each image in an input batch.""" + p = ops.non_max_suppression( + preds[0], + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + nc=len(self.model.names), + classes=self.args.classes, + ) if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) results = [] - proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported + proto = preds[1][-1] if isinstance(preds[1], tuple) else preds[1] # tuple if PyTorch model or array if exported for i, pred in enumerate(p): orig_img = orig_imgs[i] img_path = self.batch[0][i] diff --git a/ultralytics/models/yolo/segment/train.py b/ultralytics/models/yolo/segment/train.py index b290192..126baf2 100644 --- a/ultralytics/models/yolo/segment/train.py +++ b/ultralytics/models/yolo/segment/train.py @@ -26,12 +26,12 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer): """Initialize a SegmentationTrainer object with given arguments.""" if overrides is None: overrides = {} - overrides['task'] = 'segment' + overrides["task"] = "segment" super().__init__(cfg, overrides, _callbacks) def get_model(self, cfg=None, weights=None, verbose=True): """Return SegmentationModel initialized with specified config and weights.""" - model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1) + model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1) if weights: model.load(weights) @@ -39,19 +39,23 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer): def get_validator(self): """Return an instance of SegmentationValidator for validation of YOLO model.""" - self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' - return yolo.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) + self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss" + return yolo.segment.SegmentationValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) def plot_training_samples(self, batch, ni): """Creates a plot of training sample images with labels and box coordinates.""" - plot_images(batch['img'], - batch['batch_idx'], - batch['cls'].squeeze(-1), - batch['bboxes'], - batch['masks'], - paths=batch['im_file'], - fname=self.save_dir / f'train_batch{ni}.jpg', - on_plot=self.on_plot) + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + masks=batch["masks"], + paths=batch["im_file"], + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) def plot_metrics(self): """Plots training/val metrics.""" diff --git a/ultralytics/models/yolo/segment/val.py b/ultralytics/models/yolo/segment/val.py index 0a2acb4..94757c4 100644 --- a/ultralytics/models/yolo/segment/val.py +++ b/ultralytics/models/yolo/segment/val.py @@ -33,13 +33,13 @@ class SegmentationValidator(DetectionValidator): super().__init__(dataloader, save_dir, pbar, args, _callbacks) self.plot_masks = None self.process = None - self.args.task = 'segment' + self.args.task = "segment" self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot) def preprocess(self, batch): """Preprocesses batch by converting masks to float and sending to device.""" batch = super().preprocess(batch) - batch['masks'] = batch['masks'].to(self.device).float() + batch["masks"] = batch["masks"].to(self.device).float() return batch def init_metrics(self, model): @@ -47,82 +47,99 @@ class SegmentationValidator(DetectionValidator): super().init_metrics(model) self.plot_masks = [] if self.args.save_json: - check_requirements('pycocotools>=2.0.6') + check_requirements("pycocotools>=2.0.6") self.process = ops.process_mask_upsample # more accurate else: self.process = ops.process_mask # faster + self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[]) def get_desc(self): """Return a formatted description of evaluation metrics.""" - return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P', - 'R', 'mAP50', 'mAP50-95)') + return ("%22s" + "%11s" * 10) % ( + "Class", + "Images", + "Instances", + "Box(P", + "R", + "mAP50", + "mAP50-95)", + "Mask(P", + "R", + "mAP50", + "mAP50-95)", + ) def postprocess(self, preds): """Post-processes YOLO predictions and returns output detections with proto.""" - p = ops.non_max_suppression(preds[0], - self.args.conf, - self.args.iou, - labels=self.lb, - multi_label=True, - agnostic=self.args.single_cls, - max_det=self.args.max_det, - nc=self.nc) + p = ops.non_max_suppression( + preds[0], + self.args.conf, + self.args.iou, + labels=self.lb, + multi_label=True, + agnostic=self.args.single_cls, + max_det=self.args.max_det, + nc=self.nc, + ) proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported return p, proto + def _prepare_batch(self, si, batch): + """Prepares a batch for training or inference by processing images and targets.""" + prepared_batch = super()._prepare_batch(si, batch) + midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si + prepared_batch["masks"] = batch["masks"][midx] + return prepared_batch + + def _prepare_pred(self, pred, pbatch, proto): + """Prepares a batch for training or inference by processing images and targets.""" + predn = super()._prepare_pred(pred, pbatch) + pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"]) + return predn, pred_masks + def update_metrics(self, preds, batch): """Metrics.""" for si, (pred, proto) in enumerate(zip(preds[0], preds[1])): - idx = batch['batch_idx'] == si - cls = batch['cls'][idx] - bbox = batch['bboxes'][idx] - nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions - shape = batch['ori_shape'][si] - correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init - correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init self.seen += 1 - + npr = len(pred) + stat = dict( + conf=torch.zeros(0, device=self.device), + pred_cls=torch.zeros(0, device=self.device), + tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + ) + pbatch = self._prepare_batch(si, batch) + cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") + nl = len(cls) + stat["target_cls"] = cls if npr == 0: if nl: - self.stats.append((correct_bboxes, correct_masks, *torch.zeros( - (2, 0), device=self.device), cls.squeeze(-1))) + for k in self.stats.keys(): + self.stats[k].append(stat[k]) if self.args.plots: - self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1)) + self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls) continue # Masks - midx = [si] if self.args.overlap_mask else idx - gt_masks = batch['masks'][midx] - pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:]) - + gt_masks = pbatch.pop("masks") # Predictions if self.args.single_cls: pred[:, 5] = 0 - predn = pred.clone() - ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape, - ratio_pad=batch['ratio_pad'][si]) # native-space pred + predn, pred_masks = self._prepare_pred(pred, pbatch, proto) + stat["conf"] = predn[:, 4] + stat["pred_cls"] = predn[:, 5] # Evaluate if nl: - height, width = batch['img'].shape[2:] - tbox = ops.xywh2xyxy(bbox) * torch.tensor( - (width, height, width, height), device=self.device) # target boxes - ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape, - ratio_pad=batch['ratio_pad'][si]) # native-space labels - labelsn = torch.cat((cls, tbox), 1) # native-space labels - correct_bboxes = self._process_batch(predn, labelsn) - # TODO: maybe remove these `self.` arguments as they already are member variable - correct_masks = self._process_batch(predn, - labelsn, - pred_masks, - gt_masks, - overlap=self.args.overlap_mask, - masks=True) + stat["tp"] = self._process_batch(predn, bbox, cls) + stat["tp_m"] = self._process_batch( + predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True + ) if self.args.plots: - self.confusion_matrix.process_batch(predn, labelsn) + self.confusion_matrix.process_batch(predn, bbox, cls) - # Append correct_masks, correct_boxes, pconf, pcls, tcls - self.stats.append((correct_bboxes, correct_masks, pred[:, 4], pred[:, 5], cls.squeeze(-1))) + for k in self.stats.keys(): + self.stats[k].append(stat[k]) pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) if self.args.plots and self.batch_i < 3: @@ -130,10 +147,12 @@ class SegmentationValidator(DetectionValidator): # Save if self.args.save_json: - pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), - shape, - ratio_pad=batch['ratio_pad'][si]) - self.pred_to_json(predn, batch['im_file'][si], pred_masks) + pred_masks = ops.scale_image( + pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), + pbatch["ori_shape"], + ratio_pad=batch["ratio_pad"][si], + ) + self.pred_to_json(predn, batch["im_file"][si], pred_masks) # if self.args.save_txt: # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') @@ -142,9 +161,9 @@ class SegmentationValidator(DetectionValidator): self.metrics.speed = self.speed self.metrics.confusion_matrix = self.confusion_matrix - def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False): + def _process_batch(self, detections, gt_bboxes, gt_cls, pred_masks=None, gt_masks=None, overlap=False, masks=False): """ - Return correct prediction matrix + Return correct prediction matrix. Args: detections (array[N, 6]), x1, y1, x2, y2, conf, class @@ -155,52 +174,59 @@ class SegmentationValidator(DetectionValidator): """ if masks: if overlap: - nl = len(labels) + nl = len(gt_cls) index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1 gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640) gt_masks = torch.where(gt_masks == index, 1.0, 0.0) if gt_masks.shape[1:] != pred_masks.shape[1:]: - gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0] + gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0] gt_masks = gt_masks.gt_(0.5) iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1)) else: # boxes - iou = box_iou(labels[:, 1:], detections[:, :4]) + iou = box_iou(gt_bboxes, detections[:, :4]) - return self.match_predictions(detections[:, 5], labels[:, 0], iou) + return self.match_predictions(detections[:, 5], gt_cls, iou) def plot_val_samples(self, batch, ni): """Plots validation samples with bounding box labels.""" - plot_images(batch['img'], - batch['batch_idx'], - batch['cls'].squeeze(-1), - batch['bboxes'], - batch['masks'], - paths=batch['im_file'], - fname=self.save_dir / f'val_batch{ni}_labels.jpg', - names=self.names, - on_plot=self.on_plot) + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + masks=batch["masks"], + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) def plot_predictions(self, batch, preds, ni): """Plots batch predictions with masks and bounding boxes.""" plot_images( - batch['img'], + batch["img"], *output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks, - paths=batch['im_file'], - fname=self.save_dir / f'val_batch{ni}_pred.jpg', + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", names=self.names, - on_plot=self.on_plot) # pred + on_plot=self.on_plot, + ) # pred self.plot_masks.clear() def pred_to_json(self, predn, filename, pred_masks): - """Save one JSON result.""" - # Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} + """ + Save one JSON result. + + Examples: + >>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} + """ from pycocotools.mask import encode # noqa def single_encode(x): """Encode predicted masks as RLE and append results to jdict.""" - rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0] - rle['counts'] = rle['counts'].decode('utf-8') + rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0] + rle["counts"] = rle["counts"].decode("utf-8") return rle stem = Path(filename).stem @@ -211,37 +237,41 @@ class SegmentationValidator(DetectionValidator): with ThreadPool(NUM_THREADS) as pool: rles = pool.map(single_encode, pred_masks) for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())): - self.jdict.append({ - 'image_id': image_id, - 'category_id': self.class_map[int(p[5])], - 'bbox': [round(x, 3) for x in b], - 'score': round(p[4], 5), - 'segmentation': rles[i]}) + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(p[5])], + "bbox": [round(x, 3) for x in b], + "score": round(p[4], 5), + "segmentation": rles[i], + } + ) def eval_json(self, stats): """Return COCO-style object detection evaluation metrics.""" if self.args.save_json and self.is_coco and len(self.jdict): - anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations - pred_json = self.save_dir / 'predictions.json' # predictions - LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') + anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations + pred_json = self.save_dir / "predictions.json" # predictions + LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb - check_requirements('pycocotools>=2.0.6') + check_requirements("pycocotools>=2.0.6") from pycocotools.coco import COCO # noqa from pycocotools.cocoeval import COCOeval # noqa for x in anno_json, pred_json: - assert x.is_file(), f'{x} file not found' + assert x.is_file(), f"{x} file not found" anno = COCO(str(anno_json)) # init annotations api pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) - for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]): + for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]): if self.is_coco: eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval eval.evaluate() eval.accumulate() eval.summarize() idx = i * 4 + 2 - stats[self.metrics.keys[idx + 1]], stats[ - self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50 + stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[ + :2 + ] # update mAP50-95 and mAP50 except Exception as e: - LOGGER.warning(f'pycocotools unable to run: {e}') + LOGGER.warning(f"pycocotools unable to run: {e}") return stats diff --git a/ultralytics/models/yolov10/__init__.py b/ultralytics/models/yolov10/__init__.py new file mode 100644 index 0000000..97f137f --- /dev/null +++ b/ultralytics/models/yolov10/__init__.py @@ -0,0 +1,5 @@ +from .model import YOLOv10 +from .predict import YOLOv10DetectionPredictor +from .val import YOLOv10DetectionValidator + +__all__ = "YOLOv10DetectionPredictor", "YOLOv10DetectionValidator", "YOLOv10" diff --git a/ultralytics/models/yolov10/__pycache__/__init__.cpython-312.pyc b/ultralytics/models/yolov10/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..4ad531b Binary files /dev/null and b/ultralytics/models/yolov10/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/models/yolov10/__pycache__/__init__.cpython-39.pyc b/ultralytics/models/yolov10/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..54fffc5 Binary files /dev/null and b/ultralytics/models/yolov10/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/models/yolov10/__pycache__/card.cpython-312.pyc b/ultralytics/models/yolov10/__pycache__/card.cpython-312.pyc new file mode 100644 index 0000000..6c5ae8a Binary files /dev/null and b/ultralytics/models/yolov10/__pycache__/card.cpython-312.pyc differ diff --git a/ultralytics/models/yolov10/__pycache__/card.cpython-39.pyc b/ultralytics/models/yolov10/__pycache__/card.cpython-39.pyc new file mode 100644 index 0000000..45addf0 Binary files /dev/null and b/ultralytics/models/yolov10/__pycache__/card.cpython-39.pyc differ diff --git a/ultralytics/models/yolov10/__pycache__/model.cpython-312.pyc b/ultralytics/models/yolov10/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000..6f88e23 Binary files /dev/null and b/ultralytics/models/yolov10/__pycache__/model.cpython-312.pyc differ diff --git a/ultralytics/models/yolov10/__pycache__/model.cpython-39.pyc b/ultralytics/models/yolov10/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000..402cd0b Binary files /dev/null and b/ultralytics/models/yolov10/__pycache__/model.cpython-39.pyc differ diff --git a/ultralytics/models/yolov10/__pycache__/predict.cpython-312.pyc b/ultralytics/models/yolov10/__pycache__/predict.cpython-312.pyc new file mode 100644 index 0000000..509badc Binary files /dev/null and b/ultralytics/models/yolov10/__pycache__/predict.cpython-312.pyc differ diff --git a/ultralytics/models/yolov10/__pycache__/predict.cpython-39.pyc b/ultralytics/models/yolov10/__pycache__/predict.cpython-39.pyc new file mode 100644 index 0000000..b0b5942 Binary files /dev/null and b/ultralytics/models/yolov10/__pycache__/predict.cpython-39.pyc differ diff --git a/ultralytics/models/yolov10/__pycache__/train.cpython-312.pyc b/ultralytics/models/yolov10/__pycache__/train.cpython-312.pyc new file mode 100644 index 0000000..e3fb6c1 Binary files /dev/null and b/ultralytics/models/yolov10/__pycache__/train.cpython-312.pyc differ diff --git a/ultralytics/models/yolov10/__pycache__/train.cpython-39.pyc b/ultralytics/models/yolov10/__pycache__/train.cpython-39.pyc new file mode 100644 index 0000000..a0e1e89 Binary files /dev/null and b/ultralytics/models/yolov10/__pycache__/train.cpython-39.pyc differ diff --git a/ultralytics/models/yolov10/__pycache__/val.cpython-312.pyc b/ultralytics/models/yolov10/__pycache__/val.cpython-312.pyc new file mode 100644 index 0000000..6258144 Binary files /dev/null and b/ultralytics/models/yolov10/__pycache__/val.cpython-312.pyc differ diff --git a/ultralytics/models/yolov10/__pycache__/val.cpython-39.pyc b/ultralytics/models/yolov10/__pycache__/val.cpython-39.pyc new file mode 100644 index 0000000..24e02c3 Binary files /dev/null and b/ultralytics/models/yolov10/__pycache__/val.cpython-39.pyc differ diff --git a/ultralytics/models/yolov10/card.py b/ultralytics/models/yolov10/card.py new file mode 100644 index 0000000..fc2405c --- /dev/null +++ b/ultralytics/models/yolov10/card.py @@ -0,0 +1,64 @@ +card_template_text = """ +--- +license: agpl-3.0 +library_name: ultralytics +repo_url: https://github.com/THU-MIG/yolov10 +tags: +- object-detection +- computer-vision +- yolov10 +datasets: +- detection-datasets/coco +inference: false +--- + +### Model Description +[YOLOv10: Real-Time End-to-End Object Detection](https://arxiv.org/abs/2405.14458v1) + +- arXiv: https://arxiv.org/abs/2405.14458v1 +- github: https://github.com/THU-MIG/yolov10 + +### Installation +``` +pip install git+https://github.com/THU-MIG/yolov10.git +``` + +### Training and validation +```python +from ultralytics import YOLOv10 + +model = YOLOv10.from_pretrained('jameslahm/yolov10n') +# Training +model.train(...) +# after training, one can push to the hub +model.push_to_hub("your-hf-username/yolov10-finetuned") + +# Validation +model.val(...) +``` + +### Inference + +Here's an end-to-end example showcasing inference on a cats image: + +```python +from ultralytics import YOLOv10 + +model = YOLOv10.from_pretrained('jameslahm/yolov10n') +source = 'http://images.cocodataset.org/val2017/000000039769.jpg' +model.predict(source=source, save=True) +``` +which shows: + +![image/png](https://cdn-uploads.huggingface.co/production/uploads/628ece6054698ce61d1e7be3/tBwAsKcQA_96HCYQp7BRr.png) + +### BibTeX Entry and Citation Info +``` +@article{wang2024yolov10, + title={YOLOv10: Real-Time End-to-End Object Detection}, + author={Wang, Ao and Chen, Hui and Liu, Lihao and Chen, Kai and Lin, Zijia and Han, Jungong and Ding, Guiguang}, + journal={arXiv preprint arXiv:2405.14458}, + year={2024} +} +``` +""".strip() \ No newline at end of file diff --git a/ultralytics/models/yolov10/model.py b/ultralytics/models/yolov10/model.py new file mode 100644 index 0000000..09592c8 --- /dev/null +++ b/ultralytics/models/yolov10/model.py @@ -0,0 +1,36 @@ +from ultralytics.engine.model import Model +from ultralytics.nn.tasks import YOLOv10DetectionModel +from .val import YOLOv10DetectionValidator +from .predict import YOLOv10DetectionPredictor +from .train import YOLOv10DetectionTrainer + +from huggingface_hub import PyTorchModelHubMixin +from .card import card_template_text + +class YOLOv10(Model, PyTorchModelHubMixin, model_card_template=card_template_text): + + def __init__(self, model="yolov10n.pt", task=None, verbose=False, + names=None): + super().__init__(model=model, task=task, verbose=verbose) + if names is not None: + setattr(self.model, 'names', names) + + def push_to_hub(self, repo_name, **kwargs): + config = kwargs.get('config', {}) + config['names'] = self.names + config['model'] = self.model.yaml['yaml_file'] + config['task'] = self.task + kwargs['config'] = config + super().push_to_hub(repo_name, **kwargs) + + @property + def task_map(self): + """Map head to model, trainer, validator, and predictor classes.""" + return { + "detect": { + "model": YOLOv10DetectionModel, + "trainer": YOLOv10DetectionTrainer, + "validator": YOLOv10DetectionValidator, + "predictor": YOLOv10DetectionPredictor, + }, + } \ No newline at end of file diff --git a/ultralytics/models/yolov10/predict.py b/ultralytics/models/yolov10/predict.py new file mode 100644 index 0000000..77644e9 --- /dev/null +++ b/ultralytics/models/yolov10/predict.py @@ -0,0 +1,38 @@ +from ultralytics.models.yolo.detect import DetectionPredictor +import torch +from ultralytics.utils import ops +from ultralytics.engine.results import Results + + +class YOLOv10DetectionPredictor(DetectionPredictor): + def postprocess(self, preds, img, orig_imgs): + if isinstance(preds, dict): + preds = preds["one2one"] + + if isinstance(preds, (list, tuple)): + preds = preds[0] + + if preds.shape[-1] == 6: + pass + else: + preds = preds.transpose(-1, -2) + bboxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, preds.shape[-1]-4) + bboxes = ops.xywh2xyxy(bboxes) + preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) + + mask = preds[..., 4] > self.args.conf + if self.args.classes is not None: + mask = mask & (preds[..., 5:6] == torch.tensor(self.args.classes, device=preds.device).unsqueeze(0)).any(2) + + preds = [p[mask[idx]] for idx, p in enumerate(preds)] + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for i, pred in enumerate(preds): + orig_img = orig_imgs[i] + pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + img_path = self.batch[0][i] + results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred)) + return results diff --git a/ultralytics/models/yolov10/train.py b/ultralytics/models/yolov10/train.py new file mode 100644 index 0000000..7305bca --- /dev/null +++ b/ultralytics/models/yolov10/train.py @@ -0,0 +1,20 @@ +from ultralytics.models.yolo.detect import DetectionTrainer +from .val import YOLOv10DetectionValidator +from .model import YOLOv10DetectionModel +from copy import copy +from ultralytics.utils import RANK + +class YOLOv10DetectionTrainer(DetectionTrainer): + def get_validator(self): + """Returns a DetectionValidator for YOLO model validation.""" + self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", + return YOLOv10DetectionValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) + + def get_model(self, cfg=None, weights=None, verbose=True): + """Return a YOLO detection model.""" + model = YOLOv10DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + return model diff --git a/ultralytics/models/yolov10/val.py b/ultralytics/models/yolov10/val.py new file mode 100644 index 0000000..19a019c --- /dev/null +++ b/ultralytics/models/yolov10/val.py @@ -0,0 +1,24 @@ +from ultralytics.models.yolo.detect import DetectionValidator +from ultralytics.utils import ops +import torch + +class YOLOv10DetectionValidator(DetectionValidator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.args.save_json |= self.is_coco + + def postprocess(self, preds): + if isinstance(preds, dict): + preds = preds["one2one"] + + if isinstance(preds, (list, tuple)): + preds = preds[0] + + # Acknowledgement: Thanks to sanha9999 in #190 and #181! + if preds.shape[-1] == 6: + return preds + else: + preds = preds.transpose(-1, -2) + boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc) + bboxes = ops.xywh2xyxy(boxes) + return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) \ No newline at end of file diff --git a/ultralytics/nn/__init__.py b/ultralytics/nn/__init__.py index 9889b7e..6905d34 100644 --- a/ultralytics/nn/__init__.py +++ b/ultralytics/nn/__init__.py @@ -1,9 +1,29 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from .tasks import (BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight, - attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load, - yaml_model_load) +from .tasks import ( + BaseModel, + ClassificationModel, + DetectionModel, + SegmentationModel, + attempt_load_one_weight, + attempt_load_weights, + guess_model_scale, + guess_model_task, + parse_model, + torch_safe_load, + yaml_model_load, +) -__all__ = ('attempt_load_one_weight', 'attempt_load_weights', 'parse_model', 'yaml_model_load', 'guess_model_task', - 'guess_model_scale', 'torch_safe_load', 'DetectionModel', 'SegmentationModel', 'ClassificationModel', - 'BaseModel') +__all__ = ( + "attempt_load_one_weight", + "attempt_load_weights", + "parse_model", + "yaml_model_load", + "guess_model_task", + "guess_model_scale", + "torch_safe_load", + "DetectionModel", + "SegmentationModel", + "ClassificationModel", + "BaseModel", +) diff --git a/ultralytics/nn/__pycache__/__init__.cpython-312.pyc b/ultralytics/nn/__pycache__/__init__.cpython-312.pyc index c6e62c9..2e2cef5 100644 Binary files a/ultralytics/nn/__pycache__/__init__.cpython-312.pyc and b/ultralytics/nn/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/nn/__pycache__/__init__.cpython-39.pyc b/ultralytics/nn/__pycache__/__init__.cpython-39.pyc index 40bf2cf..0bef79a 100644 Binary files a/ultralytics/nn/__pycache__/__init__.cpython-39.pyc and b/ultralytics/nn/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/nn/__pycache__/autobackend.cpython-312.pyc b/ultralytics/nn/__pycache__/autobackend.cpython-312.pyc index 966828c..975c422 100644 Binary files a/ultralytics/nn/__pycache__/autobackend.cpython-312.pyc and b/ultralytics/nn/__pycache__/autobackend.cpython-312.pyc differ diff --git a/ultralytics/nn/__pycache__/autobackend.cpython-39.pyc b/ultralytics/nn/__pycache__/autobackend.cpython-39.pyc index 594a614..3f6ef4e 100644 Binary files a/ultralytics/nn/__pycache__/autobackend.cpython-39.pyc and b/ultralytics/nn/__pycache__/autobackend.cpython-39.pyc differ diff --git a/ultralytics/nn/__pycache__/tasks.cpython-312.pyc b/ultralytics/nn/__pycache__/tasks.cpython-312.pyc index 2640d75..3d17fab 100644 Binary files a/ultralytics/nn/__pycache__/tasks.cpython-312.pyc and b/ultralytics/nn/__pycache__/tasks.cpython-312.pyc differ diff --git a/ultralytics/nn/__pycache__/tasks.cpython-39.pyc b/ultralytics/nn/__pycache__/tasks.cpython-39.pyc index 519c646..60201b2 100644 Binary files a/ultralytics/nn/__pycache__/tasks.cpython-39.pyc and b/ultralytics/nn/__pycache__/tasks.cpython-39.pyc differ diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 9010815..abd255c 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -7,7 +7,6 @@ import platform import zipfile from collections import OrderedDict, namedtuple from pathlib import Path -from urllib.parse import urlparse import cv2 import numpy as np @@ -21,7 +20,11 @@ from ultralytics.utils.downloads import attempt_download_asset, is_url def check_class_names(names): - """Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts.""" + """ + Check class names. + + Map imagenet class codes to human-readable names if required. Convert lists to dicts. + """ if isinstance(names, list): # names is a list names = dict(enumerate(names)) # convert to dict if isinstance(names, dict): @@ -29,44 +32,39 @@ def check_class_names(names): names = {int(k): str(v) for k, v in names.items()} n = len(names) if max(names.keys()) >= n: - raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices ' - f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.') - if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764' - map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map'] # human-readable names - names = {k: map[v] for k, v in names.items()} + raise KeyError( + f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices " + f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML." + ) + if isinstance(names[0], str) and names[0].startswith("n0"): # imagenet class codes, i.e. 'n01440764' + names_map = yaml_load(ROOT / "cfg/datasets/ImageNet.yaml")["map"] # human-readable names + names = {k: names_map[v] for k, v in names.items()} return names +def default_class_names(data=None): + """Applies default class names to an input YAML file or returns numerical class names.""" + if data: + with contextlib.suppress(Exception): + return yaml_load(check_yaml(data))["names"] + return {i: f"class{i}" for i in range(999)} # return default if above errors + + class AutoBackend(nn.Module): + """ + Handles dynamic backend selection for running inference using Ultralytics YOLO models. - def __init__(self, - weights='yolov8n.pt', - device=torch.device('cpu'), - dnn=False, - data=None, - fp16=False, - fuse=True, - verbose=True): - """ - MultiBackend class for python inference on various platforms using Ultralytics YOLO. + The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide + range of formats, each with specific naming conventions as outlined below: - Args: - weights (str): The path to the weights file. Default: 'yolov8n.pt' - device (torch.device): The device to run the model on. - dnn (bool): Use OpenCV DNN module for inference if True, defaults to False. - data (str | Path | optional): Additional data.yaml file for class names. - fp16 (bool): If True, use half precision. Default: False - fuse (bool): Whether to fuse the model or not. Default: True - verbose (bool): Whether to run in verbose mode or not. Default: True - - Supported formats and their naming conventions: - | Format | Suffix | + Supported Formats and Naming Conventions: + | Format | File Suffix | |-----------------------|------------------| | PyTorch | *.pt | | TorchScript | *.torchscript | | ONNX Runtime | *.onnx | - | ONNX OpenCV DNN | *.onnx dnn=True | - | OpenVINO | *.xml | + | ONNX OpenCV DNN | *.onnx (dnn=True)| + | OpenVINO | *openvino_model/ | | CoreML | *.mlpackage | | TensorRT | *.engine | | TensorFlow SavedModel | *_saved_model | @@ -74,103 +72,166 @@ class AutoBackend(nn.Module): | TensorFlow Lite | *.tflite | | TensorFlow Edge TPU | *_edgetpu.tflite | | PaddlePaddle | *_paddle_model | - | ncnn | *_ncnn_model | + | NCNN | *_ncnn_model | + + This class offers dynamic backend switching capabilities based on the input model format, making it easier to deploy + models across various platforms. + """ + + @torch.no_grad() + def __init__( + self, + weights="yolov8n.pt", + device=torch.device("cpu"), + dnn=False, + data=None, + fp16=False, + batch=1, + fuse=True, + verbose=True, + ): + """ + Initialize the AutoBackend for inference. + + Args: + weights (str): Path to the model weights file. Defaults to 'yolov8n.pt'. + device (torch.device): Device to run the model on. Defaults to CPU. + dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False. + data (str | Path | optional): Path to the additional data.yaml file containing class names. Optional. + fp16 (bool): Enable half-precision inference. Supported only on specific backends. Defaults to False. + batch (int): Batch-size to assume for inference. + fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. Defaults to True. + verbose (bool): Enable verbose logging. Defaults to True. """ super().__init__() w = str(weights[0] if isinstance(weights, list) else weights) nn_module = isinstance(weights, torch.nn.Module) - pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \ - self._model_type(w) + ( + pt, + jit, + onnx, + xml, + engine, + coreml, + saved_model, + pb, + tflite, + edgetpu, + tfjs, + paddle, + ncnn, + triton, + ) = self._model_type(w) fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16 nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH) stride = 32 # default stride model, metadata = None, None # Set device - cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA - if cuda and not any([nn_module, pt, jit, engine]): # GPU dataloader formats - device = torch.device('cpu') + cuda = torch.cuda.is_available() and device.type != "cpu" # use CUDA + if cuda and not any([nn_module, pt, jit, engine, onnx]): # GPU dataloader formats + device = torch.device("cpu") cuda = False # Download if not local if not (pt or triton or nn_module): w = attempt_download_asset(w) - # Load model - if nn_module: # in-memory PyTorch model + # In-memory PyTorch model + if nn_module: model = weights.to(device) model = model.fuse(verbose=verbose) if fuse else model - if hasattr(model, 'kpt_shape'): + if hasattr(model, "kpt_shape"): kpt_shape = model.kpt_shape # pose-only stride = max(int(model.stride.max()), 32) # model stride - names = model.module.names if hasattr(model, 'module') else model.names # get class names + names = model.module.names if hasattr(model, "module") else model.names # get class names model.half() if fp16 else model.float() self.model = model # explicitly assign for to(), cpu(), cuda(), half() pt = True - elif pt: # PyTorch + + # PyTorch + elif pt: from ultralytics.nn.tasks import attempt_load_weights - model = attempt_load_weights(weights if isinstance(weights, list) else w, - device=device, - inplace=True, - fuse=fuse) - if hasattr(model, 'kpt_shape'): + + model = attempt_load_weights( + weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse + ) + if hasattr(model, "kpt_shape"): kpt_shape = model.kpt_shape # pose-only stride = max(int(model.stride.max()), 32) # model stride - names = model.module.names if hasattr(model, 'module') else model.names # get class names + names = model.module.names if hasattr(model, "module") else model.names # get class names model.half() if fp16 else model.float() self.model = model # explicitly assign for to(), cpu(), cuda(), half() - elif jit: # TorchScript - LOGGER.info(f'Loading {w} for TorchScript inference...') - extra_files = {'config.txt': ''} # model metadata + + # TorchScript + elif jit: + LOGGER.info(f"Loading {w} for TorchScript inference...") + extra_files = {"config.txt": ""} # model metadata model = torch.jit.load(w, _extra_files=extra_files, map_location=device) model.half() if fp16 else model.float() - if extra_files['config.txt']: # load metadata dict - metadata = json.loads(extra_files['config.txt'], object_hook=lambda x: dict(x.items())) - elif dnn: # ONNX OpenCV DNN - LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...') - check_requirements('opencv-python>=4.5.4') + if extra_files["config.txt"]: # load metadata dict + metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items())) + + # ONNX OpenCV DNN + elif dnn: + LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...") + check_requirements("opencv-python>=4.5.4") net = cv2.dnn.readNetFromONNX(w) - elif onnx: # ONNX Runtime - LOGGER.info(f'Loading {w} for ONNX Runtime inference...') - check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime')) + + # ONNX Runtime + elif onnx: + LOGGER.info(f"Loading {w} for ONNX Runtime inference...") + check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime")) import onnxruntime - providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider'] + + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"] session = onnxruntime.InferenceSession(w, providers=providers) output_names = [x.name for x in session.get_outputs()] - metadata = session.get_modelmeta().custom_metadata_map # metadata - elif xml: # OpenVINO - LOGGER.info(f'Loading {w} for OpenVINO inference...') - check_requirements('openvino>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/ - from openvino.runtime import Core, Layout, get_batch # noqa - core = Core() + metadata = session.get_modelmeta().custom_metadata_map + + # OpenVINO + elif xml: + LOGGER.info(f"Loading {w} for OpenVINO inference...") + check_requirements("openvino>=2024.0.0") + import openvino as ov + + core = ov.Core() w = Path(w) if not w.is_file(): # if not *.xml - w = next(w.glob('*.xml')) # get *.xml file from *_openvino_model dir - ov_model = core.read_model(model=str(w), weights=w.with_suffix('.bin')) + w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir + ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin")) if ov_model.get_parameters()[0].get_layout().empty: - ov_model.get_parameters()[0].set_layout(Layout('NCHW')) - batch_dim = get_batch(ov_model) - if batch_dim.is_static: - batch_size = batch_dim.get_length() - ov_compiled_model = core.compile_model(ov_model, device_name='AUTO') # AUTO selects best available device - metadata = w.parent / 'metadata.yaml' - elif engine: # TensorRT - LOGGER.info(f'Loading {w} for TensorRT inference...') + ov_model.get_parameters()[0].set_layout(ov.Layout("NCHW")) + + # OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT' + inference_mode = "CUMULATIVE_THROUGHPUT" if batch > 1 else "LATENCY" + LOGGER.info(f"Using OpenVINO {inference_mode} mode for batch={batch} inference...") + ov_compiled_model = core.compile_model( + ov_model, + device_name="AUTO", # AUTO selects best available device, do not modify + config={"PERFORMANCE_HINT": inference_mode}, + ) + input_name = ov_compiled_model.input().get_any_name() + metadata = w.parent / "metadata.yaml" + + # TensorRT + elif engine: + LOGGER.info(f"Loading {w} for TensorRT inference...") try: import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download except ImportError: if LINUX: - check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com') + check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com") import tensorrt as trt # noqa - check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0 - if device.type == 'cpu': - device = torch.device('cuda:0') - Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) + check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0 + if device.type == "cpu": + device = torch.device("cuda:0") + Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) logger = trt.Logger(trt.Logger.INFO) # Read file - with open(w, 'rb') as f, trt.Runtime(logger) as runtime: - meta_len = int.from_bytes(f.read(4), byteorder='little') # read metadata length - metadata = json.loads(f.read(meta_len).decode('utf-8')) # read metadata + with open(w, "rb") as f, trt.Runtime(logger) as runtime: + meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length + metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata model = runtime.deserialize_cuda_engine(f.read()) # read engine context = model.create_execution_context() bindings = OrderedDict() @@ -192,126 +253,152 @@ class AutoBackend(nn.Module): im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) - batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size - elif coreml: # CoreML - LOGGER.info(f'Loading {w} for CoreML inference...') + batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size + + # CoreML + elif coreml: + LOGGER.info(f"Loading {w} for CoreML inference...") import coremltools as ct + model = ct.models.MLModel(w) metadata = dict(model.user_defined_metadata) - elif saved_model: # TF SavedModel - LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...') + + # TF SavedModel + elif saved_model: + LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...") import tensorflow as tf + keras = False # assume TF1 saved_model model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) - metadata = Path(w) / 'metadata.yaml' - elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt - LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...') + metadata = Path(w) / "metadata.yaml" + + # TF GraphDef + elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt + LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...") import tensorflow as tf from ultralytics.engine.exporter import gd_outputs def wrap_frozen_graph(gd, inputs, outputs): """Wrap frozen graphs for deployment.""" - x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped + x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped ge = x.graph.as_graph_element return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) gd = tf.Graph().as_graph_def() # TF GraphDef - with open(w, 'rb') as f: + with open(w, "rb") as f: gd.ParseFromString(f.read()) - frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd)) + frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) + + # TFLite or TFLite Edge TPU elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu from tflite_runtime.interpreter import Interpreter, load_delegate except ImportError: import tensorflow as tf + Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime - LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...') - delegate = { - 'Linux': 'libedgetpu.so.1', - 'Darwin': 'libedgetpu.1.dylib', - 'Windows': 'edgetpu.dll'}[platform.system()] + LOGGER.info(f"Loading {w} for TensorFlow Lite Edge TPU inference...") + delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[ + platform.system() + ] interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)]) else: # TFLite - LOGGER.info(f'Loading {w} for TensorFlow Lite inference...') + LOGGER.info(f"Loading {w} for TensorFlow Lite inference...") interpreter = Interpreter(model_path=w) # load TFLite model interpreter.allocate_tensors() # allocate input_details = interpreter.get_input_details() # inputs output_details = interpreter.get_output_details() # outputs # Load metadata with contextlib.suppress(zipfile.BadZipFile): - with zipfile.ZipFile(w, 'r') as model: + with zipfile.ZipFile(w, "r") as model: meta_file = model.namelist()[0] - metadata = ast.literal_eval(model.read(meta_file).decode('utf-8')) - elif tfjs: # TF.js - raise NotImplementedError('YOLOv8 TF.js inference is not currently supported.') - elif paddle: # PaddlePaddle - LOGGER.info(f'Loading {w} for PaddlePaddle inference...') - check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle') + metadata = ast.literal_eval(model.read(meta_file).decode("utf-8")) + + # TF.js + elif tfjs: + raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.") + + # PaddlePaddle + elif paddle: + LOGGER.info(f"Loading {w} for PaddlePaddle inference...") + check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle") import paddle.inference as pdi # noqa + w = Path(w) if not w.is_file(): # if not *.pdmodel - w = next(w.rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir - config = pdi.Config(str(w), str(w.with_suffix('.pdiparams'))) + w = next(w.rglob("*.pdmodel")) # get *.pdmodel file from *_paddle_model dir + config = pdi.Config(str(w), str(w.with_suffix(".pdiparams"))) if cuda: config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0) predictor = pdi.create_predictor(config) input_handle = predictor.get_input_handle(predictor.get_input_names()[0]) output_names = predictor.get_output_names() - metadata = w.parents[1] / 'metadata.yaml' - elif ncnn: # ncnn - LOGGER.info(f'Loading {w} for ncnn inference...') - check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires ncnn + metadata = w.parents[1] / "metadata.yaml" + + # NCNN + elif ncnn: + LOGGER.info(f"Loading {w} for NCNN inference...") + check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires NCNN import ncnn as pyncnn + net = pyncnn.Net() net.opt.use_vulkan_compute = cuda w = Path(w) if not w.is_file(): # if not *.param - w = next(w.glob('*.param')) # get *.param file from *_ncnn_model dir + w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir net.load_param(str(w)) - net.load_model(str(w.with_suffix('.bin'))) - metadata = w.parent / 'metadata.yaml' - elif triton: # NVIDIA Triton Inference Server - """TODO - check_requirements('tritonclient[all]') - from utils.triton import TritonRemoteModel - model = TritonRemoteModel(url=w) - nhwc = model.runtime.startswith("tensorflow") - """ - raise NotImplementedError('Triton Inference Server is not currently supported.') + net.load_model(str(w.with_suffix(".bin"))) + metadata = w.parent / "metadata.yaml" + + # NVIDIA Triton Inference Server + elif triton: + check_requirements("tritonclient[all]") + from ultralytics.utils.triton import TritonRemoteModel + + model = TritonRemoteModel(w) + + # Any other format (unsupported) else: from ultralytics.engine.exporter import export_formats - raise TypeError(f"model='{w}' is not a supported model format. " - 'See https://docs.ultralytics.com/modes/predict for help.' - f'\n\n{export_formats()}') + + raise TypeError( + f"model='{w}' is not a supported model format. " + f"See https://docs.ultralytics.com/modes/predict for help.\n\n{export_formats()}" + ) # Load external metadata YAML if isinstance(metadata, (str, Path)) and Path(metadata).exists(): metadata = yaml_load(metadata) if metadata: for k, v in metadata.items(): - if k in ('stride', 'batch'): + if k in ("stride", "batch"): metadata[k] = int(v) - elif k in ('imgsz', 'names', 'kpt_shape') and isinstance(v, str): + elif k in ("imgsz", "names", "kpt_shape") and isinstance(v, str): metadata[k] = eval(v) - stride = metadata['stride'] - task = metadata['task'] - batch = metadata['batch'] - imgsz = metadata['imgsz'] - names = metadata['names'] - kpt_shape = metadata.get('kpt_shape') + stride = metadata["stride"] + task = metadata["task"] + batch = metadata["batch"] + imgsz = metadata["imgsz"] + names = metadata["names"] + kpt_shape = metadata.get("kpt_shape") elif not (pt or triton or nn_module): LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'") # Check names - if 'names' not in locals(): # names missing - names = self._apply_default_class_names(data) + if "names" not in locals(): # names missing + names = default_class_names(data) names = check_class_names(names) + # Disable gradients + if pt: + for p in model.parameters(): + p.requires_grad = False + self.__dict__.update(locals()) # assign all variables to self - def forward(self, im, augment=False, visualize=False): + def forward(self, im, augment=False, visualize=False, embed=None): """ Runs inference on the YOLOv8 MultiBackend model. @@ -319,6 +406,7 @@ class AutoBackend(nn.Module): im (torch.Tensor): The image tensor to perform inference on. augment (bool): whether to perform data augmentation during inference, defaults to False visualize (bool): whether to visualize the output predictions, defaults to False + embed (list, optional): A list of feature vectors/embeddings to return. Returns: (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True) @@ -329,41 +417,75 @@ class AutoBackend(nn.Module): if self.nhwc: im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3) - if self.pt or self.nn_module: # PyTorch - y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im) - elif self.jit: # TorchScript + # PyTorch + if self.pt or self.nn_module: + y = self.model(im, augment=augment, visualize=visualize, embed=embed) + + # TorchScript + elif self.jit: y = self.model(im) - elif self.dnn: # ONNX OpenCV DNN + + # ONNX OpenCV DNN + elif self.dnn: im = im.cpu().numpy() # torch to numpy self.net.setInput(im) y = self.net.forward() - elif self.onnx: # ONNX Runtime + + # ONNX Runtime + elif self.onnx: im = im.cpu().numpy() # torch to numpy y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) - elif self.xml: # OpenVINO + + # OpenVINO + elif self.xml: im = im.cpu().numpy() # FP32 - y = list(self.ov_compiled_model(im).values()) - elif self.engine: # TensorRT - if self.dynamic and im.shape != self.bindings['images'].shape: - i = self.model.get_binding_index('images') + + if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes + n = im.shape[0] # number of images in batch + results = [None] * n # preallocate list with None to match the number of images + + def callback(request, userdata): + """Places result in preallocated list using userdata index.""" + results[userdata] = request.results + + # Create AsyncInferQueue, set the callback and start asynchronous inference for each input image + async_queue = self.ov.runtime.AsyncInferQueue(self.ov_compiled_model) + async_queue.set_callback(callback) + for i in range(n): + # Start async inference with userdata=i to specify the position in results list + async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW + async_queue.wait_all() # wait for all inference requests to complete + y = np.concatenate([list(r.values())[0] for r in results]) + + else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1 + y = list(self.ov_compiled_model(im).values()) + + # TensorRT + elif self.engine: + if self.dynamic and im.shape != self.bindings["images"].shape: + i = self.model.get_binding_index("images") self.context.set_binding_shape(i, im.shape) # reshape if dynamic - self.bindings['images'] = self.bindings['images']._replace(shape=im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) for name in self.output_names: i = self.model.get_binding_index(name) self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i))) - s = self.bindings['images'].shape + s = self.bindings["images"].shape assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" - self.binding_addrs['images'] = int(im.data_ptr()) + self.binding_addrs["images"] = int(im.data_ptr()) self.context.execute_v2(list(self.binding_addrs.values())) y = [self.bindings[x].data for x in sorted(self.output_names)] - elif self.coreml: # CoreML + + # CoreML + elif self.coreml: im = im[0].cpu().numpy() - im_pil = Image.fromarray((im * 255).astype('uint8')) + im_pil = Image.fromarray((im * 255).astype("uint8")) # im = im.resize((192, 320), Image.BILINEAR) - y = self.model.predict({'image': im_pil}) # coordinates are xywh normalized - if 'confidence' in y: - raise TypeError('Ultralytics only supports inference of non-pipelined CoreML models exported with ' - f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export.") + y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized + if "confidence" in y: + raise TypeError( + "Ultralytics only supports inference of non-pipelined CoreML models exported with " + f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export." + ) # TODO: CoreML NMS inference handling # from ultralytics.utils.ops import xywh2xyxy # box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels @@ -373,24 +495,28 @@ class AutoBackend(nn.Module): y = list(y.values()) elif len(y) == 2: # segmentation model y = list(reversed(y.values())) # reversed for segmentation models (pred, proto) - elif self.paddle: # PaddlePaddle + + # PaddlePaddle + elif self.paddle: im = im.cpu().numpy().astype(np.float32) self.input_handle.copy_from_cpu(im) self.predictor.run() y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names] - elif self.ncnn: # ncnn + + # NCNN + elif self.ncnn: mat_in = self.pyncnn.Mat(im[0].cpu().numpy()) - ex = self.net.create_extractor() - input_names, output_names = self.net.input_names(), self.net.output_names() - ex.input(input_names[0], mat_in) - y = [] - for output_name in output_names: - mat_out = self.pyncnn.Mat() - ex.extract(output_name, mat_out) - y.append(np.array(mat_out)[None]) - elif self.triton: # NVIDIA Triton Inference Server + with self.net.create_extractor() as ex: + ex.input(self.net.input_names()[0], mat_in) + y = [np.array(ex.extract(x)[1])[None] for x in self.net.output_names()] + + # NVIDIA Triton Inference Server + elif self.triton: + im = im.cpu().numpy() # torch to numpy y = self.model(im) - else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) + + # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) + else: im = im.cpu().numpy() if self.saved_model: # SavedModel y = self.model(im, training=False) if self.keras else self.model(im) @@ -401,20 +527,20 @@ class AutoBackend(nn.Module): if len(y) == 2 and len(self.names) == 999: # segments and names not defined ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400) - self.names = {i: f'class{i}' for i in range(nc)} + self.names = {i: f"class{i}" for i in range(nc)} else: # Lite or Edge TPU details = self.input_details[0] - integer = details['dtype'] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model + integer = details["dtype"] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model if integer: - scale, zero_point = details['quantization'] - im = (im / scale + zero_point).astype(details['dtype']) # de-scale - self.interpreter.set_tensor(details['index'], im) + scale, zero_point = details["quantization"] + im = (im / scale + zero_point).astype(details["dtype"]) # de-scale + self.interpreter.set_tensor(details["index"], im) self.interpreter.invoke() y = [] for output in self.output_details: - x = self.interpreter.get_tensor(output['index']) + x = self.interpreter.get_tensor(output["index"]) if integer: - scale, zero_point = output['quantization'] + scale, zero_point = output["quantization"] x = (x.astype(np.float32) - zero_point) * scale # re-scale if x.ndim > 2: # if task is not classification # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695 @@ -438,14 +564,14 @@ class AutoBackend(nn.Module): def from_numpy(self, x): """ - Convert a numpy array to a tensor. + Convert a numpy array to a tensor. - Args: - x (np.ndarray): The array to be converted. + Args: + x (np.ndarray): The array to be converted. - Returns: - (torch.Tensor): The converted tensor - """ + Returns: + (torch.Tensor): The converted tensor + """ return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x def warmup(self, imgsz=(1, 3, 640, 640)): @@ -454,44 +580,41 @@ class AutoBackend(nn.Module): Args: imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width) - - Returns: - (None): This method runs the forward pass and don't return any value """ warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module - if any(warmup_types) and (self.device.type != 'cpu' or self.triton): + if any(warmup_types) and (self.device.type != "cpu" or self.triton): im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input - for _ in range(2 if self.jit else 1): # + for _ in range(2 if self.jit else 1): self.forward(im) # warmup @staticmethod - def _apply_default_class_names(data): - """Applies default class names to an input YAML file or returns numerical class names.""" - with contextlib.suppress(Exception): - return yaml_load(check_yaml(data))['names'] - return {i: f'class{i}' for i in range(999)} # return default if above errors - - @staticmethod - def _model_type(p='path/to/model.pt'): + def _model_type(p="path/to/model.pt"): """ - This function takes a path to a model file and returns the model type + This function takes a path to a model file and returns the model type. Possibles types are pt, jit, onnx, xml, + engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle. Args: p: path to the model file. Defaults to path/to/model.pt + + Examples: + >>> model = AutoBackend(weights="path/to/model.onnx") + >>> model_type = model._model_type() # returns "onnx" """ - # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx - # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle] from ultralytics.engine.exporter import export_formats + sf = list(export_formats().Suffix) # export suffixes - if not is_url(p, check=False) and not isinstance(p, str): + if not is_url(p) and not isinstance(p, str): check_suffix(p, sf) # checks name = Path(p).name types = [s in name for s in sf] - types[5] |= name.endswith('.mlmodel') # retain support for older Apple CoreML *.mlmodel formats + types[5] |= name.endswith(".mlmodel") # retain support for older Apple CoreML *.mlmodel formats types[8] &= not types[9] # tflite &= not edgetpu if any(types): triton = False else: - url = urlparse(p) # if url may be Triton inference server - triton = all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc]) + from urllib.parse import urlsplit + + url = urlsplit(p) + triton = bool(url.netloc) and bool(url.path) and url.scheme in {"http", "grpc"} + return types + [triton] diff --git a/ultralytics/nn/modules/__init__.py b/ultralytics/nn/modules/__init__.py index b6dc6c4..7f4c4fe 100644 --- a/ultralytics/nn/modules/__init__.py +++ b/ultralytics/nn/modules/__init__.py @@ -1,29 +1,147 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license """ -Ultralytics modules. Visualize with: +Ultralytics modules. -from ultralytics.nn.modules import * -import torch -import os +Example: + Visualize a module with Netron. + ```python + from ultralytics.nn.modules import * + import torch + import os -x = torch.ones(1, 128, 40, 40) -m = Conv(128, 128) -f = f'{m._get_name()}.onnx' -torch.onnx.export(m, x, f) -os.system(f'onnxsim {f} {f} && open {f}') + x = torch.ones(1, 128, 40, 40) + m = Conv(128, 128) + f = f'{m._get_name()}.onnx' + torch.onnx.export(m, x, f) + os.system(f'onnxslim {f} {f} && open {f}') # pip install onnxslim + ``` """ -from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck, - HGBlock, HGStem, Proto, RepC3) -from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus, - GhostConv, LightConv, RepConv, SpatialAttention) -from .head import Classify, Detect, Pose, RTDETRDecoder, Segment -from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d, - MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer) +from .block import ( + C1, + C2, + C3, + C3TR, + DFL, + SPP, + SPPF, + Bottleneck, + BottleneckCSP, + C2f, + C2fAttn, + ImagePoolingAttn, + C3Ghost, + C3x, + GhostBottleneck, + HGBlock, + HGStem, + Proto, + RepC3, + ResNetLayer, + ContrastiveHead, + BNContrastiveHead, + RepNCSPELAN4, + ADown, + SPPELAN, + CBFuse, + CBLinear, + Silence, + PSA, + C2fCIB, + SCDown, + RepVGGDW +) +from .conv import ( + CBAM, + ChannelAttention, + Concat, + Conv, + Conv2, + ConvTranspose, + DWConv, + DWConvTranspose2d, + Focus, + GhostConv, + LightConv, + RepConv, + SpatialAttention, +) +from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect, v10Detect +from .transformer import ( + AIFI, + MLP, + DeformableTransformerDecoder, + DeformableTransformerDecoderLayer, + LayerNorm2d, + MLPBlock, + MSDeformAttn, + TransformerBlock, + TransformerEncoderLayer, + TransformerLayer, +) -__all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', - 'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer', - 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', - 'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect', - 'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI', - 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP') +__all__ = ( + "Conv", + "Conv2", + "LightConv", + "RepConv", + "DWConv", + "DWConvTranspose2d", + "ConvTranspose", + "Focus", + "GhostConv", + "ChannelAttention", + "SpatialAttention", + "CBAM", + "Concat", + "TransformerLayer", + "TransformerBlock", + "MLPBlock", + "LayerNorm2d", + "DFL", + "HGBlock", + "HGStem", + "SPP", + "SPPF", + "C1", + "C2", + "C3", + "C2f", + "C2fAttn", + "C3x", + "C3TR", + "C3Ghost", + "GhostBottleneck", + "Bottleneck", + "BottleneckCSP", + "Proto", + "Detect", + "Segment", + "Pose", + "Classify", + "TransformerEncoderLayer", + "RepC3", + "RTDETRDecoder", + "AIFI", + "DeformableTransformerDecoder", + "DeformableTransformerDecoderLayer", + "MSDeformAttn", + "MLP", + "ResNetLayer", + "OBB", + "WorldDetect", + "ImagePoolingAttn", + "ContrastiveHead", + "BNContrastiveHead", + "RepNCSPELAN4", + "ADown", + "SPPELAN", + "CBFuse", + "CBLinear", + "Silence", + "PSA", + "C2fCIB", + "SCDown", + "RepVGGDW", + "v10Detect" +) diff --git a/ultralytics/nn/modules/__pycache__/__init__.cpython-312.pyc b/ultralytics/nn/modules/__pycache__/__init__.cpython-312.pyc index 5bd4b3d..1b13305 100644 Binary files a/ultralytics/nn/modules/__pycache__/__init__.cpython-312.pyc and b/ultralytics/nn/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/nn/modules/__pycache__/__init__.cpython-39.pyc b/ultralytics/nn/modules/__pycache__/__init__.cpython-39.pyc index be57fd6..8afb5eb 100644 Binary files a/ultralytics/nn/modules/__pycache__/__init__.cpython-39.pyc and b/ultralytics/nn/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/nn/modules/__pycache__/block.cpython-312.pyc b/ultralytics/nn/modules/__pycache__/block.cpython-312.pyc index 8913826..d0a03bd 100644 Binary files a/ultralytics/nn/modules/__pycache__/block.cpython-312.pyc and b/ultralytics/nn/modules/__pycache__/block.cpython-312.pyc differ diff --git a/ultralytics/nn/modules/__pycache__/block.cpython-39.pyc b/ultralytics/nn/modules/__pycache__/block.cpython-39.pyc index 912cb7c..56ce422 100644 Binary files a/ultralytics/nn/modules/__pycache__/block.cpython-39.pyc and b/ultralytics/nn/modules/__pycache__/block.cpython-39.pyc differ diff --git a/ultralytics/nn/modules/__pycache__/conv.cpython-312.pyc b/ultralytics/nn/modules/__pycache__/conv.cpython-312.pyc index 4c92622..218ceb8 100644 Binary files a/ultralytics/nn/modules/__pycache__/conv.cpython-312.pyc and b/ultralytics/nn/modules/__pycache__/conv.cpython-312.pyc differ diff --git a/ultralytics/nn/modules/__pycache__/conv.cpython-39.pyc b/ultralytics/nn/modules/__pycache__/conv.cpython-39.pyc index 53e2e8e..c967af1 100644 Binary files a/ultralytics/nn/modules/__pycache__/conv.cpython-39.pyc and b/ultralytics/nn/modules/__pycache__/conv.cpython-39.pyc differ diff --git a/ultralytics/nn/modules/__pycache__/head.cpython-312.pyc b/ultralytics/nn/modules/__pycache__/head.cpython-312.pyc index d5773fc..97dab04 100644 Binary files a/ultralytics/nn/modules/__pycache__/head.cpython-312.pyc and b/ultralytics/nn/modules/__pycache__/head.cpython-312.pyc differ diff --git a/ultralytics/nn/modules/__pycache__/head.cpython-39.pyc b/ultralytics/nn/modules/__pycache__/head.cpython-39.pyc index 01ba79e..35edea3 100644 Binary files a/ultralytics/nn/modules/__pycache__/head.cpython-39.pyc and b/ultralytics/nn/modules/__pycache__/head.cpython-39.pyc differ diff --git a/ultralytics/nn/modules/__pycache__/transformer.cpython-312.pyc b/ultralytics/nn/modules/__pycache__/transformer.cpython-312.pyc index 23ed098..941b74a 100644 Binary files a/ultralytics/nn/modules/__pycache__/transformer.cpython-312.pyc and b/ultralytics/nn/modules/__pycache__/transformer.cpython-312.pyc differ diff --git a/ultralytics/nn/modules/__pycache__/transformer.cpython-39.pyc b/ultralytics/nn/modules/__pycache__/transformer.cpython-39.pyc index 9f832c0..565233c 100644 Binary files a/ultralytics/nn/modules/__pycache__/transformer.cpython-39.pyc and b/ultralytics/nn/modules/__pycache__/transformer.cpython-39.pyc differ diff --git a/ultralytics/nn/modules/__pycache__/utils.cpython-312.pyc b/ultralytics/nn/modules/__pycache__/utils.cpython-312.pyc index 470ed3c..f5943b6 100644 Binary files a/ultralytics/nn/modules/__pycache__/utils.cpython-312.pyc and b/ultralytics/nn/modules/__pycache__/utils.cpython-312.pyc differ diff --git a/ultralytics/nn/modules/__pycache__/utils.cpython-39.pyc b/ultralytics/nn/modules/__pycache__/utils.cpython-39.pyc index 0a78c19..51ac323 100644 Binary files a/ultralytics/nn/modules/__pycache__/utils.cpython-39.pyc and b/ultralytics/nn/modules/__pycache__/utils.cpython-39.pyc differ diff --git a/ultralytics/nn/modules/block.py b/ultralytics/nn/modules/block.py index d8183d8..d11c16e 100644 --- a/ultralytics/nn/modules/block.py +++ b/ultralytics/nn/modules/block.py @@ -1,22 +1,50 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -""" -Block modules -""" +"""Block modules.""" import torch import torch.nn as nn import torch.nn.functional as F -from .conv import Conv, DWConv, GhostConv, LightConv, RepConv +from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad from .transformer import TransformerBlock +from ultralytics.utils.torch_utils import fuse_conv_and_bn -__all__ = ('DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost', - 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'RepC3') +__all__ = ( + "DFL", + "HGBlock", + "HGStem", + "SPP", + "SPPF", + "C1", + "C2", + "C3", + "C2f", + "C2fAttn", + "ImagePoolingAttn", + "ContrastiveHead", + "BNContrastiveHead", + "C3x", + "C3TR", + "C3Ghost", + "GhostBottleneck", + "Bottleneck", + "BottleneckCSP", + "Proto", + "RepC3", + "ResNetLayer", + "RepNCSPELAN4", + "ADown", + "SPPELAN", + "CBFuse", + "CBLinear", + "Silence", +) class DFL(nn.Module): """ Integral module of Distribution Focal Loss (DFL). + Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391 """ @@ -30,7 +58,7 @@ class DFL(nn.Module): def forward(self, x): """Applies a transformer layer on input tensor 'x' and returns a tensor.""" - b, c, a = x.shape # batch, channels, anchors + b, _, a = x.shape # batch, channels, anchors return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a) # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a) @@ -38,7 +66,12 @@ class DFL(nn.Module): class Proto(nn.Module): """YOLOv8 mask Proto module for segmentation models.""" - def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks + def __init__(self, c1, c_=256, c2=32): + """ + Initializes the YOLOv8 mask Proto module with specified number of protos and masks. + + Input arguments are ch_in, number of protos, number of masks. + """ super().__init__() self.cv1 = Conv(c1, c_, k=3) self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest') @@ -51,11 +84,14 @@ class Proto(nn.Module): class HGStem(nn.Module): - """StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d. + """ + StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d. + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py """ def __init__(self, c1, cm, c2): + """Initialize the SPP layer with input/output channels and specified kernel sizes for max pooling.""" super().__init__() self.stem1 = Conv(c1, cm, 3, 2, act=nn.ReLU()) self.stem2a = Conv(cm, cm // 2, 2, 1, 0, act=nn.ReLU()) @@ -79,11 +115,14 @@ class HGStem(nn.Module): class HGBlock(nn.Module): - """HG_Block of PPHGNetV2 with 2 convolutions and LightConv. + """ + HG_Block of PPHGNetV2 with 2 convolutions and LightConv. + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py """ def __init__(self, c1, cm, c2, k=3, n=6, lightconv=False, shortcut=False, act=nn.ReLU()): + """Initializes a CSP Bottleneck with 1 convolution using specified input and output channels.""" super().__init__() block = LightConv if lightconv else Conv self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n)) @@ -119,7 +158,12 @@ class SPP(nn.Module): class SPPF(nn.Module): """Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher.""" - def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13)) + def __init__(self, c1, c2, k=5): + """ + Initializes the SPPF layer with given input/output channels and kernel size. + + This module is equivalent to SPP(k=(5, 9, 13)). + """ super().__init__() c_ = c1 // 2 # hidden channels self.cv1 = Conv(c1, c_, 1, 1) @@ -137,7 +181,8 @@ class SPPF(nn.Module): class C1(nn.Module): """CSP Bottleneck with 1 convolution.""" - def __init__(self, c1, c2, n=1): # ch_in, ch_out, number + def __init__(self, c1, c2, n=1): + """Initializes the CSP Bottleneck with configurations for 1 convolution with arguments ch_in, ch_out, number.""" super().__init__() self.cv1 = Conv(c1, c2, 1, 1) self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n))) @@ -151,7 +196,10 @@ class C1(nn.Module): class C2(nn.Module): """CSP Bottleneck with 2 convolutions.""" - def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """Initializes the CSP Bottleneck with 2 convolutions module with arguments ch_in, ch_out, number, shortcut, + groups, expansion. + """ super().__init__() self.c = int(c2 * e) # hidden channels self.cv1 = Conv(c1, 2 * self.c, 1, 1) @@ -168,7 +216,10 @@ class C2(nn.Module): class C2f(nn.Module): """Faster Implementation of CSP Bottleneck with 2 convolutions.""" - def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): + """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups, + expansion. + """ super().__init__() self.c = int(c2 * e) # hidden channels self.cv1 = Conv(c1, 2 * self.c, 1, 1) @@ -191,7 +242,8 @@ class C2f(nn.Module): class C3(nn.Module): """CSP Bottleneck with 3 convolutions.""" - def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """Initialize the CSP Bottleneck with given channels, number, shortcut, groups, and expansion values.""" super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, 1, 1) @@ -218,6 +270,7 @@ class RepC3(nn.Module): """Rep C3.""" def __init__(self, c1, c2, n=3, e=1.0): + """Initialize CSP Bottleneck with a single convolution using input channels, output channels, and number.""" super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c2, 1, 1) @@ -253,15 +306,18 @@ class C3Ghost(C3): class GhostBottleneck(nn.Module): """Ghost Bottleneck https://github.com/huawei-noah/ghostnet.""" - def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride + def __init__(self, c1, c2, k=3, s=1): + """Initializes GhostBottleneck module with arguments ch_in, ch_out, kernel, stride.""" super().__init__() c_ = c2 // 2 self.conv = nn.Sequential( GhostConv(c1, c_, 1, 1), # pw DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw - GhostConv(c_, c2, 1, 1, act=False)) # pw-linear - self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, - act=False)) if s == 2 else nn.Identity() + GhostConv(c_, c2, 1, 1, act=False), # pw-linear + ) + self.shortcut = ( + nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity() + ) def forward(self, x): """Applies skip connection and concatenation to input tensor.""" @@ -271,7 +327,10 @@ class GhostBottleneck(nn.Module): class Bottleneck(nn.Module): """Standard bottleneck.""" - def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, groups, kernels, expand + def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): + """Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and + expansion. + """ super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, k[0], 1) @@ -279,14 +338,15 @@ class Bottleneck(nn.Module): self.add = shortcut and c1 == c2 def forward(self, x): - """'forward()' applies the YOLOv5 FPN to input data.""" + """'forward()' applies the YOLO FPN to input data.""" return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) class BottleneckCSP(nn.Module): """CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks.""" - def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """Initializes the CSP Bottleneck given arguments for ch_in, ch_out, number, shortcut, groups, expansion.""" super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, 1, 1) @@ -302,3 +362,466 @@ class BottleneckCSP(nn.Module): y1 = self.cv3(self.m(self.cv1(x))) y2 = self.cv2(x) return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1)))) + + +class ResNetBlock(nn.Module): + """ResNet block with standard convolution layers.""" + + def __init__(self, c1, c2, s=1, e=4): + """Initialize convolution with given parameters.""" + super().__init__() + c3 = e * c2 + self.cv1 = Conv(c1, c2, k=1, s=1, act=True) + self.cv2 = Conv(c2, c2, k=3, s=s, p=1, act=True) + self.cv3 = Conv(c2, c3, k=1, act=False) + self.shortcut = nn.Sequential(Conv(c1, c3, k=1, s=s, act=False)) if s != 1 or c1 != c3 else nn.Identity() + + def forward(self, x): + """Forward pass through the ResNet block.""" + return F.relu(self.cv3(self.cv2(self.cv1(x))) + self.shortcut(x)) + + +class ResNetLayer(nn.Module): + """ResNet layer with multiple ResNet blocks.""" + + def __init__(self, c1, c2, s=1, is_first=False, n=1, e=4): + """Initializes the ResNetLayer given arguments.""" + super().__init__() + self.is_first = is_first + + if self.is_first: + self.layer = nn.Sequential( + Conv(c1, c2, k=7, s=2, p=3, act=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + ) + else: + blocks = [ResNetBlock(c1, c2, s, e=e)] + blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)]) + self.layer = nn.Sequential(*blocks) + + def forward(self, x): + """Forward pass through the ResNet layer.""" + return self.layer(x) + + +class MaxSigmoidAttnBlock(nn.Module): + """Max Sigmoid attention block.""" + + def __init__(self, c1, c2, nh=1, ec=128, gc=512, scale=False): + """Initializes MaxSigmoidAttnBlock with specified arguments.""" + super().__init__() + self.nh = nh + self.hc = c2 // nh + self.ec = Conv(c1, ec, k=1, act=False) if c1 != ec else None + self.gl = nn.Linear(gc, ec) + self.bias = nn.Parameter(torch.zeros(nh)) + self.proj_conv = Conv(c1, c2, k=3, s=1, act=False) + self.scale = nn.Parameter(torch.ones(1, nh, 1, 1)) if scale else 1.0 + + def forward(self, x, guide): + """Forward process.""" + bs, _, h, w = x.shape + + guide = self.gl(guide) + guide = guide.view(bs, -1, self.nh, self.hc) + embed = self.ec(x) if self.ec is not None else x + embed = embed.view(bs, self.nh, self.hc, h, w) + + aw = torch.einsum("bmchw,bnmc->bmhwn", embed, guide) + aw = aw.max(dim=-1)[0] + aw = aw / (self.hc**0.5) + aw = aw + self.bias[None, :, None, None] + aw = aw.sigmoid() * self.scale + + x = self.proj_conv(x) + x = x.view(bs, self.nh, -1, h, w) + x = x * aw.unsqueeze(2) + return x.view(bs, -1, h, w) + + +class C2fAttn(nn.Module): + """C2f module with an additional attn module.""" + + def __init__(self, c1, c2, n=1, ec=128, nh=1, gc=512, shortcut=False, g=1, e=0.5): + """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups, + expansion. + """ + super().__init__() + self.c = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv((3 + n) * self.c, c2, 1) # optional act=FReLU(c2) + self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) + self.attn = MaxSigmoidAttnBlock(self.c, self.c, gc=gc, ec=ec, nh=nh) + + def forward(self, x, guide): + """Forward pass through C2f layer.""" + y = list(self.cv1(x).chunk(2, 1)) + y.extend(m(y[-1]) for m in self.m) + y.append(self.attn(y[-1], guide)) + return self.cv2(torch.cat(y, 1)) + + def forward_split(self, x, guide): + """Forward pass using split() instead of chunk().""" + y = list(self.cv1(x).split((self.c, self.c), 1)) + y.extend(m(y[-1]) for m in self.m) + y.append(self.attn(y[-1], guide)) + return self.cv2(torch.cat(y, 1)) + + +class ImagePoolingAttn(nn.Module): + """ImagePoolingAttn: Enhance the text embeddings with image-aware information.""" + + def __init__(self, ec=256, ch=(), ct=512, nh=8, k=3, scale=False): + """Initializes ImagePoolingAttn with specified arguments.""" + super().__init__() + + nf = len(ch) + self.query = nn.Sequential(nn.LayerNorm(ct), nn.Linear(ct, ec)) + self.key = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec)) + self.value = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec)) + self.proj = nn.Linear(ec, ct) + self.scale = nn.Parameter(torch.tensor([0.0]), requires_grad=True) if scale else 1.0 + self.projections = nn.ModuleList([nn.Conv2d(in_channels, ec, kernel_size=1) for in_channels in ch]) + self.im_pools = nn.ModuleList([nn.AdaptiveMaxPool2d((k, k)) for _ in range(nf)]) + self.ec = ec + self.nh = nh + self.nf = nf + self.hc = ec // nh + self.k = k + + def forward(self, x, text): + """Executes attention mechanism on input tensor x and guide tensor.""" + bs = x[0].shape[0] + assert len(x) == self.nf + num_patches = self.k**2 + x = [pool(proj(x)).view(bs, -1, num_patches) for (x, proj, pool) in zip(x, self.projections, self.im_pools)] + x = torch.cat(x, dim=-1).transpose(1, 2) + q = self.query(text) + k = self.key(x) + v = self.value(x) + + # q = q.reshape(1, text.shape[1], self.nh, self.hc).repeat(bs, 1, 1, 1) + q = q.reshape(bs, -1, self.nh, self.hc) + k = k.reshape(bs, -1, self.nh, self.hc) + v = v.reshape(bs, -1, self.nh, self.hc) + + aw = torch.einsum("bnmc,bkmc->bmnk", q, k) + aw = aw / (self.hc**0.5) + aw = F.softmax(aw, dim=-1) + + x = torch.einsum("bmnk,bkmc->bnmc", aw, v) + x = self.proj(x.reshape(bs, -1, self.ec)) + return x * self.scale + text + + +class ContrastiveHead(nn.Module): + """Contrastive Head for YOLO-World compute the region-text scores according to the similarity between image and text + features. + """ + + def __init__(self): + """Initializes ContrastiveHead with specified region-text similarity parameters.""" + super().__init__() + self.bias = nn.Parameter(torch.zeros([])) + self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log()) + + def forward(self, x, w): + """Forward function of contrastive learning.""" + x = F.normalize(x, dim=1, p=2) + w = F.normalize(w, dim=-1, p=2) + x = torch.einsum("bchw,bkc->bkhw", x, w) + return x * self.logit_scale.exp() + self.bias + + +class BNContrastiveHead(nn.Module): + """ + Batch Norm Contrastive Head for YOLO-World using batch norm instead of l2-normalization. + + Args: + embed_dims (int): Embed dimensions of text and image features. + """ + + def __init__(self, embed_dims: int): + """Initialize ContrastiveHead with region-text similarity parameters.""" + super().__init__() + self.norm = nn.BatchNorm2d(embed_dims) + self.bias = nn.Parameter(torch.zeros([])) + # use -1.0 is more stable + self.logit_scale = nn.Parameter(-1.0 * torch.ones([])) + + def forward(self, x, w): + """Forward function of contrastive learning.""" + x = self.norm(x) + w = F.normalize(w, dim=-1, p=2) + x = torch.einsum("bchw,bkc->bkhw", x, w) + return x * self.logit_scale.exp() + self.bias + + +class RepBottleneck(nn.Module): + """Rep bottleneck.""" + + def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): + """Initializes a RepBottleneck module with customizable in/out channels, shortcut option, groups and expansion + ratio. + """ + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = RepConv(c1, c_, k[0], 1) + self.cv2 = Conv(c_, c2, k[1], 1, g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + """Forward pass through RepBottleneck layer.""" + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class RepCSP(nn.Module): + """Rep CSP Bottleneck with 3 convolutions.""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """Initializes RepCSP layer with given channels, repetitions, shortcut, groups and expansion ratio.""" + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2) + self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + def forward(self, x): + """Forward pass through RepCSP layer.""" + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) + + +class RepNCSPELAN4(nn.Module): + """CSP-ELAN.""" + + def __init__(self, c1, c2, c3, c4, n=1): + """Initializes CSP-ELAN layer with specified channel sizes, repetitions, and convolutions.""" + super().__init__() + self.c = c3 // 2 + self.cv1 = Conv(c1, c3, 1, 1) + self.cv2 = nn.Sequential(RepCSP(c3 // 2, c4, n), Conv(c4, c4, 3, 1)) + self.cv3 = nn.Sequential(RepCSP(c4, c4, n), Conv(c4, c4, 3, 1)) + self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1) + + def forward(self, x): + """Forward pass through RepNCSPELAN4 layer.""" + y = list(self.cv1(x).chunk(2, 1)) + y.extend((m(y[-1])) for m in [self.cv2, self.cv3]) + return self.cv4(torch.cat(y, 1)) + + def forward_split(self, x): + """Forward pass using split() instead of chunk().""" + y = list(self.cv1(x).split((self.c, self.c), 1)) + y.extend(m(y[-1]) for m in [self.cv2, self.cv3]) + return self.cv4(torch.cat(y, 1)) + + +class ADown(nn.Module): + """ADown.""" + + def __init__(self, c1, c2): + """Initializes ADown module with convolution layers to downsample input from channels c1 to c2.""" + super().__init__() + self.c = c2 // 2 + self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1) + self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0) + + def forward(self, x): + """Forward pass through ADown layer.""" + x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True) + x1, x2 = x.chunk(2, 1) + x1 = self.cv1(x1) + x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1) + x2 = self.cv2(x2) + return torch.cat((x1, x2), 1) + + +class SPPELAN(nn.Module): + """SPP-ELAN.""" + + def __init__(self, c1, c2, c3, k=5): + """Initializes SPP-ELAN block with convolution and max pooling layers for spatial pyramid pooling.""" + super().__init__() + self.c = c3 + self.cv1 = Conv(c1, c3, 1, 1) + self.cv2 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + self.cv3 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + self.cv4 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + self.cv5 = Conv(4 * c3, c2, 1, 1) + + def forward(self, x): + """Forward pass through SPPELAN layer.""" + y = [self.cv1(x)] + y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4]) + return self.cv5(torch.cat(y, 1)) + + +class Silence(nn.Module): + """Silence.""" + + def __init__(self): + """Initializes the Silence module.""" + super(Silence, self).__init__() + + def forward(self, x): + """Forward pass through Silence layer.""" + return x + + +class CBLinear(nn.Module): + """CBLinear.""" + + def __init__(self, c1, c2s, k=1, s=1, p=None, g=1): + """Initializes the CBLinear module, passing inputs unchanged.""" + super(CBLinear, self).__init__() + self.c2s = c2s + self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True) + + def forward(self, x): + """Forward pass through CBLinear layer.""" + outs = self.conv(x).split(self.c2s, dim=1) + return outs + + +class CBFuse(nn.Module): + """CBFuse.""" + + def __init__(self, idx): + """Initializes CBFuse module with layer index for selective feature fusion.""" + super(CBFuse, self).__init__() + self.idx = idx + + def forward(self, xs): + """Forward pass through CBFuse layer.""" + target_size = xs[-1].shape[2:] + res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])] + out = torch.sum(torch.stack(res + xs[-1:]), dim=0) + return out + + +class RepVGGDW(torch.nn.Module): + def __init__(self, ed) -> None: + super().__init__() + self.conv = Conv(ed, ed, 7, 1, 3, g=ed, act=False) + self.conv1 = Conv(ed, ed, 3, 1, 1, g=ed, act=False) + self.dim = ed + self.act = nn.SiLU() + + def forward(self, x): + return self.act(self.conv(x) + self.conv1(x)) + + def forward_fuse(self, x): + return self.act(self.conv(x)) + + @torch.no_grad() + def fuse(self): + conv = fuse_conv_and_bn(self.conv.conv, self.conv.bn) + conv1 = fuse_conv_and_bn(self.conv1.conv, self.conv1.bn) + + conv_w = conv.weight + conv_b = conv.bias + conv1_w = conv1.weight + conv1_b = conv1.bias + + conv1_w = torch.nn.functional.pad(conv1_w, [2,2,2,2]) + + final_conv_w = conv_w + conv1_w + final_conv_b = conv_b + conv1_b + + conv.weight.data.copy_(final_conv_w) + conv.bias.data.copy_(final_conv_b) + + self.conv = conv + del self.conv1 + +class CIB(nn.Module): + """Standard bottleneck.""" + + def __init__(self, c1, c2, shortcut=True, e=0.5, lk=False): + """Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and + expansion. + """ + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = nn.Sequential( + Conv(c1, c1, 3, g=c1), + Conv(c1, 2 * c_, 1), + Conv(2 * c_, 2 * c_, 3, g=2 * c_) if not lk else RepVGGDW(2 * c_), + Conv(2 * c_, c2, 1), + Conv(c2, c2, 3, g=c2), + ) + + self.add = shortcut and c1 == c2 + + def forward(self, x): + """'forward()' applies the YOLO FPN to input data.""" + return x + self.cv1(x) if self.add else self.cv1(x) + +class C2fCIB(C2f): + """Faster Implementation of CSP Bottleneck with 2 convolutions.""" + + def __init__(self, c1, c2, n=1, shortcut=False, lk=False, g=1, e=0.5): + """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups, + expansion. + """ + super().__init__(c1, c2, n, shortcut, g, e) + self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n)) + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, + attn_ratio=0.5): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.key_dim = int(self.head_dim * attn_ratio) + self.scale = self.key_dim ** -0.5 + nh_kd = nh_kd = self.key_dim * num_heads + h = dim + nh_kd * 2 + self.qkv = Conv(dim, h, 1, act=False) + self.proj = Conv(dim, dim, 1, act=False) + self.pe = Conv(dim, dim, 3, 1, g=dim, act=False) + + def forward(self, x): + B, C, H, W = x.shape + N = H * W + qkv = self.qkv(x) + q, k, v = qkv.view(B, self.num_heads, self.key_dim*2 + self.head_dim, N).split([self.key_dim, self.key_dim, self.head_dim], dim=2) + + attn = ( + (q.transpose(-2, -1) @ k) * self.scale + ) + attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W)) + x = self.proj(x) + return x + +class PSA(nn.Module): + + def __init__(self, c1, c2, e=0.5): + super().__init__() + assert(c1 == c2) + self.c = int(c1 * e) + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv(2 * self.c, c1, 1) + + self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64) + self.ffn = nn.Sequential( + Conv(self.c, self.c*2, 1), + Conv(self.c*2, self.c, 1, act=False) + ) + + def forward(self, x): + a, b = self.cv1(x).split((self.c, self.c), dim=1) + b = b + self.attn(b) + b = b + self.ffn(b) + return self.cv2(torch.cat((a, b), 1)) + +class SCDown(nn.Module): + def __init__(self, c1, c2, k, s): + super().__init__() + self.cv1 = Conv(c1, c2, 1, 1) + self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False) + + def forward(self, x): + return self.cv2(self.cv1(x)) \ No newline at end of file diff --git a/ultralytics/nn/modules/conv.py b/ultralytics/nn/modules/conv.py index 77e99c0..399c422 100644 --- a/ultralytics/nn/modules/conv.py +++ b/ultralytics/nn/modules/conv.py @@ -1,7 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -""" -Convolution modules -""" +"""Convolution modules.""" import math @@ -9,8 +7,21 @@ import numpy as np import torch import torch.nn as nn -__all__ = ('Conv', 'Conv2', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv', - 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv') +__all__ = ( + "Conv", + "Conv2", + "LightConv", + "DWConv", + "DWConvTranspose2d", + "ConvTranspose", + "Focus", + "GhostConv", + "ChannelAttention", + "SpatialAttention", + "CBAM", + "Concat", + "RepConv", +) def autopad(k, p=None, d=1): # kernel, padding, dilation @@ -24,6 +35,7 @@ def autopad(k, p=None, d=1): # kernel, padding, dilation class Conv(nn.Module): """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation).""" + default_act = nn.SiLU() # default activation def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): @@ -62,14 +74,16 @@ class Conv2(Conv): """Fuse parallel convolutions.""" w = torch.zeros_like(self.conv.weight.data) i = [x // 2 for x in w.shape[2:]] - w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone() + w[:, :, i[0] : i[0] + 1, i[1] : i[1] + 1] = self.cv2.weight.data.clone() self.conv.weight.data += w - self.__delattr__('cv2') + self.__delattr__("cv2") self.forward = self.forward_fuse class LightConv(nn.Module): - """Light convolution with args(ch_in, ch_out, kernel). + """ + Light convolution with args(ch_in, ch_out, kernel). + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py """ @@ -88,6 +102,7 @@ class DWConv(Conv): """Depth-wise convolution.""" def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation + """Initialize Depth-wise convolution with given parameters.""" super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act) @@ -95,11 +110,13 @@ class DWConvTranspose2d(nn.ConvTranspose2d): """Depth-wise transpose convolution.""" def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out + """Initialize DWConvTranspose2d class with given parameters.""" super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2)) class ConvTranspose(nn.Module): """Convolution transpose 2d layer.""" + default_act = nn.SiLU() # default activation def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True): @@ -121,12 +138,18 @@ class ConvTranspose(nn.Module): class Focus(nn.Module): """Focus wh information into c-space.""" - def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): + """Initializes Focus object with user defined channel, convolution, padding, group and activation values.""" super().__init__() self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act) # self.contract = Contract(gain=2) - def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) + def forward(self, x): + """ + Applies convolution to concatenated tensor and returns the output. + + Input shape is (b,c,w,h) and output shape is (b,4c,w/2,h/2). + """ return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1)) # return self.conv(self.contract(x)) @@ -134,7 +157,10 @@ class Focus(nn.Module): class GhostConv(nn.Module): """Ghost Convolution https://github.com/huawei-noah/ghostnet.""" - def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups + def __init__(self, c1, c2, k=1, s=1, g=1, act=True): + """Initializes the GhostConv object with input channels, output channels, kernel size, stride, groups and + activation. + """ super().__init__() c_ = c2 // 2 # hidden channels self.cv1 = Conv(c1, c_, k, s, None, g, act=act) @@ -148,12 +174,16 @@ class GhostConv(nn.Module): class RepConv(nn.Module): """ - RepConv is a basic rep-style block, including training and deploy status. This module is used in RT-DETR. + RepConv is a basic rep-style block, including training and deploy status. + + This module is used in RT-DETR. Based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py """ + default_act = nn.SiLU() # default activation def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False): + """Initializes Light Convolution layer with inputs, outputs & optional activation function.""" super().__init__() assert k == 3 and p == 1 self.g = g @@ -166,27 +196,30 @@ class RepConv(nn.Module): self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False) def forward_fuse(self, x): - """Forward process""" + """Forward process.""" return self.act(self.conv(x)) def forward(self, x): - """Forward process""" + """Forward process.""" id_out = 0 if self.bn is None else self.bn(x) return self.act(self.conv1(x) + self.conv2(x) + id_out) def get_equivalent_kernel_bias(self): + """Returns equivalent kernel and bias by adding 3x3 kernel, 1x1 kernel and identity kernel with their biases.""" kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1) kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2) kernelid, biasid = self._fuse_bn_tensor(self.bn) return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid def _pad_1x1_to_3x3_tensor(self, kernel1x1): + """Pads a 1x1 tensor to a 3x3 tensor.""" if kernel1x1 is None: return 0 else: return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) def _fuse_bn_tensor(self, branch): + """Generates appropriate kernels and biases for convolution by fusing branches of the neural network.""" if branch is None: return 0, 0 if isinstance(branch, Conv): @@ -197,7 +230,7 @@ class RepConv(nn.Module): beta = branch.bn.bias eps = branch.bn.eps elif isinstance(branch, nn.BatchNorm2d): - if not hasattr(self, 'id_tensor'): + if not hasattr(self, "id_tensor"): input_dim = self.c1 // self.g kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32) for i in range(self.c1): @@ -214,41 +247,46 @@ class RepConv(nn.Module): return kernel * t, beta - running_mean * gamma / std def fuse_convs(self): - if hasattr(self, 'conv'): + """Combines two convolution layers into a single layer and removes unused attributes from the class.""" + if hasattr(self, "conv"): return kernel, bias = self.get_equivalent_kernel_bias() - self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels, - out_channels=self.conv1.conv.out_channels, - kernel_size=self.conv1.conv.kernel_size, - stride=self.conv1.conv.stride, - padding=self.conv1.conv.padding, - dilation=self.conv1.conv.dilation, - groups=self.conv1.conv.groups, - bias=True).requires_grad_(False) + self.conv = nn.Conv2d( + in_channels=self.conv1.conv.in_channels, + out_channels=self.conv1.conv.out_channels, + kernel_size=self.conv1.conv.kernel_size, + stride=self.conv1.conv.stride, + padding=self.conv1.conv.padding, + dilation=self.conv1.conv.dilation, + groups=self.conv1.conv.groups, + bias=True, + ).requires_grad_(False) self.conv.weight.data = kernel self.conv.bias.data = bias for para in self.parameters(): para.detach_() - self.__delattr__('conv1') - self.__delattr__('conv2') - if hasattr(self, 'nm'): - self.__delattr__('nm') - if hasattr(self, 'bn'): - self.__delattr__('bn') - if hasattr(self, 'id_tensor'): - self.__delattr__('id_tensor') + self.__delattr__("conv1") + self.__delattr__("conv2") + if hasattr(self, "nm"): + self.__delattr__("nm") + if hasattr(self, "bn"): + self.__delattr__("bn") + if hasattr(self, "id_tensor"): + self.__delattr__("id_tensor") class ChannelAttention(nn.Module): """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet.""" def __init__(self, channels: int) -> None: + """Initializes the class and sets the basic configurations and instance variables required.""" super().__init__() self.pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True) self.act = nn.Sigmoid() def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies forward pass using activation on convolutions of the input, optionally using batch normalization.""" return x * self.act(self.fc(self.pool(x))) @@ -258,7 +296,7 @@ class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): """Initialize Spatial-attention module with kernel size argument.""" super().__init__() - assert kernel_size in (3, 7), 'kernel size must be 3 or 7' + assert kernel_size in (3, 7), "kernel size must be 3 or 7" padding = 3 if kernel_size == 7 else 1 self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.act = nn.Sigmoid() @@ -271,7 +309,8 @@ class SpatialAttention(nn.Module): class CBAM(nn.Module): """Convolutional Block Attention Module.""" - def __init__(self, c1, kernel_size=7): # ch_in, kernels + def __init__(self, c1, kernel_size=7): + """Initialize CBAM with given input channel (c1) and kernel size.""" super().__init__() self.channel_attention = ChannelAttention(c1) self.spatial_attention = SpatialAttention(kernel_size) diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 0b02eb3..a9c5d9e 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -1,7 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -""" -Model head modules -""" +"""Model head modules.""" import math @@ -9,25 +7,28 @@ import torch import torch.nn as nn from torch.nn.init import constant_, xavier_uniform_ -from ultralytics.utils.tal import TORCH_1_10, dist2bbox, make_anchors - -from .block import DFL, Proto +from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors +from .block import DFL, Proto, ContrastiveHead, BNContrastiveHead from .conv import Conv from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer -from .utils import bias_init_with_prob, linear_init_ +from .utils import bias_init_with_prob, linear_init +import copy +from ultralytics.utils import ops -__all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'RTDETRDecoder' +__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder" class Detect(nn.Module): """YOLOv8 Detect head for detection models.""" + dynamic = False # force grid reconstruction export = False # export mode shape = None anchors = torch.empty(0) # init strides = torch.empty(0) # init - def __init__(self, nc=80, ch=()): # detection layer + def __init__(self, nc=80, ch=()): + """Initializes the YOLOv8 detection layer with specified number of classes and channels.""" super().__init__() self.nc = nc # number of classes self.nl = len(ch) # number of detection layers @@ -36,41 +37,54 @@ class Detect(nn.Module): self.stride = torch.zeros(self.nl) # strides computed during build c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels self.cv2 = nn.ModuleList( - nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch) + nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch + ) self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() - def forward(self, x): - """Concatenates and returns predicted bounding boxes and class probabilities.""" + def inference(self, x): + # Inference path shape = x[0].shape # BCHW - for i in range(self.nl): - x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) - if self.training: - return x - elif self.dynamic or self.shape != shape: + x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) + if self.dynamic or self.shape != shape: self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) self.shape = shape - x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) - if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops - box = x_cat[:, :self.reg_max * 4] - cls = x_cat[:, self.reg_max * 4:] + if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops + box = x_cat[:, : self.reg_max * 4] + cls = x_cat[:, self.reg_max * 4 :] else: box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) - dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides - if self.export and self.format in ('tflite', 'edgetpu'): - # Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5: - # https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309 - # See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695 - img_h = shape[2] * self.stride[0] - img_w = shape[3] * self.stride[0] - img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1) - dbox /= img_size + if self.export and self.format in ("tflite", "edgetpu"): + # Precompute normalization factor to increase numerical stability + # See https://github.com/ultralytics/ultralytics/issues/7371 + grid_h = shape[2] + grid_w = shape[3] + grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1) + norm = self.strides / (self.stride[0] * grid_size) + dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) + else: + dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides y = torch.cat((dbox, cls.sigmoid()), 1) return y if self.export else (y, x) + def forward_feat(self, x, cv2, cv3): + y = [] + for i in range(self.nl): + y.append(torch.cat((cv2[i](x[i]), cv3[i](x[i])), 1)) + return y + + def forward(self, x): + """Concatenates and returns predicted bounding boxes and class probabilities.""" + y = self.forward_feat(x, self.cv2, self.cv3) + + if self.training: + return y + + return self.inference(y) + def bias_init(self): """Initialize Detect() biases, WARNING: requires stride availability.""" m = self # self.model[-1] # Detect() module @@ -78,7 +92,13 @@ class Detect(nn.Module): # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency for a, b, s in zip(m.cv2, m.cv3, m.stride): # from a[-1].bias.data[:] = 1.0 # box - b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) + b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) + + def decode_bboxes(self, bboxes, anchors): + """Decode bounding boxes.""" + if self.export: + return dist2bbox(bboxes, anchors, xywh=False, dim=1) + return dist2bbox(bboxes, anchors, xywh=True, dim=1) class Segment(Detect): @@ -107,6 +127,37 @@ class Segment(Detect): return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p)) +class OBB(Detect): + """YOLOv8 OBB detection head for detection with rotation models.""" + + def __init__(self, nc=80, ne=1, ch=()): + """Initialize OBB with number of classes `nc` and layer channels `ch`.""" + super().__init__(nc, ch) + self.ne = ne # number of extra parameters + self.detect = Detect.forward + + c4 = max(ch[0] // 4, self.ne) + self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch) + + def forward(self, x): + """Concatenates and returns predicted bounding boxes and class probabilities.""" + bs = x[0].shape[0] # batch size + angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits + # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it. + angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4] + # angle = angle.sigmoid() * math.pi / 2 # [0, pi/2] + if not self.training: + self.angle = angle + x = self.detect(self, x) + if self.training: + return x, angle + return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle)) + + def decode_bboxes(self, bboxes, anchors): + """Decode rotated bounding boxes.""" + return dist2rbox(bboxes, self.angle, anchors, dim=1) + + class Pose(Detect): """YOLOv8 Pose head for keypoints models.""" @@ -142,7 +193,7 @@ class Pose(Detect): else: y = kpts.clone() if ndim == 3: - y[:, 2::3].sigmoid_() # inplace sigmoid + y[:, 2::3] = y[:, 2::3].sigmoid() # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug) y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides return y @@ -151,7 +202,10 @@ class Pose(Detect): class Classify(nn.Module): """YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2).""" - def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups + def __init__(self, c1, c2, k=1, s=1, p=None, g=1): + """Initializes YOLOv8 classification head with specified input and output channels, kernel size, stride, + padding, and groups. + """ super().__init__() c_ = 1280 # efficientnet_b0 size self.conv = Conv(c1, c_, k, s, p, g) @@ -167,27 +221,99 @@ class Classify(nn.Module): return x if self.training else x.softmax(1) +class WorldDetect(Detect): + def __init__(self, nc=80, embed=512, with_bn=False, ch=()): + """Initialize YOLOv8 detection layer with nc classes and layer channels ch.""" + super().__init__(nc, ch) + c3 = max(ch[0], min(self.nc, 100)) + self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch) + self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch) + + def forward(self, x, text): + """Concatenates and returns predicted bounding boxes and class probabilities.""" + for i in range(self.nl): + x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1) + if self.training: + return x + + # Inference path + shape = x[0].shape # BCHW + x_cat = torch.cat([xi.view(shape[0], self.nc + self.reg_max * 4, -1) for xi in x], 2) + if self.dynamic or self.shape != shape: + self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) + self.shape = shape + + if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops + box = x_cat[:, : self.reg_max * 4] + cls = x_cat[:, self.reg_max * 4 :] + else: + box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) + + if self.export and self.format in ("tflite", "edgetpu"): + # Precompute normalization factor to increase numerical stability + # See https://github.com/ultralytics/ultralytics/issues/7371 + grid_h = shape[2] + grid_w = shape[3] + grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1) + norm = self.strides / (self.stride[0] * grid_size) + dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) + else: + dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides + + y = torch.cat((dbox, cls.sigmoid()), 1) + return y if self.export else (y, x) + + class RTDETRDecoder(nn.Module): + """ + Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection. + + This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes + and class labels for objects in an image. It integrates features from multiple layers and runs through a series of + Transformer decoder layers to output the final predictions. + """ + export = False # export mode def __init__( - self, - nc=80, - ch=(512, 1024, 2048), - hd=256, # hidden dim - nq=300, # num queries - ndp=4, # num decoder points - nh=8, # num head - ndl=6, # num decoder layers - d_ffn=1024, # dim of feedforward - dropout=0., - act=nn.ReLU(), - eval_idx=-1, - # training args - nd=100, # num denoising - label_noise_ratio=0.5, - box_noise_scale=1.0, - learnt_init_query=False): + self, + nc=80, + ch=(512, 1024, 2048), + hd=256, # hidden dim + nq=300, # num queries + ndp=4, # num decoder points + nh=8, # num head + ndl=6, # num decoder layers + d_ffn=1024, # dim of feedforward + dropout=0.0, + act=nn.ReLU(), + eval_idx=-1, + # Training args + nd=100, # num denoising + label_noise_ratio=0.5, + box_noise_scale=1.0, + learnt_init_query=False, + ): + """ + Initializes the RTDETRDecoder module with the given parameters. + + Args: + nc (int): Number of classes. Default is 80. + ch (tuple): Channels in the backbone feature maps. Default is (512, 1024, 2048). + hd (int): Dimension of hidden layers. Default is 256. + nq (int): Number of query points. Default is 300. + ndp (int): Number of decoder points. Default is 4. + nh (int): Number of heads in multi-head attention. Default is 8. + ndl (int): Number of decoder layers. Default is 6. + d_ffn (int): Dimension of the feed-forward networks. Default is 1024. + dropout (float): Dropout rate. Default is 0. + act (nn.Module): Activation function. Default is nn.ReLU. + eval_idx (int): Evaluation index. Default is -1. + nd (int): Number of denoising. Default is 100. + label_noise_ratio (float): Label noise ratio. Default is 0.5. + box_noise_scale (float): Box noise scale. Default is 1.0. + learnt_init_query (bool): Whether to learn initial query embeddings. Default is False. + """ super().__init__() self.hidden_dim = hd self.nhead = nh @@ -196,7 +322,7 @@ class RTDETRDecoder(nn.Module): self.num_queries = nq self.num_decoder_layers = ndl - # backbone feature projection + # Backbone feature projection self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch) # NOTE: simplified version but it's not consistent with .pt weights. # self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch) @@ -205,58 +331,61 @@ class RTDETRDecoder(nn.Module): decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp) self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx) - # denoising part + # Denoising part self.denoising_class_embed = nn.Embedding(nc, hd) self.num_denoising = nd self.label_noise_ratio = label_noise_ratio self.box_noise_scale = box_noise_scale - # decoder embedding + # Decoder embedding self.learnt_init_query = learnt_init_query if learnt_init_query: self.tgt_embed = nn.Embedding(nq, hd) self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2) - # encoder head + # Encoder head self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd)) self.enc_score_head = nn.Linear(hd, nc) self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3) - # decoder head + # Decoder head self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)]) self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)]) self._reset_parameters() def forward(self, x, batch=None): + """Runs the forward pass of the module, returning bounding box and classification scores for the input.""" from ultralytics.models.utils.ops import get_cdn_group - # input projection and embedding + # Input projection and embedding feats, shapes = self._get_encoder_input(x) - # prepare denoising training - dn_embed, dn_bbox, attn_mask, dn_meta = \ - get_cdn_group(batch, - self.nc, - self.num_queries, - self.denoising_class_embed.weight, - self.num_denoising, - self.label_noise_ratio, - self.box_noise_scale, - self.training) + # Prepare denoising training + dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group( + batch, + self.nc, + self.num_queries, + self.denoising_class_embed.weight, + self.num_denoising, + self.label_noise_ratio, + self.box_noise_scale, + self.training, + ) - embed, refer_bbox, enc_bboxes, enc_scores = \ - self._get_decoder_input(feats, shapes, dn_embed, dn_bbox) + embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox) - # decoder - dec_bboxes, dec_scores = self.decoder(embed, - refer_bbox, - feats, - shapes, - self.dec_bbox_head, - self.dec_score_head, - self.query_pos_head, - attn_mask=attn_mask) + # Decoder + dec_bboxes, dec_scores = self.decoder( + embed, + refer_bbox, + feats, + shapes, + self.dec_bbox_head, + self.dec_score_head, + self.query_pos_head, + attn_mask=attn_mask, + ) x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta if self.training: return x @@ -264,29 +393,31 @@ class RTDETRDecoder(nn.Module): y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1) return y if self.export else (y, x) - def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2): + def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2): + """Generates anchor bounding boxes for given shapes with specific grid size and validates them.""" anchors = [] for i, (h, w) in enumerate(shapes): sy = torch.arange(end=h, dtype=dtype, device=device) sx = torch.arange(end=w, dtype=dtype, device=device) - grid_y, grid_x = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx) + grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2) - valid_WH = torch.tensor([h, w], dtype=dtype, device=device) + valid_WH = torch.tensor([w, h], dtype=dtype, device=device) grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2) - wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i) + wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i) anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4) anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4) - valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1 + valid_mask = ((anchors > eps) & (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1 anchors = torch.log(anchors / (1 - anchors)) - anchors = anchors.masked_fill(~valid_mask, float('inf')) + anchors = anchors.masked_fill(~valid_mask, float("inf")) return anchors, valid_mask def _get_encoder_input(self, x): - # get projection features + """Processes and returns encoder inputs by getting projection features from input and concatenating them.""" + # Get projection features x = [self.input_proj[i](feat) for i, feat in enumerate(x)] - # get encoder inputs + # Get encoder inputs feats = [] shapes = [] for feat in x: @@ -301,14 +432,15 @@ class RTDETRDecoder(nn.Module): return feats, shapes def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None): - bs = len(feats) - # prepare input for decoder + """Generates and prepares the input required for the decoder from the provided features and shapes.""" + bs = feats.shape[0] + # Prepare input for decoder anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device) features = self.enc_output(valid_mask * feats) # bs, h*w, 256 enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc) - # query selection + # Query selection # (bs, num_queries) topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1) # (bs, num_queries) @@ -319,7 +451,7 @@ class RTDETRDecoder(nn.Module): # (bs, num_queries, 4) top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1) - # dynamic anchors + static content + # Dynamic anchors + static content refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors enc_bboxes = refer_bbox.sigmoid() @@ -339,20 +471,21 @@ class RTDETRDecoder(nn.Module): # TODO def _reset_parameters(self): - # class and bbox head init + """Initializes or resets the parameters of the model's various components with predefined weights and biases.""" + # Class and bbox head init bias_cls = bias_init_with_prob(0.01) / 80 * self.nc - # NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets. - # linear_init_(self.enc_score_head) + # NOTE: the weight initialization in `linear_init` would cause NaN when training with custom datasets. + # linear_init(self.enc_score_head) constant_(self.enc_score_head.bias, bias_cls) - constant_(self.enc_bbox_head.layers[-1].weight, 0.) - constant_(self.enc_bbox_head.layers[-1].bias, 0.) + constant_(self.enc_bbox_head.layers[-1].weight, 0.0) + constant_(self.enc_bbox_head.layers[-1].bias, 0.0) for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head): - # linear_init_(cls_) + # linear_init(cls_) constant_(cls_.bias, bias_cls) - constant_(reg_.layers[-1].weight, 0.) - constant_(reg_.layers[-1].bias, 0.) + constant_(reg_.layers[-1].weight, 0.0) + constant_(reg_.layers[-1].bias, 0.0) - linear_init_(self.enc_output[0]) + linear_init(self.enc_output[0]) xavier_uniform_(self.enc_output[0].weight) if self.learnt_init_query: xavier_uniform_(self.tgt_embed.weight) @@ -360,3 +493,43 @@ class RTDETRDecoder(nn.Module): xavier_uniform_(self.query_pos_head.layers[1].weight) for layer in self.input_proj: xavier_uniform_(layer[0].weight) + +class v10Detect(Detect): + + max_det = 300 + + def __init__(self, nc=80, ch=()): + super().__init__(nc, ch) + c3 = max(ch[0], min(self.nc, 100)) # channels + self.cv3 = nn.ModuleList(nn.Sequential(nn.Sequential(Conv(x, x, 3, g=x), Conv(x, c3, 1)), \ + nn.Sequential(Conv(c3, c3, 3, g=c3), Conv(c3, c3, 1)), \ + nn.Conv2d(c3, self.nc, 1)) for i, x in enumerate(ch)) + + self.one2one_cv2 = copy.deepcopy(self.cv2) + self.one2one_cv3 = copy.deepcopy(self.cv3) + + def forward(self, x): + one2one = self.forward_feat([xi.detach() for xi in x], self.one2one_cv2, self.one2one_cv3) + if not self.export: + one2many = super().forward(x) + + if not self.training: + one2one = self.inference(one2one) + if not self.export: + return {"one2many": one2many, "one2one": one2one} + else: + assert(self.max_det != -1) + boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det, self.nc) + return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1) + else: + return {"one2many": one2many, "one2one": one2one} + + def bias_init(self): + super().bias_init() + """Initialize Detect() biases, WARNING: requires stride availability.""" + m = self # self.model[-1] # Detect() module + # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1 + # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency + for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from + a[-1].bias.data[:] = 1.0 # box + b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) diff --git a/ultralytics/nn/modules/transformer.py b/ultralytics/nn/modules/transformer.py index 9a51d2c..062c609 100644 --- a/ultralytics/nn/modules/transformer.py +++ b/ultralytics/nn/modules/transformer.py @@ -1,7 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -""" -Transformer modules -""" +"""Transformer modules.""" import math @@ -13,19 +11,32 @@ from torch.nn.init import constant_, xavier_uniform_ from .conv import Conv from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch -__all__ = ('TransformerEncoderLayer', 'TransformerLayer', 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'AIFI', - 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP') +__all__ = ( + "TransformerEncoderLayer", + "TransformerLayer", + "TransformerBlock", + "MLPBlock", + "LayerNorm2d", + "AIFI", + "DeformableTransformerDecoder", + "DeformableTransformerDecoderLayer", + "MSDeformAttn", + "MLP", +) class TransformerEncoderLayer(nn.Module): - """Transformer Encoder.""" + """Defines a single layer of the transformer encoder.""" def __init__(self, c1, cm=2048, num_heads=8, dropout=0.0, act=nn.GELU(), normalize_before=False): + """Initialize the TransformerEncoderLayer with specified parameters.""" super().__init__() from ...utils.torch_utils import TORCH_1_9 + if not TORCH_1_9: raise ModuleNotFoundError( - 'TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True).') + "TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True)." + ) self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True) # Implementation of Feedforward model self.fc1 = nn.Linear(c1, cm) @@ -40,11 +51,13 @@ class TransformerEncoderLayer(nn.Module): self.act = act self.normalize_before = normalize_before - def with_pos_embed(self, tensor, pos=None): - """Add position embeddings if given.""" + @staticmethod + def with_pos_embed(tensor, pos=None): + """Add position embeddings to the tensor if provided.""" return tensor if pos is None else tensor + pos def forward_post(self, src, src_mask=None, src_key_padding_mask=None, pos=None): + """Performs forward pass with post-normalization.""" q = k = self.with_pos_embed(src, pos) src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] src = src + self.dropout1(src2) @@ -54,6 +67,7 @@ class TransformerEncoderLayer(nn.Module): return self.norm2(src) def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None): + """Performs forward pass with pre-normalization.""" src2 = self.norm1(src) q = k = self.with_pos_embed(src2, pos) src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] @@ -70,27 +84,30 @@ class TransformerEncoderLayer(nn.Module): class AIFI(TransformerEncoderLayer): + """Defines the AIFI transformer layer.""" def __init__(self, c1, cm=2048, num_heads=8, dropout=0, act=nn.GELU(), normalize_before=False): + """Initialize the AIFI instance with specified parameters.""" super().__init__(c1, cm, num_heads, dropout, act, normalize_before) def forward(self, x): + """Forward pass for the AIFI transformer layer.""" c, h, w = x.shape[1:] pos_embed = self.build_2d_sincos_position_embedding(w, h, c) - # flatten [B, C, H, W] to [B, HxW, C] + # Flatten [B, C, H, W] to [B, HxW, C] x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype)) return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous() @staticmethod - def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.): - grid_w = torch.arange(int(w), dtype=torch.float32) - grid_h = torch.arange(int(h), dtype=torch.float32) - grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij') - assert embed_dim % 4 == 0, \ - 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' + def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0): + """Builds 2D sine-cosine position embedding.""" + assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" + grid_w = torch.arange(w, dtype=torch.float32) + grid_h = torch.arange(h, dtype=torch.float32) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") pos_dim = embed_dim // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim - omega = 1. / (temperature ** omega) + omega = 1.0 / (temperature**omega) out_w = grid_w.flatten()[..., None] @ omega[None] out_h = grid_h.flatten()[..., None] @ omega[None] @@ -140,27 +157,32 @@ class TransformerBlock(nn.Module): class MLPBlock(nn.Module): + """Implements a single block of a multi-layer perceptron.""" def __init__(self, embedding_dim, mlp_dim, act=nn.GELU): + """Initialize the MLPBlock with specified embedding dimension, MLP dimension, and activation function.""" super().__init__() self.lin1 = nn.Linear(embedding_dim, mlp_dim) self.lin2 = nn.Linear(mlp_dim, embedding_dim) self.act = act() def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for the MLPBlock.""" return self.lin2(self.act(self.lin1(x))) class MLP(nn.Module): - """ Very simple multi-layer perceptron (also called FFN)""" + """Implements a simple multi-layer perceptron (also called FFN).""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + """Initialize the MLP with specified input, hidden, output dimensions and number of layers.""" super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): + """Forward pass for the entire MLP.""" for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x @@ -168,17 +190,23 @@ class MLP(nn.Module): class LayerNorm2d(nn.Module): """ - LayerNorm2d module from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py - https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 + 2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations. + + Original implementations in + https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py + and + https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py. """ def __init__(self, num_channels, eps=1e-6): + """Initialize LayerNorm2d with the given parameters.""" super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x): + """Perform forward pass for 2D layer normalization.""" u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) @@ -187,17 +215,19 @@ class LayerNorm2d(nn.Module): class MSDeformAttn(nn.Module): """ - Original Multi-Scale Deformable Attention Module. + Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations. + https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py """ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """Initialize MSDeformAttn with the given parameters.""" super().__init__() if d_model % n_heads != 0: - raise ValueError(f'd_model must be divisible by n_heads, but got {d_model} and {n_heads}') + raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}") _d_per_head = d_model // n_heads - # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation - assert _d_per_head * n_heads == d_model, '`d_model` must be divisible by `n_heads`' + # Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation + assert _d_per_head * n_heads == d_model, "`d_model` must be divisible by `n_heads`" self.im2col_step = 64 @@ -214,25 +244,32 @@ class MSDeformAttn(nn.Module): self._reset_parameters() def _reset_parameters(self): - constant_(self.sampling_offsets.weight.data, 0.) + """Reset module parameters.""" + constant_(self.sampling_offsets.weight.data, 0.0) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) - grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat( - 1, self.n_levels, self.n_points, 1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - constant_(self.attention_weights.weight.data, 0.) - constant_(self.attention_weights.bias.data, 0.) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) xavier_uniform_(self.value_proj.weight.data) - constant_(self.value_proj.bias.data, 0.) + constant_(self.value_proj.bias.data, 0.0) xavier_uniform_(self.output_proj.weight.data) - constant_(self.output_proj.bias.data, 0.) + constant_(self.output_proj.bias.data, 0.0) def forward(self, query, refer_bbox, value, value_shapes, value_mask=None): """ + Perform forward pass for multiscale deformable attention. + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py + Args: query (torch.Tensor): [bs, query_length, C] refer_bbox (torch.Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), @@ -265,31 +302,34 @@ class MSDeformAttn(nn.Module): add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5 sampling_locations = refer_bbox[:, :, None, :, None, :2] + add else: - raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.') + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {num_points}.") output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights) return self.output_proj(output) class DeformableTransformerDecoderLayer(nn.Module): """ + Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations. + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py """ - def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0., act=nn.ReLU(), n_levels=4, n_points=4): + def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0.0, act=nn.ReLU(), n_levels=4, n_points=4): + """Initialize the DeformableTransformerDecoderLayer with the given parameters.""" super().__init__() - # self attention + # Self attention self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) - # cross attention + # Cross attention self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout2 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) - # ffn + # FFN self.linear1 = nn.Linear(d_model, d_ffn) self.act = act self.dropout3 = nn.Dropout(dropout) @@ -299,37 +339,46 @@ class DeformableTransformerDecoderLayer(nn.Module): @staticmethod def with_pos_embed(tensor, pos): + """Add positional embeddings to the input tensor, if provided.""" return tensor if pos is None else tensor + pos def forward_ffn(self, tgt): + """Perform forward pass through the Feed-Forward Network part of the layer.""" tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt)))) tgt = tgt + self.dropout4(tgt2) return self.norm3(tgt) def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None): - # self attention + """Perform the forward pass through the entire decoder layer.""" + + # Self attention q = k = self.with_pos_embed(embed, query_pos) - tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), - attn_mask=attn_mask)[0].transpose(0, 1) + tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[ + 0 + ].transpose(0, 1) embed = embed + self.dropout1(tgt) embed = self.norm1(embed) - # cross attention - tgt = self.cross_attn(self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, - padding_mask) + # Cross attention + tgt = self.cross_attn( + self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, padding_mask + ) embed = embed + self.dropout2(tgt) embed = self.norm2(embed) - # ffn + # FFN return self.forward_ffn(embed) class DeformableTransformerDecoder(nn.Module): """ + Implementation of Deformable Transformer Decoder based on PaddleDetection. + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py """ def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1): + """Initialize the DeformableTransformerDecoder with the given parameters.""" super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers @@ -337,16 +386,18 @@ class DeformableTransformerDecoder(nn.Module): self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx def forward( - self, - embed, # decoder embeddings - refer_bbox, # anchor - feats, # image features - shapes, # feature shapes - bbox_head, - score_head, - pos_mlp, - attn_mask=None, - padding_mask=None): + self, + embed, # decoder embeddings + refer_bbox, # anchor + feats, # image features + shapes, # feature shapes + bbox_head, + score_head, + pos_mlp, + attn_mask=None, + padding_mask=None, + ): + """Perform the forward pass through the entire decoder.""" output = embed dec_bboxes = [] dec_cls = [] diff --git a/ultralytics/nn/modules/utils.py b/ultralytics/nn/modules/utils.py index f8636dc..1512967 100644 --- a/ultralytics/nn/modules/utils.py +++ b/ultralytics/nn/modules/utils.py @@ -1,7 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -""" -Module utils -""" +"""Module utils.""" import copy import math @@ -12,37 +10,44 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn.init import uniform_ -__all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid' +__all__ = "multi_scale_deformable_attn_pytorch", "inverse_sigmoid" def _get_clones(module, n): + """Create a list of cloned modules from the given module.""" return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) def bias_init_with_prob(prior_prob=0.01): - """initialize conv/fc bias value according to a given probability value.""" + """Initialize conv/fc bias value according to a given probability value.""" return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init -def linear_init_(module): +def linear_init(module): + """Initialize the weights and biases of a linear module.""" bound = 1 / math.sqrt(module.weight.shape[0]) uniform_(module.weight, -bound, bound) - if hasattr(module, 'bias') and module.bias is not None: + if hasattr(module, "bias") and module.bias is not None: uniform_(module.bias, -bound, bound) def inverse_sigmoid(x, eps=1e-5): + """Calculate the inverse sigmoid function for a tensor.""" x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) return torch.log(x1 / x2) -def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor, - sampling_locations: torch.Tensor, - attention_weights: torch.Tensor) -> torch.Tensor: +def multi_scale_deformable_attn_pytorch( + value: torch.Tensor, + value_spatial_shapes: torch.Tensor, + sampling_locations: torch.Tensor, + attention_weights: torch.Tensor, +) -> torch.Tensor: """ - Multi-scale deformable attention. + Multiscale deformable attention. + https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py """ @@ -56,23 +61,25 @@ def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shape # bs, H_*W_, num_heads*embed_dims -> # bs, num_heads*embed_dims, H_*W_ -> # bs*num_heads, embed_dims, H_, W_ - value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)) + value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_) # bs, num_queries, num_heads, num_points, 2 -> # bs, num_heads, num_queries, num_points, 2 -> # bs*num_heads, num_queries, num_points, 2 sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) # bs*num_heads, embed_dims, num_queries, num_points - sampling_value_l_ = F.grid_sample(value_l_, - sampling_grid_l_, - mode='bilinear', - padding_mode='zeros', - align_corners=False) + sampling_value_l_ = F.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) sampling_value_list.append(sampling_value_l_) # (bs, num_queries, num_heads, num_levels, num_points) -> # (bs, num_heads, num_queries, num_levels, num_points) -> # (bs, num_heads, 1, num_queries, num_levels*num_points) - attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries, - num_levels * num_points) - output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view( - bs, num_heads * embed_dims, num_queries)) + attention_weights = attention_weights.transpose(1, 2).reshape( + bs * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(bs, num_heads * embed_dims, num_queries) + ) return output.transpose(1, 2).contiguous() diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 24153d2..268bd12 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -7,16 +7,68 @@ from pathlib import Path import torch import torch.nn as nn -from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, - Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d, - Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv, - RTDETRDecoder, Segment) +from ultralytics.nn.modules import ( + AIFI, + C1, + C2, + C3, + C3TR, + OBB, + SPP, + SPPF, + Bottleneck, + BottleneckCSP, + C2f, + C2fAttn, + ImagePoolingAttn, + C3Ghost, + C3x, + Classify, + Concat, + Conv, + Conv2, + ConvTranspose, + Detect, + DWConv, + DWConvTranspose2d, + Focus, + GhostBottleneck, + GhostConv, + HGBlock, + HGStem, + Pose, + RepC3, + RepConv, + ResNetLayer, + RTDETRDecoder, + Segment, + WorldDetect, + RepNCSPELAN4, + ADown, + SPPELAN, + CBFuse, + CBLinear, + Silence, + C2fCIB, + PSA, + SCDown, + RepVGGDW, + v10Detect +) from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml -from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss +from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss from ultralytics.utils.plotting import feature_visualization -from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts, - make_divisible, model_info, scale_img, time_sync) +from ultralytics.utils.torch_utils import ( + fuse_conv_and_bn, + fuse_deconv_and_bn, + initialize_weights, + intersect_dicts, + make_divisible, + model_info, + scale_img, + time_sync, +) try: import thop @@ -25,14 +77,11 @@ except ImportError: class BaseModel(nn.Module): - """ - The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family. - """ + """The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.""" def forward(self, x, *args, **kwargs): """ - Forward pass of the model on a single scale. - Wrapper for `_forward_once` method. + Forward pass of the model on a single scale. Wrapper for `_forward_once` method. Args: x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels. @@ -44,7 +93,7 @@ class BaseModel(nn.Module): return self.loss(x, *args, **kwargs) return self.predict(x, *args, **kwargs) - def predict(self, x, profile=False, visualize=False, augment=False): + def predict(self, x, profile=False, visualize=False, augment=False, embed=None): """ Perform a forward pass through the network. @@ -53,15 +102,16 @@ class BaseModel(nn.Module): profile (bool): Print the computation time of each layer if True, defaults to False. visualize (bool): Save the feature maps of the model if True, defaults to False. augment (bool): Augment image during prediction, defaults to False. + embed (list, optional): A list of feature vectors/embeddings to return. Returns: (torch.Tensor): The last output of the model. """ if augment: return self._predict_augment(x) - return self._predict_once(x, profile, visualize) + return self._predict_once(x, profile, visualize, embed) - def _predict_once(self, x, profile=False, visualize=False): + def _predict_once(self, x, profile=False, visualize=False, embed=None): """ Perform a forward pass through the network. @@ -69,11 +119,12 @@ class BaseModel(nn.Module): x (torch.Tensor): The input tensor to the model. profile (bool): Print the computation time of each layer if True, defaults to False. visualize (bool): Save the feature maps of the model if True, defaults to False. + embed (list, optional): A list of feature vectors/embeddings to return. Returns: (torch.Tensor): The last output of the model. """ - y, dt = [], [] # outputs + y, dt, embeddings = [], [], [] # outputs for m in self.model: if m.f != -1: # if not from previous layer x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers @@ -83,18 +134,24 @@ class BaseModel(nn.Module): y.append(x if m.i in self.save else None) # save output if visualize: feature_visualization(x, m.type, m.i, save_dir=visualize) + if embed and m.i in embed: + embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten + if m.i == max(embed): + return torch.unbind(torch.cat(embeddings, 1), dim=0) return x def _predict_augment(self, x): """Perform augmentations on input image x and return augmented inference.""" - LOGGER.warning(f'WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. ' - f'Reverting to single-scale inference instead.') + LOGGER.warning( + f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. " + f"Reverting to single-scale inference instead." + ) return self._predict_once(x) def _profile_one_layer(self, m, x, dt): """ - Profile the computation time and FLOPs of a single layer of the model on a given input. - Appends the results to the provided list. + Profile the computation time and FLOPs of a single layer of the model on a given input. Appends the results to + the provided list. Args: m (nn.Module): The layer to be profiled. @@ -105,14 +162,14 @@ class BaseModel(nn.Module): None """ c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix - flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs + flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # FLOPs t = time_sync() for _ in range(10): m(x.copy() if c else x) dt.append((time_sync() - t) * 100) if m == self.model[0]: LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module") - LOGGER.info(f'{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}') + LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}") if c: LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total") @@ -126,19 +183,22 @@ class BaseModel(nn.Module): """ if not self.is_fused(): for m in self.model.modules(): - if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'): + if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"): if isinstance(m, Conv2): m.fuse_convs() m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv - delattr(m, 'bn') # remove batchnorm + delattr(m, "bn") # remove batchnorm m.forward = m.forward_fuse # update forward - if isinstance(m, ConvTranspose) and hasattr(m, 'bn'): + if isinstance(m, ConvTranspose) and hasattr(m, "bn"): m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn) - delattr(m, 'bn') # remove batchnorm + delattr(m, "bn") # remove batchnorm m.forward = m.forward_fuse # update forward if isinstance(m, RepConv): m.fuse_convs() m.forward = m.forward_fuse # update forward + if isinstance(m, RepVGGDW): + m.fuse() + m.forward = m.forward_fuse self.info(verbose=verbose) return self @@ -153,12 +213,12 @@ class BaseModel(nn.Module): Returns: (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise. """ - bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d() + bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d() return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model def info(self, detailed=False, verbose=True, imgsz=640): """ - Prints model information + Prints model information. Args: detailed (bool): if True, prints out detailed information about the model. Defaults to False @@ -175,11 +235,11 @@ class BaseModel(nn.Module): fn (function): the function to apply to the model Returns: - A model that is a Detect() object. + (BaseModel): An updated BaseModel object. """ self = super()._apply(fn) m = self.model[-1] # Detect() - if isinstance(m, (Detect, Segment)): + if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect m.stride = fn(m.stride) m.anchors = fn(m.anchors) m.strides = fn(m.strides) @@ -193,53 +253,57 @@ class BaseModel(nn.Module): weights (dict | torch.nn.Module): The pre-trained weights to be loaded. verbose (bool, optional): Whether to log the transfer progress. Defaults to True. """ - model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts + model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts csd = model.float().state_dict() # checkpoint state_dict as FP32 csd = intersect_dicts(csd, self.state_dict()) # intersect self.load_state_dict(csd, strict=False) # load if verbose: - LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights') + LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights") def loss(self, batch, preds=None): """ - Compute loss + Compute loss. Args: batch (dict): Batch to compute loss on preds (torch.Tensor | List[torch.Tensor]): Predictions. """ - if not hasattr(self, 'criterion'): + if not hasattr(self, "criterion"): self.criterion = self.init_criterion() - preds = self.forward(batch['img']) if preds is None else preds + preds = self.forward(batch["img"]) if preds is None else preds return self.criterion(preds, batch) def init_criterion(self): - raise NotImplementedError('compute_loss() needs to be implemented by task heads') + """Initialize the loss criterion for the BaseModel.""" + raise NotImplementedError("compute_loss() needs to be implemented by task heads") class DetectionModel(BaseModel): """YOLOv8 detection model.""" - def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes + def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes + """Initialize the YOLOv8 detection model with the given config and parameters.""" super().__init__() self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict # Define model - ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels - if nc and nc != self.yaml['nc']: + ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels + if nc and nc != self.yaml["nc"]: LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") - self.yaml['nc'] = nc # override YAML value + self.yaml["nc"] = nc # override YAML value self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist - self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict - self.inplace = self.yaml.get('inplace', True) + self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict + self.inplace = self.yaml.get("inplace", True) # Build strides m = self.model[-1] # Detect() - if isinstance(m, (Detect, Segment, Pose)): + if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect s = 256 # 2x min stride m.inplace = self.inplace - forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose)) else self.forward(x) + forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x) + if isinstance(m, v10Detect): + forward = lambda x: self.forward(x)["one2many"] m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward self.stride = m.stride m.bias_init() # only run once @@ -250,7 +314,7 @@ class DetectionModel(BaseModel): initialize_weights(self) if verbose: self.info() - LOGGER.info('') + LOGGER.info("") def _predict_augment(self, x): """Perform augmentations on input image x and return augmented inference and train outputs.""" @@ -260,7 +324,11 @@ class DetectionModel(BaseModel): y = [] # outputs for si, fi in zip(s, f): xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max())) - yi = super().predict(xi)[0] # forward + yi = super().predict(xi) # forward + if isinstance(yi, dict): + yi = yi["one2one"] # yolov10 outputs + if isinstance(yi, (list, tuple)): + yi = yi[0] yi = self._descale_pred(yi, fi, si, img_size) y.append(yi) y = self._clip_augmented(y) # clip augmented tails @@ -278,51 +346,66 @@ class DetectionModel(BaseModel): return torch.cat((x, y, wh, cls), dim) def _clip_augmented(self, y): - """Clip YOLOv5 augmented inference tails.""" + """Clip YOLO augmented inference tails.""" nl = self.model[-1].nl # number of detection layers (P3-P5) - g = sum(4 ** x for x in range(nl)) # grid points + g = sum(4**x for x in range(nl)) # grid points e = 1 # exclude layer count - i = (y[0].shape[-1] // g) * sum(4 ** x for x in range(e)) # indices + i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # indices y[0] = y[0][..., :-i] # large i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices y[-1] = y[-1][..., i:] # small return y def init_criterion(self): + """Initialize the loss criterion for the DetectionModel.""" return v8DetectionLoss(self) +class OBBModel(DetectionModel): + """YOLOv8 Oriented Bounding Box (OBB) model.""" + + def __init__(self, cfg="yolov8n-obb.yaml", ch=3, nc=None, verbose=True): + """Initialize YOLOv8 OBB model with given config and parameters.""" + super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) + + def init_criterion(self): + """Initialize the loss criterion for the model.""" + return v8OBBLoss(self) + + class SegmentationModel(DetectionModel): """YOLOv8 segmentation model.""" - def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True): + def __init__(self, cfg="yolov8n-seg.yaml", ch=3, nc=None, verbose=True): """Initialize YOLOv8 segmentation model with given config and parameters.""" super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) def init_criterion(self): + """Initialize the loss criterion for the SegmentationModel.""" return v8SegmentationLoss(self) class PoseModel(DetectionModel): """YOLOv8 pose model.""" - def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True): + def __init__(self, cfg="yolov8n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True): """Initialize YOLOv8 Pose model.""" if not isinstance(cfg, dict): cfg = yaml_model_load(cfg) # load model YAML - if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']): + if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]): LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}") - cfg['kpt_shape'] = data_kpt_shape + cfg["kpt_shape"] = data_kpt_shape super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) def init_criterion(self): + """Initialize the loss criterion for the PoseModel.""" return v8PoseLoss(self) class ClassificationModel(BaseModel): """YOLOv8 classification model.""" - def __init__(self, cfg='yolov8n-cls.yaml', ch=3, nc=None, verbose=True): + def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True): """Init ClassificationModel with YAML, channels, number of classes, verbose flag.""" super().__init__() self._from_yaml(cfg, ch, nc, verbose) @@ -332,21 +415,21 @@ class ClassificationModel(BaseModel): self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict # Define model - ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels - if nc and nc != self.yaml['nc']: + ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels + if nc and nc != self.yaml["nc"]: LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") - self.yaml['nc'] = nc # override YAML value - elif not nc and not self.yaml.get('nc', None): - raise ValueError('nc not specified. Must specify nc in model.yaml or function arguments.') + self.yaml["nc"] = nc # override YAML value + elif not nc and not self.yaml.get("nc", None): + raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.") self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist self.stride = torch.Tensor([1]) # no stride constraints - self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict + self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict self.info() @staticmethod def reshape_outputs(model, nc): """Update a TorchVision classification model to class count 'n' if required.""" - name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module + name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module if isinstance(m, Classify): # YOLO Classify() head if m.linear.out_features != nc: m.linear = nn.Linear(m.linear.in_features, nc) @@ -365,70 +448,109 @@ class ClassificationModel(BaseModel): m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None) def init_criterion(self): - """Compute the classification loss between predictions and true labels.""" + """Initialize the loss criterion for the ClassificationModel.""" return v8ClassificationLoss() class RTDETRDetectionModel(DetectionModel): + """ + RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class. - def __init__(self, cfg='rtdetr-l.yaml', ch=3, nc=None, verbose=True): + This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both + the training and inference processes. RTDETR is an object detection and tracking model that extends from the + DetectionModel base class. + + Attributes: + cfg (str): The configuration file path or preset string. Default is 'rtdetr-l.yaml'. + ch (int): Number of input channels. Default is 3 (RGB). + nc (int, optional): Number of classes for object detection. Default is None. + verbose (bool): Specifies if summary statistics are shown during initialization. Default is True. + + Methods: + init_criterion: Initializes the criterion used for loss calculation. + loss: Computes and returns the loss during training. + predict: Performs a forward pass through the network and returns the output. + """ + + def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True): + """ + Initialize the RTDETRDetectionModel. + + Args: + cfg (str): Configuration file name or path. + ch (int): Number of input channels. + nc (int, optional): Number of classes. Defaults to None. + verbose (bool, optional): Print additional information during initialization. Defaults to True. + """ super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) def init_criterion(self): - """Compute the classification loss between predictions and true labels.""" + """Initialize the loss criterion for the RTDETRDetectionModel.""" from ultralytics.models.utils.loss import RTDETRDetectionLoss return RTDETRDetectionLoss(nc=self.nc, use_vfl=True) def loss(self, batch, preds=None): - if not hasattr(self, 'criterion'): + """ + Compute the loss for the given batch of data. + + Args: + batch (dict): Dictionary containing image and label data. + preds (torch.Tensor, optional): Precomputed model predictions. Defaults to None. + + Returns: + (tuple): A tuple containing the total loss and main three losses in a tensor. + """ + if not hasattr(self, "criterion"): self.criterion = self.init_criterion() - img = batch['img'] + img = batch["img"] # NOTE: preprocess gt_bbox and gt_labels to list. bs = len(img) - batch_idx = batch['batch_idx'] + batch_idx = batch["batch_idx"] gt_groups = [(batch_idx == i).sum().item() for i in range(bs)] targets = { - 'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1), - 'bboxes': batch['bboxes'].to(device=img.device), - 'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1), - 'gt_groups': gt_groups} + "cls": batch["cls"].to(img.device, dtype=torch.long).view(-1), + "bboxes": batch["bboxes"].to(device=img.device), + "batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1), + "gt_groups": gt_groups, + } preds = self.predict(img, batch=targets) if preds is None else preds dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1] if dn_meta is None: dn_bboxes, dn_scores = None, None else: - dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2) - dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2) + dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2) + dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2) dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4) dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores]) - loss = self.criterion((dec_bboxes, dec_scores), - targets, - dn_bboxes=dn_bboxes, - dn_scores=dn_scores, - dn_meta=dn_meta) + loss = self.criterion( + (dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta + ) # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses. - return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']], - device=img.device) + return sum(loss.values()), torch.as_tensor( + [loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device + ) - def predict(self, x, profile=False, visualize=False, batch=None, augment=False): + def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None): """ - Perform a forward pass through the network. + Perform a forward pass through the model. Args: - x (torch.Tensor): The input tensor to the model - profile (bool): Print the computation time of each layer if True, defaults to False. - visualize (bool): Save the feature maps of the model if True, defaults to False - batch (dict): A dict including gt boxes and labels from dataloader. + x (torch.Tensor): The input tensor. + profile (bool, optional): If True, profile the computation time for each layer. Defaults to False. + visualize (bool, optional): If True, save feature maps for visualization. Defaults to False. + batch (dict, optional): Ground truth data for evaluation. Defaults to None. + augment (bool, optional): If True, perform data augmentation during inference. Defaults to False. + embed (list, optional): A list of feature vectors/embeddings to return. Returns: - (torch.Tensor): The last output of the model. + (torch.Tensor): Model's output tensor. """ - y, dt = [], [] # outputs + y, dt, embeddings = [], [], [] # outputs for m in self.model[:-1]: # except the head part if m.f != -1: # if not from previous layer x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers @@ -438,11 +560,91 @@ class RTDETRDetectionModel(DetectionModel): y.append(x if m.i in self.save else None) # save output if visualize: feature_visualization(x, m.type, m.i, save_dir=visualize) + if embed and m.i in embed: + embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten + if m.i == max(embed): + return torch.unbind(torch.cat(embeddings, 1), dim=0) head = self.model[-1] x = head([y[j] for j in head.f], batch) # head inference return x +class WorldModel(DetectionModel): + """YOLOv8 World Model.""" + + def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True): + """Initialize YOLOv8 world model with given config and parameters.""" + self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder + self.clip_model = None # CLIP model placeholder + super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) + + def set_classes(self, text): + """Perform a forward pass with optional profiling, visualization, and embedding extraction.""" + try: + import clip + except ImportError: + check_requirements("git+https://github.com/openai/CLIP.git") + import clip + + if not getattr(self, "clip_model", None): # for backwards compatibility of models lacking clip_model attribute + self.clip_model = clip.load("ViT-B/32")[0] + device = next(self.clip_model.parameters()).device + text_token = clip.tokenize(text).to(device) + txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32) + txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) + self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach() + self.model[-1].nc = len(text) + + def init_criterion(self): + """Initialize the loss criterion for the model.""" + raise NotImplementedError + + def predict(self, x, profile=False, visualize=False, augment=False, embed=None): + """ + Perform a forward pass through the model. + + Args: + x (torch.Tensor): The input tensor. + profile (bool, optional): If True, profile the computation time for each layer. Defaults to False. + visualize (bool, optional): If True, save feature maps for visualization. Defaults to False. + augment (bool, optional): If True, perform data augmentation during inference. Defaults to False. + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (torch.Tensor): Model's output tensor. + """ + txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype) + if len(txt_feats) != len(x): + txt_feats = txt_feats.repeat(len(x), 1, 1) + ori_txt_feats = txt_feats.clone() + y, dt, embeddings = [], [], [] # outputs + for m in self.model: # except the head part + if m.f != -1: # if not from previous layer + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers + if profile: + self._profile_one_layer(m, x, dt) + if isinstance(m, C2fAttn): + x = m(x, txt_feats) + elif isinstance(m, WorldDetect): + x = m(x, ori_txt_feats) + elif isinstance(m, ImagePoolingAttn): + txt_feats = m(x, txt_feats) + else: + x = m(x) # run + + y.append(x if m.i in self.save else None) # save output + if visualize: + feature_visualization(x, m.type, m.i, save_dir=visualize) + if embed and m.i in embed: + embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten + if m.i == max(embed): + return torch.unbind(torch.cat(embeddings, 1), dim=0) + return x + +class YOLOv10DetectionModel(DetectionModel): + def init_criterion(self): + return v10DetectLoss(self) + class Ensemble(nn.ModuleList): """Ensemble of models.""" @@ -451,7 +653,7 @@ class Ensemble(nn.ModuleList): super().__init__() def forward(self, x, augment=False, profile=False, visualize=False): - """Function generates the YOLOv5 network's final layer.""" + """Function generates the YOLO network's final layer.""" y = [module(x, augment, profile, visualize)[0] for module in self] # y = torch.stack(y).max(0)[0] # max ensemble # y = torch.stack(y).mean(0) # mean ensemble @@ -490,6 +692,7 @@ def temporary_modules(modules=None): import importlib import sys + try: # Set modules in sys.modules under their old name for old, new in modules.items(): @@ -517,30 +720,47 @@ def torch_safe_load(weight): """ from ultralytics.utils.downloads import attempt_download_asset - check_suffix(file=weight, suffix='.pt') + check_suffix(file=weight, suffix=".pt") file = attempt_download_asset(weight) # search online if missing locally try: - with temporary_modules({ - 'ultralytics.yolo.utils': 'ultralytics.utils', - 'ultralytics.yolo.v8': 'ultralytics.models.yolo', - 'ultralytics.yolo.data': 'ultralytics.data'}): # for legacy 8.0 Classify and Pose models - return torch.load(file, map_location='cpu'), file # load + with temporary_modules( + { + "ultralytics.yolo.utils": "ultralytics.utils", + "ultralytics.yolo.v8": "ultralytics.models.yolo", + "ultralytics.yolo.data": "ultralytics.data", + } + ): # for legacy 8.0 Classify and Pose models + ckpt = torch.load(file, map_location="cpu") except ModuleNotFoundError as e: # e.name is missing module name - if e.name == 'models': + if e.name == "models": raise TypeError( - emojis(f'ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained ' - f'with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with ' - f'YOLOv8 at https://github.com/ultralytics/ultralytics.' - f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " - f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")) from e - LOGGER.warning(f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements." - f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future." - f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " - f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'") + emojis( + f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained " + f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with " + f"YOLOv8 at https://github.com/ultralytics/ultralytics." + f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " + f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'" + ) + ) from e + LOGGER.warning( + f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements." + f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future." + f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " + f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'" + ) check_requirements(e.name) # install missing module + ckpt = torch.load(file, map_location="cpu") - return torch.load(file, map_location='cpu'), file # load + if not isinstance(ckpt, dict): + # File is likely a YOLO instance saved with i.e. torch.save(model, "saved_model.pt") + LOGGER.warning( + f"WARNING ⚠️ The file '{weight}' appears to be improperly saved or formatted. " + f"For optimal results, use model.save('filename.pt') to correctly save YOLO models." + ) + ckpt = {"model": ckpt.model} + + return ckpt, file # load def attempt_load_weights(weights, device=None, inplace=True, fuse=False): @@ -549,25 +769,24 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): ensemble = Ensemble() for w in weights if isinstance(weights, list) else [weights]: ckpt, w = torch_safe_load(w) # load ckpt - args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} if 'train_args' in ckpt else None # combined args - model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model + args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args + model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model # Model compatibility updates model.args = args # attach args to model model.pt_path = w # attach *.pt file path to model model.task = guess_model_task(model) - if not hasattr(model, 'stride'): - model.stride = torch.tensor([32.]) + if not hasattr(model, "stride"): + model.stride = torch.tensor([32.0]) # Append - ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode + ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode # Module updates for m in ensemble.modules(): - t = type(m) - if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment): + if hasattr(m, "inplace"): m.inplace = inplace - elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'): + elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"): m.recompute_scale_factor = None # torch 1.11.0 compatibility # Return model @@ -575,35 +794,34 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): return ensemble[-1] # Return ensemble - LOGGER.info(f'Ensemble created with {weights}\n') - for k in 'names', 'nc', 'yaml': + LOGGER.info(f"Ensemble created with {weights}\n") + for k in "names", "nc", "yaml": setattr(ensemble, k, getattr(ensemble[0], k)) - ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride - assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts {[m.nc for m in ensemble]}' + ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].stride + assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}" return ensemble def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): """Loads a single model weights.""" ckpt, weight = torch_safe_load(weight) # load ckpt - args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args - model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model + args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args + model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model # Model compatibility updates model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model model.pt_path = weight # attach *.pt file path to model model.task = guess_model_task(model) - if not hasattr(model, 'stride'): - model.stride = torch.tensor([32.]) + if not hasattr(model, "stride"): + model.stride = torch.tensor([32.0]) - model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode + model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() # model in eval mode # Module updates for m in model.modules(): - t = type(m) - if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment): + if hasattr(m, "inplace"): m.inplace = inplace - elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'): + elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"): m.recompute_scale_factor = None # torch 1.11.0 compatibility # Return model and ckpt @@ -615,11 +833,11 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) import ast # Args - max_channels = float('inf') - nc, act, scales = (d.get(x) for x in ('nc', 'activation', 'scales')) - depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape')) + max_channels = float("inf") + nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales")) + depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape")) if scales: - scale = d.get('scale') + scale = d.get("scale") if not scale: scale = tuple(scales.keys())[0] LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.") @@ -634,52 +852,92 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}") ch = [ch] layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out - for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args - m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module + for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args + m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module for j, a in enumerate(args): if isinstance(a, str): with contextlib.suppress(ValueError): args[j] = locals()[a] if a in locals() else ast.literal_eval(a) n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain - if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, - BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3): + if m in { + Classify, + Conv, + ConvTranspose, + GhostConv, + Bottleneck, + GhostBottleneck, + SPP, + SPPF, + DWConv, + Focus, + BottleneckCSP, + C1, + C2, + C2f, + RepNCSPELAN4, + ADown, + SPPELAN, + C2fAttn, + C3, + C3TR, + C3Ghost, + nn.ConvTranspose2d, + DWConvTranspose2d, + C3x, + RepC3, + PSA, + SCDown, + C2fCIB + }: c1, c2 = ch[f], args[0] if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output) c2 = make_divisible(min(c2, max_channels) * width, 8) + if m is C2fAttn: + args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) # embed channels + args[2] = int( + max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2] + ) # num heads args = [c1, c2, *args[1:]] - if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3): + if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fCIB): args.insert(2, n) # number of repeats n = 1 elif m is AIFI: args = [ch[f], *args] - elif m in (HGStem, HGBlock): + elif m in {HGStem, HGBlock}: c1, cm, c2 = ch[f], args[0], args[1] args = [c1, cm, c2, *args[2:]] if m is HGBlock: args.insert(4, n) # number of repeats n = 1 - + elif m is ResNetLayer: + c2 = args[1] if args[3] else args[1] * 4 elif m is nn.BatchNorm2d: args = [ch[f]] elif m is Concat: c2 = sum(ch[x] for x in f) - elif m in (Detect, Segment, Pose): + elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}: args.append([ch[x] for x in f]) if m is Segment: args[2] = make_divisible(min(args[2], max_channels) * width, 8) elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1 args.insert(1, [ch[x] for x in f]) + elif m is CBLinear: + c2 = args[0] + c1 = ch[f] + args = [c1, c2, *args[1:]] + elif m is CBFuse: + c2 = ch[f[-1]] else: c2 = ch[f] m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module - t = str(m)[8:-2].replace('__main__.', '') # module type + t = str(m)[8:-2].replace("__main__.", "") # module type m.np = sum(x.numel() for x in m_.parameters()) # number params m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type if verbose: - LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print + LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}") # print save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist layers.append(m_) if i == 0: @@ -693,24 +951,27 @@ def yaml_model_load(path): import re path = Path(path) - if path.stem in (f'yolov{d}{x}6' for x in 'nsmlx' for d in (5, 8)): - new_stem = re.sub(r'(\d+)([nslmx])6(.+)?$', r'\1\2-p6\3', path.stem) - LOGGER.warning(f'WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.') + if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)): + new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem) + LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.") path = path.with_name(new_stem + path.suffix) - unified_path = re.sub(r'(\d+)([nslmx])(.+)?$', r'\1\3', str(path)) # i.e. yolov8x.yaml -> yolov8.yaml + if "v10" not in str(path): + unified_path = re.sub(r"(\d+)([nsblmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml + else: + unified_path = path yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path) d = yaml_load(yaml_file) # model dict - d['scale'] = guess_model_scale(path) - d['yaml_file'] = str(path) + d["scale"] = guess_model_scale(path) + d["yaml_file"] = str(path) return d def guess_model_scale(model_path): """ - Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale. - The function uses regular expression matching to find the pattern of the model scale in the YAML file name, - which is denoted by n, s, m, l, or x. The function returns the size character of the model scale as a string. + Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale. The function + uses regular expression matching to find the pattern of the model scale in the YAML file name, which is denoted by + n, s, m, l, or x. The function returns the size character of the model scale as a string. Args: model_path (str | Path): The path to the YOLO model's YAML file. @@ -720,8 +981,9 @@ def guess_model_scale(model_path): """ with contextlib.suppress(AttributeError): import re - return re.search(r'yolov\d+([nslmx])', Path(model_path).stem).group(1) # n, s, m, l, or x - return '' + + return re.search(r"yolov\d+([nsblmx])", Path(model_path).stem).group(1) # n, s, m, l, or x + return "" def guess_model_task(model): @@ -740,15 +1002,17 @@ def guess_model_task(model): def cfg2task(cfg): """Guess from YAML dictionary.""" - m = cfg['head'][-1][-2].lower() # output module name - if m in ('classify', 'classifier', 'cls', 'fc'): - return 'classify' - if m == 'detect': - return 'detect' - if m == 'segment': - return 'segment' - if m == 'pose': - return 'pose' + m = cfg["head"][-1][-2].lower() # output module name + if m in {"classify", "classifier", "cls", "fc"}: + return "classify" + if m == "detect" or m == "v10detect": + return "detect" + if m == "segment": + return "segment" + if m == "pose": + return "pose" + if m == "obb": + return "obb" # Guess from model cfg if isinstance(model, dict): @@ -757,36 +1021,42 @@ def guess_model_task(model): # Guess from PyTorch model if isinstance(model, nn.Module): # PyTorch model - for x in 'model.args', 'model.model.args', 'model.model.model.args': + for x in "model.args", "model.model.args", "model.model.model.args": with contextlib.suppress(Exception): - return eval(x)['task'] - for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml': + return eval(x)["task"] + for x in "model.yaml", "model.model.yaml", "model.model.model.yaml": with contextlib.suppress(Exception): return cfg2task(eval(x)) for m in model.modules(): - if isinstance(m, Detect): - return 'detect' - elif isinstance(m, Segment): - return 'segment' + if isinstance(m, Segment): + return "segment" elif isinstance(m, Classify): - return 'classify' + return "classify" elif isinstance(m, Pose): - return 'pose' + return "pose" + elif isinstance(m, OBB): + return "obb" + elif isinstance(m, (Detect, WorldDetect, v10Detect)): + return "detect" # Guess from model filename if isinstance(model, (str, Path)): model = Path(model) - if '-seg' in model.stem or 'segment' in model.parts: - return 'segment' - elif '-cls' in model.stem or 'classify' in model.parts: - return 'classify' - elif '-pose' in model.stem or 'pose' in model.parts: - return 'pose' - elif 'detect' in model.parts: - return 'detect' + if "-seg" in model.stem or "segment" in model.parts: + return "segment" + elif "-cls" in model.stem or "classify" in model.parts: + return "classify" + elif "-pose" in model.stem or "pose" in model.parts: + return "pose" + elif "-obb" in model.stem or "obb" in model.parts: + return "obb" + elif "detect" in model.parts: + return "detect" # Unable to determine task from model - LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. " - "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify', or 'pose'.") - return 'detect' # assume detect + LOGGER.warning( + "WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. " + "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'." + ) + return "detect" # assume detect diff --git a/ultralytics/solutions/__init__.py b/ultralytics/solutions/__init__.py new file mode 100644 index 0000000..9e68dc1 --- /dev/null +++ b/ultralytics/solutions/__init__.py @@ -0,0 +1 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license diff --git a/ultralytics/solutions/ai_gym.py b/ultralytics/solutions/ai_gym.py new file mode 100644 index 0000000..b78cf59 --- /dev/null +++ b/ultralytics/solutions/ai_gym.py @@ -0,0 +1,150 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import cv2 + +from ultralytics.utils.checks import check_imshow +from ultralytics.utils.plotting import Annotator + + +class AIGym: + """A class to manage the gym steps of people in a real-time video stream based on their poses.""" + + def __init__(self): + """Initializes the AIGym with default values for Visual and Image parameters.""" + + # Image and line thickness + self.im0 = None + self.tf = None + + # Keypoints and count information + self.keypoints = None + self.poseup_angle = None + self.posedown_angle = None + self.threshold = 0.001 + + # Store stage, count and angle information + self.angle = None + self.count = None + self.stage = None + self.pose_type = "pushup" + self.kpts_to_check = None + + # Visual Information + self.view_img = False + self.annotator = None + + # Check if environment support imshow + self.env_check = check_imshow(warn=True) + + def set_args( + self, + kpts_to_check, + line_thickness=2, + view_img=False, + pose_up_angle=145.0, + pose_down_angle=90.0, + pose_type="pullup", + ): + """ + Configures the AIGym line_thickness, save image and view image parameters. + + Args: + kpts_to_check (list): 3 keypoints for counting + line_thickness (int): Line thickness for bounding boxes. + view_img (bool): display the im0 + pose_up_angle (float): Angle to set pose position up + pose_down_angle (float): Angle to set pose position down + pose_type (str): "pushup", "pullup" or "abworkout" + """ + self.kpts_to_check = kpts_to_check + self.tf = line_thickness + self.view_img = view_img + self.poseup_angle = pose_up_angle + self.posedown_angle = pose_down_angle + self.pose_type = pose_type + + def start_counting(self, im0, results, frame_count): + """ + Function used to count the gym steps. + + Args: + im0 (ndarray): Current frame from the video stream. + results (list): Pose estimation data + frame_count (int): store current frame count + """ + self.im0 = im0 + if frame_count == 1: + self.count = [0] * len(results[0]) + self.angle = [0] * len(results[0]) + self.stage = ["-" for _ in results[0]] + self.keypoints = results[0].keypoints.data + self.annotator = Annotator(im0, line_width=2) + + for ind, k in enumerate(reversed(self.keypoints)): + if self.pose_type in ["pushup", "pullup"]: + self.angle[ind] = self.annotator.estimate_pose_angle( + k[int(self.kpts_to_check[0])].cpu(), + k[int(self.kpts_to_check[1])].cpu(), + k[int(self.kpts_to_check[2])].cpu(), + ) + self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10) + + if self.pose_type == "abworkout": + self.angle[ind] = self.annotator.estimate_pose_angle( + k[int(self.kpts_to_check[0])].cpu(), + k[int(self.kpts_to_check[1])].cpu(), + k[int(self.kpts_to_check[2])].cpu(), + ) + self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10) + if self.angle[ind] > self.poseup_angle: + self.stage[ind] = "down" + if self.angle[ind] < self.posedown_angle and self.stage[ind] == "down": + self.stage[ind] = "up" + self.count[ind] += 1 + self.annotator.plot_angle_and_count_and_stage( + angle_text=self.angle[ind], + count_text=self.count[ind], + stage_text=self.stage[ind], + center_kpt=k[int(self.kpts_to_check[1])], + line_thickness=self.tf, + ) + + if self.pose_type == "pushup": + if self.angle[ind] > self.poseup_angle: + self.stage[ind] = "up" + if self.angle[ind] < self.posedown_angle and self.stage[ind] == "up": + self.stage[ind] = "down" + self.count[ind] += 1 + self.annotator.plot_angle_and_count_and_stage( + angle_text=self.angle[ind], + count_text=self.count[ind], + stage_text=self.stage[ind], + center_kpt=k[int(self.kpts_to_check[1])], + line_thickness=self.tf, + ) + if self.pose_type == "pullup": + if self.angle[ind] > self.poseup_angle: + self.stage[ind] = "down" + if self.angle[ind] < self.posedown_angle and self.stage[ind] == "down": + self.stage[ind] = "up" + self.count[ind] += 1 + self.annotator.plot_angle_and_count_and_stage( + angle_text=self.angle[ind], + count_text=self.count[ind], + stage_text=self.stage[ind], + center_kpt=k[int(self.kpts_to_check[1])], + line_thickness=self.tf, + ) + + self.annotator.kpts(k, shape=(640, 640), radius=1, kpt_line=True) + + if self.env_check and self.view_img: + cv2.imshow("Ultralytics YOLOv8 AI GYM", self.im0) + if cv2.waitKey(1) & 0xFF == ord("q"): + return + + return self.im0 + + +if __name__ == "__main__": + AIGym() diff --git a/ultralytics/solutions/distance_calculation.py b/ultralytics/solutions/distance_calculation.py new file mode 100644 index 0000000..f09209e --- /dev/null +++ b/ultralytics/solutions/distance_calculation.py @@ -0,0 +1,181 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import math + +import cv2 + +from ultralytics.utils.checks import check_imshow +from ultralytics.utils.plotting import Annotator, colors + + +class DistanceCalculation: + """A class to calculate distance between two objects in real-time video stream based on their tracks.""" + + def __init__(self): + """Initializes the distance calculation class with default values for Visual, Image, track and distance + parameters. + """ + + # Visual & im0 information + self.im0 = None + self.annotator = None + self.view_img = False + self.line_color = (255, 255, 0) + self.centroid_color = (255, 0, 255) + + # Predict/track information + self.clss = None + self.names = None + self.boxes = None + self.line_thickness = 2 + self.trk_ids = None + + # Distance calculation information + self.centroids = [] + self.pixel_per_meter = 10 + + # Mouse event + self.left_mouse_count = 0 + self.selected_boxes = {} + + # Check if environment support imshow + self.env_check = check_imshow(warn=True) + + def set_args( + self, + names, + pixels_per_meter=10, + view_img=False, + line_thickness=2, + line_color=(255, 255, 0), + centroid_color=(255, 0, 255), + ): + """ + Configures the distance calculation and display parameters. + + Args: + names (dict): object detection classes names + pixels_per_meter (int): Number of pixels in meter + view_img (bool): Flag indicating frame display + line_thickness (int): Line thickness for bounding boxes. + line_color (RGB): color of centroids line + centroid_color (RGB): colors of bbox centroids + """ + self.names = names + self.pixel_per_meter = pixels_per_meter + self.view_img = view_img + self.line_thickness = line_thickness + self.line_color = line_color + self.centroid_color = centroid_color + + def mouse_event_for_distance(self, event, x, y, flags, param): + """ + This function is designed to move region with mouse events in a real-time video stream. + + Args: + event (int): The type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN, etc.). + x (int): The x-coordinate of the mouse pointer. + y (int): The y-coordinate of the mouse pointer. + flags (int): Any flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, + cv2.EVENT_FLAG_SHIFTKEY, etc.). + param (dict): Additional parameters you may want to pass to the function. + """ + global selected_boxes + global left_mouse_count + if event == cv2.EVENT_LBUTTONDOWN: + self.left_mouse_count += 1 + if self.left_mouse_count <= 2: + for box, track_id in zip(self.boxes, self.trk_ids): + if box[0] < x < box[2] and box[1] < y < box[3] and track_id not in self.selected_boxes: + self.selected_boxes[track_id] = [] + self.selected_boxes[track_id] = box + + if event == cv2.EVENT_RBUTTONDOWN: + self.selected_boxes = {} + self.left_mouse_count = 0 + + def extract_tracks(self, tracks): + """ + Extracts results from the provided data. + + Args: + tracks (list): List of tracks obtained from the object tracking process. + """ + self.boxes = tracks[0].boxes.xyxy.cpu() + self.clss = tracks[0].boxes.cls.cpu().tolist() + self.trk_ids = tracks[0].boxes.id.int().cpu().tolist() + + def calculate_centroid(self, box): + """ + Calculate the centroid of bounding box. + + Args: + box (list): Bounding box data + """ + return int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2) + + def calculate_distance(self, centroid1, centroid2): + """ + Calculate distance between two centroids. + + Args: + centroid1 (point): First bounding box data + centroid2 (point): Second bounding box data + """ + pixel_distance = math.sqrt((centroid1[0] - centroid2[0]) ** 2 + (centroid1[1] - centroid2[1]) ** 2) + return pixel_distance / self.pixel_per_meter, (pixel_distance / self.pixel_per_meter) * 1000 + + def start_process(self, im0, tracks): + """ + Calculate distance between two bounding boxes based on tracking data. + + Args: + im0 (nd array): Image + tracks (list): List of tracks obtained from the object tracking process. + """ + self.im0 = im0 + if tracks[0].boxes.id is None: + if self.view_img: + self.display_frames() + return + self.extract_tracks(tracks) + + self.annotator = Annotator(self.im0, line_width=2) + + for box, cls, track_id in zip(self.boxes, self.clss, self.trk_ids): + self.annotator.box_label(box, color=colors(int(cls), True), label=self.names[int(cls)]) + + if len(self.selected_boxes) == 2: + for trk_id, _ in self.selected_boxes.items(): + if trk_id == track_id: + self.selected_boxes[track_id] = box + + if len(self.selected_boxes) == 2: + for trk_id, box in self.selected_boxes.items(): + centroid = self.calculate_centroid(self.selected_boxes[trk_id]) + self.centroids.append(centroid) + + distance_m, distance_mm = self.calculate_distance(self.centroids[0], self.centroids[1]) + self.annotator.plot_distance_and_line( + distance_m, distance_mm, self.centroids, self.line_color, self.centroid_color + ) + + self.centroids = [] + + if self.view_img and self.env_check: + self.display_frames() + + return im0 + + def display_frames(self): + """Display frame.""" + cv2.namedWindow("Ultralytics Distance Estimation") + cv2.setMouseCallback("Ultralytics Distance Estimation", self.mouse_event_for_distance) + cv2.imshow("Ultralytics Distance Estimation", self.im0) + + if cv2.waitKey(1) & 0xFF == ord("q"): + return + + +if __name__ == "__main__": + DistanceCalculation() diff --git a/ultralytics/solutions/heatmap.py b/ultralytics/solutions/heatmap.py new file mode 100644 index 0000000..f70e62b --- /dev/null +++ b/ultralytics/solutions/heatmap.py @@ -0,0 +1,281 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from collections import defaultdict + +import cv2 +import numpy as np + +from ultralytics.utils.checks import check_imshow, check_requirements +from ultralytics.utils.plotting import Annotator + +check_requirements("shapely>=2.0.0") + +from shapely.geometry import LineString, Point, Polygon + + +class Heatmap: + """A class to draw heatmaps in real-time video stream based on their tracks.""" + + def __init__(self): + """Initializes the heatmap class with default values for Visual, Image, track, count and heatmap parameters.""" + + # Visual information + self.annotator = None + self.view_img = False + self.shape = "circle" + + # Image information + self.imw = None + self.imh = None + self.im0 = None + self.view_in_counts = True + self.view_out_counts = True + + # Heatmap colormap and heatmap np array + self.colormap = None + self.heatmap = None + self.heatmap_alpha = 0.5 + + # Predict/track information + self.boxes = None + self.track_ids = None + self.clss = None + self.track_history = defaultdict(list) + + # Region & Line Information + self.count_reg_pts = None + self.counting_region = None + self.line_dist_thresh = 15 + self.region_thickness = 5 + self.region_color = (255, 0, 255) + + # Object Counting Information + self.in_counts = 0 + self.out_counts = 0 + self.counting_list = [] + self.count_txt_thickness = 0 + self.count_txt_color = (0, 0, 0) + self.count_color = (255, 255, 255) + + # Decay factor + self.decay_factor = 0.99 + + # Check if environment support imshow + self.env_check = check_imshow(warn=True) + + def set_args( + self, + imw, + imh, + colormap=cv2.COLORMAP_JET, + heatmap_alpha=0.5, + view_img=False, + view_in_counts=True, + view_out_counts=True, + count_reg_pts=None, + count_txt_thickness=2, + count_txt_color=(0, 0, 0), + count_color=(255, 255, 255), + count_reg_color=(255, 0, 255), + region_thickness=5, + line_dist_thresh=15, + decay_factor=0.99, + shape="circle", + ): + """ + Configures the heatmap colormap, width, height and display parameters. + + Args: + colormap (cv2.COLORMAP): The colormap to be set. + imw (int): The width of the frame. + imh (int): The height of the frame. + heatmap_alpha (float): alpha value for heatmap display + view_img (bool): Flag indicating frame display + view_in_counts (bool): Flag to control whether to display the incounts on video stream. + view_out_counts (bool): Flag to control whether to display the outcounts on video stream. + count_reg_pts (list): Object counting region points + count_txt_thickness (int): Text thickness for object counting display + count_txt_color (RGB color): count text color value + count_color (RGB color): count text background color value + count_reg_color (RGB color): Color of object counting region + region_thickness (int): Object counting Region thickness + line_dist_thresh (int): Euclidean Distance threshold for line counter + decay_factor (float): value for removing heatmap area after object passed + shape (str): Heatmap shape, rect or circle shape supported + """ + self.imw = imw + self.imh = imh + self.heatmap_alpha = heatmap_alpha + self.view_img = view_img + self.view_in_counts = view_in_counts + self.view_out_counts = view_out_counts + self.colormap = colormap + + # Region and line selection + if count_reg_pts is not None: + if len(count_reg_pts) == 2: + print("Line Counter Initiated.") + self.count_reg_pts = count_reg_pts + self.counting_region = LineString(count_reg_pts) + + elif len(count_reg_pts) == 4: + print("Region Counter Initiated.") + self.count_reg_pts = count_reg_pts + self.counting_region = Polygon(self.count_reg_pts) + + else: + print("Region or line points Invalid, 2 or 4 points supported") + print("Using Line Counter Now") + self.counting_region = Polygon([(20, 400), (1260, 400)]) # dummy points + + # Heatmap new frame + self.heatmap = np.zeros((int(self.imh), int(self.imw)), dtype=np.float32) + + self.count_txt_thickness = count_txt_thickness + self.count_txt_color = count_txt_color + self.count_color = count_color + self.region_color = count_reg_color + self.region_thickness = region_thickness + self.decay_factor = decay_factor + self.line_dist_thresh = line_dist_thresh + self.shape = shape + + # shape of heatmap, if not selected + if self.shape not in ["circle", "rect"]: + print("Unknown shape value provided, 'circle' & 'rect' supported") + print("Using Circular shape now") + self.shape = "circle" + + def extract_results(self, tracks): + """ + Extracts results from the provided data. + + Args: + tracks (list): List of tracks obtained from the object tracking process. + """ + self.boxes = tracks[0].boxes.xyxy.cpu() + self.clss = tracks[0].boxes.cls.cpu().tolist() + self.track_ids = tracks[0].boxes.id.int().cpu().tolist() + + def generate_heatmap(self, im0, tracks): + """ + Generate heatmap based on tracking data. + + Args: + im0 (nd array): Image + tracks (list): List of tracks obtained from the object tracking process. + """ + self.im0 = im0 + if tracks[0].boxes.id is None: + self.heatmap = np.zeros((int(self.imh), int(self.imw)), dtype=np.float32) + if self.view_img and self.env_check: + self.display_frames() + return im0 + self.heatmap *= self.decay_factor # decay factor + self.extract_results(tracks) + self.annotator = Annotator(self.im0, self.count_txt_thickness, None) + + if self.count_reg_pts is not None: + # Draw counting region + if self.view_in_counts or self.view_out_counts: + self.annotator.draw_region( + reg_pts=self.count_reg_pts, color=self.region_color, thickness=self.region_thickness + ) + + for box, cls, track_id in zip(self.boxes, self.clss, self.track_ids): + if self.shape == "circle": + center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)) + radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2 + + y, x = np.ogrid[0 : self.heatmap.shape[0], 0 : self.heatmap.shape[1]] + mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2 + + self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += ( + 2 * mask[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] + ) + + else: + self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += 2 + + # Store tracking hist + track_line = self.track_history[track_id] + track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2))) + if len(track_line) > 30: + track_line.pop(0) + + # Count objects + if len(self.count_reg_pts) == 4: + if self.counting_region.contains(Point(track_line[-1])) and track_id not in self.counting_list: + self.counting_list.append(track_id) + if box[0] < self.counting_region.centroid.x: + self.out_counts += 1 + else: + self.in_counts += 1 + + elif len(self.count_reg_pts) == 2: + distance = Point(track_line[-1]).distance(self.counting_region) + if distance < self.line_dist_thresh and track_id not in self.counting_list: + self.counting_list.append(track_id) + if box[0] < self.counting_region.centroid.x: + self.out_counts += 1 + else: + self.in_counts += 1 + else: + for box, cls in zip(self.boxes, self.clss): + if self.shape == "circle": + center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)) + radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2 + + y, x = np.ogrid[0 : self.heatmap.shape[0], 0 : self.heatmap.shape[1]] + mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2 + + self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += ( + 2 * mask[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] + ) + + else: + self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += 2 + + # Normalize, apply colormap to heatmap and combine with original image + heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX) + heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), self.colormap) + + incount_label = f"In Count : {self.in_counts}" + outcount_label = f"OutCount : {self.out_counts}" + + # Display counts based on user choice + counts_label = None + if not self.view_in_counts and not self.view_out_counts: + counts_label = None + elif not self.view_in_counts: + counts_label = outcount_label + elif not self.view_out_counts: + counts_label = incount_label + else: + counts_label = f"{incount_label} {outcount_label}" + + if self.count_reg_pts is not None and counts_label is not None: + self.annotator.count_labels( + counts=counts_label, + count_txt_size=self.count_txt_thickness, + txt_color=self.count_txt_color, + color=self.count_color, + ) + + self.im0 = cv2.addWeighted(self.im0, 1 - self.heatmap_alpha, heatmap_colored, self.heatmap_alpha, 0) + + if self.env_check and self.view_img: + self.display_frames() + + return self.im0 + + def display_frames(self): + """Display frame.""" + cv2.imshow("Ultralytics Heatmap", self.im0) + + if cv2.waitKey(1) & 0xFF == ord("q"): + return + + +if __name__ == "__main__": + Heatmap() diff --git a/ultralytics/solutions/object_counter.py b/ultralytics/solutions/object_counter.py new file mode 100644 index 0000000..18f42c6 --- /dev/null +++ b/ultralytics/solutions/object_counter.py @@ -0,0 +1,278 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from collections import defaultdict + +import cv2 + +from ultralytics.utils.checks import check_imshow, check_requirements +from ultralytics.utils.plotting import Annotator, colors + +check_requirements("shapely>=2.0.0") + +from shapely.geometry import LineString, Point, Polygon + + +class ObjectCounter: + """A class to manage the counting of objects in a real-time video stream based on their tracks.""" + + def __init__(self): + """Initializes the Counter with default values for various tracking and counting parameters.""" + + # Mouse events + self.is_drawing = False + self.selected_point = None + + # Region & Line Information + self.reg_pts = [(20, 400), (1260, 400)] + self.line_dist_thresh = 15 + self.counting_region = None + self.region_color = (255, 0, 255) + self.region_thickness = 5 + + # Image and annotation Information + self.im0 = None + self.tf = None + self.view_img = False + self.view_in_counts = True + self.view_out_counts = True + + self.names = None # Classes names + self.annotator = None # Annotator + self.window_name = "Ultralytics YOLOv8 Object Counter" + + # Object counting Information + self.in_counts = 0 + self.out_counts = 0 + self.counting_dict = {} + self.count_txt_thickness = 0 + self.count_txt_color = (0, 0, 0) + self.count_color = (255, 255, 255) + + # Tracks info + self.track_history = defaultdict(list) + self.track_thickness = 2 + self.draw_tracks = False + self.track_color = (0, 255, 0) + + # Check if environment support imshow + self.env_check = check_imshow(warn=True) + + def set_args( + self, + classes_names, + reg_pts, + count_reg_color=(255, 0, 255), + line_thickness=2, + track_thickness=2, + view_img=False, + view_in_counts=True, + view_out_counts=True, + draw_tracks=False, + count_txt_thickness=2, + count_txt_color=(0, 0, 0), + count_color=(255, 255, 255), + track_color=(0, 255, 0), + region_thickness=5, + line_dist_thresh=15, + ): + """ + Configures the Counter's image, bounding box line thickness, and counting region points. + + Args: + line_thickness (int): Line thickness for bounding boxes. + view_img (bool): Flag to control whether to display the video stream. + view_in_counts (bool): Flag to control whether to display the incounts on video stream. + view_out_counts (bool): Flag to control whether to display the outcounts on video stream. + reg_pts (list): Initial list of points defining the counting region. + classes_names (dict): Classes names + track_thickness (int): Track thickness + draw_tracks (Bool): draw tracks + count_txt_thickness (int): Text thickness for object counting display + count_txt_color (RGB color): count text color value + count_color (RGB color): count text background color value + count_reg_color (RGB color): Color of object counting region + track_color (RGB color): color for tracks + region_thickness (int): Object counting Region thickness + line_dist_thresh (int): Euclidean Distance threshold for line counter + """ + self.tf = line_thickness + self.view_img = view_img + self.view_in_counts = view_in_counts + self.view_out_counts = view_out_counts + self.track_thickness = track_thickness + self.draw_tracks = draw_tracks + + # Region and line selection + if len(reg_pts) == 2: + print("Line Counter Initiated.") + self.reg_pts = reg_pts + self.counting_region = LineString(self.reg_pts) + elif len(reg_pts) >= 3: + print("Region Counter Initiated.") + self.reg_pts = reg_pts + self.counting_region = Polygon(self.reg_pts) + else: + print("Invalid Region points provided, region_points must be 2 for lines or >= 3 for polygons.") + print("Using Line Counter Now") + self.counting_region = LineString(self.reg_pts) + + self.names = classes_names + self.track_color = track_color + self.count_txt_thickness = count_txt_thickness + self.count_txt_color = count_txt_color + self.count_color = count_color + self.region_color = count_reg_color + self.region_thickness = region_thickness + self.line_dist_thresh = line_dist_thresh + + def mouse_event_for_region(self, event, x, y, flags, params): + """ + This function is designed to move region with mouse events in a real-time video stream. + + Args: + event (int): The type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN, etc.). + x (int): The x-coordinate of the mouse pointer. + y (int): The y-coordinate of the mouse pointer. + flags (int): Any flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, + cv2.EVENT_FLAG_SHIFTKEY, etc.). + params (dict): Additional parameters you may want to pass to the function. + """ + if event == cv2.EVENT_LBUTTONDOWN: + for i, point in enumerate(self.reg_pts): + if ( + isinstance(point, (tuple, list)) + and len(point) >= 2 + and (abs(x - point[0]) < 10 and abs(y - point[1]) < 10) + ): + self.selected_point = i + self.is_drawing = True + break + + elif event == cv2.EVENT_MOUSEMOVE: + if self.is_drawing and self.selected_point is not None: + self.reg_pts[self.selected_point] = (x, y) + self.counting_region = Polygon(self.reg_pts) + + elif event == cv2.EVENT_LBUTTONUP: + self.is_drawing = False + self.selected_point = None + + def extract_and_process_tracks(self, tracks): + """Extracts and processes tracks for object counting in a video stream.""" + + # Annotator Init and region drawing + self.annotator = Annotator(self.im0, self.tf, self.names) + + if tracks[0].boxes.id is not None: + boxes = tracks[0].boxes.xyxy.cpu() + clss = tracks[0].boxes.cls.cpu().tolist() + track_ids = tracks[0].boxes.id.int().cpu().tolist() + + # Extract tracks + for box, track_id, cls in zip(boxes, track_ids, clss): + # Draw bounding box + self.annotator.box_label(box, label=f"{track_id}:{self.names[cls]}", color=colors(int(track_id), True)) + + # Draw Tracks + track_line = self.track_history[track_id] + track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2))) + if len(track_line) > 30: + track_line.pop(0) + + # Draw track trails + if self.draw_tracks: + self.annotator.draw_centroid_and_tracks( + track_line, color=self.track_color, track_thickness=self.track_thickness + ) + + prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None + centroid = Point((box[:2] + box[2:]) / 2) + + # Count objects + if len(self.reg_pts) >= 3: # any polygon + is_inside = self.counting_region.contains(centroid) + current_position = "in" if is_inside else "out" + + if prev_position is not None: + if self.counting_dict[track_id] != current_position and is_inside: + self.in_counts += 1 + self.counting_dict[track_id] = "in" + elif self.counting_dict[track_id] != current_position and not is_inside: + self.out_counts += 1 + self.counting_dict[track_id] = "out" + else: + self.counting_dict[track_id] = current_position + + else: + self.counting_dict[track_id] = current_position + + elif len(self.reg_pts) == 2: + if prev_position is not None: + is_inside = (box[0] - prev_position[0]) * ( + self.counting_region.centroid.x - prev_position[0] + ) > 0 + current_position = "in" if is_inside else "out" + + if self.counting_dict[track_id] != current_position and is_inside: + self.in_counts += 1 + self.counting_dict[track_id] = "in" + elif self.counting_dict[track_id] != current_position and not is_inside: + self.out_counts += 1 + self.counting_dict[track_id] = "out" + else: + self.counting_dict[track_id] = current_position + else: + self.counting_dict[track_id] = None + + incount_label = f"In Count : {self.in_counts}" + outcount_label = f"OutCount : {self.out_counts}" + + # Display counts based on user choice + counts_label = None + if not self.view_in_counts and not self.view_out_counts: + counts_label = None + elif not self.view_in_counts: + counts_label = outcount_label + elif not self.view_out_counts: + counts_label = incount_label + else: + counts_label = f"{incount_label} {outcount_label}" + + if counts_label is not None: + self.annotator.count_labels( + counts=counts_label, + count_txt_size=self.count_txt_thickness, + txt_color=self.count_txt_color, + color=self.count_color, + ) + + def display_frames(self): + """Display frame.""" + if self.env_check: + self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness) + cv2.namedWindow(self.window_name) + if len(self.reg_pts) == 4: # only add mouse event If user drawn region + cv2.setMouseCallback(self.window_name, self.mouse_event_for_region, {"region_points": self.reg_pts}) + cv2.imshow(self.window_name, self.im0) + # Break Window + if cv2.waitKey(1) & 0xFF == ord("q"): + return + + def start_counting(self, im0, tracks): + """ + Main function to start the object counting process. + + Args: + im0 (ndarray): Current frame from the video stream. + tracks (list): List of tracks obtained from the object tracking process. + """ + self.im0 = im0 # store image + self.extract_and_process_tracks(tracks) # draw region even if no objects + + if self.view_img: + self.display_frames() + return self.im0 + + +if __name__ == "__main__": + ObjectCounter() diff --git a/ultralytics/solutions/speed_estimation.py b/ultralytics/solutions/speed_estimation.py new file mode 100644 index 0000000..f3f1795 --- /dev/null +++ b/ultralytics/solutions/speed_estimation.py @@ -0,0 +1,198 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from collections import defaultdict +from time import time + +import cv2 +import numpy as np + +from ultralytics.utils.checks import check_imshow +from ultralytics.utils.plotting import Annotator, colors + + +class SpeedEstimator: + """A class to estimation speed of objects in real-time video stream based on their tracks.""" + + def __init__(self): + """Initializes the speed-estimator class with default values for Visual, Image, track and speed parameters.""" + + # Visual & im0 information + self.im0 = None + self.annotator = None + self.view_img = False + + # Region information + self.reg_pts = [(20, 400), (1260, 400)] + self.region_thickness = 3 + + # Predict/track information + self.clss = None + self.names = None + self.boxes = None + self.trk_ids = None + self.trk_pts = None + self.line_thickness = 2 + self.trk_history = defaultdict(list) + + # Speed estimator information + self.current_time = 0 + self.dist_data = {} + self.trk_idslist = [] + self.spdl_dist_thresh = 10 + self.trk_previous_times = {} + self.trk_previous_points = {} + + # Check if environment support imshow + self.env_check = check_imshow(warn=True) + + def set_args( + self, + reg_pts, + names, + view_img=False, + line_thickness=2, + region_thickness=5, + spdl_dist_thresh=10, + ): + """ + Configures the speed estimation and display parameters. + + Args: + reg_pts (list): Initial list of points defining the speed calculation region. + names (dict): object detection classes names + view_img (bool): Flag indicating frame display + line_thickness (int): Line thickness for bounding boxes. + region_thickness (int): Speed estimation region thickness + spdl_dist_thresh (int): Euclidean distance threshold for speed line + """ + if reg_pts is None: + print("Region points not provided, using default values") + else: + self.reg_pts = reg_pts + self.names = names + self.view_img = view_img + self.line_thickness = line_thickness + self.region_thickness = region_thickness + self.spdl_dist_thresh = spdl_dist_thresh + + def extract_tracks(self, tracks): + """ + Extracts results from the provided data. + + Args: + tracks (list): List of tracks obtained from the object tracking process. + """ + self.boxes = tracks[0].boxes.xyxy.cpu() + self.clss = tracks[0].boxes.cls.cpu().tolist() + self.trk_ids = tracks[0].boxes.id.int().cpu().tolist() + + def store_track_info(self, track_id, box): + """ + Store track data. + + Args: + track_id (int): object track id. + box (list): object bounding box data + """ + track = self.trk_history[track_id] + bbox_center = (float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2)) + track.append(bbox_center) + + if len(track) > 30: + track.pop(0) + + self.trk_pts = np.hstack(track).astype(np.int32).reshape((-1, 1, 2)) + return track + + def plot_box_and_track(self, track_id, box, cls, track): + """ + Plot track and bounding box. + + Args: + track_id (int): object track id. + box (list): object bounding box data + cls (str): object class name + track (list): tracking history for tracks path drawing + """ + speed_label = f"{int(self.dist_data[track_id])}km/ph" if track_id in self.dist_data else self.names[int(cls)] + bbox_color = colors(int(track_id)) if track_id in self.dist_data else (255, 0, 255) + + self.annotator.box_label(box, speed_label, bbox_color) + + cv2.polylines(self.im0, [self.trk_pts], isClosed=False, color=(0, 255, 0), thickness=1) + cv2.circle(self.im0, (int(track[-1][0]), int(track[-1][1])), 5, bbox_color, -1) + + def calculate_speed(self, trk_id, track): + """ + Calculation of object speed. + + Args: + trk_id (int): object track id. + track (list): tracking history for tracks path drawing + """ + + if not self.reg_pts[0][0] < track[-1][0] < self.reg_pts[1][0]: + return + if self.reg_pts[1][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[1][1] + self.spdl_dist_thresh: + direction = "known" + + elif self.reg_pts[0][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[0][1] + self.spdl_dist_thresh: + direction = "known" + + else: + direction = "unknown" + + if self.trk_previous_times[trk_id] != 0 and direction != "unknown" and trk_id not in self.trk_idslist: + self.trk_idslist.append(trk_id) + + time_difference = time() - self.trk_previous_times[trk_id] + if time_difference > 0: + dist_difference = np.abs(track[-1][1] - self.trk_previous_points[trk_id][1]) + speed = dist_difference / time_difference + self.dist_data[trk_id] = speed + + self.trk_previous_times[trk_id] = time() + self.trk_previous_points[trk_id] = track[-1] + + def estimate_speed(self, im0, tracks, region_color=(255, 0, 0)): + """ + Calculate object based on tracking data. + + Args: + im0 (nd array): Image + tracks (list): List of tracks obtained from the object tracking process. + region_color (tuple): Color to use when drawing regions. + """ + self.im0 = im0 + if tracks[0].boxes.id is None: + if self.view_img and self.env_check: + self.display_frames() + return im0 + self.extract_tracks(tracks) + + self.annotator = Annotator(self.im0, line_width=2) + self.annotator.draw_region(reg_pts=self.reg_pts, color=region_color, thickness=self.region_thickness) + + for box, trk_id, cls in zip(self.boxes, self.trk_ids, self.clss): + track = self.store_track_info(trk_id, box) + + if trk_id not in self.trk_previous_times: + self.trk_previous_times[trk_id] = 0 + + self.plot_box_and_track(trk_id, box, cls, track) + self.calculate_speed(trk_id, track) + + if self.view_img and self.env_check: + self.display_frames() + + return im0 + + def display_frames(self): + """Display frame.""" + cv2.imshow("Ultralytics Speed Estimation", self.im0) + if cv2.waitKey(1) & 0xFF == ord("q"): + return + + +if __name__ == "__main__": + SpeedEstimator() diff --git a/ultralytics/trackers/README.md b/ultralytics/trackers/README.md index a6505e0..2cab3c0 100644 --- a/ultralytics/trackers/README.md +++ b/ultralytics/trackers/README.md @@ -1,91 +1,318 @@ -# Tracker +# Multi-Object Tracking with Ultralytics YOLO -## Supported Trackers +YOLOv8 trackers visualization -- [x] ByteTracker -- [x] BoT-SORT +Object tracking in the realm of video analytics is a critical task that not only identifies the location and class of objects within the frame but also maintains a unique ID for each detected object as the video progresses. The applications are limitless—ranging from surveillance and security to real-time sports analytics. -## Usage +## Why Choose Ultralytics YOLO for Object Tracking? -### python interface: +The output from Ultralytics trackers is consistent with standard object detection but has the added value of object IDs. This makes it easy to track objects in video streams and perform subsequent analytics. Here's why you should consider using Ultralytics YOLO for your object tracking needs: -You can use the Python interface to track objects using the YOLO model. +- **Efficiency:** Process video streams in real-time without compromising accuracy. +- **Flexibility:** Supports multiple tracking algorithms and configurations. +- **Ease of Use:** Simple Python API and CLI options for quick integration and deployment. +- **Customizability:** Easy to use with custom trained YOLO models, allowing integration into domain-specific applications. + +**Video Tutorial:** [Object Detection and Tracking with Ultralytics YOLOv8](https://www.youtube.com/embed/hHyHmOtmEgs?si=VNZtXmm45Nb9s-N-). + +## Features at a Glance + +Ultralytics YOLO extends its object detection features to provide robust and versatile object tracking: + +- **Real-Time Tracking:** Seamlessly track objects in high-frame-rate videos. +- **Multiple Tracker Support:** Choose from a variety of established tracking algorithms. +- **Customizable Tracker Configurations:** Tailor the tracking algorithm to meet specific requirements by adjusting various parameters. + +## Available Trackers + +Ultralytics YOLO supports the following tracking algorithms. They can be enabled by passing the relevant YAML configuration file such as `tracker=tracker_type.yaml`: + +- [BoT-SORT](https://github.com/NirAharon/BoT-SORT) - Use `botsort.yaml` to enable this tracker. +- [ByteTrack](https://github.com/ifzhang/ByteTrack) - Use `bytetrack.yaml` to enable this tracker. + +The default tracker is BoT-SORT. + +## Tracking + +To run the tracker on video streams, use a trained Detect, Segment or Pose model such as YOLOv8n, YOLOv8n-seg and YOLOv8n-pose. + +#### Python ```python from ultralytics import YOLO -model = YOLO("yolov8n.pt") # or a segmentation model .i.e yolov8n-seg.pt -model.track( - source="video/streams", - stream=True, - tracker="botsort.yaml", # or 'bytetrack.yaml' - show=True, +# Load an official or custom model +model = YOLO("yolov8n.pt") # Load an official Detect model +model = YOLO("yolov8n-seg.pt") # Load an official Segment model +model = YOLO("yolov8n-pose.pt") # Load an official Pose model +model = YOLO("path/to/best.pt") # Load a custom trained model + +# Perform tracking with the model +results = model.track( + source="https://youtu.be/LNwODJXcvt4", show=True +) # Tracking with default tracker +results = model.track( + source="https://youtu.be/LNwODJXcvt4", show=True, tracker="bytetrack.yaml" +) # Tracking with ByteTrack tracker +``` + +#### CLI + +```bash +# Perform tracking with various models using the command line interface +yolo track model=yolov8n.pt source="https://youtu.be/LNwODJXcvt4" # Official Detect model +yolo track model=yolov8n-seg.pt source="https://youtu.be/LNwODJXcvt4" # Official Segment model +yolo track model=yolov8n-pose.pt source="https://youtu.be/LNwODJXcvt4" # Official Pose model +yolo track model=path/to/best.pt source="https://youtu.be/LNwODJXcvt4" # Custom trained model + +# Track using ByteTrack tracker +yolo track model=path/to/best.pt tracker="bytetrack.yaml" +``` + +As can be seen in the above usage, tracking is available for all Detect, Segment and Pose models run on videos or streaming sources. + +## Configuration + +### Tracking Arguments + +Tracking configuration shares properties with Predict mode, such as `conf`, `iou`, and `show`. For further configurations, refer to the [Predict](https://docs.ultralytics.com/modes/predict/) model page. + +#### Python + +```python +from ultralytics import YOLO + +# Configure the tracking parameters and run the tracker +model = YOLO("yolov8n.pt") +results = model.track( + source="https://youtu.be/LNwODJXcvt4", conf=0.3, iou=0.5, show=True ) ``` -You can get the IDs of the tracked objects using the following code: +#### CLI + +```bash +# Configure tracking parameters and run the tracker using the command line interface +yolo track model=yolov8n.pt source="https://youtu.be/LNwODJXcvt4" conf=0.3, iou=0.5 show +``` + +### Tracker Selection + +Ultralytics also allows you to use a modified tracker configuration file. To do this, simply make a copy of a tracker config file (for example, `custom_tracker.yaml`) from [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) and modify any configurations (except the `tracker_type`) as per your needs. + +#### Python ```python from ultralytics import YOLO +# Load the model and run the tracker with a custom configuration file model = YOLO("yolov8n.pt") - -for result in model.track(source="video.mp4"): - print( - result.boxes.id.cpu().numpy().astype(int) - ) # this will print the IDs of the tracked objects in the frame +results = model.track( + source="https://youtu.be/LNwODJXcvt4", tracker="custom_tracker.yaml" +) ``` -If you want to use the tracker with a folder of images or when you loop on the video frames, you should use the `persist` parameter to tell the model that these frames are related to each other so the IDs will be fixed for the same objects. Otherwise, the IDs will be different in each frame because in each loop, the model creates a new object for tracking, but the `persist` parameter makes it use the same object for tracking. +#### CLI + +```bash +# Load the model and run the tracker with a custom configuration file using the command line interface +yolo track model=yolov8n.pt source="https://youtu.be/LNwODJXcvt4" tracker='custom_tracker.yaml' +``` + +For a comprehensive list of tracking arguments, refer to the [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) page. + +## Python Examples + +### Persisting Tracks Loop + +Here is a Python script using OpenCV (`cv2`) and YOLOv8 to run object tracking on video frames. This script still assumes you have already installed the necessary packages (`opencv-python` and `ultralytics`). The `persist=True` argument tells the tracker than the current image or frame is the next in a sequence and to expect tracks from the previous image in the current image. + +#### Python ```python import cv2 from ultralytics import YOLO -cap = cv2.VideoCapture("video.mp4") +# Load the YOLOv8 model model = YOLO("yolov8n.pt") -while True: - ret, frame = cap.read() - if not ret: - break - results = model.track(frame, persist=True) - boxes = results[0].boxes.xyxy.cpu().numpy().astype(int) - ids = results[0].boxes.id.cpu().numpy().astype(int) - for box, id in zip(boxes, ids): - cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2) - cv2.putText( - frame, - f"Id {id}", - (box[0], box[1]), - cv2.FONT_HERSHEY_SIMPLEX, - 1, - (0, 0, 255), - 2, - ) - cv2.imshow("frame", frame) - if cv2.waitKey(1) & 0xFF == ord("q"): + +# Open the video file +video_path = "path/to/video.mp4" +cap = cv2.VideoCapture(video_path) + +# Loop through the video frames +while cap.isOpened(): + # Read a frame from the video + success, frame = cap.read() + + if success: + # Run YOLOv8 tracking on the frame, persisting tracks between frames + results = model.track(frame, persist=True) + + # Visualize the results on the frame + annotated_frame = results[0].plot() + + # Display the annotated frame + cv2.imshow("YOLOv8 Tracking", annotated_frame) + + # Break the loop if 'q' is pressed + if cv2.waitKey(1) & 0xFF == ord("q"): + break + else: + # Break the loop if the end of the video is reached break + +# Release the video capture object and close the display window +cap.release() +cv2.destroyAllWindows() ``` -## Change tracker parameters +Please note the change from `model(frame)` to `model.track(frame)`, which enables object tracking instead of simple detection. This modified script will run the tracker on each frame of the video, visualize the results, and display them in a window. The loop can be exited by pressing 'q'. -You can change the tracker parameters by editing the `tracker.yaml` file which is located in the ultralytics/cfg/trackers folder. +### Plotting Tracks Over Time -## Command Line Interface (CLI) +Visualizing object tracks over consecutive frames can provide valuable insights into the movement patterns and behavior of detected objects within a video. With Ultralytics YOLOv8, plotting these tracks is a seamless and efficient process. -You can also use the command line interface to track objects using the YOLO model. +In the following example, we demonstrate how to utilize YOLOv8's tracking capabilities to plot the movement of detected objects across multiple video frames. This script involves opening a video file, reading it frame by frame, and utilizing the YOLO model to identify and track various objects. By retaining the center points of the detected bounding boxes and connecting them, we can draw lines that represent the paths followed by the tracked objects. -```bash -yolo detect track source=... tracker=... -yolo segment track source=... tracker=... -yolo pose track source=... tracker=... +#### Python + +```python +from collections import defaultdict + +import cv2 +import numpy as np + +from ultralytics import YOLO + +# Load the YOLOv8 model +model = YOLO("yolov8n.pt") + +# Open the video file +video_path = "path/to/video.mp4" +cap = cv2.VideoCapture(video_path) + +# Store the track history +track_history = defaultdict(lambda: []) + +# Loop through the video frames +while cap.isOpened(): + # Read a frame from the video + success, frame = cap.read() + + if success: + # Run YOLOv8 tracking on the frame, persisting tracks between frames + results = model.track(frame, persist=True) + + # Get the boxes and track IDs + boxes = results[0].boxes.xywh.cpu() + track_ids = results[0].boxes.id.int().cpu().tolist() + + # Visualize the results on the frame + annotated_frame = results[0].plot() + + # Plot the tracks + for box, track_id in zip(boxes, track_ids): + x, y, w, h = box + track = track_history[track_id] + track.append((float(x), float(y))) # x, y center point + if len(track) > 30: # retain 90 tracks for 90 frames + track.pop(0) + + # Draw the tracking lines + points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines( + annotated_frame, + [points], + isClosed=False, + color=(230, 230, 230), + thickness=10, + ) + + # Display the annotated frame + cv2.imshow("YOLOv8 Tracking", annotated_frame) + + # Break the loop if 'q' is pressed + if cv2.waitKey(1) & 0xFF == ord("q"): + break + else: + # Break the loop if the end of the video is reached + break + +# Release the video capture object and close the display window +cap.release() +cv2.destroyAllWindows() ``` -By default, trackers will use the configuration in `ultralytics/cfg/trackers`. We also support using a modified tracker config file. Please refer to the tracker config files in `ultralytics/cfg/trackers`. +### Multithreaded Tracking -## Contribute to Our Trackers Section +Multithreaded tracking provides the capability to run object tracking on multiple video streams simultaneously. This is particularly useful when handling multiple video inputs, such as from multiple surveillance cameras, where concurrent processing can greatly enhance efficiency and performance. -Are you proficient in multi-object tracking and have successfully implemented or adapted a tracking algorithm with Ultralytics YOLO? We invite you to contribute to our Trackers section! Your real-world applications and solutions could be invaluable for users working on tracking tasks. +In the provided Python script, we make use of Python's `threading` module to run multiple instances of the tracker concurrently. Each thread is responsible for running the tracker on one video file, and all the threads run simultaneously in the background. + +To ensure that each thread receives the correct parameters (the video file and the model to use), we define a function `run_tracker_in_thread` that accepts these parameters and contains the main tracking loop. This function reads the video frame by frame, runs the tracker, and displays the results. + +Two different models are used in this example: `yolov8n.pt` and `yolov8n-seg.pt`, each tracking objects in a different video file. The video files are specified in `video_file1` and `video_file2`. + +The `daemon=True` parameter in `threading.Thread` means that these threads will be closed as soon as the main program finishes. We then start the threads with `start()` and use `join()` to make the main thread wait until both tracker threads have finished. + +Finally, after all threads have completed their task, the windows displaying the results are closed using `cv2.destroyAllWindows()`. + +#### Python + +```python +import threading + +import cv2 +from ultralytics import YOLO + + +def run_tracker_in_thread(filename, model): + video = cv2.VideoCapture(filename) + frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + for _ in range(frames): + ret, frame = video.read() + if ret: + results = model.track(source=frame, persist=True) + res_plotted = results[0].plot() + cv2.imshow("p", res_plotted) + if cv2.waitKey(1) == ord("q"): + break + + +# Load the models +model1 = YOLO("yolov8n.pt") +model2 = YOLO("yolov8n-seg.pt") + +# Define the video files for the trackers +video_file1 = "path/to/video1.mp4" +video_file2 = "path/to/video2.mp4" + +# Create the tracker threads +tracker_thread1 = threading.Thread( + target=run_tracker_in_thread, args=(video_file1, model1), daemon=True +) +tracker_thread2 = threading.Thread( + target=run_tracker_in_thread, args=(video_file2, model2), daemon=True +) + +# Start the tracker threads +tracker_thread1.start() +tracker_thread2.start() + +# Wait for the tracker threads to finish +tracker_thread1.join() +tracker_thread2.join() + +# Clean up and close windows +cv2.destroyAllWindows() +``` + +This example can easily be extended to handle more video files and models by creating more threads and applying the same methodology. + +## Contribute New Trackers + +Are you proficient in multi-object tracking and have successfully implemented or adapted a tracking algorithm with Ultralytics YOLO? We invite you to contribute to our Trackers section in [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers)! Your real-world applications and solutions could be invaluable for users working on tracking tasks. By contributing to this section, you help expand the scope of tracking solutions available within the Ultralytics YOLO framework, adding another layer of functionality and utility for the community. diff --git a/ultralytics/trackers/__init__.py b/ultralytics/trackers/__init__.py index 46e178e..bf51b8d 100644 --- a/ultralytics/trackers/__init__.py +++ b/ultralytics/trackers/__init__.py @@ -4,4 +4,4 @@ from .bot_sort import BOTSORT from .byte_tracker import BYTETracker from .track import register_tracker -__all__ = 'register_tracker', 'BOTSORT', 'BYTETracker' # allow simpler import +__all__ = "register_tracker", "BOTSORT", "BYTETracker" # allow simpler import diff --git a/ultralytics/trackers/__pycache__/__init__.cpython-39.pyc b/ultralytics/trackers/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index f4c0444..0000000 Binary files a/ultralytics/trackers/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/ultralytics/trackers/__pycache__/basetrack.cpython-39.pyc b/ultralytics/trackers/__pycache__/basetrack.cpython-39.pyc deleted file mode 100644 index 5476853..0000000 Binary files a/ultralytics/trackers/__pycache__/basetrack.cpython-39.pyc and /dev/null differ diff --git a/ultralytics/trackers/__pycache__/bot_sort.cpython-39.pyc b/ultralytics/trackers/__pycache__/bot_sort.cpython-39.pyc deleted file mode 100644 index 1275148..0000000 Binary files a/ultralytics/trackers/__pycache__/bot_sort.cpython-39.pyc and /dev/null differ diff --git a/ultralytics/trackers/__pycache__/byte_tracker.cpython-39.pyc b/ultralytics/trackers/__pycache__/byte_tracker.cpython-39.pyc deleted file mode 100644 index 9abb721..0000000 Binary files a/ultralytics/trackers/__pycache__/byte_tracker.cpython-39.pyc and /dev/null differ diff --git a/ultralytics/trackers/__pycache__/track.cpython-39.pyc b/ultralytics/trackers/__pycache__/track.cpython-39.pyc deleted file mode 100644 index a5b8d03..0000000 Binary files a/ultralytics/trackers/__pycache__/track.cpython-39.pyc and /dev/null differ diff --git a/ultralytics/trackers/basetrack.py b/ultralytics/trackers/basetrack.py index 3c7b0f7..c900cac 100644 --- a/ultralytics/trackers/basetrack.py +++ b/ultralytics/trackers/basetrack.py @@ -1,4 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license +"""This module defines the base classes and structures for object tracking in YOLO.""" from collections import OrderedDict @@ -6,7 +7,15 @@ import numpy as np class TrackState: - """Enumeration of possible object tracking states.""" + """ + Enumeration class representing the possible states of an object being tracked. + + Attributes: + New (int): State when the object is newly detected. + Tracked (int): State when the object is successfully tracked in subsequent frames. + Lost (int): State when the object is no longer tracked. + Removed (int): State when the object is removed from tracking. + """ New = 0 Tracked = 1 @@ -15,24 +24,49 @@ class TrackState: class BaseTrack: - """Base class for object tracking, handling basic track attributes and operations.""" + """ + Base class for object tracking, providing foundational attributes and methods. + + Attributes: + _count (int): Class-level counter for unique track IDs. + track_id (int): Unique identifier for the track. + is_activated (bool): Flag indicating whether the track is currently active. + state (TrackState): Current state of the track. + history (OrderedDict): Ordered history of the track's states. + features (list): List of features extracted from the object for tracking. + curr_feature (any): The current feature of the object being tracked. + score (float): The confidence score of the tracking. + start_frame (int): The frame number where tracking started. + frame_id (int): The most recent frame ID processed by the track. + time_since_update (int): Frames passed since the last update. + location (tuple): The location of the object in the context of multi-camera tracking. + + Methods: + end_frame: Returns the ID of the last frame where the object was tracked. + next_id: Increments and returns the next global track ID. + activate: Abstract method to activate the track. + predict: Abstract method to predict the next state of the track. + update: Abstract method to update the track with new data. + mark_lost: Marks the track as lost. + mark_removed: Marks the track as removed. + reset_id: Resets the global track ID counter. + """ _count = 0 - track_id = 0 - is_activated = False - state = TrackState.New - - history = OrderedDict() - features = [] - curr_feature = None - score = 0 - start_frame = 0 - frame_id = 0 - time_since_update = 0 - - # Multi-camera - location = (np.inf, np.inf) + def __init__(self): + """Initializes a new track with unique ID and foundational tracking attributes.""" + self.track_id = 0 + self.is_activated = False + self.state = TrackState.New + self.history = OrderedDict() + self.features = [] + self.curr_feature = None + self.score = 0 + self.start_frame = 0 + self.frame_id = 0 + self.time_since_update = 0 + self.location = (np.inf, np.inf) @property def end_frame(self): @@ -46,15 +80,15 @@ class BaseTrack: return BaseTrack._count def activate(self, *args): - """Activate the track with the provided arguments.""" + """Abstract method to activate the track with provided arguments.""" raise NotImplementedError def predict(self): - """Predict the next state of the track.""" + """Abstract method to predict the next state of the track.""" raise NotImplementedError def update(self, *args, **kwargs): - """Update the track with new observations.""" + """Abstract method to update the track with new observations.""" raise NotImplementedError def mark_lost(self): diff --git a/ultralytics/trackers/bot_sort.py b/ultralytics/trackers/bot_sort.py index 7bd63e5..31d5e1b 100644 --- a/ultralytics/trackers/bot_sort.py +++ b/ultralytics/trackers/bot_sort.py @@ -12,6 +12,34 @@ from .utils.kalman_filter import KalmanFilterXYWH class BOTrack(STrack): + """ + An extended version of the STrack class for YOLOv8, adding object tracking features. + + Attributes: + shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack. + smooth_feat (np.ndarray): Smoothed feature vector. + curr_feat (np.ndarray): Current feature vector. + features (deque): A deque to store feature vectors with a maximum length defined by `feat_history`. + alpha (float): Smoothing factor for the exponential moving average of features. + mean (np.ndarray): The mean state of the Kalman filter. + covariance (np.ndarray): The covariance matrix of the Kalman filter. + + Methods: + update_features(feat): Update features vector and smooth it using exponential moving average. + predict(): Predicts the mean and covariance using Kalman filter. + re_activate(new_track, frame_id, new_id): Reactivates a track with updated features and optionally new ID. + update(new_track, frame_id): Update the YOLOv8 instance with new track and frame ID. + tlwh: Property that gets the current position in tlwh format `(top left x, top left y, width, height)`. + multi_predict(stracks): Predicts the mean and covariance of multiple object tracks using shared Kalman filter. + convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format. + tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`. + + Usage: + bo_track = BOTrack(tlwh, score, cls, feat) + bo_track.predict() + bo_track.update(new_track, frame_id) + """ + shared_kalman = KalmanFilterXYWH() def __init__(self, tlwh, score, cls, feat=None, feat_history=50): @@ -59,9 +87,7 @@ class BOTrack(STrack): @property def tlwh(self): - """Get current position in bounding box format `(top left x, top left y, - width, height)`. - """ + """Get current position in bounding box format `(top left x, top left y, width, height)`.""" if self.mean is None: return self._tlwh.copy() ret = self.mean[:4].copy() @@ -90,15 +116,37 @@ class BOTrack(STrack): @staticmethod def tlwh_to_xywh(tlwh): - """Convert bounding box to format `(center x, center y, width, - height)`. - """ + """Convert bounding box to format `(center x, center y, width, height)`.""" ret = np.asarray(tlwh).copy() ret[:2] += ret[2:] / 2 return ret class BOTSORT(BYTETracker): + """ + An extended version of the BYTETracker class for YOLOv8, designed for object tracking with ReID and GMC algorithm. + + Attributes: + proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections. + appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections. + encoder (object): Object to handle ReID embeddings, set to None if ReID is not enabled. + gmc (GMC): An instance of the GMC algorithm for data association. + args (object): Parsed command-line arguments containing tracking parameters. + + Methods: + get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking. + init_track(dets, scores, cls, img): Initialize track with detections, scores, and classes. + get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID. + multi_predict(tracks): Predict and track multiple objects with YOLOv8 model. + + Usage: + bot_sort = BOTSORT(args, frame_rate) + bot_sort.init_track(dets, scores, cls, img) + bot_sort.multi_predict(tracks) + + Note: + The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args. + """ def __init__(self, args, frame_rate=30): """Initialize YOLOv8 object with ReID module and GMC algorithm.""" @@ -110,8 +158,7 @@ class BOTSORT(BYTETracker): if args.with_reid: # Haven't supported BoT-SORT(reid) yet self.encoder = None - - # self.gmc = GMC(method=args.gmc_method) # commented by WQG + self.gmc = GMC(method=args.gmc_method) def get_kalmanfilter(self): """Returns an instance of KalmanFilterXYWH for object tracking.""" @@ -130,7 +177,7 @@ class BOTSORT(BYTETracker): def get_dists(self, tracks, detections): """Get distances between tracks and detections using IoU and (optionally) ReID embeddings.""" dists = matching.iou_distance(tracks, detections) - dists_mask = (dists > self.proximity_thresh) + dists_mask = dists > self.proximity_thresh # TODO: mot20 # if not self.args.mot20: @@ -146,3 +193,8 @@ class BOTSORT(BYTETracker): def multi_predict(self, tracks): """Predict and track multiple objects with YOLOv8 model.""" BOTrack.multi_predict(tracks) + + def reset(self): + """Reset tracker.""" + super().reset() + self.gmc.reset_params() diff --git a/ultralytics/trackers/byte_tracker.py b/ultralytics/trackers/byte_tracker.py index 91559df..01cbca9 100644 --- a/ultralytics/trackers/byte_tracker.py +++ b/ultralytics/trackers/byte_tracker.py @@ -1,29 +1,54 @@ -# Ultralytics YOLO 🚀, AGPL-3.0 license +# Ultralytics YOLO 🚀, AGPL-3.0 license import numpy as np from .basetrack import BaseTrack, TrackState from .utils import matching from .utils.kalman_filter import KalmanFilterXYAH - - -def dists_update(dists, strack_pool, detections): - if len(strack_pool) and len(detections): - alabel = np.array([int(stack.cls) for stack in strack_pool]) - blabel = np.array([int(stack.cls) for stack in detections]) - amlabel = np.expand_dims(alabel, axis=1).repeat(len(detections),axis=1) - bmlabel = np.expand_dims(blabel, axis=0).repeat(len(strack_pool),axis=0) - dist_label = 1 - (bmlabel == amlabel) - dists = np.where(dists > dist_label, dists, dist_label) - return dists +from ..utils.ops import xywh2ltwh +from ..utils import LOGGER class STrack(BaseTrack): + """ + Single object tracking representation that uses Kalman filtering for state estimation. + + This class is responsible for storing all the information regarding individual tracklets and performs state updates + and predictions based on Kalman filter. + + Attributes: + shared_kalman (KalmanFilterXYAH): Shared Kalman filter that is used across all STrack instances for prediction. + _tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box. + kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track. + mean (np.ndarray): Mean state estimate vector. + covariance (np.ndarray): Covariance of state estimate. + is_activated (bool): Boolean flag indicating if the track has been activated. + score (float): Confidence score of the track. + tracklet_len (int): Length of the tracklet. + cls (any): Class label for the object. + idx (int): Index or identifier for the object. + frame_id (int): Current frame ID. + start_frame (int): Frame where the object was first detected. + + Methods: + predict(): Predict the next state of the object using Kalman filter. + multi_predict(stracks): Predict the next states for multiple tracks. + multi_gmc(stracks, H): Update multiple track states using a homography matrix. + activate(kalman_filter, frame_id): Activate a new tracklet. + re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet. + update(new_track, frame_id): Update the state of a matched track. + convert_coords(tlwh): Convert bounding box to x-y-aspect-height format. + tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format. + """ + shared_kalman = KalmanFilterXYAH() - def __init__(self, tlwh, score, cls): - """wait activate.""" - self._tlwh = np.asarray(self.tlbr_to_tlwh(tlwh[:-1]), dtype=np.float32) + def __init__(self, xywh, score, cls): + """Initialize new STrack instance.""" + super().__init__() + # xywh+idx or xywha+idx + assert len(xywh) in [5, 6], f"expected 5 or 6 values but got {len(xywh)}" + self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32) self.kalman_filter = None self.mean, self.covariance = None, None self.is_activated = False @@ -31,7 +56,8 @@ class STrack(BaseTrack): self.score = score self.tracklet_len = 0 self.cls = cls - self.idx = tlwh[-1] + self.idx = xywh[-1] + self.angle = xywh[4] if len(xywh) == 6 else None def predict(self): """Predicts mean and covariance using Kalman filter.""" @@ -89,8 +115,9 @@ class STrack(BaseTrack): def re_activate(self, new_track, frame_id, new_id=False): """Reactivates a previously lost track with a new detection.""" - self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, - self.convert_coords(new_track.tlwh)) + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, self.convert_coords(new_track.tlwh) + ) self.tracklet_len = 0 self.state = TrackState.Tracked self.is_activated = True @@ -99,37 +126,39 @@ class STrack(BaseTrack): self.track_id = self.next_id() self.score = new_track.score self.cls = new_track.cls + self.angle = new_track.angle self.idx = new_track.idx def update(self, new_track, frame_id): """ - Update a matched track - :type new_track: STrack - :type frame_id: int - :return: + Update the state of a matched track. + + Args: + new_track (STrack): The new track containing updated information. + frame_id (int): The ID of the current frame. """ self.frame_id = frame_id self.tracklet_len += 1 new_tlwh = new_track.tlwh - self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, - self.convert_coords(new_tlwh)) + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, self.convert_coords(new_tlwh) + ) self.state = TrackState.Tracked self.is_activated = True self.score = new_track.score self.cls = new_track.cls + self.angle = new_track.angle self.idx = new_track.idx def convert_coords(self, tlwh): - """Convert a bounding box's top-left-width-height format to its x-y-angle-height equivalent.""" + """Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent.""" return self.tlwh_to_xyah(tlwh) @property def tlwh(self): - """Get current position in bounding box format `(top left x, top left y, - width, height)`. - """ + """Get current position in bounding box format (top left x, top left y, width, height).""" if self.mean is None: return self._tlwh.copy() ret = self.mean[:4].copy() @@ -138,44 +167,76 @@ class STrack(BaseTrack): return ret @property - def tlbr(self): - """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., - `(top left, bottom right)`. - """ + def xyxy(self): + """Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right).""" ret = self.tlwh.copy() ret[2:] += ret[:2] return ret @staticmethod def tlwh_to_xyah(tlwh): - """Convert bounding box to format `(center x, center y, aspect ratio, - height)`, where the aspect ratio is `width / height`. + """Convert bounding box to format (center x, center y, aspect ratio, height), where the aspect ratio is width / + height. """ ret = np.asarray(tlwh).copy() ret[:2] += ret[2:] / 2 ret[2] /= ret[3] return ret - @staticmethod - def tlbr_to_tlwh(tlbr): - """Converts top-left bottom-right format to top-left width height format.""" - ret = np.asarray(tlbr).copy() - ret[2:] -= ret[:2] + @property + def xywh(self): + """Get current position in bounding box format (center x, center y, width, height).""" + ret = np.asarray(self.tlwh).copy() + ret[:2] += ret[2:] / 2 return ret - @staticmethod - def tlwh_to_tlbr(tlwh): - """Converts tlwh bounding box format to tlbr format.""" - ret = np.asarray(tlwh).copy() - ret[2:] += ret[:2] - return ret + @property + def xywha(self): + """Get current position in bounding box format (center x, center y, width, height, angle).""" + if self.angle is None: + LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.") + return self.xywh + return np.concatenate([self.xywh, self.angle[None]]) + + @property + def result(self): + """Get current tracking results.""" + coords = self.xyxy if self.angle is None else self.xywha + return coords.tolist() + [self.track_id, self.score, self.cls, self.idx] def __repr__(self): """Return a string representation of the BYTETracker object with start and end frames and track ID.""" - return f'OT_{self.track_id}_({self.start_frame}-{self.end_frame})' + return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})" class BYTETracker: + """ + BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking. + + The class is responsible for initializing, updating, and managing the tracks for detected objects in a video + sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for + predicting the new object locations, and performs data association. + + Attributes: + tracked_stracks (list[STrack]): List of successfully activated tracks. + lost_stracks (list[STrack]): List of lost tracks. + removed_stracks (list[STrack]): List of removed tracks. + frame_id (int): The current frame ID. + args (namespace): Command-line arguments. + max_time_lost (int): The maximum frames for a track to be considered as 'lost'. + kalman_filter (object): Kalman Filter object. + + Methods: + update(results, img=None): Updates object tracker with new detections. + get_kalmanfilter(): Returns a Kalman filter object for tracking bounding boxes. + init_track(dets, scores, cls, img=None): Initialize object tracking with detections. + get_dists(tracks, detections): Calculates the distance between tracks and detections. + multi_predict(tracks): Predicts the location of tracks. + reset_id(): Resets the ID counter of STrack. + joint_stracks(tlista, tlistb): Combines two lists of stracks. + sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list. + remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU. + """ def __init__(self, args, frame_rate=30): """Initialize a YOLOv8 object to track objects with given arguments and frame rate.""" @@ -198,7 +259,7 @@ class BYTETracker: removed_stracks = [] scores = results.conf - bboxes = results.xyxy + bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh # Add index bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1) cls = results.cls @@ -216,7 +277,6 @@ class BYTETracker: cls_second = cls[inds_second] detections = self.init_track(dets, scores_keep, cls_keep, img) - # Add newly detected tracklets to tracked_stracks unconfirmed = [] tracked_stracks = [] # type: list[STrack] @@ -225,24 +285,18 @@ class BYTETracker: unconfirmed.append(track) else: tracked_stracks.append(track) - - # Step 2: First association, with high score detection boxes strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks) # Predict the current location with KF self.multi_predict(strack_pool) - -# ============================================================= 没必要gmc,WQG -# if hasattr(self, 'gmc') and img is not None: -# warp = self.gmc.apply(img, dets) -# STrack.multi_gmc(strack_pool, warp) -# STrack.multi_gmc(unconfirmed, warp) -# ============================================================================= + if hasattr(self, "gmc") and img is not None: + warp = self.gmc.apply(img, dets) + STrack.multi_gmc(strack_pool, warp) + STrack.multi_gmc(unconfirmed, warp) dists = self.get_dists(strack_pool, detections) - dists = dists_update(dists, strack_pool, detections) - matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh) + for itracked, idet in matches: track = strack_pool[itracked] det = detections[idet] @@ -252,17 +306,11 @@ class BYTETracker: else: track.re_activate(det, self.frame_id, new_id=False) refind_stracks.append(track) - - - # Step 3: Second association, with low score detection boxes - # association the untrack to the low score detections + # Step 3: Second association, with low score detection boxes association the untrack to the low score detections detections_second = self.init_track(dets_second, scores_second, cls_second, img) r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked] - # TODO dists = matching.iou_distance(r_tracked_stracks, detections_second) - dists = dists_update(dists, r_tracked_stracks, detections_second) - matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5) for itracked, idet in matches: track = r_tracked_stracks[itracked] @@ -279,13 +327,9 @@ class BYTETracker: if track.state != TrackState.Lost: track.mark_lost() lost_stracks.append(track) - # Deal with unconfirmed tracks, usually tracks with only one beginning frame detections = [detections[i] for i in u_detection] dists = self.get_dists(unconfirmed, detections) - - dists = dists_update(dists, unconfirmed, detections) - matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7) for itracked, idet in matches: unconfirmed[itracked].update(detections[idet], self.frame_id) @@ -317,9 +361,8 @@ class BYTETracker: self.removed_stracks.extend(removed_stracks) if len(self.removed_stracks) > 1000: self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum - return np.asarray( - [x.tlbr.tolist() + [x.track_id, x.score, x.cls, x.idx] for x in self.tracked_stracks if x.is_activated], - dtype=np.float32) + + return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32) def get_kalmanfilter(self): """Returns a Kalman filter object for tracking bounding boxes.""" @@ -330,7 +373,7 @@ class BYTETracker: return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections def get_dists(self, tracks, detections): - """Calculates the distance between tracks and detections using IOU and fuses scores.""" + """Calculates the distance between tracks and detections using IoU and fuses scores.""" dists = matching.iou_distance(tracks, detections) # TODO: mot20 # if not self.args.mot20: @@ -341,10 +384,20 @@ class BYTETracker: """Returns the predicted tracks using the YOLOv8 network.""" STrack.multi_predict(tracks) - def reset_id(self): + @staticmethod + def reset_id(): """Resets the ID counter of STrack.""" STrack.reset_id() + def reset(self): + """Reset tracker.""" + self.tracked_stracks = [] # type: list[STrack] + self.lost_stracks = [] # type: list[STrack] + self.removed_stracks = [] # type: list[STrack] + self.frame_id = 0 + self.kalman_filter = self.get_kalmanfilter() + self.reset_id() + @staticmethod def joint_stracks(tlista, tlistb): """Combine two lists of stracks into a single one.""" @@ -375,7 +428,7 @@ class BYTETracker: @staticmethod def remove_duplicate_stracks(stracksa, stracksb): - """Remove duplicate stracks with non-maximum IOU distance.""" + """Remove duplicate stracks with non-maximum IoU distance.""" pdist = matching.iou_distance(stracksa, stracksb) pairs = np.where(pdist < 0.15) dupa, dupb = [], [] diff --git a/ultralytics/trackers/track.py b/ultralytics/trackers/track.py index cfb4b08..7146a40 100644 --- a/ultralytics/trackers/track.py +++ b/ultralytics/trackers/track.py @@ -1,19 +1,20 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license from functools import partial +from pathlib import Path import torch from ultralytics.utils import IterableSimpleNamespace, yaml_load from ultralytics.utils.checks import check_yaml - from .bot_sort import BOTSORT from .byte_tracker import BYTETracker -TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT} +# A mapping of tracker types to corresponding tracker classes +TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT} -def on_predict_start(predictor, persist=False): +def on_predict_start(predictor: object, persist: bool = False) -> None: """ Initialize trackers for object tracking during prediction. @@ -24,43 +25,65 @@ def on_predict_start(predictor, persist=False): Raises: AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'. """ - if hasattr(predictor, 'trackers') and persist: + if hasattr(predictor, "trackers") and persist: return + tracker = check_yaml(predictor.args.tracker) cfg = IterableSimpleNamespace(**yaml_load(tracker)) - assert cfg.tracker_type in ['bytetrack', 'botsort'], \ - f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'" + + if cfg.tracker_type not in ["bytetrack", "botsort"]: + raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'") + trackers = [] for _ in range(predictor.dataset.bs): tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30) trackers.append(tracker) + if predictor.dataset.mode != "stream": # only need one tracker for other modes. + break predictor.trackers = trackers + predictor.vid_path = [None] * predictor.dataset.bs # for determining when to reset tracker on new video -def on_predict_postprocess_end(predictor): - """Postprocess detected boxes and update with object tracking.""" - bs = predictor.dataset.bs - im0s = predictor.batch[1] - for i in range(bs): - det = predictor.results[i].boxes.cpu().numpy() +def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None: + """ + Postprocess detected boxes and update with object tracking. + + Args: + predictor (object): The predictor object containing the predictions. + persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. + """ + path, im0s = predictor.batch[:2] + + is_obb = predictor.args.task == "obb" + is_stream = predictor.dataset.mode == "stream" + for i in range(len(im0s)): + tracker = predictor.trackers[i if is_stream else 0] + vid_path = predictor.save_dir / Path(path[i]).name + if not persist and predictor.vid_path[i if is_stream else 0] != vid_path: + tracker.reset() + predictor.vid_path[i if is_stream else 0] = vid_path + + det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy() if len(det) == 0: continue - tracks = predictor.trackers[i].update(det, im0s[i]) + tracks = tracker.update(det, im0s[i]) if len(tracks) == 0: continue idx = tracks[:, -1].astype(int) predictor.results[i] = predictor.results[i][idx] - predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1])) + + update_args = dict() + update_args["obb" if is_obb else "boxes"] = torch.as_tensor(tracks[:, :-1]) + predictor.results[i].update(**update_args) -def register_tracker(model, persist): +def register_tracker(model: object, persist: bool) -> None: """ Register tracking callbacks to the model for object tracking during prediction. Args: model (object): The model object to register tracking callbacks for. persist (bool): Whether to persist the trackers if they already exist. - """ - model.add_callback('on_predict_start', partial(on_predict_start, persist=persist)) - model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end) + model.add_callback("on_predict_start", partial(on_predict_start, persist=persist)) + model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist)) diff --git a/ultralytics/trackers/utils/__pycache__/__init__.cpython-39.pyc b/ultralytics/trackers/utils/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 6cfb15e..0000000 Binary files a/ultralytics/trackers/utils/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/ultralytics/trackers/utils/__pycache__/gmc.cpython-39.pyc b/ultralytics/trackers/utils/__pycache__/gmc.cpython-39.pyc deleted file mode 100644 index 82e0b1b..0000000 Binary files a/ultralytics/trackers/utils/__pycache__/gmc.cpython-39.pyc and /dev/null differ diff --git a/ultralytics/trackers/utils/__pycache__/kalman_filter.cpython-39.pyc b/ultralytics/trackers/utils/__pycache__/kalman_filter.cpython-39.pyc deleted file mode 100644 index 4c34a52..0000000 Binary files a/ultralytics/trackers/utils/__pycache__/kalman_filter.cpython-39.pyc and /dev/null differ diff --git a/ultralytics/trackers/utils/__pycache__/matching.cpython-39.pyc b/ultralytics/trackers/utils/__pycache__/matching.cpython-39.pyc deleted file mode 100644 index 4fe80d6..0000000 Binary files a/ultralytics/trackers/utils/__pycache__/matching.cpython-39.pyc and /dev/null differ diff --git a/ultralytics/trackers/utils/gmc.py b/ultralytics/trackers/utils/gmc.py index 4d91df4..806f1b5 100644 --- a/ultralytics/trackers/utils/gmc.py +++ b/ultralytics/trackers/utils/gmc.py @@ -9,67 +9,121 @@ from ultralytics.utils import LOGGER class GMC: + """ + Generalized Motion Compensation (GMC) class for tracking and object detection in video frames. - def __init__(self, method='sparseOptFlow', downscale=2): - """Initialize a video tracker with specified parameters.""" + This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB, + SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency. + + Attributes: + method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'. + downscale (int): Factor by which to downscale the frames for processing. + prevFrame (np.ndarray): Stores the previous frame for tracking. + prevKeyPoints (list): Stores the keypoints from the previous frame. + prevDescriptors (np.ndarray): Stores the descriptors from the previous frame. + initializedFirstFrame (bool): Flag to indicate if the first frame has been processed. + + Methods: + __init__(self, method='sparseOptFlow', downscale=2): Initializes a GMC object with the specified method + and downscale factor. + apply(self, raw_frame, detections=None): Applies the chosen method to a raw frame and optionally uses + provided detections. + applyEcc(self, raw_frame, detections=None): Applies the ECC algorithm to a raw frame. + applyFeatures(self, raw_frame, detections=None): Applies feature-based methods like ORB or SIFT to a raw frame. + applySparseOptFlow(self, raw_frame, detections=None): Applies the Sparse Optical Flow method to a raw frame. + """ + + def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None: + """ + Initialize a video tracker with specified parameters. + + Args: + method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'. + downscale (int): Downscale factor for processing frames. + """ super().__init__() self.method = method self.downscale = max(1, int(downscale)) - if self.method == 'orb': + if self.method == "orb": self.detector = cv2.FastFeatureDetector_create(20) self.extractor = cv2.ORB_create() self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING) - elif self.method == 'sift': + elif self.method == "sift": self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20) self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20) self.matcher = cv2.BFMatcher(cv2.NORM_L2) - elif self.method == 'ecc': + elif self.method == "ecc": number_of_iterations = 5000 termination_eps = 1e-6 self.warp_mode = cv2.MOTION_EUCLIDEAN self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps) - elif self.method == 'sparseOptFlow': - self.feature_params = dict(maxCorners=1000, - qualityLevel=0.01, - minDistance=1, - blockSize=3, - useHarrisDetector=False, - k=0.04) + elif self.method == "sparseOptFlow": + self.feature_params = dict( + maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3, useHarrisDetector=False, k=0.04 + ) - elif self.method in ['none', 'None', None]: + elif self.method in {"none", "None", None}: self.method = None else: - raise ValueError(f'Error: Unknown GMC method:{method}') + raise ValueError(f"Error: Unknown GMC method:{method}") self.prevFrame = None self.prevKeyPoints = None self.prevDescriptors = None - self.initializedFirstFrame = False - def apply(self, raw_frame, detections=None): - """Apply object detection on a raw frame using specified method.""" - if self.method in ['orb', 'sift']: + def apply(self, raw_frame: np.array, detections: list = None) -> np.array: + """ + Apply object detection on a raw frame using specified method. + + Args: + raw_frame (np.ndarray): The raw frame to be processed. + detections (list): List of detections to be used in the processing. + + Returns: + (np.ndarray): Processed frame. + + Examples: + >>> gmc = GMC() + >>> gmc.apply(np.array([[1, 2, 3], [4, 5, 6]])) + array([[1, 2, 3], + [4, 5, 6]]) + """ + if self.method in ["orb", "sift"]: return self.applyFeatures(raw_frame, detections) - elif self.method == 'ecc': - return self.applyEcc(raw_frame, detections) - elif self.method == 'sparseOptFlow': - return self.applySparseOptFlow(raw_frame, detections) + elif self.method == "ecc": + return self.applyEcc(raw_frame) + elif self.method == "sparseOptFlow": + return self.applySparseOptFlow(raw_frame) else: return np.eye(2, 3) - def applyEcc(self, raw_frame, detections=None): - """Initialize.""" + def applyEcc(self, raw_frame: np.array) -> np.array: + """ + Apply ECC algorithm to a raw frame. + + Args: + raw_frame (np.ndarray): The raw frame to be processed. + + Returns: + (np.ndarray): Processed frame. + + Examples: + >>> gmc = GMC() + >>> gmc.applyEcc(np.array([[1, 2, 3], [4, 5, 6]])) + array([[1, 2, 3], + [4, 5, 6]]) + """ height, width, _ = raw_frame.shape frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) H = np.eye(2, 3, dtype=np.float32) - # Downscale image (TODO: consider using pyramids) + # Downscale image if self.downscale > 1.0: frame = cv2.GaussianBlur(frame, (3, 3), 1.5) frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) @@ -89,33 +143,46 @@ class GMC: # Run the ECC algorithm. The results are stored in warp_matrix. # (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria) try: - (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1) + (_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1) except Exception as e: - LOGGER.warning(f'WARNING: find transform failed. Set warp as identity {e}') + LOGGER.warning(f"WARNING: find transform failed. Set warp as identity {e}") return H - def applyFeatures(self, raw_frame, detections=None): - """Initialize.""" + def applyFeatures(self, raw_frame: np.array, detections: list = None) -> np.array: + """ + Apply feature-based methods like ORB or SIFT to a raw frame. + + Args: + raw_frame (np.ndarray): The raw frame to be processed. + detections (list): List of detections to be used in the processing. + + Returns: + (np.ndarray): Processed frame. + + Examples: + >>> gmc = GMC() + >>> gmc.applyFeatures(np.array([[1, 2, 3], [4, 5, 6]])) + array([[1, 2, 3], + [4, 5, 6]]) + """ height, width, _ = raw_frame.shape frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) H = np.eye(2, 3) - # Downscale image (TODO: consider using pyramids) + # Downscale image if self.downscale > 1.0: - # frame = cv2.GaussianBlur(frame, (3, 3), 1.5) frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) width = width // self.downscale height = height // self.downscale # Find the keypoints mask = np.zeros_like(frame) - # mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255 - mask[int(0.02 * height):int(0.98 * height), int(0.02 * width):int(0.98 * width)] = 255 + mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255 if detections is not None: for det in detections: tlbr = (det[:4] / self.downscale).astype(np.int_) - mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0 + mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0 keypoints = self.detector.detect(frame, mask) @@ -134,10 +201,10 @@ class GMC: return H - # Match descriptors. + # Match descriptors knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2) - # Filtered matches based on smallest spatial distance + # Filter matches based on smallest spatial distance matches = [] spatialDistances = [] @@ -157,11 +224,14 @@ class GMC: prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt currKeyPointLocation = keypoints[m.trainIdx].pt - spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0], - prevKeyPointLocation[1] - currKeyPointLocation[1]) + spatialDistance = ( + prevKeyPointLocation[0] - currKeyPointLocation[0], + prevKeyPointLocation[1] - currKeyPointLocation[1], + ) - if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \ - (np.abs(spatialDistance[1]) < maxSpatialDistance[1]): + if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and ( + np.abs(spatialDistance[1]) < maxSpatialDistance[1] + ): spatialDistances.append(spatialDistance) matches.append(m) @@ -187,7 +257,7 @@ class GMC: # import matplotlib.pyplot as plt # matches_img = np.hstack((self.prevFrame, frame)) # matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR) - # W = np.size(self.prevFrame, 1) + # W = self.prevFrame.shape[1] # for m in goodMatches: # prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_) # curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_) @@ -204,7 +274,7 @@ class GMC: # plt.show() # Find rigid matrix - if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)): + if prevPoints.shape[0] > 4: H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC) # Handle downscale @@ -212,7 +282,7 @@ class GMC: H[0, 2] *= self.downscale H[1, 2] *= self.downscale else: - LOGGER.warning('WARNING: not enough matching points') + LOGGER.warning("WARNING: not enough matching points") # Store to next iteration self.prevFrame = frame.copy() @@ -221,15 +291,28 @@ class GMC: return H - def applySparseOptFlow(self, raw_frame, detections=None): - """Initialize.""" + def applySparseOptFlow(self, raw_frame: np.array) -> np.array: + """ + Apply Sparse Optical Flow method to a raw frame. + + Args: + raw_frame (np.ndarray): The raw frame to be processed. + + Returns: + (np.ndarray): Processed frame. + + Examples: + >>> gmc = GMC() + >>> gmc.applySparseOptFlow(np.array([[1, 2, 3], [4, 5, 6]])) + array([[1, 2, 3], + [4, 5, 6]]) + """ height, width, _ = raw_frame.shape frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) H = np.eye(2, 3) # Downscale image if self.downscale > 1.0: - # frame = cv2.GaussianBlur(frame, (3, 3), 1.5) frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) # Find the keypoints @@ -237,17 +320,13 @@ class GMC: # Handle first frame if not self.initializedFirstFrame: - # Initialize data self.prevFrame = frame.copy() self.prevKeyPoints = copy.copy(keypoints) - - # Initialization done self.initializedFirstFrame = True - return H # Find correspondences - matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None) + matchedKeypoints, status, _ = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None) # Leave good correspondences only prevPoints = [] @@ -262,18 +341,23 @@ class GMC: currPoints = np.array(currPoints) # Find rigid matrix - if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)): - H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC) + if (prevPoints.shape[0] > 4) and (prevPoints.shape[0] == prevPoints.shape[0]): + H, _ = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC) - # Handle downscale if self.downscale > 1.0: H[0, 2] *= self.downscale H[1, 2] *= self.downscale else: - LOGGER.warning('WARNING: not enough matching points') + LOGGER.warning("WARNING: not enough matching points") - # Store to next iteration self.prevFrame = frame.copy() self.prevKeyPoints = copy.copy(keypoints) return H + + def reset_params(self) -> None: + """Reset parameters.""" + self.prevFrame = None + self.prevKeyPoints = None + self.prevDescriptors = None + self.initializedFirstFrame = False diff --git a/ultralytics/trackers/utils/kalman_filter.py b/ultralytics/trackers/utils/kalman_filter.py index 9527ede..4ae68be 100644 --- a/ultralytics/trackers/utils/kalman_filter.py +++ b/ultralytics/trackers/utils/kalman_filter.py @@ -8,8 +8,8 @@ class KalmanFilterXYAH: """ For bytetrack. A simple Kalman filter for tracking bounding boxes in image space. - The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), - aspect ratio a, height h, and their respective velocities. + The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect + ratio a, height h, and their respective velocities. Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct observation of the state space (linear observation model). @@ -17,126 +17,126 @@ class KalmanFilterXYAH: def __init__(self): """Initialize Kalman filter model matrices with motion and observation uncertainty weights.""" - ndim, dt = 4, 1. + ndim, dt = 4, 1.0 - # Create Kalman filter model matrices. + # Create Kalman filter model matrices self._motion_mat = np.eye(2 * ndim, 2 * ndim) for i in range(ndim): self._motion_mat[i, ndim + i] = dt self._update_mat = np.eye(ndim, 2 * ndim) # Motion and observation uncertainty are chosen relative to the current state estimate. These weights control - # the amount of uncertainty in the model. This is a bit hacky. - self._std_weight_position = 1. / 20 - self._std_weight_velocity = 1. / 160 + # the amount of uncertainty in the model. + self._std_weight_position = 1.0 / 20 + self._std_weight_velocity = 1.0 / 160 - def initiate(self, measurement): + def initiate(self, measurement: np.ndarray) -> tuple: """ Create track from unassociated measurement. - Parameters - ---------- - measurement : ndarray - Bounding box coordinates (x, y, a, h) with center position (x, y), - aspect ratio a, and height h. + Args: + measurement (ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a, + and height h. - Returns - ------- - (ndarray, ndarray) - Returns the mean vector (8 dimensional) and covariance matrix (8x8 - dimensional) of the new track. Unobserved velocities are initialized - to 0 mean. + Returns: + (tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of + the new track. Unobserved velocities are initialized to 0 mean. """ mean_pos = measurement mean_vel = np.zeros_like(mean_pos) mean = np.r_[mean_pos, mean_vel] std = [ - 2 * self._std_weight_position * measurement[3], 2 * self._std_weight_position * measurement[3], 1e-2, - 2 * self._std_weight_position * measurement[3], 10 * self._std_weight_velocity * measurement[3], - 10 * self._std_weight_velocity * measurement[3], 1e-5, 10 * self._std_weight_velocity * measurement[3]] + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[3], + 1e-2, + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 1e-5, + 10 * self._std_weight_velocity * measurement[3], + ] covariance = np.diag(np.square(std)) return mean, covariance - def predict(self, mean, covariance): + def predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple: """ Run Kalman filter prediction step. - Parameters - ---------- - mean : ndarray - The 8 dimensional mean vector of the object state at the previous time step. - covariance : ndarray - The 8x8 dimensional covariance matrix of the object state at the previous time step. + Args: + mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step. + covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step. - Returns - ------- - (ndarray, ndarray) - Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are - initialized to 0 mean. + Returns: + (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved + velocities are initialized to 0 mean. """ std_pos = [ - self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2, - self._std_weight_position * mean[3]] + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-2, + self._std_weight_position * mean[3], + ] std_vel = [ - self._std_weight_velocity * mean[3], self._std_weight_velocity * mean[3], 1e-5, - self._std_weight_velocity * mean[3]] + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[3], + 1e-5, + self._std_weight_velocity * mean[3], + ] motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) - # mean = np.dot(self._motion_mat, mean) mean = np.dot(mean, self._motion_mat.T) covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov return mean, covariance - def project(self, mean, covariance): + def project(self, mean: np.ndarray, covariance: np.ndarray) -> tuple: """ Project state distribution to measurement space. - Parameters - ---------- - mean : ndarray - The state's mean vector (8 dimensional array). - covariance : ndarray - The state's covariance matrix (8x8 dimensional). + Args: + mean (ndarray): The state's mean vector (8 dimensional array). + covariance (ndarray): The state's covariance matrix (8x8 dimensional). - Returns - ------- - (ndarray, ndarray) - Returns the projected mean and covariance matrix of the given state estimate. + Returns: + (tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate. """ std = [ - self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1, - self._std_weight_position * mean[3]] + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-1, + self._std_weight_position * mean[3], + ] innovation_cov = np.diag(np.square(std)) mean = np.dot(self._update_mat, mean) covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T)) return mean, covariance + innovation_cov - def multi_predict(self, mean, covariance): + def multi_predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple: """ Run Kalman filter prediction step (Vectorized version). - Parameters - ---------- - mean : ndarray - The Nx8 dimensional mean matrix of the object states at the previous time step. - covariance : ndarray - The Nx8x8 dimensional covariance matrix of the object states at the previous time step. + Args: + mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step. + covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step. - Returns - ------- - (ndarray, ndarray) - Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are - initialized to 0 mean. + Returns: + (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved + velocities are initialized to 0 mean. """ std_pos = [ - self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3], - 1e-2 * np.ones_like(mean[:, 3]), self._std_weight_position * mean[:, 3]] + self._std_weight_position * mean[:, 3], + self._std_weight_position * mean[:, 3], + 1e-2 * np.ones_like(mean[:, 3]), + self._std_weight_position * mean[:, 3], + ] std_vel = [ - self._std_weight_velocity * mean[:, 3], self._std_weight_velocity * mean[:, 3], - 1e-5 * np.ones_like(mean[:, 3]), self._std_weight_velocity * mean[:, 3]] + self._std_weight_velocity * mean[:, 3], + self._std_weight_velocity * mean[:, 3], + 1e-5 * np.ones_like(mean[:, 3]), + self._std_weight_velocity * mean[:, 3], + ] sqr = np.square(np.r_[std_pos, std_vel]).T motion_cov = [np.diag(sqr[i]) for i in range(len(mean))] @@ -148,60 +148,57 @@ class KalmanFilterXYAH: return mean, covariance - def update(self, mean, covariance, measurement): + def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray) -> tuple: """ Run Kalman filter correction step. - Parameters - ---------- - mean : ndarray - The predicted state's mean vector (8 dimensional). - covariance : ndarray - The state's covariance matrix (8x8 dimensional). - measurement : ndarray - The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center position, a the aspect - ratio, and h the height of the bounding box. + Args: + mean (ndarray): The predicted state's mean vector (8 dimensional). + covariance (ndarray): The state's covariance matrix (8x8 dimensional). + measurement (ndarray): The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center + position, a the aspect ratio, and h the height of the bounding box. - Returns - ------- - (ndarray, ndarray) - Returns the measurement-corrected state distribution. + Returns: + (tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution. """ projected_mean, projected_cov = self.project(mean, covariance) chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False) - kalman_gain = scipy.linalg.cho_solve((chol_factor, lower), - np.dot(covariance, self._update_mat.T).T, - check_finite=False).T + kalman_gain = scipy.linalg.cho_solve( + (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False + ).T innovation = measurement - projected_mean new_mean = mean + np.dot(innovation, kalman_gain.T) new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T)) return new_mean, new_covariance - def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'): + def gating_distance( + self, + mean: np.ndarray, + covariance: np.ndarray, + measurements: np.ndarray, + only_position: bool = False, + metric: str = "maha", + ) -> np.ndarray: """ Compute gating distance between state distribution and measurements. A suitable distance threshold can be - obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of - freedom, otherwise 2. + obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of freedom, + otherwise 2. - Parameters - ---------- - mean : ndarray - Mean vector over the state distribution (8 dimensional). - covariance : ndarray - Covariance of the state distribution (8x8 dimensional). - measurements : ndarray - An Nx4 dimensional matrix of N measurements, each in format (x, y, a, h) where (x, y) is the bounding box - center position, a the aspect ratio, and h the height. - only_position : Optional[bool] - If True, distance computation is done with respect to the bounding box center position only. + Args: + mean (ndarray): Mean vector over the state distribution (8 dimensional). + covariance (ndarray): Covariance of the state distribution (8x8 dimensional). + measurements (ndarray): An Nx4 matrix of N measurements, each in format (x, y, a, h) where (x, y) + is the bounding box center position, a the aspect ratio, and h the height. + only_position (bool, optional): If True, distance computation is done with respect to the bounding box + center position only. Defaults to False. + metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the + squared Euclidean distance and 'maha' for the squared Mahalanobis distance. Defaults to 'maha'. - Returns - ------- - ndarray - Returns an array of length N, where the i-th element contains the squared Mahalanobis distance between - (mean, covariance) and `measurements[i]`. + Returns: + (np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between + (mean, covariance) and `measurements[i]`. """ mean, covariance = self.project(mean, covariance) if only_position: @@ -209,77 +206,79 @@ class KalmanFilterXYAH: measurements = measurements[:, :2] d = measurements - mean - if metric == 'gaussian': + if metric == "gaussian": return np.sum(d * d, axis=1) - elif metric == 'maha': + elif metric == "maha": cholesky_factor = np.linalg.cholesky(covariance) z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True) return np.sum(z * z, axis=0) # square maha else: - raise ValueError('invalid distance metric') + raise ValueError("Invalid distance metric") class KalmanFilterXYWH(KalmanFilterXYAH): """ For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space. - The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y), - width w, height h, and their respective velocities. + The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y), width + w, height h, and their respective velocities. Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct observation of the state space (linear observation model). """ - def initiate(self, measurement): + def initiate(self, measurement: np.ndarray) -> tuple: """ Create track from unassociated measurement. - Parameters - ---------- - measurement : ndarray - Bounding box coordinates (x, y, w, h) with center position (x, y), width w, and height h. + Args: + measurement (ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height. - Returns - ------- - (ndarray, ndarray) - Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of the new track. - Unobserved velocities are initialized to 0 mean. + Returns: + (tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of + the new track. Unobserved velocities are initialized to 0 mean. """ mean_pos = measurement mean_vel = np.zeros_like(mean_pos) mean = np.r_[mean_pos, mean_vel] std = [ - 2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3], - 2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3], - 10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3], - 10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3]] + 2 * self._std_weight_position * measurement[2], + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[2], + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[2], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[2], + 10 * self._std_weight_velocity * measurement[3], + ] covariance = np.diag(np.square(std)) return mean, covariance - def predict(self, mean, covariance): + def predict(self, mean, covariance) -> tuple: """ Run Kalman filter prediction step. - Parameters - ---------- - mean : ndarray - The 8 dimensional mean vector of the object state at the previous time step. - covariance : ndarray - The 8x8 dimensional covariance matrix of the object state at the previous time step. + Args: + mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step. + covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step. - Returns - ------- - (ndarray, ndarray) - Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are - initialized to 0 mean. + Returns: + (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved + velocities are initialized to 0 mean. """ std_pos = [ - self._std_weight_position * mean[2], self._std_weight_position * mean[3], - self._std_weight_position * mean[2], self._std_weight_position * mean[3]] + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + ] std_vel = [ - self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3], - self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3]] + self._std_weight_velocity * mean[2], + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[2], + self._std_weight_velocity * mean[3], + ] motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) mean = np.dot(mean, self._motion_mat.T) @@ -287,54 +286,53 @@ class KalmanFilterXYWH(KalmanFilterXYAH): return mean, covariance - def project(self, mean, covariance): + def project(self, mean, covariance) -> tuple: """ Project state distribution to measurement space. - Parameters - ---------- - mean : ndarray - The state's mean vector (8 dimensional array). - covariance : ndarray - The state's covariance matrix (8x8 dimensional). + Args: + mean (ndarray): The state's mean vector (8 dimensional array). + covariance (ndarray): The state's covariance matrix (8x8 dimensional). - Returns - ------- - (ndarray, ndarray) - Returns the projected mean and covariance matrix of the given state estimate. + Returns: + (tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate. """ std = [ - self._std_weight_position * mean[2], self._std_weight_position * mean[3], - self._std_weight_position * mean[2], self._std_weight_position * mean[3]] + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + ] innovation_cov = np.diag(np.square(std)) mean = np.dot(self._update_mat, mean) covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T)) return mean, covariance + innovation_cov - def multi_predict(self, mean, covariance): + def multi_predict(self, mean, covariance) -> tuple: """ Run Kalman filter prediction step (Vectorized version). - Parameters - ---------- - mean : ndarray - The Nx8 dimensional mean matrix of the object states at the previous time step. - covariance : ndarray - The Nx8x8 dimensional covariance matrix of the object states at the previous time step. + Args: + mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step. + covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step. - Returns - ------- - (ndarray, ndarray) - Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are - initialized to 0 mean. + Returns: + (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved + velocities are initialized to 0 mean. """ std_pos = [ - self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3], - self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3]] + self._std_weight_position * mean[:, 2], + self._std_weight_position * mean[:, 3], + self._std_weight_position * mean[:, 2], + self._std_weight_position * mean[:, 3], + ] std_vel = [ - self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3], - self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3]] + self._std_weight_velocity * mean[:, 2], + self._std_weight_velocity * mean[:, 3], + self._std_weight_velocity * mean[:, 2], + self._std_weight_velocity * mean[:, 3], + ] sqr = np.square(np.r_[std_pos, std_vel]).T motion_cov = [np.diag(sqr[i]) for i in range(len(mean))] @@ -346,23 +344,17 @@ class KalmanFilterXYWH(KalmanFilterXYAH): return mean, covariance - def update(self, mean, covariance, measurement): + def update(self, mean, covariance, measurement) -> tuple: """ Run Kalman filter correction step. - Parameters - ---------- - mean : ndarray - The predicted state's mean vector (8 dimensional). - covariance : ndarray - The state's covariance matrix (8x8 dimensional). - measurement : ndarray - The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center position, w the width, - and h the height of the bounding box. + Args: + mean (ndarray): The predicted state's mean vector (8 dimensional). + covariance (ndarray): The state's covariance matrix (8x8 dimensional). + measurement (ndarray): The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center + position, w the width, and h the height of the bounding box. - Returns - ------- - (ndarray, ndarray) - Returns the measurement-corrected state distribution. + Returns: + (tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution. """ return super().update(mean, covariance, measurement) diff --git a/ultralytics/trackers/utils/matching.py b/ultralytics/trackers/utils/matching.py index f2ee75e..fa72b8b 100644 --- a/ultralytics/trackers/utils/matching.py +++ b/ultralytics/trackers/utils/matching.py @@ -4,7 +4,7 @@ import numpy as np import scipy from scipy.spatial.distance import cdist -from ultralytics.utils.metrics import bbox_ioa +from ultralytics.utils.metrics import bbox_ioa, batch_probiou try: import lap # for linear_assignment @@ -13,11 +13,11 @@ try: except (ImportError, AssertionError, AttributeError): from ultralytics.utils.checks import check_requirements - check_requirements('lapx>=0.5.2') # update to lap package from https://github.com/rathaROG/lapx + check_requirements("lapx>=0.5.2") # update to lap package from https://github.com/rathaROG/lapx import lap -def linear_assignment(cost_matrix, thresh, use_lap=True): +def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple: """ Perform linear assignment using scipy or lap.lapjv. @@ -27,19 +27,24 @@ def linear_assignment(cost_matrix, thresh, use_lap=True): use_lap (bool, optional): Whether to use lap.lapjv. Defaults to True. Returns: - (tuple): Tuple containing matched indices, unmatched indices from 'a', and unmatched indices from 'b'. + Tuple with: + - matched indices + - unmatched indices from 'a' + - unmatched indices from 'b' """ if cost_matrix.size == 0: return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) if use_lap: + # Use lap.lapjv # https://github.com/gatagat/lap _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0] unmatched_a = np.where(x < 0)[0] unmatched_b = np.where(y < 0)[0] else: + # Use scipy.optimize.linear_sum_assignment # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh]) @@ -53,7 +58,7 @@ def linear_assignment(cost_matrix, thresh, use_lap=True): return matches, unmatched_a, unmatched_b -def iou_distance(atracks, btracks): +def iou_distance(atracks: list, btracks: list) -> np.ndarray: """ Compute cost based on Intersection over Union (IoU) between tracks. @@ -65,23 +70,30 @@ def iou_distance(atracks, btracks): (np.ndarray): Cost matrix computed based on IoU. """ - if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \ - or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)): + if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray): atlbrs = atracks btlbrs = btracks else: - atlbrs = [track.tlbr for track in atracks] - btlbrs = [track.tlbr for track in btracks] + atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks] + btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks] ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) if len(atlbrs) and len(btlbrs): - ious = bbox_ioa(np.ascontiguousarray(atlbrs, dtype=np.float32), - np.ascontiguousarray(btlbrs, dtype=np.float32), - iou=True) + if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5: + ious = batch_probiou( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + ).numpy() + else: + ious = bbox_ioa( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + iou=True, + ) return 1 - ious # cost matrix -def embedding_distance(tracks, detections, metric='cosine'): +def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray: """ Compute distance between tracks and detections based on embeddings. @@ -105,7 +117,7 @@ def embedding_distance(tracks, detections, metric='cosine'): return cost_matrix -def fuse_score(cost_matrix, detections): +def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray: """ Fuses cost matrix with detection scores to produce a single similarity matrix. diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py index 872ce5f..93347f5 100644 --- a/ultralytics/utils/__init__.py +++ b/ultralytics/utils/__init__.py @@ -9,6 +9,7 @@ import re import subprocess import sys import threading +import time import urllib import uuid from pathlib import Path @@ -25,23 +26,22 @@ from tqdm import tqdm as tqdm_original from ultralytics import __version__ # PyTorch Multi-GPU DDP Constants -RANK = int(os.getenv('RANK', -1)) -LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html +RANK = int(os.getenv("RANK", -1)) +LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html # Other Constants FILE = Path(__file__).resolve() ROOT = FILE.parents[1] # YOLO -ASSETS = ROOT / 'assets' # default images -DEFAULT_CFG_PATH = ROOT / 'cfg/default.yaml' +ASSETS = ROOT / "assets" # default images +DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml" NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads -AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode -VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode -TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' if VERBOSE else None # tqdm bar format -LOGGING_NAME = 'ultralytics' -MACOS, LINUX, WINDOWS = (platform.system() == x for x in ['Darwin', 'Linux', 'Windows']) # environment booleans -ARM64 = platform.machine() in ('arm64', 'aarch64') # ARM64 booleans -HELP_MSG = \ - """ +AUTOINSTALL = str(os.getenv("YOLO_AUTOINSTALL", True)).lower() == "true" # global auto-install mode +VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbose mode +TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar format +LOGGING_NAME = "ultralytics" +MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans +ARM64 = platform.machine() in ("arm64", "aarch64") # ARM64 booleans +HELP_MSG = """ Usage examples for running YOLOv8: 1. Install the ultralytics package: @@ -77,7 +77,7 @@ HELP_MSG = \ yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01 - Predict a YouTube video using a pretrained segmentation model at image size 320: - yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320 + yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320 - Val a pretrained detection model at batch-size 1 and image size 640: yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640 @@ -99,12 +99,12 @@ HELP_MSG = \ """ # Settings -torch.set_printoptions(linewidth=320, precision=4, profile='default') -np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 +torch.set_printoptions(linewidth=320, precision=4, profile="default") +np.set_printoptions(linewidth=320, formatter={"float_kind": "{:11.5g}".format}) # format short g, %precision=5 cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) -os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads -os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # suppress verbose TF compiler warnings in Colab +os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS) # NumExpr max threads +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # for deterministic training +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # suppress verbose TF compiler warnings in Colab class TQDM(tqdm_original): @@ -113,19 +113,22 @@ class TQDM(tqdm_original): Args: *args (list): Positional arguments passed to original tqdm. - **kwargs (dict): Keyword arguments, with custom defaults applied. + **kwargs (any): Keyword arguments, with custom defaults applied. """ def __init__(self, *args, **kwargs): - # Set new default values (these can still be overridden when calling TQDM) - kwargs['disable'] = not VERBOSE or kwargs.get('disable', False) # logical 'and' with default value if passed - kwargs.setdefault('bar_format', TQDM_BAR_FORMAT) # override default value if passed + """ + Initialize custom Ultralytics tqdm class with different default arguments. + + Note these can still be overridden when calling TQDM. + """ + kwargs["disable"] = not VERBOSE or kwargs.get("disable", False) # logical 'and' with default value if passed + kwargs.setdefault("bar_format", TQDM_BAR_FORMAT) # override default value if passed super().__init__(*args, **kwargs) class SimpleClass: - """ - Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute + """Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute access methods for easier debugging and usage. """ @@ -134,14 +137,14 @@ class SimpleClass: attr = [] for a in dir(self): v = getattr(self, a) - if not callable(v) and not a.startswith('_'): + if not callable(v) and not a.startswith("_"): if isinstance(v, SimpleClass): # Display only the module and class name for subclasses - s = f'{a}: {v.__module__}.{v.__class__.__name__} object' + s = f"{a}: {v.__module__}.{v.__class__.__name__} object" else: - s = f'{a}: {repr(v)}' + s = f"{a}: {repr(v)}" attr.append(s) - return f'{self.__module__}.{self.__class__.__name__} object with attributes:\n\n' + '\n'.join(attr) + return f"{self.__module__}.{self.__class__.__name__} object with attributes:\n\n" + "\n".join(attr) def __repr__(self): """Return a machine-readable string representation of the object.""" @@ -154,8 +157,7 @@ class SimpleClass: class IterableSimpleNamespace(SimpleNamespace): - """ - Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and + """Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and enables usage with dict() and for loops. """ @@ -165,24 +167,26 @@ class IterableSimpleNamespace(SimpleNamespace): def __str__(self): """Return a human-readable string representation of the object.""" - return '\n'.join(f'{k}={v}' for k, v in vars(self).items()) + return "\n".join(f"{k}={v}" for k, v in vars(self).items()) def __getattr__(self, attr): """Custom attribute access error message with helpful information.""" name = self.__class__.__name__ - raise AttributeError(f""" + raise AttributeError( + f""" '{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics 'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace {DEFAULT_CFG_PATH} with the latest version from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml - """) + """ + ) def get(self, key, default=None): """Return the value of the specified key if it exists; otherwise, return the default value.""" return getattr(self, key, default) -def plt_settings(rcparams=None, backend='Agg'): +def plt_settings(rcparams=None, backend="Agg"): """ Decorator to temporarily set rc parameters and the backend for a plotting function. @@ -200,7 +204,7 @@ def plt_settings(rcparams=None, backend='Agg'): """ if rcparams is None: - rcparams = {'font.size': 11} + rcparams = {"font.size": 11} def decorator(func): """Decorator to apply temporary rc parameters and backend to a function.""" @@ -208,12 +212,16 @@ def plt_settings(rcparams=None, backend='Agg'): def wrapper(*args, **kwargs): """Sets rc parameters and backend, calls the original function, and restores the settings.""" original_backend = plt.get_backend() - plt.switch_backend(backend) + if backend.lower() != original_backend.lower(): + plt.close("all") # auto-close()ing of figures upon backend switching is deprecated since 3.8 + plt.switch_backend(backend) with plt.rc_context(rcparams): result = func(*args, **kwargs) - plt.switch_backend(original_backend) + if backend != original_backend: + plt.close("all") + plt.switch_backend(original_backend) return result return wrapper @@ -222,58 +230,59 @@ def plt_settings(rcparams=None, backend='Agg'): def set_logging(name=LOGGING_NAME, verbose=True): - """Sets up logging for the given name.""" - rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings - level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR - logging.config.dictConfig({ - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - name: { - 'format': '%(message)s'}}, - 'handlers': { - name: { - 'class': 'logging.StreamHandler', - 'formatter': name, - 'level': level}}, - 'loggers': { - name: { - 'level': level, - 'handlers': [name], - 'propagate': False}}}) + """Sets up logging for the given name with UTF-8 encoding support.""" + level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings + # Configure the console (stdout) encoding to UTF-8 + formatter = logging.Formatter("%(message)s") # Default formatter + if WINDOWS and sys.stdout.encoding != "utf-8": + try: + if hasattr(sys.stdout, "reconfigure"): + sys.stdout.reconfigure(encoding="utf-8") + elif hasattr(sys.stdout, "buffer"): + import io -def emojis(string=''): - """Return platform-dependent emoji-safe version of string.""" - return string.encode().decode('ascii', 'ignore') if WINDOWS else string + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + else: + sys.stdout.encoding = "utf-8" + except Exception as e: + print(f"Creating custom formatter for non UTF-8 environments due to {e}") + class CustomFormatter(logging.Formatter): + def format(self, record): + """Sets up logging with UTF-8 encoding and configurable verbosity.""" + return emojis(super().format(record)) -class EmojiFilter(logging.Filter): - """ - A custom logging filter class for removing emojis in log messages. + formatter = CustomFormatter("%(message)s") # Use CustomFormatter to eliminate UTF-8 output as last recourse - This filter is particularly useful for ensuring compatibility with Windows terminals - that may not support the display of emojis in log messages. - """ + # Create and configure the StreamHandler + stream_handler = logging.StreamHandler(sys.stdout) + stream_handler.setFormatter(formatter) + stream_handler.setLevel(level) - def filter(self, record): - """Filter logs by emoji unicode characters on windows.""" - record.msg = emojis(record.msg) - return super().filter(record) + logger = logging.getLogger(name) + logger.setLevel(level) + logger.addHandler(stream_handler) + logger.propagate = False + return logger # Set logger -set_logging(LOGGING_NAME, verbose=VERBOSE) # run before defining LOGGER -LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.) -if WINDOWS: # emoji-safe logging - LOGGER.addFilter(EmojiFilter()) +LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE) # define globally (used in train.py, val.py, predict.py, etc.) +for logger in "sentry_sdk", "urllib3.connectionpool": + logging.getLogger(logger).setLevel(logging.CRITICAL + 1) + + +def emojis(string=""): + """Return platform-dependent emoji-safe version of string.""" + return string.encode().decode("ascii", "ignore") if WINDOWS else string class ThreadingLocked: """ - A decorator class for ensuring thread-safe execution of a function or method. - This class can be used as a decorator to make sure that if the decorated function - is called from multiple threads, only one thread at a time will be able to execute the function. + A decorator class for ensuring thread-safe execution of a function or method. This class can be used as a decorator + to make sure that if the decorated function is called from multiple threads, only one thread at a time will be able + to execute the function. Attributes: lock (threading.Lock): A lock object used to manage access to the decorated function. @@ -290,20 +299,23 @@ class ThreadingLocked: """ def __init__(self): + """Initializes the decorator class for thread-safe execution of a function or method.""" self.lock = threading.Lock() def __call__(self, f): + """Run thread-safe execution of function or method.""" from functools import wraps @wraps(f) def decorated(*args, **kwargs): + """Applies thread-safety to the decorated function or method.""" with self.lock: return f(*args, **kwargs) return decorated -def yaml_save(file='data.yaml', data=None, header=''): +def yaml_save(file="data.yaml", data=None, header=""): """ Save YAML data to a file. @@ -323,18 +335,19 @@ def yaml_save(file='data.yaml', data=None, header=''): file.parent.mkdir(parents=True, exist_ok=True) # Convert Path objects to strings + valid_types = int, float, str, bool, list, tuple, dict, type(None) for k, v in data.items(): - if isinstance(v, Path): + if not isinstance(v, valid_types): data[k] = str(v) # Dump data to file in YAML format - with open(file, 'w', errors='ignore', encoding='utf-8') as f: + with open(file, "w", errors="ignore", encoding="utf-8") as f: if header: f.write(header) yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True) -def yaml_load(file='data.yaml', append_filename=False): +def yaml_load(file="data.yaml", append_filename=False): """ Load YAML data from a file. @@ -345,18 +358,18 @@ def yaml_load(file='data.yaml', append_filename=False): Returns: (dict): YAML data and file name. """ - assert Path(file).suffix in ('.yaml', '.yml'), f'Attempting to load non-YAML file {file} with yaml_load()' - with open(file, errors='ignore', encoding='utf-8') as f: + assert Path(file).suffix in (".yaml", ".yml"), f"Attempting to load non-YAML file {file} with yaml_load()" + with open(file, errors="ignore", encoding="utf-8") as f: s = f.read() # string # Remove special characters if not s.isprintable(): - s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s) + s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s) # Add YAML filename to dict and return data = yaml.safe_load(s) or {} # always return a dict (yaml.safe_load() may return None for empty files) if append_filename: - data['yaml_file'] = str(file) + data["yaml_file"] = str(file) return data @@ -368,7 +381,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None: yaml_file: The file path of the YAML file or a YAML-formatted dictionary. Returns: - None + (None) """ yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True) @@ -378,7 +391,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None: # Default configuration DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH) for k, v in DEFAULT_CFG_DICT.items(): - if isinstance(v, str) and v.lower() == 'none': + if isinstance(v, str) and v.lower() == "none": DEFAULT_CFG_DICT[k] = None DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys() DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT) @@ -392,8 +405,8 @@ def is_ubuntu() -> bool: (bool): True if OS is Ubuntu, False otherwise. """ with contextlib.suppress(FileNotFoundError): - with open('/etc/os-release') as f: - return 'ID=ubuntu' in f.read() + with open("/etc/os-release") as f: + return "ID=ubuntu" in f.read() return False @@ -404,7 +417,7 @@ def is_colab(): Returns: (bool): True if running inside a Colab notebook, False otherwise. """ - return 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_BACKEND_VERSION' in os.environ + return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environ def is_kaggle(): @@ -414,19 +427,19 @@ def is_kaggle(): Returns: (bool): True if running inside a Kaggle kernel, False otherwise. """ - return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com' + return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com" def is_jupyter(): """ - Check if the current script is running inside a Jupyter Notebook. - Verified on Colab, Jupyterlab, Kaggle, Paperspace. + Check if the current script is running inside a Jupyter Notebook. Verified on Colab, Jupyterlab, Kaggle, Paperspace. Returns: (bool): True if running inside a Jupyter Notebook, False otherwise. """ with contextlib.suppress(Exception): from IPython import get_ipython + return get_ipython() is not None return False @@ -438,10 +451,10 @@ def is_docker() -> bool: Returns: (bool): True if the script is running inside a Docker container, False otherwise. """ - file = Path('/proc/self/cgroup') + file = Path("/proc/self/cgroup") if file.exists(): with open(file) as f: - return 'docker' in f.read() + return "docker" in f.read() else: return False @@ -455,7 +468,7 @@ def is_online() -> bool: """ import socket - for host in '1.1.1.1', '8.8.8.8', '223.5.5.5': # Cloudflare, Google, AliDNS: + for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS: try: test_connection = socket.create_connection(address=(host, 53), timeout=2) except (socket.timeout, socket.gaierror, OSError): @@ -509,23 +522,23 @@ def is_pytest_running(): Returns: (bool): True if pytest is running, False otherwise. """ - return ('PYTEST_CURRENT_TEST' in os.environ) or ('pytest' in sys.modules) or ('pytest' in Path(sys.argv[0]).stem) + return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(sys.argv[0]).stem) -def is_github_actions_ci() -> bool: +def is_github_action_running() -> bool: """ - Determine if the current environment is a GitHub Actions CI Python runner. + Determine if the current environment is a GitHub Actions runner. Returns: - (bool): True if the current environment is a GitHub Actions CI Python runner, False otherwise. + (bool): True if the current environment is a GitHub Actions runner, False otherwise. """ - return 'GITHUB_ACTIONS' in os.environ and 'RUNNER_OS' in os.environ and 'RUNNER_TOOL_CACHE' in os.environ + return "GITHUB_ACTIONS" in os.environ and "GITHUB_WORKFLOW" in os.environ and "RUNNER_OS" in os.environ def is_git_dir(): """ - Determines whether the current file is part of a git repository. - If the current file is not part of a git repository, returns None. + Determines whether the current file is part of a git repository. If the current file is not part of a git + repository, returns None. Returns: (bool): True if current file is part of a git repository. @@ -535,14 +548,14 @@ def is_git_dir(): def get_git_dir(): """ - Determines whether the current file is part of a git repository and if so, returns the repository root directory. - If the current file is not part of a git repository, returns None. + Determines whether the current file is part of a git repository and if so, returns the repository root directory. If + the current file is not part of a git repository, returns None. Returns: (Path | None): Git root directory if found or None if not found. """ for d in Path(__file__).parents: - if (d / '.git').is_dir(): + if (d / ".git").is_dir(): return d @@ -555,7 +568,7 @@ def get_git_origin_url(): """ if is_git_dir(): with contextlib.suppress(subprocess.CalledProcessError): - origin = subprocess.check_output(['git', 'config', '--get', 'remote.origin.url']) + origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"]) return origin.decode().strip() @@ -568,12 +581,13 @@ def get_git_branch(): """ if is_git_dir(): with contextlib.suppress(subprocess.CalledProcessError): - origin = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]) return origin.decode().strip() def get_default_args(func): - """Returns a dictionary of default arguments for a function. + """ + Returns a dictionary of default arguments for a function. Args: func (callable): The function to inspect. @@ -594,13 +608,13 @@ def get_ubuntu_version(): """ if is_ubuntu(): with contextlib.suppress(FileNotFoundError, AttributeError): - with open('/etc/os-release') as f: + with open("/etc/os-release") as f: return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1] -def get_user_config_dir(sub_dir='Ultralytics'): +def get_user_config_dir(sub_dir="yolov10"): """ - Get the user config directory. + Return the appropriate config directory based on the environment operating system. Args: sub_dir (str): The name of the subdirectory to create. @@ -608,21 +622,22 @@ def get_user_config_dir(sub_dir='Ultralytics'): Returns: (Path): The path to the user config directory. """ - # Return the appropriate config directory for each operating system if WINDOWS: - path = Path.home() / 'AppData' / 'Roaming' / sub_dir + path = Path.home() / "AppData" / "Roaming" / sub_dir elif MACOS: # macOS - path = Path.home() / 'Library' / 'Application Support' / sub_dir + path = Path.home() / "Library" / "Application Support" / sub_dir elif LINUX: - path = Path.home() / '.config' / sub_dir + path = Path.home() / ".config" / sub_dir else: - raise ValueError(f'Unsupported operating system: {platform.system()}') + raise ValueError(f"Unsupported operating system: {platform.system()}") # GCP and AWS lambda fix, only /tmp is writeable if not is_dir_writeable(path.parent): - LOGGER.warning(f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD." - 'Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path.') - path = Path('/tmp') / sub_dir if is_dir_writeable('/tmp') else Path().cwd() / sub_dir + LOGGER.warning( + f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD." + "Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path." + ) + path = Path("/tmp") / sub_dir if is_dir_writeable("/tmp") else Path().cwd() / sub_dir # Create the subdirectory if it does not exist path.mkdir(parents=True, exist_ok=True) @@ -630,40 +645,99 @@ def get_user_config_dir(sub_dir='Ultralytics'): return path -USER_CONFIG_DIR = Path(os.getenv('YOLO_CONFIG_DIR') or get_user_config_dir()) # Ultralytics settings dir -SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml' +USER_CONFIG_DIR = Path(os.getenv("YOLO_CONFIG_DIR") or get_user_config_dir()) # Ultralytics settings dir +SETTINGS_YAML = USER_CONFIG_DIR / "settings.yaml" def colorstr(*input): - """Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world').""" - *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string + """ + Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes. + See https://en.wikipedia.org/wiki/ANSI_escape_code for more details. + + This function can be called in two ways: + - colorstr('color', 'style', 'your string') + - colorstr('your string') + + In the second form, 'blue' and 'bold' will be applied by default. + + Args: + *input (str): A sequence of strings where the first n-1 strings are color and style arguments, + and the last string is the one to be colored. + + Supported Colors and Styles: + Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white' + Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow', + 'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white' + Misc: 'end', 'bold', 'underline' + + Returns: + (str): The input string wrapped with ANSI escape codes for the specified color and style. + + Examples: + >>> colorstr('blue', 'bold', 'hello world') + >>> '\033[34m\033[1mhello world\033[0m' + """ + *args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string colors = { - 'black': '\033[30m', # basic colors - 'red': '\033[31m', - 'green': '\033[32m', - 'yellow': '\033[33m', - 'blue': '\033[34m', - 'magenta': '\033[35m', - 'cyan': '\033[36m', - 'white': '\033[37m', - 'bright_black': '\033[90m', # bright colors - 'bright_red': '\033[91m', - 'bright_green': '\033[92m', - 'bright_yellow': '\033[93m', - 'bright_blue': '\033[94m', - 'bright_magenta': '\033[95m', - 'bright_cyan': '\033[96m', - 'bright_white': '\033[97m', - 'end': '\033[0m', # misc - 'bold': '\033[1m', - 'underline': '\033[4m'} - return ''.join(colors[x] for x in args) + f'{string}' + colors['end'] + "black": "\033[30m", # basic colors + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "magenta": "\033[35m", + "cyan": "\033[36m", + "white": "\033[37m", + "bright_black": "\033[90m", # bright colors + "bright_red": "\033[91m", + "bright_green": "\033[92m", + "bright_yellow": "\033[93m", + "bright_blue": "\033[94m", + "bright_magenta": "\033[95m", + "bright_cyan": "\033[96m", + "bright_white": "\033[97m", + "end": "\033[0m", # misc + "bold": "\033[1m", + "underline": "\033[4m", + } + return "".join(colors[x] for x in args) + f"{string}" + colors["end"] + + +def remove_colorstr(input_string): + """ + Removes ANSI escape codes from a string, effectively un-coloring it. + + Args: + input_string (str): The string to remove color and style from. + + Returns: + (str): A new string with all ANSI escape codes removed. + + Examples: + >>> remove_colorstr(colorstr('blue', 'bold', 'hello world')) + >>> 'hello world' + """ + ansi_escape = re.compile(r"\x1B\[[0-9;]*[A-Za-z]") + return ansi_escape.sub("", input_string) class TryExcept(contextlib.ContextDecorator): - """YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager.""" + """ + Ultralytics TryExcept class. Use as @TryExcept() decorator or 'with TryExcept():' context manager. - def __init__(self, msg='', verbose=True): + Examples: + As a decorator: + >>> @TryExcept(msg="Error occurred in func", verbose=True) + >>> def func(): + >>> # Function logic here + >>> pass + + As a context manager: + >>> with TryExcept(msg="Error occurred in block", verbose=True): + >>> # Code block here + >>> pass + """ + + def __init__(self, msg="", verbose=True): """Initialize TryExcept class with optional message and verbosity settings.""" self.msg = msg self.verbose = verbose @@ -679,14 +753,80 @@ class TryExcept(contextlib.ContextDecorator): return True +class Retry(contextlib.ContextDecorator): + """ + Retry class for function execution with exponential backoff. + + Can be used as a decorator or a context manager to retry a function or block of code on exceptions, up to a + specified number of times with an exponentially increasing delay between retries. + + Examples: + Example usage as a decorator: + >>> @Retry(times=3, delay=2) + >>> def test_func(): + >>> # Replace with function logic that may raise exceptions + >>> return True + + Example usage as a context manager: + >>> with Retry(times=3, delay=2): + >>> # Replace with code block that may raise exceptions + >>> pass + """ + + def __init__(self, times=3, delay=2): + """Initialize Retry class with specified number of retries and delay.""" + self.times = times + self.delay = delay + self._attempts = 0 + + def __call__(self, func): + """Decorator implementation for Retry with exponential backoff.""" + + def wrapped_func(*args, **kwargs): + """Applies retries to the decorated function or method.""" + self._attempts = 0 + while self._attempts < self.times: + try: + return func(*args, **kwargs) + except Exception as e: + self._attempts += 1 + print(f"Retry {self._attempts}/{self.times} failed: {e}") + if self._attempts >= self.times: + raise e + time.sleep(self.delay * (2**self._attempts)) # exponential backoff delay + + return wrapped_func + + def __enter__(self): + """Enter the runtime context related to this object.""" + self._attempts = 0 + + def __exit__(self, exc_type, exc_value, traceback): + """Exit the runtime context related to this object with exponential backoff.""" + if exc_type is not None: + self._attempts += 1 + if self._attempts < self.times: + print(f"Retry {self._attempts}/{self.times} failed: {exc_value}") + time.sleep(self.delay * (2**self._attempts)) # exponential backoff delay + return True # Suppresses the exception and retries + return False # Re-raises the exception if retries are exhausted + + def threaded(func): - """Multi-threads a target function and returns thread. Usage: @threaded decorator.""" + """ + Multi-threads a target function by default and returns the thread or function result. + + Use as @threaded decorator. The function runs in a separate thread unless 'threaded=False' is passed. + """ def wrapper(*args, **kwargs): - """Multi-threads a given function and returns the thread.""" - thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) - thread.start() - return thread + """Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result.""" + if kwargs.pop("threaded", True): # run in thread + thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) + thread.start() + return thread + else: + return func(*args, **kwargs) return wrapper @@ -723,27 +863,28 @@ def set_sentry(): Returns: dict: The modified event or None if the event should not be sent to Sentry. """ - if 'exc_info' in hint: - exc_type, exc_value, tb = hint['exc_info'] - if exc_type in (KeyboardInterrupt, FileNotFoundError) \ - or 'out of memory' in str(exc_value): + if "exc_info" in hint: + exc_type, exc_value, tb = hint["exc_info"] + if exc_type in (KeyboardInterrupt, FileNotFoundError) or "out of memory" in str(exc_value): return None # do not send event - event['tags'] = { - 'sys_argv': sys.argv[0], - 'sys_argv_name': Path(sys.argv[0]).name, - 'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other', - 'os': ENVIRONMENT} + event["tags"] = { + "sys_argv": sys.argv[0], + "sys_argv_name": Path(sys.argv[0]).name, + "install": "git" if is_git_dir() else "pip" if is_pip_package() else "other", + "os": ENVIRONMENT, + } return event - if SETTINGS['sync'] and \ - RANK in (-1, 0) and \ - Path(sys.argv[0]).name == 'yolo' and \ - not TESTS_RUNNING and \ - ONLINE and \ - is_pip_package() and \ - not is_git_dir(): - + if ( + SETTINGS["sync"] + and RANK in (-1, 0) + and Path(sys.argv[0]).name == "yolo" + and not TESTS_RUNNING + and ONLINE + and is_pip_package() + and not is_git_dir() + ): # If sentry_sdk package is not installed then return and do not use Sentry try: import sentry_sdk # noqa @@ -751,18 +892,15 @@ def set_sentry(): return sentry_sdk.init( - dsn='https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016', + dsn="https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016", debug=False, traces_sample_rate=1.0, release=__version__, - environment='production', # 'dev' or 'production' + environment="production", # 'dev' or 'production' before_send=before_send, - ignore_errors=[KeyboardInterrupt, FileNotFoundError]) - sentry_sdk.set_user({'id': SETTINGS['uuid']}) # SHA-256 anonymized UUID hash - - # Disable all sentry logging - for logger in 'sentry_sdk', 'sentry_sdk.errors': - logging.getLogger(logger).setLevel(logging.CRITICAL) + ignore_errors=[KeyboardInterrupt, FileNotFoundError], + ) + sentry_sdk.set_user({"id": SETTINGS["uuid"]}) # SHA-256 anonymized UUID hash class SettingsManager(dict): @@ -774,7 +912,10 @@ class SettingsManager(dict): version (str): Settings version. In case of local version mismatch, new default settings will be saved. """ - def __init__(self, file=SETTINGS_YAML, version='0.0.4'): + def __init__(self, file=SETTINGS_YAML, version="0.0.4"): + """Initialize the SettingsManager with default settings, load and validate current settings from the YAML + file. + """ import copy import hashlib @@ -788,22 +929,24 @@ class SettingsManager(dict): self.file = Path(file) self.version = version self.defaults = { - 'settings_version': version, - 'datasets_dir': str(datasets_root / 'datasets'), - 'weights_dir': str(root / 'weights'), - 'runs_dir': str(root / 'runs'), - 'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), - 'sync': True, - 'api_key': '', - 'clearml': True, # integrations - 'comet': True, - 'dvc': True, - 'hub': True, - 'mlflow': True, - 'neptune': True, - 'raytune': True, - 'tensorboard': True, - 'wandb': True} + "settings_version": version, + "datasets_dir": str(datasets_root / "datasets"), + "weights_dir": str(root / "weights"), + "runs_dir": str(root / "runs"), + "uuid": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), + "sync": True, + "api_key": "", + "openai_api_key": "", + "clearml": True, # integrations + "comet": True, + "dvc": True, + "hub": True, + "mlflow": True, + "neptune": True, + "raytune": True, + "tensorboard": True, + "wandb": True, + } super().__init__(copy.deepcopy(self.defaults)) @@ -814,15 +957,26 @@ class SettingsManager(dict): self.load() correct_keys = self.keys() == self.defaults.keys() correct_types = all(type(a) is type(b) for a, b in zip(self.values(), self.defaults.values())) - correct_version = check_version(self['settings_version'], self.version) + correct_version = check_version(self["settings_version"], self.version) + help_msg = ( + f"\nView settings with 'yolo settings' or at '{self.file}'" + "\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. " + "For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings." + ) if not (correct_keys and correct_types and correct_version): LOGGER.warning( - 'WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem ' - 'with your settings or a recent ultralytics package update. ' - f"\nView settings with 'yolo settings' or at '{self.file}'" - "\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'.") + "WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem " + f"with your settings or a recent ultralytics package update. {help_msg}" + ) self.reset() + if self.get("datasets_dir") == self.get("runs_dir"): + LOGGER.warning( + f"WARNING ⚠️ Ultralytics setting 'datasets_dir: {self.get('datasets_dir')}' " + f"must be different than 'runs_dir: {self.get('runs_dir')}'. " + f"Please change one to avoid possible issues during training. {help_msg}" + ) + def load(self): """Loads settings from the YAML file.""" super().update(yaml_load(self.file)) @@ -847,14 +1001,16 @@ def deprecation_warn(arg, new_arg, version=None): """Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument.""" if not version: version = float(__version__[:3]) + 0.2 # deprecate after 2nd major release - LOGGER.warning(f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. " - f"Please use '{new_arg}' instead.") + LOGGER.warning( + f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. " + f"Please use '{new_arg}' instead." + ) def clean_url(url): """Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt.""" - url = Path(url).as_posix().replace(':/', '://') # Pathlib turns :// -> :/, as_posix() for Windows - return urllib.parse.unquote(url).split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth + url = Path(url).as_posix().replace(":/", "://") # Pathlib turns :// -> :/, as_posix() for Windows + return urllib.parse.unquote(url).split("?")[0] # '%2F' to '/', split https://url.com/file.txt?auth def url2file(url): @@ -865,12 +1021,23 @@ def url2file(url): # Run below code on utils init ------------------------------------------------------------------------------------ # Check first-install steps -PREFIX = colorstr('Ultralytics: ') +PREFIX = colorstr("Ultralytics: ") SETTINGS = SettingsManager() # initialize settings -DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory -ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \ - 'Docker' if is_docker() else platform.system() -TESTS_RUNNING = is_pytest_running() or is_github_actions_ci() +DATASETS_DIR = Path(SETTINGS["datasets_dir"]) # global datasets directory +WEIGHTS_DIR = Path(SETTINGS["weights_dir"]) # global weights directory +RUNS_DIR = Path(SETTINGS["runs_dir"]) # global runs directory +ENVIRONMENT = ( + "Colab" + if is_colab() + else "Kaggle" + if is_kaggle() + else "Jupyter" + if is_jupyter() + else "Docker" + if is_docker() + else platform.system() +) +TESTS_RUNNING = is_pytest_running() or is_github_action_running() set_sentry() # Apply monkey patches diff --git a/ultralytics/utils/__pycache__/__init__.cpython-312.pyc b/ultralytics/utils/__pycache__/__init__.cpython-312.pyc index a595202..0951c2a 100644 Binary files a/ultralytics/utils/__pycache__/__init__.cpython-312.pyc and b/ultralytics/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/utils/__pycache__/__init__.cpython-39.pyc b/ultralytics/utils/__pycache__/__init__.cpython-39.pyc index c6c39b2..30dadc2 100644 Binary files a/ultralytics/utils/__pycache__/__init__.cpython-39.pyc and b/ultralytics/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/utils/__pycache__/autobatch.cpython-312.pyc b/ultralytics/utils/__pycache__/autobatch.cpython-312.pyc index 239c707..e198f39 100644 Binary files a/ultralytics/utils/__pycache__/autobatch.cpython-312.pyc and b/ultralytics/utils/__pycache__/autobatch.cpython-312.pyc differ diff --git a/ultralytics/utils/__pycache__/autobatch.cpython-39.pyc b/ultralytics/utils/__pycache__/autobatch.cpython-39.pyc index 7080183..0915af0 100644 Binary files a/ultralytics/utils/__pycache__/autobatch.cpython-39.pyc and b/ultralytics/utils/__pycache__/autobatch.cpython-39.pyc differ diff --git a/ultralytics/utils/__pycache__/checks.cpython-312.pyc b/ultralytics/utils/__pycache__/checks.cpython-312.pyc index 32e65c2..aed4c3f 100644 Binary files a/ultralytics/utils/__pycache__/checks.cpython-312.pyc and b/ultralytics/utils/__pycache__/checks.cpython-312.pyc differ diff --git a/ultralytics/utils/__pycache__/checks.cpython-39.pyc b/ultralytics/utils/__pycache__/checks.cpython-39.pyc index a29aa8b..771b0dd 100644 Binary files a/ultralytics/utils/__pycache__/checks.cpython-39.pyc and b/ultralytics/utils/__pycache__/checks.cpython-39.pyc differ diff --git a/ultralytics/utils/__pycache__/dist.cpython-312.pyc b/ultralytics/utils/__pycache__/dist.cpython-312.pyc index 87941b5..84659f2 100644 Binary files a/ultralytics/utils/__pycache__/dist.cpython-312.pyc and b/ultralytics/utils/__pycache__/dist.cpython-312.pyc differ diff --git a/ultralytics/utils/__pycache__/dist.cpython-39.pyc b/ultralytics/utils/__pycache__/dist.cpython-39.pyc index e79d20c..ec3b3b5 100644 Binary files a/ultralytics/utils/__pycache__/dist.cpython-39.pyc and b/ultralytics/utils/__pycache__/dist.cpython-39.pyc differ diff --git a/ultralytics/utils/__pycache__/downloads.cpython-312.pyc b/ultralytics/utils/__pycache__/downloads.cpython-312.pyc index ecf94e0..b40e6d8 100644 Binary files a/ultralytics/utils/__pycache__/downloads.cpython-312.pyc and b/ultralytics/utils/__pycache__/downloads.cpython-312.pyc differ diff --git a/ultralytics/utils/__pycache__/downloads.cpython-39.pyc b/ultralytics/utils/__pycache__/downloads.cpython-39.pyc index 1537905..bdac1b6 100644 Binary files a/ultralytics/utils/__pycache__/downloads.cpython-39.pyc and b/ultralytics/utils/__pycache__/downloads.cpython-39.pyc differ diff --git a/ultralytics/utils/__pycache__/files.cpython-312.pyc b/ultralytics/utils/__pycache__/files.cpython-312.pyc index e3dbc82..7f4ff26 100644 Binary files a/ultralytics/utils/__pycache__/files.cpython-312.pyc and b/ultralytics/utils/__pycache__/files.cpython-312.pyc differ diff --git a/ultralytics/utils/__pycache__/files.cpython-39.pyc b/ultralytics/utils/__pycache__/files.cpython-39.pyc index 39062d4..1b367bb 100644 Binary files a/ultralytics/utils/__pycache__/files.cpython-39.pyc and b/ultralytics/utils/__pycache__/files.cpython-39.pyc differ diff --git a/ultralytics/utils/__pycache__/instance.cpython-312.pyc b/ultralytics/utils/__pycache__/instance.cpython-312.pyc index c891ce9..6f3eddf 100644 Binary files a/ultralytics/utils/__pycache__/instance.cpython-312.pyc and b/ultralytics/utils/__pycache__/instance.cpython-312.pyc differ diff --git a/ultralytics/utils/__pycache__/instance.cpython-39.pyc b/ultralytics/utils/__pycache__/instance.cpython-39.pyc index 0aefb93..5c8b014 100644 Binary files a/ultralytics/utils/__pycache__/instance.cpython-39.pyc and b/ultralytics/utils/__pycache__/instance.cpython-39.pyc differ diff --git a/ultralytics/utils/__pycache__/loss.cpython-312.pyc b/ultralytics/utils/__pycache__/loss.cpython-312.pyc index 75173d7..7bc1552 100644 Binary files a/ultralytics/utils/__pycache__/loss.cpython-312.pyc and b/ultralytics/utils/__pycache__/loss.cpython-312.pyc differ diff --git a/ultralytics/utils/__pycache__/loss.cpython-39.pyc b/ultralytics/utils/__pycache__/loss.cpython-39.pyc index fbf33cc..e1d4160 100644 Binary files a/ultralytics/utils/__pycache__/loss.cpython-39.pyc and b/ultralytics/utils/__pycache__/loss.cpython-39.pyc differ diff --git a/ultralytics/utils/__pycache__/metrics.cpython-312.pyc b/ultralytics/utils/__pycache__/metrics.cpython-312.pyc index d3865de..6d9b55f 100644 Binary files a/ultralytics/utils/__pycache__/metrics.cpython-312.pyc and b/ultralytics/utils/__pycache__/metrics.cpython-312.pyc differ diff --git a/ultralytics/utils/__pycache__/metrics.cpython-39.pyc b/ultralytics/utils/__pycache__/metrics.cpython-39.pyc index 6aa136b..2bb74a9 100644 Binary files a/ultralytics/utils/__pycache__/metrics.cpython-39.pyc and b/ultralytics/utils/__pycache__/metrics.cpython-39.pyc differ diff --git a/ultralytics/utils/__pycache__/ops.cpython-312.pyc b/ultralytics/utils/__pycache__/ops.cpython-312.pyc index d0ec5a2..48b67ce 100644 Binary files a/ultralytics/utils/__pycache__/ops.cpython-312.pyc and b/ultralytics/utils/__pycache__/ops.cpython-312.pyc differ diff --git a/ultralytics/utils/__pycache__/ops.cpython-39.pyc b/ultralytics/utils/__pycache__/ops.cpython-39.pyc index 43f7f86..3fd77f6 100644 Binary files a/ultralytics/utils/__pycache__/ops.cpython-39.pyc and b/ultralytics/utils/__pycache__/ops.cpython-39.pyc differ diff --git a/ultralytics/utils/__pycache__/patches.cpython-312.pyc b/ultralytics/utils/__pycache__/patches.cpython-312.pyc index 176c8c0..a60046a 100644 Binary files a/ultralytics/utils/__pycache__/patches.cpython-312.pyc and b/ultralytics/utils/__pycache__/patches.cpython-312.pyc differ diff --git a/ultralytics/utils/__pycache__/patches.cpython-39.pyc b/ultralytics/utils/__pycache__/patches.cpython-39.pyc index 5d8dd57..b456782 100644 Binary files a/ultralytics/utils/__pycache__/patches.cpython-39.pyc and b/ultralytics/utils/__pycache__/patches.cpython-39.pyc differ diff --git a/ultralytics/utils/__pycache__/plotting.cpython-312.pyc b/ultralytics/utils/__pycache__/plotting.cpython-312.pyc index 642bbd9..8367369 100644 Binary files a/ultralytics/utils/__pycache__/plotting.cpython-312.pyc and b/ultralytics/utils/__pycache__/plotting.cpython-312.pyc differ diff --git a/ultralytics/utils/__pycache__/plotting.cpython-39.pyc b/ultralytics/utils/__pycache__/plotting.cpython-39.pyc index 01ed426..08c62be 100644 Binary files a/ultralytics/utils/__pycache__/plotting.cpython-39.pyc and b/ultralytics/utils/__pycache__/plotting.cpython-39.pyc differ diff --git a/ultralytics/utils/__pycache__/tal.cpython-312.pyc b/ultralytics/utils/__pycache__/tal.cpython-312.pyc index 1b9e9bb..6c9ad96 100644 Binary files a/ultralytics/utils/__pycache__/tal.cpython-312.pyc and b/ultralytics/utils/__pycache__/tal.cpython-312.pyc differ diff --git a/ultralytics/utils/__pycache__/tal.cpython-39.pyc b/ultralytics/utils/__pycache__/tal.cpython-39.pyc index 342a967..a324052 100644 Binary files a/ultralytics/utils/__pycache__/tal.cpython-39.pyc and b/ultralytics/utils/__pycache__/tal.cpython-39.pyc differ diff --git a/ultralytics/utils/__pycache__/torch_utils.cpython-312.pyc b/ultralytics/utils/__pycache__/torch_utils.cpython-312.pyc index b47a76a..35db9da 100644 Binary files a/ultralytics/utils/__pycache__/torch_utils.cpython-312.pyc and b/ultralytics/utils/__pycache__/torch_utils.cpython-312.pyc differ diff --git a/ultralytics/utils/__pycache__/torch_utils.cpython-39.pyc b/ultralytics/utils/__pycache__/torch_utils.cpython-39.pyc index 5e4c172..18adc00 100644 Binary files a/ultralytics/utils/__pycache__/torch_utils.cpython-39.pyc and b/ultralytics/utils/__pycache__/torch_utils.cpython-39.pyc differ diff --git a/ultralytics/utils/autobatch.py b/ultralytics/utils/autobatch.py index 4e9ed07..daea14e 100644 --- a/ultralytics/utils/autobatch.py +++ b/ultralytics/utils/autobatch.py @@ -1,7 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -""" -Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch. -""" +"""Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.""" from copy import deepcopy @@ -36,7 +34,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch): Args: model (torch.nn.module): YOLO model to compute batch size for. imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640. - fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67. + fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.60. batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16. Returns: @@ -44,14 +42,14 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch): """ # Check device - prefix = colorstr('AutoBatch: ') - LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}') + prefix = colorstr("AutoBatch: ") + LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz}") device = next(model.parameters()).device # get model device - if device.type == 'cpu': - LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}') + if device.type == "cpu": + LOGGER.info(f"{prefix}CUDA not detected, using default CPU batch-size {batch_size}") return batch_size if torch.backends.cudnn.benchmark: - LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}') + LOGGER.info(f"{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}") return batch_size # Inspect CUDA memory @@ -62,7 +60,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch): r = torch.cuda.memory_reserved(device) / gb # GiB reserved a = torch.cuda.memory_allocated(device) / gb # GiB allocated f = t - (r + a) # GiB free - LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free') + LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free") # Profile batch sizes batch_sizes = [1, 2, 4, 8, 16] @@ -72,7 +70,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch): # Fit a solution y = [x[2] for x in results if x] # memory [2] - p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit + p = np.polyfit(batch_sizes[: len(y)], y, deg=1) # first degree polynomial fit b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size) if None in results: # some sizes failed i = results.index(None) # first fail index @@ -80,11 +78,11 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch): b = batch_sizes[max(i - 1, 0)] # select prior safe point if b < 1 or b > 1024: # b outside of safe range b = batch_size - LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.') + LOGGER.info(f"{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.") fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted - LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅') + LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅") return b except Exception as e: - LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.') + LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.") return batch_size diff --git a/ultralytics/utils/benchmarks.py b/ultralytics/utils/benchmarks.py index ad1bcf3..0286990 100644 --- a/ultralytics/utils/benchmarks.py +++ b/ultralytics/utils/benchmarks.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license """ -Benchmark a YOLO model formats for speed and accuracy +Benchmark a YOLO model formats for speed and accuracy. Usage: from ultralytics.utils.benchmarks import ProfileModels, benchmark @@ -21,34 +21,29 @@ TensorFlow Lite | `tflite` | yolov8n.tflite TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite TensorFlow.js | `tfjs` | yolov8n_web_model/ PaddlePaddle | `paddle` | yolov8n_paddle_model/ -ncnn | `ncnn` | yolov8n_ncnn_model/ +NCNN | `ncnn` | yolov8n_ncnn_model/ """ import glob import platform -import sys import time from pathlib import Path import numpy as np import torch.cuda -from ultralytics import YOLO +from ultralytics import YOLO, YOLOWorld from ultralytics.cfg import TASK2DATA, TASK2METRIC from ultralytics.engine.exporter import export_formats -from ultralytics.utils import ASSETS, LINUX, LOGGER, MACOS, SETTINGS, TQDM -from ultralytics.utils.checks import check_requirements, check_yolo +from ultralytics.utils import ASSETS, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR +from ultralytics.utils.checks import IS_PYTHON_3_12, check_requirements, check_yolo from ultralytics.utils.files import file_size from ultralytics.utils.torch_utils import select_device -def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', - data=None, - imgsz=160, - half=False, - int8=False, - device='cpu', - verbose=False): +def benchmark( + model=WEIGHTS_DIR / "yolov8n.pt", data=None, imgsz=160, half=False, int8=False, device="cpu", verbose=False +): """ Benchmark a YOLO model across different formats for speed and accuracy. @@ -76,6 +71,7 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', """ import pandas as pd + pd.options.display.max_columns = 10 pd.options.display.width = 120 device = select_device(device, verbose=False) @@ -85,67 +81,72 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', y = [] t0 = time.time() for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU) - emoji, filename = '❌', None # export defaults + emoji, filename = "❌", None # export defaults try: - assert i != 9 or LINUX, 'Edge TPU export only supported on Linux' - if i == 10: - assert MACOS or LINUX, 'TF.js export only supported on macOS and Linux' - elif i == 11: - assert sys.version_info < (3, 11), 'PaddlePaddle export only supported on Python<=3.10' - if 'cpu' in device.type: - assert cpu, 'inference not supported on CPU' - if 'cuda' in device.type: - assert gpu, 'inference not supported on GPU' + # Checks + if i == 9: # Edge TPU + assert LINUX, "Edge TPU export only supported on Linux" + elif i == 7: # TF GraphDef + assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task" + elif i in {5, 10}: # CoreML and TF.js + assert MACOS or LINUX, "export only supported on macOS and Linux" + if i in {3, 5}: # CoreML and OpenVINO + assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12" + if i in {6, 7, 8, 9, 10}: # All TF formats + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet" + if i in {11}: # Paddle + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet" + if i in {12}: # NCNN + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet" + if "cpu" in device.type: + assert cpu, "inference not supported on CPU" + if "cuda" in device.type: + assert gpu, "inference not supported on GPU" # Export - if format == '-': + if format == "-": filename = model.ckpt_path or model.cfg - export = model # PyTorch format + exported_model = model # PyTorch format else: filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device, verbose=False) - export = YOLO(filename, task=model.task) - assert suffix in str(filename), 'export failed' - emoji = '❎' # indicates export succeeded + exported_model = YOLO(filename, task=model.task) + assert suffix in str(filename), "export failed" + emoji = "❎" # indicates export succeeded # Predict - assert model.task != 'pose' or i != 7, 'GraphDef Pose inference is not supported' - assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported - assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML - export.predict(ASSETS / 'bus.jpg', imgsz=imgsz, device=device, half=half) + assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported" + assert i not in (9, 10), "inference not supported" # Edge TPU and TF.js are unsupported + assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML + exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half) # Validate data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect - results = export.val(data=data, - batch=1, - imgsz=imgsz, - plots=False, - device=device, - half=half, - int8=int8, - verbose=False) - metric, speed = results.results_dict[key], results.speed['inference'] - y.append([name, '✅', round(file_size(filename), 1), round(metric, 4), round(speed, 2)]) + results = exported_model.val( + data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, int8=int8, verbose=False + ) + metric, speed = results.results_dict[key], results.speed["inference"] + y.append([name, "✅", round(file_size(filename), 1), round(metric, 4), round(speed, 2)]) except Exception as e: if verbose: - assert type(e) is AssertionError, f'Benchmark failure for {name}: {e}' - LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}') + assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}" + LOGGER.warning(f"ERROR ❌️ Benchmark failure for {name}: {e}") y.append([name, emoji, round(file_size(filename), 1), None, None]) # mAP, t_inference # Print results check_yolo(device=device) # print system info - df = pd.DataFrame(y, columns=['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)']) + df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)"]) name = Path(model.ckpt_path).name - s = f'\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n' + s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n" LOGGER.info(s) - with open('benchmarks.log', 'a', errors='ignore', encoding='utf-8') as f: + with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f: f.write(s) if verbose and isinstance(verbose, float): metrics = df[key].array # values to compare to floor floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n - assert all(x > floor for x in metrics if pd.notna(x)), f'Benchmark failure: metric(s) < floor {floor}' + assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}" return df @@ -154,8 +155,7 @@ class ProfileModels: """ ProfileModels class for profiling different models on ONNX and TensorRT. - This class profiles the performance of different models, provided their paths. The profiling includes parameters such as - model speed and FLOPs. + This class profiles the performance of different models, returning results such as model speed and FLOPs. Attributes: paths (list): Paths of the models to profile. @@ -175,15 +175,30 @@ class ProfileModels: ``` """ - def __init__(self, - paths: list, - num_timed_runs=100, - num_warmup_runs=10, - min_time=60, - imgsz=640, - half=True, - trt=True, - device=None): + def __init__( + self, + paths: list, + num_timed_runs=100, + num_warmup_runs=10, + min_time=60, + imgsz=640, + half=True, + trt=True, + device=None, + ): + """ + Initialize the ProfileModels class for profiling models. + + Args: + paths (list): List of paths of the models to be profiled. + num_timed_runs (int, optional): Number of timed runs for the profiling. Default is 100. + num_warmup_runs (int, optional): Number of warmup runs before the actual profiling starts. Default is 10. + min_time (float, optional): Minimum time in seconds for profiling a model. Default is 60. + imgsz (int, optional): Size of the image used during profiling. Default is 640. + half (bool, optional): Flag to indicate whether to use half-precision floating point for profiling. + trt (bool, optional): Flag to indicate whether to profile using TensorRT. Default is True. + device (torch.device, optional): Device used for profiling. If None, it is determined automatically. + """ self.paths = paths self.num_timed_runs = num_timed_runs self.num_warmup_runs = num_warmup_runs @@ -191,36 +206,32 @@ class ProfileModels: self.imgsz = imgsz self.half = half self.trt = trt # run TensorRT profiling - self.device = device or torch.device(0 if torch.cuda.is_available() else 'cpu') + self.device = device or torch.device(0 if torch.cuda.is_available() else "cpu") def profile(self): + """Logs the benchmarking results of a model, checks metrics against floor and returns the results.""" files = self.get_files() if not files: - print('No matching *.pt or *.onnx files found.') + print("No matching *.pt or *.onnx files found.") return table_rows = [] output = [] for file in files: - engine_file = file.with_suffix('.engine') - if file.suffix in ('.pt', '.yaml', '.yml'): + engine_file = file.with_suffix(".engine") + if file.suffix in (".pt", ".yaml", ".yml"): model = YOLO(str(file)) model.fuse() # to report correct params and GFLOPs in model.info() model_info = model.info() - if self.trt and self.device.type != 'cpu' and not engine_file.is_file(): - engine_file = model.export(format='engine', - half=self.half, - imgsz=self.imgsz, - device=self.device, - verbose=False) - onnx_file = model.export(format='onnx', - half=self.half, - imgsz=self.imgsz, - simplify=True, - device=self.device, - verbose=False) - elif file.suffix == '.onnx': + if self.trt and self.device.type != "cpu" and not engine_file.is_file(): + engine_file = model.export( + format="engine", half=self.half, imgsz=self.imgsz, device=self.device, verbose=False + ) + onnx_file = model.export( + format="onnx", half=self.half, imgsz=self.imgsz, simplify=True, device=self.device, verbose=False + ) + elif file.suffix == ".onnx": model_info = self.get_onnx_model_info(file) onnx_file = file else: @@ -235,25 +246,30 @@ class ProfileModels: return output def get_files(self): + """Returns a list of paths for all relevant model files given by the user.""" files = [] for path in self.paths: path = Path(path) if path.is_dir(): - extensions = ['*.pt', '*.onnx', '*.yaml'] + extensions = ["*.pt", "*.onnx", "*.yaml"] files.extend([file for ext in extensions for file in glob.glob(str(path / ext))]) - elif path.suffix in {'.pt', '.yaml', '.yml'}: # add non-existing + elif path.suffix in {".pt", ".yaml", ".yml"}: # add non-existing files.append(str(path)) else: files.extend(glob.glob(str(path))) - print(f'Profiling: {sorted(files)}') + print(f"Profiling: {sorted(files)}") return [Path(file) for file in sorted(files)] def get_onnx_model_info(self, onnx_file: str): - # return (num_layers, num_params, num_gradients, num_flops) - return 0.0, 0.0, 0.0, 0.0 + """Retrieves the information including number of layers, parameters, gradients and FLOPs for an ONNX model + file. + """ + return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops) - def iterative_sigma_clipping(self, data, sigma=2, max_iters=3): + @staticmethod + def iterative_sigma_clipping(data, sigma=2, max_iters=3): + """Applies an iterative sigma clipping algorithm to the given data times number of iterations.""" data = np.array(data) for _ in range(max_iters): mean, std = np.mean(data), np.std(data) @@ -264,6 +280,7 @@ class ProfileModels: return data def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3): + """Profiles the TensorRT model, measuring average run time and standard deviation among runs.""" if not self.trt or not Path(engine_file).is_file(): return 0.0, 0.0 @@ -286,39 +303,44 @@ class ProfileModels: run_times = [] for _ in TQDM(range(num_runs), desc=engine_file): results = model(input_data, imgsz=self.imgsz, verbose=False) - run_times.append(results[0].speed['inference']) # Convert to milliseconds + run_times.append(results[0].speed["inference"]) # Convert to milliseconds run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping return np.mean(run_times), np.std(run_times) def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3): - check_requirements('onnxruntime') + """Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run + times. + """ + check_requirements("onnxruntime") import onnxruntime as ort # Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider' sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.intra_op_num_threads = 8 # Limit the number of threads - sess = ort.InferenceSession(onnx_file, sess_options, providers=['CPUExecutionProvider']) + sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"]) input_tensor = sess.get_inputs()[0] input_type = input_tensor.type + dynamic = not all(isinstance(dim, int) and dim >= 0 for dim in input_tensor.shape) # dynamic input shape + input_shape = (1, 3, self.imgsz, self.imgsz) if dynamic else input_tensor.shape # Mapping ONNX datatype to numpy datatype - if 'float16' in input_type: + if "float16" in input_type: input_dtype = np.float16 - elif 'float' in input_type: + elif "float" in input_type: input_dtype = np.float32 - elif 'double' in input_type: + elif "double" in input_type: input_dtype = np.float64 - elif 'int64' in input_type: + elif "int64" in input_type: input_dtype = np.int64 - elif 'int32' in input_type: + elif "int32" in input_type: input_dtype = np.int32 else: - raise ValueError(f'Unsupported ONNX datatype {input_type}') + raise ValueError(f"Unsupported ONNX datatype {input_type}") - input_data = np.random.rand(*input_tensor.shape).astype(input_dtype) + input_data = np.random.rand(*input_shape).astype(input_dtype) input_name = input_tensor.name output_name = sess.get_outputs()[0].name @@ -344,24 +366,39 @@ class ProfileModels: return np.mean(run_times), np.std(run_times) def generate_table_row(self, model_name, t_onnx, t_engine, model_info): + """Generates a formatted string for a table row that includes model performance and metric details.""" layers, params, gradients, flops = model_info - return f'| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |' + return ( + f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± " + f"{t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |" + ) - def generate_results_dict(self, model_name, t_onnx, t_engine, model_info): + @staticmethod + def generate_results_dict(model_name, t_onnx, t_engine, model_info): + """Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics.""" layers, params, gradients, flops = model_info return { - 'model/name': model_name, - 'model/parameters': params, - 'model/GFLOPs': round(flops, 3), - 'model/speed_ONNX(ms)': round(t_onnx[0], 3), - 'model/speed_TensorRT(ms)': round(t_engine[0], 3)} + "model/name": model_name, + "model/parameters": params, + "model/GFLOPs": round(flops, 3), + "model/speed_ONNX(ms)": round(t_onnx[0], 3), + "model/speed_TensorRT(ms)": round(t_engine[0], 3), + } - def print_table(self, table_rows): - gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU' - header = f'| Model | size
(pixels) | mAPval
50-95 | Speed
CPU ONNX
(ms) | Speed
{gpu} TensorRT
(ms) | params
(M) | FLOPs
(B) |' - separator = '|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|' + @staticmethod + def print_table(table_rows): + """Formats and prints a comparison table for different models with given statistics and performance data.""" + gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU" + header = ( + f"| Model | size
(pixels) | mAPval
50-95 | Speed
CPU ONNX
(ms) | " + f"Speed
{gpu} TensorRT
(ms) | params
(M) | FLOPs
(B) |" + ) + separator = ( + "|-------------|---------------------|--------------------|------------------------------|" + "-----------------------------------|------------------|-----------------|" + ) - print(f'\n\n{header}') + print(f"\n\n{header}") print(separator) for row in table_rows: print(row) diff --git a/ultralytics/utils/callbacks/__init__.py b/ultralytics/utils/callbacks/__init__.py index 8ad4ad6..116babe 100644 --- a/ultralytics/utils/callbacks/__init__.py +++ b/ultralytics/utils/callbacks/__init__.py @@ -2,4 +2,4 @@ from .base import add_integration_callbacks, default_callbacks, get_default_callbacks -__all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks' +__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks" diff --git a/ultralytics/utils/callbacks/__pycache__/__init__.cpython-312.pyc b/ultralytics/utils/callbacks/__pycache__/__init__.cpython-312.pyc index 224d9a6..9cdee6c 100644 Binary files a/ultralytics/utils/callbacks/__pycache__/__init__.cpython-312.pyc and b/ultralytics/utils/callbacks/__pycache__/__init__.cpython-312.pyc differ diff --git a/ultralytics/utils/callbacks/__pycache__/__init__.cpython-39.pyc b/ultralytics/utils/callbacks/__pycache__/__init__.cpython-39.pyc index fa7f542..f4edc7a 100644 Binary files a/ultralytics/utils/callbacks/__pycache__/__init__.cpython-39.pyc and b/ultralytics/utils/callbacks/__pycache__/__init__.cpython-39.pyc differ diff --git a/ultralytics/utils/callbacks/__pycache__/base.cpython-312.pyc b/ultralytics/utils/callbacks/__pycache__/base.cpython-312.pyc index 0313f40..02e8cd1 100644 Binary files a/ultralytics/utils/callbacks/__pycache__/base.cpython-312.pyc and b/ultralytics/utils/callbacks/__pycache__/base.cpython-312.pyc differ diff --git a/ultralytics/utils/callbacks/__pycache__/base.cpython-39.pyc b/ultralytics/utils/callbacks/__pycache__/base.cpython-39.pyc index c679acb..acf88c8 100644 Binary files a/ultralytics/utils/callbacks/__pycache__/base.cpython-39.pyc and b/ultralytics/utils/callbacks/__pycache__/base.cpython-39.pyc differ diff --git a/ultralytics/utils/callbacks/__pycache__/hub.cpython-39.pyc b/ultralytics/utils/callbacks/__pycache__/hub.cpython-39.pyc index a85b433..4dcbe86 100644 Binary files a/ultralytics/utils/callbacks/__pycache__/hub.cpython-39.pyc and b/ultralytics/utils/callbacks/__pycache__/hub.cpython-39.pyc differ diff --git a/ultralytics/utils/callbacks/base.py b/ultralytics/utils/callbacks/base.py index 2e676bf..d015457 100644 --- a/ultralytics/utils/callbacks/base.py +++ b/ultralytics/utils/callbacks/base.py @@ -1,11 +1,10 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -""" -Base callbacks -""" +"""Base callbacks.""" from collections import defaultdict from copy import deepcopy + # Trainer callbacks ---------------------------------------------------------------------------------------------------- @@ -145,37 +144,35 @@ def on_export_end(exporter): default_callbacks = { # Run in trainer - 'on_pretrain_routine_start': [on_pretrain_routine_start], - 'on_pretrain_routine_end': [on_pretrain_routine_end], - 'on_train_start': [on_train_start], - 'on_train_epoch_start': [on_train_epoch_start], - 'on_train_batch_start': [on_train_batch_start], - 'optimizer_step': [optimizer_step], - 'on_before_zero_grad': [on_before_zero_grad], - 'on_train_batch_end': [on_train_batch_end], - 'on_train_epoch_end': [on_train_epoch_end], - 'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val - 'on_model_save': [on_model_save], - 'on_train_end': [on_train_end], - 'on_params_update': [on_params_update], - 'teardown': [teardown], - + "on_pretrain_routine_start": [on_pretrain_routine_start], + "on_pretrain_routine_end": [on_pretrain_routine_end], + "on_train_start": [on_train_start], + "on_train_epoch_start": [on_train_epoch_start], + "on_train_batch_start": [on_train_batch_start], + "optimizer_step": [optimizer_step], + "on_before_zero_grad": [on_before_zero_grad], + "on_train_batch_end": [on_train_batch_end], + "on_train_epoch_end": [on_train_epoch_end], + "on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val + "on_model_save": [on_model_save], + "on_train_end": [on_train_end], + "on_params_update": [on_params_update], + "teardown": [teardown], # Run in validator - 'on_val_start': [on_val_start], - 'on_val_batch_start': [on_val_batch_start], - 'on_val_batch_end': [on_val_batch_end], - 'on_val_end': [on_val_end], - + "on_val_start": [on_val_start], + "on_val_batch_start": [on_val_batch_start], + "on_val_batch_end": [on_val_batch_end], + "on_val_end": [on_val_end], # Run in predictor - 'on_predict_start': [on_predict_start], - 'on_predict_batch_start': [on_predict_batch_start], - 'on_predict_postprocess_end': [on_predict_postprocess_end], - 'on_predict_batch_end': [on_predict_batch_end], - 'on_predict_end': [on_predict_end], - + "on_predict_start": [on_predict_start], + "on_predict_batch_start": [on_predict_batch_start], + "on_predict_postprocess_end": [on_predict_postprocess_end], + "on_predict_batch_end": [on_predict_batch_end], + "on_predict_end": [on_predict_end], # Run in exporter - 'on_export_start': [on_export_start], - 'on_export_end': [on_export_end]} + "on_export_start": [on_export_start], + "on_export_end": [on_export_end], +} def get_default_callbacks(): @@ -199,10 +196,11 @@ def add_integration_callbacks(instance): # Load HUB callbacks from .hub import callbacks as hub_cb + callbacks_list = [hub_cb] # Load training callbacks - if 'Trainer' in instance.__class__.__name__: + if "Trainer" in instance.__class__.__name__: from .clearml import callbacks as clear_cb from .comet import callbacks as comet_cb from .dvc import callbacks as dvc_cb @@ -211,12 +209,8 @@ def add_integration_callbacks(instance): from .raytune import callbacks as tune_cb from .tensorboard import callbacks as tb_cb from .wb import callbacks as wb_cb - callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb]) - # Load export callbacks (patch to avoid CoreML protobuf error) - if 'Exporter' in instance.__class__.__name__: - from .tensorboard import callbacks as tb_cb - callbacks_list.append(tb_cb) + callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb]) # Add the callbacks to the callbacks dictionary for callbacks in callbacks_list: diff --git a/ultralytics/utils/callbacks/clearml.py b/ultralytics/utils/callbacks/clearml.py index dfb2203..a030fc5 100644 --- a/ultralytics/utils/callbacks/clearml.py +++ b/ultralytics/utils/callbacks/clearml.py @@ -4,19 +4,19 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING try: assert not TESTS_RUNNING # do not log pytest - assert SETTINGS['clearml'] is True # verify integration is enabled + assert SETTINGS["clearml"] is True # verify integration is enabled import clearml from clearml import Task from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO from clearml.binding.matplotlib_bind import PatchedMatplotlib - assert hasattr(clearml, '__version__') # verify package is not directory + assert hasattr(clearml, "__version__") # verify package is not directory except (ImportError, AssertionError): clearml = None -def _log_debug_samples(files, title='Debug Samples') -> None: +def _log_debug_samples(files, title="Debug Samples") -> None: """ Log files (images) as debug samples in the ClearML task. @@ -29,12 +29,11 @@ def _log_debug_samples(files, title='Debug Samples') -> None: if task := Task.current_task(): for f in files: if f.exists(): - it = re.search(r'_batch(\d+)', f.name) + it = re.search(r"_batch(\d+)", f.name) iteration = int(it.groups()[0]) if it else 0 - task.get_logger().report_image(title=title, - series=f.name.replace(it.group(), ''), - local_path=str(f), - iteration=iteration) + task.get_logger().report_image( + title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration + ) def _log_plot(title, plot_path) -> None: @@ -50,13 +49,12 @@ def _log_plot(title, plot_path) -> None: img = mpimg.imread(plot_path) fig = plt.figure() - ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks + ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks ax.imshow(img) - Task.current_task().get_logger().report_matplotlib_figure(title=title, - series='', - figure=fig, - report_interactive=False) + Task.current_task().get_logger().report_matplotlib_figure( + title=title, series="", figure=fig, report_interactive=False + ) def on_pretrain_routine_start(trainer): @@ -68,19 +66,21 @@ def on_pretrain_routine_start(trainer): PatchPyTorchModelIO.update_current_task(None) PatchedMatplotlib.update_current_task(None) else: - task = Task.init(project_name=trainer.args.project or 'YOLOv8', - task_name=trainer.args.name, - tags=['YOLOv8'], - output_uri=True, - reuse_last_task_id=False, - auto_connect_frameworks={ - 'pytorch': False, - 'matplotlib': False}) - LOGGER.warning('ClearML Initialized a new task. If you want to run remotely, ' - 'please add clearml-init and connect your arguments before initializing YOLO.') - task.connect(vars(trainer.args), name='General') + task = Task.init( + project_name=trainer.args.project or "YOLOv8", + task_name=trainer.args.name, + tags=["YOLOv8"], + output_uri=True, + reuse_last_task_id=False, + auto_connect_frameworks={"pytorch": False, "matplotlib": False}, + ) + LOGGER.warning( + "ClearML Initialized a new task. If you want to run remotely, " + "please add clearml-init and connect your arguments before initializing YOLO." + ) + task.connect(vars(trainer.args), name="General") except Exception as e: - LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}') + LOGGER.warning(f"WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}") def on_train_epoch_end(trainer): @@ -88,22 +88,26 @@ def on_train_epoch_end(trainer): if task := Task.current_task(): # Log debug samples if trainer.epoch == 1: - _log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic') + _log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic") # Report the current training progress - for k, v in trainer.validator.metrics.results_dict.items(): - task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch) + for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items(): + task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch) + for k, v in trainer.lr.items(): + task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch) def on_fit_epoch_end(trainer): """Reports model information to logger at the end of an epoch.""" if task := Task.current_task(): # You should have access to the validation bboxes under jdict - task.get_logger().report_scalar(title='Epoch Time', - series='Epoch Time', - value=trainer.epoch_time, - iteration=trainer.epoch) + task.get_logger().report_scalar( + title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch + ) + for k, v in trainer.metrics.items(): + task.get_logger().report_scalar("val", k, v, iteration=trainer.epoch) if trainer.epoch == 0: from ultralytics.utils.torch_utils import model_info_for_loggers + for k, v in model_info_for_loggers(trainer).items(): task.get_logger().report_single_value(k, v) @@ -112,7 +116,7 @@ def on_val_end(validator): """Logs validation results including labels and predictions.""" if Task.current_task(): # Log val_labels and val_pred - _log_debug_samples(sorted(validator.save_dir.glob('val*.jpg')), 'Validation') + _log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation") def on_train_end(trainer): @@ -120,8 +124,11 @@ def on_train_end(trainer): if task := Task.current_task(): # Log final results, CM matrix + PR plots files = [ - 'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png', - *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))] + "results.png", + "confusion_matrix.png", + "confusion_matrix_normalized.png", + *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")), + ] files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter for f in files: _log_plot(title=f.stem, plot_path=f) @@ -132,9 +139,14 @@ def on_train_end(trainer): task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False) -callbacks = { - 'on_pretrain_routine_start': on_pretrain_routine_start, - 'on_train_epoch_end': on_train_epoch_end, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_val_end': on_val_end, - 'on_train_end': on_train_end} if clearml else {} +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_val_end": on_val_end, + "on_train_end": on_train_end, + } + if clearml + else {} +) diff --git a/ultralytics/utils/callbacks/comet.py b/ultralytics/utils/callbacks/comet.py index 2da71a9..1c5f585 100644 --- a/ultralytics/utils/callbacks/comet.py +++ b/ultralytics/utils/callbacks/comet.py @@ -4,20 +4,20 @@ from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops try: assert not TESTS_RUNNING # do not log pytest - assert SETTINGS['comet'] is True # verify integration is enabled + assert SETTINGS["comet"] is True # verify integration is enabled import comet_ml - assert hasattr(comet_ml, '__version__') # verify package is not directory + assert hasattr(comet_ml, "__version__") # verify package is not directory import os from pathlib import Path # Ensures certain logging functions only run for supported tasks - COMET_SUPPORTED_TASKS = ['detect'] + COMET_SUPPORTED_TASKS = ["detect"] # Names of plots created by YOLOv8 that are logged to Comet - EVALUATION_PLOT_NAMES = 'F1_curve', 'P_curve', 'R_curve', 'PR_curve', 'confusion_matrix' - LABEL_PLOT_NAMES = 'labels', 'labels_correlogram' + EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve", "confusion_matrix" + LABEL_PLOT_NAMES = "labels", "labels_correlogram" _comet_image_prediction_count = 0 @@ -26,37 +26,44 @@ except (ImportError, AssertionError): def _get_comet_mode(): - return os.getenv('COMET_MODE', 'online') + """Returns the mode of comet set in the environment variables, defaults to 'online' if not set.""" + return os.getenv("COMET_MODE", "online") def _get_comet_model_name(): - return os.getenv('COMET_MODEL_NAME', 'YOLOv8') + """Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'.""" + return os.getenv("COMET_MODEL_NAME", "YOLOv8") def _get_eval_batch_logging_interval(): - return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1)) + """Get the evaluation batch logging interval from environment variable or use default value 1.""" + return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1)) def _get_max_image_predictions_to_log(): - return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100)) + """Get the maximum number of image predictions to log from the environment variables.""" + return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100)) def _scale_confidence_score(score): - scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0)) + """Scales the given confidence score by a factor specified in an environment variable.""" + scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0)) return score * scale def _should_log_confusion_matrix(): - return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'false').lower() == 'true' + """Determines if the confusion matrix should be logged based on the environment variable settings.""" + return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true" def _should_log_image_predictions(): - return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true' + """Determines whether to log image predictions based on a specified environment variable.""" + return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true" def _get_experiment_type(mode, project_name): """Return an experiment based on mode and project name.""" - if mode == 'offline': + if mode == "offline": return comet_ml.OfflineExperiment(project_name=project_name) return comet_ml.Experiment(project_name=project_name) @@ -68,18 +75,21 @@ def _create_experiment(args): return try: comet_mode = _get_comet_mode() - _project_name = os.getenv('COMET_PROJECT_NAME', args.project) + _project_name = os.getenv("COMET_PROJECT_NAME", args.project) experiment = _get_experiment_type(comet_mode, _project_name) experiment.log_parameters(vars(args)) - experiment.log_others({ - 'eval_batch_logging_interval': _get_eval_batch_logging_interval(), - 'log_confusion_matrix_on_eval': _should_log_confusion_matrix(), - 'log_image_predictions': _should_log_image_predictions(), - 'max_image_predictions': _get_max_image_predictions_to_log(), }) - experiment.log_other('Created from', 'yolov8') + experiment.log_others( + { + "eval_batch_logging_interval": _get_eval_batch_logging_interval(), + "log_confusion_matrix_on_eval": _should_log_confusion_matrix(), + "log_image_predictions": _should_log_image_predictions(), + "max_image_predictions": _get_max_image_predictions_to_log(), + } + ) + experiment.log_other("Created from", "yolov8") except Exception as e: - LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}') + LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}") def _fetch_trainer_metadata(trainer): @@ -95,18 +105,14 @@ def _fetch_trainer_metadata(trainer): save_interval = curr_epoch % save_period == 0 save_assets = save and save_period > 0 and save_interval and not final_epoch - return dict( - curr_epoch=curr_epoch, - curr_step=curr_step, - save_assets=save_assets, - final_epoch=final_epoch, - ) + return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch) def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad): - """YOLOv8 resizes images during training and the label values - are normalized based on this resized shape. This function rescales the - bounding box labels to the original image shape. + """ + YOLOv8 resizes images during training and the label values are normalized based on this resized shape. + + This function rescales the bounding box labels to the original image shape. """ resized_image_height, resized_image_width = resized_image_shape @@ -126,29 +132,32 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None): """Format ground truth annotations for detection.""" - indices = batch['batch_idx'] == img_idx - bboxes = batch['bboxes'][indices] + indices = batch["batch_idx"] == img_idx + bboxes = batch["bboxes"][indices] if len(bboxes) == 0: - LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes labels') + LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes labels") return None - cls_labels = batch['cls'][indices].squeeze(1).tolist() + cls_labels = batch["cls"][indices].squeeze(1).tolist() if class_name_map: cls_labels = [str(class_name_map[label]) for label in cls_labels] - original_image_shape = batch['ori_shape'][img_idx] - resized_image_shape = batch['resized_shape'][img_idx] - ratio_pad = batch['ratio_pad'][img_idx] + original_image_shape = batch["ori_shape"][img_idx] + resized_image_shape = batch["resized_shape"][img_idx] + ratio_pad = batch["ratio_pad"][img_idx] data = [] for box, label in zip(bboxes, cls_labels): box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad) - data.append({ - 'boxes': [box], - 'label': f'gt_{label}', - 'score': _scale_confidence_score(1.0), }) + data.append( + { + "boxes": [box], + "label": f"gt_{label}", + "score": _scale_confidence_score(1.0), + } + ) - return {'name': 'ground_truth', 'data': data} + return {"name": "ground_truth", "data": data} def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None): @@ -158,31 +167,34 @@ def _format_prediction_annotations_for_detection(image_path, metadata, class_lab predictions = metadata.get(image_id) if not predictions: - LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes predictions') + LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions") return None data = [] for prediction in predictions: - boxes = prediction['bbox'] - score = _scale_confidence_score(prediction['score']) - cls_label = prediction['category_id'] + boxes = prediction["bbox"] + score = _scale_confidence_score(prediction["score"]) + cls_label = prediction["category_id"] if class_label_map: cls_label = str(class_label_map[cls_label]) - data.append({'boxes': [boxes], 'label': cls_label, 'score': score}) + data.append({"boxes": [boxes], "label": cls_label, "score": score}) - return {'name': 'prediction', 'data': data} + return {"name": "prediction", "data": data} def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map): """Join the ground truth and prediction annotations if they exist.""" - ground_truth_annotations = _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, - class_label_map) - prediction_annotations = _format_prediction_annotations_for_detection(image_path, prediction_metadata_map, - class_label_map) + ground_truth_annotations = _format_ground_truth_annotations_for_detection( + img_idx, image_path, batch, class_label_map + ) + prediction_annotations = _format_prediction_annotations_for_detection( + image_path, prediction_metadata_map, class_label_map + ) annotations = [ - annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None] + annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None + ] return [annotations] if annotations else None @@ -190,8 +202,8 @@ def _create_prediction_metadata_map(model_predictions): """Create metadata map for model predictions by groupings them based on image ID.""" pred_metadata_map = {} for prediction in model_predictions: - pred_metadata_map.setdefault(prediction['image_id'], []) - pred_metadata_map[prediction['image_id']].append(prediction) + pred_metadata_map.setdefault(prediction["image_id"], []) + pred_metadata_map[prediction["image_id"]].append(prediction) return pred_metadata_map @@ -199,13 +211,9 @@ def _create_prediction_metadata_map(model_predictions): def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch): """Log the confusion matrix to Comet experiment.""" conf_mat = trainer.validator.confusion_matrix.matrix - names = list(trainer.data['names'].values()) + ['background'] + names = list(trainer.data["names"].values()) + ["background"] experiment.log_confusion_matrix( - matrix=conf_mat, - labels=names, - max_categories=len(names), - epoch=curr_epoch, - step=curr_step, + matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step ) @@ -243,7 +251,7 @@ def _log_image_predictions(experiment, validator, curr_step): if (batch_idx + 1) % batch_logging_interval != 0: continue - image_paths = batch['im_file'] + image_paths = batch["im_file"] for img_idx, image_path in enumerate(image_paths): if _comet_image_prediction_count >= max_image_predictions: return @@ -267,28 +275,23 @@ def _log_image_predictions(experiment, validator, curr_step): def _log_plots(experiment, trainer): """Logs evaluation plots and label plots for the experiment.""" - plot_filenames = [trainer.save_dir / f'{plots}.png' for plots in EVALUATION_PLOT_NAMES] + plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES] _log_images(experiment, plot_filenames, None) - label_plot_filenames = [trainer.save_dir / f'{labels}.jpg' for labels in LABEL_PLOT_NAMES] + label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES] _log_images(experiment, label_plot_filenames, None) def _log_model(experiment, trainer): """Log the best-trained model to Comet.ml.""" model_name = _get_comet_model_name() - experiment.log_model( - model_name, - file_or_folder=str(trainer.best), - file_name='best.pt', - overwrite=True, - ) + experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True) def on_pretrain_routine_start(trainer): """Creates or resumes a CometML experiment at the start of a YOLO pre-training routine.""" experiment = comet_ml.get_global_experiment() - is_alive = getattr(experiment, 'alive', False) + is_alive = getattr(experiment, "alive", False) if not experiment or not is_alive: _create_experiment(trainer.args) @@ -300,17 +303,13 @@ def on_train_epoch_end(trainer): return metadata = _fetch_trainer_metadata(trainer) - curr_epoch = metadata['curr_epoch'] - curr_step = metadata['curr_step'] + curr_epoch = metadata["curr_epoch"] + curr_step = metadata["curr_step"] - experiment.log_metrics( - trainer.label_loss_items(trainer.tloss, prefix='train'), - step=curr_step, - epoch=curr_epoch, - ) + experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch) if curr_epoch == 1: - _log_images(experiment, trainer.save_dir.glob('train_batch*.jpg'), curr_step) + _log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step) def on_fit_epoch_end(trainer): @@ -320,14 +319,15 @@ def on_fit_epoch_end(trainer): return metadata = _fetch_trainer_metadata(trainer) - curr_epoch = metadata['curr_epoch'] - curr_step = metadata['curr_step'] - save_assets = metadata['save_assets'] + curr_epoch = metadata["curr_epoch"] + curr_step = metadata["curr_step"] + save_assets = metadata["save_assets"] experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch) experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch) if curr_epoch == 1: from ultralytics.utils.torch_utils import model_info_for_loggers + experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch) if not save_assets: @@ -347,8 +347,8 @@ def on_train_end(trainer): return metadata = _fetch_trainer_metadata(trainer) - curr_epoch = metadata['curr_epoch'] - curr_step = metadata['curr_step'] + curr_epoch = metadata["curr_epoch"] + curr_step = metadata["curr_step"] plots = trainer.args.plots _log_model(experiment, trainer) @@ -363,8 +363,13 @@ def on_train_end(trainer): _comet_image_prediction_count = 0 -callbacks = { - 'on_pretrain_routine_start': on_pretrain_routine_start, - 'on_train_epoch_end': on_train_epoch_end, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_train_end': on_train_end} if comet_ml else {} +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if comet_ml + else {} +) diff --git a/ultralytics/utils/callbacks/dvc.py b/ultralytics/utils/callbacks/dvc.py index b5bfa9d..ab51dc5 100644 --- a/ultralytics/utils/callbacks/dvc.py +++ b/ultralytics/utils/callbacks/dvc.py @@ -1,26 +1,18 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING +from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks try: assert not TESTS_RUNNING # do not log pytest - assert SETTINGS['dvc'] is True # verify integration is enabled + assert SETTINGS["dvc"] is True # verify integration is enabled import dvclive - assert hasattr(dvclive, '__version__') # verify package is not directory + assert checks.check_version("dvclive", "2.11.0", verbose=True) import os import re - from importlib.metadata import version from pathlib import Path - import pkg_resources as pkg - - ver = version('dvclive') - if pkg.parse_version(ver) < pkg.parse_version('2.11.0'): - LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).') - dvclive = None # noqa: F811 - # DVCLive logger instance live = None _processed_plots = {} @@ -33,108 +25,121 @@ except (ImportError, AssertionError, TypeError): dvclive = None -def _log_images(path, prefix=''): +def _log_images(path, prefix=""): + """Logs images at specified path with an optional prefix using DVCLive.""" if live: name = path.name # Group images by batch to enable sliders in UI - if m := re.search(r'_batch(\d+)', name): + if m := re.search(r"_batch(\d+)", name): ni = m[1] - new_stem = re.sub(r'_batch(\d+)', '_batch', path.stem) + new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem) name = (Path(new_stem) / ni).with_suffix(path.suffix) live.log_image(os.path.join(prefix, name), path) -def _log_plots(plots, prefix=''): +def _log_plots(plots, prefix=""): + """Logs plot images for training progress if they have not been previously processed.""" for name, params in plots.items(): - timestamp = params['timestamp'] + timestamp = params["timestamp"] if _processed_plots.get(name) != timestamp: _log_images(name, prefix) _processed_plots[name] = timestamp def _log_confusion_matrix(validator): + """Logs the confusion matrix for the given validator using DVCLive.""" targets = [] preds = [] matrix = validator.confusion_matrix.matrix names = list(validator.names.values()) - if validator.confusion_matrix.task == 'detect': - names += ['background'] + if validator.confusion_matrix.task == "detect": + names += ["background"] for ti, pred in enumerate(matrix.T.astype(int)): for pi, num in enumerate(pred): targets.extend([names[ti]] * num) preds.extend([names[pi]] * num) - live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True) + live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True) def on_pretrain_routine_start(trainer): + """Initializes DVCLive logger for training metadata during pre-training routine.""" try: global live live = dvclive.Live(save_dvc_exp=True, cache_images=True) - LOGGER.info( - f'DVCLive is detected and auto logging is enabled (can be disabled in the {SETTINGS.file} with `dvc: false`).' - ) + LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).") except Exception as e: - LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}') + LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}") def on_pretrain_routine_end(trainer): - _log_plots(trainer.plots, 'train') + """Logs plots related to the training process at the end of the pretraining routine.""" + _log_plots(trainer.plots, "train") def on_train_start(trainer): + """Logs the training parameters if DVCLive logging is active.""" if live: live.log_params(trainer.args) def on_train_epoch_start(trainer): + """Sets the global variable _training_epoch value to True at the start of training each epoch.""" global _training_epoch _training_epoch = True def on_fit_epoch_end(trainer): + """Logs training metrics and model info, and advances to next step on the end of each fit epoch.""" global _training_epoch if live and _training_epoch: - all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr} + all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr} for metric, value in all_metrics.items(): live.log_metric(metric, value) if trainer.epoch == 0: from ultralytics.utils.torch_utils import model_info_for_loggers + for metric, value in model_info_for_loggers(trainer).items(): live.log_metric(metric, value, plot=False) - _log_plots(trainer.plots, 'train') - _log_plots(trainer.validator.plots, 'val') + _log_plots(trainer.plots, "train") + _log_plots(trainer.validator.plots, "val") live.next_step() _training_epoch = False def on_train_end(trainer): + """Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active.""" if live: # At the end log the best metrics. It runs validator on the best model internally. - all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr} + all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr} for metric, value in all_metrics.items(): live.log_metric(metric, value, plot=False) - _log_plots(trainer.plots, 'val') - _log_plots(trainer.validator.plots, 'val') + _log_plots(trainer.plots, "val") + _log_plots(trainer.validator.plots, "val") _log_confusion_matrix(trainer.validator) if trainer.best.exists(): - live.log_artifact(trainer.best, copy=True, type='model') + live.log_artifact(trainer.best, copy=True, type="model") live.end() -callbacks = { - 'on_pretrain_routine_start': on_pretrain_routine_start, - 'on_pretrain_routine_end': on_pretrain_routine_end, - 'on_train_start': on_train_start, - 'on_train_epoch_start': on_train_epoch_start, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_train_end': on_train_end} if dvclive else {} +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_pretrain_routine_end": on_pretrain_routine_end, + "on_train_start": on_train_start, + "on_train_epoch_start": on_train_epoch_start, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if dvclive + else {} +) diff --git a/ultralytics/utils/callbacks/hub.py b/ultralytics/utils/callbacks/hub.py index 7171fb9..cdb42b9 100644 --- a/ultralytics/utils/callbacks/hub.py +++ b/ultralytics/utils/callbacks/hub.py @@ -9,51 +9,67 @@ from ultralytics.utils import LOGGER, SETTINGS def on_pretrain_routine_end(trainer): """Logs info before starting timer for upload rate limit.""" - session = getattr(trainer, 'hub_session', None) + session = getattr(trainer, "hub_session", None) if session: # Start timer for upload rate limit - LOGGER.info(f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀') - session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit + session.timers = { + "metrics": time(), + "ckpt": time(), + } # start timer on session.rate_limit def on_fit_epoch_end(trainer): """Uploads training progress metrics at the end of each epoch.""" - session = getattr(trainer, 'hub_session', None) + session = getattr(trainer, "hub_session", None) if session: # Upload metrics after val end - all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics} + all_plots = { + **trainer.label_loss_items(trainer.tloss, prefix="train"), + **trainer.metrics, + } if trainer.epoch == 0: from ultralytics.utils.torch_utils import model_info_for_loggers + all_plots = {**all_plots, **model_info_for_loggers(trainer)} + session.metrics_queue[trainer.epoch] = json.dumps(all_plots) - if time() - session.timers['metrics'] > session.rate_limits['metrics']: + + # If any metrics fail to upload, add them to the queue to attempt uploading again. + if session.metrics_upload_failed_queue: + session.metrics_queue.update(session.metrics_upload_failed_queue) + + if time() - session.timers["metrics"] > session.rate_limits["metrics"]: session.upload_metrics() - session.timers['metrics'] = time() # reset timer + session.timers["metrics"] = time() # reset timer session.metrics_queue = {} # reset queue def on_model_save(trainer): """Saves checkpoints to Ultralytics HUB with rate limiting.""" - session = getattr(trainer, 'hub_session', None) + session = getattr(trainer, "hub_session", None) if session: # Upload checkpoints with rate limiting is_best = trainer.best_fitness == trainer.fitness - if time() - session.timers['ckpt'] > session.rate_limits['ckpt']: - LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_id}') + if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]: + LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}") session.upload_model(trainer.epoch, trainer.last, is_best) - session.timers['ckpt'] = time() # reset timer + session.timers["ckpt"] = time() # reset timer def on_train_end(trainer): """Upload final model and metrics to Ultralytics HUB at the end of training.""" - session = getattr(trainer, 'hub_session', None) + session = getattr(trainer, "hub_session", None) if session: # Upload final model and metrics with exponential standoff - LOGGER.info(f'{PREFIX}Syncing final model...') - session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True) + LOGGER.info(f"{PREFIX}Syncing final model...") + session.upload_model( + trainer.epoch, + trainer.best, + map=trainer.metrics.get("metrics/mAP50-95(B)", 0), + final=True, + ) session.alive = False # stop heartbeats - LOGGER.info(f'{PREFIX}Done ✅\n' - f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀') + LOGGER.info(f"{PREFIX}Done ✅\n" f"{PREFIX}View model at {session.model_url} 🚀") def on_train_start(trainer): @@ -76,12 +92,17 @@ def on_export_start(exporter): events(exporter.args) -callbacks = { - 'on_pretrain_routine_end': on_pretrain_routine_end, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_model_save': on_model_save, - 'on_train_end': on_train_end, - 'on_train_start': on_train_start, - 'on_val_start': on_val_start, - 'on_predict_start': on_predict_start, - 'on_export_start': on_export_start} if SETTINGS['hub'] is True else {} # verify enabled +callbacks = ( + { + "on_pretrain_routine_end": on_pretrain_routine_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_model_save": on_model_save, + "on_train_end": on_train_end, + "on_train_start": on_train_start, + "on_val_start": on_val_start, + "on_predict_start": on_predict_start, + "on_export_start": on_export_start, + } + if SETTINGS["hub"] is True + else {} +) # verify enabled diff --git a/ultralytics/utils/callbacks/mlflow.py b/ultralytics/utils/callbacks/mlflow.py index 8d4501b..e554620 100644 --- a/ultralytics/utils/callbacks/mlflow.py +++ b/ultralytics/utils/callbacks/mlflow.py @@ -1,70 +1,133 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license +""" +MLflow Logging for Ultralytics YOLO. -from ultralytics.utils import LOGGER, ROOT, SETTINGS, TESTS_RUNNING, colorstr +This module enables MLflow logging for Ultralytics YOLO. It logs metrics, parameters, and model artifacts. +For setting up, a tracking URI should be specified. The logging can be customized using environment variables. + +Commands: + 1. To set a project name: + `export MLFLOW_EXPERIMENT_NAME=` or use the project= argument + + 2. To set a run name: + `export MLFLOW_RUN=` or use the name= argument + + 3. To start a local MLflow server: + mlflow server --backend-store-uri runs/mlflow + It will by default start a local server at http://127.0.0.1:5000. + To specify a different URI, set the MLFLOW_TRACKING_URI environment variable. + + 4. To kill all running MLflow server instances: + ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9 +""" + +from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr try: - assert not TESTS_RUNNING # do not log pytest - assert SETTINGS['mlflow'] is True # verify integration is enabled + import os + + assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest + assert SETTINGS["mlflow"] is True # verify integration is enabled import mlflow - assert hasattr(mlflow, '__version__') # verify package is not directory + assert hasattr(mlflow, "__version__") # verify package is not directory + from pathlib import Path - import os - import re + PREFIX = colorstr("MLflow: ") + SANITIZE = lambda x: {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()} except (ImportError, AssertionError): mlflow = None def on_pretrain_routine_end(trainer): - """Logs training parameters to MLflow.""" - global mlflow, run, experiment_name + """ + Log training parameters to MLflow at the end of the pretraining routine. - if os.environ.get('MLFLOW_TRACKING_URI') is None: - mlflow = None + This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI, + experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters + from the trainer. + Args: + trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log. + + Global: + mlflow: The imported mlflow module to use for logging. + + Environment Variables: + MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'. + MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project. + MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name. + MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after the end of the training phase. + """ + global mlflow + + uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow") + LOGGER.debug(f"{PREFIX} tracking uri: {uri}") + mlflow.set_tracking_uri(uri) + + # Set experiment and run names + experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8" + run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name + mlflow.set_experiment(experiment_name) + + mlflow.autolog() + try: + active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name) + LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}") + if Path(uri).is_dir(): + LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'") + LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'") + mlflow.log_params(dict(trainer.args)) + except Exception as e: + LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run") + + +def on_train_epoch_end(trainer): + """Log training metrics at the end of each train epoch to MLflow.""" if mlflow: - mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000" - mlflow.set_tracking_uri(mlflow_location) - - experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME') or trainer.args.project or '/Shared/YOLOv8' - run_name = os.environ.get('MLFLOW_RUN') or trainer.args.name - experiment = mlflow.get_experiment_by_name(experiment_name) - if experiment is None: - mlflow.create_experiment(experiment_name) - mlflow.set_experiment(experiment_name) - - prefix = colorstr('MLFlow: ') - try: - run, active_run = mlflow, mlflow.active_run() - if not active_run: - active_run = mlflow.start_run(experiment_id=experiment.experiment_id, run_name=run_name) - LOGGER.info(f'{prefix}Using run_id({active_run.info.run_id}) at {mlflow_location}') - run.log_params(vars(trainer.model.args)) - except Exception as err: - LOGGER.error(f'{prefix}Failing init - {repr(err)}') - LOGGER.warning(f'{prefix}Continuing without Mlflow') + mlflow.log_metrics( + metrics={ + **SANITIZE(trainer.lr), + **SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")), + }, + step=trainer.epoch, + ) def on_fit_epoch_end(trainer): - """Logs training metrics to Mlflow.""" + """Log training metrics at the end of each fit epoch to MLflow.""" if mlflow: - metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()} - run.log_metrics(metrics=metrics_dict, step=trainer.epoch) + mlflow.log_metrics(metrics=SANITIZE(trainer.metrics), step=trainer.epoch) def on_train_end(trainer): - """Called at end of train loop to log model artifact info.""" + """Log model artifacts at the end of the training.""" if mlflow: - run.log_artifact(trainer.last) - run.log_artifact(trainer.best) - run.pyfunc.log_model(artifact_path=experiment_name, - code_path=[str(ROOT.parent)], - artifacts={'model_path': str(trainer.save_dir)}, - python_model=run.pyfunc.PythonModel()) + mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt + for f in trainer.save_dir.glob("*"): # log all other files in save_dir + if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}: + mlflow.log_artifact(str(f)) + keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() in ("true") + if keep_run_active: + LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()") + else: + mlflow.end_run() + LOGGER.debug(f"{PREFIX}mlflow run ended") + + LOGGER.info( + f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n" + f"{PREFIX}disable with 'yolo settings mlflow=False'" + ) -callbacks = { - 'on_pretrain_routine_end': on_pretrain_routine_end, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_train_end': on_train_end} if mlflow else {} +callbacks = ( + { + "on_pretrain_routine_end": on_pretrain_routine_end, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if mlflow + else {} +) diff --git a/ultralytics/utils/callbacks/neptune.py b/ultralytics/utils/callbacks/neptune.py index 40916a3..6be8a82 100644 --- a/ultralytics/utils/callbacks/neptune.py +++ b/ultralytics/utils/callbacks/neptune.py @@ -4,11 +4,11 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING try: assert not TESTS_RUNNING # do not log pytest - assert SETTINGS['neptune'] is True # verify integration is enabled + assert SETTINGS["neptune"] is True # verify integration is enabled import neptune from neptune.types import File - assert hasattr(neptune, '__version__') + assert hasattr(neptune, "__version__") run = None # NeptuneAI experiment logger instance @@ -23,55 +23,55 @@ def _log_scalars(scalars, step=0): run[k].append(value=v, step=step) -def _log_images(imgs_dict, group=''): +def _log_images(imgs_dict, group=""): """Log scalars to the NeptuneAI experiment logger.""" if run: for k, v in imgs_dict.items(): - run[f'{group}/{k}'].upload(File(v)) + run[f"{group}/{k}"].upload(File(v)) def _log_plot(title, plot_path): - """Log plots to the NeptuneAI experiment logger.""" """ - Log image as plot in the plot section of NeptuneAI + Log plots to the NeptuneAI experiment logger. - arguments: - title (str) Title of the plot - plot_path (PosixPath or str) Path to the saved image file - """ + Args: + title (str): Title of the plot. + plot_path (PosixPath | str): Path to the saved image file. + """ import matplotlib.image as mpimg import matplotlib.pyplot as plt img = mpimg.imread(plot_path) fig = plt.figure() - ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks + ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks ax.imshow(img) - run[f'Plots/{title}'].upload(fig) + run[f"Plots/{title}"].upload(fig) def on_pretrain_routine_start(trainer): """Callback function called before the training routine starts.""" try: global run - run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8']) - run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()} + run = neptune.init_run(project=trainer.args.project or "YOLOv8", name=trainer.args.name, tags=["YOLOv8"]) + run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()} except Exception as e: - LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}') + LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}") def on_train_epoch_end(trainer): """Callback function called at end of each training epoch.""" - _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1) + _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1) _log_scalars(trainer.lr, trainer.epoch + 1) if trainer.epoch == 1: - _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic') + _log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic") def on_fit_epoch_end(trainer): """Callback function called at end of each fit (train+val) epoch.""" if run and trainer.epoch == 0: from ultralytics.utils.torch_utils import model_info_for_loggers - run['Configuration/Model'] = model_info_for_loggers(trainer) + + run["Configuration/Model"] = model_info_for_loggers(trainer) _log_scalars(trainer.metrics, trainer.epoch + 1) @@ -79,7 +79,7 @@ def on_val_end(validator): """Callback function called at end of each validation.""" if run: # Log val_labels and val_pred - _log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation') + _log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation") def on_train_end(trainer): @@ -87,19 +87,26 @@ def on_train_end(trainer): if run: # Log final results, CM matrix + PR plots files = [ - 'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png', - *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))] + "results.png", + "confusion_matrix.png", + "confusion_matrix_normalized.png", + *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")), + ] files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter for f in files: _log_plot(title=f.stem, plot_path=f) # Log the final model - run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str( - trainer.best))) + run[f"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}"].upload(File(str(trainer.best))) -callbacks = { - 'on_pretrain_routine_start': on_pretrain_routine_start, - 'on_train_epoch_end': on_train_epoch_end, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_val_end': on_val_end, - 'on_train_end': on_train_end} if neptune else {} +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_val_end": on_val_end, + "on_train_end": on_train_end, + } + if neptune + else {} +) diff --git a/ultralytics/utils/callbacks/raytune.py b/ultralytics/utils/callbacks/raytune.py index 417b331..f269455 100644 --- a/ultralytics/utils/callbacks/raytune.py +++ b/ultralytics/utils/callbacks/raytune.py @@ -3,7 +3,7 @@ from ultralytics.utils import SETTINGS try: - assert SETTINGS['raytune'] is True # verify integration is enabled + assert SETTINGS["raytune"] is True # verify integration is enabled import ray from ray import tune from ray.air import session @@ -16,9 +16,14 @@ def on_fit_epoch_end(trainer): """Sends training metrics to Ray Tune at end of each epoch.""" if ray.tune.is_session_enabled(): metrics = trainer.metrics - metrics['epoch'] = trainer.epoch + metrics["epoch"] = trainer.epoch session.report(metrics) -callbacks = { - 'on_fit_epoch_end': on_fit_epoch_end, } if tune else {} +callbacks = ( + { + "on_fit_epoch_end": on_fit_epoch_end, + } + if tune + else {} +) diff --git a/ultralytics/utils/callbacks/tensorboard.py b/ultralytics/utils/callbacks/tensorboard.py index c1fce53..59024ee 100644 --- a/ultralytics/utils/callbacks/tensorboard.py +++ b/ultralytics/utils/callbacks/tensorboard.py @@ -1,17 +1,25 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license +import contextlib from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr try: - # WARNING: do not move import due to protobuf issue in https://github.com/ultralytics/ultralytics/pull/4674 + # WARNING: do not move SummaryWriter import due to protobuf bug https://github.com/ultralytics/ultralytics/pull/4674 from torch.utils.tensorboard import SummaryWriter assert not TESTS_RUNNING # do not log pytest - assert SETTINGS['tensorboard'] is True # verify integration is enabled + assert SETTINGS["tensorboard"] is True # verify integration is enabled WRITER = None # TensorBoard SummaryWriter instance + PREFIX = colorstr("TensorBoard: ") -except (ImportError, AssertionError, TypeError): + # Imports below only required if TensorBoard enabled + import warnings + from copy import deepcopy + from ultralytics.utils.torch_utils import de_parallel, torch + +except (ImportError, AssertionError, TypeError, AttributeError): # TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows + # AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed SummaryWriter = None @@ -24,20 +32,38 @@ def _log_scalars(scalars, step=0): def _log_tensorboard_graph(trainer): """Log model graph to TensorBoard.""" - try: - import warnings - from ultralytics.utils.torch_utils import de_parallel, torch + # Input image + imgsz = trainer.args.imgsz + imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz + p = next(trainer.model.parameters()) # for device, type + im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty) - imgsz = trainer.args.imgsz - imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz - p = next(trainer.model.parameters()) # for device, type - im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty) - with warnings.catch_warnings(): - warnings.simplefilter('ignore', category=UserWarning) # suppress jit trace warning + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning + warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning + + # Try simple method first (YOLO) + with contextlib.suppress(Exception): + trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), []) - except Exception as e: - LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}') + LOGGER.info(f"{PREFIX}model graph visualization added ✅") + return + + # Fallback to TorchScript export steps (RTDETR) + try: + model = deepcopy(de_parallel(trainer.model)) + model.eval() + model = model.fuse(verbose=False) + for m in model.modules(): + if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class) + m.export = True + m.format = "torchscript" + model(im) # dry run + WRITER.add_graph(torch.jit.trace(model, im, strict=False), []) + LOGGER.info(f"{PREFIX}model graph visualization added ✅") + except Exception as e: + LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}") def on_pretrain_routine_start(trainer): @@ -46,10 +72,9 @@ def on_pretrain_routine_start(trainer): try: global WRITER WRITER = SummaryWriter(str(trainer.save_dir)) - prefix = colorstr('TensorBoard: ') - LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/") + LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/") except Exception as e: - LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}') + LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}") def on_train_start(trainer): @@ -58,9 +83,10 @@ def on_train_start(trainer): _log_tensorboard_graph(trainer) -def on_batch_end(trainer): - """Logs scalar statistics at the end of a training batch.""" - _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1) +def on_train_epoch_end(trainer): + """Logs scalar statistics at the end of a training epoch.""" + _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1) + _log_scalars(trainer.lr, trainer.epoch + 1) def on_fit_epoch_end(trainer): @@ -68,8 +94,13 @@ def on_fit_epoch_end(trainer): _log_scalars(trainer.metrics, trainer.epoch + 1) -callbacks = { - 'on_pretrain_routine_start': on_pretrain_routine_start, - 'on_train_start': on_train_start, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_batch_end': on_batch_end} if SummaryWriter else {} +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_start": on_train_start, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_epoch_end": on_train_epoch_end, + } + if SummaryWriter + else {} +) diff --git a/ultralytics/utils/callbacks/wb.py b/ultralytics/utils/callbacks/wb.py index 27b3874..25a1b64 100644 --- a/ultralytics/utils/callbacks/wb.py +++ b/ultralytics/utils/callbacks/wb.py @@ -5,10 +5,13 @@ from ultralytics.utils.torch_utils import model_info_for_loggers try: assert not TESTS_RUNNING # do not log pytest - assert SETTINGS['wandb'] is True # verify integration is enabled + assert SETTINGS["wandb"] is True # verify integration is enabled import wandb as wb - assert hasattr(wb, '__version__') + assert hasattr(wb, "__version__") # verify package is not directory + + import numpy as np + import pandas as pd _processed_plots = {} @@ -16,9 +19,89 @@ except (ImportError, AssertionError): wb = None +def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"): + """ + Create and log a custom metric visualization to wandb.plot.pr_curve. + + This function crafts a custom metric visualization that mimics the behavior of wandb's default precision-recall + curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across + different classes. + + Args: + x (List): Values for the x-axis; expected to have length N. + y (List): Corresponding values for the y-axis; also expected to have length N. + classes (List): Labels identifying the class of each point; length N. + title (str, optional): Title for the plot; defaults to 'Precision Recall Curve'. + x_title (str, optional): Label for the x-axis; defaults to 'Recall'. + y_title (str, optional): Label for the y-axis; defaults to 'Precision'. + + Returns: + (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization. + """ + df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3) + fields = {"x": "x", "y": "y", "class": "class"} + string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title} + return wb.plot_table( + "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields + ) + + +def _plot_curve( + x, + y, + names=None, + id="precision-recall", + title="Precision Recall Curve", + x_title="Recall", + y_title="Precision", + num_x=100, + only_mean=False, +): + """ + Log a metric curve visualization. + + This function generates a metric curve based on input data and logs the visualization to wandb. + The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag. + + Args: + x (np.ndarray): Data points for the x-axis with length N. + y (np.ndarray): Corresponding data points for the y-axis with shape CxN, where C is the number of classes. + names (list, optional): Names of the classes corresponding to the y-axis data; length C. Defaults to []. + id (str, optional): Unique identifier for the logged data in wandb. Defaults to 'precision-recall'. + title (str, optional): Title for the visualization plot. Defaults to 'Precision Recall Curve'. + x_title (str, optional): Label for the x-axis. Defaults to 'Recall'. + y_title (str, optional): Label for the y-axis. Defaults to 'Precision'. + num_x (int, optional): Number of interpolated data points for visualization. Defaults to 100. + only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. Defaults to True. + + Note: + The function leverages the '_custom_table' function to generate the actual visualization. + """ + # Create new x + if names is None: + names = [] + x_new = np.linspace(x[0], x[-1], num_x).round(5) + + # Create arrays for logging + x_log = x_new.tolist() + y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist() + + if only_mean: + table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title]) + wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)}) + else: + classes = ["mean"] * len(x_log) + for i, yi in enumerate(y): + x_log.extend(x_new) # add new x + y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x + classes.extend([names[i]] * len(x_new)) # add class names + wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False) + + def _log_plots(plots, step): + """Logs plots from the input dictionary if they haven't been logged already at the specified step.""" for name, params in plots.items(): - timestamp = params['timestamp'] + timestamp = params["timestamp"] if _processed_plots.get(name) != timestamp: wb.run.log({name.stem: wb.Image(str(name))}, step=step) _processed_plots[name] = timestamp @@ -26,7 +109,7 @@ def _log_plots(plots, step): def on_pretrain_routine_start(trainer): """Initiate and start project if module is present.""" - wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args)) + wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args)) def on_fit_epoch_end(trainer): @@ -40,7 +123,7 @@ def on_fit_epoch_end(trainer): def on_train_epoch_end(trainer): """Log metrics and save images at the end of each training epoch.""" - wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1) + wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1) wb.run.log(trainer.lr, step=trainer.epoch + 1) if trainer.epoch == 1: _log_plots(trainer.plots, step=trainer.epoch + 1) @@ -50,14 +133,31 @@ def on_train_end(trainer): """Save the best model as an artifact at end of training.""" _log_plots(trainer.validator.plots, step=trainer.epoch + 1) _log_plots(trainer.plots, step=trainer.epoch + 1) - art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model') + art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model") if trainer.best.exists(): art.add_file(trainer.best) - wb.run.log_artifact(art, aliases=['best']) + wb.run.log_artifact(art, aliases=["best"]) + for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results): + x, y, x_title, y_title = curve_values + _plot_curve( + x, + y, + names=list(trainer.validator.metrics.names.values()), + id=f"curves/{curve_name}", + title=curve_name, + x_title=x_title, + y_title=y_title, + ) + wb.run.finish() # required or run continues on dashboard -callbacks = { - 'on_pretrain_routine_start': on_pretrain_routine_start, - 'on_train_epoch_end': on_train_epoch_end, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_train_end': on_train_end} if wb else {} +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if wb + else {} +) diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py index 28cad0d..c44ac0b 100644 --- a/ultralytics/utils/checks.py +++ b/ultralytics/utils/checks.py @@ -1,4 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license + import contextlib import glob import inspect @@ -9,20 +10,96 @@ import re import shutil import subprocess import time +from importlib import metadata from pathlib import Path from typing import Optional import cv2 import numpy as np -import pkg_resources as pkg -import psutil import requests import torch from matplotlib import font_manager -from ultralytics.utils import (ASSETS, AUTOINSTALL, LINUX, LOGGER, ONLINE, ROOT, USER_CONFIG_DIR, ThreadingLocked, - TryExcept, clean_url, colorstr, downloads, emojis, is_colab, is_docker, is_jupyter, - is_kaggle, is_online, is_pip_package, url2file) +from ultralytics.utils import ( + ASSETS, + AUTOINSTALL, + LINUX, + LOGGER, + ONLINE, + ROOT, + USER_CONFIG_DIR, + SimpleNamespace, + ThreadingLocked, + TryExcept, + clean_url, + colorstr, + downloads, + emojis, + is_colab, + is_docker, + is_github_action_running, + is_jupyter, + is_kaggle, + is_online, + is_pip_package, + url2file, +) + +PYTHON_VERSION = platform.python_version() + + +def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""): + """ + Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'. + + Args: + file_path (Path): Path to the requirements.txt file. + package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'. + + Returns: + (List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys. + + Example: + ```python + from ultralytics.utils.checks import parse_requirements + + parse_requirements(package='ultralytics') + ``` + """ + + if package: + requires = [x for x in metadata.distribution(package).requires if "extra == " not in x] + else: + requires = Path(file_path).read_text().splitlines() + + requirements = [] + for line in requires: + line = line.strip() + if line and not line.startswith("#"): + line = line.split("#")[0].strip() # ignore inline comments + match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line) + if match: + requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else "")) + + return requirements + + +def parse_version(version="0.0.0") -> tuple: + """ + Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This + function replaces deprecated 'pkg_resources.parse_version(v)'. + + Args: + version (str): Version string, i.e. '2.0.1+cpu' + + Returns: + (tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1) + """ + try: + return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1) + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}") + return 0, 0, 0 def is_ascii(s) -> bool: @@ -33,7 +110,7 @@ def is_ascii(s) -> bool: s (str): String to be checked. Returns: - bool: True if the string is composed only of ASCII characters, False otherwise. + (bool): True if the string is composed only of ASCII characters, False otherwise. """ # Convert list, tuple, None, etc. to string s = str(s) @@ -65,16 +142,22 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): imgsz = [imgsz] elif isinstance(imgsz, (list, tuple)): imgsz = list(imgsz) + elif isinstance(imgsz, str): # i.e. '640' or '[640,640]' + imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz) else: - raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. " - f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'") + raise TypeError( + f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. " + f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'" + ) # Apply max_dim if len(imgsz) > max_dim: - msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \ - "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'" + msg = ( + "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " + "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'" + ) if max_dim != 1: - raise ValueError(f'imgsz={imgsz} is not a valid image size. {msg}') + raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}") LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}") imgsz = [max(imgsz)] # Make image size a multiple of the stride @@ -82,7 +165,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): # Print warning message if image size was updated if sz != imgsz: - LOGGER.warning(f'WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}') + LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}") # Add missing dimensions if necessary sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz @@ -90,66 +173,88 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): return sz -def check_version(current: str = '0.0.0', - required: str = '0.0.0', - name: str = 'version ', - hard: bool = False, - verbose: bool = False) -> bool: +def check_version( + current: str = "0.0.0", + required: str = "0.0.0", + name: str = "version", + hard: bool = False, + verbose: bool = False, + msg: str = "", +) -> bool: """ Check current version against the required version or range. Args: - current (str): Current version. + current (str): Current version or package name to get version from. required (str): Required version or range (in pip-style format). - name (str): Name to be used in warning message. - hard (bool): If True, raise an AssertionError if the requirement is not met. - verbose (bool): If True, print warning message if requirement is not met. + name (str, optional): Name to be used in warning message. + hard (bool, optional): If True, raise an AssertionError if the requirement is not met. + verbose (bool, optional): If True, print warning message if requirement is not met. + msg (str, optional): Extra message to display if verbose. Returns: (bool): True if requirement is met, False otherwise. Example: - # check if current version is exactly 22.04 + ```python + # Check if current version is exactly 22.04 check_version(current='22.04', required='==22.04') - # check if current version is greater than or equal to 22.04 + # Check if current version is greater than or equal to 22.04 check_version(current='22.10', required='22.04') # assumes '>=' inequality if none passed - # check if current version is less than or equal to 22.04 + # Check if current version is less than or equal to 22.04 check_version(current='22.04', required='<=22.04') - # check if current version is between 20.04 (inclusive) and 22.04 (exclusive) + # Check if current version is between 20.04 (inclusive) and 22.04 (exclusive) check_version(current='21.10', required='>20.04,<22.04') + ``` """ - current = pkg.parse_version(current) - constraints = re.findall(r'([<>!=]{1,2}\s*\d+\.\d+)', required) or [f'>={required}'] + if not current: # if current is '' or None + LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.") + return True + elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics' + try: + name = current # assigned package name to 'name' arg + current = metadata.version(current) # get version string from package name + except metadata.PackageNotFoundError as e: + if hard: + raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) from e + else: + return False + if not required: # if required is '' or None + return True + + op = "" + version = "" result = True - for constraint in constraints: - op, version = re.match(r'([<>!=]{1,2})\s*(\d+\.\d+)', constraint).groups() - version = pkg.parse_version(version) - if op == '==' and current != version: + c = parse_version(current) # '1.2.3' -> (1, 2, 3) + for r in required.strip(",").split(","): + op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04') + v = parse_version(version) # '1.2.3' -> (1, 2, 3) + if op == "==" and c != v: result = False - elif op == '!=' and current == version: + elif op == "!=" and c == v: result = False - elif op == '>=' and not (current >= version): + elif op in (">=", "") and not (c >= v): # if no constraint passed assume '>=required' result = False - elif op == '<=' and not (current <= version): + elif op == "<=" and not (c <= v): result = False - elif op == '>' and not (current > version): + elif op == ">" and not (c > v): result = False - elif op == '<' and not (current < version): + elif op == "<" and not (c < v): result = False if not result: - warning_message = f'WARNING ⚠️ {name}{required} is required, but {name}{current} is currently installed' + warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}" if hard: - raise ModuleNotFoundError(emojis(warning_message)) # assert version requirements met + raise ModuleNotFoundError(emojis(warning)) # assert version requirements met if verbose: - LOGGER.warning(warning_message) + LOGGER.warning(warning) return result -def check_latest_pypi_version(package_name='ultralytics'): +def check_latest_pypi_version(package_name="ultralytics"): """ Returns the latest version of a PyPI package without downloading or installing it. @@ -161,9 +266,9 @@ def check_latest_pypi_version(package_name='ultralytics'): """ with contextlib.suppress(Exception): requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning - response = requests.get(f'https://pypi.org/pypi/{package_name}/json', timeout=3) + response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3) if response.status_code == 200: - return response.json()['info']['version'] + return response.json()["info"]["version"] def check_pip_update_available(): @@ -176,16 +281,19 @@ def check_pip_update_available(): if ONLINE and is_pip_package(): with contextlib.suppress(Exception): from ultralytics import __version__ + latest = check_latest_pypi_version() - if pkg.parse_version(__version__) < pkg.parse_version(latest): # update is available - LOGGER.info(f'New https://pypi.org/project/ultralytics/{latest} available 😃 ' - f"Update with 'pip install -U ultralytics'") + if check_version(__version__, f"<{latest}"): # check if current version is < latest version + LOGGER.info( + f"New https://pypi.org/project/ultralytics/{latest} available 😃 " + f"Update with 'pip install -U ultralytics'" + ) return True return False @ThreadingLocked() -def check_font(font='Arial.ttf'): +def check_font(font="Arial.ttf"): """ Find font locally or download to user's configuration directory if it does not already exist. @@ -208,13 +316,13 @@ def check_font(font='Arial.ttf'): return matches[0] # Download to USER_CONFIG_DIR if missing - url = f'https://ultralytics.com/assets/{name}' - if downloads.is_url(url): + url = f"https://ultralytics.com/assets/{name}" + if downloads.is_url(url, check=True): downloads.safe_download(url=url, file=file) return file -def check_python(minimum: str = '3.8.0') -> bool: +def check_python(minimum: str = "3.8.0") -> bool: """ Check current python version against the required minimum version. @@ -222,13 +330,13 @@ def check_python(minimum: str = '3.8.0') -> bool: minimum (str): Required minimum version of python. Returns: - None + (bool): Whether the installed Python version meets the minimum constraints. """ - return check_version(platform.python_version(), minimum, name='Python ', hard=True) + return check_version(PYTHON_VERSION, minimum, name="Python ", hard=True) @TryExcept() -def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=(), install=True, cmds=''): +def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""): """ Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed. @@ -253,46 +361,43 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=() check_requirements(['numpy', 'ultralytics>=8.0.0']) ``` """ - prefix = colorstr('red', 'bold', 'requirements:') + + prefix = colorstr("red", "bold", "requirements:") check_python() # check python version check_torchvision() # check torch-torchvision compatibility if isinstance(requirements, Path): # requirements.txt file file = requirements.resolve() - assert file.exists(), f'{prefix} {file} not found, check failed.' - with file.open() as f: - requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude] + assert file.exists(), f"{prefix} {file} not found, check failed." + requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude] elif isinstance(requirements, str): requirements = [requirements] pkgs = [] for r in requirements: - r_stripped = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo' + r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo' + match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped) + name, required = match[1], match[2].strip() if match[2] else "" try: - pkg.require(r_stripped) # exception if requirements not met - except pkg.DistributionNotFound: - try: # attempt to import (slower but more accurate) - import importlib - importlib.import_module(next(pkg.parse_requirements(r_stripped)).name) - except ImportError: - pkgs.append(r) - except pkg.VersionConflict: + assert check_version(metadata.version(name), required) # exception if requirements not met + except (AssertionError, metadata.PackageNotFoundError): pkgs.append(r) - s = ' '.join(f'"{x}"' for x in pkgs) # console string + s = " ".join(f'"{x}"' for x in pkgs) # console string if s: if install and AUTOINSTALL: # check environment variable n = len(pkgs) # number of packages updates LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...") try: t = time.time() - assert is_online(), 'AutoUpdate skipped (offline)' - LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode()) + assert is_online(), "AutoUpdate skipped (offline)" + LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode()) dt = time.time() - t LOGGER.info( f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n" - f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n") + f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" + ) except Exception as e: - LOGGER.warning(f'{prefix} ❌ {e}') + LOGGER.warning(f"{prefix} ❌ {e}") return False else: return False @@ -305,134 +410,211 @@ def check_torchvision(): Checks the installed versions of PyTorch and Torchvision to ensure they're compatible. This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according - to the provided compatibility table based on https://github.com/pytorch/vision#installation. The - compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible + to the provided compatibility table based on: + https://github.com/pytorch/vision#installation. + + The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible Torchvision versions. """ import torchvision # Compatibility table - compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']} + compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]} # Extract only the major and minor versions - v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2]) - v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2]) + v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2]) + v_torchvision = ".".join(torchvision.__version__.split("+")[0].split(".")[:2]) if v_torch in compatibility_table: compatible_versions = compatibility_table[v_torch] - if all(pkg.parse_version(v_torchvision) != pkg.parse_version(v) for v in compatible_versions): - print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n' - f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or " - "'pip install -U torch torchvision' to update both.\n" - 'For a full compatibility table see https://github.com/pytorch/vision#installation') + if all(v_torchvision != v for v in compatible_versions): + print( + f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n" + f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or " + "'pip install -U torch torchvision' to update both.\n" + "For a full compatibility table see https://github.com/pytorch/vision#installation" + ) -def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''): +def check_suffix(file="yolov8n.pt", suffix=".pt", msg=""): """Check file(s) for acceptable suffix.""" if file and suffix: if isinstance(suffix, str): - suffix = (suffix, ) + suffix = (suffix,) for f in file if isinstance(file, (list, tuple)) else [file]: s = Path(f).suffix.lower().strip() # file suffix if len(s): - assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}, not {s}' + assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}" def check_yolov5u_filename(file: str, verbose: bool = True): """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.""" - if 'yolov3' in file or 'yolov5' in file: - if 'u.yaml' in file: - file = file.replace('u.yaml', '.yaml') # i.e. yolov5nu.yaml -> yolov5n.yaml - elif '.pt' in file and 'u' not in file: + if "yolov3" in file or "yolov5" in file: + if "u.yaml" in file: + file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml + elif ".pt" in file and "u" not in file: original_file = file - file = re.sub(r'(.*yolov5([nsmlx]))\.pt', '\\1u.pt', file) # i.e. yolov5n.pt -> yolov5nu.pt - file = re.sub(r'(.*yolov5([nsmlx])6)\.pt', '\\1u.pt', file) # i.e. yolov5n6.pt -> yolov5n6u.pt - file = re.sub(r'(.*yolov3(|-tiny|-spp))\.pt', '\\1u.pt', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt + file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt + file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt + file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt if file != original_file and verbose: LOGGER.info( f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " - f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs ' - f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n') + f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs " + f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n" + ) return file -def check_file(file, suffix='', download=True, hard=True): +def check_model_file_from_stem(model="yolov8n"): + """Return a model filename from a valid model stem.""" + if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS: + return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt + else: + return model + + +def check_file(file, suffix="", download=True, hard=True): """Search/download file (if necessary) and return path.""" check_suffix(file, suffix) # optional file = str(file).strip() # convert to string and strip spaces file = check_yolov5u_filename(file) # yolov5n -> yolov5nu - if not file or ('://' not in file and Path(file).exists()): # exists ('://' check required in Windows Python<3.10) + if ( + not file + or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10 + or file.lower().startswith("grpc://") + ): # file exists or gRPC Triton images return file - elif download and file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')): # download + elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download url = file # warning: Pathlib turns :// -> :/ file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth if Path(file).exists(): - LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists else: downloads.safe_download(url=url, file=file, unzip=False) return file else: # search - files = glob.glob(str(ROOT / 'cfg' / '**' / file), recursive=True) # find file + files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file if not files and hard: raise FileNotFoundError(f"'{file}' does not exist") elif len(files) > 1 and hard: raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") - return files[0] if len(files) else [] # return file + return files[0] if len(files) else [] if hard else file # return file -def check_yaml(file, suffix=('.yaml', '.yml'), hard=True): +def check_yaml(file, suffix=(".yaml", ".yml"), hard=True): """Search/download YAML file (if necessary) and return path, checking suffix.""" return check_file(file, suffix, hard=hard) +def check_is_path_safe(basedir, path): + """ + Check if the resolved path is under the intended directory to prevent path traversal. + + Args: + basedir (Path | str): The intended directory. + path (Path | str): The path to check. + + Returns: + (bool): True if the path is safe, False otherwise. + """ + base_dir_resolved = Path(basedir).resolve() + path_resolved = Path(path).resolve() + + return path_resolved.is_file() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts + + def check_imshow(warn=False): """Check if environment supports image displays.""" try: if LINUX: - assert 'DISPLAY' in os.environ and not is_docker() and not is_colab() and not is_kaggle() - cv2.imshow('test', np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image + assert "DISPLAY" in os.environ and not is_docker() and not is_colab() and not is_kaggle() + cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image cv2.waitKey(1) cv2.destroyAllWindows() cv2.waitKey(1) return True except Exception as e: if warn: - LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}') + LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}") return False -def check_yolo(verbose=True, device=''): +def check_yolo(verbose=True, device=""): """Return a human-readable YOLO software and hardware summary.""" + import psutil + from ultralytics.utils.torch_utils import select_device if is_jupyter(): - if check_requirements('wandb', install=False): - os.system('pip uninstall -y wandb') # uninstall wandb: unwanted account creation prompt with infinite hang + if check_requirements("wandb", install=False): + os.system("pip uninstall -y wandb") # uninstall wandb: unwanted account creation prompt with infinite hang if is_colab(): - shutil.rmtree('sample_data', ignore_errors=True) # remove colab /sample_data directory + shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory if verbose: # System info gib = 1 << 30 # bytes per GiB ram = psutil.virtual_memory().total - total, used, free = shutil.disk_usage('/') - s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)' + total, used, free = shutil.disk_usage("/") + s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)" with contextlib.suppress(Exception): # clear display if ipython is installed from IPython import display + display.clear_output() else: - s = '' + s = "" select_device(device=device, newline=False) - LOGGER.info(f'Setup complete ✅ {s}') + LOGGER.info(f"Setup complete ✅ {s}") + + +def collect_system_info(): + """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.""" + + import psutil + + from ultralytics.utils import ENVIRONMENT, is_git_dir + from ultralytics.utils.torch_utils import get_cpu_info + + ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB + check_yolo() + LOGGER.info( + f"\n{'OS':<20}{platform.platform()}\n" + f"{'Environment':<20}{ENVIRONMENT}\n" + f"{'Python':<20}{PYTHON_VERSION}\n" + f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n" + f"{'RAM':<20}{ram_info:.2f} GB\n" + f"{'CPU':<20}{get_cpu_info()}\n" + f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n" + ) + + for r in parse_requirements(package="ultralytics"): + try: + current = metadata.version(r.name) + is_met = "✅ " if check_version(current, str(r.specifier), hard=True) else "❌ " + except metadata.PackageNotFoundError: + current = "(not installed)" + is_met = "❌ " + LOGGER.info(f"{r.name:<20}{is_met}{current}{r.specifier}") + + if is_github_action_running(): + LOGGER.info( + f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n" + f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n" + f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n" + f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n" + f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n" + f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n" + ) def check_amp(model): """ - This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. - If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP - results, so AMP will be disabled during training. + This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks + fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will + be disabled during training. Args: model (nn.Module): A YOLOv8 model instance. @@ -450,7 +632,7 @@ def check_amp(model): (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False. """ device = next(model.parameters()).device # get model device - if device.type in ('cpu', 'mps'): + if device.type in ("cpu", "mps"): return False # AMP only used on CUDA devices def amp_allclose(m, im): @@ -461,23 +643,27 @@ def check_amp(model): del m return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance - im = ASSETS / 'bus.jpg' # image to check - prefix = colorstr('AMP: ') - LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...') + im = ASSETS / "bus.jpg" # image to check + prefix = colorstr("AMP: ") + LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...") warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." try: from ultralytics import YOLO - assert amp_allclose(YOLO('yolov8n.pt'), im) - LOGGER.info(f'{prefix}checks passed ✅') + + assert amp_allclose(YOLO("yolov8n.pt"), im) + LOGGER.info(f"{prefix}checks passed ✅") except ConnectionError: - LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}') + LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}") except (AttributeError, ModuleNotFoundError): LOGGER.warning( - f'{prefix}checks skipped ⚠️. Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}' + f"{prefix}checks skipped ⚠️. " + f"Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}" ) except AssertionError: - LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to ' - f'NaN losses or zero-mAP results, so AMP will be disabled during training.') + LOGGER.warning( + f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to " + f"NaN losses or zero-mAP results, so AMP will be disabled during training." + ) return False return True @@ -485,8 +671,8 @@ def check_amp(model): def git_describe(path=ROOT): # path must be a directory """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe.""" with contextlib.suppress(Exception): - return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1] - return '' + return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1] + return "" def print_args(args: Optional[dict] = None, show_file=True, show_func=False): @@ -494,7 +680,7 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False): def strip_auth(v): """Clean longer Ultralytics HUB URLs by stripping potential authentication information.""" - return clean_url(v) if (isinstance(v, str) and v.startswith('http') and len(v) > 100) else v + return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v x = inspect.currentframe().f_back # previous frame file, _, func, _, _ = inspect.getframeinfo(x) @@ -502,26 +688,28 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False): args, _, _, frm = inspect.getargvalues(x) args = {k: v for k, v in frm.items() if k in args} try: - file = Path(file).resolve().relative_to(ROOT).with_suffix('') + file = Path(file).resolve().relative_to(ROOT).with_suffix("") except ValueError: file = Path(file).stem - s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '') - LOGGER.info(colorstr(s) + ', '.join(f'{k}={strip_auth(v)}' for k, v in args.items())) + s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "") + LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items())) def cuda_device_count() -> int: - """Get the number of NVIDIA GPUs available in the environment. + """ + Get the number of NVIDIA GPUs available in the environment. Returns: (int): The number of NVIDIA GPUs available. """ try: # Run the nvidia-smi command and capture its output - output = subprocess.check_output(['nvidia-smi', '--query-gpu=count', '--format=csv,noheader,nounits'], - encoding='utf-8') + output = subprocess.check_output( + ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8" + ) # Take the first line and strip any leading/trailing white space - first_line = output.strip().split('\n')[0] + first_line = output.strip().split("\n")[0] return int(first_line) except (subprocess.CalledProcessError, FileNotFoundError, ValueError): @@ -530,9 +718,14 @@ def cuda_device_count() -> int: def cuda_is_available() -> bool: - """Check if CUDA is available in the environment. + """ + Check if CUDA is available in the environment. Returns: (bool): True if one or more NVIDIA GPUs are available, False otherwise. """ return cuda_device_count() > 0 + + +# Define constants +IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12") diff --git a/ultralytics/utils/dist.py b/ultralytics/utils/dist.py index 1190098..b669e52 100644 --- a/ultralytics/utils/dist.py +++ b/ultralytics/utils/dist.py @@ -1,47 +1,53 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license import os -import re import shutil import socket import sys import tempfile -from pathlib import Path from . import USER_CONFIG_DIR from .torch_utils import TORCH_1_9 def find_free_network_port() -> int: - """Finds a free port on localhost. + """ + Finds a free port on localhost. It is useful in single-node training when we don't want to connect to a real main node but have to set the `MASTER_PORT` environment variable. """ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('127.0.0.1', 0)) + s.bind(("127.0.0.1", 0)) return s.getsockname()[1] # port def generate_ddp_file(trainer): """Generates a DDP file and returns its file name.""" - module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1) + module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1) - content = f'''overrides = {vars(trainer.args)} \nif __name__ == "__main__": + content = f""" +# Ultralytics Multi-GPU training temp file (should be automatically deleted after use) +overrides = {vars(trainer.args)} + +if __name__ == "__main__": from {module} import {name} from ultralytics.utils import DEFAULT_CFG_DICT cfg = DEFAULT_CFG_DICT.copy() cfg.update(save_dir='') # handle the extra key 'save_dir' trainer = {name}(cfg=cfg, overrides=overrides) - trainer.train()''' - (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True) - with tempfile.NamedTemporaryFile(prefix='_temp_', - suffix=f'{id(trainer)}.py', - mode='w+', - encoding='utf-8', - dir=USER_CONFIG_DIR / 'DDP', - delete=False) as file: + results = trainer.train() +""" + (USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True) + with tempfile.NamedTemporaryFile( + prefix="_temp_", + suffix=f"{id(trainer)}.py", + mode="w+", + encoding="utf-8", + dir=USER_CONFIG_DIR / "DDP", + delete=False, + ) as file: file.write(content) return file.name @@ -49,19 +55,17 @@ def generate_ddp_file(trainer): def generate_ddp_command(world_size, trainer): """Generates and returns command for distributed training.""" import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 + if not trainer.resume: shutil.rmtree(trainer.save_dir) # remove the save_dir - file = str(Path(sys.argv[0]).resolve()) - safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters - if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI - file = generate_ddp_file(trainer) - dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch' + file = generate_ddp_file(trainer) + dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch" port = find_free_network_port() - cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file] return cmd, file def ddp_cleanup(trainer, file): """Delete temp file if created.""" - if f'{id(trainer)}.py' in file: # if temp_file suffix in file + if f"{id(trainer)}.py" in file: # if temp_file suffix in file os.remove(file) diff --git a/ultralytics/utils/downloads.py b/ultralytics/utils/downloads.py index 6e310bf..6191ade 100644 --- a/ultralytics/utils/downloads.py +++ b/ultralytics/utils/downloads.py @@ -15,20 +15,42 @@ import torch from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file # Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets -GITHUB_ASSETS_REPO = 'ultralytics/assets' -GITHUB_ASSETS_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in ('', '6', '-cls', '-seg', '-pose')] + \ - [f'yolov5{k}u.pt' for k in 'nsmlx'] + \ - [f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \ - [f'yolo_nas_{k}.pt' for k in 'sml'] + \ - [f'sam_{k}.pt' for k in 'bl'] + \ - [f'FastSAM-{k}.pt' for k in 'sx'] + \ - [f'rtdetr-{k}.pt' for k in 'lx'] + \ - ['mobile_sam.pt'] +GITHUB_ASSETS_REPO = "ultralytics/assets" +GITHUB_ASSETS_NAMES = ( + [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")] + + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")] + + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")] + + [f"yolov8{k}-world.pt" for k in "smlx"] + + [f"yolov8{k}-worldv2.pt" for k in "smlx"] + + [f"yolov9{k}.pt" for k in "ce"] + + [f"yolo_nas_{k}.pt" for k in "sml"] + + [f"sam_{k}.pt" for k in "bl"] + + [f"FastSAM-{k}.pt" for k in "sx"] + + [f"rtdetr-{k}.pt" for k in "lx"] + + ["mobile_sam.pt"] + + ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"] +) GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES] -def is_url(url, check=True): - """Check if string is URL and check if URL exists.""" +def is_url(url, check=False): + """ + Validates if the given string is a URL and optionally checks if the URL exists online. + + Args: + url (str): The string to be validated as a URL. + check (bool, optional): If True, performs an additional check to see if the URL exists online. + Defaults to True. + + Returns: + (bool): Returns True for a valid URL. If 'check' is True, also returns True if the URL exists online. + Returns False otherwise. + + Example: + ```python + valid = is_url("https://www.example.com") + ``` + """ with contextlib.suppress(Exception): url = str(url) result = parse.urlparse(url) @@ -40,7 +62,7 @@ def is_url(url, check=True): return False -def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')): +def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")): """ Deletes all ".DS_store" files under a specified directory. @@ -59,18 +81,17 @@ def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')): ".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They are hidden system files and can cause issues when transferring files between different operating systems. """ - # Delete Apple .DS_store files for file in files_to_delete: matches = list(Path(path).rglob(file)) - LOGGER.info(f'Deleting {file} files: {matches}') + LOGGER.info(f"Deleting {file} files: {matches}") for f in matches: f.unlink() -def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), progress=True): +def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True): """ - Zips the contents of a directory, excluding files containing strings in the exclude list. - The resulting zip file is named after the directory and placed alongside it. + Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is + named after the directory and placed alongside it. Args: directory (str | Path): The path to the directory to be zipped. @@ -96,17 +117,17 @@ def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), p raise FileNotFoundError(f"Directory '{directory}' does not exist.") # Unzip with progress bar - files_to_zip = [f for f in directory.rglob('*') if f.is_file() and all(x not in f.name for x in exclude)] - zip_file = directory.with_suffix('.zip') + files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)] + zip_file = directory.with_suffix(".zip") compression = ZIP_DEFLATED if compress else ZIP_STORED - with ZipFile(zip_file, 'w', compression) as f: - for file in TQDM(files_to_zip, desc=f'Zipping {directory} to {zip_file}...', unit='file', disable=not progress): + with ZipFile(zip_file, "w", compression) as f: + for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress): f.write(file, file.relative_to(directory)) return zip_file # return path to zip file -def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=False, progress=True): +def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True): """ Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list. @@ -146,51 +167,62 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)] top_level_dirs = {Path(f).parts[0] for f in files} - if len(top_level_dirs) > 1 or not files[0].endswith('/'): # zip has multiple files at top level + if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith("/")): + # Zip has multiple files at top level path = extract_path = Path(path) / Path(file).stem # i.e. ../datasets/coco8 - else: # zip has 1 top-level directory + else: + # Zip has 1 top-level directory extract_path = path # i.e. ../datasets path = Path(path) / list(top_level_dirs)[0] # i.e. ../datasets/coco8 # Check if destination directory already exists and contains files if path.exists() and any(path.iterdir()) and not exist_ok: # If it exists and is not empty, return the path without unzipping - LOGGER.warning(f'WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.') + LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.") return path - for f in TQDM(files, desc=f'Unzipping {file} to {Path(path).resolve()}...', unit='file', disable=not progress): - zipObj.extract(f, path=extract_path) + for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress): + # Ensure the file is within the extract_path to avoid path traversal security vulnerability + if ".." in Path(f).parts: + LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.") + continue + zipObj.extract(f, extract_path) return path # return unzip dir -def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, hard=True): +def check_disk_space(url="https://ultralytics.com/assets/coco128.zip", path=Path.cwd(), sf=1.5, hard=True): """ Check if there is sufficient disk space to download and store a file. Args: url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco128.zip'. + path (str | Path, optional): The path or drive to check the available free space on. sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 2.0. hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True. Returns: (bool): True if there is sufficient disk space, False otherwise. """ - r = requests.head(url) # response - - # Check response - assert r.status_code < 400, f'URL error for {url}: {r.status_code} {r.reason}' + try: + r = requests.head(url) # response + assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response + except Exception: + return True # requests issue, default to True # Check file size gib = 1 << 30 # bytes per GiB - data = int(r.headers.get('Content-Length', 0)) / gib # file size (GB) - total, used, free = (x / gib for x in shutil.disk_usage('/')) # bytes + data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB) + total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes + if data * sf < free: return True # sufficient space # Insufficient space - text = (f'WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, ' - f'Please free {data * sf - free:.1f} GB additional disk space and try again.') + text = ( + f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, " + f"Please free {data * sf - free:.1f} GB additional disk space and try again." + ) if hard: raise MemoryError(text) LOGGER.warning(text) @@ -216,35 +248,41 @@ def get_google_drive_file_info(link): url, filename = get_google_drive_file_info(link) ``` """ - file_id = link.split('/d/')[1].split('/view')[0] - drive_url = f'https://drive.google.com/uc?export=download&id={file_id}' + file_id = link.split("/d/")[1].split("/view")[0] + drive_url = f"https://drive.google.com/uc?export=download&id={file_id}" filename = None # Start session with requests.Session() as session: response = session.get(drive_url, stream=True) - if 'quota exceeded' in str(response.content.lower()): + if "quota exceeded" in str(response.content.lower()): raise ConnectionError( - emojis(f'❌ Google Drive file download quota exceeded. ' - f'Please try again later or download this file manually at {link}.')) + emojis( + f"❌ Google Drive file download quota exceeded. " + f"Please try again later or download this file manually at {link}." + ) + ) for k, v in response.cookies.items(): - if k.startswith('download_warning'): - drive_url += f'&confirm={v}' # v is token - cd = response.headers.get('content-disposition') + if k.startswith("download_warning"): + drive_url += f"&confirm={v}" # v is token + cd = response.headers.get("content-disposition") if cd: filename = re.findall('filename="(.+)"', cd)[0] return drive_url, filename -def safe_download(url, - file=None, - dir=None, - unzip=True, - delete=False, - curl=False, - retry=3, - min_bytes=1E0, - progress=True): +def safe_download( + url, + file=None, + dir=None, + unzip=True, + delete=False, + curl=False, + retry=3, + min_bytes=1e0, + exist_ok=False, + progress=True, +): """ Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file. @@ -260,41 +298,49 @@ def safe_download(url, retry (int, optional): The number of times to retry the download in case of failure. Default: 3. min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered a successful download. Default: 1E0. + exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False. progress (bool, optional): Whether to display a progress bar during the download. Default: True. - """ - # Check if the URL is a Google Drive link - gdrive = url.startswith('https://drive.google.com/') + Example: + ```python + from ultralytics.utils.downloads import safe_download + + link = "https://ultralytics.com/assets/bus.jpg" + path = safe_download(link) + ``` + """ + gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link if gdrive: url, file = get_google_drive_file_info(url) - f = dir / (file if gdrive else url2file(url)) if dir else Path(file) # URL converted to filename - if '://' not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10) + f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename + if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10) f = Path(url) # filename elif not f.is_file(): # URL and file do not exist - assert dir or file, 'dir or file required for download' desc = f"Downloading {url if gdrive else clean_url(url)} to '{f}'" - LOGGER.info(f'{desc}...') + LOGGER.info(f"{desc}...") f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing - check_disk_space(url) + check_disk_space(url, path=f.parent) for i in range(retry + 1): try: if curl or i > 0: # curl download with retry, continue - s = 'sS' * (not progress) # silent - r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '3', '-C', '-']).returncode - assert r == 0, f'Curl return value {r}' + s = "sS" * (not progress) # silent + r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode + assert r == 0, f"Curl return value {r}" else: # urllib download - method = 'torch' - if method == 'torch': + method = "torch" + if method == "torch": torch.hub.download_url_to_file(url, f, progress=progress) else: - with request.urlopen(url) as response, TQDM(total=int(response.getheader('Content-Length', 0)), - desc=desc, - disable=not progress, - unit='B', - unit_scale=True, - unit_divisor=1024) as pbar: - with open(f, 'wb') as f_opened: + with request.urlopen(url) as response, TQDM( + total=int(response.getheader("Content-Length", 0)), + desc=desc, + disable=not progress, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + with open(f, "wb") as f_opened: for data in response: f_opened.write(data) pbar.update(len(data)) @@ -305,88 +351,150 @@ def safe_download(url, f.unlink() # remove partial downloads except Exception as e: if i == 0 and not is_online(): - raise ConnectionError(emojis(f'❌ Download failure for {url}. Environment is not online.')) from e + raise ConnectionError(emojis(f"❌ Download failure for {url}. Environment is not online.")) from e elif i >= retry: - raise ConnectionError(emojis(f'❌ Download failure for {url}. Retry limit reached.')) from e - LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...') + raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e + LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...") - if unzip and f.exists() and f.suffix in ('', '.zip', '.tar', '.gz'): + if unzip and f.exists() and f.suffix in ("", ".zip", ".tar", ".gz"): from zipfile import is_zipfile - unzip_dir = dir or f.parent # unzip to dir if provided else unzip in place + unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place if is_zipfile(f): - unzip_dir = unzip_file(file=f, path=unzip_dir, progress=progress) # unzip - elif f.suffix in ('.tar', '.gz'): - LOGGER.info(f'Unzipping {f} to {unzip_dir.resolve()}...') - subprocess.run(['tar', 'xf' if f.suffix == '.tar' else 'xfz', f, '--directory', unzip_dir], check=True) + unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip + elif f.suffix in (".tar", ".gz"): + LOGGER.info(f"Unzipping {f} to {unzip_dir}...") + subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True) if delete: f.unlink() # remove zip return unzip_dir -def get_github_assets(repo='ultralytics/assets', version='latest', retry=False): - """Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...]).""" - if version != 'latest': - version = f'tags/{version}' # i.e. tags/v6.2 - url = f'https://api.github.com/repos/{repo}/releases/{version}' +def get_github_assets(repo="ultralytics/assets", version="latest", retry=False): + """ + Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the + function fetches the latest release assets. + + Args: + repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'. + version (str, optional): The release version to fetch assets from. Defaults to 'latest'. + retry (bool, optional): Flag to retry the request in case of a failure. Defaults to False. + + Returns: + (tuple): A tuple containing the release tag and a list of asset names. + + Example: + ```python + tag, assets = get_github_assets(repo='ultralytics/assets', version='latest') + ``` + """ + + if version != "latest": + version = f"tags/{version}" # i.e. tags/v6.2 + url = f"https://api.github.com/repos/{repo}/releases/{version}" r = requests.get(url) # github api - if r.status_code != 200 and r.reason != 'rate limit exceeded' and retry: # failed and not 403 rate limit exceeded + if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded r = requests.get(url) # try again if r.status_code != 200: - LOGGER.warning(f'⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}') - return '', [] + LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}") + return "", [] data = r.json() - return data['tag_name'], [x['name'] for x in data['assets']] # tag, assets + return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...] -def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'): - """Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.""" +def attempt_download_asset(file, repo="ultralytics/assets", release="v8.1.0", **kwargs): + """ + Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file + locally first, then tries to download it from the specified GitHub repository release. + + Args: + file (str | Path): The filename or file path to be downloaded. + repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'. + release (str, optional): The specific release version to be downloaded. Defaults to 'v8.1.0'. + **kwargs (any): Additional keyword arguments for the download process. + + Returns: + (str): The path to the downloaded file. + + Example: + ```python + file_path = attempt_download_asset('yolov5s.pt', repo='ultralytics/assets', release='latest') + ``` + """ from ultralytics.utils import SETTINGS # scoped for circular import # YOLOv3/5u updates file = str(file) file = checks.check_yolov5u_filename(file) - file = Path(file.strip().replace("'", '')) + file = Path(file.strip().replace("'", "")) if file.exists(): return str(file) - elif (SETTINGS['weights_dir'] / file).exists(): - return str(SETTINGS['weights_dir'] / file) + elif (SETTINGS["weights_dir"] / file).exists(): + return str(SETTINGS["weights_dir"] / file) else: # URL specified name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc. - if str(file).startswith(('http:/', 'https:/')): # download - url = str(file).replace(':/', '://') # Pathlib turns :// -> :/ + download_url = f"https://github.com/{repo}/releases/download" + if str(file).startswith(("http:/", "https:/")): # download + url = str(file).replace(":/", "://") # Pathlib turns :// -> :/ file = url2file(name) # parse authentication https://url.com/file.txt?auth... if Path(file).is_file(): - LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists else: - safe_download(url=url, file=file, min_bytes=1E5) + safe_download(url=url, file=file, min_bytes=1e5, **kwargs) elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES: - safe_download(url=f'https://github.com/{repo}/releases/download/{release}/{name}', file=file, min_bytes=1E5) + safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs) else: tag, assets = get_github_assets(repo, release) if not assets: tag, assets = get_github_assets(repo) # latest release if name in assets: - safe_download(url=f'https://github.com/{repo}/releases/download/{tag}/{name}', file=file, min_bytes=1E5) + safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs) return str(file) -def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3): - """Downloads and unzips files concurrently if threads > 1, else sequentially.""" +def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False): + """ + Downloads files from specified URLs to a given directory. Supports concurrent downloads if multiple threads are + specified. + + Args: + url (str | list): The URL or list of URLs of the files to be downloaded. + dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory. + unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True. + delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False. + curl (bool, optional): Flag to use curl for downloading. Defaults to False. + threads (int, optional): Number of threads to use for concurrent downloads. Defaults to 1. + retry (int, optional): Number of retries in case of download failure. Defaults to 3. + exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False. + + Example: + ```python + download('https://ultralytics.com/assets/example.zip', dir='path/to/dir', unzip=True) + ``` + """ dir = Path(dir) dir.mkdir(parents=True, exist_ok=True) # make directory if threads > 1: with ThreadPool(threads) as pool: pool.map( lambda x: safe_download( - url=x[0], dir=x[1], unzip=unzip, delete=delete, curl=curl, retry=retry, progress=threads <= 1), - zip(url, repeat(dir))) + url=x[0], + dir=x[1], + unzip=unzip, + delete=delete, + curl=curl, + retry=retry, + exist_ok=exist_ok, + progress=threads <= 1, + ), + zip(url, repeat(dir)), + ) pool.close() pool.join() else: for u in [url] if isinstance(url, (str, Path)) else url: - safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry) + safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok) diff --git a/ultralytics/utils/errors.py b/ultralytics/utils/errors.py index 5a76431..86aee1d 100644 --- a/ultralytics/utils/errors.py +++ b/ultralytics/utils/errors.py @@ -4,7 +4,19 @@ from ultralytics.utils import emojis class HUBModelError(Exception): + """ + Custom exception class for handling errors related to model fetching in Ultralytics YOLO. - def __init__(self, message='Model not found. Please check model URL and try again.'): + This exception is raised when a requested model is not found or cannot be retrieved. + The message is also processed to include emojis for better user experience. + + Attributes: + message (str): The error message displayed when the exception is raised. + + Note: + The message is automatically processed through the 'emojis' function from the 'ultralytics.utils' package. + """ + + def __init__(self, message="Model not found. Please check model URL and try again."): """Create an exception for when a model is not found.""" super().__init__(emojis(message)) diff --git a/ultralytics/utils/files.py b/ultralytics/utils/files.py index 0102c4b..719caca 100644 --- a/ultralytics/utils/files.py +++ b/ultralytics/utils/files.py @@ -30,9 +30,9 @@ class WorkingDirectory(contextlib.ContextDecorator): @contextmanager def spaces_in_path(path): """ - Context manager to handle paths with spaces in their names. - If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path, - executes the context code block, then copies the file/directory back to its original location. + Context manager to handle paths with spaces in their names. If a path contains spaces, it replaces them with + underscores, copies the file/directory to the new path, executes the context code block, then copies the + file/directory back to its original location. Args: path (str | Path): The original path. @@ -45,18 +45,18 @@ def spaces_in_path(path): with ultralytics.utils.files import spaces_in_path with spaces_in_path('/path/with spaces') as new_path: - # your code here + # Your code here ``` """ # If path has spaces, replace them with underscores - if ' ' in str(path): + if " " in str(path): string = isinstance(path, str) # input type path = Path(path) # Create a temporary directory and construct the new path with tempfile.TemporaryDirectory() as tmp_dir: - tmp_path = Path(tmp_dir) / path.name.replace(' ', '_') + tmp_path = Path(tmp_dir) / path.name.replace(" ", "_") # Copy file/directory if path.is_dir(): @@ -82,7 +82,7 @@ def spaces_in_path(path): yield path -def increment_path(path, exist_ok=False, sep='', mkdir=False): +def increment_path(path, exist_ok=False, sep="", mkdir=False): """ Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. @@ -102,12 +102,12 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False): """ path = Path(path) # os-agnostic if path.exists() and not exist_ok: - path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '') + path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "") # Method 1 for n in range(2, 9999): - p = f'{path}{sep}{n}{suffix}' # increment path - if not os.path.exists(p): # + p = f"{path}{sep}{n}{suffix}" # increment path + if not os.path.exists(p): break path = Path(p) @@ -119,14 +119,14 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False): def file_age(path=__file__): """Return days since last file update.""" - dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta + dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta return dt.days # + dt.seconds / 86400 # fractional days def file_date(path=__file__): """Return human-readable file modification date, i.e. '2021-3-26'.""" t = datetime.fromtimestamp(Path(path).stat().st_mtime) - return f'{t.year}-{t.month}-{t.day}' + return f"{t.year}-{t.month}-{t.day}" def file_size(path): @@ -137,11 +137,52 @@ def file_size(path): if path.is_file(): return path.stat().st_size / mb elif path.is_dir(): - return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb + return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb return 0.0 -def get_latest_run(search_dir='.'): +def get_latest_run(search_dir="."): """Return path to most recent 'last.pt' in /runs (i.e. to --resume from).""" - last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) - return max(last_list, key=os.path.getctime) if last_list else '' + last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True) + return max(last_list, key=os.path.getctime) if last_list else "" + + +def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_names=False): + """ + Updates and re-saves specified YOLO models in an 'updated_models' subdirectory. + + Args: + model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt"). + source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory. + update_names (bool, optional): Update model names from a data YAML. + + Example: + ```python + from ultralytics.utils.files import update_models + + model_names = (f"rtdetr-{size}.pt" for size in "lx") + update_models(model_names) + ``` + """ + from ultralytics import YOLO + from ultralytics.nn.autobackend import default_class_names + + target_dir = source_dir / "updated_models" + target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists + + for model_name in model_names: + model_path = source_dir / model_name + print(f"Loading model from {model_path}") + + # Load model + model = YOLO(model_path) + model.half() + if update_names: # update model names from a dataset YAML + model.model.names = default_class_names("coco8.yaml") + + # Define new save path + save_path = target_dir / model_name + + # Save model using model.save() + print(f"Re-saving {model_name} model to {save_path}") + model.save(save_path, use_dill=False) diff --git a/ultralytics/utils/instance.py b/ultralytics/utils/instance.py index 4e2e438..4e9ef2c 100644 --- a/ultralytics/utils/instance.py +++ b/ultralytics/utils/instance.py @@ -7,7 +7,7 @@ from typing import List import numpy as np -from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh +from .ops import ltwh2xywh, ltwh2xyxy, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh def _ntuple(n): @@ -26,16 +26,29 @@ to_4tuple = _ntuple(4) # `xyxy` means left top and right bottom # `xywh` means center x, center y and width, height(YOLO format) # `ltwh` means left top and width, height(COCO format) -_formats = ['xyxy', 'xywh', 'ltwh'] +_formats = ["xyxy", "xywh", "ltwh"] -__all__ = 'Bboxes', # tuple or list +__all__ = ("Bboxes",) # tuple or list class Bboxes: - """Bounding Boxes class. Only numpy variables are supported.""" + """ + A class for handling bounding boxes. - def __init__(self, bboxes, format='xyxy') -> None: - assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}' + The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh'. + Bounding box data should be provided in numpy arrays. + + Attributes: + bboxes (numpy.ndarray): The bounding boxes stored in a 2D numpy array. + format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh'). + + Note: + This class does not handle normalization or denormalization of bounding boxes. + """ + + def __init__(self, bboxes, format="xyxy") -> None: + """Initializes the Bboxes class with bounding box data in a specified format.""" + assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}" bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes assert bboxes.ndim == 2 assert bboxes.shape[1] == 4 @@ -45,21 +58,21 @@ class Bboxes: def convert(self, format): """Converts bounding box format from one type to another.""" - assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}' + assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}" if self.format == format: return - elif self.format == 'xyxy': - func = xyxy2xywh if format == 'xywh' else xyxy2ltwh - elif self.format == 'xywh': - func = xywh2xyxy if format == 'xyxy' else xywh2ltwh + elif self.format == "xyxy": + func = xyxy2xywh if format == "xywh" else xyxy2ltwh + elif self.format == "xywh": + func = xywh2xyxy if format == "xyxy" else xywh2ltwh else: - func = ltwh2xyxy if format == 'xyxy' else ltwh2xywh + func = ltwh2xyxy if format == "xyxy" else ltwh2xywh self.bboxes = func(self.bboxes) self.format = format def areas(self): """Return box areas.""" - self.convert('xyxy') + self.convert("xyxy") return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) # def denormalize(self, w, h): @@ -111,7 +124,7 @@ class Bboxes: return len(self.bboxes) @classmethod - def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes': + def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes": """ Concatenate a list of Bboxes objects into a single Bboxes object. @@ -135,7 +148,7 @@ class Bboxes: return boxes_list[0] return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis)) - def __getitem__(self, index) -> 'Bboxes': + def __getitem__(self, index) -> "Bboxes": """ Retrieve a specific bounding box or a set of bounding boxes using indexing. @@ -156,32 +169,52 @@ class Bboxes: if isinstance(index, int): return Bboxes(self.bboxes[index].view(1, -1)) b = self.bboxes[index] - assert b.ndim == 2, f'Indexing on Bboxes with {index} failed to return a matrix!' + assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!" return Bboxes(b) class Instances: + """ + Container for bounding boxes, segments, and keypoints of detected objects in an image. - def __init__(self, bboxes, segments=None, keypoints=None, bbox_format='xywh', normalized=True) -> None: + Attributes: + _bboxes (Bboxes): Internal object for handling bounding box operations. + keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3]. Default is None. + normalized (bool): Flag indicating whether the bounding box coordinates are normalized. + segments (ndarray): Segments array with shape [N, 1000, 2] after resampling. + + Args: + bboxes (ndarray): An array of bounding boxes with shape [N, 4]. + segments (list | ndarray, optional): A list or array of object segments. Default is None. + keypoints (ndarray, optional): An array of keypoints with shape [N, 17, 3]. Default is None. + bbox_format (str, optional): The format of bounding boxes ('xywh' or 'xyxy'). Default is 'xywh'. + normalized (bool, optional): Whether the bounding box coordinates are normalized. Default is True. + + Examples: + ```python + # Create an Instances object + instances = Instances( + bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]), + segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])], + keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]]) + ) + ``` + + Note: + The bounding box format is either 'xywh' or 'xyxy', and is determined by the `bbox_format` argument. + This class does not perform input validation, and it assumes the inputs are well-formed. + """ + + def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None: """ Args: bboxes (ndarray): bboxes with shape [N, 4]. segments (list | ndarray): segments. keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3]. """ - if segments is None: - segments = [] self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format) self.keypoints = keypoints self.normalized = normalized - - if len(segments) > 0: - # list[np.array(1000, 2)] * num_samples - segments = resample_segments(segments) - # (N, 1000, 2) - segments = np.stack(segments, axis=0) - else: - segments = np.zeros((0, 1000, 2), dtype=np.float32) self.segments = segments def convert_bbox(self, format): @@ -194,7 +227,7 @@ class Instances: return self._bboxes.areas() def scale(self, scale_w, scale_h, bbox_only=False): - """this might be similar with denormalize func but without normalized sign.""" + """This might be similar with denormalize func but without normalized sign.""" self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h)) if bbox_only: return @@ -230,7 +263,7 @@ class Instances: def add_padding(self, padw, padh): """Handle rect and mosaic situation.""" - assert not self.normalized, 'you should add padding with absolute coordinates.' + assert not self.normalized, "you should add padding with absolute coordinates." self._bboxes.add(offset=(padw, padh, padw, padh)) self.segments[..., 0] += padw self.segments[..., 1] += padh @@ -238,7 +271,7 @@ class Instances: self.keypoints[..., 0] += padw self.keypoints[..., 1] += padh - def __getitem__(self, index) -> 'Instances': + def __getitem__(self, index) -> "Instances": """ Retrieve a specific instance or a set of instances using indexing. @@ -268,7 +301,7 @@ class Instances: def flipud(self, h): """Flips the coordinates of bounding boxes, segments, and keypoints vertically.""" - if self._bboxes.format == 'xyxy': + if self._bboxes.format == "xyxy": y1 = self.bboxes[:, 1].copy() y2 = self.bboxes[:, 3].copy() self.bboxes[:, 1] = h - y2 @@ -281,7 +314,7 @@ class Instances: def fliplr(self, w): """Reverses the order of the bounding boxes and segments horizontally.""" - if self._bboxes.format == 'xyxy': + if self._bboxes.format == "xyxy": x1 = self.bboxes[:, 0].copy() x2 = self.bboxes[:, 2].copy() self.bboxes[:, 0] = w - x2 @@ -295,10 +328,10 @@ class Instances: def clip(self, w, h): """Clips bounding boxes, segments, and keypoints values to stay within image boundaries.""" ori_format = self._bboxes.format - self.convert_bbox(format='xyxy') + self.convert_bbox(format="xyxy") self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w) self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h) - if ori_format != 'xyxy': + if ori_format != "xyxy": self.convert_bbox(format=ori_format) self.segments[..., 0] = self.segments[..., 0].clip(0, w) self.segments[..., 1] = self.segments[..., 1].clip(0, h) @@ -307,7 +340,11 @@ class Instances: self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h) def remove_zero_area_boxes(self): - """Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height. This removes them.""" + """ + Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height. + + This removes them. + """ good = self.bbox_areas > 0 if not all(good): self._bboxes = self._bboxes[good] @@ -330,7 +367,7 @@ class Instances: return len(self.bboxes) @classmethod - def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances': + def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances": """ Concatenates a list of Instances objects into a single Instances object. diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py index 69f08db..d0ca9c3 100644 --- a/ultralytics/utils/loss.py +++ b/ultralytics/utils/loss.py @@ -6,14 +6,17 @@ import torch.nn.functional as F from ultralytics.utils.metrics import OKS_SIGMA from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh -from ultralytics.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors - -from .metrics import bbox_iou +from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors +from .metrics import bbox_iou, probiou from .tal import bbox2dist class VarifocalLoss(nn.Module): - """Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367.""" + """ + Varifocal loss by Zhang et al. + + https://arxiv.org/abs/2008.13367. + """ def __init__(self): """Initialize the VarifocalLoss class.""" @@ -24,21 +27,25 @@ class VarifocalLoss(nn.Module): """Computes varfocal loss.""" weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label with torch.cuda.amp.autocast(enabled=False): - loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') * - weight).mean(1).sum() + loss = ( + (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight) + .mean(1) + .sum() + ) return loss class FocalLoss(nn.Module): """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).""" - def __init__(self, ): + def __init__(self): + """Initializer for FocalLoss class with no parameters.""" super().__init__() @staticmethod def forward(pred, label, gamma=1.5, alpha=0.25): """Calculates and updates confusion matrix for object detection/classification tasks.""" - loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none') + loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none") # p_t = torch.exp(-loss) # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability @@ -54,6 +61,7 @@ class FocalLoss(nn.Module): class BboxLoss(nn.Module): + """Criterion class for computing training losses during training.""" def __init__(self, reg_max, use_dfl=False): """Initialize the BboxLoss module with regularization maximum and DFL settings.""" @@ -79,42 +87,73 @@ class BboxLoss(nn.Module): @staticmethod def _df_loss(pred_dist, target): - """Return sum of left and right DFL losses.""" - # Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391 + """ + Return sum of left and right DFL losses. + + Distribution Focal Loss (DFL) proposed in Generalized Focal Loss + https://ieeexplore.ieee.org/document/9792391 + """ tl = target.long() # target left tr = tl + 1 # target right wl = tr - target # weight left wr = 1 - wl # weight right - return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl + - F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True) + return ( + F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl + + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr + ).mean(-1, keepdim=True) + + +class RotatedBboxLoss(BboxLoss): + """Criterion class for computing training losses during training.""" + + def __init__(self, reg_max, use_dfl=False): + """Initialize the BboxLoss module with regularization maximum and DFL settings.""" + super().__init__(reg_max, use_dfl) + + def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask): + """IoU loss.""" + weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1) + iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask]) + loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum + + # DFL loss + if self.use_dfl: + target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max) + loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight + loss_dfl = loss_dfl.sum() / target_scores_sum + else: + loss_dfl = torch.tensor(0.0).to(pred_dist.device) + + return loss_iou, loss_dfl class KeypointLoss(nn.Module): """Criterion class for computing training losses.""" def __init__(self, sigmas) -> None: + """Initialize the KeypointLoss class.""" super().__init__() self.sigmas = sigmas def forward(self, pred_kpts, gt_kpts, kpt_mask, area): """Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints.""" - d = (pred_kpts[..., 0] - gt_kpts[..., 0]) ** 2 + (pred_kpts[..., 1] - gt_kpts[..., 1]) ** 2 - kpt_loss_factor = (torch.sum(kpt_mask != 0) + torch.sum(kpt_mask == 0)) / (torch.sum(kpt_mask != 0) + 1e-9) + d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2) + kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9) # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula - e = d / (2 * self.sigmas) ** 2 / (area + 1e-9) / 2 # from cocoeval - return kpt_loss_factor * ((1 - torch.exp(-e)) * kpt_mask).mean() + e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval + return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean() class v8DetectionLoss: """Criterion class for computing training losses.""" - def __init__(self, model): # model must be de-paralleled - + def __init__(self, model, tal_topk=10): # model must be de-paralleled + """Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function.""" device = next(model.parameters()).device # get model device h = model.args # hyperparameters m = model.model[-1] # Detect() module - self.bce = nn.BCEWithLogitsLoss(reduction='none') + self.bce = nn.BCEWithLogitsLoss(reduction="none") self.hyp = h self.stride = m.stride # model strides self.nc = m.nc # number of classes @@ -124,7 +163,7 @@ class v8DetectionLoss: self.use_dfl = m.reg_max > 1 - self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) + self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0) self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device) self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device) @@ -159,7 +198,8 @@ class v8DetectionLoss: loss = torch.zeros(3, device=self.device) # box, cls, dfl feats = preds[1] if isinstance(preds, tuple) else preds pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( - (self.reg_max * 4, self.nc), 1) + (self.reg_max * 4, self.nc), 1 + ) pred_scores = pred_scores.permute(0, 2, 1).contiguous() pred_distri = pred_distri.permute(0, 2, 1).contiguous() @@ -169,30 +209,36 @@ class v8DetectionLoss: imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) - # targets - targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1) + # Targets + targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) - # pboxes + # Pboxes pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) _, target_bboxes, target_scores, fg_mask, _ = self.assigner( - pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), - anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt) + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) target_scores_sum = max(target_scores.sum(), 1) - # cls loss + # Cls loss # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE - # bbox loss + # Bbox loss if fg_mask.sum(): target_bboxes /= stride_tensor - loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, - target_scores_sum, fg_mask) + loss[0], loss[2] = self.bbox_loss( + pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask + ) loss[0] *= self.hyp.box # box gain loss[1] *= self.hyp.cls # cls gain @@ -205,8 +251,8 @@ class v8SegmentationLoss(v8DetectionLoss): """Criterion class for computing training losses.""" def __init__(self, model): # model must be de-paralleled + """Initializes the v8SegmentationLoss class, taking a de-paralleled model as argument.""" super().__init__(model) - self.nm = model.model[-1].nm # number of masks self.overlap = model.args.overlap_mask def __call__(self, preds, batch): @@ -215,9 +261,10 @@ class v8SegmentationLoss(v8DetectionLoss): feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( - (self.reg_max * 4, self.nc), 1) + (self.reg_max * 4, self.nc), 1 + ) - # b, grids, .. + # B, grids, .. pred_scores = pred_scores.permute(0, 2, 1).contiguous() pred_distri = pred_distri.permute(0, 2, 1).contiguous() pred_masks = pred_masks.permute(0, 2, 1).contiguous() @@ -226,80 +273,168 @@ class v8SegmentationLoss(v8DetectionLoss): imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) - # targets + # Targets try: - batch_idx = batch['batch_idx'].view(-1, 1) - targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1) + batch_idx = batch["batch_idx"].view(-1, 1) + targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1) targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) except RuntimeError as e: - raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n' - "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, " - "i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a " - "correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' " - 'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e + raise TypeError( + "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n" + "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, " + "i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a " + "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' " + "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help." + ) from e - # pboxes + # Pboxes pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( - pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), - anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt) + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) target_scores_sum = max(target_scores.sum(), 1) - # cls loss + # Cls loss # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE if fg_mask.sum(): - # bbox loss - loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor, - target_scores, target_scores_sum, fg_mask) - # masks loss - masks = batch['masks'].to(self.device).float() + # Bbox loss + loss[0], loss[3] = self.bbox_loss( + pred_distri, + pred_bboxes, + anchor_points, + target_bboxes / stride_tensor, + target_scores, + target_scores_sum, + fg_mask, + ) + # Masks loss + masks = batch["masks"].to(self.device).float() if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample - masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0] + masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0] - for i in range(batch_size): - if fg_mask[i].sum(): - mask_idx = target_gt_idx[i][fg_mask[i]] - if self.overlap: - gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0) - else: - gt_mask = masks[batch_idx.view(-1) == i][mask_idx] - xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]] - marea = xyxy2xywh(xyxyn)[:, 2:].prod(1) - mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device) - loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea) # seg - - # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove - else: - loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss + loss[1] = self.calculate_segmentation_loss( + fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap + ) # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove else: loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss loss[0] *= self.hyp.box # box gain - loss[1] *= self.hyp.box / batch_size # seg gain + loss[1] *= self.hyp.box # seg gain loss[2] *= self.hyp.cls # cls gain loss[3] *= self.hyp.dfl # dfl gain return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) - def single_mask_loss(self, gt_mask, pred, proto, xyxy, area): - """Mask loss for one image.""" - pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80) - loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none') - return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean() + @staticmethod + def single_mask_loss( + gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor + ) -> torch.Tensor: + """ + Compute the instance segmentation loss for a single image. + + Args: + gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects. + pred (torch.Tensor): Predicted mask coefficients of shape (n, 32). + proto (torch.Tensor): Prototype masks of shape (32, H, W). + xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4). + area (torch.Tensor): Area of each ground truth bounding box of shape (n,). + + Returns: + (torch.Tensor): The calculated mask loss for a single image. + + Notes: + The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the + predicted masks from the prototype masks and predicted mask coefficients. + """ + pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80) + loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none") + return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum() + + def calculate_segmentation_loss( + self, + fg_mask: torch.Tensor, + masks: torch.Tensor, + target_gt_idx: torch.Tensor, + target_bboxes: torch.Tensor, + batch_idx: torch.Tensor, + proto: torch.Tensor, + pred_masks: torch.Tensor, + imgsz: torch.Tensor, + overlap: bool, + ) -> torch.Tensor: + """ + Calculate the loss for instance segmentation. + + Args: + fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive. + masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W). + target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors). + target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4). + batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1). + proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W). + pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32). + imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W). + overlap (bool): Whether the masks in `masks` tensor overlap. + + Returns: + (torch.Tensor): The calculated loss for instance segmentation. + + Notes: + The batch loss can be computed for improved speed at higher memory usage. + For example, pred_mask can be computed as follows: + pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160) + """ + _, _, mask_h, mask_w = proto.shape + loss = 0 + + # Normalize to 0-1 + target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]] + + # Areas of target bboxes + marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2) + + # Normalize to mask size + mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device) + + for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)): + fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i + if fg_mask_i.any(): + mask_idx = target_gt_idx_i[fg_mask_i] + if overlap: + gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1) + gt_mask = gt_mask.float() + else: + gt_mask = masks[batch_idx.view(-1) == i][mask_idx] + + loss += self.single_mask_loss( + gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i] + ) + + # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove + else: + loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss + + return loss / fg_mask.sum() class v8PoseLoss(v8DetectionLoss): """Criterion class for computing training losses.""" def __init__(self, model): # model must be de-paralleled + """Initializes v8PoseLoss with model, sets keypoint variables and declares a keypoint loss instance.""" super().__init__(model) self.kpt_shape = model.model[-1].kpt_shape self.bce_pose = nn.BCEWithLogitsLoss() @@ -313,9 +448,10 @@ class v8PoseLoss(v8DetectionLoss): loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1] pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( - (self.reg_max * 4, self.nc), 1) + (self.reg_max * 4, self.nc), 1 + ) - # b, grids, .. + # B, grids, .. pred_scores = pred_scores.permute(0, 2, 1).contiguous() pred_distri = pred_distri.permute(0, 2, 1).contiguous() pred_kpts = pred_kpts.permute(0, 2, 1).contiguous() @@ -324,53 +460,50 @@ class v8PoseLoss(v8DetectionLoss): imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) - # targets + # Targets batch_size = pred_scores.shape[0] - batch_idx = batch['batch_idx'].view(-1, 1) - targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1) + batch_idx = batch["batch_idx"].view(-1, 1) + targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1) targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) - # pboxes + # Pboxes pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3) _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( - pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), - anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt) + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) target_scores_sum = max(target_scores.sum(), 1) - # cls loss + # Cls loss # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE - # bbox loss + # Bbox loss if fg_mask.sum(): target_bboxes /= stride_tensor - loss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, - target_scores_sum, fg_mask) - keypoints = batch['keypoints'].to(self.device).float().clone() + loss[0], loss[4] = self.bbox_loss( + pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask + ) + keypoints = batch["keypoints"].to(self.device).float().clone() keypoints[..., 0] *= imgsz[1] keypoints[..., 1] *= imgsz[0] - for i in range(batch_size): - if fg_mask[i].sum(): - idx = target_gt_idx[i][fg_mask[i]] - gt_kpt = keypoints[batch_idx.view(-1) == i][idx] # (n, 51) - gt_kpt[..., 0] /= stride_tensor[fg_mask[i]] - gt_kpt[..., 1] /= stride_tensor[fg_mask[i]] - area = xyxy2xywh(target_bboxes[i][fg_mask[i]])[:, 2:].prod(1, keepdim=True) - pred_kpt = pred_kpts[i][fg_mask[i]] - kpt_mask = gt_kpt[..., 2] != 0 - loss[1] += self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss - # kpt_score loss - if pred_kpt.shape[-1] == 3: - loss[2] += self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss + + loss[1], loss[2] = self.calculate_keypoints_loss( + fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts + ) loss[0] *= self.hyp.box # box gain - loss[1] *= self.hyp.pose / batch_size # pose gain - loss[2] *= self.hyp.kobj / batch_size # kobj gain + loss[1] *= self.hyp.pose # pose gain + loss[2] *= self.hyp.kobj # kobj gain loss[3] *= self.hyp.cls # cls gain loss[4] *= self.hyp.dfl # dfl gain @@ -385,12 +518,210 @@ class v8PoseLoss(v8DetectionLoss): y[..., 1] += anchor_points[:, [1]] - 0.5 return y + def calculate_keypoints_loss( + self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts + ): + """ + Calculate the keypoints loss for the model. + + This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is + based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is + a binary classification loss that classifies whether a keypoint is present or not. + + Args: + masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors). + target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors). + keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim). + batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1). + stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1). + target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4). + pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim). + + Returns: + (tuple): Returns a tuple containing: + - kpts_loss (torch.Tensor): The keypoints loss. + - kpts_obj_loss (torch.Tensor): The keypoints object loss. + """ + batch_idx = batch_idx.flatten() + batch_size = len(masks) + + # Find the maximum number of keypoints in a single image + max_kpts = torch.unique(batch_idx, return_counts=True)[1].max() + + # Create a tensor to hold batched keypoints + batched_keypoints = torch.zeros( + (batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device + ) + + # TODO: any idea how to vectorize this? + # Fill batched_keypoints with keypoints based on batch_idx + for i in range(batch_size): + keypoints_i = keypoints[batch_idx == i] + batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i + + # Expand dimensions of target_gt_idx to match the shape of batched_keypoints + target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1) + + # Use target_gt_idx_expanded to select keypoints from batched_keypoints + selected_keypoints = batched_keypoints.gather( + 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2]) + ) + + # Divide coordinates by stride + selected_keypoints /= stride_tensor.view(1, -1, 1, 1) + + kpts_loss = 0 + kpts_obj_loss = 0 + + if masks.any(): + gt_kpt = selected_keypoints[masks] + area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True) + pred_kpt = pred_kpts[masks] + kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True) + kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss + + if pred_kpt.shape[-1] == 3: + kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss + + return kpts_loss, kpts_obj_loss + class v8ClassificationLoss: """Criterion class for computing training losses.""" def __call__(self, preds, batch): """Compute the classification loss between predictions and true labels.""" - loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / 64 + loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean") loss_items = loss.detach() return loss, loss_items + + +class v8OBBLoss(v8DetectionLoss): + def __init__(self, model): + """ + Initializes v8OBBLoss with model, assigner, and rotated bbox loss. + + Note model must be de-paralleled. + """ + super().__init__(model) + self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) + self.bbox_loss = RotatedBboxLoss(self.reg_max - 1, use_dfl=self.use_dfl).to(self.device) + + def preprocess(self, targets, batch_size, scale_tensor): + """Preprocesses the target counts and matches with the input batch size to output a tensor.""" + if targets.shape[0] == 0: + out = torch.zeros(batch_size, 0, 6, device=self.device) + else: + i = targets[:, 0] # image index + _, counts = i.unique(return_counts=True) + counts = counts.to(dtype=torch.int32) + out = torch.zeros(batch_size, counts.max(), 6, device=self.device) + for j in range(batch_size): + matches = i == j + n = matches.sum() + if n: + bboxes = targets[matches, 2:] + bboxes[..., :4].mul_(scale_tensor) + out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1) + return out + + def __call__(self, preds, batch): + """Calculate and return the loss for the YOLO model.""" + loss = torch.zeros(3, device=self.device) # box, cls, dfl + feats, pred_angle = preds if isinstance(preds[0], list) else preds[1] + batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width + pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( + (self.reg_max * 4, self.nc), 1 + ) + + # b, grids, .. + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + pred_angle = pred_angle.permute(0, 2, 1).contiguous() + + dtype = pred_scores.dtype + imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) + anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) + + # targets + try: + batch_idx = batch["batch_idx"].view(-1, 1) + targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1) + rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item() + targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training + targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) + except RuntimeError as e: + raise TypeError( + "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n" + "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, " + "i.e. 'yolo train model=yolov8n-obb.pt data=dota8.yaml'.\nVerify your dataset is a " + "correctly formatted 'OBB' dataset using 'data=dota8.yaml' " + "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help." + ) from e + + # Pboxes + pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4) + + bboxes_for_assigner = pred_bboxes.clone().detach() + # Only the first four elements need to be scaled + bboxes_for_assigner[..., :4] *= stride_tensor + _, target_bboxes, target_scores, fg_mask, _ = self.assigner( + pred_scores.detach().sigmoid(), + bboxes_for_assigner.type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) + + target_scores_sum = max(target_scores.sum(), 1) + + # Cls loss + # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way + loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE + + # Bbox loss + if fg_mask.sum(): + target_bboxes[..., :4] /= stride_tensor + loss[0], loss[2] = self.bbox_loss( + pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask + ) + else: + loss[0] += (pred_angle * 0).sum() + + loss[0] *= self.hyp.box # box gain + loss[1] *= self.hyp.cls # cls gain + loss[2] *= self.hyp.dfl # dfl gain + + return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) + + def bbox_decode(self, anchor_points, pred_dist, pred_angle): + """ + Decode predicted object bounding box coordinates from anchor points and distribution. + + Args: + anchor_points (torch.Tensor): Anchor points, (h*w, 2). + pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4). + pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1). + + Returns: + (torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5). + """ + if self.use_dfl: + b, a, c = pred_dist.shape # batch, anchors, channels + pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype)) + return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1) + +class v10DetectLoss: + def __init__(self, model): + self.one2many = v8DetectionLoss(model, tal_topk=10) + self.one2one = v8DetectionLoss(model, tal_topk=1) + + def __call__(self, preds, batch): + one2many = preds["one2many"] + loss_one2many = self.one2many(one2many, batch) + one2one = preds["one2one"] + loss_one2one = self.one2one(one2one, batch) + return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1])) diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py index 731b55a..b598811 100644 --- a/ultralytics/utils/metrics.py +++ b/ultralytics/utils/metrics.py @@ -1,7 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -""" -Model validation metrics -""" +"""Model validation metrics.""" + import math import warnings from pathlib import Path @@ -12,7 +11,10 @@ import torch from ultralytics.utils import LOGGER, SimpleClass, TryExcept, plt_settings -OKS_SIGMA = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0 +OKS_SIGMA = ( + np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89]) + / 10.0 +) def bbox_ioa(box1, box2, iou=False, eps=1e-7): @@ -20,13 +22,13 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7): Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format. Args: - box1 (np.array): A numpy array of shape (n, 4) representing n bounding boxes. - box2 (np.array): A numpy array of shape (m, 4) representing m bounding boxes. - iou (bool): Calculate the standard iou if True else return inter_area/box2_area. + box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes. + box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes. + iou (bool): Calculate the standard IoU if True else return inter_area/box2_area. eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7. Returns: - (np.array): A numpy array of shape (n, m) representing the intersection over box2 area. + (np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area. """ # Get the coordinates of bounding boxes @@ -34,10 +36,11 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7): b2_x1, b2_y1, b2_x2, b2_y2 = box2.T # Intersection area - inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * \ - (np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0) + inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * ( + np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1) + ).clip(0) - # box2 area + # Box2 area area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) if iou: box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) @@ -49,8 +52,7 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7): def box_iou(box1, box2, eps=1e-7): """ - Calculate intersection-over-union (IoU) of boxes. - Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format. Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py Args: @@ -62,6 +64,9 @@ def box_iou(box1, box2, eps=1e-7): (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2. """ + # NOTE: need float32 to get accurate iou values + box1 = torch.as_tensor(box1, dtype=torch.float32) + box2 = torch.as_tensor(box2, dtype=torch.float32) # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2) inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2) @@ -101,8 +106,9 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7 w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps # Intersection area - inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * \ - (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp_(0) + inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * ( + b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1) + ).clamp_(0) # Union Area union = w1 * h1 + w2 * h2 - inter + eps @@ -113,10 +119,12 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7 cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 - c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared - rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2 + c2 = cw.pow(2) + ch.pow(2) + eps # convex diagonal squared + rho2 = ( + (b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) + (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2) + ) / 4 # center dist**2 if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 - v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2) + v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2) with torch.no_grad(): alpha = v / (v - iou + (1 + eps)) return iou - (rho2 / c2 + v * alpha) # CIoU @@ -159,16 +167,120 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7): Returns: (torch.Tensor): A tensor of shape (N, M) representing keypoint similarities. """ - d = (kpt1[:, None, :, 0] - kpt2[..., 0]) ** 2 + (kpt1[:, None, :, 1] - kpt2[..., 1]) ** 2 # (N, M, 17) + d = (kpt1[:, None, :, 0] - kpt2[..., 0]).pow(2) + (kpt1[:, None, :, 1] - kpt2[..., 1]).pow(2) # (N, M, 17) sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, ) kpt_mask = kpt1[..., 2] != 0 # (N, 17) - e = d / (2 * sigma) ** 2 / (area[:, None, None] + eps) / 2 # from cocoeval + e = d / (2 * sigma).pow(2) / (area[:, None, None] + eps) / 2 # from cocoeval # e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2 # from formula - return (torch.exp(-e) * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps) + return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps) -def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 - # return positive, negative label smoothing BCE targets +def _get_covariance_matrix(boxes): + """ + Generating covariance matrix from obbs. + + Args: + boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format. + + Returns: + (torch.Tensor): Covariance metrixs corresponding to original rotated bounding boxes. + """ + # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here. + gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1) + a, b, c = gbbs.split(1, dim=-1) + cos = c.cos() + sin = c.sin() + cos2 = cos.pow(2) + sin2 = sin.pow(2) + return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sin + + +def probiou(obb1, obb2, CIoU=False, eps=1e-7): + """ + Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf. + + Args: + obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format. + obb2 (torch.Tensor): A tensor of shape (N, 5) representing predicted obbs, with xywhr format. + eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7. + + Returns: + (torch.Tensor): A tensor of shape (N, ) representing obb similarities. + """ + x1, y1 = obb1[..., :2].split(1, dim=-1) + x2, y2 = obb2[..., :2].split(1, dim=-1) + a1, b1, c1 = _get_covariance_matrix(obb1) + a2, b2, c2 = _get_covariance_matrix(obb2) + + t1 = ( + ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps) + ) * 0.25 + t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5 + t3 = ( + ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2)) + / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps) + + eps + ).log() * 0.5 + bd = (t1 + t2 + t3).clamp(eps, 100.0) + hd = (1.0 - (-bd).exp() + eps).sqrt() + iou = 1 - hd + if CIoU: # only include the wh aspect ratio part + w1, h1 = obb1[..., 2:4].split(1, dim=-1) + w2, h2 = obb2[..., 2:4].split(1, dim=-1) + v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2) + with torch.no_grad(): + alpha = v / (v - iou + (1 + eps)) + return iou - v * alpha # CIoU + return iou + + +def batch_probiou(obb1, obb2, eps=1e-7): + """ + Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf. + + Args: + obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format. + obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format. + eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7. + + Returns: + (torch.Tensor): A tensor of shape (N, M) representing obb similarities. + """ + obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1 + obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2 + + x1, y1 = obb1[..., :2].split(1, dim=-1) + x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1)) + a1, b1, c1 = _get_covariance_matrix(obb1) + a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2)) + + t1 = ( + ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps) + ) * 0.25 + t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5 + t3 = ( + ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2)) + / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps) + + eps + ).log() * 0.5 + bd = (t1 + t2 + t3).clamp(eps, 100.0) + hd = (1.0 - (-bd).exp() + eps).sqrt() + return 1 - hd + + +def smooth_BCE(eps=0.1): + """ + Computes smoothed positive and negative Binary Cross-Entropy targets. + + This function calculates positive and negative label smoothing BCE targets based on a given epsilon value. + For implementation details, refer to https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441. + + Args: + eps (float, optional): The epsilon value for label smoothing. Defaults to 0.1. + + Returns: + (tuple): A tuple containing the positive and negative label smoothing BCE targets. + """ return 1.0 - 0.5 * eps, 0.5 * eps @@ -178,23 +290,23 @@ class ConfusionMatrix: Attributes: task (str): The type of task, either 'detect' or 'classify'. - matrix (np.array): The confusion matrix, with dimensions depending on the task. + matrix (np.ndarray): The confusion matrix, with dimensions depending on the task. nc (int): The number of classes. conf (float): The confidence threshold for detections. iou_thres (float): The Intersection over Union threshold. """ - def __init__(self, nc, conf=0.25, iou_thres=0.45, task='detect'): + def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"): """Initialize attributes for the YOLO model.""" self.task = task - self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == 'detect' else np.zeros((nc, nc)) + self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc)) self.nc = nc # number of classes - self.conf = 0.25 if conf is None else conf # argument may be None from default cfg + self.conf = 0.25 if conf in (None, 0.001) else conf # apply 0.25 if default val conf is passed self.iou_thres = iou_thres def process_cls_preds(self, preds, targets): """ - Update confusion matrix for classification task + Update confusion matrix for classification task. Args: preds (Array[N, min(nc,5)]): Predicted class labels. @@ -204,26 +316,39 @@ class ConfusionMatrix: for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()): self.matrix[p][t] += 1 - def process_batch(self, detections, labels): + def process_batch(self, detections, gt_bboxes, gt_cls): """ Update confusion matrix for object detection task. Args: - detections (Array[N, 6]): Detected bounding boxes and their associated information. - Each row should contain (x1, y1, x2, y2, conf, class). - labels (Array[M, 5]): Ground truth bounding boxes and their associated class labels. - Each row should contain (class, x1, y1, x2, y2). + detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information. + Each row should contain (x1, y1, x2, y2, conf, class) + or with an additional element `angle` when it's obb. + gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format. + gt_cls (Array[M]): The class labels. """ + if gt_cls.shape[0] == 0: # Check if labels is empty + if detections is not None: + detections = detections[detections[:, 4] > self.conf] + detection_classes = detections[:, 5].int() + for dc in detection_classes: + self.matrix[dc, self.nc] += 1 # false positives + return if detections is None: - gt_classes = labels.int() + gt_classes = gt_cls.int() for gc in gt_classes: self.matrix[self.nc, gc] += 1 # background FN return detections = detections[detections[:, 4] > self.conf] - gt_classes = labels[:, 0].int() + gt_classes = gt_cls.int() detection_classes = detections[:, 5].int() - iou = box_iou(labels[:, 1:], detections[:, :4]) + is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5 # with additional `angle` dimension + iou = ( + batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1)) + if is_obb + else box_iou(gt_bboxes, detections[:, :4]) + ) x = torch.where(iou > self.iou_thres) if x[0].shape[0]: @@ -259,11 +384,11 @@ class ConfusionMatrix: tp = self.matrix.diagonal() # true positives fp = self.matrix.sum(1) - tp # false positives # fn = self.matrix.sum(0) - tp # false negatives (missed detections) - return (tp[:-1], fp[:-1]) if self.task == 'detect' else (tp, fp) # remove background class if task=detect + return (tp[:-1], fp[:-1]) if self.task == "detect" else (tp, fp) # remove background class if task=detect - @TryExcept('WARNING ⚠️ ConfusionMatrix plot failure') + @TryExcept("WARNING ⚠️ ConfusionMatrix plot failure") @plt_settings() - def plot(self, normalize=True, save_dir='', names=(), on_plot=None): + def plot(self, normalize=True, save_dir="", names=(), on_plot=None): """ Plot the confusion matrix using seaborn and save it to a file. @@ -275,30 +400,31 @@ class ConfusionMatrix: """ import seaborn as sn - array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns + array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True) nc, nn = self.nc, len(names) # number of classes, names sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels - ticklabels = (list(names) + ['background']) if labels else 'auto' + ticklabels = (list(names) + ["background"]) if labels else "auto" with warnings.catch_warnings(): - warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered - sn.heatmap(array, - ax=ax, - annot=nc < 30, - annot_kws={ - 'size': 8}, - cmap='Blues', - fmt='.2f' if normalize else '.0f', - square=True, - vmin=0.0, - xticklabels=ticklabels, - yticklabels=ticklabels).set_facecolor((1, 1, 1)) - title = 'Confusion Matrix' + ' Normalized' * normalize - ax.set_xlabel('True') - ax.set_ylabel('Predicted') + warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered + sn.heatmap( + array, + ax=ax, + annot=nc < 30, + annot_kws={"size": 8}, + cmap="Blues", + fmt=".2f" if normalize else ".0f", + square=True, + vmin=0.0, + xticklabels=ticklabels, + yticklabels=ticklabels, + ).set_facecolor((1, 1, 1)) + title = "Confusion Matrix" + " Normalized" * normalize + ax.set_xlabel("True") + ax.set_ylabel("Predicted") ax.set_title(title) plot_fname = Path(save_dir) / f'{title.lower().replace(" ", "_")}.png' fig.savefig(plot_fname, dpi=250) @@ -307,11 +433,9 @@ class ConfusionMatrix: on_plot(plot_fname) def print(self): - """ - Print the confusion matrix to the console. - """ + """Print the confusion matrix to the console.""" for i in range(self.nc + 1): - LOGGER.info(' '.join(map(str, self.matrix[i]))) + LOGGER.info(" ".join(map(str, self.matrix[i]))) def smooth(y, f=0.05): @@ -319,28 +443,28 @@ def smooth(y, f=0.05): nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd) p = np.ones(nf // 2) # ones padding yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded - return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed + return np.convolve(yp, np.ones(nf) / nf, mode="valid") # y-smoothed @plt_settings() -def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=None): +def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=None): """Plots a precision-recall curve.""" fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) py = np.stack(py, axis=1) if 0 < len(names) < 21: # display per-class legend if < 21 classes for i, y in enumerate(py.T): - ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision) + ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision) else: - ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision) + ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision) - ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean()) - ax.set_xlabel('Recall') - ax.set_ylabel('Precision') + ax.plot(px, py.mean(1), linewidth=3, color="blue", label="all classes %.3f mAP@0.5" % ap[:, 0].mean()) + ax.set_xlabel("Recall") + ax.set_ylabel("Precision") ax.set_xlim(0, 1) ax.set_ylim(0, 1) - ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left') - ax.set_title('Precision-Recall Curve') + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + ax.set_title("Precision-Recall Curve") fig.savefig(save_dir, dpi=250) plt.close(fig) if on_plot: @@ -348,24 +472,24 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=N @plt_settings() -def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric', on_plot=None): +def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names=(), xlabel="Confidence", ylabel="Metric", on_plot=None): """Plots a metric-confidence curve.""" fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) if 0 < len(names) < 21: # display per-class legend if < 21 classes for i, y in enumerate(py): - ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric) + ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric) else: - ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric) + ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric) y = smooth(py.mean(0), 0.05) - ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}') + ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}") ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_xlim(0, 1) ax.set_ylim(0, 1) - ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left') - ax.set_title(f'{ylabel}-Confidence Curve') + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + ax.set_title(f"{ylabel}-Confidence Curve") fig.savefig(save_dir, dpi=250) plt.close(fig) if on_plot: @@ -394,8 +518,8 @@ def compute_ap(recall, precision): mpre = np.flip(np.maximum.accumulate(np.flip(mpre))) # Integrate area under curve - method = 'interp' # methods: 'continuous', 'interp' - if method == 'interp': + method = "interp" # methods: 'continuous', 'interp' + if method == "interp": x = np.linspace(0, 1, 101) # 101-point interp (COCO) ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate else: # 'continuous' @@ -405,16 +529,9 @@ def compute_ap(recall, precision): return ap, mpre, mrec -def ap_per_class(tp, - conf, - pred_cls, - target_cls, - plot=False, - on_plot=None, - save_dir=Path(), - names=(), - eps=1e-16, - prefix=''): +def ap_per_class( + tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names=(), eps=1e-16, prefix="" +): """ Computes the average precision per class for object detection evaluation. @@ -432,14 +549,18 @@ def ap_per_class(tp, Returns: (tuple): A tuple of six arrays and one array of unique classes, where: - tp (np.ndarray): True positive counts for each class. - fp (np.ndarray): False positive counts for each class. - p (np.ndarray): Precision values at each confidence threshold. - r (np.ndarray): Recall values at each confidence threshold. - f1 (np.ndarray): F1-score values at each confidence threshold. - ap (np.ndarray): Average precision for each class at different IoU thresholds. - unique_classes (np.ndarray): An array of unique classes that have data. - + tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,). + fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,). + p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,). + r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,). + f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,). + ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10). + unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,). + p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000). + r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000). + f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000). + x (np.ndarray): X-axis values for the curves. Shape: (1000,). + prec_values: Precision values at mAP@0.5 for each class. Shape: (nc, 1000). """ # Sort by objectness @@ -451,8 +572,10 @@ def ap_per_class(tp, nc = unique_classes.shape[0] # number of classes, number of detections # Create Precision-Recall curve and compute AP for each class - px, py = np.linspace(0, 1, 1000), [] # for plotting - ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) + x, prec_values = np.linspace(0, 1, 1000), [] + + # Average precision, precision and recall curves + ap, p_curve, r_curve = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) for ci, c in enumerate(unique_classes): i = pred_cls == c n_l = nt[ci] # number of labels @@ -466,63 +589,66 @@ def ap_per_class(tp, # Recall recall = tpc / (n_l + eps) # recall curve - r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases + r_curve[ci] = np.interp(-x, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases # Precision precision = tpc / (tpc + fpc) # precision curve - p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score + p_curve[ci] = np.interp(-x, -conf[i], precision[:, 0], left=1) # p at pr_score # AP from recall-precision curve for j in range(tp.shape[1]): ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j]) if plot and j == 0: - py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5 + prec_values.append(np.interp(x, mrec, mpre)) # precision at mAP@0.5 + + prec_values = np.array(prec_values) # (nc, 1000) # Compute F1 (harmonic mean of precision and recall) - f1 = 2 * p * r / (p + r + eps) + f1_curve = 2 * p_curve * r_curve / (p_curve + r_curve + eps) names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data names = dict(enumerate(names)) # to dict if plot: - plot_pr_curve(px, py, ap, save_dir / f'{prefix}PR_curve.png', names, on_plot=on_plot) - plot_mc_curve(px, f1, save_dir / f'{prefix}F1_curve.png', names, ylabel='F1', on_plot=on_plot) - plot_mc_curve(px, p, save_dir / f'{prefix}P_curve.png', names, ylabel='Precision', on_plot=on_plot) - plot_mc_curve(px, r, save_dir / f'{prefix}R_curve.png', names, ylabel='Recall', on_plot=on_plot) + plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot) + plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot) + plot_mc_curve(x, p_curve, save_dir / f"{prefix}P_curve.png", names, ylabel="Precision", on_plot=on_plot) + plot_mc_curve(x, r_curve, save_dir / f"{prefix}R_curve.png", names, ylabel="Recall", on_plot=on_plot) - i = smooth(f1.mean(0), 0.1).argmax() # max F1 index - p, r, f1 = p[:, i], r[:, i], f1[:, i] + i = smooth(f1_curve.mean(0), 0.1).argmax() # max F1 index + p, r, f1 = p_curve[:, i], r_curve[:, i], f1_curve[:, i] # max-F1 precision, recall, F1 values tp = (r * nt).round() # true positives fp = (tp / (p + eps) - tp).round() # false positives - return tp, fp, p, r, f1, ap, unique_classes.astype(int) + return tp, fp, p, r, f1, ap, unique_classes.astype(int), p_curve, r_curve, f1_curve, x, prec_values class Metric(SimpleClass): """ - Class for computing evaluation metrics for YOLOv8 model. + Class for computing evaluation metrics for YOLOv8 model. - Attributes: - p (list): Precision for each class. Shape: (nc,). - r (list): Recall for each class. Shape: (nc,). - f1 (list): F1 score for each class. Shape: (nc,). - all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10). - ap_class_index (list): Index of class for each AP score. Shape: (nc,). - nc (int): Number of classes. + Attributes: + p (list): Precision for each class. Shape: (nc,). + r (list): Recall for each class. Shape: (nc,). + f1 (list): F1 score for each class. Shape: (nc,). + all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10). + ap_class_index (list): Index of class for each AP score. Shape: (nc,). + nc (int): Number of classes. - Methods: - ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or []. - ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or []. - mp(): Mean precision of all classes. Returns: Float. - mr(): Mean recall of all classes. Returns: Float. - map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float. - map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float. - map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float. - mean_results(): Mean of results, returns mp, mr, map50, map. - class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i]. - maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,). - fitness(): Model fitness as a weighted combination of metrics. Returns: Float. - update(results): Update metric attributes with new evaluation results. - """ + Methods: + ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or []. + ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or []. + mp(): Mean precision of all classes. Returns: Float. + mr(): Mean recall of all classes. Returns: Float. + map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float. + map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float. + map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float. + mean_results(): Mean of results, returns mp, mr, map50, map. + class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i]. + maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,). + fitness(): Model fitness as a weighted combination of metrics. Returns: Float. + update(results): Update metric attributes with new evaluation results. + """ def __init__(self) -> None: + """Initializes a Metric instance for computing evaluation metrics for the YOLOv8 model.""" self.p = [] # (nc, ) self.r = [] # (nc, ) self.f1 = [] # (nc, ) @@ -576,7 +702,7 @@ class Metric(SimpleClass): Returns the mean Average Precision (mAP) at an IoU threshold of 0.5. Returns: - (float): The mAP50 at an IoU threshold of 0.5. + (float): The mAP at an IoU threshold of 0.5. """ return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0 @@ -586,7 +712,7 @@ class Metric(SimpleClass): Returns the mean Average Precision (mAP) at an IoU threshold of 0.75. Returns: - (float): The mAP50 at an IoU threshold of 0.75. + (float): The mAP at an IoU threshold of 0.75. """ return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0 @@ -605,12 +731,12 @@ class Metric(SimpleClass): return [self.mp, self.mr, self.map50, self.map] def class_result(self, i): - """class-aware result, return p[i], r[i], ap50[i], ap[i].""" + """Class-aware result, return p[i], r[i], ap50[i], ap[i].""" return self.p[i], self.r[i], self.ap50[i], self.ap[i] @property def maps(self): - """mAP of each class.""" + """MAP of each class.""" maps = np.zeros(self.nc) + self.map for i, c in enumerate(self.ap_class_index): maps[c] = self.ap[i] @@ -623,10 +749,47 @@ class Metric(SimpleClass): def update(self, results): """ + Updates the evaluation metrics of the model with a new set of results. + Args: - results (tuple): A tuple of (p, r, ap, f1, ap_class) + results (tuple): A tuple containing the following evaluation metrics: + - p (list): Precision for each class. Shape: (nc,). + - r (list): Recall for each class. Shape: (nc,). + - f1 (list): F1 score for each class. Shape: (nc,). + - all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10). + - ap_class_index (list): Index of class for each AP score. Shape: (nc,). + + Side Effects: + Updates the class attributes `self.p`, `self.r`, `self.f1`, `self.all_ap`, and `self.ap_class_index` based + on the values provided in the `results` tuple. """ - self.p, self.r, self.f1, self.all_ap, self.ap_class_index = results + ( + self.p, + self.r, + self.f1, + self.all_ap, + self.ap_class_index, + self.p_curve, + self.r_curve, + self.f1_curve, + self.px, + self.prec_values, + ) = results + + @property + def curves(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [] + + @property + def curves_results(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [ + [self.px, self.prec_values, "Recall", "Precision"], + [self.px, self.f1_curve, "Confidence", "F1"], + [self.px, self.p_curve, "Confidence", "Precision"], + [self.px, self.r_curve, "Confidence", "Recall"], + ] class DetMetrics(SimpleClass): @@ -657,33 +820,39 @@ class DetMetrics(SimpleClass): fitness: Computes the fitness score based on the computed detection metrics. ap_class_index: Returns a list of class indices sorted by their average precision (AP) values. results_dict: Returns a dictionary that maps detection metric keys to their computed values. + curves: TODO + curves_results: TODO """ - def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None: + def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None: + """Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names.""" self.save_dir = save_dir self.plot = plot self.on_plot = on_plot self.names = names self.box = Metric() - self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "detect" def process(self, tp, conf, pred_cls, target_cls): """Process predicted results for object detection and update metrics.""" - results = ap_per_class(tp, - conf, - pred_cls, - target_cls, - plot=self.plot, - save_dir=self.save_dir, - names=self.names, - on_plot=self.on_plot)[2:] + results = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + save_dir=self.save_dir, + names=self.names, + on_plot=self.on_plot, + )[2:] self.box.nc = len(self.names) self.box.update(results) @property def keys(self): """Returns a list of keys for accessing specific metrics.""" - return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)'] + return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"] def mean_results(self): """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95.""" @@ -711,7 +880,17 @@ class DetMetrics(SimpleClass): @property def results_dict(self): """Returns dictionary of computed performance metrics and statistics.""" - return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness])) + return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) + + @property + def curves(self): + """Returns a list of curves for accessing specific metrics curves.""" + return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"] + + @property + def curves_results(self): + """Returns dictionary of computed performance metrics and statistics.""" + return self.box.curves_results class SegmentMetrics(SimpleClass): @@ -743,47 +922,53 @@ class SegmentMetrics(SimpleClass): results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score. """ - def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None: + def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None: + """Initialize a SegmentMetrics instance with a save directory, plot flag, callback function, and class names.""" self.save_dir = save_dir self.plot = plot self.on_plot = on_plot self.names = names self.box = Metric() self.seg = Metric() - self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "segment" - def process(self, tp_b, tp_m, conf, pred_cls, target_cls): + def process(self, tp, tp_m, conf, pred_cls, target_cls): """ Processes the detection and segmentation metrics over the given set of predictions. Args: - tp_b (list): List of True Positive boxes. + tp (list): List of True Positive boxes. tp_m (list): List of True Positive masks. conf (list): List of confidence scores. pred_cls (list): List of predicted classes. target_cls (list): List of target classes. """ - results_mask = ap_per_class(tp_m, - conf, - pred_cls, - target_cls, - plot=self.plot, - on_plot=self.on_plot, - save_dir=self.save_dir, - names=self.names, - prefix='Mask')[2:] + results_mask = ap_per_class( + tp_m, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=self.on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Mask", + )[2:] self.seg.nc = len(self.names) self.seg.update(results_mask) - results_box = ap_per_class(tp_b, - conf, - pred_cls, - target_cls, - plot=self.plot, - on_plot=self.on_plot, - save_dir=self.save_dir, - names=self.names, - prefix='Box')[2:] + results_box = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=self.on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Box", + )[2:] self.box.nc = len(self.names) self.box.update(results_box) @@ -791,8 +976,15 @@ class SegmentMetrics(SimpleClass): def keys(self): """Returns a list of keys for accessing metrics.""" return [ - 'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)', - 'metrics/precision(M)', 'metrics/recall(M)', 'metrics/mAP50(M)', 'metrics/mAP50-95(M)'] + "metrics/precision(B)", + "metrics/recall(B)", + "metrics/mAP50(B)", + "metrics/mAP50-95(B)", + "metrics/precision(M)", + "metrics/recall(M)", + "metrics/mAP50(M)", + "metrics/mAP50-95(M)", + ] def mean_results(self): """Return the mean metrics for bounding box and segmentation results.""" @@ -820,7 +1012,26 @@ class SegmentMetrics(SimpleClass): @property def results_dict(self): """Returns results of object detection model for evaluation.""" - return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness])) + return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) + + @property + def curves(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [ + "Precision-Recall(B)", + "F1-Confidence(B)", + "Precision-Confidence(B)", + "Recall-Confidence(B)", + "Precision-Recall(M)", + "F1-Confidence(M)", + "Precision-Confidence(M)", + "Recall-Confidence(M)", + ] + + @property + def curves_results(self): + """Returns dictionary of computed performance metrics and statistics.""" + return self.box.curves_results + self.seg.curves_results class PoseMetrics(SegmentMetrics): @@ -852,7 +1063,8 @@ class PoseMetrics(SegmentMetrics): results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score. """ - def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None: + def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None: + """Initialize the PoseMetrics class with directory path, class names, and plotting options.""" super().__init__(save_dir, plot, names) self.save_dir = save_dir self.plot = plot @@ -860,40 +1072,45 @@ class PoseMetrics(SegmentMetrics): self.names = names self.box = Metric() self.pose = Metric() - self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "pose" - def process(self, tp_b, tp_p, conf, pred_cls, target_cls): + def process(self, tp, tp_p, conf, pred_cls, target_cls): """ Processes the detection and pose metrics over the given set of predictions. Args: - tp_b (list): List of True Positive boxes. + tp (list): List of True Positive boxes. tp_p (list): List of True Positive keypoints. conf (list): List of confidence scores. pred_cls (list): List of predicted classes. target_cls (list): List of target classes. """ - results_pose = ap_per_class(tp_p, - conf, - pred_cls, - target_cls, - plot=self.plot, - on_plot=self.on_plot, - save_dir=self.save_dir, - names=self.names, - prefix='Pose')[2:] + results_pose = ap_per_class( + tp_p, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=self.on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Pose", + )[2:] self.pose.nc = len(self.names) self.pose.update(results_pose) - results_box = ap_per_class(tp_b, - conf, - pred_cls, - target_cls, - plot=self.plot, - on_plot=self.on_plot, - save_dir=self.save_dir, - names=self.names, - prefix='Box')[2:] + results_box = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=self.on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Box", + )[2:] self.box.nc = len(self.names) self.box.update(results_box) @@ -901,8 +1118,15 @@ class PoseMetrics(SegmentMetrics): def keys(self): """Returns list of evaluation metric keys.""" return [ - 'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)', - 'metrics/precision(P)', 'metrics/recall(P)', 'metrics/mAP50(P)', 'metrics/mAP50-95(P)'] + "metrics/precision(B)", + "metrics/recall(B)", + "metrics/mAP50(B)", + "metrics/mAP50-95(B)", + "metrics/precision(P)", + "metrics/recall(P)", + "metrics/mAP50(P)", + "metrics/mAP50-95(P)", + ] def mean_results(self): """Return the mean results of box and pose.""" @@ -922,6 +1146,25 @@ class PoseMetrics(SegmentMetrics): """Computes classification metrics and speed using the `targets` and `pred` inputs.""" return self.pose.fitness() + self.box.fitness() + @property + def curves(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [ + "Precision-Recall(B)", + "F1-Confidence(B)", + "Precision-Confidence(B)", + "Recall-Confidence(B)", + "Precision-Recall(P)", + "F1-Confidence(P)", + "Precision-Confidence(P)", + "Recall-Confidence(P)", + ] + + @property + def curves_results(self): + """Returns dictionary of computed performance metrics and statistics.""" + return self.box.curves_results + self.pose.curves_results + class ClassifyMetrics(SimpleClass): """ @@ -942,9 +1185,11 @@ class ClassifyMetrics(SimpleClass): """ def __init__(self) -> None: + """Initialize a ClassifyMetrics instance.""" self.top1 = 0 self.top5 = 0 - self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "classify" def process(self, targets, pred): """Target classes and predicted classes.""" @@ -961,9 +1206,87 @@ class ClassifyMetrics(SimpleClass): @property def results_dict(self): """Returns a dictionary with model's performance metrics and fitness score.""" - return dict(zip(self.keys + ['fitness'], [self.top1, self.top5, self.fitness])) + return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness])) @property def keys(self): """Returns a list of keys for the results_dict property.""" - return ['metrics/accuracy_top1', 'metrics/accuracy_top5'] + return ["metrics/accuracy_top1", "metrics/accuracy_top5"] + + @property + def curves(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [] + + @property + def curves_results(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [] + + +class OBBMetrics(SimpleClass): + def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None: + self.save_dir = save_dir + self.plot = plot + self.on_plot = on_plot + self.names = names + self.box = Metric() + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + + def process(self, tp, conf, pred_cls, target_cls): + """Process predicted results for object detection and update metrics.""" + results = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + save_dir=self.save_dir, + names=self.names, + on_plot=self.on_plot, + )[2:] + self.box.nc = len(self.names) + self.box.update(results) + + @property + def keys(self): + """Returns a list of keys for accessing specific metrics.""" + return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"] + + def mean_results(self): + """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95.""" + return self.box.mean_results() + + def class_result(self, i): + """Return the result of evaluating the performance of an object detection model on a specific class.""" + return self.box.class_result(i) + + @property + def maps(self): + """Returns mean Average Precision (mAP) scores per class.""" + return self.box.maps + + @property + def fitness(self): + """Returns the fitness of box object.""" + return self.box.fitness() + + @property + def ap_class_index(self): + """Returns the average precision index per class.""" + return self.box.ap_class_index + + @property + def results_dict(self): + """Returns dictionary of computed performance metrics and statistics.""" + return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) + + @property + def curves(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [] + + @property + def curves_results(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [] diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py index 9089d0f..edbb103 100644 --- a/ultralytics/utils/ops.py +++ b/ultralytics/utils/ops.py @@ -12,6 +12,7 @@ import torch.nn.functional as F import torchvision from ultralytics.utils import LOGGER +from ultralytics.utils.metrics import batch_probiou class Profile(contextlib.ContextDecorator): @@ -22,22 +23,24 @@ class Profile(contextlib.ContextDecorator): ```python from ultralytics.utils.ops import Profile - with Profile() as dt: + with Profile(device=device) as dt: pass # slow operation here print(dt) # prints "Elapsed time is 9.5367431640625e-07 s" ``` """ - def __init__(self, t=0.0): + def __init__(self, t=0.0, device: torch.device = None): """ Initialize the Profile class. Args: t (float): Initial time. Defaults to 0.0. + device (torch.device): Devices used for model inference. Defaults to None (cpu). """ self.t = t - self.cuda = torch.cuda.is_available() + self.device = device + self.cuda = bool(device and str(device).startswith("cuda")) def __enter__(self): """Start timing.""" @@ -50,12 +53,13 @@ class Profile(contextlib.ContextDecorator): self.t += self.dt # accumulate dt def __str__(self): - return f'Elapsed time is {self.t} s' + """Returns a human-readable string representing the accumulated elapsed time in the profiler.""" + return f"Elapsed time is {self.t} s" def time(self): """Get current time.""" if self.cuda: - torch.cuda.synchronize() + torch.cuda.synchronize(self.device) return time.time() @@ -71,18 +75,21 @@ def segment2box(segment, width=640, height=640): Returns: (np.ndarray): the minimum and maximum x and y values of the segment. """ - # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy) x, y = segment.T # segment xy inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height) - x, y, = x[inside], y[inside] - return np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) if any(x) else np.zeros( - 4, dtype=segment.dtype) # xyxy + x = x[inside] + y = y[inside] + return ( + np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) + if any(x) + else np.zeros(4, dtype=segment.dtype) + ) # xyxy -def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True): +def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False): """ - Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in - (img1_shape) to the shape of a different image (img0_shape). + Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally + specified in (img1_shape) to the shape of a different image (img0_shape). Args: img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width). @@ -92,24 +99,29 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True): calculated based on the size difference between the two images. padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular rescaling. + xywh (bool): The box format is xywh or not, default=False. Returns: boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2) """ if ratio_pad is None: # calculate from img0_shape gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new - pad = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), round( - (img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1) # wh padding + pad = ( + round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), + round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1), + ) # wh padding else: gain = ratio_pad[0][0] pad = ratio_pad[1] if padding: - boxes[..., [0, 2]] -= pad[0] # x padding - boxes[..., [1, 3]] -= pad[1] # y padding + boxes[..., 0] -= pad[0] # x padding + boxes[..., 1] -= pad[1] # y padding + if not xywh: + boxes[..., 2] -= pad[0] # x padding + boxes[..., 3] -= pad[1] # y padding boxes[..., :4] /= gain - clip_boxes(boxes, img0_shape) - return boxes + return clip_boxes(boxes, img0_shape) def make_divisible(x, divisor): @@ -128,19 +140,41 @@ def make_divisible(x, divisor): return math.ceil(x / divisor) * divisor +def nms_rotated(boxes, scores, threshold=0.45): + """ + NMS for obbs, powered by probiou and fast-nms. + + Args: + boxes (torch.Tensor): (N, 5), xywhr. + scores (torch.Tensor): (N, ). + threshold (float): IoU threshold. + + Returns: + """ + if len(boxes) == 0: + return np.empty((0,), dtype=np.int8) + sorted_idx = torch.argsort(scores, descending=True) + boxes = boxes[sorted_idx] + ious = batch_probiou(boxes, boxes).triu_(diagonal=1) + pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1) + return sorted_idx[pick] + + def non_max_suppression( - prediction, - conf_thres=0.25, - iou_thres=0.45, - classes=None, - agnostic=False, - multi_label=False, - labels=(), - max_det=300, - nc=0, # number of classes (optional) - max_time_img=0.05, - max_nms=30000, - max_wh=7680, + prediction, + conf_thres=0.25, + iou_thres=0.45, + classes=None, + agnostic=False, + multi_label=False, + labels=(), + max_det=300, + nc=0, # number of classes (optional) + max_time_img=0.05, + max_nms=30000, + max_wh=7680, + in_place=True, + rotated=False, ): """ Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box. @@ -164,7 +198,8 @@ def non_max_suppression( nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks. max_time_img (float): The maximum time (seconds) for processing one image. max_nms (int): The maximum number of boxes into torchvision.ops.nms(). - max_wh (int): The maximum box width and height in pixels + max_wh (int): The maximum box width and height in pixels. + in_place (bool): If True, the input prediction tensor will be modified in place. Returns: (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of @@ -173,15 +208,11 @@ def non_max_suppression( """ # Checks - assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' - assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0' + assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0" + assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out) prediction = prediction[0] # select only inference output - device = prediction.device - mps = 'mps' in device.type # Apple MPS - if mps: # MPS not fully supported yet, convert tensors to CPU before NMS - prediction = prediction.cpu() bs = prediction.shape[0] # batch size nc = nc or (prediction.shape[1] - 4) # number of classes nm = prediction.shape[1] - nc - 4 @@ -190,11 +221,15 @@ def non_max_suppression( # Settings # min_wh = 2 # (pixels) minimum box width and height - time_limit = 0.5 + max_time_img * bs # seconds to quit after + time_limit = 2.0 + max_time_img * bs # seconds to quit after multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84) - prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy + if not rotated: + if in_place: + prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy + else: + prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy t = time.time() output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs @@ -204,7 +239,7 @@ def non_max_suppression( x = x[xc[xi]] # confidence # Cat apriori labels if autolabelling - if labels and len(labels[xi]): + if labels and len(labels[xi]) and not rotated: lb = labels[xi] v = torch.zeros((len(lb), nc + nm + 4), device=x.device) v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box @@ -238,8 +273,13 @@ def non_max_suppression( # Batched NMS c = x[:, 5:6] * (0 if agnostic else max_wh) # classes - boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores - i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + scores = x[:, 4] # scores + if rotated: + boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr + i = nms_rotated(boxes, scores, iou_thres) + else: + boxes = x[:, :4] + c # boxes (offset by class) + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS i = i[:max_det] # limit detections # # Experimental @@ -247,7 +287,7 @@ def non_max_suppression( # if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) # # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) # from .metrics import box_iou - # iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + # iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix # weights = iou * scores[None] # box weights # x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes # redundant = True # require redundant detections @@ -255,10 +295,8 @@ def non_max_suppression( # i = i[iou.sum(1) > 1] # require redundancy output[xi] = x[i] - if mps: - output[xi] = output[xi].to(device) if (time.time() - t) > time_limit: - LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded') + LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded") break # time limit exceeded return output @@ -269,17 +307,21 @@ def clip_boxes(boxes, shape): Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape. Args: - boxes (torch.Tensor): the bounding boxes to clip - shape (tuple): the shape of the image + boxes (torch.Tensor): the bounding boxes to clip + shape (tuple): the shape of the image + + Returns: + (torch.Tensor | numpy.ndarray): Clipped boxes """ - if isinstance(boxes, torch.Tensor): # faster individually - boxes[..., 0].clamp_(0, shape[1]) # x1 - boxes[..., 1].clamp_(0, shape[0]) # y1 - boxes[..., 2].clamp_(0, shape[1]) # x2 - boxes[..., 3].clamp_(0, shape[0]) # y2 + if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) + boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1 + boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1 + boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2 + boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2 else: # np.array (faster grouped) boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2 boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2 + return boxes def clip_coords(coords, shape): @@ -291,19 +333,20 @@ def clip_coords(coords, shape): shape (tuple): A tuple of integers representing the size of the image in the format (height, width). Returns: - (None): The function modifies the input `coordinates` in place, by clipping each coordinate to the image boundaries. + (torch.Tensor | numpy.ndarray): Clipped coordinates """ - if isinstance(coords, torch.Tensor): # faster individually - coords[..., 0].clamp_(0, shape[1]) # x - coords[..., 1].clamp_(0, shape[0]) # y + if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) + coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x + coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y else: # np.array (faster grouped) coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y + return coords def scale_image(masks, im0_shape, ratio_pad=None): """ - Takes a mask, and resizes it to the original image size + Takes a mask, and resizes it to the original image size. Args: masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3]. @@ -321,7 +364,7 @@ def scale_image(masks, im0_shape, ratio_pad=None): gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding else: - gain = ratio_pad[0][0] + # gain = ratio_pad[0][0] pad = ratio_pad[1] top, left = int(pad[1]), int(pad[0]) # y, x bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0]) @@ -347,7 +390,7 @@ def xyxy2xywh(x): Returns: y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format. """ - assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}' + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center @@ -367,7 +410,7 @@ def xywh2xyxy(x): Returns: y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format. """ - assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}' + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy dw = x[..., 2] / 2 # half-width dh = x[..., 3] / 2 # half-height @@ -392,7 +435,7 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box. """ - assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}' + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y @@ -403,8 +446,8 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): """ - Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. - x, y, width and height are normalized to image dimensions + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y, + width and height are normalized to image dimensions. Args: x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. @@ -417,8 +460,8 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format """ if clip: - clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip - assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}' + x = clip_boxes(x, (h - eps, w - eps)) + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center @@ -445,7 +488,7 @@ def xywh2ltwh(x): def xyxy2ltwh(x): """ - Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right + Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right. Args: x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format @@ -461,7 +504,7 @@ def xyxy2ltwh(x): def ltwh2xywh(x): """ - Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center + Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center. Args: x (torch.Tensor): the input tensor @@ -477,7 +520,8 @@ def ltwh2xywh(x): def xyxyxyxy2xywhr(corners): """ - Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. + Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are + expected in degrees from 0 to 90. Args: corners (numpy.ndarray | torch.Tensor): Input corners of shape (n, 8). @@ -485,66 +529,53 @@ def xyxyxyxy2xywhr(corners): Returns: (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5). """ - is_numpy = isinstance(corners, np.ndarray) - atan2, sqrt = (np.arctan2, np.sqrt) if is_numpy else (torch.atan2, torch.sqrt) - - x1, y1, x2, y2, x3, y3, x4, y4 = corners.T - cx = (x1 + x3) / 2 - cy = (y1 + y3) / 2 - dx21 = x2 - x1 - dy21 = y2 - y1 - - w = sqrt(dx21 ** 2 + dy21 ** 2) - h = sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2) - - rotation = atan2(-dy21, dx21) - rotation *= 180.0 / math.pi # radians to degrees - - return np.vstack((cx, cy, w, h, rotation)).T if is_numpy else torch.stack((cx, cy, w, h, rotation), dim=1) + is_torch = isinstance(corners, torch.Tensor) + points = corners.cpu().numpy() if is_torch else corners + points = points.reshape(len(corners), -1, 2) + rboxes = [] + for pts in points: + # NOTE: Use cv2.minAreaRect to get accurate xywhr, + # especially some objects are cut off by augmentations in dataloader. + (x, y), (w, h), angle = cv2.minAreaRect(pts) + rboxes.append([x, y, w, h, angle / 180 * np.pi]) + return ( + torch.tensor(rboxes, device=corners.device, dtype=corners.dtype) + if is_torch + else np.asarray(rboxes, dtype=points.dtype) + ) # rboxes -def xywhr2xyxyxyxy(center): +def xywhr2xyxyxyxy(rboxes): """ - Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. + Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should + be in degrees from 0 to 90. Args: - center (numpy.ndarray | torch.Tensor): Input data in [cx, cy, w, h, rotation] format of shape (n, 5). + rboxes (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5). Returns: - (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 8). + (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2). """ - is_numpy = isinstance(center, np.ndarray) + is_numpy = isinstance(rboxes, np.ndarray) cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin) - cx, cy, w, h, rotation = center.T - rotation *= math.pi / 180.0 # degrees to radians - - dx = w / 2 - dy = h / 2 - - cos_rot = cos(rotation) - sin_rot = sin(rotation) - dx_cos_rot = dx * cos_rot - dx_sin_rot = dx * sin_rot - dy_cos_rot = dy * cos_rot - dy_sin_rot = dy * sin_rot - - x1 = cx - dx_cos_rot - dy_sin_rot - y1 = cy + dx_sin_rot - dy_cos_rot - x2 = cx + dx_cos_rot - dy_sin_rot - y2 = cy - dx_sin_rot - dy_cos_rot - x3 = cx + dx_cos_rot + dy_sin_rot - y3 = cy - dx_sin_rot + dy_cos_rot - x4 = cx - dx_cos_rot + dy_sin_rot - y4 = cy + dx_sin_rot + dy_cos_rot - - return np.vstack((x1, y1, x2, y2, x3, y3, x4, y4)).T if is_numpy else torch.stack( - (x1, y1, x2, y2, x3, y3, x4, y4), dim=1) + ctr = rboxes[..., :2] + w, h, angle = (rboxes[..., i : i + 1] for i in range(2, 5)) + cos_value, sin_value = cos(angle), sin(angle) + vec1 = [w / 2 * cos_value, w / 2 * sin_value] + vec2 = [-h / 2 * sin_value, h / 2 * cos_value] + vec1 = np.concatenate(vec1, axis=-1) if is_numpy else torch.cat(vec1, dim=-1) + vec2 = np.concatenate(vec2, axis=-1) if is_numpy else torch.cat(vec2, dim=-1) + pt1 = ctr + vec1 + vec2 + pt2 = ctr + vec1 - vec2 + pt3 = ctr - vec1 - vec2 + pt4 = ctr - vec1 + vec2 + return np.stack([pt1, pt2, pt3, pt4], axis=-2) if is_numpy else torch.stack([pt1, pt2, pt3, pt4], dim=-2) def ltwh2xyxy(x): """ - It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right. Args: x (np.ndarray | torch.Tensor): the input image @@ -590,8 +621,9 @@ def resample_segments(segments, n=1000): s = np.concatenate((s, s[0:1, :]), axis=0) x = np.linspace(0, len(s) - 1, n) xp = np.arange(len(s)) - segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], - dtype=np.float32).reshape(2, -1).T # segment xy + segments[i] = ( + np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T + ) # segment xy return segments @@ -606,7 +638,7 @@ def crop_mask(masks, boxes): Returns: (torch.Tensor): The masks are being cropped to the bounding box. """ - n, h, w = masks.shape + _, h, w = masks.shape x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1) r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w) c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1) @@ -616,8 +648,8 @@ def crop_mask(masks, boxes): def process_mask_upsample(protos, masks_in, bboxes, shape): """ - Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher - quality but is slower. + Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher quality + but is slower. Args: protos (torch.Tensor): [mask_dim, mask_h, mask_w] @@ -630,7 +662,7 @@ def process_mask_upsample(protos, masks_in, bboxes, shape): """ c, mh, mw = protos.shape # CHW masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) - masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW + masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW masks = crop_mask(masks, bboxes) # CHW return masks.gt_(0.5) @@ -654,16 +686,18 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False): c, mh, mw = protos.shape # CHW ih, iw = shape masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW + width_ratio = mw / iw + height_ratio = mh / ih downsampled_bboxes = bboxes.clone() - downsampled_bboxes[:, 0] *= mw / iw - downsampled_bboxes[:, 2] *= mw / iw - downsampled_bboxes[:, 3] *= mh / ih - downsampled_bboxes[:, 1] *= mh / ih + downsampled_bboxes[:, 0] *= width_ratio + downsampled_bboxes[:, 2] *= width_ratio + downsampled_bboxes[:, 3] *= height_ratio + downsampled_bboxes[:, 1] *= height_ratio masks = crop_mask(masks, downsampled_bboxes) # CHW if upsample: - masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW + masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW return masks.gt_(0.5) @@ -707,13 +741,13 @@ def scale_masks(masks, shape, padding=True): bottom, right = (int(mh - pad[1]), int(mw - pad[0])) masks = masks[..., top:bottom, left:right] - masks = F.interpolate(masks, shape, mode='bilinear', align_corners=False) # NCHW + masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW return masks def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True): """ - Rescale segment coordinates (xy) from img1_shape to img0_shape + Rescale segment coordinates (xy) from img1_shape to img0_shape. Args: img1_shape (tuple): The shape of the image that the coords are from. @@ -739,14 +773,32 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False coords[..., 1] -= pad[1] # y padding coords[..., 0] /= gain coords[..., 1] /= gain - clip_coords(coords, img0_shape) + coords = clip_coords(coords, img0_shape) if normalize: coords[..., 0] /= img0_shape[1] # width coords[..., 1] /= img0_shape[0] # height return coords -def masks2segments(masks, strategy='largest'): +def regularize_rboxes(rboxes): + """ + Regularize rotated boxes in range [0, pi/2]. + + Args: + rboxes (torch.Tensor): (N, 5), xywhr. + + Returns: + (torch.Tensor): The regularized boxes. + """ + x, y, w, h, t = rboxes.unbind(dim=-1) + # Swap edge and angle if h >= w + w_ = torch.where(w > h, w, h) + h_ = torch.where(w > h, h, w) + t = torch.where(w > h, t, t + math.pi / 2) % math.pi + return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes + + +def masks2segments(masks, strategy="largest"): """ It takes a list of masks(n,h,w) and returns a list of segments(n,xy) @@ -758,16 +810,16 @@ def masks2segments(masks, strategy='largest'): segments (List): list of segment masks """ segments = [] - for x in masks.int().cpu().numpy().astype('uint8'): + for x in masks.int().cpu().numpy().astype("uint8"): c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] if c: - if strategy == 'concat': # concatenate all segments + if strategy == "concat": # concatenate all segments c = np.concatenate([x.reshape(-1, 2) for x in c]) - elif strategy == 'largest': # select largest segment + elif strategy == "largest": # select largest segment c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) else: c = np.zeros((0, 2)) # no segments found - segments.append(c.astype('float32')) + segments.append(c.astype("float32")) return segments @@ -794,4 +846,19 @@ def clean_str(s): Returns: (str): a string with special characters replaced by an underscore _ """ - return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s) + return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) + +def v10postprocess(preds, max_det, nc=80): + assert(4 + nc == preds.shape[-1]) + boxes, scores = preds.split([4, nc], dim=-1) + max_scores = scores.amax(dim=-1) + max_scores, index = torch.topk(max_scores, max_det, dim=-1) + index = index.unsqueeze(-1) + boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1])) + scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1])) + + scores, index = torch.topk(scores.flatten(1), max_det, dim=-1) + labels = index % nc + index = index // nc + boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) + return boxes, scores, labels \ No newline at end of file diff --git a/ultralytics/utils/patches.py b/ultralytics/utils/patches.py index a145763..d438407 100644 --- a/ultralytics/utils/patches.py +++ b/ultralytics/utils/patches.py @@ -1,8 +1,7 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -""" -Monkey patches to update/extend functionality of existing functions -""" +"""Monkey patches to update/extend functionality of existing functions.""" +import time from pathlib import Path import cv2 @@ -14,7 +13,8 @@ _imshow = cv2.imshow # copy to avoid recursion errors def imread(filename: str, flags: int = cv2.IMREAD_COLOR): - """Read an image from a file. + """ + Read an image from a file. Args: filename (str): Path to the file to read. @@ -27,7 +27,8 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR): def imwrite(filename: str, img: np.ndarray, params=None): - """Write an image to a file. + """ + Write an image to a file. Args: filename (str): Path to the file to write. @@ -45,31 +46,43 @@ def imwrite(filename: str, img: np.ndarray, params=None): def imshow(winname: str, mat: np.ndarray): - """Displays an image in the specified window. + """ + Displays an image in the specified window. Args: winname (str): Name of the window. mat (np.ndarray): Image to be shown. """ - _imshow(winname.encode('unicode_escape').decode(), mat) + _imshow(winname.encode("unicode_escape").decode(), mat) # PyTorch functions ---------------------------------------------------------------------------------------------------- _torch_save = torch.save # copy to avoid recursion errors -def torch_save(*args, **kwargs): - """Use dill (if exists) to serialize the lambda functions where pickle does not do this. +def torch_save(*args, use_dill=True, **kwargs): + """ + Optionally use dill to serialize lambda functions where pickle does not, adding robustness with 3 retries and + exponential standoff in case of save failure. Args: *args (tuple): Positional arguments to pass to torch.save. - **kwargs (dict): Keyword arguments to pass to torch.save. + use_dill (bool): Whether to try using dill for serialization if available. Defaults to True. + **kwargs (any): Keyword arguments to pass to torch.save. """ try: - import dill as pickle # noqa - except ImportError: + assert use_dill + import dill as pickle + except (AssertionError, ImportError): import pickle - if 'pickle_module' not in kwargs: - kwargs['pickle_module'] = pickle # noqa - return _torch_save(*args, **kwargs) + if "pickle_module" not in kwargs: + kwargs["pickle_module"] = pickle + + for i in range(4): # 3 retries + try: + return _torch_save(*args, **kwargs) + except RuntimeError as e: # unable to save, possibly waiting for device to flush or antivirus scan + if i == 3: + raise e + time.sleep((2**i) / 2) # exponential standoff: 0.5s, 1.0s, 2.0s diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py index 6237f13..d0215ba 100644 --- a/ultralytics/utils/plotting.py +++ b/ultralytics/utils/plotting.py @@ -13,7 +13,6 @@ from PIL import Image, ImageDraw, ImageFont from PIL import __version__ as pil_version from ultralytics.utils import LOGGER, TryExcept, ops, plt_settings, threaded - from .checks import check_font, check_version, is_ascii from .files import increment_path @@ -28,20 +27,60 @@ class Colors: Attributes: palette (list of tuple): List of RGB color values. n (int): The number of colors in the palette. - pose_palette (np.array): A specific color palette array with dtype np.uint8. + pose_palette (np.ndarray): A specific color palette array with dtype np.uint8. """ def __init__(self): """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values().""" - hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', - '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7') - self.palette = [self.hex2rgb(f'#{c}') for c in hexs] + hexs = ( + "FF3838", + "FF9D97", + "FF701F", + "FFB21D", + "CFD231", + "48F90A", + "92CC17", + "3DDB86", + "1A9334", + "00D4BB", + "2C99A8", + "00C2FF", + "344593", + "6473FF", + "0018EC", + "8438FF", + "520085", + "CB38FF", + "FF95C8", + "FF37C7", + ) + self.palette = [self.hex2rgb(f"#{c}") for c in hexs] self.n = len(self.palette) - self.pose_palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], [230, 230, 0], [255, 153, 255], - [153, 204, 255], [255, 102, 255], [255, 51, 255], [102, 178, 255], [51, 153, 255], - [255, 153, 153], [255, 102, 102], [255, 51, 51], [153, 255, 153], [102, 255, 102], - [51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], [255, 255, 255]], - dtype=np.uint8) + self.pose_palette = np.array( + [ + [255, 128, 0], + [255, 153, 51], + [255, 178, 102], + [230, 230, 0], + [255, 153, 255], + [153, 204, 255], + [255, 102, 255], + [255, 51, 255], + [102, 178, 255], + [51, 153, 255], + [255, 153, 153], + [255, 102, 102], + [255, 51, 51], + [153, 255, 153], + [102, 255, 102], + [51, 255, 51], + [0, 255, 0], + [0, 0, 255], + [255, 0, 0], + [255, 255, 255], + ], + dtype=np.uint8, + ) def __call__(self, i, bgr=False): """Converts hex color codes to RGB values.""" @@ -51,7 +90,7 @@ class Colors: @staticmethod def hex2rgb(h): """Converts hex color codes to RGB values (i.e. default PIL order).""" - return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) + return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)) colors = Colors() # create instance for 'from utils.plots import colors' @@ -71,65 +110,99 @@ class Annotator: kpt_color (List[int]): Color palette for keypoints. """ - def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'): + def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs.""" - assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.' non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic - self.pil = pil or non_ascii + input_is_pil = isinstance(im, Image.Image) + self.pil = pil or non_ascii or input_is_pil + self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2) if self.pil: # use PIL - self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) + self.im = im if input_is_pil else Image.fromarray(im) self.draw = ImageDraw.Draw(self.im) try: - font = check_font('Arial.Unicode.ttf' if non_ascii else font) + font = check_font("Arial.Unicode.ttf" if non_ascii else font) size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12) self.font = ImageFont.truetype(str(font), size) except Exception: self.font = ImageFont.load_default() # Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string) - if check_version(pil_version, '9.2.0'): + if check_version(pil_version, "9.2.0"): self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height else: # use cv2 - self.im = im - self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width + assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." + self.im = im if im.flags.writeable else im.copy() + self.tf = max(self.lw - 1, 1) # font thickness + self.sf = self.lw / 3 # font scale # Pose - self.skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9], - [8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]] + self.skeleton = [ + [16, 14], + [14, 12], + [17, 15], + [15, 13], + [12, 13], + [6, 12], + [7, 13], + [6, 7], + [6, 8], + [7, 9], + [8, 10], + [9, 11], + [2, 3], + [1, 2], + [1, 3], + [2, 4], + [3, 5], + [4, 6], + [5, 7], + ] self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]] self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]] - def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)): + def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): """Add one xyxy box to image with label.""" if isinstance(box, torch.Tensor): box = box.tolist() if self.pil or not is_ascii(label): - self.draw.rectangle(box, width=self.lw, outline=color) # box + if rotated: + p1 = box[0] + # NOTE: PIL-version polygon needs tuple type. + self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) + else: + p1 = (box[0], box[1]) + self.draw.rectangle(box, width=self.lw, outline=color) # box if label: w, h = self.font.getsize(label) # text width, height - outside = box[1] - h >= 0 # label fits outside box + outside = p1[1] - h >= 0 # label fits outside box self.draw.rectangle( - (box[0], box[1] - h if outside else box[1], box[0] + w + 1, - box[1] + 1 if outside else box[1] + h + 1), + (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1), fill=color, ) # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0 - self.draw.text((box[0], box[1] - h if outside else box[1]), label, fill=txt_color, font=self.font) + self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font) else: # cv2 - p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) - cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if rotated: + p1 = [int(b) for b in box[0]] + # NOTE: cv2-version polylines needs np.asarray type. + cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw) + else: + p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) + cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) if label: - tf = max(self.lw - 1, 1) # font thickness - w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height + w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height outside = p1[1] - h >= 3 p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled - cv2.putText(self.im, - label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), - 0, - self.lw / 3, - txt_color, - thickness=tf, - lineType=cv2.LINE_AA) + cv2.putText( + self.im, + label, + (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), + 0, + self.sf, + txt_color, + thickness=self.tf, + lineType=cv2.LINE_AA, + ) def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False): """ @@ -154,13 +227,13 @@ class Annotator: masks = masks.unsqueeze(3) # shape(n,h,w,1) masks_color = masks * (colors * alpha) # shape(n,h,w,3) - inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1) + inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1) mcs = masks_color.max(dim=0).values # shape(n,h,w,3) im_gpu = im_gpu.flip(dims=[0]) # flip channel im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3) - im_gpu = im_gpu * inv_alph_masks[-1] + mcs - im_mask = (im_gpu * 255) + im_gpu = im_gpu * inv_alpha_masks[-1] + mcs + im_mask = im_gpu * 255 im_mask_np = im_mask.byte().cpu().numpy() self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape) if self.pil: @@ -178,13 +251,14 @@ class Annotator: kpt_line (bool, optional): If True, the function will draw lines connecting keypoints for human pose. Default is True. - Note: `kpt_line=True` currently only supports human pose plotting. + Note: + `kpt_line=True` currently only supports human pose plotting. """ if self.pil: # Convert to numpy first self.im = np.asarray(self.im).copy() nkpt, ndim = kpts.shape - is_pose = nkpt == 17 and ndim == 3 + is_pose = nkpt == 17 and ndim in {2, 3} kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting for i, k in enumerate(kpts): color_k = [int(x) for x in self.kpt_color[i]] if is_pose else colors(i) @@ -219,9 +293,9 @@ class Annotator: """Add rectangle to image (PIL-only).""" self.draw.rectangle(xy, fill, outline, width) - def text(self, xy, text, txt_color=(255, 255, 255), anchor='top', box_style=False): + def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False): """Adds text to an image using PIL or cv2.""" - if anchor == 'bottom': # start y from font bottom + if anchor == "bottom": # start y from font bottom w, h = self.font.getsize(text) # text width, height xy[1] += 1 - h if self.pil: @@ -230,8 +304,8 @@ class Annotator: self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color) # Using `txt_color` for background and draw fg with white color txt_color = (255, 255, 255) - if '\n' in text: - lines = text.split('\n') + if "\n" in text: + lines = text.split("\n") _, h = self.font.getsize(text) for line in lines: self.draw.text(xy, line, fill=txt_color, font=self.font) @@ -240,15 +314,13 @@ class Annotator: self.draw.text(xy, text, fill=txt_color, font=self.font) else: if box_style: - tf = max(self.lw - 1, 1) # font thickness - w, h = cv2.getTextSize(text, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height + w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height outside = xy[1] - h >= 3 p2 = xy[0] + w, xy[1] - h - 3 if outside else xy[1] + h + 3 cv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA) # filled # Using `txt_color` for background and draw fg with white color txt_color = (255, 255, 255) - tf = max(self.lw - 1, 1) # font thickness - cv2.putText(self.im, text, xy, 0, self.lw / 3, txt_color, thickness=tf, lineType=cv2.LINE_AA) + cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA) def fromarray(self, im): """Update self.im from a numpy array.""" @@ -259,27 +331,289 @@ class Annotator: """Return annotated image as array.""" return np.asarray(self.im) + def show(self, title=None): + """Show the annotated image.""" + Image.fromarray(np.asarray(self.im)[..., ::-1]).show(title) + + def save(self, filename="image.jpg"): + """Save the annotated image to 'filename'.""" + cv2.imwrite(filename, np.asarray(self.im)) + + def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5): + """ + Draw region line. + + Args: + reg_pts (list): Region Points (for line 2 points, for region 4 points) + color (tuple): Region Color value + thickness (int): Region area thickness value + """ + cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness) + + def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2): + """ + Draw centroid point and track trails. + + Args: + track (list): object tracking points for trails display + color (tuple): tracks line color + track_thickness (int): track line thickness value + """ + points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness) + cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1) + + def count_labels(self, counts=0, count_txt_size=2, color=(255, 255, 255), txt_color=(0, 0, 0)): + """ + Plot counts for object counter. + + Args: + counts (int): objects counts value + count_txt_size (int): text size for counts display + color (tuple): background color of counts display + txt_color (tuple): text color of counts display + """ + self.tf = count_txt_size + tl = self.tf or round(0.002 * (self.im.shape[0] + self.im.shape[1]) / 2) + 1 + tf = max(tl - 1, 1) + + # Get text size for in_count and out_count + t_size_in = cv2.getTextSize(str(counts), 0, fontScale=tl / 2, thickness=tf)[0] + + # Calculate positions for counts label + text_width = t_size_in[0] + text_x = (self.im.shape[1] - text_width) // 2 # Center x-coordinate + text_y = t_size_in[1] + + # Create a rounded rectangle for in_count + cv2.rectangle( + self.im, (text_x - 5, text_y - 5), (text_x + text_width + 7, text_y + t_size_in[1] + 7), color, -1 + ) + cv2.putText( + self.im, str(counts), (text_x, text_y + t_size_in[1]), 0, tl / 2, txt_color, self.tf, lineType=cv2.LINE_AA + ) + + @staticmethod + def estimate_pose_angle(a, b, c): + """ + Calculate the pose angle for object. + + Args: + a (float) : The value of pose point a + b (float): The value of pose point b + c (float): The value o pose point c + + Returns: + angle (degree): Degree value of angle between three points + """ + a, b, c = np.array(a), np.array(b), np.array(c) + radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0]) + angle = np.abs(radians * 180.0 / np.pi) + if angle > 180.0: + angle = 360 - angle + return angle + + def draw_specific_points(self, keypoints, indices=[2, 5, 7], shape=(640, 640), radius=2): + """ + Draw specific keypoints for gym steps counting. + + Args: + keypoints (list): list of keypoints data to be plotted + indices (list): keypoints ids list to be plotted + shape (tuple): imgsz for model inference + radius (int): Keypoint radius value + """ + for i, k in enumerate(keypoints): + if i in indices: + x_coord, y_coord = k[0], k[1] + if x_coord % shape[1] != 0 and y_coord % shape[0] != 0: + if len(k) == 3: + conf = k[2] + if conf < 0.5: + continue + cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, (0, 255, 0), -1, lineType=cv2.LINE_AA) + return self.im + + def plot_angle_and_count_and_stage(self, angle_text, count_text, stage_text, center_kpt, line_thickness=2): + """ + Plot the pose angle, count value and step stage. + + Args: + angle_text (str): angle value for workout monitoring + count_text (str): counts value for workout monitoring + stage_text (str): stage decision for workout monitoring + center_kpt (int): centroid pose index for workout monitoring + line_thickness (int): thickness for text display + """ + angle_text, count_text, stage_text = (f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}") + font_scale = 0.6 + (line_thickness / 10.0) + + # Draw angle + (angle_text_width, angle_text_height), _ = cv2.getTextSize(angle_text, 0, font_scale, line_thickness) + angle_text_position = (int(center_kpt[0]), int(center_kpt[1])) + angle_background_position = (angle_text_position[0], angle_text_position[1] - angle_text_height - 5) + angle_background_size = (angle_text_width + 2 * 5, angle_text_height + 2 * 5 + (line_thickness * 2)) + cv2.rectangle( + self.im, + angle_background_position, + ( + angle_background_position[0] + angle_background_size[0], + angle_background_position[1] + angle_background_size[1], + ), + (255, 255, 255), + -1, + ) + cv2.putText(self.im, angle_text, angle_text_position, 0, font_scale, (0, 0, 0), line_thickness) + + # Draw Counts + (count_text_width, count_text_height), _ = cv2.getTextSize(count_text, 0, font_scale, line_thickness) + count_text_position = (angle_text_position[0], angle_text_position[1] + angle_text_height + 20) + count_background_position = ( + angle_background_position[0], + angle_background_position[1] + angle_background_size[1] + 5, + ) + count_background_size = (count_text_width + 10, count_text_height + 10 + (line_thickness * 2)) + + cv2.rectangle( + self.im, + count_background_position, + ( + count_background_position[0] + count_background_size[0], + count_background_position[1] + count_background_size[1], + ), + (255, 255, 255), + -1, + ) + cv2.putText(self.im, count_text, count_text_position, 0, font_scale, (0, 0, 0), line_thickness) + + # Draw Stage + (stage_text_width, stage_text_height), _ = cv2.getTextSize(stage_text, 0, font_scale, line_thickness) + stage_text_position = (int(center_kpt[0]), int(center_kpt[1]) + angle_text_height + count_text_height + 40) + stage_background_position = (stage_text_position[0], stage_text_position[1] - stage_text_height - 5) + stage_background_size = (stage_text_width + 10, stage_text_height + 10) + + cv2.rectangle( + self.im, + stage_background_position, + ( + stage_background_position[0] + stage_background_size[0], + stage_background_position[1] + stage_background_size[1], + ), + (255, 255, 255), + -1, + ) + cv2.putText(self.im, stage_text, stage_text_position, 0, font_scale, (0, 0, 0), line_thickness) + + def seg_bbox(self, mask, mask_color=(255, 0, 255), det_label=None, track_label=None): + """ + Function for drawing segmented object in bounding box shape. + + Args: + mask (list): masks data list for instance segmentation area plotting + mask_color (tuple): mask foreground color + det_label (str): Detection label text + track_label (str): Tracking label text + """ + cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2) + + label = f"Track ID: {track_label}" if track_label else det_label + text_size, _ = cv2.getTextSize(label, 0, 0.7, 1) + + cv2.rectangle( + self.im, + (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10), + (int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)), + mask_color, + -1, + ) + + cv2.putText( + self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255), 2 + ) + + def plot_distance_and_line(self, distance_m, distance_mm, centroids, line_color, centroid_color): + """ + Plot the distance and line on frame. + + Args: + distance_m (float): Distance between two bbox centroids in meters. + distance_mm (float): Distance between two bbox centroids in millimeters. + centroids (list): Bounding box centroids data. + line_color (RGB): Distance line color. + centroid_color (RGB): Bounding box centroid color. + """ + (text_width_m, text_height_m), _ = cv2.getTextSize( + f"Distance M: {distance_m:.2f}m", cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2 + ) + cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 10, 25 + text_height_m + 20), (255, 255, 255), -1) + cv2.putText( + self.im, + f"Distance M: {distance_m:.2f}m", + (20, 50), + cv2.FONT_HERSHEY_SIMPLEX, + 0.8, + (0, 0, 0), + 2, + cv2.LINE_AA, + ) + + (text_width_mm, text_height_mm), _ = cv2.getTextSize( + f"Distance MM: {distance_mm:.2f}mm", cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2 + ) + cv2.rectangle(self.im, (15, 75), (15 + text_width_mm + 10, 75 + text_height_mm + 20), (255, 255, 255), -1) + cv2.putText( + self.im, + f"Distance MM: {distance_mm:.2f}mm", + (20, 100), + cv2.FONT_HERSHEY_SIMPLEX, + 0.8, + (0, 0, 0), + 2, + cv2.LINE_AA, + ) + + cv2.line(self.im, centroids[0], centroids[1], line_color, 3) + cv2.circle(self.im, centroids[0], 6, centroid_color, -1) + cv2.circle(self.im, centroids[1], 6, centroid_color, -1) + + def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255), thickness=2, pins_radius=10): + """ + Function for pinpoint human-vision eye mapping and plotting. + + Args: + box (list): Bounding box coordinates + center_point (tuple): center point for vision eye view + color (tuple): object centroid and line color value + pin_color (tuple): visioneye point color value + thickness (int): int value for line thickness + pins_radius (int): visioneye point radius value + """ + center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2) + cv2.circle(self.im, center_point, pins_radius, pin_color, -1) + cv2.circle(self.im, center_bbox, pins_radius, color, -1) + cv2.line(self.im, center_point, center_bbox, color, thickness) + @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395 @plt_settings() -def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None): +def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None): """Plot training labels including class histograms and box statistics.""" import pandas as pd import seaborn as sn # Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings - warnings.filterwarnings('ignore', category=UserWarning, message='The figure layout has changed to tight') - warnings.filterwarnings('ignore', category=FutureWarning) + warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight") + warnings.filterwarnings("ignore", category=FutureWarning) # Plot dataset labels LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ") nc = int(cls.max() + 1) # number of classes boxes = boxes[:1000000] # limit to 1M boxes - x = pd.DataFrame(boxes, columns=['x', 'y', 'width', 'height']) + x = pd.DataFrame(boxes, columns=["x", "y", "width", "height"]) # Seaborn correlogram - sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) - plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200) + sn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) + plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200) plt.close() # Matplotlib labels @@ -287,14 +621,14 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None): y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) for i in range(nc): y[2].patches[i].set_color([x / 255 for x in colors(i)]) - ax[0].set_ylabel('instances') + ax[0].set_ylabel("instances") if 0 < len(names) < 30: ax[0].set_xticks(range(len(names))) ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10) else: - ax[0].set_xlabel('classes') - sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9) - sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9) + ax[0].set_xlabel("classes") + sn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9) + sn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9) # Rectangles boxes[:, 0:2] = 0.5 # center @@ -303,21 +637,22 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None): for cls, box in zip(cls[:500], boxes[:500]): ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot ax[1].imshow(img) - ax[1].axis('off') + ax[1].axis("off") for a in [0, 1, 2, 3]: - for s in ['top', 'right', 'left', 'bottom']: + for s in ["top", "right", "left", "bottom"]: ax[a].spines[s].set_visible(False) - fname = save_dir / 'labels.jpg' + fname = save_dir / "labels.jpg" plt.savefig(fname, dpi=200) plt.close() if on_plot: on_plot(fname) -def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True): - """Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop. +def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True): + """ + Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop. This function takes a bounding box and an image, and then saves a cropped portion of the image according to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding @@ -353,27 +688,33 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad xyxy = ops.xywh2xyxy(b).long() - ops.clip_boxes(xyxy, im.shape) - crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)] + xyxy = ops.clip_boxes(xyxy, im.shape) + crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)] if save: file.parent.mkdir(parents=True, exist_ok=True) # make directory - f = str(increment_path(file).with_suffix('.jpg')) + f = str(increment_path(file).with_suffix(".jpg")) # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB return crop @threaded -def plot_images(images, - batch_idx, - cls, - bboxes=np.zeros(0, dtype=np.float32), - masks=np.zeros(0, dtype=np.uint8), - kpts=np.zeros((0, 51), dtype=np.float32), - paths=None, - fname='images.jpg', - names=None, - on_plot=None): +def plot_images( + images, + batch_idx, + cls, + bboxes=np.zeros(0, dtype=np.float32), + confs=None, + masks=np.zeros(0, dtype=np.uint8), + kpts=np.zeros((0, 51), dtype=np.float32), + paths=None, + fname="images.jpg", + names=None, + on_plot=None, + max_subplots=16, + save=True, + conf_thres=0.25, +): """Plot image grid with labels.""" if isinstance(images, torch.Tensor): images = images.cpu().float().numpy() @@ -389,21 +730,17 @@ def plot_images(images, batch_idx = batch_idx.cpu().numpy() max_size = 1920 # max image size - max_subplots = 16 # max image subplots, i.e. 4x4 bs, _, h, w = images.shape # batch size, _, height, width bs = min(bs, max_subplots) # limit plot images - ns = np.ceil(bs ** 0.5) # number of subplots (square) + ns = np.ceil(bs**0.5) # number of subplots (square) if np.max(images[0]) <= 1: images *= 255 # de-normalise (optional) # Build Image mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init - for i, im in enumerate(images): - if i == max_subplots: # if last batch has fewer images than we expect - break + for i in range(bs): x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin - im = im.transpose(1, 2, 0) - mosaic[y:y + h, x:x + w, :] = im + mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0) # Resize (optional) scale = max_size / ns / max(h, w) @@ -415,40 +752,42 @@ def plot_images(images, # Annotate fs = int((h + w) * ns * 0.01) # font size annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names) - for i in range(i + 1): + for i in range(bs): x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders if paths: annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames if len(cls) > 0: idx = batch_idx == i - classes = cls[idx].astype('int') + classes = cls[idx].astype("int") + labels = confs is None if len(bboxes): - boxes = ops.xywh2xyxy(bboxes[idx, :4]).T - labels = bboxes.shape[1] == 4 # labels if no conf column - conf = None if labels else bboxes[idx, 4] # check for confidence presence (label vs pred) - - if boxes.shape[1]: - if boxes.max() <= 1.01: # if normalized with tolerance 0.01 - boxes[[0, 2]] *= w # scale to pixels - boxes[[1, 3]] *= h + boxes = bboxes[idx] + conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred) + is_obb = boxes.shape[-1] == 5 # xywhr + boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes) + if len(boxes): + if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1 + boxes[..., 0::2] *= w # scale to pixels + boxes[..., 1::2] *= h elif scale < 1: # absolute coords need scale if image scales - boxes *= scale - boxes[[0, 2]] += x - boxes[[1, 3]] += y - for j, box in enumerate(boxes.T.tolist()): + boxes[..., :4] *= scale + boxes[..., 0::2] += x + boxes[..., 1::2] += y + for j, box in enumerate(boxes.astype(np.int64).tolist()): c = classes[j] color = colors(c) c = names.get(c, c) if names else c - if labels or conf[j] > 0.25: # 0.25 conf thresh - label = f'{c}' if labels else f'{c} {conf[j]:.1f}' - annotator.box_label(box, label, color=color) + if labels or conf[j] > conf_thres: + label = f"{c}" if labels else f"{c} {conf[j]:.1f}" + annotator.box_label(box, label, color=color, rotated=is_obb) + elif len(classes): for c in classes: color = colors(c) c = names.get(c, c) if names else c - annotator.text((x, y), f'{c}', txt_color=color, box_style=True) + annotator.text((x, y), f"{c}", txt_color=color, box_style=True) # Plot keypoints if len(kpts): @@ -462,7 +801,7 @@ def plot_images(images, kpts_[..., 0] += x kpts_[..., 1] += y for j in range(len(kpts_)): - if labels or conf[j] > 0.25: # 0.25 conf thresh + if labels or conf[j] > conf_thres: annotator.kpts(kpts_[j]) # Plot masks @@ -477,8 +816,8 @@ def plot_images(images, image_masks = np.where(image_masks == index, 1.0, 0.0) im = np.asarray(annotator.im).copy() - for j, box in enumerate(boxes.T.tolist()): - if labels or conf[j] > 0.25: # 0.25 conf thresh + for j in range(len(image_masks)): + if labels or conf[j] > conf_thres: color = colors(classes[j]) mh, mw = image_masks[j].shape if mh != h or mw != w: @@ -488,27 +827,42 @@ def plot_images(images, else: mask = image_masks[j].astype(bool) with contextlib.suppress(Exception): - im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6 + im[y : y + h, x : x + w, :][mask] = ( + im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6 + ) annotator.fromarray(im) + if not save: + return np.asarray(annotator.im) annotator.im.save(fname) # save if on_plot: on_plot(fname) @plt_settings() -def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None): +def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None): """ - Plot training results from results CSV file. + Plot training results from a results CSV file. The function supports various types of data including segmentation, + pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located. + + Args: + file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'. + dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''. + segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False. + pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False. + classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False. + on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument. + Defaults to None. Example: ```python from ultralytics.utils.plotting import plot_results - plot_results('path/to/results.csv') + plot_results('path/to/results.csv', segment=True) ``` """ import pandas as pd from scipy.ndimage import gaussian_filter1d + save_dir = Path(file).parent if file else Path(dir) if classify: fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True) @@ -523,31 +877,121 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7] ax = ax.ravel() - files = list(save_dir.glob('results*.csv')) - assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.' + files = list(save_dir.glob("results*.csv")) + assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot." for f in files: try: data = pd.read_csv(f) s = [x.strip() for x in data.columns] x = data.values[:, 0] for i, j in enumerate(index): - y = data.values[:, j].astype('float') + y = data.values[:, j].astype("float") # y[y == 0] = np.nan # don't show zero values - ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) # actual results - ax[i].plot(x, gaussian_filter1d(y, sigma=3), ':', label='smooth', linewidth=2) # smoothing line + ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results + ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line ax[i].set_title(s[j], fontsize=12) # if j in [8, 9, 10]: # share train and val loss y axes # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) except Exception as e: - LOGGER.warning(f'WARNING: Plotting error for {f}: {e}') + LOGGER.warning(f"WARNING: Plotting error for {f}: {e}") ax[1].legend() - fname = save_dir / 'results.png' + fname = save_dir / "results.png" fig.savefig(fname, dpi=200) plt.close() if on_plot: on_plot(fname) +def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"): + """ + Plots a scatter plot with points colored based on a 2D histogram. + + Args: + v (array-like): Values for the x-axis. + f (array-like): Values for the y-axis. + bins (int, optional): Number of bins for the histogram. Defaults to 20. + cmap (str, optional): Colormap for the scatter plot. Defaults to 'viridis'. + alpha (float, optional): Alpha for the scatter plot. Defaults to 0.8. + edgecolors (str, optional): Edge colors for the scatter plot. Defaults to 'none'. + + Examples: + >>> v = np.random.rand(100) + >>> f = np.random.rand(100) + >>> plt_color_scatter(v, f) + """ + + # Calculate 2D histogram and corresponding colors + hist, xedges, yedges = np.histogram2d(v, f, bins=bins) + colors = [ + hist[ + min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1), + min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1), + ] + for i in range(len(v)) + ] + + # Scatter plot + plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors) + + +def plot_tune_results(csv_file="tune_results.csv"): + """ + Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key + in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots. + + Args: + csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'. + + Examples: + >>> plot_tune_results('path/to/tune_results.csv') + """ + + import pandas as pd + from scipy.ndimage import gaussian_filter1d + + # Scatter plots for each hyperparameter + csv_file = Path(csv_file) + data = pd.read_csv(csv_file) + num_metrics_columns = 1 + keys = [x.strip() for x in data.columns][num_metrics_columns:] + x = data.values + fitness = x[:, 0] # fitness + j = np.argmax(fitness) # max fitness index + n = math.ceil(len(keys) ** 0.5) # columns and rows in plot + plt.figure(figsize=(10, 10), tight_layout=True) + for i, k in enumerate(keys): + v = x[:, i + num_metrics_columns] + mu = v[j] # best single result + plt.subplot(n, n, i + 1) + plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none") + plt.plot(mu, fitness.max(), "k+", markersize=15) + plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters + plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8 + if i % n != 0: + plt.yticks([]) + + file = csv_file.with_name("tune_scatter_plots.png") # filename + plt.savefig(file, dpi=200) + plt.close() + LOGGER.info(f"Saved {file}") + + # Fitness vs iteration + x = range(1, len(fitness) + 1) + plt.figure(figsize=(10, 6), tight_layout=True) + plt.plot(x, fitness, marker="o", linestyle="none", label="fitness") + plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line + plt.title("Fitness vs Iteration") + plt.xlabel("Iteration") + plt.ylabel("Fitness") + plt.grid(True) + plt.legend() + + file = csv_file.with_name("tune_fitness.png") # filename + plt.savefig(file, dpi=200) + plt.close() + LOGGER.info(f"Saved {file}") + + def output_to_target(output, max_det=300): """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting.""" targets = [] @@ -556,10 +1000,21 @@ def output_to_target(output, max_det=300): j = torch.full((conf.shape[0], 1), i) targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1)) targets = torch.cat(targets, 0).numpy() - return targets[:, 0], targets[:, 1], targets[:, 2:] + return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1] -def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')): +def output_to_rotated_target(output, max_det=300): + """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting.""" + targets = [] + for i, o in enumerate(output): + box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1) + j = torch.full((conf.shape[0], 1), i) + targets.append(torch.cat((j, cls, box, angle, conf), 1)) + targets = torch.cat(targets, 0).numpy() + return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1] + + +def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")): """ Visualize feature maps of a given model module during inference. @@ -570,23 +1025,23 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec n (int, optional): Maximum number of feature maps to plot. Defaults to 32. save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp'). """ - for m in ['Detect', 'Pose', 'Segment']: + for m in ["Detect", "Pose", "Segment"]: if m in module_type: return - batch, channels, height, width = x.shape # batch, channels, height, width + _, channels, height, width = x.shape # batch, channels, height, width if height > 1 and width > 1: f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels n = min(n, channels) # number of plots - fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols + _, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols ax = ax.ravel() plt.subplots_adjust(wspace=0.05, hspace=0.05) for i in range(n): ax[i].imshow(blocks[i].squeeze()) # cmap='gray' - ax[i].axis('off') + ax[i].axis("off") - LOGGER.info(f'Saving {f}... ({n}/{channels})') - plt.savefig(f, dpi=300, bbox_inches='tight') + LOGGER.info(f"Saving {f}... ({n}/{channels})") + plt.savefig(f, dpi=300, bbox_inches="tight") plt.close() - np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save + np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save diff --git a/ultralytics/utils/tal.py b/ultralytics/utils/tal.py index 432e7a7..b11c2b2 100644 --- a/ultralytics/utils/tal.py +++ b/ultralytics/utils/tal.py @@ -4,65 +4,18 @@ import torch import torch.nn as nn from .checks import check_version -from .metrics import bbox_iou +from .metrics import bbox_iou, probiou +from .ops import xywhr2xyxyxyxy -TORCH_1_10 = check_version(torch.__version__, '1.10.0') - - -def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): - """ - Select the positive anchor center in gt. - - Args: - xy_centers (Tensor): shape(h*w, 2) - gt_bboxes (Tensor): shape(b, n_boxes, 4) - - Returns: - (Tensor): shape(b, n_boxes, h*w) - """ - n_anchors = xy_centers.shape[0] - bs, n_boxes, _ = gt_bboxes.shape - lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom - bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1) - # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype) - return bbox_deltas.amin(3).gt_(eps) - - -def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): - """ - If an anchor box is assigned to multiple gts, the one with the highest IoI will be selected. - - Args: - mask_pos (Tensor): shape(b, n_max_boxes, h*w) - overlaps (Tensor): shape(b, n_max_boxes, h*w) - - Returns: - target_gt_idx (Tensor): shape(b, h*w) - fg_mask (Tensor): shape(b, h*w) - mask_pos (Tensor): shape(b, n_max_boxes, h*w) - """ - # (b, n_max_boxes, h*w) -> (b, h*w) - fg_mask = mask_pos.sum(-2) - if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes - mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w) - max_overlaps_idx = overlaps.argmax(1) # (b, h*w) - - is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device) - is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1) - - mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w) - fg_mask = mask_pos.sum(-2) - # Find each grid serve which gt(index) - target_gt_idx = mask_pos.argmax(-2) # (b, h*w) - return target_gt_idx, fg_mask, mask_pos +TORCH_1_10 = check_version(torch.__version__, "1.10.0") class TaskAlignedAssigner(nn.Module): """ A task-aligned assigner for object detection. - This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, - which combines both classification and localization information. + This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both + classification and localization information. Attributes: topk (int): The number of top candidates to consider. @@ -85,8 +38,8 @@ class TaskAlignedAssigner(nn.Module): @torch.no_grad() def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): """ - Compute the task-aligned assignment. - Reference https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py + Compute the task-aligned assignment. Reference code is available at + https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py. Args: pd_scores (Tensor): shape(bs, num_total_anchors, num_classes) @@ -103,19 +56,24 @@ class TaskAlignedAssigner(nn.Module): fg_mask (Tensor): shape(bs, num_total_anchors) target_gt_idx (Tensor): shape(bs, num_total_anchors) """ - self.bs = pd_scores.size(0) - self.n_max_boxes = gt_bboxes.size(1) + self.bs = pd_scores.shape[0] + self.n_max_boxes = gt_bboxes.shape[1] if self.n_max_boxes == 0: device = gt_bboxes.device - return (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), torch.zeros_like(pd_bboxes).to(device), - torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device), - torch.zeros_like(pd_scores[..., 0]).to(device)) + return ( + torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), + torch.zeros_like(pd_bboxes).to(device), + torch.zeros_like(pd_scores).to(device), + torch.zeros_like(pd_scores[..., 0]).to(device), + torch.zeros_like(pd_scores[..., 0]).to(device), + ) - mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, - mask_gt) + mask_pos, align_metric, overlaps = self.get_pos_mask( + pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt + ) - target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes) + target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes) # Assigned target target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask) @@ -131,7 +89,7 @@ class TaskAlignedAssigner(nn.Module): def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt): """Get in_gts mask, (b, max_num_obj, h*w).""" - mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes) + mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes) # Get anchor_align metric, (b, max_num_obj, h*w) align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt) # Get topk_metric mask, (b, max_num_obj, h*w) @@ -157,11 +115,15 @@ class TaskAlignedAssigner(nn.Module): # (b, max_num_obj, 1, 4), (b, 1, h*w, 4) pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt] gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt] - overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0) + overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes) align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) return align_metric, overlaps + def iou_calculation(self, gt_bboxes, pd_bboxes): + """IoU calculation for horizontal bounding boxes.""" + return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0) + def select_topk_candidates(self, metrics, largest=True, topk_mask=None): """ Select the top-k candidates based on the given metrics. @@ -191,9 +153,9 @@ class TaskAlignedAssigner(nn.Module): ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device) for k in range(self.topk): # Expand topk_idxs for each value of k and add 1 at the specified positions - count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones) + count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones) # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device)) - # filter invalid bboxes + # Filter invalid bboxes count_tensor.masked_fill_(count_tensor > 1, 0) return count_tensor.to(metrics.dtype) @@ -229,15 +191,17 @@ class TaskAlignedAssigner(nn.Module): target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w) # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4) - target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx] + target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx] # Assigned target scores target_labels.clamp_(0) # 10x faster than F.one_hot() - target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes), - dtype=torch.int64, - device=target_labels.device) # (b, h*w, 80) + target_scores = torch.zeros( + (target_labels.shape[0], target_labels.shape[1], self.num_classes), + dtype=torch.int64, + device=target_labels.device, + ) # (b, h*w, 80) target_scores.scatter_(2, target_labels.unsqueeze(-1), 1) fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80) @@ -245,6 +209,87 @@ class TaskAlignedAssigner(nn.Module): return target_labels, target_bboxes, target_scores + @staticmethod + def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): + """ + Select the positive anchor center in gt. + + Args: + xy_centers (Tensor): shape(h*w, 2) + gt_bboxes (Tensor): shape(b, n_boxes, 4) + + Returns: + (Tensor): shape(b, n_boxes, h*w) + """ + n_anchors = xy_centers.shape[0] + bs, n_boxes, _ = gt_bboxes.shape + lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom + bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1) + # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype) + return bbox_deltas.amin(3).gt_(eps) + + @staticmethod + def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): + """ + If an anchor box is assigned to multiple gts, the one with the highest IoU will be selected. + + Args: + mask_pos (Tensor): shape(b, n_max_boxes, h*w) + overlaps (Tensor): shape(b, n_max_boxes, h*w) + + Returns: + target_gt_idx (Tensor): shape(b, h*w) + fg_mask (Tensor): shape(b, h*w) + mask_pos (Tensor): shape(b, n_max_boxes, h*w) + """ + # (b, n_max_boxes, h*w) -> (b, h*w) + fg_mask = mask_pos.sum(-2) + if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes + mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w) + max_overlaps_idx = overlaps.argmax(1) # (b, h*w) + + is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device) + is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1) + + mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w) + fg_mask = mask_pos.sum(-2) + # Find each grid serve which gt(index) + target_gt_idx = mask_pos.argmax(-2) # (b, h*w) + return target_gt_idx, fg_mask, mask_pos + + +class RotatedTaskAlignedAssigner(TaskAlignedAssigner): + def iou_calculation(self, gt_bboxes, pd_bboxes): + """IoU calculation for rotated bounding boxes.""" + return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0) + + @staticmethod + def select_candidates_in_gts(xy_centers, gt_bboxes): + """ + Select the positive anchor center in gt for rotated bounding boxes. + + Args: + xy_centers (Tensor): shape(h*w, 2) + gt_bboxes (Tensor): shape(b, n_boxes, 5) + + Returns: + (Tensor): shape(b, n_boxes, h*w) + """ + # (b, n_boxes, 5) --> (b, n_boxes, 4, 2) + corners = xywhr2xyxyxyxy(gt_bboxes) + # (b, n_boxes, 1, 2) + a, b, _, d = corners.split(1, dim=-2) + ab = b - a + ad = d - a + + # (b, n_boxes, h*w, 2) + ap = xy_centers - a + norm_ab = (ab * ab).sum(dim=-1) + norm_ad = (ad * ad).sum(dim=-1) + ap_dot_ab = (ap * ab).sum(dim=-1) + ap_dot_ad = (ap * ad).sum(dim=-1) + return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box + def make_anchors(feats, strides, grid_cell_offset=0.5): """Generate anchors from features.""" @@ -255,7 +300,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5): _, _, h, w = feats[i].shape sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y - sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx) + sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) return torch.cat(anchor_points), torch.cat(stride_tensor) @@ -263,7 +308,8 @@ def make_anchors(feats, strides, grid_cell_offset=0.5): def dist2bbox(distance, anchor_points, xywh=True, dim=-1): """Transform distance(ltrb) to box(xywh or xyxy).""" - lt, rb = distance.chunk(2, dim) + assert(distance.shape[dim] == 4) + lt, rb = distance.split([2, 2], dim) x1y1 = anchor_points - lt x2y2 = anchor_points + rb if xywh: @@ -277,3 +323,23 @@ def bbox2dist(anchor_points, bbox, reg_max): """Transform bbox(xyxy) to dist(ltrb).""" x1y1, x2y2 = bbox.chunk(2, -1) return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb) + + +def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1): + """ + Decode predicted object bounding box coordinates from anchor points and distribution. + + Args: + pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4). + pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1). + anchor_points (torch.Tensor): Anchor points, (h*w, 2). + Returns: + (torch.Tensor): Predicted rotated bounding boxes, (bs, h*w, 4). + """ + lt, rb = pred_dist.split(2, dim=dim) + cos, sin = torch.cos(pred_angle), torch.sin(pred_angle) + # (bs, h*w, 1) + xf, yf = ((rb - lt) / 2).split(1, dim=dim) + x, y = xf * cos - yf * sin, xf * sin + yf * cos + xy = torch.cat([x, y], dim=dim) + anchor_points + return torch.cat([xy, lt + rb], dim=dim) diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index def7442..d476e1f 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -2,7 +2,6 @@ import math import os -import platform import random import time from contextlib import contextmanager @@ -15,17 +14,23 @@ import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +import torchvision -from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, __version__ -from ultralytics.utils.checks import check_version +from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__ +from ultralytics.utils.checks import PYTHON_VERSION, check_version try: import thop except ImportError: thop = None -TORCH_1_9 = check_version(torch.__version__, '1.9.0') -TORCH_2_0 = check_version(torch.__version__, '2.0.0') +# Version checks (all default to version>=min_version) +TORCH_1_9 = check_version(torch.__version__, "1.9.0") +TORCH_1_13 = check_version(torch.__version__, "1.13.0") +TORCH_2_0 = check_version(torch.__version__, "2.0.0") +TORCHVISION_0_10 = check_version(torchvision.__version__, "0.10.0") +TORCHVISION_0_11 = check_version(torchvision.__version__, "0.11.0") +TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0") @contextmanager @@ -44,7 +49,10 @@ def smart_inference_mode(): def decorate(fn): """Applies appropriate torch decorator for inference mode based on torch version.""" - return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn) + if TORCH_1_9 and torch.is_inference_mode_enabled(): + return fn # already in inference_mode, act as a pass-through + else: + return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn) return decorate @@ -53,59 +61,102 @@ def get_cpu_info(): """Return a string with system CPU information, i.e. 'Apple M2'.""" import cpuinfo # pip install py-cpuinfo - k = 'brand_raw', 'hardware_raw', 'arch_string_raw' # info keys sorted by preference (not all keys always available) + k = "brand_raw", "hardware_raw", "arch_string_raw" # info keys sorted by preference (not all keys always available) info = cpuinfo.get_cpu_info() # info dict - string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], 'unknown') - return string.replace('(R)', '').replace('CPU ', '').replace('@ ', '') + string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown") + return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "") -def select_device(device='', batch=0, newline=False, verbose=True): - """Selects PyTorch Device. Options are device = None or 'cpu' or 0 or '0' or '0,1,2,3'.""" - s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} ' +def select_device(device="", batch=0, newline=False, verbose=True): + """ + Selects the appropriate PyTorch device based on the provided arguments. + + The function takes a string specifying the device or a torch.device object and returns a torch.device object + representing the selected device. The function also validates the number of available devices and raises an + exception if the requested device(s) are not available. + + Args: + device (str | torch.device, optional): Device string or torch.device object. + Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects + the first available GPU, or CPU if no GPU is available. + batch (int, optional): Batch size being used in your model. Defaults to 0. + newline (bool, optional): If True, adds a newline at the end of the log string. Defaults to False. + verbose (bool, optional): If True, logs the device information. Defaults to True. + + Returns: + (torch.device): Selected device. + + Raises: + ValueError: If the specified device is not available or if the batch size is not a multiple of the number of + devices when using multiple GPUs. + + Examples: + >>> select_device('cuda:0') + device(type='cuda', index=0) + + >>> select_device('cpu') + device(type='cpu') + + Note: + Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use. + """ + + if isinstance(device, torch.device): + return device + + s = f"Ultralytics YOLOv{__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} " device = str(device).lower() - for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ': - device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1' - cpu = device == 'cpu' - mps = device == 'mps' # Apple Metal Performance Shaders (MPS) + for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ": + device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1' + cpu = device == "cpu" + mps = device in ("mps", "mps:0") # Apple Metal Performance Shaders (MPS) if cpu or mps: - os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False elif device: # non-cpu device requested - if device == 'cuda': - device = '0' - visible = os.environ.get('CUDA_VISIBLE_DEVICES', None) - os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available() - if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))): + if device == "cuda": + device = "0" + visible = os.environ.get("CUDA_VISIBLE_DEVICES", None) + os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available() + if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))): LOGGER.info(s) - install = 'See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no ' \ - 'CUDA devices are seen by torch.\n' if torch.cuda.device_count() == 0 else '' - raise ValueError(f"Invalid CUDA 'device={device}' requested." - f" Use 'device=cpu' or pass valid CUDA device(s) if available," - f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n" - f'\ntorch.cuda.is_available(): {torch.cuda.is_available()}' - f'\ntorch.cuda.device_count(): {torch.cuda.device_count()}' - f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n" - f'{install}') + install = ( + "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no " + "CUDA devices are seen by torch.\n" + if torch.cuda.device_count() == 0 + else "" + ) + raise ValueError( + f"Invalid CUDA 'device={device}' requested." + f" Use 'device=cpu' or pass valid CUDA device(s) if available," + f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n" + f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}" + f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}" + f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n" + f"{install}" + ) if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available - devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7 + devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7 n = len(devices) # device count if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count - raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or " - f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.") - space = ' ' * (len(s) + 1) + raise ValueError( + f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or " + f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}." + ) + space = " " * (len(s) + 1) for i, d in enumerate(devices): p = torch.cuda.get_device_properties(i) s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB - arg = 'cuda:0' - elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available() and TORCH_2_0: + arg = "cuda:0" + elif mps and TORCH_2_0 and torch.backends.mps.is_available(): # Prefer MPS if available - s += f'MPS ({get_cpu_info()})\n' - arg = 'mps' + s += f"MPS ({get_cpu_info()})\n" + arg = "mps" else: # revert to CPU - s += f'CPU ({get_cpu_info()})\n' - arg = 'cpu' + s += f"CPU ({get_cpu_info()})\n" + arg = "cpu" - if verbose and RANK == -1: + if verbose: LOGGER.info(s if newline else s.rstrip()) return torch.device(arg) @@ -119,14 +170,20 @@ def time_sync(): def fuse_conv_and_bn(conv, bn): """Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/.""" - fusedconv = nn.Conv2d(conv.in_channels, - conv.out_channels, - kernel_size=conv.kernel_size, - stride=conv.stride, - padding=conv.padding, - dilation=conv.dilation, - groups=conv.groups, - bias=True).requires_grad_(False).to(conv.weight.device) + fusedconv = ( + nn.Conv2d( + conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + dilation=conv.dilation, + groups=conv.groups, + bias=True, + ) + .requires_grad_(False) + .to(conv.weight.device) + ) # Prepare filters w_conv = conv.weight.clone().view(conv.out_channels, -1) @@ -134,7 +191,7 @@ def fuse_conv_and_bn(conv, bn): fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) # Prepare spatial bias - b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias + b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) @@ -143,15 +200,21 @@ def fuse_conv_and_bn(conv, bn): def fuse_deconv_and_bn(deconv, bn): """Fuse ConvTranspose2d() and BatchNorm2d() layers.""" - fuseddconv = nn.ConvTranspose2d(deconv.in_channels, - deconv.out_channels, - kernel_size=deconv.kernel_size, - stride=deconv.stride, - padding=deconv.padding, - output_padding=deconv.output_padding, - dilation=deconv.dilation, - groups=deconv.groups, - bias=True).requires_grad_(False).to(deconv.weight.device) + fuseddconv = ( + nn.ConvTranspose2d( + deconv.in_channels, + deconv.out_channels, + kernel_size=deconv.kernel_size, + stride=deconv.stride, + padding=deconv.padding, + output_padding=deconv.output_padding, + dilation=deconv.dilation, + groups=deconv.groups, + bias=True, + ) + .requires_grad_(False) + .to(deconv.weight.device) + ) # Prepare filters w_deconv = deconv.weight.clone().view(deconv.out_channels, -1) @@ -159,7 +222,7 @@ def fuse_deconv_and_bn(deconv, bn): fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape)) # Prepare spatial bias - b_conv = torch.zeros(deconv.weight.size(1), device=deconv.weight.device) if deconv.bias is None else deconv.bias + b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) @@ -167,7 +230,11 @@ def fuse_deconv_and_bn(deconv, bn): def model_info(model, detailed=False, verbose=True, imgsz=640): - """Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320].""" + """ + Model information. + + imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]. + """ if not verbose: return n_p = get_num_params(model) # number of parameters @@ -175,18 +242,21 @@ def model_info(model, detailed=False, verbose=True, imgsz=640): n_l = len(list(model.modules())) # number of layers if detailed: LOGGER.info( - f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}") + f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}" + ) for i, (name, p) in enumerate(model.named_parameters()): - name = name.replace('module_list.', '') - LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g %10s' % - (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype)) + name = name.replace("module_list.", "") + LOGGER.info( + "%5g %40s %9s %12g %20s %10.3g %10.3g %10s" + % (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype) + ) flops = get_flops(model, imgsz) - fused = ' (fused)' if getattr(model, 'is_fused', lambda: False)() else '' - fs = f', {flops:.1f} GFLOPs' if flops else '' - yaml_file = getattr(model, 'yaml_file', '') or getattr(model, 'yaml', {}).get('yaml_file', '') - model_name = Path(yaml_file).stem.replace('yolo', 'YOLO') or 'Model' - LOGGER.info(f'{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}') + fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else "" + fs = f", {flops:.1f} GFLOPs" if flops else "" + yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "") + model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model" + LOGGER.info(f"{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}") return n_l, n_p, n_g, flops @@ -204,37 +274,53 @@ def model_info_for_loggers(trainer): """ Return model info dict with useful model information. - Example for YOLOv8n: - {'model/parameters': 3151904, - 'model/GFLOPs': 8.746, - 'model/speed_ONNX(ms)': 41.244, - 'model/speed_TensorRT(ms)': 3.211, - 'model/speed_PyTorch(ms)': 18.755} + Example: + YOLOv8n info for loggers + ```python + results = {'model/parameters': 3151904, + 'model/GFLOPs': 8.746, + 'model/speed_ONNX(ms)': 41.244, + 'model/speed_TensorRT(ms)': 3.211, + 'model/speed_PyTorch(ms)': 18.755} + ``` """ if trainer.args.profile: # profile ONNX and TensorRT times from ultralytics.utils.benchmarks import ProfileModels + results = ProfileModels([trainer.last], device=trainer.device).profile()[0] - results.pop('model/name') + results.pop("model/name") else: # only return PyTorch times from most recent validation results = { - 'model/parameters': get_num_params(trainer.model), - 'model/GFLOPs': round(get_flops(trainer.model), 3)} - results['model/speed_PyTorch(ms)'] = round(trainer.validator.speed['inference'], 3) + "model/parameters": get_num_params(trainer.model), + "model/GFLOPs": round(get_flops(trainer.model), 3), + } + results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3) return results def get_flops(model, imgsz=640): """Return a YOLO model's FLOPs.""" + if not thop: + return 0.0 # if not installed return 0.0 GFLOPs + try: model = de_parallel(model) p = next(model.parameters()) - stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride - im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format - flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 if thop else 0 # stride GFLOPs - imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float - return flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] # expand if int/float + try: + # Use stride size for input tensor + # stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride + # im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format + # flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs + # return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs + raise Exception + except Exception: + # Use actual image size for input tensor (i.e. required for RTDETR models) + im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format + return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs except Exception: - return 0 + return 0.0 def get_flops_with_torch_profiler(model, imgsz=640): @@ -242,11 +328,11 @@ def get_flops_with_torch_profiler(model, imgsz=640): if TORCH_2_0: model = de_parallel(model) p = next(model.parameters()) - stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride + stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format with torch.profiler.profile(with_flops=True) as prof: model(im) - flops = sum(x.flops for x in prof.key_averages()) / 1E9 + flops = sum(x.flops for x in prof.key_averages()) / 1e9 imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs return flops @@ -266,13 +352,15 @@ def initialize_weights(model): m.inplace = True -def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416) - # Scales img(bs,3,y,x) by ratio constrained to gs-multiple +def scale_img(img, ratio=1.0, same_shape=False, gs=32): + """Scales and pads an image tensor of shape img(bs,3,y,x) based on given ratio and grid size gs, optionally + retaining the original shape. + """ if ratio == 1.0: return img h, w = img.shape[2:] s = (int(h * ratio), int(w * ratio)) # new size - img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize + img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize if not same_shape: # pad/crop img h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w)) return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean @@ -288,7 +376,7 @@ def make_divisible(x, divisor): def copy_attr(a, b, include=(), exclude=()): """Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes.""" for k, v in b.__dict__.items(): - if (len(include) and k not in include) or k.startswith('_') or k in exclude: + if (len(include) and k not in include) or k.startswith("_") or k in exclude: continue else: setattr(a, k, v) @@ -296,7 +384,7 @@ def copy_attr(a, b, include=(), exclude=()): def get_latest_opset(): """Return second-most (for maturity) recently supported ONNX opset by this version of torch.""" - return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) - 1 # opset + return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1 # opset def intersect_dicts(da, db, exclude=()): @@ -316,7 +404,7 @@ def de_parallel(model): def one_cycle(y1=0.0, y2=1.0, steps=100): """Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.""" - return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1 + return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1 def init_seeds(seed=0, deterministic=False): @@ -331,10 +419,10 @@ def init_seeds(seed=0, deterministic=False): if TORCH_2_0: torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible torch.backends.cudnn.deterministic = True - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' - os.environ['PYTHONHASHSEED'] = str(seed) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + os.environ["PYTHONHASHSEED"] = str(seed) else: - LOGGER.warning('WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.') + LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.") else: torch.use_deterministic_algorithms(False) torch.backends.cudnn.deterministic = False @@ -369,13 +457,13 @@ class ModelEMA: v += (1 - d) * msd[k].detach() # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}' - def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): + def update_attr(self, model, include=(), exclude=("process_group", "reducer")): """Updates attributes and saves stripped model with optimizer removed.""" if self.enabled: copy_attr(self.ema, model, include, exclude) -def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None: +def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None: """ Strip optimizer from 'f' to finalize training, optionally save as 's'. @@ -395,32 +483,26 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None: strip_optimizer(f) ``` """ - # Use dill (if exists) to serialize the lambda functions where pickle does not do this - try: - import dill as pickle - except ImportError: - import pickle - - x = torch.load(f, map_location=torch.device('cpu')) - if 'model' not in x: - LOGGER.info(f'Skipping {f}, not a valid Ultralytics model.') + x = torch.load(f, map_location=torch.device("cpu")) + if "model" not in x: + LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.") return - if hasattr(x['model'], 'args'): - x['model'].args = dict(x['model'].args) # convert from IterableSimpleNamespace to dict - args = {**DEFAULT_CFG_DICT, **x['train_args']} if 'train_args' in x else None # combine args - if x.get('ema'): - x['model'] = x['ema'] # replace model with ema - for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys + if hasattr(x["model"], "args"): + x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict + args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args + if x.get("ema"): + x["model"] = x["ema"] # replace model with ema + for k in "optimizer", "best_fitness", "ema", "updates": # keys x[k] = None - x['epoch'] = -1 - x['model'].half() # to FP16 - for p in x['model'].parameters(): + x["epoch"] = -1 + x["model"].half() # to FP16 + for p in x["model"].parameters(): p.requires_grad = False - x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys + x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys # x['model'].args = x['train_args'] - torch.save(x, s or f, pickle_module=pickle) - mb = os.path.getsize(s or f) / 1E6 # filesize + torch.save(x, s or f) + mb = os.path.getsize(s or f) / 1e6 # file size LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") @@ -441,18 +523,20 @@ def profile(input, ops, n=10, device=None): results = [] if not isinstance(device, torch.device): device = select_device(device) - LOGGER.info(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}" - f"{'input':>24s}{'output':>24s}") + LOGGER.info( + f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}" + f"{'input':>24s}{'output':>24s}" + ) for x in input if isinstance(input, list) else [input]: x = x.to(device) x.requires_grad = True for m in ops if isinstance(ops, list) else [ops]: - m = m.to(device) if hasattr(m, 'to') else m # device - m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m + m = m.to(device) if hasattr(m, "to") else m # device + m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward try: - flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1E9 * 2 if thop else 0 # GFLOPs + flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs except Exception: flops = 0 @@ -466,13 +550,13 @@ def profile(input, ops, n=10, device=None): t[2] = time_sync() except Exception: # no backward method # print(e) # for debug - t[2] = float('nan') + t[2] = float("nan") tf += (t[1] - t[0]) * 1000 / n # ms per op forward tb += (t[2] - t[1]) * 1000 / n # ms per op backward - mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB) - s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes + mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 # (GB) + s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters - LOGGER.info(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}') + LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}") results.append([p, flops, mem, tf, tb, s_in, s_out]) except Exception as e: LOGGER.info(e) @@ -482,25 +566,23 @@ def profile(input, ops, n=10, device=None): class EarlyStopping: - """ - Early stopping class that stops training when a specified number of epochs have passed without improvement. - """ + """Early stopping class that stops training when a specified number of epochs have passed without improvement.""" def __init__(self, patience=50): """ - Initialize early stopping object + Initialize early stopping object. Args: patience (int, optional): Number of epochs to wait after fitness stops improving before stopping. """ self.best_fitness = 0.0 # i.e. mAP self.best_epoch = 0 - self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop + self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop self.possible_stop = False # possible stop may occur next epoch def __call__(self, epoch, fitness): """ - Check whether to stop training + Check whether to stop training. Args: epoch (int): Current epoch of training @@ -519,8 +601,10 @@ class EarlyStopping: self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch stop = delta >= self.patience # stop training if patience exceeded if stop: - LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. ' - f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n' - f'To update EarlyStopping(patience={self.patience}) pass a new patience value, ' - f'i.e. `patience=300` or use `patience=0` to disable EarlyStopping.') + LOGGER.info( + f"Stopping training early as no improvement observed in last {self.patience} epochs. " + f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n" + f"To update EarlyStopping(patience={self.patience}) pass a new patience value, " + f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping." + ) return stop diff --git a/ultralytics/utils/triton.py b/ultralytics/utils/triton.py new file mode 100644 index 0000000..3f873a6 --- /dev/null +++ b/ultralytics/utils/triton.py @@ -0,0 +1,92 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from typing import List +from urllib.parse import urlsplit + +import numpy as np + + +class TritonRemoteModel: + """ + Client for interacting with a remote Triton Inference Server model. + + Attributes: + endpoint (str): The name of the model on the Triton server. + url (str): The URL of the Triton server. + triton_client: The Triton client (either HTTP or gRPC). + InferInput: The input class for the Triton client. + InferRequestedOutput: The output request class for the Triton client. + input_formats (List[str]): The data types of the model inputs. + np_input_formats (List[type]): The numpy data types of the model inputs. + input_names (List[str]): The names of the model inputs. + output_names (List[str]): The names of the model outputs. + """ + + def __init__(self, url: str, endpoint: str = "", scheme: str = ""): + """ + Initialize the TritonRemoteModel. + + Arguments may be provided individually or parsed from a collective 'url' argument of the form + ://// + + Args: + url (str): The URL of the Triton server. + endpoint (str): The name of the model on the Triton server. + scheme (str): The communication scheme ('http' or 'grpc'). + """ + if not endpoint and not scheme: # Parse all args from URL string + splits = urlsplit(url) + endpoint = splits.path.strip("/").split("/")[0] + scheme = splits.scheme + url = splits.netloc + + self.endpoint = endpoint + self.url = url + + # Choose the Triton client based on the communication scheme + if scheme == "http": + import tritonclient.http as client # noqa + + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) + config = self.triton_client.get_model_config(endpoint) + else: + import tritonclient.grpc as client # noqa + + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) + config = self.triton_client.get_model_config(endpoint, as_json=True)["config"] + + # Sort output names alphabetically, i.e. 'output0', 'output1', etc. + config["output"] = sorted(config["output"], key=lambda x: x.get("name")) + + # Define model attributes + type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8} + self.InferRequestedOutput = client.InferRequestedOutput + self.InferInput = client.InferInput + self.input_formats = [x["data_type"] for x in config["input"]] + self.np_input_formats = [type_map[x] for x in self.input_formats] + self.input_names = [x["name"] for x in config["input"]] + self.output_names = [x["name"] for x in config["output"]] + + def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]: + """ + Call the model with the given inputs. + + Args: + *inputs (List[np.ndarray]): Input data to the model. + + Returns: + (List[np.ndarray]): Model outputs. + """ + infer_inputs = [] + input_format = inputs[0].dtype + for i, x in enumerate(inputs): + if x.dtype != self.np_input_formats[i]: + x = x.astype(self.np_input_formats[i]) + infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", "")) + infer_input.set_data_from_numpy(x) + infer_inputs.append(infer_input) + + infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names] + outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs) + + return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names] diff --git a/ultralytics/utils/tuner.py b/ultralytics/utils/tuner.py index 015e596..305c60a 100644 --- a/ultralytics/utils/tuner.py +++ b/ultralytics/utils/tuner.py @@ -2,16 +2,13 @@ import subprocess -from ultralytics.cfg import TASK2DATA, TASK2METRIC -from ultralytics.utils import DEFAULT_CFG_DICT, LOGGER, NUM_THREADS +from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_save_dir +from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks -def run_ray_tune(model, - space: dict = None, - grace_period: int = 10, - gpu_per_trial: int = None, - max_samples: int = 10, - **train_args): +def run_ray_tune( + model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args +): """ Runs hyperparameter tuning using Ray Tune. @@ -37,49 +34,59 @@ def run_ray_tune(model, result_grid = model.tune(data='coco8.yaml', use_ray=True) ``` """ + + LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune") if train_args is None: train_args = {} try: - subprocess.run('pip install ray[tune]'.split(), check=True) + subprocess.run("pip install ray[tune]<=2.9.3".split(), check=True) # do not add single quotes here + import ray from ray import tune from ray.air import RunConfig from ray.air.integrations.wandb import WandbLoggerCallback from ray.tune.schedulers import ASHAScheduler except ImportError: - raise ModuleNotFoundError('Tuning hyperparameters requires Ray Tune. Install with: pip install "ray[tune]"') + raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]<=2.9.3"') try: import wandb - assert hasattr(wandb, '__version__') + assert hasattr(wandb, "__version__") except (ImportError, AssertionError): wandb = False + checks.check_version(ray.__version__, "<=2.9.3", "ray") default_space = { # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), - 'lr0': tune.uniform(1e-5, 1e-1), - 'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) - 'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1 - 'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4 - 'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok) - 'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum - 'box': tune.uniform(0.02, 0.2), # box loss gain - 'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels) - 'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction) - 'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction) - 'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction) - 'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg) - 'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction) - 'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain) - 'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg) - 'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 - 'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability) - 'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability) - 'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability) - 'mixup': tune.uniform(0.0, 1.0), # image mixup (probability) - 'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability) + "lr0": tune.uniform(1e-5, 1e-1), + "lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) + "momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1 + "weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4 + "warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok) + "warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum + "box": tune.uniform(0.02, 0.2), # box loss gain + "cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels) + "hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction) + "hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction) + "hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction) + "degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg) + "translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction) + "scale": tune.uniform(0.0, 0.9), # image scale (+/- gain) + "shear": tune.uniform(0.0, 10.0), # image shear (+/- deg) + "perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 + "flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability) + "fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability) + "bgr": tune.uniform(0.0, 1.0), # image channel BGR (probability) + "mosaic": tune.uniform(0.0, 1.0), # image mixup (probability) + "mixup": tune.uniform(0.0, 1.0), # image mixup (probability) + "copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability) + } + + # Put the model in ray store + task = model.task + model_in_store = ray.put(model) def _tune(config): """ @@ -89,42 +96,50 @@ def run_ray_tune(model, config (dict): A dictionary of hyperparameters to use for training. Returns: - None. + None """ - model._reset_callbacks() + model_to_train = ray.get(model_in_store) # get the model from ray store for tuning + model_to_train.reset_callbacks() config.update(train_args) - model.train(**config) + results = model_to_train.train(**config) + return results.results_dict # Get search space if not space: space = default_space - LOGGER.warning('WARNING ⚠️ search space not provided, using default search space.') + LOGGER.warning("WARNING ⚠️ search space not provided, using default search space.") # Get dataset - data = train_args.get('data', TASK2DATA[model.task]) - space['data'] = data - if 'data' not in train_args: + data = train_args.get("data", TASK2DATA[task]) + space["data"] = data + if "data" not in train_args: LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".') # Define the trainable function with allocated resources - trainable_with_resources = tune.with_resources(_tune, {'cpu': NUM_THREADS, 'gpu': gpu_per_trial or 0}) + trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0}) # Define the ASHA scheduler for hyperparameter search - asha_scheduler = ASHAScheduler(time_attr='epoch', - metric=TASK2METRIC[model.task], - mode='max', - max_t=train_args.get('epochs') or DEFAULT_CFG_DICT['epochs'] or 100, - grace_period=grace_period, - reduction_factor=3) + asha_scheduler = ASHAScheduler( + time_attr="epoch", + metric=TASK2METRIC[task], + mode="max", + max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100, + grace_period=grace_period, + reduction_factor=3, + ) # Define the callbacks for the hyperparameter search - tuner_callbacks = [WandbLoggerCallback(project='YOLOv8-tune')] if wandb else [] + tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else [] # Create the Ray Tune hyperparameter search tuner - tuner = tune.Tuner(trainable_with_resources, - param_space=space, - tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples), - run_config=RunConfig(callbacks=tuner_callbacks, storage_path='./runs/tune')) + tune_dir = get_save_dir(DEFAULT_CFG, name="tune").resolve() # must be absolute dir + tune_dir.mkdir(parents=True, exist_ok=True) + tuner = tune.Tuner( + trainable_with_resources, + param_space=space, + tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples), + run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir), + ) # Run the hyperparameter search tuner.fit() diff --git a/ultralytics/yolo/__init__.py b/ultralytics/yolo/__init__.py deleted file mode 100644 index d1fa558..0000000 --- a/ultralytics/yolo/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Ultralytics YOLO 🚀, AGPL-3.0 license - -from . import v8 - -__all__ = 'v8', # tuple or list diff --git a/ultralytics/yolo/__pycache__/__init__.cpython-39.pyc b/ultralytics/yolo/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index e1746fd..0000000 Binary files a/ultralytics/yolo/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/ultralytics/yolo/cfg/__init__.py b/ultralytics/yolo/cfg/__init__.py deleted file mode 100644 index 5ea5519..0000000 --- a/ultralytics/yolo/cfg/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -import importlib -import sys - -from ultralytics.utils import LOGGER - -# Set modules in sys.modules under their old name -sys.modules['ultralytics.yolo.cfg'] = importlib.import_module('ultralytics.cfg') - -LOGGER.warning("WARNING ⚠️ 'ultralytics.yolo.cfg' is deprecated since '8.0.136' and will be removed in '8.1.0'. " - "Please use 'ultralytics.cfg' instead.") diff --git a/ultralytics/yolo/data/__init__.py b/ultralytics/yolo/data/__init__.py deleted file mode 100644 index f68391e..0000000 --- a/ultralytics/yolo/data/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -import importlib -import sys - -from ultralytics.utils import LOGGER - -# Set modules in sys.modules under their old name -sys.modules['ultralytics.yolo.data'] = importlib.import_module('ultralytics.data') -# This is for updating old cls models, or the way in following warning won't work. -sys.modules['ultralytics.yolo.data.augment'] = importlib.import_module('ultralytics.data.augment') - -DATA_WARNING = """WARNING ⚠️ 'ultralytics.yolo.data' is deprecated since '8.0.136' and will be removed in '8.1.0'. Please use 'ultralytics.data' instead. -Note this warning may be related to loading older models. You can update your model to current structure with: - import torch - ckpt = torch.load("model.pt") # applies to both official and custom models - torch.save(ckpt, "updated-model.pt") -""" -LOGGER.warning(DATA_WARNING) diff --git a/ultralytics/yolo/engine/__init__.py b/ultralytics/yolo/engine/__init__.py deleted file mode 100644 index 794efcd..0000000 --- a/ultralytics/yolo/engine/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -import importlib -import sys - -from ultralytics.utils import LOGGER - -# Set modules in sys.modules under their old name -sys.modules['ultralytics.yolo.engine'] = importlib.import_module('ultralytics.engine') - -LOGGER.warning("WARNING ⚠️ 'ultralytics.yolo.engine' is deprecated since '8.0.136' and will be removed in '8.1.0'. " - "Please use 'ultralytics.engine' instead.") diff --git a/ultralytics/yolo/engine/__pycache__/__init__.cpython-39.pyc b/ultralytics/yolo/engine/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 33d188c..0000000 Binary files a/ultralytics/yolo/engine/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py deleted file mode 100644 index 71557b0..0000000 --- a/ultralytics/yolo/utils/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -import importlib -import sys - -from ultralytics.utils import LOGGER - -# Set modules in sys.modules under their old name -sys.modules['ultralytics.yolo.utils'] = importlib.import_module('ultralytics.utils') - -UTILS_WARNING = """WARNING ⚠️ 'ultralytics.yolo.utils' is deprecated since '8.0.136' and will be removed in '8.1.0'. Please use 'ultralytics.utils' instead. -Note this warning may be related to loading older models. You can update your model to current structure with: - import torch - ckpt = torch.load("model.pt") # applies to both official and custom models - torch.save(ckpt, "updated-model.pt") -""" -LOGGER.warning(UTILS_WARNING) diff --git a/ultralytics/yolo/utils/__pycache__/__init__.cpython-39.pyc b/ultralytics/yolo/utils/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 4aa5370..0000000 Binary files a/ultralytics/yolo/utils/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/ultralytics/yolo/v8/__init__.py b/ultralytics/yolo/v8/__init__.py deleted file mode 100644 index 51adf81..0000000 --- a/ultralytics/yolo/v8/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -import importlib -import sys - -from ultralytics.utils import LOGGER - -# Set modules in sys.modules under their old name -sys.modules['ultralytics.yolo.v8'] = importlib.import_module('ultralytics.models.yolo') - -LOGGER.warning("WARNING ⚠️ 'ultralytics.yolo.v8' is deprecated since '8.0.136' and will be removed in '8.1.0'. " - "Please use 'ultralytics.models.yolo' instead.") diff --git a/ultralytics/yolo/v8/__pycache__/__init__.cpython-39.pyc b/ultralytics/yolo/v8/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 063c3ba..0000000 Binary files a/ultralytics/yolo/v8/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/utils/__pycache__/dataloaders.cpython-39.pyc b/utils/__pycache__/dataloaders.cpython-39.pyc index c77703a..8da62af 100644 Binary files a/utils/__pycache__/dataloaders.cpython-39.pyc and b/utils/__pycache__/dataloaders.cpython-39.pyc differ diff --git a/utils/__pycache__/getsource.cpython-39.pyc b/utils/__pycache__/getsource.cpython-39.pyc index 2253c09..51e6471 100644 Binary files a/utils/__pycache__/getsource.cpython-39.pyc and b/utils/__pycache__/getsource.cpython-39.pyc differ diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 406009c..eeceb2b 100644 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -300,7 +300,9 @@ class LoadImages: ret_val, im0 = self.cap.read() self.frame += 1 - im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False + + # if self.orientation == 270: + # im0 = cv2.rotate(im0, cv2.ROTATE_90_COUNTERCLOCKWISE) # for use if cv2 autorotation is False s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ' else: @@ -329,14 +331,14 @@ class LoadImages: def _cv2_rotate(self, im): # Rotate a cv2 video manually - # if self.orientation == 0: - # return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE) - # elif self.orientation == 180: - # return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE) - # elif self.orientation == 90: - # return cv2.rotate(im, cv2.ROTATE_180) - if self.orientation == 270: + if self.orientation == 0: + return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE) + elif self.orientation == 180: return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE) + elif self.orientation == 90: + return cv2.rotate(im, cv2.ROTATE_180) + # if self.orientation == 270: + # return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE) return im