predict_rec.py 5.8 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
LDOUBLEV 已提交
14 15
import os
import sys
W
WenmuZhou 已提交
16

17
__dir__ = os.path.dirname(os.path.abspath(__file__))
L
LDOUBLEV 已提交
18
sys.path.append(__dir__)
19
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
L
LDOUBLEV 已提交
20 21 22 23 24

import cv2
import numpy as np
import math
import time
W
WenmuZhou 已提交
25
import traceback
26 27 28
import paddle.fluid as fluid

import tools.infer.utility as utility
W
WenmuZhou 已提交
29 30
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
31
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
L
LDOUBLEV 已提交
32

W
WenmuZhou 已提交
33 34
logger = get_logger()

L
LDOUBLEV 已提交
35 36 37

class TextRecognizer(object):
    def __init__(self, args):
38
        self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
D
dyning 已提交
39
        self.character_type = args.rec_char_type
40
        self.rec_batch_num = args.rec_batch_num
T
tink2123 已提交
41
        self.rec_algorithm = args.rec_algorithm
W
WenmuZhou 已提交
42 43
        postprocess_params = {
            'name': 'CTCLabelDecode',
T
tink2123 已提交
44
            "character_type": args.rec_char_type,
45
            "character_dict_path": args.rec_char_dict_path,
W
WenmuZhou 已提交
46
            "use_space_char": args.use_space_char
T
tink2123 已提交
47
        }
W
WenmuZhou 已提交
48 49 50
        self.postprocess_op = build_post_process(postprocess_params)
        self.predictor, self.input_tensor, self.output_tensors = \
            utility.create_predictor(args, 'rec', logger)
L
LDOUBLEV 已提交
51

52
    def resize_norm_img(self, img, max_wh_ratio):
L
LDOUBLEV 已提交
53
        imgC, imgH, imgW = self.rec_image_shape
54
        assert imgC == img.shape[2]
55
        if self.character_type == "ch":
T
tink2123 已提交
56
            imgW = int((32 * max_wh_ratio))
57
        h, w = img.shape[:2]
58 59 60 61 62
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
T
tink2123 已提交
63
        resized_image = cv2.resize(img, (resized_w, imgH))
L
LDOUBLEV 已提交
64 65 66 67 68 69 70 71 72 73
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

    def __call__(self, img_list):
        img_num = len(img_list)
74
        # Calculate the aspect ratio of all text bars
75 76 77
        width_list = []
        for img in img_list:
            width_list.append(img.shape[1] / float(img.shape[0]))
张欣-男's avatar
张欣-男 已提交
78
        # Sorting can speed up the recognition process
79 80 81 82
        indices = np.argsort(np.array(width_list))

        # rec_res = []
        rec_res = [['', 0.0]] * img_num
83
        batch_num = self.rec_batch_num
W
WenmuZhou 已提交
84
        elapse = 0
L
LDOUBLEV 已提交
85 86 87
        for beg_img_no in range(0, img_num, batch_num):
            end_img_no = min(img_num, beg_img_no + batch_num)
            norm_img_batch = []
88
            max_wh_ratio = 0
L
LDOUBLEV 已提交
89
            for ino in range(beg_img_no, end_img_no):
90 91
                # h, w = img_list[ino].shape[0:2]
                h, w = img_list[indices[ino]].shape[0:2]
92 93 94
                wh_ratio = w * 1.0 / h
                max_wh_ratio = max(max_wh_ratio, wh_ratio)
            for ino in range(beg_img_no, end_img_no):
95
                # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
T
tink2123 已提交
96 97
                norm_img = self.resize_norm_img(img_list[indices[ino]],
                                                max_wh_ratio)
L
LDOUBLEV 已提交
98 99 100 101 102
                norm_img = norm_img[np.newaxis, :]
                norm_img_batch.append(norm_img)
            norm_img_batch = np.concatenate(norm_img_batch)
            norm_img_batch = norm_img_batch.copy()
            starttime = time.time()
W
WenmuZhou 已提交
103 104
            self.input_tensor.copy_from_cpu(norm_img_batch)
            self.predictor.run()
W
WenmuZhou 已提交
105 106 107 108 109
            outputs = []
            for output_tensor in self.output_tensors:
                output = output_tensor.copy_to_cpu()
                outputs.append(output)
            preds = outputs[0]
W
WenmuZhou 已提交
110 111 112
            rec_result = self.postprocess_op(preds)
            for rno in range(len(rec_result)):
                rec_res[indices[beg_img_no + rno]] = rec_result[rno]
113
            elapse += time.time() - starttime
W
WenmuZhou 已提交
114
        return rec_res, elapse
L
LDOUBLEV 已提交
115 116


117
def main(args):
D
dyning 已提交
118
    image_file_list = get_image_file_list(args.image_dir)
L
LDOUBLEV 已提交
119 120 121 122
    text_recognizer = TextRecognizer(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list:
L
LDOUBLEV 已提交
123 124 125
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
L
LDOUBLEV 已提交
126 127 128 129 130
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
T
tink2123 已提交
131 132
    try:
        rec_res, predict_time = text_recognizer(img_list)
W
WenmuZhou 已提交
133 134
    except:
        logger.info(traceback.format_exc())
T
tink2123 已提交
135
        logger.info(
T
tink2123 已提交
136 137 138 139
            "ERROR!!!! \n"
            "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
            "If your model has tps module:  "
            "TPS does not support variable shape.\n"
T
tink2123 已提交
140
            "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
T
tink2123 已提交
141
        exit()
L
LDOUBLEV 已提交
142
    for ino in range(len(img_list)):
W
WenmuZhou 已提交
143 144
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
                                               rec_res[ino]))
W
WenmuZhou 已提交
145
    logger.info("Total predict time for {} images, cost: {:.3f}".format(
W
WenmuZhou 已提交
146
        len(img_list), predict_time))
147 148 149 150


if __name__ == "__main__":
    main(utility.parse_args())