predict_rec.py 5.7 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
LDOUBLEV 已提交
14 15
import os
import sys
W
WenmuZhou 已提交
16

17
__dir__ = os.path.dirname(os.path.abspath(__file__))
L
LDOUBLEV 已提交
18
sys.path.append(__dir__)
19
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
L
LDOUBLEV 已提交
20 21 22 23 24

import cv2
import numpy as np
import math
import time
W
WenmuZhou 已提交
25
import traceback
26 27

import tools.infer.utility as utility
W
WenmuZhou 已提交
28 29
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
30
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
L
LDOUBLEV 已提交
31

W
WenmuZhou 已提交
32 33
logger = get_logger()

L
LDOUBLEV 已提交
34 35 36

class TextRecognizer(object):
    def __init__(self, args):
37
        self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
D
dyning 已提交
38
        self.character_type = args.rec_char_type
39
        self.rec_batch_num = args.rec_batch_num
T
tink2123 已提交
40
        self.rec_algorithm = args.rec_algorithm
W
WenmuZhou 已提交
41 42
        postprocess_params = {
            'name': 'CTCLabelDecode',
T
tink2123 已提交
43
            "character_type": args.rec_char_type,
44
            "character_dict_path": args.rec_char_dict_path,
W
WenmuZhou 已提交
45
            "use_space_char": args.use_space_char
T
tink2123 已提交
46
        }
W
WenmuZhou 已提交
47 48 49
        self.postprocess_op = build_post_process(postprocess_params)
        self.predictor, self.input_tensor, self.output_tensors = \
            utility.create_predictor(args, 'rec', logger)
L
LDOUBLEV 已提交
50

51
    def resize_norm_img(self, img, max_wh_ratio):
L
LDOUBLEV 已提交
52
        imgC, imgH, imgW = self.rec_image_shape
53
        assert imgC == img.shape[2]
54
        if self.character_type == "ch":
T
tink2123 已提交
55
            imgW = int((32 * max_wh_ratio))
56
        h, w = img.shape[:2]
57 58 59 60 61
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
T
tink2123 已提交
62
        resized_image = cv2.resize(img, (resized_w, imgH))
L
LDOUBLEV 已提交
63 64 65 66 67 68 69 70 71 72
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

    def __call__(self, img_list):
        img_num = len(img_list)
73
        # Calculate the aspect ratio of all text bars
74 75 76
        width_list = []
        for img in img_list:
            width_list.append(img.shape[1] / float(img.shape[0]))
张欣-男's avatar
张欣-男 已提交
77
        # Sorting can speed up the recognition process
78 79 80 81
        indices = np.argsort(np.array(width_list))

        # rec_res = []
        rec_res = [['', 0.0]] * img_num
82
        batch_num = self.rec_batch_num
W
WenmuZhou 已提交
83
        elapse = 0
L
LDOUBLEV 已提交
84 85 86
        for beg_img_no in range(0, img_num, batch_num):
            end_img_no = min(img_num, beg_img_no + batch_num)
            norm_img_batch = []
87
            max_wh_ratio = 0
L
LDOUBLEV 已提交
88
            for ino in range(beg_img_no, end_img_no):
89 90
                # h, w = img_list[ino].shape[0:2]
                h, w = img_list[indices[ino]].shape[0:2]
91 92 93
                wh_ratio = w * 1.0 / h
                max_wh_ratio = max(max_wh_ratio, wh_ratio)
            for ino in range(beg_img_no, end_img_no):
94
                # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
T
tink2123 已提交
95 96
                norm_img = self.resize_norm_img(img_list[indices[ino]],
                                                max_wh_ratio)
L
LDOUBLEV 已提交
97 98 99 100 101
                norm_img = norm_img[np.newaxis, :]
                norm_img_batch.append(norm_img)
            norm_img_batch = np.concatenate(norm_img_batch)
            norm_img_batch = norm_img_batch.copy()
            starttime = time.time()
W
WenmuZhou 已提交
102 103
            self.input_tensor.copy_from_cpu(norm_img_batch)
            self.predictor.run()
W
WenmuZhou 已提交
104 105 106 107 108
            outputs = []
            for output_tensor in self.output_tensors:
                output = output_tensor.copy_to_cpu()
                outputs.append(output)
            preds = outputs[0]
W
WenmuZhou 已提交
109 110 111
            rec_result = self.postprocess_op(preds)
            for rno in range(len(rec_result)):
                rec_res[indices[beg_img_no + rno]] = rec_result[rno]
112
            elapse += time.time() - starttime
W
WenmuZhou 已提交
113
        return rec_res, elapse
L
LDOUBLEV 已提交
114 115


116
def main(args):
D
dyning 已提交
117
    image_file_list = get_image_file_list(args.image_dir)
L
LDOUBLEV 已提交
118 119 120 121
    text_recognizer = TextRecognizer(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list:
L
LDOUBLEV 已提交
122 123 124
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
L
LDOUBLEV 已提交
125 126 127 128 129
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
T
tink2123 已提交
130 131
    try:
        rec_res, predict_time = text_recognizer(img_list)
W
WenmuZhou 已提交
132 133
    except:
        logger.info(traceback.format_exc())
T
tink2123 已提交
134
        logger.info(
T
tink2123 已提交
135 136 137 138
            "ERROR!!!! \n"
            "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
            "If your model has tps module:  "
            "TPS does not support variable shape.\n"
T
tink2123 已提交
139
            "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
T
tink2123 已提交
140
        exit()
L
LDOUBLEV 已提交
141
    for ino in range(len(img_list)):
W
WenmuZhou 已提交
142 143
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
                                               rec_res[ino]))
W
WenmuZhou 已提交
144
    logger.info("Total predict time for {} images, cost: {:.3f}".format(
W
WenmuZhou 已提交
145
        len(img_list), predict_time))
146 147 148 149


if __name__ == "__main__":
    main(utility.parse_args())