add yolo v10 and modify pipeline

This commit is contained in:
王庆刚
2025-03-28 13:19:54 +08:00
parent 183299c06b
commit 798c596acc
471 changed files with 19109 additions and 7342 deletions

View File

@ -5,24 +5,51 @@ import requests
from ultralytics.data.utils import HUBDatasetStats
from ultralytics.hub.auth import Auth
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
from ultralytics.utils import LOGGER, SETTINGS
from ultralytics.utils import LOGGER, SETTINGS, checks
def login(api_key=''):
def login(api_key: str = None, save=True) -> bool:
"""
Log in to the Ultralytics HUB API using the provided API key.
The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY
environment variable if successfully authenticated.
Args:
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
api_key (str, optional): API key to use for authentication.
If not provided, it will be retrieved from SETTINGS or HUB_API_KEY environment variable.
save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
Example:
```python
from ultralytics import hub
hub.login('API_KEY')
```
Returns:
(bool): True if authentication is successful, False otherwise.
"""
Auth(api_key, verbose=True)
checks.check_requirements("hub-sdk>=0.0.6")
from hub_sdk import HUBClient
api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL
saved_key = SETTINGS.get("api_key")
active_key = api_key or saved_key
credentials = {"api_key": active_key} if active_key and active_key != "" else None # set credentials
client = HUBClient(credentials) # initialize HUBClient
if client.authenticated:
# Successfully authenticated with HUB
if save and client.api_key != saved_key:
SETTINGS.update({"api_key": client.api_key}) # update settings with valid API key
# Set message based on whether key was provided or retrieved from settings
log_message = (
"New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅"
)
LOGGER.info(f"{PREFIX}{log_message}")
return True
else:
# Failed to authenticate with HUB
LOGGER.info(f"{PREFIX}Get API key from {api_key_url} and then run 'yolo hub login API_KEY'")
return False
def logout():
@ -36,52 +63,53 @@ def logout():
hub.logout()
```
"""
SETTINGS['api_key'] = ''
SETTINGS["api_key"] = ""
SETTINGS.save()
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
def reset_model(model_id=''):
def reset_model(model_id=""):
"""Reset a trained model to an untrained state."""
r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
if r.status_code == 200:
LOGGER.info(f'{PREFIX}Model reset successfully')
LOGGER.info(f"{PREFIX}Model reset successfully")
return
LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}")
def export_fmts_hub():
"""Returns a list of HUB-supported export formats."""
from ultralytics.engine.exporter import export_formats
return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
def export_model(model_id='', format='torchscript'):
def export_model(model_id="", format="torchscript"):
"""Export a model to all formats."""
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
r = requests.post(f'{HUB_API_ROOT}/v1/models/{model_id}/export',
json={'format': format},
headers={'x-api-key': Auth().api_key})
assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
LOGGER.info(f'{PREFIX}{format} export started ✅')
r = requests.post(
f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
)
assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
LOGGER.info(f"{PREFIX}{format} export started ✅")
def get_export(model_id='', format='torchscript'):
def get_export(model_id="", format="torchscript"):
"""Get an exported model dictionary with download URL."""
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
r = requests.post(f'{HUB_API_ROOT}/get-export',
json={
'apiKey': Auth().api_key,
'modelId': model_id,
'format': format})
assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
r = requests.post(
f"{HUB_API_ROOT}/get-export",
json={"apiKey": Auth().api_key, "modelId": model_id, "format": format},
headers={"x-api-key": Auth().api_key},
)
assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
return r.json()
def check_dataset(path='', task='detect'):
def check_dataset(path="", task="detect"):
"""
Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is
uploaded to the HUB. Usage examples are given below.
Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded
to the HUB. Usage examples are given below.
Args:
path (str, optional): Path to data.zip (with data.yaml inside data.zip). Defaults to ''.
@ -97,4 +125,4 @@ def check_dataset(path='', task='detect'):
```
"""
HUBDatasetStats(path=path, task=task).get_json()
LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.')
LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.")

View File

@ -5,13 +5,27 @@ import requests
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials
from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab
API_KEY_URL = f'{HUB_WEB_ROOT}/settings?tab=api+keys'
API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"
class Auth:
"""
Manages authentication processes including API key handling, cookie-based authentication, and header generation.
The class supports different methods of authentication:
1. Directly using an API key.
2. Authenticating using browser cookies (specifically in Google Colab).
3. Prompting the user to enter an API key.
Attributes:
id_token (str or bool): Token used for identity verification, initialized as False.
api_key (str or bool): API key for authentication, initialized as False.
model_key (bool): Placeholder for model key, initialized as False.
"""
id_token = api_key = model_key = False
def __init__(self, api_key='', verbose=False):
def __init__(self, api_key="", verbose=False):
"""
Initialize the Auth class with an optional API key.
@ -19,18 +33,18 @@ class Auth:
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
"""
# Split the input API key in case it contains a combined key_model and keep only the API key part
api_key = api_key.split('_')[0]
api_key = api_key.split("_")[0]
# Set API key attribute as value passed or SETTINGS API key if none passed
self.api_key = api_key or SETTINGS.get('api_key', '')
self.api_key = api_key or SETTINGS.get("api_key", "")
# If an API key is provided
if self.api_key:
# If the provided API key matches the API key in the SETTINGS
if self.api_key == SETTINGS.get('api_key'):
if self.api_key == SETTINGS.get("api_key"):
# Log that the user is already logged in
if verbose:
LOGGER.info(f'{PREFIX}Authenticated ✅')
LOGGER.info(f"{PREFIX}Authenticated ✅")
return
else:
# Attempt to authenticate with the provided API key
@ -45,62 +59,65 @@ class Auth:
# Update SETTINGS with the new API key after successful authentication
if success:
SETTINGS.update({'api_key': self.api_key})
SETTINGS.update({"api_key": self.api_key})
# Log that the new login was successful
if verbose:
LOGGER.info(f'{PREFIX}New authentication successful ✅')
LOGGER.info(f"{PREFIX}New authentication successful ✅")
elif verbose:
LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}')
LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo hub login API_KEY'")
def request_api_key(self, max_attempts=3):
"""
Prompt the user to input their API key. Returns the model ID.
Prompt the user to input their API key.
Returns the model ID.
"""
import getpass
for attempts in range(max_attempts):
LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
input_key = getpass.getpass(f'Enter API key from {API_KEY_URL} ')
self.api_key = input_key.split('_')[0] # remove model id if present
LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ")
self.api_key = input_key.split("_")[0] # remove model id if present
if self.authenticate():
return True
raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
def authenticate(self) -> bool:
"""
Attempt to authenticate with the server using either id_token or API key.
Returns:
bool: True if authentication is successful, False otherwise.
(bool): True if authentication is successful, False otherwise.
"""
try:
if header := self.get_auth_header():
r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
if not r.json().get('success', False):
raise ConnectionError('Unable to authenticate.')
r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
if not r.json().get("success", False):
raise ConnectionError("Unable to authenticate.")
return True
raise ConnectionError('User has not authenticated locally.')
raise ConnectionError("User has not authenticated locally.")
except ConnectionError:
self.id_token = self.api_key = False # reset invalid
LOGGER.warning(f'{PREFIX}Invalid API key ⚠️')
LOGGER.warning(f"{PREFIX}Invalid API key ⚠️")
return False
def auth_with_cookies(self) -> bool:
"""
Attempt to fetch authentication via cookies and set id_token.
User must be logged in to HUB and running in a supported browser.
Attempt to fetch authentication via cookies and set id_token. User must be logged in to HUB and running in a
supported browser.
Returns:
bool: True if authentication is successful, False otherwise.
(bool): True if authentication is successful, False otherwise.
"""
if not is_colab():
return False # Currently only works with Colab
try:
authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto')
if authn.get('success', False):
self.id_token = authn.get('data', {}).get('idToken', None)
authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto")
if authn.get("success", False):
self.id_token = authn.get("data", {}).get("idToken", None)
self.authenticate()
return True
raise ConnectionError('Unable to fetch browser authentication details.')
raise ConnectionError("Unable to fetch browser authentication details.")
except ConnectionError:
self.id_token = False # reset invalid
return False
@ -113,7 +130,7 @@ class Auth:
(dict): The authentication header if id_token or API key is set, None otherwise.
"""
if self.id_token:
return {'authorization': f'Bearer {self.id_token}'}
return {"authorization": f"Bearer {self.id_token}"}
elif self.api_key:
return {'x-api-key': self.api_key}
return {"x-api-key": self.api_key}
# else returns None

View File

@ -1,29 +1,26 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import signal
import sys
import threading
import time
from http import HTTPStatus
from pathlib import Path
from time import sleep
import requests
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, smart_request
from ultralytics.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
from ultralytics.hub.utils import HUB_WEB_ROOT, HELP_MSG, PREFIX, TQDM
from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab
from ultralytics.utils.errors import HUBModelError
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local"
class HUBTrainingSession:
"""
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
Args:
url (str): Model identifier used to initialize the HUB training session.
Attributes:
agent_id (str): Identifier for the instance communicating with the server.
model_id (str): Identifier for the YOLOv5 model being trained.
model_id (str): Identifier for the YOLO model being trained.
model_url (str): URL for the model in Ultralytics HUB.
api_url (str): API URL for the model in Ultralytics HUB.
auth_header (dict): Authentication header for the Ultralytics HUB API requests.
@ -34,110 +31,287 @@ class HUBTrainingSession:
alive (bool): Indicates if the heartbeat loop is active.
"""
def __init__(self, url):
def __init__(self, identifier):
"""
Initialize the HUBTrainingSession with the provided model identifier.
Args:
url (str): Model identifier used to initialize the HUB training session.
It can be a URL string or a model key with specific format.
identifier (str): Model identifier used to initialize the HUB training session.
It can be a URL string or a model key with specific format.
Raises:
ValueError: If the provided model identifier is invalid.
ConnectionError: If connecting with global API key is not supported.
ModuleNotFoundError: If hub-sdk package is not installed.
"""
from hub_sdk import HUBClient
from ultralytics.hub.auth import Auth
self.rate_limits = {
"metrics": 3.0,
"ckpt": 900.0,
"heartbeat": 300.0,
} # rate limits (seconds)
self.metrics_queue = {} # holds metrics for each epoch until upload
self.metrics_upload_failed_queue = {} # holds metrics for each epoch if upload failed
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
# Parse input
if url.startswith(f'{HUB_WEB_ROOT}/models/'):
url = url.split(f'{HUB_WEB_ROOT}/models/')[-1]
if [len(x) for x in url.split('_')] == [42, 20]:
key, model_id = url.split('_')
elif len(url) == 20:
key, model_id = '', url
api_key, model_id, self.filename = self._parse_identifier(identifier)
# Get credentials
active_key = api_key or SETTINGS.get("api_key")
credentials = {"api_key": active_key} if active_key else None # set credentials
# Initialize client
self.client = HUBClient(credentials)
if model_id:
self.load_model(model_id) # load existing model
else:
raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. "
f"model='{HUB_WEB_ROOT}/models/MODEL_ID' and try again.")
self.model = self.client.model() # load empty model
# Authorize
auth = Auth(key)
self.agent_id = None # identifies which instance is communicating with server
self.model_id = model_id
self.model_url = f'{HUB_WEB_ROOT}/models/{model_id}'
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
self.auth_header = auth.get_auth_header()
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
self.timers = {} # rate limit timers (seconds)
self.metrics_queue = {} # metrics queue
self.model = self._get_model()
self.alive = True
self._start_heartbeat() # start heartbeats
self._register_signal_handlers()
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
def load_model(self, model_id):
"""Loads an existing model from Ultralytics HUB using the provided model identifier."""
self.model = self.client.model(model_id)
if not self.model.data: # then model does not exist
raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling
def _register_signal_handlers(self):
"""Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination."""
signal.signal(signal.SIGTERM, self._handle_signal)
signal.signal(signal.SIGINT, self._handle_signal)
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
def _handle_signal(self, signum, frame):
self._set_train_args()
# Start heartbeats for HUB to monitor agent
self.model.start_heartbeat(self.rate_limits["heartbeat"])
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
def create_model(self, model_args):
"""Initializes a HUB training session with the specified model identifier."""
payload = {
"config": {
"batchSize": model_args.get("batch", -1),
"epochs": model_args.get("epochs", 300),
"imageSize": model_args.get("imgsz", 640),
"patience": model_args.get("patience", 100),
"device": model_args.get("device", ""),
"cache": model_args.get("cache", "ram"),
},
"dataset": {"name": model_args.get("data")},
"lineage": {
"architecture": {
"name": self.filename.replace(".pt", "").replace(".yaml", ""),
},
"parent": {},
},
"meta": {"name": self.filename},
}
if self.filename.endswith(".pt"):
payload["lineage"]["parent"]["name"] = self.filename
self.model.create_model(payload)
# Model could not be created
# TODO: improve error handling
if not self.model.id:
return
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
# Start heartbeats for HUB to monitor agent
self.model.start_heartbeat(self.rate_limits["heartbeat"])
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
def _parse_identifier(self, identifier):
"""
Handle kill signals and prevent heartbeats from being sent on Colab after termination.
This method does not use frame, it is included as it is passed by signal.
"""
if self.alive is True:
LOGGER.info(f'{PREFIX}Kill signal received! ❌')
self._stop_heartbeat()
sys.exit(signum)
Parses the given identifier to determine the type of identifier and extract relevant components.
def _stop_heartbeat(self):
"""Terminate the heartbeat loop."""
self.alive = False
The method supports different identifier formats:
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
- An identifier containing an API key and a model ID separated by an underscore
- An identifier that is solely a model ID of a fixed length
- A local filename that ends with '.pt' or '.yaml'
Args:
identifier (str): The identifier string to be parsed.
Returns:
(tuple): A tuple containing the API key, model ID, and filename as applicable.
Raises:
HUBModelError: If the identifier format is not recognized.
"""
# Initialize variables
api_key, model_id, filename = None, None, None
# Check if identifier is a HUB URL
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
# Extract the model_id after the HUB_WEB_ROOT URL
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
else:
# Split the identifier based on underscores only if it's not a HUB URL
parts = identifier.split("_")
# Check if identifier is in the format of API key and model ID
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
api_key, model_id = parts
# Check if identifier is a single model ID
elif len(parts) == 1 and len(parts[0]) == 20:
model_id = parts[0]
# Check if identifier is a local filename
elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
filename = identifier
else:
raise HUBModelError(
f"model='{identifier}' could not be parsed. Check format is correct. "
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
)
return api_key, model_id, filename
def _set_train_args(self):
"""
Initializes training arguments and creates a model entry on the Ultralytics HUB.
This method sets up training arguments based on the model's state and updates them with any additional
arguments provided. It handles different states of the model, such as whether it's resumable, pretrained,
or requires specific file setup.
Raises:
ValueError: If the model is already trained, if required dataset information is missing, or if there are
issues with the provided training arguments.
"""
if self.model.is_trained():
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
if self.model.is_resumable():
# Model has saved weights
self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
self.model_file = self.model.get_weights_url("last")
else:
# Model has no saved weights
self.train_args = self.model.data.get("train_args") # new response
# Set the model file as either a *.pt or *.yaml file
self.model_file = (
self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
)
if "data" not in self.train_args:
# RF bug - datasets are sometimes not exported
raise ValueError("Dataset may still be processing. Please wait a minute and try again.")
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
self.model_id = self.model.id
def request_queue(
self,
request_func,
retry=3,
timeout=30,
thread=True,
verbose=True,
progress_total=None,
*args,
**kwargs,
):
def retry_request():
"""Attempts to call `request_func` with retries, timeout, and optional threading."""
t0 = time.time() # Record the start time for the timeout
for i in range(retry + 1):
if (time.time() - t0) > timeout:
LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
break # Timeout reached, exit loop
response = request_func(*args, **kwargs)
if response is None:
LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
time.sleep(2**i) # Exponential backoff before retrying
continue # Skip further processing and retry
if progress_total:
self._show_upload_progress(progress_total, response)
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
# if request related to metrics upload
if kwargs.get("metrics"):
self.metrics_upload_failed_queue = {}
return response # Success, no need to retry
if i == 0:
# Initial attempt, check status code and provide messages
message = self._get_failure_message(response, retry, timeout)
if verbose:
LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})")
if not self._should_retry(response.status_code):
LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}")
break # Not an error that should be retried, exit loop
time.sleep(2**i) # Exponential backoff for retries
# if request related to metrics upload and exceed retries
if response is None and kwargs.get("metrics"):
self.metrics_upload_failed_queue.update(kwargs.get("metrics", None))
return response
if thread:
# Start a new thread to run the retry_request function
threading.Thread(target=retry_request, daemon=True).start()
else:
# If running in the main thread, call retry_request directly
return retry_request()
def _should_retry(self, status_code):
"""Determines if a request should be retried based on the HTTP status code."""
retry_codes = {
HTTPStatus.REQUEST_TIMEOUT,
HTTPStatus.BAD_GATEWAY,
HTTPStatus.GATEWAY_TIMEOUT,
}
return status_code in retry_codes
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
"""
Generate a retry message based on the response status code.
Args:
response: The HTTP response object.
retry: The number of retry attempts allowed.
timeout: The maximum timeout duration.
Returns:
(str): The retry message.
"""
if self._should_retry(response.status_code):
return f"Retrying {retry}x for {timeout}s." if retry else ""
elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit
headers = response.headers
return (
f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
f"Please retry after {headers['Retry-After']}s."
)
else:
try:
return response.json().get("message", "No JSON message.")
except AttributeError:
return "Unable to read JSON."
def upload_metrics(self):
"""Upload model metrics to Ultralytics HUB."""
payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True)
def _get_model(self):
"""Fetch and return model data from Ultralytics HUB."""
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
try:
response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0)
data = response.json().get('data', None)
if data.get('status', None) == 'trained':
raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
if not data.get('data', None):
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
self.model_id = data['id']
if data['status'] == 'new': # new model to start training
self.train_args = {
# TODO: deprecate 'batch_size' key for 'batch' in 3Q23
'batch': data['batch' if ('batch' in data) else 'batch_size'],
'epochs': data['epochs'],
'imgsz': data['imgsz'],
'patience': data['patience'],
'device': data['device'],
'cache': data['cache'],
'data': data['data']}
self.model_file = data.get('cfg') or data.get('weights') # cfg for pretrained=False
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
elif data['status'] == 'training': # existing model to resume training
self.train_args = {'data': data['data'], 'resume': True}
self.model_file = data['resume']
return data
except requests.exceptions.ConnectionError as e:
raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
except Exception:
raise
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
def upload_model(
self,
epoch: int,
weights: str,
is_best: bool = False,
map: float = 0.0,
final: bool = False,
) -> None:
"""
Upload a model checkpoint to Ultralytics HUB.
@ -149,42 +323,33 @@ class HUBTrainingSession:
final (bool): Indicates if the model is the final model after training.
"""
if Path(weights).is_file():
with open(weights, 'rb') as f:
file = f.read()
progress_total = Path(weights).stat().st_size if final else None # Only show progress if final
self.request_queue(
self.model.upload_model,
epoch=epoch,
weights=weights,
is_best=is_best,
map=map,
final=final,
retry=10,
timeout=3600,
thread=not final,
progress_total=progress_total,
)
else:
LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
file = None
url = f'{self.api_url}/upload'
# url = 'http://httpbin.org/post' # for debug
data = {'epoch': epoch}
if final:
data.update({'type': 'final', 'map': map})
smart_request('post',
url,
data=data,
files={'best.pt': file},
headers=self.auth_header,
retry=10,
timeout=3600,
thread=False,
progress=True,
code=4)
else:
data.update({'type': 'epoch', 'isBest': bool(is_best)})
smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3)
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.")
@threaded
def _start_heartbeat(self):
"""Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB."""
while self.alive:
r = smart_request('post',
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
json={
'agent': AGENT_NAME,
'agentId': self.agent_id},
headers=self.auth_header,
retry=0,
code=5,
thread=False) # already in a thread
self.agent_id = r.json().get('data', {}).get('agentId', None)
sleep(self.rate_limits['heartbeat'])
def _show_upload_progress(self, content_length: int, response: requests.Response) -> None:
"""
Display a progress bar to track the upload progress of a file download.
Args:
content_length (int): The total size of the content to be downloaded in bytes.
response (requests.Response): The response object from the file download request.
Returns:
None
"""
with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
for data in response.iter_content(chunk_size=1024):
pbar.update(len(data))

View File

@ -10,14 +10,29 @@ from pathlib import Path
import requests
from ultralytics.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM, TryExcept, __version__,
colorstr, get_git_origin_url, is_colab, is_git_dir, is_pip_package)
from ultralytics.utils import (
ENVIRONMENT,
LOGGER,
ONLINE,
RANK,
SETTINGS,
TESTS_RUNNING,
TQDM,
TryExcept,
__version__,
colorstr,
get_git_origin_url,
is_colab,
is_git_dir,
is_pip_package,
)
from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
PREFIX = colorstr('Ultralytics HUB: ')
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com')
HUB_WEB_ROOT = os.environ.get('ULTRALYTICS_HUB_WEB', 'https://hub.ultralytics.com')
HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com")
HUB_WEB_ROOT = os.environ.get("ULTRALYTICS_HUB_WEB", "https://hub.ultralytics.com")
PREFIX = colorstr("Ultralytics HUB: ")
HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance."
def request_with_credentials(url: str) -> any:
@ -34,11 +49,13 @@ def request_with_credentials(url: str) -> any:
OSError: If the function is not run in a Google Colab environment.
"""
if not is_colab():
raise OSError('request_with_credentials() must run in a Colab environment')
raise OSError("request_with_credentials() must run in a Colab environment")
from google.colab import output # noqa
from IPython import display # noqa
display.display(
display.Javascript("""
display.Javascript(
"""
window._hub_tmp = new Promise((resolve, reject) => {
const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
fetch("%s", {
@ -53,8 +70,11 @@ def request_with_credentials(url: str) -> any:
reject(err);
});
});
""" % url))
return output.eval_js('_hub_tmp')
"""
% url
)
)
return output.eval_js("_hub_tmp")
def requests_with_progress(method, url, **kwargs):
@ -64,22 +84,23 @@ def requests_with_progress(method, url, **kwargs):
Args:
method (str): The HTTP method to use (e.g. 'GET', 'POST').
url (str): The URL to send the request to.
**kwargs (dict): Additional keyword arguments to pass to the underlying `requests.request` function.
**kwargs (any): Additional keyword arguments to pass to the underlying `requests.request` function.
Returns:
(requests.Response): The response object from the HTTP request.
Note:
If 'progress' is set to True, the progress bar will display the download progress
for responses with a known content length.
- If 'progress' is set to True, the progress bar will display the download progress for responses with a known
content length.
- If 'progress' is a number then progress bar will display assuming content length = progress.
"""
progress = kwargs.pop('progress', False)
progress = kwargs.pop("progress", False)
if not progress:
return requests.request(method, url, **kwargs)
response = requests.request(method, url, stream=True, **kwargs)
total = int(response.headers.get('content-length', 0)) # total size
total = int(response.headers.get("content-length", 0) if isinstance(progress, bool) else progress) # total size
try:
pbar = TQDM(total=total, unit='B', unit_scale=True, unit_divisor=1024)
pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024)
for data in response.iter_content(chunk_size=1024):
pbar.update(len(data))
pbar.close()
@ -101,7 +122,7 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
code (int, optional): An identifier for the request, used for logging purposes. Default is -1.
verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True.
progress (bool, optional): Whether to show a progress bar during the request. Default is False.
**kwargs (dict): Keyword arguments to be passed to the requests function specified in method.
**kwargs (any): Keyword arguments to be passed to the requests function specified in method.
Returns:
(requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None.
@ -120,25 +141,27 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful"
break
try:
m = r.json().get('message', 'No JSON message.')
m = r.json().get("message", "No JSON message.")
except AttributeError:
m = 'Unable to read JSON.'
m = "Unable to read JSON."
if i == 0:
if r.status_code in retry_codes:
m += f' Retrying {retry}x for {timeout}s.' if retry else ''
m += f" Retrying {retry}x for {timeout}s." if retry else ""
elif r.status_code == 429: # rate limit
h = r.headers # response headers
m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
m = (
f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). "
f"Please retry after {h['Retry-After']}s."
)
if verbose:
LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})')
LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})")
if r.status_code not in retry_codes:
return r
time.sleep(2 ** i) # exponential standoff
time.sleep(2**i) # exponential standoff
return r
args = method, url
kwargs['progress'] = progress
kwargs["progress"] = progress
if thread:
threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
else:
@ -157,29 +180,29 @@ class Events:
enabled (bool): A flag to enable or disable Events based on certain conditions.
"""
url = 'https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw'
url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"
def __init__(self):
"""
Initializes the Events object with default values for events, rate_limit, and metadata.
"""
"""Initializes the Events object with default values for events, rate_limit, and metadata."""
self.events = [] # events list
self.rate_limit = 60.0 # rate limit (seconds)
self.t = 0.0 # rate limit timer (seconds)
self.metadata = {
'cli': Path(sys.argv[0]).name == 'yolo',
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
'python': '.'.join(platform.python_version_tuple()[:2]), # i.e. 3.10
'version': __version__,
'env': ENVIRONMENT,
'session_id': round(random.random() * 1E15),
'engagement_time_msec': 1000}
self.enabled = \
SETTINGS['sync'] and \
RANK in (-1, 0) and \
not TESTS_RUNNING and \
ONLINE and \
(is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
"cli": Path(sys.argv[0]).name == "yolo",
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
"python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10
"version": __version__,
"env": ENVIRONMENT,
"session_id": round(random.random() * 1e15),
"engagement_time_msec": 1000,
}
self.enabled = (
SETTINGS["sync"]
and RANK in (-1, 0)
and not TESTS_RUNNING
and ONLINE
and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
)
def __call__(self, cfg):
"""
@ -195,11 +218,13 @@ class Events:
# Attempt to add to events
if len(self.events) < 25: # Events list limited to 25 events (drop any events past this)
params = {
**self.metadata, 'task': cfg.task,
'model': cfg.model if cfg.model in GITHUB_ASSETS_NAMES else 'custom'}
if cfg.mode == 'export':
params['format'] = cfg.format
self.events.append({'name': cfg.mode, 'params': params})
**self.metadata,
"task": cfg.task,
"model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom",
}
if cfg.mode == "export":
params["format"] = cfg.format
self.events.append({"name": cfg.mode, "params": params})
# Check rate limit
t = time.time()
@ -208,10 +233,10 @@ class Events:
return
# Time is over rate limiter, send now
data = {'client_id': SETTINGS['uuid'], 'events': self.events} # SHA-256 anonymized UUID hash and events list
data = {"client_id": SETTINGS["uuid"], "events": self.events} # SHA-256 anonymized UUID hash and events list
# POST equivalent to requests.post(self.url, json=data)
smart_request('post', self.url, json=data, retry=0, verbose=False)
smart_request("post", self.url, json=data, retry=0, verbose=False)
# Reset events and rate limit timer
self.events = []