更新 detacttracking
This commit is contained in:
21
detecttracking/tracking/trackers/reid/test.py
Normal file
21
detecttracking/tracking/trackers/reid/test.py
Normal file
@ -0,0 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Fri Jan 19 16:10:39 2024
|
||||
|
||||
@author: ym
|
||||
"""
|
||||
import torch
|
||||
from model.resnet_pre import resnet18
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
model_path = "best.pth"
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model = resnet18().to(device)
|
||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user