predict_rec.py 5.8 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
LDOUBLEV 已提交
14 15
import os
import sys
16
__dir__ = os.path.dirname(os.path.abspath(__file__))
L
LDOUBLEV 已提交
17
sys.path.append(__dir__)
18
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
L
LDOUBLEV 已提交
19 20 21 22 23 24

import cv2
import copy
import numpy as np
import math
import time
25 26 27 28

import paddle.fluid as fluid

import tools.infer.utility as utility
Z
zhoujun 已提交
29 30
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
31
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
L
LDOUBLEV 已提交
32 33 34 35


class TextRecognizer(object):
    def __init__(self, args):
36
        self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
D
dyning 已提交
37
        self.character_type = args.rec_char_type
38
        self.rec_batch_num = args.rec_batch_num
T
tink2123 已提交
39
        self.rec_algorithm = args.rec_algorithm
littletomatodonkey's avatar
littletomatodonkey 已提交
40
        self.use_zero_copy_run = args.use_zero_copy_run
Z
zhoujun 已提交
41 42
        postprocess_params = {
            'name': 'CTCLabelDecode',
T
tink2123 已提交
43
            "character_type": args.rec_char_type,
44
            "character_dict_path": args.rec_char_dict_path,
Z
zhoujun 已提交
45
            "use_space_char": args.use_space_char
T
tink2123 已提交
46
        }
Z
zhoujun 已提交
47 48 49
        self.postprocess_op = build_post_process(postprocess_params)
        self.predictor, self.input_tensor, self.output_tensors = \
            utility.create_predictor(args, 'rec', logger)
L
LDOUBLEV 已提交
50

51
    def resize_norm_img(self, img, max_wh_ratio):
L
LDOUBLEV 已提交
52
        imgC, imgH, imgW = self.rec_image_shape
53
        assert imgC == img.shape[2]
54
        if self.character_type == "ch":
T
tink2123 已提交
55
            imgW = int((32 * max_wh_ratio))
56
        h, w = img.shape[:2]
57 58 59 60 61
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
T
tink2123 已提交
62
        resized_image = cv2.resize(img, (resized_w, imgH))
L
LDOUBLEV 已提交
63 64 65 66 67 68 69 70 71 72
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

    def __call__(self, img_list):
        img_num = len(img_list)
73
        # Calculate the aspect ratio of all text bars
74 75 76
        width_list = []
        for img in img_list:
            width_list.append(img.shape[1] / float(img.shape[0]))
张欣-男's avatar
张欣-男 已提交
77
        # Sorting can speed up the recognition process
78 79 80 81
        indices = np.argsort(np.array(width_list))

        # rec_res = []
        rec_res = [['', 0.0]] * img_num
82
        batch_num = self.rec_batch_num
L
LDOUBLEV 已提交
83 84 85 86
        predict_time = 0
        for beg_img_no in range(0, img_num, batch_num):
            end_img_no = min(img_num, beg_img_no + batch_num)
            norm_img_batch = []
87
            max_wh_ratio = 0
L
LDOUBLEV 已提交
88
            for ino in range(beg_img_no, end_img_no):
89 90
                # h, w = img_list[ino].shape[0:2]
                h, w = img_list[indices[ino]].shape[0:2]
91 92 93
                wh_ratio = w * 1.0 / h
                max_wh_ratio = max(max_wh_ratio, wh_ratio)
            for ino in range(beg_img_no, end_img_no):
94
                # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
T
tink2123 已提交
95 96
                norm_img = self.resize_norm_img(img_list[indices[ino]],
                                                max_wh_ratio)
L
LDOUBLEV 已提交
97 98 99 100 101
                norm_img = norm_img[np.newaxis, :]
                norm_img_batch.append(norm_img)
            norm_img_batch = np.concatenate(norm_img_batch)
            norm_img_batch = norm_img_batch.copy()
            starttime = time.time()
littletomatodonkey's avatar
littletomatodonkey 已提交
102 103 104 105 106 107
            if self.use_zero_copy_run:
                self.input_tensor.copy_from_cpu(norm_img_batch)
                self.predictor.zero_copy_run()
            else:
                norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
                self.predictor.run([norm_img_batch])
Z
zhoujun 已提交
108 109 110 111 112 113 114 115
            outputs = []
            for output_tensor in self.output_tensors:
                output = output_tensor.copy_to_cpu()
                outputs.append(output)
            preds = outputs[0]
            rec_res = self.postprocess_op(preds)
            elapse = time.time() - starttime
        return rec_res, elapse
L
LDOUBLEV 已提交
116 117


118
def main(args):
D
dyning 已提交
119
    image_file_list = get_image_file_list(args.image_dir)
L
LDOUBLEV 已提交
120 121 122 123
    text_recognizer = TextRecognizer(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list:
L
LDOUBLEV 已提交
124 125 126
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
L
LDOUBLEV 已提交
127 128 129 130 131
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
T
tink2123 已提交
132 133
    try:
        rec_res, predict_time = text_recognizer(img_list)
T
tink2123 已提交
134 135
    except Exception as e:
        print(e)
T
tink2123 已提交
136
        logger.info(
T
tink2123 已提交
137 138 139 140
            "ERROR!!!! \n"
            "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
            "If your model has tps module:  "
            "TPS does not support variable shape.\n"
T
tink2123 已提交
141
            "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
T
tink2123 已提交
142
        exit()
L
LDOUBLEV 已提交
143 144
    for ino in range(len(img_list)):
        print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
Z
zhoujun 已提交
145
    print("Total predict time for %d images, cost: %.3f" %
L
LDOUBLEV 已提交
146
          (len(img_list), predict_time))
147 148 149


if __name__ == "__main__":
Z
zhoujun 已提交
150
    logger = get_logger()
151
    main(utility.parse_args())