From 744fb7b7b2255ebb244824cfafe4830d5ec3e5b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=BA=86=E5=88=9A?= Date: Mon, 13 Jan 2025 18:11:56 +0800 Subject: [PATCH] cpu to device select --- __pycache__/track_reid.cpython-39.pyc | Bin 15754 -> 15754 bytes contrast/feat_extract/config.py | 4 ++-- .../__pycache__/experimental.cpython-39.pyc | Bin 4801 -> 4791 bytes models/experimental.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) diff --git a/__pycache__/track_reid.cpython-39.pyc b/__pycache__/track_reid.cpython-39.pyc index b351d96205a6562293bf23518cc255ea9339f62a..8dad08455f154116127627cf8339ed6dbec4d810 100644 GIT binary patch delta 113 zcmeCG?yBZX{8ntt+Kfyf$OI;l*~lV{Y%h4~8B`}b*m?l~ZyFb9 delta 113 zcmeCG?yBZXEF91xD7cBq) diff --git a/contrast/feat_extract/config.py b/contrast/feat_extract/config.py index f0cc387..2fc5c78 100644 --- a/contrast/feat_extract/config.py +++ b/contrast/feat_extract/config.py @@ -76,8 +76,8 @@ class Config: lr_decay = 0.95 # 0.98 weight_decay = 5e-4 loss = 'cross_entropy' # ['focal_loss', 'cross_entropy'] - device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') - # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + # device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') pin_memory = True # if memory is large, set it True to speed up a bit num_workers = 4 # dataloader diff --git a/models/__pycache__/experimental.cpython-39.pyc b/models/__pycache__/experimental.cpython-39.pyc index 0abfeeb961c732f1f28ebc153f5ee08c8bc35115..babf43124971654a2ae21900a01f68159ffdfec4 100644 GIT binary patch delta 371 zcmXAky-UMT6vf|7HR(rQ+QiRHqS#jHqJxWwh;&i#0~NKkSYmG}Qn4DJI2fl`>>>yk zTtsO`$0F#ge}R)jHwQ=m0|lRYhXeQg&Kb^g@+rw3&ZA78^L}1=Y;j$&MzWKIY*~~` zV$F3*)!qHd!MKV=s7;Pq zMG{Xd8h?J6YWg@tDA141sILGMgI*Y;GC71XO7npUT{d5ldEsL6@KvN5%m<8Bl`C*X zgCV!sDidimJa)q{ejD-u7$)eHXF?`gbk7IuF6=%JH&|9SeYk?7uX%<&*emHf@9FPb zbU|(%Rod$R!gl9E4aEWs9~%nTH0#hg}8eqo-jrLjXV<~r2LwyIL=*3=JoN@gbj delta 486 zcmYLFzi-n}5Wcgi?Krj*he8|ZQfOL2#K6#{su++5sQ4uzA=;*jLe4}@L!7!Vh^jBD zz!miXQqUVqQ9L0=gb=Lk44KWpfy4q5FLdCB)9JhK`|ih=(zlXbw`~(-{Cxh4K7X_= zW$KEzRP}tnd&{qfoy|_Ox*N2Er;DE|%Ad@2Qt8xdrLS{+n4yUTyZ5+GS`mgO#{`Gj z1o}F&B&J8o4rAKTCZj1kT(3F*@Qi3#SJ4!2t=@h%QV`fIHz+Ha7syJmiQH&ioP8d9{#cC=fm{zZ zX8I6Gnss z-1)r%be5+6>yF@9iSPx+a(;RGPKffIWhK=-pZN`zPTUU2Z>MUb+bd&gJ*jtV?VwR- zn?Wa4#d)WPYvPS_6IaC#N7ywf4p(IQfZ;>2mwShg#JT*TwmQbw#QXdlu8Z&ayI2)7 hZUr~Q4fmS%Y}{NIuO!eAU)=}zQsfF-GVc|Z{{rQVh#CL@ diff --git a/models/experimental.py b/models/experimental.py index 11f75e2..a20cc3d 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -76,7 +76,7 @@ def attempt_load(weights, device=None, inplace=True, fuse=True): model = Ensemble() for w in weights if isinstance(weights, list) else [weights]: - ckpt = torch.load(attempt_download(w), map_location='cpu') # load + ckpt = torch.load(attempt_download(w), map_location=device) # load ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model # Model compatibility updates