add yolo v10 and modify pipeline
This commit is contained in:
@ -5,24 +5,51 @@ import requests
|
||||
from ultralytics.data.utils import HUBDatasetStats
|
||||
from ultralytics.hub.auth import Auth
|
||||
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
|
||||
from ultralytics.utils import LOGGER, SETTINGS
|
||||
from ultralytics.utils import LOGGER, SETTINGS, checks
|
||||
|
||||
|
||||
def login(api_key=''):
|
||||
def login(api_key: str = None, save=True) -> bool:
|
||||
"""
|
||||
Log in to the Ultralytics HUB API using the provided API key.
|
||||
|
||||
The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY
|
||||
environment variable if successfully authenticated.
|
||||
|
||||
Args:
|
||||
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
|
||||
api_key (str, optional): API key to use for authentication.
|
||||
If not provided, it will be retrieved from SETTINGS or HUB_API_KEY environment variable.
|
||||
save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics import hub
|
||||
|
||||
hub.login('API_KEY')
|
||||
```
|
||||
Returns:
|
||||
(bool): True if authentication is successful, False otherwise.
|
||||
"""
|
||||
Auth(api_key, verbose=True)
|
||||
checks.check_requirements("hub-sdk>=0.0.6")
|
||||
from hub_sdk import HUBClient
|
||||
|
||||
api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL
|
||||
saved_key = SETTINGS.get("api_key")
|
||||
active_key = api_key or saved_key
|
||||
credentials = {"api_key": active_key} if active_key and active_key != "" else None # set credentials
|
||||
|
||||
client = HUBClient(credentials) # initialize HUBClient
|
||||
|
||||
if client.authenticated:
|
||||
# Successfully authenticated with HUB
|
||||
|
||||
if save and client.api_key != saved_key:
|
||||
SETTINGS.update({"api_key": client.api_key}) # update settings with valid API key
|
||||
|
||||
# Set message based on whether key was provided or retrieved from settings
|
||||
log_message = (
|
||||
"New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅"
|
||||
)
|
||||
LOGGER.info(f"{PREFIX}{log_message}")
|
||||
|
||||
return True
|
||||
else:
|
||||
# Failed to authenticate with HUB
|
||||
LOGGER.info(f"{PREFIX}Get API key from {api_key_url} and then run 'yolo hub login API_KEY'")
|
||||
return False
|
||||
|
||||
|
||||
def logout():
|
||||
@ -36,52 +63,53 @@ def logout():
|
||||
hub.logout()
|
||||
```
|
||||
"""
|
||||
SETTINGS['api_key'] = ''
|
||||
SETTINGS["api_key"] = ""
|
||||
SETTINGS.save()
|
||||
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
|
||||
|
||||
|
||||
def reset_model(model_id=''):
|
||||
def reset_model(model_id=""):
|
||||
"""Reset a trained model to an untrained state."""
|
||||
r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
|
||||
r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
|
||||
if r.status_code == 200:
|
||||
LOGGER.info(f'{PREFIX}Model reset successfully')
|
||||
LOGGER.info(f"{PREFIX}Model reset successfully")
|
||||
return
|
||||
LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
|
||||
LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}")
|
||||
|
||||
|
||||
def export_fmts_hub():
|
||||
"""Returns a list of HUB-supported export formats."""
|
||||
from ultralytics.engine.exporter import export_formats
|
||||
return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
|
||||
|
||||
return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
|
||||
|
||||
|
||||
def export_model(model_id='', format='torchscript'):
|
||||
def export_model(model_id="", format="torchscript"):
|
||||
"""Export a model to all formats."""
|
||||
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
||||
r = requests.post(f'{HUB_API_ROOT}/v1/models/{model_id}/export',
|
||||
json={'format': format},
|
||||
headers={'x-api-key': Auth().api_key})
|
||||
assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
|
||||
LOGGER.info(f'{PREFIX}{format} export started ✅')
|
||||
r = requests.post(
|
||||
f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
|
||||
)
|
||||
assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
|
||||
LOGGER.info(f"{PREFIX}{format} export started ✅")
|
||||
|
||||
|
||||
def get_export(model_id='', format='torchscript'):
|
||||
def get_export(model_id="", format="torchscript"):
|
||||
"""Get an exported model dictionary with download URL."""
|
||||
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
||||
r = requests.post(f'{HUB_API_ROOT}/get-export',
|
||||
json={
|
||||
'apiKey': Auth().api_key,
|
||||
'modelId': model_id,
|
||||
'format': format})
|
||||
assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
|
||||
r = requests.post(
|
||||
f"{HUB_API_ROOT}/get-export",
|
||||
json={"apiKey": Auth().api_key, "modelId": model_id, "format": format},
|
||||
headers={"x-api-key": Auth().api_key},
|
||||
)
|
||||
assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
|
||||
return r.json()
|
||||
|
||||
|
||||
def check_dataset(path='', task='detect'):
|
||||
def check_dataset(path="", task="detect"):
|
||||
"""
|
||||
Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is
|
||||
uploaded to the HUB. Usage examples are given below.
|
||||
Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded
|
||||
to the HUB. Usage examples are given below.
|
||||
|
||||
Args:
|
||||
path (str, optional): Path to data.zip (with data.yaml inside data.zip). Defaults to ''.
|
||||
@ -97,4 +125,4 @@ def check_dataset(path='', task='detect'):
|
||||
```
|
||||
"""
|
||||
HUBDatasetStats(path=path, task=task).get_json()
|
||||
LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.')
|
||||
LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.")
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -5,13 +5,27 @@ import requests
|
||||
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials
|
||||
from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab
|
||||
|
||||
API_KEY_URL = f'{HUB_WEB_ROOT}/settings?tab=api+keys'
|
||||
API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"
|
||||
|
||||
|
||||
class Auth:
|
||||
"""
|
||||
Manages authentication processes including API key handling, cookie-based authentication, and header generation.
|
||||
|
||||
The class supports different methods of authentication:
|
||||
1. Directly using an API key.
|
||||
2. Authenticating using browser cookies (specifically in Google Colab).
|
||||
3. Prompting the user to enter an API key.
|
||||
|
||||
Attributes:
|
||||
id_token (str or bool): Token used for identity verification, initialized as False.
|
||||
api_key (str or bool): API key for authentication, initialized as False.
|
||||
model_key (bool): Placeholder for model key, initialized as False.
|
||||
"""
|
||||
|
||||
id_token = api_key = model_key = False
|
||||
|
||||
def __init__(self, api_key='', verbose=False):
|
||||
def __init__(self, api_key="", verbose=False):
|
||||
"""
|
||||
Initialize the Auth class with an optional API key.
|
||||
|
||||
@ -19,18 +33,18 @@ class Auth:
|
||||
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
|
||||
"""
|
||||
# Split the input API key in case it contains a combined key_model and keep only the API key part
|
||||
api_key = api_key.split('_')[0]
|
||||
api_key = api_key.split("_")[0]
|
||||
|
||||
# Set API key attribute as value passed or SETTINGS API key if none passed
|
||||
self.api_key = api_key or SETTINGS.get('api_key', '')
|
||||
self.api_key = api_key or SETTINGS.get("api_key", "")
|
||||
|
||||
# If an API key is provided
|
||||
if self.api_key:
|
||||
# If the provided API key matches the API key in the SETTINGS
|
||||
if self.api_key == SETTINGS.get('api_key'):
|
||||
if self.api_key == SETTINGS.get("api_key"):
|
||||
# Log that the user is already logged in
|
||||
if verbose:
|
||||
LOGGER.info(f'{PREFIX}Authenticated ✅')
|
||||
LOGGER.info(f"{PREFIX}Authenticated ✅")
|
||||
return
|
||||
else:
|
||||
# Attempt to authenticate with the provided API key
|
||||
@ -45,62 +59,65 @@ class Auth:
|
||||
|
||||
# Update SETTINGS with the new API key after successful authentication
|
||||
if success:
|
||||
SETTINGS.update({'api_key': self.api_key})
|
||||
SETTINGS.update({"api_key": self.api_key})
|
||||
# Log that the new login was successful
|
||||
if verbose:
|
||||
LOGGER.info(f'{PREFIX}New authentication successful ✅')
|
||||
LOGGER.info(f"{PREFIX}New authentication successful ✅")
|
||||
elif verbose:
|
||||
LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}')
|
||||
LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo hub login API_KEY'")
|
||||
|
||||
def request_api_key(self, max_attempts=3):
|
||||
"""
|
||||
Prompt the user to input their API key. Returns the model ID.
|
||||
Prompt the user to input their API key.
|
||||
|
||||
Returns the model ID.
|
||||
"""
|
||||
import getpass
|
||||
|
||||
for attempts in range(max_attempts):
|
||||
LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
|
||||
input_key = getpass.getpass(f'Enter API key from {API_KEY_URL} ')
|
||||
self.api_key = input_key.split('_')[0] # remove model id if present
|
||||
LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
|
||||
input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ")
|
||||
self.api_key = input_key.split("_")[0] # remove model id if present
|
||||
if self.authenticate():
|
||||
return True
|
||||
raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
|
||||
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
|
||||
|
||||
def authenticate(self) -> bool:
|
||||
"""
|
||||
Attempt to authenticate with the server using either id_token or API key.
|
||||
|
||||
Returns:
|
||||
bool: True if authentication is successful, False otherwise.
|
||||
(bool): True if authentication is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if header := self.get_auth_header():
|
||||
r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
|
||||
if not r.json().get('success', False):
|
||||
raise ConnectionError('Unable to authenticate.')
|
||||
r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
|
||||
if not r.json().get("success", False):
|
||||
raise ConnectionError("Unable to authenticate.")
|
||||
return True
|
||||
raise ConnectionError('User has not authenticated locally.')
|
||||
raise ConnectionError("User has not authenticated locally.")
|
||||
except ConnectionError:
|
||||
self.id_token = self.api_key = False # reset invalid
|
||||
LOGGER.warning(f'{PREFIX}Invalid API key ⚠️')
|
||||
LOGGER.warning(f"{PREFIX}Invalid API key ⚠️")
|
||||
return False
|
||||
|
||||
def auth_with_cookies(self) -> bool:
|
||||
"""
|
||||
Attempt to fetch authentication via cookies and set id_token.
|
||||
User must be logged in to HUB and running in a supported browser.
|
||||
Attempt to fetch authentication via cookies and set id_token. User must be logged in to HUB and running in a
|
||||
supported browser.
|
||||
|
||||
Returns:
|
||||
bool: True if authentication is successful, False otherwise.
|
||||
(bool): True if authentication is successful, False otherwise.
|
||||
"""
|
||||
if not is_colab():
|
||||
return False # Currently only works with Colab
|
||||
try:
|
||||
authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto')
|
||||
if authn.get('success', False):
|
||||
self.id_token = authn.get('data', {}).get('idToken', None)
|
||||
authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto")
|
||||
if authn.get("success", False):
|
||||
self.id_token = authn.get("data", {}).get("idToken", None)
|
||||
self.authenticate()
|
||||
return True
|
||||
raise ConnectionError('Unable to fetch browser authentication details.')
|
||||
raise ConnectionError("Unable to fetch browser authentication details.")
|
||||
except ConnectionError:
|
||||
self.id_token = False # reset invalid
|
||||
return False
|
||||
@ -113,7 +130,7 @@ class Auth:
|
||||
(dict): The authentication header if id_token or API key is set, None otherwise.
|
||||
"""
|
||||
if self.id_token:
|
||||
return {'authorization': f'Bearer {self.id_token}'}
|
||||
return {"authorization": f"Bearer {self.id_token}"}
|
||||
elif self.api_key:
|
||||
return {'x-api-key': self.api_key}
|
||||
return {"x-api-key": self.api_key}
|
||||
# else returns None
|
||||
|
@ -1,29 +1,26 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
|
||||
import requests
|
||||
|
||||
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, smart_request
|
||||
from ultralytics.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
|
||||
from ultralytics.hub.utils import HUB_WEB_ROOT, HELP_MSG, PREFIX, TQDM
|
||||
from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab
|
||||
from ultralytics.utils.errors import HUBModelError
|
||||
|
||||
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
|
||||
AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local"
|
||||
|
||||
|
||||
class HUBTrainingSession:
|
||||
"""
|
||||
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
|
||||
|
||||
Args:
|
||||
url (str): Model identifier used to initialize the HUB training session.
|
||||
|
||||
Attributes:
|
||||
agent_id (str): Identifier for the instance communicating with the server.
|
||||
model_id (str): Identifier for the YOLOv5 model being trained.
|
||||
model_id (str): Identifier for the YOLO model being trained.
|
||||
model_url (str): URL for the model in Ultralytics HUB.
|
||||
api_url (str): API URL for the model in Ultralytics HUB.
|
||||
auth_header (dict): Authentication header for the Ultralytics HUB API requests.
|
||||
@ -34,110 +31,287 @@ class HUBTrainingSession:
|
||||
alive (bool): Indicates if the heartbeat loop is active.
|
||||
"""
|
||||
|
||||
def __init__(self, url):
|
||||
def __init__(self, identifier):
|
||||
"""
|
||||
Initialize the HUBTrainingSession with the provided model identifier.
|
||||
|
||||
Args:
|
||||
url (str): Model identifier used to initialize the HUB training session.
|
||||
It can be a URL string or a model key with specific format.
|
||||
identifier (str): Model identifier used to initialize the HUB training session.
|
||||
It can be a URL string or a model key with specific format.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided model identifier is invalid.
|
||||
ConnectionError: If connecting with global API key is not supported.
|
||||
ModuleNotFoundError: If hub-sdk package is not installed.
|
||||
"""
|
||||
from hub_sdk import HUBClient
|
||||
|
||||
from ultralytics.hub.auth import Auth
|
||||
self.rate_limits = {
|
||||
"metrics": 3.0,
|
||||
"ckpt": 900.0,
|
||||
"heartbeat": 300.0,
|
||||
} # rate limits (seconds)
|
||||
self.metrics_queue = {} # holds metrics for each epoch until upload
|
||||
self.metrics_upload_failed_queue = {} # holds metrics for each epoch if upload failed
|
||||
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
|
||||
|
||||
# Parse input
|
||||
if url.startswith(f'{HUB_WEB_ROOT}/models/'):
|
||||
url = url.split(f'{HUB_WEB_ROOT}/models/')[-1]
|
||||
if [len(x) for x in url.split('_')] == [42, 20]:
|
||||
key, model_id = url.split('_')
|
||||
elif len(url) == 20:
|
||||
key, model_id = '', url
|
||||
api_key, model_id, self.filename = self._parse_identifier(identifier)
|
||||
|
||||
# Get credentials
|
||||
active_key = api_key or SETTINGS.get("api_key")
|
||||
credentials = {"api_key": active_key} if active_key else None # set credentials
|
||||
|
||||
# Initialize client
|
||||
self.client = HUBClient(credentials)
|
||||
|
||||
if model_id:
|
||||
self.load_model(model_id) # load existing model
|
||||
else:
|
||||
raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. "
|
||||
f"model='{HUB_WEB_ROOT}/models/MODEL_ID' and try again.")
|
||||
self.model = self.client.model() # load empty model
|
||||
|
||||
# Authorize
|
||||
auth = Auth(key)
|
||||
self.agent_id = None # identifies which instance is communicating with server
|
||||
self.model_id = model_id
|
||||
self.model_url = f'{HUB_WEB_ROOT}/models/{model_id}'
|
||||
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
|
||||
self.auth_header = auth.get_auth_header()
|
||||
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
|
||||
self.timers = {} # rate limit timers (seconds)
|
||||
self.metrics_queue = {} # metrics queue
|
||||
self.model = self._get_model()
|
||||
self.alive = True
|
||||
self._start_heartbeat() # start heartbeats
|
||||
self._register_signal_handlers()
|
||||
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
|
||||
def load_model(self, model_id):
|
||||
"""Loads an existing model from Ultralytics HUB using the provided model identifier."""
|
||||
self.model = self.client.model(model_id)
|
||||
if not self.model.data: # then model does not exist
|
||||
raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling
|
||||
|
||||
def _register_signal_handlers(self):
|
||||
"""Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination."""
|
||||
signal.signal(signal.SIGTERM, self._handle_signal)
|
||||
signal.signal(signal.SIGINT, self._handle_signal)
|
||||
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
||||
|
||||
def _handle_signal(self, signum, frame):
|
||||
self._set_train_args()
|
||||
|
||||
# Start heartbeats for HUB to monitor agent
|
||||
self.model.start_heartbeat(self.rate_limits["heartbeat"])
|
||||
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
||||
|
||||
def create_model(self, model_args):
|
||||
"""Initializes a HUB training session with the specified model identifier."""
|
||||
payload = {
|
||||
"config": {
|
||||
"batchSize": model_args.get("batch", -1),
|
||||
"epochs": model_args.get("epochs", 300),
|
||||
"imageSize": model_args.get("imgsz", 640),
|
||||
"patience": model_args.get("patience", 100),
|
||||
"device": model_args.get("device", ""),
|
||||
"cache": model_args.get("cache", "ram"),
|
||||
},
|
||||
"dataset": {"name": model_args.get("data")},
|
||||
"lineage": {
|
||||
"architecture": {
|
||||
"name": self.filename.replace(".pt", "").replace(".yaml", ""),
|
||||
},
|
||||
"parent": {},
|
||||
},
|
||||
"meta": {"name": self.filename},
|
||||
}
|
||||
|
||||
if self.filename.endswith(".pt"):
|
||||
payload["lineage"]["parent"]["name"] = self.filename
|
||||
|
||||
self.model.create_model(payload)
|
||||
|
||||
# Model could not be created
|
||||
# TODO: improve error handling
|
||||
if not self.model.id:
|
||||
return
|
||||
|
||||
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
||||
|
||||
# Start heartbeats for HUB to monitor agent
|
||||
self.model.start_heartbeat(self.rate_limits["heartbeat"])
|
||||
|
||||
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
||||
|
||||
def _parse_identifier(self, identifier):
|
||||
"""
|
||||
Handle kill signals and prevent heartbeats from being sent on Colab after termination.
|
||||
This method does not use frame, it is included as it is passed by signal.
|
||||
"""
|
||||
if self.alive is True:
|
||||
LOGGER.info(f'{PREFIX}Kill signal received! ❌')
|
||||
self._stop_heartbeat()
|
||||
sys.exit(signum)
|
||||
Parses the given identifier to determine the type of identifier and extract relevant components.
|
||||
|
||||
def _stop_heartbeat(self):
|
||||
"""Terminate the heartbeat loop."""
|
||||
self.alive = False
|
||||
The method supports different identifier formats:
|
||||
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
|
||||
- An identifier containing an API key and a model ID separated by an underscore
|
||||
- An identifier that is solely a model ID of a fixed length
|
||||
- A local filename that ends with '.pt' or '.yaml'
|
||||
|
||||
Args:
|
||||
identifier (str): The identifier string to be parsed.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing the API key, model ID, and filename as applicable.
|
||||
|
||||
Raises:
|
||||
HUBModelError: If the identifier format is not recognized.
|
||||
"""
|
||||
|
||||
# Initialize variables
|
||||
api_key, model_id, filename = None, None, None
|
||||
|
||||
# Check if identifier is a HUB URL
|
||||
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
||||
# Extract the model_id after the HUB_WEB_ROOT URL
|
||||
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
|
||||
else:
|
||||
# Split the identifier based on underscores only if it's not a HUB URL
|
||||
parts = identifier.split("_")
|
||||
|
||||
# Check if identifier is in the format of API key and model ID
|
||||
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
|
||||
api_key, model_id = parts
|
||||
# Check if identifier is a single model ID
|
||||
elif len(parts) == 1 and len(parts[0]) == 20:
|
||||
model_id = parts[0]
|
||||
# Check if identifier is a local filename
|
||||
elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
|
||||
filename = identifier
|
||||
else:
|
||||
raise HUBModelError(
|
||||
f"model='{identifier}' could not be parsed. Check format is correct. "
|
||||
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
|
||||
)
|
||||
|
||||
return api_key, model_id, filename
|
||||
|
||||
def _set_train_args(self):
|
||||
"""
|
||||
Initializes training arguments and creates a model entry on the Ultralytics HUB.
|
||||
|
||||
This method sets up training arguments based on the model's state and updates them with any additional
|
||||
arguments provided. It handles different states of the model, such as whether it's resumable, pretrained,
|
||||
or requires specific file setup.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is already trained, if required dataset information is missing, or if there are
|
||||
issues with the provided training arguments.
|
||||
"""
|
||||
if self.model.is_trained():
|
||||
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
|
||||
|
||||
if self.model.is_resumable():
|
||||
# Model has saved weights
|
||||
self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
|
||||
self.model_file = self.model.get_weights_url("last")
|
||||
else:
|
||||
# Model has no saved weights
|
||||
self.train_args = self.model.data.get("train_args") # new response
|
||||
|
||||
# Set the model file as either a *.pt or *.yaml file
|
||||
self.model_file = (
|
||||
self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
|
||||
)
|
||||
|
||||
if "data" not in self.train_args:
|
||||
# RF bug - datasets are sometimes not exported
|
||||
raise ValueError("Dataset may still be processing. Please wait a minute and try again.")
|
||||
|
||||
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
|
||||
self.model_id = self.model.id
|
||||
|
||||
def request_queue(
|
||||
self,
|
||||
request_func,
|
||||
retry=3,
|
||||
timeout=30,
|
||||
thread=True,
|
||||
verbose=True,
|
||||
progress_total=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
def retry_request():
|
||||
"""Attempts to call `request_func` with retries, timeout, and optional threading."""
|
||||
t0 = time.time() # Record the start time for the timeout
|
||||
for i in range(retry + 1):
|
||||
if (time.time() - t0) > timeout:
|
||||
LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
|
||||
break # Timeout reached, exit loop
|
||||
|
||||
response = request_func(*args, **kwargs)
|
||||
if response is None:
|
||||
LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
|
||||
time.sleep(2**i) # Exponential backoff before retrying
|
||||
continue # Skip further processing and retry
|
||||
|
||||
if progress_total:
|
||||
self._show_upload_progress(progress_total, response)
|
||||
|
||||
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
||||
# if request related to metrics upload
|
||||
if kwargs.get("metrics"):
|
||||
self.metrics_upload_failed_queue = {}
|
||||
return response # Success, no need to retry
|
||||
|
||||
if i == 0:
|
||||
# Initial attempt, check status code and provide messages
|
||||
message = self._get_failure_message(response, retry, timeout)
|
||||
|
||||
if verbose:
|
||||
LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})")
|
||||
|
||||
if not self._should_retry(response.status_code):
|
||||
LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}")
|
||||
break # Not an error that should be retried, exit loop
|
||||
|
||||
time.sleep(2**i) # Exponential backoff for retries
|
||||
|
||||
# if request related to metrics upload and exceed retries
|
||||
if response is None and kwargs.get("metrics"):
|
||||
self.metrics_upload_failed_queue.update(kwargs.get("metrics", None))
|
||||
|
||||
return response
|
||||
|
||||
if thread:
|
||||
# Start a new thread to run the retry_request function
|
||||
threading.Thread(target=retry_request, daemon=True).start()
|
||||
else:
|
||||
# If running in the main thread, call retry_request directly
|
||||
return retry_request()
|
||||
|
||||
def _should_retry(self, status_code):
|
||||
"""Determines if a request should be retried based on the HTTP status code."""
|
||||
retry_codes = {
|
||||
HTTPStatus.REQUEST_TIMEOUT,
|
||||
HTTPStatus.BAD_GATEWAY,
|
||||
HTTPStatus.GATEWAY_TIMEOUT,
|
||||
}
|
||||
return status_code in retry_codes
|
||||
|
||||
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
|
||||
"""
|
||||
Generate a retry message based on the response status code.
|
||||
|
||||
Args:
|
||||
response: The HTTP response object.
|
||||
retry: The number of retry attempts allowed.
|
||||
timeout: The maximum timeout duration.
|
||||
|
||||
Returns:
|
||||
(str): The retry message.
|
||||
"""
|
||||
if self._should_retry(response.status_code):
|
||||
return f"Retrying {retry}x for {timeout}s." if retry else ""
|
||||
elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit
|
||||
headers = response.headers
|
||||
return (
|
||||
f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
|
||||
f"Please retry after {headers['Retry-After']}s."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
return response.json().get("message", "No JSON message.")
|
||||
except AttributeError:
|
||||
return "Unable to read JSON."
|
||||
|
||||
def upload_metrics(self):
|
||||
"""Upload model metrics to Ultralytics HUB."""
|
||||
payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
|
||||
smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
|
||||
return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True)
|
||||
|
||||
def _get_model(self):
|
||||
"""Fetch and return model data from Ultralytics HUB."""
|
||||
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
|
||||
|
||||
try:
|
||||
response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0)
|
||||
data = response.json().get('data', None)
|
||||
|
||||
if data.get('status', None) == 'trained':
|
||||
raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
|
||||
|
||||
if not data.get('data', None):
|
||||
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
|
||||
self.model_id = data['id']
|
||||
|
||||
if data['status'] == 'new': # new model to start training
|
||||
self.train_args = {
|
||||
# TODO: deprecate 'batch_size' key for 'batch' in 3Q23
|
||||
'batch': data['batch' if ('batch' in data) else 'batch_size'],
|
||||
'epochs': data['epochs'],
|
||||
'imgsz': data['imgsz'],
|
||||
'patience': data['patience'],
|
||||
'device': data['device'],
|
||||
'cache': data['cache'],
|
||||
'data': data['data']}
|
||||
self.model_file = data.get('cfg') or data.get('weights') # cfg for pretrained=False
|
||||
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
|
||||
elif data['status'] == 'training': # existing model to resume training
|
||||
self.train_args = {'data': data['data'], 'resume': True}
|
||||
self.model_file = data['resume']
|
||||
|
||||
return data
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
||||
def upload_model(
|
||||
self,
|
||||
epoch: int,
|
||||
weights: str,
|
||||
is_best: bool = False,
|
||||
map: float = 0.0,
|
||||
final: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Upload a model checkpoint to Ultralytics HUB.
|
||||
|
||||
@ -149,42 +323,33 @@ class HUBTrainingSession:
|
||||
final (bool): Indicates if the model is the final model after training.
|
||||
"""
|
||||
if Path(weights).is_file():
|
||||
with open(weights, 'rb') as f:
|
||||
file = f.read()
|
||||
progress_total = Path(weights).stat().st_size if final else None # Only show progress if final
|
||||
self.request_queue(
|
||||
self.model.upload_model,
|
||||
epoch=epoch,
|
||||
weights=weights,
|
||||
is_best=is_best,
|
||||
map=map,
|
||||
final=final,
|
||||
retry=10,
|
||||
timeout=3600,
|
||||
thread=not final,
|
||||
progress_total=progress_total,
|
||||
)
|
||||
else:
|
||||
LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
|
||||
file = None
|
||||
url = f'{self.api_url}/upload'
|
||||
# url = 'http://httpbin.org/post' # for debug
|
||||
data = {'epoch': epoch}
|
||||
if final:
|
||||
data.update({'type': 'final', 'map': map})
|
||||
smart_request('post',
|
||||
url,
|
||||
data=data,
|
||||
files={'best.pt': file},
|
||||
headers=self.auth_header,
|
||||
retry=10,
|
||||
timeout=3600,
|
||||
thread=False,
|
||||
progress=True,
|
||||
code=4)
|
||||
else:
|
||||
data.update({'type': 'epoch', 'isBest': bool(is_best)})
|
||||
smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3)
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.")
|
||||
|
||||
@threaded
|
||||
def _start_heartbeat(self):
|
||||
"""Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB."""
|
||||
while self.alive:
|
||||
r = smart_request('post',
|
||||
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
|
||||
json={
|
||||
'agent': AGENT_NAME,
|
||||
'agentId': self.agent_id},
|
||||
headers=self.auth_header,
|
||||
retry=0,
|
||||
code=5,
|
||||
thread=False) # already in a thread
|
||||
self.agent_id = r.json().get('data', {}).get('agentId', None)
|
||||
sleep(self.rate_limits['heartbeat'])
|
||||
def _show_upload_progress(self, content_length: int, response: requests.Response) -> None:
|
||||
"""
|
||||
Display a progress bar to track the upload progress of a file download.
|
||||
|
||||
Args:
|
||||
content_length (int): The total size of the content to be downloaded in bytes.
|
||||
response (requests.Response): The response object from the file download request.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
|
||||
for data in response.iter_content(chunk_size=1024):
|
||||
pbar.update(len(data))
|
||||
|
@ -10,14 +10,29 @@ from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
from ultralytics.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM, TryExcept, __version__,
|
||||
colorstr, get_git_origin_url, is_colab, is_git_dir, is_pip_package)
|
||||
from ultralytics.utils import (
|
||||
ENVIRONMENT,
|
||||
LOGGER,
|
||||
ONLINE,
|
||||
RANK,
|
||||
SETTINGS,
|
||||
TESTS_RUNNING,
|
||||
TQDM,
|
||||
TryExcept,
|
||||
__version__,
|
||||
colorstr,
|
||||
get_git_origin_url,
|
||||
is_colab,
|
||||
is_git_dir,
|
||||
is_pip_package,
|
||||
)
|
||||
from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
|
||||
|
||||
PREFIX = colorstr('Ultralytics HUB: ')
|
||||
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
||||
HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com')
|
||||
HUB_WEB_ROOT = os.environ.get('ULTRALYTICS_HUB_WEB', 'https://hub.ultralytics.com')
|
||||
HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com")
|
||||
HUB_WEB_ROOT = os.environ.get("ULTRALYTICS_HUB_WEB", "https://hub.ultralytics.com")
|
||||
|
||||
PREFIX = colorstr("Ultralytics HUB: ")
|
||||
HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance."
|
||||
|
||||
|
||||
def request_with_credentials(url: str) -> any:
|
||||
@ -34,11 +49,13 @@ def request_with_credentials(url: str) -> any:
|
||||
OSError: If the function is not run in a Google Colab environment.
|
||||
"""
|
||||
if not is_colab():
|
||||
raise OSError('request_with_credentials() must run in a Colab environment')
|
||||
raise OSError("request_with_credentials() must run in a Colab environment")
|
||||
from google.colab import output # noqa
|
||||
from IPython import display # noqa
|
||||
|
||||
display.display(
|
||||
display.Javascript("""
|
||||
display.Javascript(
|
||||
"""
|
||||
window._hub_tmp = new Promise((resolve, reject) => {
|
||||
const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
|
||||
fetch("%s", {
|
||||
@ -53,8 +70,11 @@ def request_with_credentials(url: str) -> any:
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
""" % url))
|
||||
return output.eval_js('_hub_tmp')
|
||||
"""
|
||||
% url
|
||||
)
|
||||
)
|
||||
return output.eval_js("_hub_tmp")
|
||||
|
||||
|
||||
def requests_with_progress(method, url, **kwargs):
|
||||
@ -64,22 +84,23 @@ def requests_with_progress(method, url, **kwargs):
|
||||
Args:
|
||||
method (str): The HTTP method to use (e.g. 'GET', 'POST').
|
||||
url (str): The URL to send the request to.
|
||||
**kwargs (dict): Additional keyword arguments to pass to the underlying `requests.request` function.
|
||||
**kwargs (any): Additional keyword arguments to pass to the underlying `requests.request` function.
|
||||
|
||||
Returns:
|
||||
(requests.Response): The response object from the HTTP request.
|
||||
|
||||
Note:
|
||||
If 'progress' is set to True, the progress bar will display the download progress
|
||||
for responses with a known content length.
|
||||
- If 'progress' is set to True, the progress bar will display the download progress for responses with a known
|
||||
content length.
|
||||
- If 'progress' is a number then progress bar will display assuming content length = progress.
|
||||
"""
|
||||
progress = kwargs.pop('progress', False)
|
||||
progress = kwargs.pop("progress", False)
|
||||
if not progress:
|
||||
return requests.request(method, url, **kwargs)
|
||||
response = requests.request(method, url, stream=True, **kwargs)
|
||||
total = int(response.headers.get('content-length', 0)) # total size
|
||||
total = int(response.headers.get("content-length", 0) if isinstance(progress, bool) else progress) # total size
|
||||
try:
|
||||
pbar = TQDM(total=total, unit='B', unit_scale=True, unit_divisor=1024)
|
||||
pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024)
|
||||
for data in response.iter_content(chunk_size=1024):
|
||||
pbar.update(len(data))
|
||||
pbar.close()
|
||||
@ -101,7 +122,7 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
|
||||
code (int, optional): An identifier for the request, used for logging purposes. Default is -1.
|
||||
verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True.
|
||||
progress (bool, optional): Whether to show a progress bar during the request. Default is False.
|
||||
**kwargs (dict): Keyword arguments to be passed to the requests function specified in method.
|
||||
**kwargs (any): Keyword arguments to be passed to the requests function specified in method.
|
||||
|
||||
Returns:
|
||||
(requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None.
|
||||
@ -120,25 +141,27 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
|
||||
if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful"
|
||||
break
|
||||
try:
|
||||
m = r.json().get('message', 'No JSON message.')
|
||||
m = r.json().get("message", "No JSON message.")
|
||||
except AttributeError:
|
||||
m = 'Unable to read JSON.'
|
||||
m = "Unable to read JSON."
|
||||
if i == 0:
|
||||
if r.status_code in retry_codes:
|
||||
m += f' Retrying {retry}x for {timeout}s.' if retry else ''
|
||||
m += f" Retrying {retry}x for {timeout}s." if retry else ""
|
||||
elif r.status_code == 429: # rate limit
|
||||
h = r.headers # response headers
|
||||
m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
|
||||
m = (
|
||||
f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). "
|
||||
f"Please retry after {h['Retry-After']}s."
|
||||
)
|
||||
if verbose:
|
||||
LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})')
|
||||
LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})")
|
||||
if r.status_code not in retry_codes:
|
||||
return r
|
||||
time.sleep(2 ** i) # exponential standoff
|
||||
time.sleep(2**i) # exponential standoff
|
||||
return r
|
||||
|
||||
args = method, url
|
||||
kwargs['progress'] = progress
|
||||
kwargs["progress"] = progress
|
||||
if thread:
|
||||
threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
|
||||
else:
|
||||
@ -157,29 +180,29 @@ class Events:
|
||||
enabled (bool): A flag to enable or disable Events based on certain conditions.
|
||||
"""
|
||||
|
||||
url = 'https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw'
|
||||
url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the Events object with default values for events, rate_limit, and metadata.
|
||||
"""
|
||||
"""Initializes the Events object with default values for events, rate_limit, and metadata."""
|
||||
self.events = [] # events list
|
||||
self.rate_limit = 60.0 # rate limit (seconds)
|
||||
self.t = 0.0 # rate limit timer (seconds)
|
||||
self.metadata = {
|
||||
'cli': Path(sys.argv[0]).name == 'yolo',
|
||||
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
|
||||
'python': '.'.join(platform.python_version_tuple()[:2]), # i.e. 3.10
|
||||
'version': __version__,
|
||||
'env': ENVIRONMENT,
|
||||
'session_id': round(random.random() * 1E15),
|
||||
'engagement_time_msec': 1000}
|
||||
self.enabled = \
|
||||
SETTINGS['sync'] and \
|
||||
RANK in (-1, 0) and \
|
||||
not TESTS_RUNNING and \
|
||||
ONLINE and \
|
||||
(is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
|
||||
"cli": Path(sys.argv[0]).name == "yolo",
|
||||
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
|
||||
"python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10
|
||||
"version": __version__,
|
||||
"env": ENVIRONMENT,
|
||||
"session_id": round(random.random() * 1e15),
|
||||
"engagement_time_msec": 1000,
|
||||
}
|
||||
self.enabled = (
|
||||
SETTINGS["sync"]
|
||||
and RANK in (-1, 0)
|
||||
and not TESTS_RUNNING
|
||||
and ONLINE
|
||||
and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
|
||||
)
|
||||
|
||||
def __call__(self, cfg):
|
||||
"""
|
||||
@ -195,11 +218,13 @@ class Events:
|
||||
# Attempt to add to events
|
||||
if len(self.events) < 25: # Events list limited to 25 events (drop any events past this)
|
||||
params = {
|
||||
**self.metadata, 'task': cfg.task,
|
||||
'model': cfg.model if cfg.model in GITHUB_ASSETS_NAMES else 'custom'}
|
||||
if cfg.mode == 'export':
|
||||
params['format'] = cfg.format
|
||||
self.events.append({'name': cfg.mode, 'params': params})
|
||||
**self.metadata,
|
||||
"task": cfg.task,
|
||||
"model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom",
|
||||
}
|
||||
if cfg.mode == "export":
|
||||
params["format"] = cfg.format
|
||||
self.events.append({"name": cfg.mode, "params": params})
|
||||
|
||||
# Check rate limit
|
||||
t = time.time()
|
||||
@ -208,10 +233,10 @@ class Events:
|
||||
return
|
||||
|
||||
# Time is over rate limiter, send now
|
||||
data = {'client_id': SETTINGS['uuid'], 'events': self.events} # SHA-256 anonymized UUID hash and events list
|
||||
data = {"client_id": SETTINGS["uuid"], "events": self.events} # SHA-256 anonymized UUID hash and events list
|
||||
|
||||
# POST equivalent to requests.post(self.url, json=data)
|
||||
smart_request('post', self.url, json=data, retry=0, verbose=False)
|
||||
smart_request("post", self.url, json=data, retry=0, verbose=False)
|
||||
|
||||
# Reset events and rate limit timer
|
||||
self.events = []
|
||||
|
Reference in New Issue
Block a user