3588适配

This commit is contained in:
2024-10-17 11:51:39 +08:00
parent 7d9c289325
commit 8475980895
2 changed files with 22 additions and 3 deletions

1
.gitignore vendored
View File

@ -90,6 +90,7 @@ parts/
sdist/
var/
wheels/
paper_data/
<<<<<<< HEAD
*.egg-info/
wandb/

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
from flask import request, Flask
import pdb
import numpy as np
import json
import time
@ -10,6 +11,8 @@ import torch
from PIL import Image
import datetime
from torchvision import transforms
from models.experimental import attempt_load
from utils.torch_utils import select_device
# from models.modeling import VisionTransformer, CONFIGS
from gevent.pywsgi import WSGIServer
sys.path.insert(0, ".")
@ -42,7 +45,7 @@ parser.add_argument('--source', type=str, default='../module/ieemoo-ai-zhanting/
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.60, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--view-img', type=bool, default=True, help='display results')
parser.add_argument('--save-txt', type=bool, default=True, help='save results to *.txt')
parser.add_argument('--save-conf', type=bool, default=True, help='save confidences in --save-txt labels')
@ -57,6 +60,17 @@ parser.add_argument('--exist-ok', type=bool, default=True, help='existing projec
#opt = parser.parse_args()
opt, unknown = parser.parse_known_args()
def init_model(opt):
device = select_device(opt.device)
#half = device.type != 'cpu'
model = attempt_load(opt.weights, map_location=device)
stride = int(model.stride.max())
#if half:
model.half()
model.eval()
return model, stride
model,stride = init_model(opt)
@app.route("/zhanting", methods=['POST'])
def get_isempty():
data = request.get_data()
@ -73,7 +87,11 @@ def get_isempty():
file = open(image_path, 'wb')
file.write(imgdata)
img = cv2.imread(image_path)
site = np.array([[[0, 1024], [0, 571], [313, 365], [949, 367], [1277, 596], [1280, 1024]]], dtype=np.int32)
#pdb.set_trace()
if img.shape[0] == 1280:
site = np.array([[[0, 1280],[0, 671], [300, 390], [740, 390], [1024, 635], [1024,1280]]], dtype=np.int32)
else:
site = np.array([[[0, 1024],[0, 571], [313, 365], [949, 367], [1277, 596], [1280, 1024]]], dtype=np.int32)
im = np.zeros(img.shape[:2], dtype="uint8")
cv2.polylines(im, site, 1, 255)
cv2.fillPoly(im, site, 255)
@ -81,7 +99,7 @@ def get_isempty():
masked = cv2.bitwise_or(img, img, mask=mask)
img0 = masked
cv2.imwrite("../module/ieemoo-ai-zhanting/imgs/1.jpg",img0)
pred = detect.detect(opt)
pred = detect.detect(opt, model, stride)
logger.info(pred)
except Exception as e:
logger.warning(e)