582 lines
24 KiB
Python
582 lines
24 KiB
Python
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||
|
||
import glob
|
||
import math
|
||
import os
|
||
import time
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
from threading import Thread
|
||
from urllib.parse import urlparse
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import requests
|
||
import torch
|
||
from PIL import Image
|
||
|
||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
||
from ultralytics.utils import LOGGER, is_colab, is_kaggle, ops
|
||
from ultralytics.utils.checks import check_requirements
|
||
|
||
import subprocess
|
||
import json
|
||
|
||
@dataclass
|
||
class SourceTypes:
|
||
"""Class to represent various types of input sources for predictions."""
|
||
|
||
stream: bool = False
|
||
screenshot: bool = False
|
||
from_img: bool = False
|
||
tensor: bool = False
|
||
|
||
|
||
class LoadStreams:
|
||
"""
|
||
Stream Loader for various types of video streams, Supports RTSP, RTMP, HTTP, and TCP streams.
|
||
|
||
Attributes:
|
||
sources (str): The source input paths or URLs for the video streams.
|
||
vid_stride (int): Video frame-rate stride, defaults to 1.
|
||
buffer (bool): Whether to buffer input streams, defaults to False.
|
||
running (bool): Flag to indicate if the streaming thread is running.
|
||
mode (str): Set to 'stream' indicating real-time capture.
|
||
imgs (list): List of image frames for each stream.
|
||
fps (list): List of FPS for each stream.
|
||
frames (list): List of total frames for each stream.
|
||
threads (list): List of threads for each stream.
|
||
shape (list): List of shapes for each stream.
|
||
caps (list): List of cv2.VideoCapture objects for each stream.
|
||
bs (int): Batch size for processing.
|
||
|
||
Methods:
|
||
__init__: Initialize the stream loader.
|
||
update: Read stream frames in daemon thread.
|
||
close: Close stream loader and release resources.
|
||
__iter__: Returns an iterator object for the class.
|
||
__next__: Returns source paths, transformed, and original images for processing.
|
||
__len__: Return the length of the sources object.
|
||
|
||
Example:
|
||
```bash
|
||
yolo predict source='rtsp://example.com/media.mp4'
|
||
```
|
||
"""
|
||
|
||
def __init__(self, sources="file.streams", vid_stride=1, buffer=False):
|
||
"""Initialize instance variables and check for consistent input stream shapes."""
|
||
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
||
self.buffer = buffer # buffer input streams
|
||
self.running = True # running flag for Thread
|
||
self.mode = "stream"
|
||
self.vid_stride = vid_stride # video frame-rate stride
|
||
|
||
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
|
||
n = len(sources)
|
||
self.bs = n
|
||
self.fps = [0] * n # frames per second
|
||
self.frames = [0] * n
|
||
self.threads = [None] * n
|
||
self.caps = [None] * n # video capture objects
|
||
self.imgs = [[] for _ in range(n)] # images
|
||
self.shape = [[] for _ in range(n)] # image shapes
|
||
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
|
||
for i, s in enumerate(sources): # index, source
|
||
# Start thread to read frames from video stream
|
||
st = f"{i + 1}/{n}: {s}... "
|
||
if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
|
||
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
|
||
s = get_best_youtube_url(s)
|
||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||
if s == 0 and (is_colab() or is_kaggle()):
|
||
raise NotImplementedError(
|
||
"'source=0' webcam not supported in Colab and Kaggle notebooks. "
|
||
"Try running 'source=0' in a local environment."
|
||
)
|
||
self.caps[i] = cv2.VideoCapture(s) # store video capture object
|
||
if not self.caps[i].isOpened():
|
||
raise ConnectionError(f"{st}Failed to open {s}")
|
||
w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
fps = self.caps[i].get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
|
||
self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float(
|
||
"inf"
|
||
) # infinite stream fallback
|
||
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
|
||
|
||
success, im = self.caps[i].read() # guarantee first frame
|
||
if not success or im is None:
|
||
raise ConnectionError(f"{st}Failed to read images from {s}")
|
||
self.imgs[i].append(im)
|
||
self.shape[i] = im.shape
|
||
self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True)
|
||
LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)")
|
||
self.threads[i].start()
|
||
LOGGER.info("") # newline
|
||
|
||
def update(self, i, cap, stream):
|
||
"""Read stream `i` frames in daemon thread."""
|
||
n, f = 0, self.frames[i] # frame number, frame array
|
||
while self.running and cap.isOpened() and n < (f - 1):
|
||
if len(self.imgs[i]) < 30: # keep a <=30-image buffer
|
||
n += 1
|
||
cap.grab() # .read() = .grab() followed by .retrieve()
|
||
if n % self.vid_stride == 0:
|
||
success, im = cap.retrieve()
|
||
if not success:
|
||
im = np.zeros(self.shape[i], dtype=np.uint8)
|
||
LOGGER.warning("WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.")
|
||
cap.open(stream) # re-open stream if signal was lost
|
||
if self.buffer:
|
||
self.imgs[i].append(im)
|
||
else:
|
||
self.imgs[i] = [im]
|
||
else:
|
||
time.sleep(0.01) # wait until the buffer is empty
|
||
|
||
def close(self):
|
||
"""Close stream loader and release resources."""
|
||
self.running = False # stop flag for Thread
|
||
for thread in self.threads:
|
||
if thread.is_alive():
|
||
thread.join(timeout=5) # Add timeout
|
||
for cap in self.caps: # Iterate through the stored VideoCapture objects
|
||
try:
|
||
cap.release() # release video capture
|
||
except Exception as e:
|
||
LOGGER.warning(f"WARNING ⚠️ Could not release VideoCapture object: {e}")
|
||
cv2.destroyAllWindows()
|
||
|
||
def __iter__(self):
|
||
"""Iterates through YOLO image feed and re-opens unresponsive streams."""
|
||
self.count = -1
|
||
return self
|
||
|
||
def __next__(self):
|
||
"""Returns source paths, transformed and original images for processing."""
|
||
self.count += 1
|
||
|
||
images = []
|
||
for i, x in enumerate(self.imgs):
|
||
# Wait until a frame is available in each buffer
|
||
while not x:
|
||
if not self.threads[i].is_alive() or cv2.waitKey(1) == ord("q"): # q to quit
|
||
self.close()
|
||
raise StopIteration
|
||
time.sleep(1 / min(self.fps))
|
||
x = self.imgs[i]
|
||
if not x:
|
||
LOGGER.warning(f"WARNING ⚠️ Waiting for stream {i}")
|
||
|
||
# Get and remove the first frame from imgs buffer
|
||
if self.buffer:
|
||
images.append(x.pop(0))
|
||
|
||
# Get the last frame, and clear the rest from the imgs buffer
|
||
else:
|
||
images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))
|
||
x.clear()
|
||
|
||
return self.sources, images, [""] * self.bs
|
||
|
||
def __len__(self):
|
||
"""Return the length of the sources object."""
|
||
return self.bs # 1E12 frames = 32 streams at 30 FPS for 30 years
|
||
|
||
|
||
class LoadScreenshots:
|
||
"""
|
||
YOLOv8 screenshot dataloader.
|
||
|
||
This class manages the loading of screenshot images for processing with YOLOv8.
|
||
Suitable for use with `yolo predict source=screen`.
|
||
|
||
Attributes:
|
||
source (str): The source input indicating which screen to capture.
|
||
screen (int): The screen number to capture.
|
||
left (int): The left coordinate for screen capture area.
|
||
top (int): The top coordinate for screen capture area.
|
||
width (int): The width of the screen capture area.
|
||
height (int): The height of the screen capture area.
|
||
mode (str): Set to 'stream' indicating real-time capture.
|
||
frame (int): Counter for captured frames.
|
||
sct (mss.mss): Screen capture object from `mss` library.
|
||
bs (int): Batch size, set to 1.
|
||
monitor (dict): Monitor configuration details.
|
||
|
||
Methods:
|
||
__iter__: Returns an iterator object.
|
||
__next__: Captures the next screenshot and returns it.
|
||
"""
|
||
|
||
def __init__(self, source):
|
||
"""Source = [screen_number left top width height] (pixels)."""
|
||
check_requirements("mss")
|
||
import mss # noqa
|
||
|
||
source, *params = source.split()
|
||
self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
|
||
if len(params) == 1:
|
||
self.screen = int(params[0])
|
||
elif len(params) == 4:
|
||
left, top, width, height = (int(x) for x in params)
|
||
elif len(params) == 5:
|
||
self.screen, left, top, width, height = (int(x) for x in params)
|
||
self.mode = "stream"
|
||
self.frame = 0
|
||
self.sct = mss.mss()
|
||
self.bs = 1
|
||
self.fps = 30
|
||
|
||
# Parse monitor shape
|
||
monitor = self.sct.monitors[self.screen]
|
||
self.top = monitor["top"] if top is None else (monitor["top"] + top)
|
||
self.left = monitor["left"] if left is None else (monitor["left"] + left)
|
||
self.width = width or monitor["width"]
|
||
self.height = height or monitor["height"]
|
||
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
|
||
|
||
def __iter__(self):
|
||
"""Returns an iterator of the object."""
|
||
return self
|
||
|
||
def __next__(self):
|
||
"""mss screen capture: get raw pixels from the screen as np array."""
|
||
im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR
|
||
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
|
||
|
||
self.frame += 1
|
||
return [str(self.screen)], [im0], [s] # screen, img, string
|
||
|
||
|
||
class LoadImagesAndVideos:
|
||
"""
|
||
YOLOv8 image/video dataloader.
|
||
|
||
This class manages the loading and pre-processing of image and video data for YOLOv8. It supports loading from
|
||
various formats, including single image files, video files, and lists of image and video paths.
|
||
|
||
Attributes:
|
||
files (list): List of image and video file paths.
|
||
nf (int): Total number of files (images and videos).
|
||
video_flag (list): Flags indicating whether a file is a video (True) or an image (False).
|
||
mode (str): Current mode, 'image' or 'video'.
|
||
vid_stride (int): Stride for video frame-rate, defaults to 1.
|
||
bs (int): Batch size, set to 1 for this class.
|
||
cap (cv2.VideoCapture): Video capture object for OpenCV.
|
||
frame (int): Frame counter for video.
|
||
frames (int): Total number of frames in the video.
|
||
count (int): Counter for iteration, initialized at 0 during `__iter__()`.
|
||
|
||
Methods:
|
||
_new_video(path): Create a new cv2.VideoCapture object for a given video path.
|
||
"""
|
||
|
||
def __init__(self, path, batch=1, vid_stride=1):
|
||
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
|
||
parent = None
|
||
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
|
||
parent = Path(path).parent
|
||
path = Path(path).read_text().splitlines() # list of sources
|
||
files = []
|
||
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
||
a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
|
||
if "*" in a:
|
||
files.extend(sorted(glob.glob(a, recursive=True))) # glob
|
||
elif os.path.isdir(a):
|
||
files.extend(sorted(glob.glob(os.path.join(a, "*.*")))) # dir
|
||
elif os.path.isfile(a):
|
||
files.append(a) # files (absolute or relative to CWD)
|
||
elif parent and (parent / p).is_file():
|
||
files.append(str((parent / p).absolute())) # files (relative to *.txt file parent)
|
||
else:
|
||
raise FileNotFoundError(f"{p} does not exist")
|
||
|
||
images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
|
||
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
|
||
ni, nv = len(images), len(videos)
|
||
|
||
self.files = images + videos
|
||
self.nf = ni + nv # number of files
|
||
self.ni = ni # number of images
|
||
self.video_flag = [False] * ni + [True] * nv
|
||
self.mode = "image"
|
||
self.vid_stride = vid_stride # video frame-rate stride
|
||
self.bs = batch
|
||
if any(videos):
|
||
self._new_video(videos[0]) # new video
|
||
else:
|
||
self.cap = None
|
||
if self.nf == 0:
|
||
raise FileNotFoundError(
|
||
f"No images or videos found in {p}. "
|
||
f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
|
||
)
|
||
|
||
def __iter__(self):
|
||
"""Returns an iterator object for VideoStream or ImageFolder."""
|
||
self.count = 0
|
||
return self
|
||
|
||
def __next__(self):
|
||
"""Returns the next batch of images or video frames along with their paths and metadata."""
|
||
paths, imgs, info = [], [], []
|
||
while len(imgs) < self.bs:
|
||
if self.count >= self.nf: # end of file list
|
||
if len(imgs) > 0:
|
||
return paths, imgs, info # return last partial batch
|
||
else:
|
||
raise StopIteration
|
||
|
||
path = self.files[self.count]
|
||
if self.video_flag[self.count]:
|
||
self.mode = "video"
|
||
if not self.cap or not self.cap.isOpened():
|
||
self._new_video(path)
|
||
|
||
for _ in range(self.vid_stride):
|
||
success = self.cap.grab()
|
||
if not success:
|
||
break # end of video or failure
|
||
|
||
if success:
|
||
success, im0 = self.cap.retrieve()
|
||
##======================
|
||
'''判断视频是否含旋转信息'''
|
||
rotation = self.get_rotation(path)
|
||
if rotation == 270:
|
||
im0 = cv2.rotate(im0, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||
###======================
|
||
if success:
|
||
self.frame += 1
|
||
paths.append(path)
|
||
imgs.append(im0)
|
||
info.append(f"video {self.count + 1}/{self.nf} (frame {self.frame}/{self.frames}) {path}: ")
|
||
if self.frame == self.frames: # end of video
|
||
self.count += 1
|
||
self.cap.release()
|
||
else:
|
||
# Move to the next file if the current video ended or failed to open
|
||
self.count += 1
|
||
if self.cap:
|
||
self.cap.release()
|
||
if self.count < self.nf:
|
||
self._new_video(self.files[self.count])
|
||
|
||
else:
|
||
self.mode = "image"
|
||
im0 = cv2.imread(path) # BGR
|
||
if im0 is None:
|
||
raise FileNotFoundError(f"Image Not Found {path}")
|
||
paths.append(path)
|
||
imgs.append(im0)
|
||
info.append(f"image {self.count + 1}/{self.nf} {path}: ")
|
||
self.count += 1 # move to the next file
|
||
if self.count >= self.ni: # end of image list
|
||
break
|
||
|
||
return paths, imgs, info
|
||
|
||
def _new_video(self, path):
|
||
"""Creates a new video capture object for the given path."""
|
||
self.frame = 0
|
||
self.cap = cv2.VideoCapture(path)
|
||
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
|
||
if not self.cap.isOpened():
|
||
raise FileNotFoundError(f"Failed to open video {path}")
|
||
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
|
||
|
||
def get_rotation(self, filename):
|
||
cmd = [
|
||
"ffprobe", # 注意是 ffprobe,不是 ffmpeg
|
||
"-v", "error",
|
||
"-select_streams", "v:0",
|
||
"-show_entries", "stream_tags=rotate",
|
||
"-of", "json",
|
||
filename
|
||
]
|
||
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||
if result.returncode == 0:
|
||
metadata = json.loads(result.stdout)
|
||
rotation = metadata.get("streams", [{}])[0].get("tags", {}).get("rotate", 0)
|
||
return int(rotation)
|
||
else:
|
||
return 0
|
||
|
||
def __len__(self):
|
||
"""Returns the number of batches in the object."""
|
||
return math.ceil(self.nf / self.bs) # number of files
|
||
|
||
|
||
class LoadPilAndNumpy:
|
||
"""
|
||
Load images from PIL and Numpy arrays for batch processing.
|
||
|
||
This class is designed to manage loading and pre-processing of image data from both PIL and Numpy formats.
|
||
It performs basic validation and format conversion to ensure that the images are in the required format for
|
||
downstream processing.
|
||
|
||
Attributes:
|
||
paths (list): List of image paths or autogenerated filenames.
|
||
im0 (list): List of images stored as Numpy arrays.
|
||
mode (str): Type of data being processed, defaults to 'image'.
|
||
bs (int): Batch size, equivalent to the length of `im0`.
|
||
|
||
Methods:
|
||
_single_check(im): Validate and format a single image to a Numpy array.
|
||
"""
|
||
|
||
def __init__(self, im0):
|
||
"""Initialize PIL and Numpy Dataloader."""
|
||
if not isinstance(im0, list):
|
||
im0 = [im0]
|
||
self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
|
||
self.im0 = [self._single_check(im) for im in im0]
|
||
self.mode = "image"
|
||
self.bs = len(self.im0)
|
||
|
||
@staticmethod
|
||
def _single_check(im):
|
||
"""Validate and format an image to numpy array."""
|
||
assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
|
||
if isinstance(im, Image.Image):
|
||
if im.mode != "RGB":
|
||
im = im.convert("RGB")
|
||
im = np.asarray(im)[:, :, ::-1]
|
||
im = np.ascontiguousarray(im) # contiguous
|
||
return im
|
||
|
||
def __len__(self):
|
||
"""Returns the length of the 'im0' attribute."""
|
||
return len(self.im0)
|
||
|
||
def __next__(self):
|
||
"""Returns batch paths, images, processed images, None, ''."""
|
||
if self.count == 1: # loop only once as it's batch inference
|
||
raise StopIteration
|
||
self.count += 1
|
||
return self.paths, self.im0, [""] * self.bs
|
||
|
||
def __iter__(self):
|
||
"""Enables iteration for class LoadPilAndNumpy."""
|
||
self.count = 0
|
||
return self
|
||
|
||
|
||
class LoadTensor:
|
||
"""
|
||
Load images from torch.Tensor data.
|
||
|
||
This class manages the loading and pre-processing of image data from PyTorch tensors for further processing.
|
||
|
||
Attributes:
|
||
im0 (torch.Tensor): The input tensor containing the image(s).
|
||
bs (int): Batch size, inferred from the shape of `im0`.
|
||
mode (str): Current mode, set to 'image'.
|
||
paths (list): List of image paths or filenames.
|
||
count (int): Counter for iteration, initialized at 0 during `__iter__()`.
|
||
|
||
Methods:
|
||
_single_check(im, stride): Validate and possibly modify the input tensor.
|
||
"""
|
||
|
||
def __init__(self, im0) -> None:
|
||
"""Initialize Tensor Dataloader."""
|
||
self.im0 = self._single_check(im0)
|
||
self.bs = self.im0.shape[0]
|
||
self.mode = "image"
|
||
self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
|
||
|
||
@staticmethod
|
||
def _single_check(im, stride=32):
|
||
"""Validate and format an image to torch.Tensor."""
|
||
s = (
|
||
f"WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) "
|
||
f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible."
|
||
)
|
||
if len(im.shape) != 4:
|
||
if len(im.shape) != 3:
|
||
raise ValueError(s)
|
||
LOGGER.warning(s)
|
||
im = im.unsqueeze(0)
|
||
if im.shape[2] % stride or im.shape[3] % stride:
|
||
raise ValueError(s)
|
||
if im.max() > 1.0 + torch.finfo(im.dtype).eps: # torch.float32 eps is 1.2e-07
|
||
LOGGER.warning(
|
||
f"WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. "
|
||
f"Dividing input by 255."
|
||
)
|
||
im = im.float() / 255.0
|
||
|
||
return im
|
||
|
||
def __iter__(self):
|
||
"""Returns an iterator object."""
|
||
self.count = 0
|
||
return self
|
||
|
||
def __next__(self):
|
||
"""Return next item in the iterator."""
|
||
if self.count == 1:
|
||
raise StopIteration
|
||
self.count += 1
|
||
return self.paths, self.im0, [""] * self.bs
|
||
|
||
def __len__(self):
|
||
"""Returns the batch size."""
|
||
return self.bs
|
||
|
||
|
||
def autocast_list(source):
|
||
"""Merges a list of source of different types into a list of numpy arrays or PIL images."""
|
||
files = []
|
||
for im in source:
|
||
if isinstance(im, (str, Path)): # filename or uri
|
||
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im))
|
||
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
|
||
files.append(im)
|
||
else:
|
||
raise TypeError(
|
||
f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n"
|
||
f"See https://docs.ultralytics.com/modes/predict for supported source types."
|
||
)
|
||
|
||
return files
|
||
|
||
|
||
def get_best_youtube_url(url, use_pafy=True):
|
||
"""
|
||
Retrieves the URL of the best quality MP4 video stream from a given YouTube video.
|
||
|
||
This function uses the pafy or yt_dlp library to extract the video info from YouTube. It then finds the highest
|
||
quality MP4 format that has video codec but no audio codec, and returns the URL of this video stream.
|
||
|
||
Args:
|
||
url (str): The URL of the YouTube video.
|
||
use_pafy (bool): Use the pafy package, default=True, otherwise use yt_dlp package.
|
||
|
||
Returns:
|
||
(str): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
|
||
"""
|
||
if use_pafy:
|
||
check_requirements(("pafy", "youtube_dl==2020.12.2"))
|
||
import pafy # noqa
|
||
|
||
return pafy.new(url).getbestvideo(preftype="mp4").url
|
||
else:
|
||
check_requirements("yt-dlp")
|
||
import yt_dlp
|
||
|
||
with yt_dlp.YoutubeDL({"quiet": True}) as ydl:
|
||
info_dict = ydl.extract_info(url, download=False) # extract info
|
||
for f in reversed(info_dict.get("formats", [])): # reversed because best is usually last
|
||
# Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size
|
||
good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080
|
||
if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4":
|
||
return f.get("url")
|
||
|
||
|
||
# Define constants
|
||
LOADERS = (LoadStreams, LoadPilAndNumpy, LoadImagesAndVideos, LoadScreenshots)
|