masterUpdate
This commit is contained in:
@ -15,6 +15,7 @@ sys.path.insert(0, ".")
|
||||
|
||||
import logging.config
|
||||
from skywalking import agent, config
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
SW_SERVER = os.environ.get('SW_AGENT_COLLECTOR_BACKEND_SERVICES')
|
||||
SW_SERVICE_NAME = os.environ.get('SW_AGENT_NAME')
|
||||
if SW_SERVER and SW_SERVICE_NAME:
|
||||
@ -96,7 +97,7 @@ class Predictor(object):
|
||||
probs = torch.nn.Softmax(dim=-1)(part_logits)
|
||||
topN = torch.argsort(probs, dim=-1, descending=True).tolist()
|
||||
clas_ids = topN[0][0]
|
||||
clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1
|
||||
#clas_ids = 0 if 0==int(clas_ids) or 2 == int(clas_ids) or 3 == int(clas_ids) else 1
|
||||
#print("cur_img result: class id: %d, score: %0.3f" % (clas_ids, probs[0, clas_ids].item()))
|
||||
result["success"] = "true"
|
||||
result["rst_cls"] = str(clas_ids)
|
||||
|
Reference in New Issue
Block a user