From 4af18109099b62215f9ad060e40a07b4d372b397 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 27 May 2020 14:55:58 +0800 Subject: [PATCH] opt visualized func and add docker usage in cpu --- doc/installation.md | 5 +- tools/infer/utility.py | 146 ++++++++++++++++++++++++++++++++++------- 2 files changed, 128 insertions(+), 23 deletions(-) diff --git a/doc/installation.md b/doc/installation.md index 833bdba3..5dabac50 100644 --- a/doc/installation.md +++ b/doc/installation.md @@ -15,6 +15,9 @@ cd /home/Projects # 首次运行需创建一个docker容器,再次运行时不需要运行当前命令 # 创建一个名字为ppocr的docker容器,并将当前目录映射到容器的/paddle目录下 +如果您希望在CPU环境下使用docker,使用docker而不是nvidia-docker创建docker +sudo docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev /bin/bash + 如果您的机器安装的是CUDA9,请运行以下命令创建容器 sudo nvidia-docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev /bin/bash @@ -24,7 +27,7 @@ sudo nvidia-docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidu 您也可以访问[DockerHub](https://hub.docker.com/r/paddlepaddle/paddle/tags/)获取与您机器适配的镜像。 # ctrl+P+Q可退出docker,重新进入docker使用如下命令 -sudo nvidia-docker container exec -it ppocr /bin/bash +sudo docker container exec -it ppocr /bin/bash ``` 注意:如果docker pull过慢,可以按照如下步骤手动下载后加载docker,以cuda9 docker为例,使用cuda10 docker只需要将cuda9改为cuda10即可。 diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 04ddbcd7..2966a3ed 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -23,6 +23,7 @@ import cv2 import numpy as np import json from PIL import Image, ImageDraw, ImageFont +import math def parse_args(): @@ -127,6 +128,18 @@ def resize_img(img, input_size=600): def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5): + """ + Visualize the results of OCR detection and recognition + args: + image(Image): image from Image.open + boxes(list): boxes with shape(N, 4, 2) + txts(list): the texts + scores(list): txxs corresponding scores + draw_txt(bool): whether draw text or not + drop_score(float): only scores greater than drop_threshold will be visualized + return(array): + the visualized img + """ from PIL import Image, ImageDraw, ImageFont img = image.copy() @@ -154,35 +167,123 @@ def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5): fill='red') if draw_txt: - txt_color = (0, 0, 0) - img = np.array(resize_img(img)) - _h = img.shape[0] - blank_img = np.ones(shape=[_h, 600], dtype=np.int8) * 255 + img = np.array(resize_img(img, input_size=600)) + txt_img = text_visual( + txts, scores, img_h=img.shape[0], img_w=600, threshold=drop_score) + img = np.concatenate([np.array(img), np.array(txt_img)], axis=1) + + return img + + +def str_count(s): + """ + Count the number of Chinese characters, + a single English character and a single number + equal to half the length of Chinese characters. + + args: + s(string): the input of string + return(int): + the number of Chinese characters + """ + import string + count_zh = count_pu = 0 + s_len = len(s) + en_dg_count = 0 + for c in s: + if c in string.ascii_letters or c.isdigit() or c.isspace(): + en_dg_count += 1 + elif c.isalpha(): + count_zh += 1 + else: + count_pu += 1 + return s_len - math.ceil(en_dg_count / 2) + + +def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.): + """ + create new blank img and draw txt on it + args: + texts(list): the text will be draw + scores(list|None): corresponding score of each txt + img_h(int): the height of blank img + img_w(int): the width of blank img + return(array): + + """ + if scores is not None: + assert len(texts) == len( + scores), "The number of txts and corresponding scores must match" + + def create_blank_img(): + blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255 + blank_img[:, img_w - 1:] = 0 blank_img = Image.fromarray(blank_img).convert("RGB") draw_txt = ImageDraw.Draw(blank_img) + return blank_img, draw_txt - font_size = 20 - gap = 20 - title = "index text score" - font = ImageFont.truetype( - "./doc/simfang.ttf", font_size, encoding="utf-8") - - draw_txt.text((20, 0), title, txt_color, font=font) - count = 0 - for idx, txt in enumerate(txts): - if scores[idx] < drop_score: - continue - font = ImageFont.truetype( - "./doc/simfang.ttf", font_size, encoding="utf-8") - new_txt = str(idx) + ': ' + txt + ' ' + '%.3f' % (scores[idx]) - draw_txt.text( - (20, gap * (count + 1)), new_txt, txt_color, font=font) + blank_img, draw_txt = create_blank_img() + + font_size = 20 + txt_color = (0, 0, 0) + font = ImageFont.truetype( + "../../doc/simfang.ttf", font_size, encoding="utf-8") + + gap = font_size + 5 + txt_img_list = [] + count, index = 0, 0 + for idx, txt in enumerate(texts): + index += 1 + if scores[idx] < threshold: + index -= 1 + continue + first_line = True + while str_count(txt) >= img_w // font_size - 4: + tmp = txt + txt = tmp[:img_w // font_size - 4] + if first_line: + new_txt = str(index) + ': ' + txt + first_line = False + else: + new_txt = ' ' + txt + draw_txt.text((0, gap * (count + 1)), new_txt, txt_color, font=font) + txt = tmp[img_w // font_size - 4:] count += 1 - img = np.concatenate([np.array(img), np.array(blank_img)], axis=1) - return img + if count >= img_h // gap - 1: + txt_img_list.append(np.array(blank_img)) + blank_img, draw_txt = create_blank_img() + count = 0 + + if first_line: + new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx]) + else: + new_txt = " " + txt + " " + '%.3f' % (scores[idx]) + draw_txt.text((0, gap * (count + 1)), new_txt, txt_color, font=font) + count += 1 + # whether add new blank img or not + if count >= img_h // gap - 1 and idx + 1 < len(texts): + txt_img_list.append(np.array(blank_img)) + blank_img, draw_txt = create_blank_img() + count = 0 + + txt_img_list.append(np.array(blank_img)) + if len(txt_img_list) == 1: + blank_img = np.array(txt_img_list[0]) + else: + blank_img = np.concatenate(txt_img_list, axis=1) + # cv2.imwrite("./draw_txt.jpg", np.array(blank_img)) + return np.array(blank_img) if __name__ == '__main__': + text = [ + "旨在打造一套丰富领先、且实用的工具库助力使、用者训练出更好的模型,并应用落地", + "以下代码实现了文本检测、识别串联推理,在执行预测时,需要通过参数image_dir指定单张图像或者图像集合", + "上述DB模型的训练和评估,需设置后处理参数box_thresh=0.6,unclip_ratio=1.5,使用不同数据集" + ] + img = text_visual(text, scores=[0.999, 0.999, 0.999], img_h=100) + cv2.imwrite("./draw_txt.jpg", np.array(img)) + """ test_img = "./doc/test_v2" predict_txt = "./doc/predict.txt" f = open(predict_txt, 'r') @@ -202,3 +303,4 @@ if __name__ == '__main__': new_img = draw_ocr(image, boxes, txts, scores, draw_txt=True) cv2.imwrite(img_name, new_img) + """ -- GitLab