predict_rec.py 6.0 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
LDOUBLEV 已提交
14 15
import os
import sys
W
WenmuZhou 已提交
16

17
__dir__ = os.path.dirname(os.path.abspath(__file__))
L
LDOUBLEV 已提交
18
sys.path.append(__dir__)
19
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
L
LDOUBLEV 已提交
20 21 22 23 24

import cv2
import numpy as np
import math
import time
W
WenmuZhou 已提交
25
import traceback
26 27 28
import paddle.fluid as fluid

import tools.infer.utility as utility
W
WenmuZhou 已提交
29 30
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
31
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
L
LDOUBLEV 已提交
32

W
WenmuZhou 已提交
33 34
logger = get_logger()

L
LDOUBLEV 已提交
35 36 37

class TextRecognizer(object):
    def __init__(self, args):
38
        self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
D
dyning 已提交
39
        self.character_type = args.rec_char_type
40
        self.rec_batch_num = args.rec_batch_num
T
tink2123 已提交
41
        self.rec_algorithm = args.rec_algorithm
littletomatodonkey's avatar
littletomatodonkey 已提交
42
        self.use_zero_copy_run = args.use_zero_copy_run
W
WenmuZhou 已提交
43 44
        postprocess_params = {
            'name': 'CTCLabelDecode',
T
tink2123 已提交
45
            "character_type": args.rec_char_type,
46
            "character_dict_path": args.rec_char_dict_path,
W
WenmuZhou 已提交
47
            "use_space_char": args.use_space_char
T
tink2123 已提交
48
        }
W
WenmuZhou 已提交
49 50 51
        self.postprocess_op = build_post_process(postprocess_params)
        self.predictor, self.input_tensor, self.output_tensors = \
            utility.create_predictor(args, 'rec', logger)
L
LDOUBLEV 已提交
52

53
    def resize_norm_img(self, img, max_wh_ratio):
L
LDOUBLEV 已提交
54
        imgC, imgH, imgW = self.rec_image_shape
55
        assert imgC == img.shape[2]
56
        if self.character_type == "ch":
T
tink2123 已提交
57
            imgW = int((32 * max_wh_ratio))
58
        h, w = img.shape[:2]
59 60 61 62 63
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
T
tink2123 已提交
64
        resized_image = cv2.resize(img, (resized_w, imgH))
L
LDOUBLEV 已提交
65 66 67 68 69 70 71 72 73 74
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

    def __call__(self, img_list):
        img_num = len(img_list)
75
        # Calculate the aspect ratio of all text bars
76 77 78
        width_list = []
        for img in img_list:
            width_list.append(img.shape[1] / float(img.shape[0]))
张欣-男's avatar
张欣-男 已提交
79
        # Sorting can speed up the recognition process
80 81 82 83
        indices = np.argsort(np.array(width_list))

        # rec_res = []
        rec_res = [['', 0.0]] * img_num
84
        batch_num = self.rec_batch_num
W
WenmuZhou 已提交
85
        elapse = 0
L
LDOUBLEV 已提交
86 87 88
        for beg_img_no in range(0, img_num, batch_num):
            end_img_no = min(img_num, beg_img_no + batch_num)
            norm_img_batch = []
89
            max_wh_ratio = 0
L
LDOUBLEV 已提交
90
            for ino in range(beg_img_no, end_img_no):
91 92
                # h, w = img_list[ino].shape[0:2]
                h, w = img_list[indices[ino]].shape[0:2]
93 94 95
                wh_ratio = w * 1.0 / h
                max_wh_ratio = max(max_wh_ratio, wh_ratio)
            for ino in range(beg_img_no, end_img_no):
96
                # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
T
tink2123 已提交
97 98
                norm_img = self.resize_norm_img(img_list[indices[ino]],
                                                max_wh_ratio)
L
LDOUBLEV 已提交
99 100 101 102 103
                norm_img = norm_img[np.newaxis, :]
                norm_img_batch.append(norm_img)
            norm_img_batch = np.concatenate(norm_img_batch)
            norm_img_batch = norm_img_batch.copy()
            starttime = time.time()
littletomatodonkey's avatar
littletomatodonkey 已提交
104 105 106 107 108 109
            if self.use_zero_copy_run:
                self.input_tensor.copy_from_cpu(norm_img_batch)
                self.predictor.zero_copy_run()
            else:
                norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
                self.predictor.run([norm_img_batch])
W
WenmuZhou 已提交
110 111 112 113 114
            outputs = []
            for output_tensor in self.output_tensors:
                output = output_tensor.copy_to_cpu()
                outputs.append(output)
            preds = outputs[0]
W
WenmuZhou 已提交
115 116 117
            rec_result = self.postprocess_op(preds)
            for rno in range(len(rec_result)):
                rec_res[indices[beg_img_no + rno]] = rec_result[rno]
118
            elapse += time.time() - starttime
W
WenmuZhou 已提交
119
        return rec_res, elapse
L
LDOUBLEV 已提交
120 121


122
def main(args):
D
dyning 已提交
123
    image_file_list = get_image_file_list(args.image_dir)
L
LDOUBLEV 已提交
124 125 126 127
    text_recognizer = TextRecognizer(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list:
L
LDOUBLEV 已提交
128 129 130
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
L
LDOUBLEV 已提交
131 132 133 134 135
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
T
tink2123 已提交
136 137
    try:
        rec_res, predict_time = text_recognizer(img_list)
W
WenmuZhou 已提交
138 139
    except:
        logger.info(traceback.format_exc())
T
tink2123 已提交
140
        logger.info(
T
tink2123 已提交
141 142 143 144
            "ERROR!!!! \n"
            "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
            "If your model has tps module:  "
            "TPS does not support variable shape.\n"
T
tink2123 已提交
145
            "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
T
tink2123 已提交
146
        exit()
L
LDOUBLEV 已提交
147
    for ino in range(len(img_list)):
W
WenmuZhou 已提交
148
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[
W
WenmuZhou 已提交
149
            ino]))
W
WenmuZhou 已提交
150
    logger.info("Total predict time for {} images, cost: {:.3f}".format(
W
WenmuZhou 已提交
151
        len(img_list), predict_time))
152 153 154 155


if __name__ == "__main__":
    main(utility.parse_args())