From 960a0ca6079bca50b759c66c43dbb2bee9492d42 Mon Sep 17 00:00:00 2001 From: zhangxinnan Date: Sun, 4 Sep 2022 10:01:16 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0draw=5Focr=5Fbox=5Ftxt2:?= =?UTF-8?q?=E8=A7=A3=E5=86=B3=E6=96=87=E6=9C=AC=E6=A1=86=E5=80=BE=E6=96=9C?= =?UTF-8?q?=E6=97=B6=E6=96=87=E5=AD=97=E7=94=BB=E5=9C=A8=E5=A4=96=E8=BE=B9?= =?UTF-8?q?=E7=9A=84=E6=83=85=E5=86=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/infer/predict_system.py | 4 +- tools/infer/utility.py | 71 +++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index e0f2c41f..ebf0b0ba 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -34,7 +34,7 @@ import tools.infer.predict_det as predict_det import tools.infer.predict_cls as predict_cls from ppocr.utils.utility import get_image_file_list, check_and_read from ppocr.utils.logging import get_logger -from tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image +from tools.infer.utility import draw_ocr_box_txt2, get_rotate_crop_image logger = get_logger() @@ -189,7 +189,7 @@ def main(args): txts = [rec_res[i][0] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))] - draw_img = draw_ocr_box_txt( + draw_img = draw_ocr_box_txt2( image, boxes, txts, diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 9baf66d7..8045ec44 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -447,6 +447,77 @@ def draw_ocr_box_txt(image, return np.array(img_show) +def draw_ocr_box_txt2(image, + boxes, + txts=None, + scores=None, + drop_score=0.5, + font_path="./doc/fonts/simfang.ttf"): + h, w = image.height, image.width + img_left = image.copy() + img_right = np.ones((h, w, 3), dtype=np.uint8) * 255 + import random + random.seed(0) + + draw_left = ImageDraw.Draw(img_left) + if txts is None or len(txts) != len(boxes): + txts = [None] * len(boxes) + for idx, (box, txt) in enumerate(zip(boxes, txts)): + if scores is not None and scores[idx] < drop_score: + continue + color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + draw_left.polygon(box, fill=color) + img_right_text = draw_box_txt_fine((w, h), box, txt, font_path) + pts = np.array(box, np.int32).reshape((-1, 1, 2)) + cv2.polylines(img_right_text, [pts], True, color, 1) + img_right = cv2.bitwise_and(img_right, img_right_text) + img_left = Image.blend(image, img_left, 0.5) + img_show = Image.new('RGB', (w * 2, h), (255, 255, 255)) + img_show.paste(img_left, (0, 0, w, h)) + img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h)) + return np.array(img_show) + + +def draw_box_txt_fine(img_size, box, txt, font_path="./doc/fonts/simfang.ttf"): + box_height = int(math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][1])**2)) + box_width = int(math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][1])**2)) + + if box_height > 2 * box_width and box_height > 30: + img_text = Image.new('RGB', (box_height, box_width), (255, 255, 255)) + draw_text = ImageDraw.Draw(img_text) + if txt: + font = create_font(txt, (box_height, box_width), font_path) + draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font) + img_text = img_text.transpose(Image.ROTATE_270) + else: + img_text = Image.new('RGB', (box_width, box_height), (255, 255, 255)) + draw_text = ImageDraw.Draw(img_text) + if txt: + font = create_font(txt, (box_width, box_height), font_path) + draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font) + + pts1 = np.float32([[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]) + pts2 = np.array(box, dtype=np.float32) + M = cv2.getPerspectiveTransform(pts1, pts2) + + img_text = np.array(img_text, dtype=np.uint8) + img_right_text = cv2.warpPerspective(img_text, M, img_size, + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(255, 255, 255)) + return img_right_text + + +def create_font(txt, sz, font_path="./doc/fonts/simfang.ttf"): + font_size = int(sz[1] * 0.99) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + length = font.getsize(txt)[0] + if length > sz[0]: + font_size = int(font_size * sz[0] / length) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + return font + + def str_count(s): """ Count the number of Chinese characters, -- GitLab