masterUpdate

This commit is contained in:
2022-12-29 13:45:07 +08:00
parent b84a92f67a
commit 615c91feb6
2 changed files with 18 additions and 6 deletions

View File

@ -15,6 +15,7 @@ sys.path.insert(0, ".")
import logging.config
from skywalking import agent, config
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
SW_SERVER = os.environ.get('SW_AGENT_COLLECTOR_BACKEND_SERVICES')
SW_SERVICE_NAME = os.environ.get('SW_AGENT_NAME')
if SW_SERVER and SW_SERVICE_NAME:
@ -96,7 +97,7 @@ class Predictor(object):
probs = torch.nn.Softmax(dim=-1)(part_logits)
topN = torch.argsort(probs, dim=-1, descending=True).tolist()
clas_ids = topN[0][0]
clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1
#clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1
#print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item()))
result["success"] = "true"
result["rst_cls"] = str(clas_ids)