utility.py 17.8 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
W
WenmuZhou 已提交
16
import os
W
WenmuZhou 已提交
17
import sys
L
LDOUBLEV 已提交
18 19
import cv2
import numpy as np
L
LDOUBLEV 已提交
20 21
import json
from PIL import Image, ImageDraw, ImageFont
22
import math
W
WenmuZhou 已提交
23
from paddle import inference
24 25 26
import time
from ppocr.utils.logging import get_logger
logger = get_logger()
L
LDOUBLEV 已提交
27 28


29 30
def str2bool(v):
    return v.lower() in ("true", "t", "1")
L
LDOUBLEV 已提交
31 32


33
inference_args_list = [
34
    # name     type      defalue
35
    # params for prediction engine
36 37 38 39 40 41 42 43 44 45
    ['use_gpu', str2bool, True],
    ['use_tensorrt', str2bool, False],
    ['use_fp16', str2bool, False],
    ['use_pdserving', str2bool, False],
    ['use_mp', str2bool, False],
    ['enable_mkldnn', str2bool, False],
    ['ir_optim', str2bool, True],
    ['total_process_num', int, 1],
    ['process_id', int, 0],
    ['gpu_mem', int, 500],
W
WenmuZhou 已提交
46
    ['cpu_threads', int, 10],
W
WenmuZhou 已提交
47
    # params for text detector
48 49 50 51 52
    ['image_dir', str, None],
    ['det_algorithm', str, 'DB'],
    ['det_model_dir', str, None],
    ['det_limit_side_len', float, 960],
    ['det_limit_type', str, 'max'],
W
WenmuZhou 已提交
53
    # DB parmas
54 55 56 57 58 59
    ['det_db_thresh', float, 0.3],
    ['det_db_box_thresh', float, 0.5],
    ['det_db_unclip_ratio', float, 1.6],
    ['max_batch_size', int, 10],
    ['use_dilation', str2bool, False],
    ['det_db_score_mode', str, 'fast'],
W
WenmuZhou 已提交
60
    # EAST parmas
61 62 63
    ['det_east_score_thresh', float, 0.8],
    ['det_east_cover_thresh', float, 0.1],
    ['det_east_nms_thresh', float, 0.2],
W
WenmuZhou 已提交
64
    # SAST parmas
65 66 67
    ['det_sast_score_thresh', float, 0.5],
    ['det_sast_nms_thresh', float, 0.2],
    ['det_sast_polygon', str2bool, False],
W
WenmuZhou 已提交
68
    # params for text recognizer
69 70 71 72 73 74 75 76 77 78
    ['rec_algorithm', str, 'CRNN'],
    ['rec_model_dir', str, None],
    ['rec_image_shape', str, '3, 32, 320'],
    ['rec_char_type', str, "ch"],
    ['rec_batch_num', int, 6],
    ['max_text_length', int, 25],
    ['rec_char_dict_path', str, './ppocr/utils/ppocr_keys_v1.txt'],
    ['use_space_char', str2bool, True],
    ['vis_font_path', str, './doc/fonts/simfang.ttf'],
    ['drop_score', float, 0.5],
J
Jethong 已提交
79
    # params for e2e
80 81 82 83
    ['e2e_algorithm', str, 'PGNet'],
    ['e2e_model_dir', str, None],
    ['e2e_limit_side_len', float, 768],
    ['e2e_limit_type', str, 'max'],
J
Jethong 已提交
84
    # PGNet parmas
85 86 87 88 89
    ['e2e_pgnet_score_thresh', float, 0.5],
    ['e2e_char_dict_path', str, './ppocr/utils/ic15_dict.txt'],
    ['e2e_pgnet_valid_set', str, 'totaltext'],
    ['e2e_pgnet_polygon', str2bool, True],
    ['e2e_pgnet_mode', str, 'fast'],
W
WenmuZhou 已提交
90
    # params for text classifier
91 92 93 94 95 96
    ['use_angle_cls', str2bool, False],
    ['cls_model_dir', str, None],
    ['cls_image_shape', str, '3, 48, 192'],
    ['label_list', list, ['0', '180']],
    ['cls_batch_num', int, 6],
    ['cls_thresh', float, 0.9],
97
]
W
WenmuZhou 已提交
98

99

100 101 102
def parse_args():
    parser = argparse.ArgumentParser()
    for item in inference_args_list:
103
        parser.add_argument('--' + item[0], type=item[1], default=item[2])
L
LDOUBLEV 已提交
104 105 106
    return parser.parse_args()


W
WenmuZhou 已提交
107 108 109 110 111
def create_predictor(args, mode, logger):
    if mode == "det":
        model_dir = args.det_model_dir
    elif mode == 'cls':
        model_dir = args.cls_model_dir
J
Jethong 已提交
112
    elif mode == 'rec':
W
WenmuZhou 已提交
113
        model_dir = args.rec_model_dir
J
Jethong 已提交
114 115
    else:
        model_dir = args.e2e_model_dir
W
WenmuZhou 已提交
116 117 118 119

    if model_dir is None:
        logger.info("not find {} model file path {}".format(mode, model_dir))
        sys.exit(0)
文幕地方's avatar
文幕地方 已提交
120 121
    model_file_path = model_dir + "/inference.pdmodel"
    params_file_path = model_dir + "/inference.pdiparams"
W
WenmuZhou 已提交
122 123 124 125 126 127 128
    if not os.path.exists(model_file_path):
        logger.info("not find model file path {}".format(model_file_path))
        sys.exit(0)
    if not os.path.exists(params_file_path):
        logger.info("not find params file path {}".format(params_file_path))
        sys.exit(0)

W
WenmuZhou 已提交
129
    config = inference.Config(model_file_path, params_file_path)
W
WenmuZhou 已提交
130 131 132

    if args.use_gpu:
        config.enable_use_gpu(args.gpu_mem, 0)
L
LDOUBLEV 已提交
133 134
        if args.use_tensorrt:
            config.enable_tensorrt_engine(
L
LDOUBLEV 已提交
135 136
                precision_mode=inference.PrecisionType.Float32,
                max_batch_size=args.max_batch_size,
W
WenmuZhou 已提交
137
                min_subgraph_size=3)  # skip the minmum trt subgraph
L
LDOUBLEV 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
        if mode == "det" and "mobile" in model_file_path:
            min_input_shape = {
                "x": [1, 3, 50, 50],
                "conv2d_92.tmp_0": [1, 96, 20, 20],
                "conv2d_91.tmp_0": [1, 96, 10, 10],
                "nearest_interp_v2_1.tmp_0": [1, 96, 10, 10],
                "nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
                "nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
                "nearest_interp_v2_4.tmp_0": [1, 24, 20, 20],
                "nearest_interp_v2_5.tmp_0": [1, 24, 20, 20],
                "elementwise_add_7": [1, 56, 2, 2],
                "nearest_interp_v2_0.tmp_0": [1, 96, 2, 2]
            }
            max_input_shape = {
                "x": [1, 3, 2000, 2000],
                "conv2d_92.tmp_0": [1, 96, 400, 400],
                "conv2d_91.tmp_0": [1, 96, 200, 200],
                "nearest_interp_v2_1.tmp_0": [1, 96, 200, 200],
                "nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
                "nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
                "nearest_interp_v2_4.tmp_0": [1, 24, 400, 400],
                "nearest_interp_v2_5.tmp_0": [1, 24, 400, 400],
                "elementwise_add_7": [1, 56, 400, 400],
                "nearest_interp_v2_0.tmp_0": [1, 96, 400, 400]
            }
            opt_input_shape = {
                "x": [1, 3, 640, 640],
                "conv2d_92.tmp_0": [1, 96, 160, 160],
                "conv2d_91.tmp_0": [1, 96, 80, 80],
                "nearest_interp_v2_1.tmp_0": [1, 96, 80, 80],
                "nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
                "nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
                "nearest_interp_v2_4.tmp_0": [1, 24, 160, 160],
                "nearest_interp_v2_5.tmp_0": [1, 24, 160, 160],
                "elementwise_add_7": [1, 56, 40, 40],
                "nearest_interp_v2_0.tmp_0": [1, 96, 40, 40]
            }
        if mode == "det" and "server" in model_file_path:
            min_input_shape = {
                "x": [1, 3, 50, 50],
                "conv2d_59.tmp_0": [1, 96, 20, 20],
                "nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
                "nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
                "nearest_interp_v2_4.tmp_0": [1, 24, 20, 20],
                "nearest_interp_v2_5.tmp_0": [1, 24, 20, 20]
            }
            max_input_shape = {
                "x": [1, 3, 2000, 2000],
                "conv2d_59.tmp_0": [1, 96, 400, 400],
                "nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
                "nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
                "nearest_interp_v2_4.tmp_0": [1, 24, 400, 400],
                "nearest_interp_v2_5.tmp_0": [1, 24, 400, 400]
            }
            opt_input_shape = {
                "x": [1, 3, 640, 640],
                "conv2d_59.tmp_0": [1, 96, 160, 160],
                "nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
                "nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
                "nearest_interp_v2_4.tmp_0": [1, 24, 160, 160],
                "nearest_interp_v2_5.tmp_0": [1, 24, 160, 160]
            }
        elif mode == "rec":
            min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]}
            max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]}
            opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]}
        elif mode == "cls":
            min_input_shape = {"x": [args.rec_batch_num, 3, 48, 10]}
            max_input_shape = {"x": [args.rec_batch_num, 3, 48, 2000]}
            opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
L
LDOUBLEV 已提交
208 209 210 211
        else:
            min_input_shape = {"x": [1, 3, 10, 10]}
            max_input_shape = {"x": [1, 3, 1000, 1000]}
            opt_input_shape = {"x": [1, 3, 500, 500]}
L
LDOUBLEV 已提交
212 213 214
        config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
                                          opt_input_shape)

W
WenmuZhou 已提交
215 216
    else:
        config.disable_gpu()
217 218 219
        if hasattr(args, "cpu_threads"):
            config.set_cpu_math_library_num_threads(args.cpu_threads)
        else:
W
WenmuZhou 已提交
220 221
            config.set_cpu_math_library_num_threads(
                10)  # default cpu threads as 10
W
WenmuZhou 已提交
222 223 224 225 226
        if args.enable_mkldnn:
            # cache 10 different shapes for mkldnn to avoid memory leak
            config.set_mkldnn_cache_capacity(10)
            config.enable_mkldnn()

L
LDOUBLEV 已提交
227 228
    # enable memory optim
    config.enable_memory_optim()
W
WenmuZhou 已提交
229 230
    config.disable_glog_info()

W
WenmuZhou 已提交
231 232
    config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
    config.switch_use_feed_fetch_ops(False)
W
WenmuZhou 已提交
233

W
WenmuZhou 已提交
234 235
    # create predictor
    predictor = inference.create_predictor(config)
W
WenmuZhou 已提交
236 237
    input_names = predictor.get_input_names()
    for name in input_names:
W
WenmuZhou 已提交
238
        input_tensor = predictor.get_input_handle(name)
W
WenmuZhou 已提交
239 240 241
    output_names = predictor.get_output_names()
    output_tensors = []
    for output_name in output_names:
W
WenmuZhou 已提交
242
        output_tensor = predictor.get_output_handle(output_name)
W
WenmuZhou 已提交
243 244 245 246
        output_tensors.append(output_tensor)
    return predictor, input_tensor, output_tensors


J
Jethong 已提交
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
def draw_e2e_res(dt_boxes, strs, img_path):
    src_im = cv2.imread(img_path)
    for box, str in zip(dt_boxes, strs):
        box = box.astype(np.int32).reshape((-1, 1, 2))
        cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
        cv2.putText(
            src_im,
            str,
            org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
            fontFace=cv2.FONT_HERSHEY_COMPLEX,
            fontScale=0.7,
            color=(0, 255, 0),
            thickness=1)
    return src_im


L
LDOUBLEV 已提交
263
def draw_text_det_res(dt_boxes, img_path):
L
LDOUBLEV 已提交
264 265 266 267
    src_im = cv2.imread(img_path)
    for box in dt_boxes:
        box = np.array(box).astype(np.int32).reshape(-1, 2)
        cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
L
LDOUBLEV 已提交
268
    return src_im
L
LDOUBLEV 已提交
269 270


L
LDOUBLEV 已提交
271 272
def resize_img(img, input_size=600):
    """
L
LDOUBLEV 已提交
273
    resize img and limit the longest side of the image to input_size
L
LDOUBLEV 已提交
274 275 276 277 278
    """
    img = np.array(img)
    im_shape = img.shape
    im_size_max = np.max(im_shape[0:2])
    im_scale = float(input_size) / float(im_size_max)
W
WenmuZhou 已提交
279 280
    img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
    return img
L
LDOUBLEV 已提交
281 282


W
WenmuZhou 已提交
283 284 285 286 287
def draw_ocr(image,
             boxes,
             txts=None,
             scores=None,
             drop_score=0.5,
L
LDOUBLEV 已提交
288
             font_path="./doc/fonts/simfang.ttf"):
289 290 291
    """
    Visualize the results of OCR detection and recognition
    args:
L
LDOUBLEV 已提交
292
        image(Image|array): RGB image
293 294 295 296
        boxes(list): boxes with shape(N, 4, 2)
        txts(list): the texts
        scores(list): txxs corresponding scores
        drop_score(float): only scores greater than drop_threshold will be visualized
W
WenmuZhou 已提交
297
        font_path: the path of font which is used to draw text
298 299 300
    return(array):
        the visualized img
    """
L
LDOUBLEV 已提交
301 302
    if scores is None:
        scores = [1] * len(boxes)
W
WenmuZhou 已提交
303 304 305 306
    box_num = len(boxes)
    for i in range(box_num):
        if scores is not None and (scores[i] < drop_score or
                                   math.isnan(scores[i])):
L
LDOUBLEV 已提交
307
            continue
W
WenmuZhou 已提交
308
        box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
L
LDOUBLEV 已提交
309
        image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
W
WenmuZhou 已提交
310
    if txts is not None:
L
LDOUBLEV 已提交
311
        img = np.array(resize_img(image, input_size=600))
312
        txt_img = text_visual(
W
WenmuZhou 已提交
313 314 315 316 317 318
            txts,
            scores,
            img_h=img.shape[0],
            img_w=600,
            threshold=drop_score,
            font_path=font_path)
319
        img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
L
LDOUBLEV 已提交
320 321
        return img
    return image
322 323


W
WenmuZhou 已提交
324 325 326 327 328 329
def draw_ocr_box_txt(image,
                     boxes,
                     txts,
                     scores=None,
                     drop_score=0.5,
                     font_path="./doc/simfang.ttf"):
330 331 332
    h, w = image.height, image.width
    img_left = image.copy()
    img_right = Image.new('RGB', (w, h), (255, 255, 255))
333 334

    import random
L
LDOUBLEV 已提交
335

336 337 338
    random.seed(0)
    draw_left = ImageDraw.Draw(img_left)
    draw_right = ImageDraw.Draw(img_right)
W
WenmuZhou 已提交
339 340 341
    for idx, (box, txt) in enumerate(zip(boxes, txts)):
        if scores is not None and scores[idx] < drop_score:
            continue
T
tink2123 已提交
342 343
        color = (random.randint(0, 255), random.randint(0, 255),
                 random.randint(0, 255))
344
        draw_left.polygon(box, fill=color)
T
tink2123 已提交
345 346 347 348 349 350 351 352 353 354
        draw_right.polygon(
            [
                box[0][0], box[0][1], box[1][0], box[1][1], box[2][0],
                box[2][1], box[3][0], box[3][1]
            ],
            outline=color)
        box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][
            1])**2)
        box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][
            1])**2)
355 356
        if box_height > 2 * box_width:
            font_size = max(int(box_width * 0.9), 10)
W
WenmuZhou 已提交
357
            font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
358 359 360
            cur_y = box[0][1]
            for c in txt:
                char_size = font.getsize(c)
T
tink2123 已提交
361 362
                draw_right.text(
                    (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font)
363 364 365
                cur_y += char_size[1]
        else:
            font_size = max(int(box_height * 0.8), 10)
W
WenmuZhou 已提交
366
            font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
T
tink2123 已提交
367 368
            draw_right.text(
                [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
369 370 371 372
    img_left = Image.blend(image, img_left, 0.5)
    img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
    img_show.paste(img_left, (0, 0, w, h))
    img_show.paste(img_right, (w, 0, w * 2, h))
373 374 375
    return np.array(img_show)


376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
def str_count(s):
    """
    Count the number of Chinese characters,
    a single English character and a single number
    equal to half the length of Chinese characters.
    args:
        s(string): the input of string
    return(int):
        the number of Chinese characters
    """
    import string
    count_zh = count_pu = 0
    s_len = len(s)
    en_dg_count = 0
    for c in s:
        if c in string.ascii_letters or c.isdigit() or c.isspace():
            en_dg_count += 1
        elif c.isalpha():
            count_zh += 1
        else:
            count_pu += 1
    return s_len - math.ceil(en_dg_count / 2)


W
WenmuZhou 已提交
400 401 402 403 404 405
def text_visual(texts,
                scores,
                img_h=400,
                img_w=600,
                threshold=0.,
                font_path="./doc/simfang.ttf"):
406 407 408 409 410 411 412
    """
    create new blank img and draw txt on it
    args:
        texts(list): the text will be draw
        scores(list|None): corresponding score of each txt
        img_h(int): the height of blank img
        img_w(int): the width of blank img
W
WenmuZhou 已提交
413
        font_path: the path of font which is used to draw text
414 415 416 417 418 419 420 421 422
    return(array):
    """
    if scores is not None:
        assert len(texts) == len(
            scores), "The number of txts and corresponding scores must match"

    def create_blank_img():
        blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
        blank_img[:, img_w - 1:] = 0
L
LDOUBLEV 已提交
423 424
        blank_img = Image.fromarray(blank_img).convert("RGB")
        draw_txt = ImageDraw.Draw(blank_img)
425
        return blank_img, draw_txt
L
LDOUBLEV 已提交
426

427 428 429 430
    blank_img, draw_txt = create_blank_img()

    font_size = 20
    txt_color = (0, 0, 0)
W
WenmuZhou 已提交
431
    font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
432 433 434

    gap = font_size + 5
    txt_img_list = []
L
LDOUBLEV 已提交
435
    count, index = 1, 0
436 437
    for idx, txt in enumerate(texts):
        index += 1
L
LDOUBLEV 已提交
438
        if scores[idx] < threshold or math.isnan(scores[idx]):
439 440 441 442 443 444 445 446 447 448 449
            index -= 1
            continue
        first_line = True
        while str_count(txt) >= img_w // font_size - 4:
            tmp = txt
            txt = tmp[:img_w // font_size - 4]
            if first_line:
                new_txt = str(index) + ': ' + txt
                first_line = False
            else:
                new_txt = '    ' + txt
L
LDOUBLEV 已提交
450
            draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
451 452 453 454 455
            txt = tmp[img_w // font_size - 4:]
            if count >= img_h // gap - 1:
                txt_img_list.append(np.array(blank_img))
                blank_img, draw_txt = create_blank_img()
                count = 0
L
LDOUBLEV 已提交
456
            count += 1
457 458 459
        if first_line:
            new_txt = str(index) + ': ' + txt + '   ' + '%.3f' % (scores[idx])
        else:
L
LDOUBLEV 已提交
460
            new_txt = "  " + txt + "  " + '%.3f' % (scores[idx])
L
LDOUBLEV 已提交
461
        draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
462
        # whether add new blank img or not
L
LDOUBLEV 已提交
463
        if count >= img_h // gap - 1 and idx + 1 < len(texts):
464 465 466
            txt_img_list.append(np.array(blank_img))
            blank_img, draw_txt = create_blank_img()
            count = 0
L
LDOUBLEV 已提交
467
        count += 1
468 469 470 471 472 473
    txt_img_list.append(np.array(blank_img))
    if len(txt_img_list) == 1:
        blank_img = np.array(txt_img_list[0])
    else:
        blank_img = np.concatenate(txt_img_list, axis=1)
    return np.array(blank_img)
L
LDOUBLEV 已提交
474 475


D
dyning 已提交
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
def base64_to_cv2(b64str):
    import base64
    data = base64.b64decode(b64str.encode('utf8'))
    data = np.fromstring(data, np.uint8)
    data = cv2.imdecode(data, cv2.IMREAD_COLOR)
    return data


def draw_boxes(image, boxes, scores=None, drop_score=0.5):
    if scores is None:
        scores = [1] * len(boxes)
    for (box, score) in zip(boxes, scores):
        if score < drop_score:
            continue
        box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
        image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
    return image


L
LDOUBLEV 已提交
495
if __name__ == '__main__':
L
LDOUBLEV 已提交
496
    pass