utility.py 13.6 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
W
WenmuZhou 已提交
16
import os
W
WenmuZhou 已提交
17
import sys
L
LDOUBLEV 已提交
18 19
import cv2
import numpy as np
L
LDOUBLEV 已提交
20 21
import json
from PIL import Image, ImageDraw, ImageFont
22
import math
W
WenmuZhou 已提交
23 24
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor
L
LDOUBLEV 已提交
25 26 27 28 29 30 31


def parse_args():
    def str2bool(v):
        return v.lower() in ("true", "t", "1")

    parser = argparse.ArgumentParser()
W
WenmuZhou 已提交
32
    # params for prediction engine
L
LDOUBLEV 已提交
33 34 35 36 37
    parser.add_argument("--use_gpu", type=str2bool, default=True)
    parser.add_argument("--ir_optim", type=str2bool, default=True)
    parser.add_argument("--use_tensorrt", type=str2bool, default=False)
    parser.add_argument("--gpu_mem", type=int, default=8000)

W
WenmuZhou 已提交
38
    # params for text detector
L
LDOUBLEV 已提交
39 40 41
    parser.add_argument("--image_dir", type=str)
    parser.add_argument("--det_algorithm", type=str, default='DB')
    parser.add_argument("--det_model_dir", type=str)
W
WenmuZhou 已提交
42
    parser.add_argument("--det_max_side_len", type=float, default=960)
L
LDOUBLEV 已提交
43

W
WenmuZhou 已提交
44
    # DB parmas
L
LDOUBLEV 已提交
45 46
    parser.add_argument("--det_db_thresh", type=float, default=0.3)
    parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
W
WenmuZhou 已提交
47
    parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
L
LDOUBLEV 已提交
48

W
WenmuZhou 已提交
49
    # EAST parmas
L
LDOUBLEV 已提交
50 51 52 53
    parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
    parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
    parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)

W
WenmuZhou 已提交
54
    # SAST parmas
L
licx 已提交
55 56
    parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
    parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
57
    parser.add_argument("--det_sast_polygon", type=bool, default=False)
L
licx 已提交
58

W
WenmuZhou 已提交
59
    # params for text recognizer
L
LDOUBLEV 已提交
60 61
    parser.add_argument("--rec_algorithm", type=str, default='CRNN')
    parser.add_argument("--rec_model_dir", type=str)
T
fix bug  
tink2123 已提交
62 63
    parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
    parser.add_argument("--rec_char_type", type=str, default='ch')
W
WenmuZhou 已提交
64
    parser.add_argument("--rec_batch_num", type=int, default=6)
T
fix bug  
tink2123 已提交
65
    parser.add_argument("--max_text_length", type=int, default=25)
L
LDOUBLEV 已提交
66 67 68 69
    parser.add_argument(
        "--rec_char_dict_path",
        type=str,
        default="./ppocr/utils/ppocr_keys_v1.txt")
W
WenmuZhou 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    parser.add_argument("--use_space_char", type=str2bool, default=True)
    parser.add_argument(
        "--vis_font_path", type=str, default="./doc/simfang.ttf")

    # params for text classifier
    parser.add_argument("--use_angle_cls", type=str2bool, default=False)
    parser.add_argument("--cls_model_dir", type=str)
    parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
    parser.add_argument("--label_list", type=list, default=['0', '180'])
    parser.add_argument("--cls_batch_num", type=int, default=30)
    parser.add_argument("--cls_thresh", type=float, default=0.9)

    parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
    parser.add_argument("--use_zero_copy_run", type=str2bool, default=False)

    parser.add_argument("--use_pdserving", type=str2bool, default=False)

L
LDOUBLEV 已提交
87 88 89
    return parser.parse_args()


W
WenmuZhou 已提交
90 91 92 93 94 95 96 97 98 99 100
def create_predictor(args, mode, logger):
    if mode == "det":
        model_dir = args.det_model_dir
    elif mode == 'cls':
        model_dir = args.cls_model_dir
    else:
        model_dir = args.rec_model_dir

    if model_dir is None:
        logger.info("not find {} model file path {}".format(mode, model_dir))
        sys.exit(0)
W
WenmuZhou 已提交
101 102
    model_file_path = model_dir + "/model"
    params_file_path = model_dir + "/params"
W
WenmuZhou 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    if not os.path.exists(model_file_path):
        logger.info("not find model file path {}".format(model_file_path))
        sys.exit(0)
    if not os.path.exists(params_file_path):
        logger.info("not find params file path {}".format(params_file_path))
        sys.exit(0)

    config = AnalysisConfig(model_file_path, params_file_path)

    if args.use_gpu:
        config.enable_use_gpu(args.gpu_mem, 0)
    else:
        config.disable_gpu()
        config.set_cpu_math_library_num_threads(6)
        if args.enable_mkldnn:
            # cache 10 different shapes for mkldnn to avoid memory leak
            config.set_mkldnn_cache_capacity(10)
            config.enable_mkldnn()

    # config.enable_memory_optim()
    config.disable_glog_info()

    if args.use_zero_copy_run:
        config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
        config.switch_use_feed_fetch_ops(False)
    else:
        config.switch_use_feed_fetch_ops(True)

    predictor = create_paddle_predictor(config)
    input_names = predictor.get_input_names()
    for name in input_names:
        input_tensor = predictor.get_input_tensor(name)
    output_names = predictor.get_output_names()
    output_tensors = []
    for output_name in output_names:
        output_tensor = predictor.get_output_tensor(output_name)
        output_tensors.append(output_tensor)
    return predictor, input_tensor, output_tensors


L
LDOUBLEV 已提交
143
def draw_text_det_res(dt_boxes, img_path):
L
LDOUBLEV 已提交
144 145 146 147
    src_im = cv2.imread(img_path)
    for box in dt_boxes:
        box = np.array(box).astype(np.int32).reshape(-1, 2)
        cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
L
LDOUBLEV 已提交
148
    return src_im
L
LDOUBLEV 已提交
149 150


L
LDOUBLEV 已提交
151 152
def resize_img(img, input_size=600):
    """
L
LDOUBLEV 已提交
153
    resize img and limit the longest side of the image to input_size
L
LDOUBLEV 已提交
154 155 156 157 158
    """
    img = np.array(img)
    im_shape = img.shape
    im_size_max = np.max(im_shape[0:2])
    im_scale = float(input_size) / float(im_size_max)
W
WenmuZhou 已提交
159 160
    img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
    return img
L
LDOUBLEV 已提交
161 162


W
WenmuZhou 已提交
163 164 165 166 167 168
def draw_ocr(image,
             boxes,
             txts=None,
             scores=None,
             drop_score=0.5,
             font_path="./doc/simfang.ttf"):
169 170 171
    """
    Visualize the results of OCR detection and recognition
    args:
L
LDOUBLEV 已提交
172
        image(Image|array): RGB image
173 174 175 176
        boxes(list): boxes with shape(N, 4, 2)
        txts(list): the texts
        scores(list): txxs corresponding scores
        drop_score(float): only scores greater than drop_threshold will be visualized
W
WenmuZhou 已提交
177
        font_path: the path of font which is used to draw text
178 179 180
    return(array):
        the visualized img
    """
L
LDOUBLEV 已提交
181 182
    if scores is None:
        scores = [1] * len(boxes)
W
WenmuZhou 已提交
183 184 185 186
    box_num = len(boxes)
    for i in range(box_num):
        if scores is not None and (scores[i] < drop_score or
                                   math.isnan(scores[i])):
L
LDOUBLEV 已提交
187
            continue
W
WenmuZhou 已提交
188
        box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
L
LDOUBLEV 已提交
189
        image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
W
WenmuZhou 已提交
190
    if txts is not None:
L
LDOUBLEV 已提交
191
        img = np.array(resize_img(image, input_size=600))
192
        txt_img = text_visual(
W
WenmuZhou 已提交
193 194 195 196 197 198
            txts,
            scores,
            img_h=img.shape[0],
            img_w=600,
            threshold=drop_score,
            font_path=font_path)
199
        img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
L
LDOUBLEV 已提交
200 201
        return img
    return image
202 203


204 205 206 207
def draw_ocr_box_txt(image, boxes, txts):
    h, w = image.height, image.width
    img_left = image.copy()
    img_right = Image.new('RGB', (w, h), (255, 255, 255))
208 209

    import random
L
LDOUBLEV 已提交
210

211 212 213
    random.seed(0)
    draw_left = ImageDraw.Draw(img_left)
    draw_right = ImageDraw.Draw(img_right)
214
    for (box, txt) in zip(boxes, txts):
T
tink2123 已提交
215 216
        color = (random.randint(0, 255), random.randint(0, 255),
                 random.randint(0, 255))
217
        draw_left.polygon(box, fill=color)
T
tink2123 已提交
218 219 220 221 222 223 224 225 226 227
        draw_right.polygon(
            [
                box[0][0], box[0][1], box[1][0], box[1][1], box[2][0],
                box[2][1], box[3][0], box[3][1]
            ],
            outline=color)
        box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][
            1])**2)
        box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][
            1])**2)
228 229
        if box_height > 2 * box_width:
            font_size = max(int(box_width * 0.9), 10)
T
tink2123 已提交
230 231
            font = ImageFont.truetype(
                "./doc/simfang.ttf", font_size, encoding="utf-8")
232 233 234
            cur_y = box[0][1]
            for c in txt:
                char_size = font.getsize(c)
T
tink2123 已提交
235 236
                draw_right.text(
                    (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font)
237 238 239
                cur_y += char_size[1]
        else:
            font_size = max(int(box_height * 0.8), 10)
T
tink2123 已提交
240 241 242 243
            font = ImageFont.truetype(
                "./doc/simfang.ttf", font_size, encoding="utf-8")
            draw_right.text(
                [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
244 245 246 247
    img_left = Image.blend(image, img_left, 0.5)
    img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
    img_show.paste(img_left, (0, 0, w, h))
    img_show.paste(img_right, (w, 0, w * 2, h))
248 249 250
    return np.array(img_show)


251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
def str_count(s):
    """
    Count the number of Chinese characters,
    a single English character and a single number
    equal to half the length of Chinese characters.

    args:
        s(string): the input of string
    return(int):
        the number of Chinese characters
    """
    import string
    count_zh = count_pu = 0
    s_len = len(s)
    en_dg_count = 0
    for c in s:
        if c in string.ascii_letters or c.isdigit() or c.isspace():
            en_dg_count += 1
        elif c.isalpha():
            count_zh += 1
        else:
            count_pu += 1
    return s_len - math.ceil(en_dg_count / 2)


W
WenmuZhou 已提交
276 277 278 279 280 281
def text_visual(texts,
                scores,
                img_h=400,
                img_w=600,
                threshold=0.,
                font_path="./doc/simfang.ttf"):
282 283 284 285 286 287 288
    """
    create new blank img and draw txt on it
    args:
        texts(list): the text will be draw
        scores(list|None): corresponding score of each txt
        img_h(int): the height of blank img
        img_w(int): the width of blank img
W
WenmuZhou 已提交
289
        font_path: the path of font which is used to draw text
290 291 292 293 294 295 296 297 298 299
    return(array):

    """
    if scores is not None:
        assert len(texts) == len(
            scores), "The number of txts and corresponding scores must match"

    def create_blank_img():
        blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
        blank_img[:, img_w - 1:] = 0
L
LDOUBLEV 已提交
300 301
        blank_img = Image.fromarray(blank_img).convert("RGB")
        draw_txt = ImageDraw.Draw(blank_img)
302
        return blank_img, draw_txt
L
LDOUBLEV 已提交
303

304 305 306 307
    blank_img, draw_txt = create_blank_img()

    font_size = 20
    txt_color = (0, 0, 0)
W
WenmuZhou 已提交
308
    font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
309 310 311

    gap = font_size + 5
    txt_img_list = []
L
LDOUBLEV 已提交
312
    count, index = 1, 0
313 314
    for idx, txt in enumerate(texts):
        index += 1
L
LDOUBLEV 已提交
315
        if scores[idx] < threshold or math.isnan(scores[idx]):
316 317 318 319 320 321 322 323 324 325 326
            index -= 1
            continue
        first_line = True
        while str_count(txt) >= img_w // font_size - 4:
            tmp = txt
            txt = tmp[:img_w // font_size - 4]
            if first_line:
                new_txt = str(index) + ': ' + txt
                first_line = False
            else:
                new_txt = '    ' + txt
L
LDOUBLEV 已提交
327
            draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
328 329 330 331 332
            txt = tmp[img_w // font_size - 4:]
            if count >= img_h // gap - 1:
                txt_img_list.append(np.array(blank_img))
                blank_img, draw_txt = create_blank_img()
                count = 0
L
LDOUBLEV 已提交
333
            count += 1
334 335 336
        if first_line:
            new_txt = str(index) + ': ' + txt + '   ' + '%.3f' % (scores[idx])
        else:
L
LDOUBLEV 已提交
337
            new_txt = "  " + txt + "  " + '%.3f' % (scores[idx])
L
LDOUBLEV 已提交
338
        draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
339
        # whether add new blank img or not
L
LDOUBLEV 已提交
340
        if count >= img_h // gap - 1 and idx + 1 < len(texts):
341 342 343
            txt_img_list.append(np.array(blank_img))
            blank_img, draw_txt = create_blank_img()
            count = 0
L
LDOUBLEV 已提交
344
        count += 1
345 346 347 348 349 350
    txt_img_list.append(np.array(blank_img))
    if len(txt_img_list) == 1:
        blank_img = np.array(txt_img_list[0])
    else:
        blank_img = np.concatenate(txt_img_list, axis=1)
    return np.array(blank_img)
L
LDOUBLEV 已提交
351 352


D
dyning 已提交
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
def base64_to_cv2(b64str):
    import base64
    data = base64.b64decode(b64str.encode('utf8'))
    data = np.fromstring(data, np.uint8)
    data = cv2.imdecode(data, cv2.IMREAD_COLOR)
    return data


def draw_boxes(image, boxes, scores=None, drop_score=0.5):
    if scores is None:
        scores = [1] * len(boxes)
    for (box, score) in zip(boxes, scores):
        if score < drop_score:
            continue
        box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
        image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
    return image


L
LDOUBLEV 已提交
372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
if __name__ == '__main__':
    test_img = "./doc/test_v2"
    predict_txt = "./doc/predict.txt"
    f = open(predict_txt, 'r')
    data = f.readlines()
    img_path, anno = data[0].strip().split('\t')
    img_name = os.path.basename(img_path)
    img_path = os.path.join(test_img, img_name)
    image = Image.open(img_path)

    data = json.loads(anno)
    boxes, txts, scores = [], [], []
    for dic in data:
        boxes.append(dic['points'])
        txts.append(dic['transcription'])
        scores.append(round(dic['scores'], 3))

W
WenmuZhou 已提交
389
    new_img = draw_ocr(image, boxes, txts, scores)
L
LDOUBLEV 已提交
390

M
MissPenguin 已提交
391
    cv2.imwrite(img_name, new_img)