132 lines
5.5 KiB
Python
132 lines
5.5 KiB
Python
import argparse
|
|
import cv2
|
|
import numpy as np
|
|
|
|
|
|
def parse_args():
|
|
def str2bool(v):
|
|
return v.lower() in ("true", "t", "1")
|
|
|
|
parser = argparse.ArgumentParser()
|
|
# params for prediction engine
|
|
parser.add_argument("--use_gpu", type=str2bool, default=False)
|
|
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
|
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
|
parser.add_argument("--use_fp16", type=str2bool, default=False)
|
|
parser.add_argument("--gpu_mem", type=int, default=500)
|
|
|
|
# params for text detector
|
|
parser.add_argument("--det_algorithm", type=str, default='DB')
|
|
parser.add_argument("--ocr_det_model_dir", type=str, default='models/ocr_det_infer')
|
|
parser.add_argument("--barcode_det_model_dir", type=str, default='models/barcode_det_infer')
|
|
parser.add_argument("--det_limit_side_len", type=float, default=960)
|
|
parser.add_argument("--det_limit_type", type=str, default='max')
|
|
|
|
# DB parmas
|
|
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
|
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
|
|
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
|
|
parser.add_argument("--max_batch_size", type=int, default=10)
|
|
parser.add_argument("--use_dilation", type=bool, default=False)
|
|
parser.add_argument("--det_db_score_mode", type=str, default="fast")
|
|
# EAST parmas
|
|
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
|
|
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
|
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
|
|
|
|
# SAST parmas
|
|
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
|
|
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
|
|
parser.add_argument("--det_sast_polygon", type=bool, default=False)
|
|
|
|
# params for text recognizer
|
|
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
|
|
parser.add_argument("--rec_model_dir", type=str, default='models/rec_infer/')
|
|
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
|
|
parser.add_argument("--rec_char_type", type=str, default='ch')
|
|
parser.add_argument("--rec_batch_num", type=int, default=6)
|
|
parser.add_argument("--max_text_length", type=int, default=25)
|
|
parser.add_argument("--rec_char_dict_path", type=str, default="models/rec_infer/ppocr_keys_v1.txt")
|
|
parser.add_argument("--use_space_char", type=str2bool, default=True)
|
|
parser.add_argument(
|
|
"--vis_font_path", type=str, default="models/rec_infer/simfang.ttf")
|
|
parser.add_argument("--drop_score", type=float, default=0.5)
|
|
|
|
# params for text classifier
|
|
parser.add_argument("--use_angle_cls", type=str2bool, default=True)
|
|
parser.add_argument("--cls_model_dir", type=str, default='models/cls_infer')
|
|
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
|
|
parser.add_argument("--label_list", type=list, default=['0', '180'])
|
|
parser.add_argument("--cls_batch_num", type=int, default=6)
|
|
parser.add_argument("--cls_thresh", type=float, default=0.9)
|
|
|
|
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
|
|
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
|
|
|
parser.add_argument("--use_mp", type=str2bool, default=False)
|
|
parser.add_argument("--total_process_num", type=int, default=1)
|
|
parser.add_argument("--process_id", type=int, default=0)
|
|
return parser.parse_args()
|
|
|
|
|
|
def get_rotate_crop_image(img, points):
|
|
'''
|
|
img_height, img_width = img.shape[0:2]
|
|
left = int(np.min(points[:, 0]))
|
|
right = int(np.max(points[:, 0]))
|
|
top = int(np.min(points[:, 1]))
|
|
bottom = int(np.max(points[:, 1]))
|
|
img_crop = img[top:bottom, left:right, :].copy()
|
|
points[:, 0] = points[:, 0] - left
|
|
points[:, 1] = points[:, 1] - top
|
|
'''
|
|
img_crop_width = int(
|
|
max(
|
|
np.linalg.norm(points[0] - points[1]),
|
|
np.linalg.norm(points[2] - points[3])))
|
|
img_crop_height = int(
|
|
max(
|
|
np.linalg.norm(points[0] - points[3]),
|
|
np.linalg.norm(points[1] - points[2])))
|
|
pts_std = np.float32([[0, 0], [img_crop_width, 0],
|
|
[img_crop_width, img_crop_height],
|
|
[0, img_crop_height]])
|
|
M = cv2.getPerspectiveTransform(points, pts_std)
|
|
dst_img = cv2.warpPerspective(
|
|
img,
|
|
M, (img_crop_width, img_crop_height),
|
|
borderMode=cv2.BORDER_REPLICATE,
|
|
flags=cv2.INTER_CUBIC)
|
|
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
|
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
|
dst_img = np.rot90(dst_img)
|
|
return dst_img
|
|
|
|
|
|
def sorted_boxes(dt_boxes):
|
|
"""
|
|
Sort text boxes in order from top to bottom, left to right
|
|
args:
|
|
dt_boxes(array):detected text boxes with shape [4, 2]
|
|
return:
|
|
sorted boxes(array) with shape [4, 2]
|
|
"""
|
|
num_boxes = dt_boxes.shape[0]
|
|
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
|
_boxes = list(sorted_boxes)
|
|
|
|
for i in range(num_boxes - 1):
|
|
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
|
|
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
|
tmp = _boxes[i]
|
|
_boxes[i] = _boxes[i + 1]
|
|
_boxes[i + 1] = tmp
|
|
return _boxes
|
|
|
|
|
|
def print_draw_crop_rec_res(img_crop_list, rec_res):
|
|
bbox_num = len(img_crop_list)
|
|
for bno in range(bbox_num):
|
|
cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno])
|
|
print(bno, rec_res[bno])
|