From 847598089513e138ae53ce712d3ab427709905a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=99=A8?= Date: Thu, 17 Oct 2024 11:51:39 +0800 Subject: [PATCH] =?UTF-8?q?3588=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + ieemoo-ai-zhanting.py | 24 +++++++++++++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 70e1997..958b303 100755 --- a/.gitignore +++ b/.gitignore @@ -90,6 +90,7 @@ parts/ sdist/ var/ wheels/ +paper_data/ <<<<<<< HEAD *.egg-info/ wandb/ diff --git a/ieemoo-ai-zhanting.py b/ieemoo-ai-zhanting.py index a5d1584..629ab41 100755 --- a/ieemoo-ai-zhanting.py +++ b/ieemoo-ai-zhanting.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from flask import request, Flask +import pdb import numpy as np import json import time @@ -10,6 +11,8 @@ import torch from PIL import Image import datetime from torchvision import transforms +from models.experimental import attempt_load +from utils.torch_utils import select_device # from models.modeling import VisionTransformer, CONFIGS from gevent.pywsgi import WSGIServer sys.path.insert(0, ".") @@ -42,7 +45,7 @@ parser.add_argument('--source', type=str, default='../module/ieemoo-ai-zhanting/ parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--conf-thres', type=float, default=0.60, help='object confidence threshold') parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') -parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') +parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--view-img', type=bool, default=True, help='display results') parser.add_argument('--save-txt', type=bool, default=True, help='save results to *.txt') parser.add_argument('--save-conf', type=bool, default=True, help='save confidences in --save-txt labels') @@ -57,6 +60,17 @@ parser.add_argument('--exist-ok', type=bool, default=True, help='existing projec #opt = parser.parse_args() opt, unknown = parser.parse_known_args() +def init_model(opt): + device = select_device(opt.device) + #half = device.type != 'cpu' + model = attempt_load(opt.weights, map_location=device) + stride = int(model.stride.max()) + #if half: + model.half() + model.eval() + return model, stride +model,stride = init_model(opt) + @app.route("/zhanting", methods=['POST']) def get_isempty(): data = request.get_data() @@ -73,7 +87,11 @@ def get_isempty(): file = open(image_path, 'wb') file.write(imgdata) img = cv2.imread(image_path) - site = np.array([[[0, 1024], [0, 571], [313, 365], [949, 367], [1277, 596], [1280, 1024]]], dtype=np.int32) + #pdb.set_trace() + if img.shape[0] == 1280: + site = np.array([[[0, 1280],[0, 671], [300, 390], [740, 390], [1024, 635], [1024,1280]]], dtype=np.int32) + else: + site = np.array([[[0, 1024],[0, 571], [313, 365], [949, 367], [1277, 596], [1280, 1024]]], dtype=np.int32) im = np.zeros(img.shape[:2], dtype="uint8") cv2.polylines(im, site, 1, 255) cv2.fillPoly(im, site, 255) @@ -81,7 +99,7 @@ def get_isempty(): masked = cv2.bitwise_or(img, img, mask=mask) img0 = masked cv2.imwrite("../module/ieemoo-ai-zhanting/imgs/1.jpg",img0) - pred = detect.detect(opt) + pred = detect.detect(opt, model, stride) logger.info(pred) except Exception as e: logger.warning(e)