utility.py 19.2 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
W
WenmuZhou 已提交
16
import os
W
WenmuZhou 已提交
17
import sys
L
LDOUBLEV 已提交
18 19
import cv2
import numpy as np
L
LDOUBLEV 已提交
20 21
import json
from PIL import Image, ImageDraw, ImageFont
22
import math
W
WenmuZhou 已提交
23
from paddle import inference
L
LDOUBLEV 已提交
24 25
import time
from ppocr.utils.logging import get_logger
W
WenmuZhou 已提交
26

L
LDOUBLEV 已提交
27
logger = get_logger()
L
LDOUBLEV 已提交
28 29


30 31
def str2bool(v):
    return v.lower() in ("true", "t", "1")
L
LDOUBLEV 已提交
32 33


W
WenmuZhou 已提交
34
def init_args():
L
LDOUBLEV 已提交
35
    parser = argparse.ArgumentParser()
W
WenmuZhou 已提交
36
    # params for prediction engine
L
LDOUBLEV 已提交
37 38 39
    parser.add_argument("--use_gpu", type=str2bool, default=True)
    parser.add_argument("--ir_optim", type=str2bool, default=True)
    parser.add_argument("--use_tensorrt", type=str2bool, default=False)
L
LDOUBLEV 已提交
40
    parser.add_argument("--min_subgraph_size", type=int, default=3)
L
LDOUBLEV 已提交
41
    parser.add_argument("--precision", type=str, default="fp32")
L
LDOUBLEV 已提交
42
    parser.add_argument("--gpu_mem", type=int, default=500)
L
LDOUBLEV 已提交
43

W
WenmuZhou 已提交
44
    # params for text detector
L
LDOUBLEV 已提交
45 46 47
    parser.add_argument("--image_dir", type=str)
    parser.add_argument("--det_algorithm", type=str, default='DB')
    parser.add_argument("--det_model_dir", type=str)
W
WenmuZhou 已提交
48 49
    parser.add_argument("--det_limit_side_len", type=float, default=960)
    parser.add_argument("--det_limit_type", type=str, default='max')
L
LDOUBLEV 已提交
50

W
WenmuZhou 已提交
51
    # DB parmas
L
LDOUBLEV 已提交
52 53
    parser.add_argument("--det_db_thresh", type=float, default=0.3)
    parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
W
WenmuZhou 已提交
54
    parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
L
LDOUBLEV 已提交
55
    parser.add_argument("--max_batch_size", type=int, default=10)
L
LDOUBLEV 已提交
56
    parser.add_argument("--use_dilation", type=bool, default=False)
littletomatodonkey's avatar
littletomatodonkey 已提交
57
    parser.add_argument("--det_db_score_mode", type=str, default="fast")
W
WenmuZhou 已提交
58
    # EAST parmas
L
LDOUBLEV 已提交
59 60 61 62
    parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
    parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
    parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)

W
WenmuZhou 已提交
63
    # SAST parmas
L
licx 已提交
64 65
    parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
    parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
66
    parser.add_argument("--det_sast_polygon", type=bool, default=False)
L
licx 已提交
67

W
WenmuZhou 已提交
68
    # params for text recognizer
L
LDOUBLEV 已提交
69 70
    parser.add_argument("--rec_algorithm", type=str, default='CRNN')
    parser.add_argument("--rec_model_dir", type=str)
T
fix bug  
tink2123 已提交
71 72
    parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
    parser.add_argument("--rec_char_type", type=str, default='ch')
L
LDOUBLEV 已提交
73
    parser.add_argument("--rec_batch_num", type=int, default=6)
T
fix bug  
tink2123 已提交
74
    parser.add_argument("--max_text_length", type=int, default=25)
L
LDOUBLEV 已提交
75 76 77 78
    parser.add_argument(
        "--rec_char_dict_path",
        type=str,
        default="./ppocr/utils/ppocr_keys_v1.txt")
W
WenmuZhou 已提交
79 80
    parser.add_argument("--use_space_char", type=str2bool, default=True)
    parser.add_argument(
T
tink2123 已提交
81
        "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
W
WenmuZhou 已提交
82
    parser.add_argument("--drop_score", type=float, default=0.5)
W
WenmuZhou 已提交
83

J
Jethong 已提交
84 85 86 87 88 89 90 91 92
    # params for e2e
    parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
    parser.add_argument("--e2e_model_dir", type=str)
    parser.add_argument("--e2e_limit_side_len", type=float, default=768)
    parser.add_argument("--e2e_limit_type", type=str, default='max')

    # PGNet parmas
    parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
    parser.add_argument(
J
Jethong 已提交
93
        "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
J
Jethong 已提交
94
    parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
J
Jethong 已提交
95
    parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True)
J
Jethong 已提交
96
    parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
J
Jethong 已提交
97

W
WenmuZhou 已提交
98 99 100 101 102
    # params for text classifier
    parser.add_argument("--use_angle_cls", type=str2bool, default=False)
    parser.add_argument("--cls_model_dir", type=str)
    parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
    parser.add_argument("--label_list", type=list, default=['0', '180'])
L
LDOUBLEV 已提交
103
    parser.add_argument("--cls_batch_num", type=int, default=6)
W
WenmuZhou 已提交
104 105 106
    parser.add_argument("--cls_thresh", type=float, default=0.9)

    parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
L
LDOUBLEV 已提交
107
    parser.add_argument("--cpu_threads", type=int, default=10)
W
WenmuZhou 已提交
108
    parser.add_argument("--use_pdserving", type=str2bool, default=False)
L
LDOUBLEV 已提交
109
    parser.add_argument("--warmup", type=str2bool, default=True)
W
WenmuZhou 已提交
110

L
LDOUBLEV 已提交
111
    # multi-process
littletomatodonkey's avatar
littletomatodonkey 已提交
112
    parser.add_argument("--use_mp", type=str2bool, default=False)
113 114
    parser.add_argument("--total_process_num", type=int, default=1)
    parser.add_argument("--process_id", type=int, default=0)
W
WenmuZhou 已提交
115

L
LDOUBLEV 已提交
116 117
    parser.add_argument("--benchmark", type=bool, default=False)
    parser.add_argument("--save_log_path", type=str, default="./log_output/")
D
Double_V 已提交
118

W
WenmuZhou 已提交
119
    parser.add_argument("--show_log", type=str2bool, default=True)
W
WenmuZhou 已提交
120
    return parser
W
WenmuZhou 已提交
121

122

123
def parse_args():
W
WenmuZhou 已提交
124
    parser = init_args()
L
LDOUBLEV 已提交
125 126 127
    return parser.parse_args()


W
WenmuZhou 已提交
128 129 130 131 132
def create_predictor(args, mode, logger):
    if mode == "det":
        model_dir = args.det_model_dir
    elif mode == 'cls':
        model_dir = args.cls_model_dir
J
Jethong 已提交
133
    elif mode == 'rec':
W
WenmuZhou 已提交
134
        model_dir = args.rec_model_dir
W
WenmuZhou 已提交
135 136
    elif mode == 'table':
        model_dir = args.table_model_dir
J
Jethong 已提交
137 138
    else:
        model_dir = args.e2e_model_dir
W
WenmuZhou 已提交
139 140 141 142

    if model_dir is None:
        logger.info("not find {} model file path {}".format(mode, model_dir))
        sys.exit(0)
文幕地方's avatar
文幕地方 已提交
143 144
    model_file_path = model_dir + "/inference.pdmodel"
    params_file_path = model_dir + "/inference.pdiparams"
W
WenmuZhou 已提交
145
    if not os.path.exists(model_file_path):
L
LDOUBLEV 已提交
146
        raise ValueError("not find model file path {}".format(model_file_path))
W
WenmuZhou 已提交
147
    if not os.path.exists(params_file_path):
L
LDOUBLEV 已提交
148 149
        raise ValueError("not find params file path {}".format(
            params_file_path))
W
WenmuZhou 已提交
150

W
WenmuZhou 已提交
151
    config = inference.Config(model_file_path, params_file_path)
W
WenmuZhou 已提交
152

L
LDOUBLEV 已提交
153 154 155 156 157 158 159 160 161 162
    if hasattr(args, 'precision'):
        if args.precision == "fp16" and args.use_tensorrt:
            precision = inference.PrecisionType.Half
        elif args.precision == "int8":
            precision = inference.PrecisionType.Int8
        else:
            precision = inference.PrecisionType.Float32
    else:
        precision = inference.PrecisionType.Float32

W
WenmuZhou 已提交
163 164
    if args.use_gpu:
        config.enable_use_gpu(args.gpu_mem, 0)
L
LDOUBLEV 已提交
165 166
        if args.use_tensorrt:
            config.enable_tensorrt_engine(
L
LDOUBLEV 已提交
167 168
                precision_mode=inference.PrecisionType.Float32,
                max_batch_size=args.max_batch_size,
L
LDOUBLEV 已提交
169 170
                min_subgraph_size=args.min_subgraph_size)
            # skip the minmum trt subgraph
L
LDOUBLEV 已提交
171
        if mode == "det":
L
LDOUBLEV 已提交
172 173 174 175
            min_input_shape = {
                "x": [1, 3, 50, 50],
                "conv2d_92.tmp_0": [1, 96, 20, 20],
                "conv2d_91.tmp_0": [1, 96, 10, 10],
L
LDOUBLEV 已提交
176
                "conv2d_59.tmp_0": [1, 96, 20, 20],
L
LDOUBLEV 已提交
177 178 179 180 181 182 183 184 185 186 187 188
                "nearest_interp_v2_1.tmp_0": [1, 96, 10, 10],
                "nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
                "nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
                "nearest_interp_v2_4.tmp_0": [1, 24, 20, 20],
                "nearest_interp_v2_5.tmp_0": [1, 24, 20, 20],
                "elementwise_add_7": [1, 56, 2, 2],
                "nearest_interp_v2_0.tmp_0": [1, 96, 2, 2]
            }
            max_input_shape = {
                "x": [1, 3, 2000, 2000],
                "conv2d_92.tmp_0": [1, 96, 400, 400],
                "conv2d_91.tmp_0": [1, 96, 200, 200],
L
LDOUBLEV 已提交
189
                "conv2d_59.tmp_0": [1, 96, 400, 400],
L
LDOUBLEV 已提交
190 191 192 193 194 195 196 197 198 199 200 201
                "nearest_interp_v2_1.tmp_0": [1, 96, 200, 200],
                "nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
                "nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
                "nearest_interp_v2_4.tmp_0": [1, 24, 400, 400],
                "nearest_interp_v2_5.tmp_0": [1, 24, 400, 400],
                "elementwise_add_7": [1, 56, 400, 400],
                "nearest_interp_v2_0.tmp_0": [1, 96, 400, 400]
            }
            opt_input_shape = {
                "x": [1, 3, 640, 640],
                "conv2d_92.tmp_0": [1, 96, 160, 160],
                "conv2d_91.tmp_0": [1, 96, 80, 80],
L
LDOUBLEV 已提交
202
                "conv2d_59.tmp_0": [1, 96, 160, 160],
L
LDOUBLEV 已提交
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
                "nearest_interp_v2_1.tmp_0": [1, 96, 80, 80],
                "nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
                "nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
                "nearest_interp_v2_4.tmp_0": [1, 24, 160, 160],
                "nearest_interp_v2_5.tmp_0": [1, 24, 160, 160],
                "elementwise_add_7": [1, 56, 40, 40],
                "nearest_interp_v2_0.tmp_0": [1, 96, 40, 40]
            }
        elif mode == "rec":
            min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]}
            max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]}
            opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]}
        elif mode == "cls":
            min_input_shape = {"x": [args.rec_batch_num, 3, 48, 10]}
            max_input_shape = {"x": [args.rec_batch_num, 3, 48, 2000]}
            opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
L
LDOUBLEV 已提交
219 220 221 222
        else:
            min_input_shape = {"x": [1, 3, 10, 10]}
            max_input_shape = {"x": [1, 3, 1000, 1000]}
            opt_input_shape = {"x": [1, 3, 500, 500]}
L
LDOUBLEV 已提交
223 224 225
        config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
                                          opt_input_shape)

W
WenmuZhou 已提交
226 227
    else:
        config.disable_gpu()
L
LDOUBLEV 已提交
228 229 230
        if hasattr(args, "cpu_threads"):
            config.set_cpu_math_library_num_threads(args.cpu_threads)
        else:
W
WenmuZhou 已提交
231
            # default cpu threads as 10
L
LDOUBLEV 已提交
232
            config.set_cpu_math_library_num_threads(10)
W
WenmuZhou 已提交
233 234 235 236 237
        if args.enable_mkldnn:
            # cache 10 different shapes for mkldnn to avoid memory leak
            config.set_mkldnn_cache_capacity(10)
            config.enable_mkldnn()

L
LDOUBLEV 已提交
238 239
    # enable memory optim
    config.enable_memory_optim()
W
WenmuZhou 已提交
240 241
    config.disable_glog_info()

W
WenmuZhou 已提交
242
    config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
W
WenmuZhou 已提交
243
    if mode == 'table':
W
WenmuZhou 已提交
244
        config.delete_pass("fc_fuse_pass")  # not supported for table
W
WenmuZhou 已提交
245
    config.switch_use_feed_fetch_ops(False)
W
WenmuZhou 已提交
246
    config.switch_ir_optim(True)
247

W
WenmuZhou 已提交
248 249
    # create predictor
    predictor = inference.create_predictor(config)
W
WenmuZhou 已提交
250 251
    input_names = predictor.get_input_names()
    for name in input_names:
W
WenmuZhou 已提交
252
        input_tensor = predictor.get_input_handle(name)
W
WenmuZhou 已提交
253 254 255
    output_names = predictor.get_output_names()
    output_tensors = []
    for output_name in output_names:
W
WenmuZhou 已提交
256
        output_tensor = predictor.get_output_handle(output_name)
W
WenmuZhou 已提交
257
        output_tensors.append(output_tensor)
L
LDOUBLEV 已提交
258
    return predictor, input_tensor, output_tensors, config
W
WenmuZhou 已提交
259 260


J
Jethong 已提交
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
def draw_e2e_res(dt_boxes, strs, img_path):
    src_im = cv2.imread(img_path)
    for box, str in zip(dt_boxes, strs):
        box = box.astype(np.int32).reshape((-1, 1, 2))
        cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
        cv2.putText(
            src_im,
            str,
            org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
            fontFace=cv2.FONT_HERSHEY_COMPLEX,
            fontScale=0.7,
            color=(0, 255, 0),
            thickness=1)
    return src_im


L
LDOUBLEV 已提交
277
def draw_text_det_res(dt_boxes, img_path):
L
LDOUBLEV 已提交
278 279 280 281
    src_im = cv2.imread(img_path)
    for box in dt_boxes:
        box = np.array(box).astype(np.int32).reshape(-1, 2)
        cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
L
LDOUBLEV 已提交
282
    return src_im
L
LDOUBLEV 已提交
283 284


L
LDOUBLEV 已提交
285 286
def resize_img(img, input_size=600):
    """
L
LDOUBLEV 已提交
287
    resize img and limit the longest side of the image to input_size
L
LDOUBLEV 已提交
288 289 290 291 292
    """
    img = np.array(img)
    im_shape = img.shape
    im_size_max = np.max(im_shape[0:2])
    im_scale = float(input_size) / float(im_size_max)
W
WenmuZhou 已提交
293 294
    img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
    return img
L
LDOUBLEV 已提交
295 296


W
WenmuZhou 已提交
297 298 299 300 301
def draw_ocr(image,
             boxes,
             txts=None,
             scores=None,
             drop_score=0.5,
L
LDOUBLEV 已提交
302
             font_path="./doc/fonts/simfang.ttf"):
303 304 305
    """
    Visualize the results of OCR detection and recognition
    args:
L
LDOUBLEV 已提交
306
        image(Image|array): RGB image
307 308 309 310
        boxes(list): boxes with shape(N, 4, 2)
        txts(list): the texts
        scores(list): txxs corresponding scores
        drop_score(float): only scores greater than drop_threshold will be visualized
W
WenmuZhou 已提交
311
        font_path: the path of font which is used to draw text
312 313 314
    return(array):
        the visualized img
    """
L
LDOUBLEV 已提交
315 316
    if scores is None:
        scores = [1] * len(boxes)
W
WenmuZhou 已提交
317 318 319 320
    box_num = len(boxes)
    for i in range(box_num):
        if scores is not None and (scores[i] < drop_score or
                                   math.isnan(scores[i])):
L
LDOUBLEV 已提交
321
            continue
W
WenmuZhou 已提交
322
        box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
L
LDOUBLEV 已提交
323
        image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
W
WenmuZhou 已提交
324
    if txts is not None:
L
LDOUBLEV 已提交
325
        img = np.array(resize_img(image, input_size=600))
326
        txt_img = text_visual(
W
WenmuZhou 已提交
327 328 329 330 331 332
            txts,
            scores,
            img_h=img.shape[0],
            img_w=600,
            threshold=drop_score,
            font_path=font_path)
333
        img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
L
LDOUBLEV 已提交
334 335
        return img
    return image
336 337


W
WenmuZhou 已提交
338 339 340 341 342 343
def draw_ocr_box_txt(image,
                     boxes,
                     txts,
                     scores=None,
                     drop_score=0.5,
                     font_path="./doc/simfang.ttf"):
344 345 346
    h, w = image.height, image.width
    img_left = image.copy()
    img_right = Image.new('RGB', (w, h), (255, 255, 255))
347 348

    import random
L
LDOUBLEV 已提交
349

350 351 352
    random.seed(0)
    draw_left = ImageDraw.Draw(img_left)
    draw_right = ImageDraw.Draw(img_right)
W
WenmuZhou 已提交
353 354 355
    for idx, (box, txt) in enumerate(zip(boxes, txts)):
        if scores is not None and scores[idx] < drop_score:
            continue
T
tink2123 已提交
356 357
        color = (random.randint(0, 255), random.randint(0, 255),
                 random.randint(0, 255))
358
        draw_left.polygon(box, fill=color)
T
tink2123 已提交
359 360 361 362 363 364 365 366 367 368
        draw_right.polygon(
            [
                box[0][0], box[0][1], box[1][0], box[1][1], box[2][0],
                box[2][1], box[3][0], box[3][1]
            ],
            outline=color)
        box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][
            1])**2)
        box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][
            1])**2)
369 370
        if box_height > 2 * box_width:
            font_size = max(int(box_width * 0.9), 10)
W
WenmuZhou 已提交
371
            font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
372 373 374
            cur_y = box[0][1]
            for c in txt:
                char_size = font.getsize(c)
T
tink2123 已提交
375 376
                draw_right.text(
                    (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font)
377 378 379
                cur_y += char_size[1]
        else:
            font_size = max(int(box_height * 0.8), 10)
W
WenmuZhou 已提交
380
            font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
T
tink2123 已提交
381 382
            draw_right.text(
                [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
383 384 385 386
    img_left = Image.blend(image, img_left, 0.5)
    img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
    img_show.paste(img_left, (0, 0, w, h))
    img_show.paste(img_right, (w, 0, w * 2, h))
387 388 389
    return np.array(img_show)


390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413
def str_count(s):
    """
    Count the number of Chinese characters,
    a single English character and a single number
    equal to half the length of Chinese characters.
    args:
        s(string): the input of string
    return(int):
        the number of Chinese characters
    """
    import string
    count_zh = count_pu = 0
    s_len = len(s)
    en_dg_count = 0
    for c in s:
        if c in string.ascii_letters or c.isdigit() or c.isspace():
            en_dg_count += 1
        elif c.isalpha():
            count_zh += 1
        else:
            count_pu += 1
    return s_len - math.ceil(en_dg_count / 2)


W
WenmuZhou 已提交
414 415 416 417 418 419
def text_visual(texts,
                scores,
                img_h=400,
                img_w=600,
                threshold=0.,
                font_path="./doc/simfang.ttf"):
420 421 422 423 424 425 426
    """
    create new blank img and draw txt on it
    args:
        texts(list): the text will be draw
        scores(list|None): corresponding score of each txt
        img_h(int): the height of blank img
        img_w(int): the width of blank img
W
WenmuZhou 已提交
427
        font_path: the path of font which is used to draw text
428 429 430 431 432 433 434 435 436
    return(array):
    """
    if scores is not None:
        assert len(texts) == len(
            scores), "The number of txts and corresponding scores must match"

    def create_blank_img():
        blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
        blank_img[:, img_w - 1:] = 0
L
LDOUBLEV 已提交
437 438
        blank_img = Image.fromarray(blank_img).convert("RGB")
        draw_txt = ImageDraw.Draw(blank_img)
439
        return blank_img, draw_txt
L
LDOUBLEV 已提交
440

441 442 443 444
    blank_img, draw_txt = create_blank_img()

    font_size = 20
    txt_color = (0, 0, 0)
W
WenmuZhou 已提交
445
    font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
446 447 448

    gap = font_size + 5
    txt_img_list = []
L
LDOUBLEV 已提交
449
    count, index = 1, 0
450 451
    for idx, txt in enumerate(texts):
        index += 1
L
LDOUBLEV 已提交
452
        if scores[idx] < threshold or math.isnan(scores[idx]):
453 454 455 456 457 458 459 460 461 462 463
            index -= 1
            continue
        first_line = True
        while str_count(txt) >= img_w // font_size - 4:
            tmp = txt
            txt = tmp[:img_w // font_size - 4]
            if first_line:
                new_txt = str(index) + ': ' + txt
                first_line = False
            else:
                new_txt = '    ' + txt
L
LDOUBLEV 已提交
464
            draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
465 466 467 468 469
            txt = tmp[img_w // font_size - 4:]
            if count >= img_h // gap - 1:
                txt_img_list.append(np.array(blank_img))
                blank_img, draw_txt = create_blank_img()
                count = 0
L
LDOUBLEV 已提交
470
            count += 1
471 472 473
        if first_line:
            new_txt = str(index) + ': ' + txt + '   ' + '%.3f' % (scores[idx])
        else:
L
LDOUBLEV 已提交
474
            new_txt = "  " + txt + "  " + '%.3f' % (scores[idx])
L
LDOUBLEV 已提交
475
        draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
476
        # whether add new blank img or not
L
LDOUBLEV 已提交
477
        if count >= img_h // gap - 1 and idx + 1 < len(texts):
478 479 480
            txt_img_list.append(np.array(blank_img))
            blank_img, draw_txt = create_blank_img()
            count = 0
L
LDOUBLEV 已提交
481
        count += 1
482 483 484 485 486 487
    txt_img_list.append(np.array(blank_img))
    if len(txt_img_list) == 1:
        blank_img = np.array(txt_img_list[0])
    else:
        blank_img = np.concatenate(txt_img_list, axis=1)
    return np.array(blank_img)
L
LDOUBLEV 已提交
488 489


D
dyning 已提交
490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508
def base64_to_cv2(b64str):
    import base64
    data = base64.b64decode(b64str.encode('utf8'))
    data = np.fromstring(data, np.uint8)
    data = cv2.imdecode(data, cv2.IMREAD_COLOR)
    return data


def draw_boxes(image, boxes, scores=None, drop_score=0.5):
    if scores is None:
        scores = [1] * len(boxes)
    for (box, score) in zip(boxes, scores):
        if score < drop_score:
            continue
        box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
        image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
    return image


L
LDOUBLEV 已提交
509
if __name__ == '__main__':
L
LDOUBLEV 已提交
510
    pass