predict_rec.py 12.9 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
LDOUBLEV 已提交
14 15
import os
import sys
16
import threading
T
Topdu 已提交
17
from PIL import Image
18
__dir__ = os.path.dirname(os.path.abspath(__file__))
L
LDOUBLEV 已提交
19
sys.path.append(__dir__)
20
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
L
LDOUBLEV 已提交
21

L
LDOUBLEV 已提交
22 23
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

L
LDOUBLEV 已提交
24 25 26 27
import cv2
import numpy as np
import math
import time
W
WenmuZhou 已提交
28
import traceback
T
tink2123 已提交
29
import paddle
30 31

import tools.infer.utility as utility
W
WenmuZhou 已提交
32 33
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
34
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
L
LDOUBLEV 已提交
35

W
WenmuZhou 已提交
36 37
logger = get_logger()

L
LDOUBLEV 已提交
38 39 40

class TextRecognizer(object):
    def __init__(self, args):
41
        self.lock = threading.RLock()
42
        self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
D
dyning 已提交
43
        self.character_type = args.rec_char_type
44
        self.rec_batch_num = args.rec_batch_num
T
tink2123 已提交
45
        self.rec_algorithm = args.rec_algorithm
W
WenmuZhou 已提交
46 47
        postprocess_params = {
            'name': 'CTCLabelDecode',
T
tink2123 已提交
48
            "character_type": args.rec_char_type,
49
            "character_dict_path": args.rec_char_dict_path,
W
WenmuZhou 已提交
50
            "use_space_char": args.use_space_char
T
tink2123 已提交
51
        }
T
tink2123 已提交
52 53 54
        if self.rec_algorithm == "SRN":
            postprocess_params = {
                'name': 'SRNLabelDecode',
W
WenmuZhou 已提交
55 56 57 58 59 60 61
                "character_type": args.rec_char_type,
                "character_dict_path": args.rec_char_dict_path,
                "use_space_char": args.use_space_char
            }
        elif self.rec_algorithm == "RARE":
            postprocess_params = {
                'name': 'AttnLabelDecode',
T
tink2123 已提交
62 63 64 65
                "character_type": args.rec_char_type,
                "character_dict_path": args.rec_char_dict_path,
                "use_space_char": args.use_space_char
            }
T
Topdu 已提交
66 67 68 69 70 71 72
        elif self.rec_algorithm == 'NRTR':
            postprocess_params = {
                'name': 'NRTRLabelDecode',
                "character_type": args.rec_char_type,
                "character_dict_path": args.rec_char_dict_path,
                "use_space_char": args.use_space_char
            }
W
WenmuZhou 已提交
73
        self.postprocess_op = build_post_process(postprocess_params)
L
LDOUBLEV 已提交
74
        self.predictor, self.input_tensor, self.output_tensors, self.config = \
W
WenmuZhou 已提交
75
            utility.create_predictor(args, 'rec', logger)
T
tink2123 已提交
76 77 78 79
        self.benchmark = args.benchmark
        if args.benchmark:
            import auto_log
            pid = os.getpid()
L
LDOUBLEV 已提交
80
            gpu_id = utility.get_infer_gpuid()
T
tink2123 已提交
81 82 83
            self.autolog = auto_log.AutoLogger(
                model_name="rec",
                model_precision=args.precision,
T
tink2123 已提交
84
                batch_size=args.rec_batch_num,
T
tink2123 已提交
85
                data_shape="dynamic",
86
                save_path=None,  #args.save_log_path,
T
tink2123 已提交
87 88 89
                inference_config=self.config,
                pids=pid,
                process_name=None,
L
LDOUBLEV 已提交
90
                gpu_ids=gpu_id if args.use_gpu else None,
T
tink2123 已提交
91 92 93
                time_keys=[
                    'preprocess_time', 'inference_time', 'postprocess_time'
                ],
94 95
                warmup=2,
                logger=logger)
L
LDOUBLEV 已提交
96

97
    def resize_norm_img(self, img, max_wh_ratio):
L
LDOUBLEV 已提交
98
        imgC, imgH, imgW = self.rec_image_shape
T
Topdu 已提交
99
        if self.rec_algorithm == 'NRTR':
T
Topdu 已提交
100 101 102 103 104 105 106 107 108
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            # return padding_im
            image_pil = Image.fromarray(np.uint8(img))
            img = image_pil.resize([100, 32], Image.ANTIALIAS)
            img = np.array(img)
            norm_img = np.expand_dims(img, -1)
            norm_img = norm_img.transpose((2, 0, 1))
            return norm_img.astype(np.float32) / 128. - 1.

109
        assert imgC == img.shape[2]
T
tink2123 已提交
110 111
        max_wh_ratio = max(max_wh_ratio, imgW / imgH)
        imgW = int((32 * max_wh_ratio))
112
        h, w = img.shape[:2]
113 114 115 116 117
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
T
tink2123 已提交
118
        resized_image = cv2.resize(img, (resized_w, imgH))
L
LDOUBLEV 已提交
119 120 121 122 123 124 125 126
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

T
tink2123 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
    def resize_norm_img_srn(self, img, image_shape):
        imgC, imgH, imgW = image_shape

        img_black = np.zeros((imgH, imgW))
        im_hei = img.shape[0]
        im_wid = img.shape[1]

        if im_wid <= im_hei * 1:
            img_new = cv2.resize(img, (imgH * 1, imgH))
        elif im_wid <= im_hei * 2:
            img_new = cv2.resize(img, (imgH * 2, imgH))
        elif im_wid <= im_hei * 3:
            img_new = cv2.resize(img, (imgH * 3, imgH))
        else:
            img_new = cv2.resize(img, (imgW, imgH))

        img_np = np.asarray(img_new)
        img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
        img_black[:, 0:img_np.shape[1]] = img_np
        img_black = img_black[:, :, np.newaxis]

        row, col, c = img_black.shape
        c = 1

        return np.reshape(img_black, (c, row, col)).astype(np.float32)

    def srn_other_inputs(self, image_shape, num_heads, max_text_length):

        imgC, imgH, imgW = image_shape
        feature_dim = int((imgH / 8) * (imgW / 8))

        encoder_word_pos = np.array(range(0, feature_dim)).reshape(
            (feature_dim, 1)).astype('int64')
        gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
            (max_text_length, 1)).astype('int64')

        gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
        gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
            [-1, 1, max_text_length, max_text_length])
        gsrm_slf_attn_bias1 = np.tile(
            gsrm_slf_attn_bias1,
            [1, num_heads, 1, 1]).astype('float32') * [-1e9]

        gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
            [-1, 1, max_text_length, max_text_length])
        gsrm_slf_attn_bias2 = np.tile(
            gsrm_slf_attn_bias2,
            [1, num_heads, 1, 1]).astype('float32') * [-1e9]

        encoder_word_pos = encoder_word_pos[np.newaxis, :]
        gsrm_word_pos = gsrm_word_pos[np.newaxis, :]

        return [
            encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
            gsrm_slf_attn_bias2
        ]

    def process_image_srn(self, img, image_shape, num_heads, max_text_length):
        norm_img = self.resize_norm_img_srn(img, image_shape)
        norm_img = norm_img[np.newaxis, :]

        [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
            self.srn_other_inputs(image_shape, num_heads, max_text_length)

        gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
        gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
        encoder_word_pos = encoder_word_pos.astype(np.int64)
        gsrm_word_pos = gsrm_word_pos.astype(np.int64)

        return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
                gsrm_slf_attn_bias2)

L
LDOUBLEV 已提交
199
    def __call__(self, img_list):
200
        self.lock.acquire()
L
LDOUBLEV 已提交
201
        img_num = len(img_list)
202
        # Calculate the aspect ratio of all text bars
203 204 205
        width_list = []
        for img in img_list:
            width_list.append(img.shape[1] / float(img.shape[0]))
张欣-男's avatar
张欣-男 已提交
206
        # Sorting can speed up the recognition process
207 208
        indices = np.argsort(np.array(width_list))
        rec_res = [['', 0.0]] * img_num
209
        batch_num = self.rec_batch_num
L
LDOUBLEV 已提交
210
        st = time.time()
T
tink2123 已提交
211 212
        if self.benchmark:
            self.autolog.times.start()
L
LDOUBLEV 已提交
213 214 215
        for beg_img_no in range(0, img_num, batch_num):
            end_img_no = min(img_num, beg_img_no + batch_num)
            norm_img_batch = []
216
            max_wh_ratio = 0
L
LDOUBLEV 已提交
217
            for ino in range(beg_img_no, end_img_no):
218
                h, w = img_list[indices[ino]].shape[0:2]
219 220 221
                wh_ratio = w * 1.0 / h
                max_wh_ratio = max(max_wh_ratio, wh_ratio)
            for ino in range(beg_img_no, end_img_no):
T
tink2123 已提交
222 223 224 225 226 227
                if self.rec_algorithm != "SRN":
                    norm_img = self.resize_norm_img(img_list[indices[ino]],
                                                    max_wh_ratio)
                    norm_img = norm_img[np.newaxis, :]
                    norm_img_batch.append(norm_img)
                else:
L
LDOUBLEV 已提交
228 229
                    norm_img = self.process_image_srn(
                        img_list[indices[ino]], self.rec_image_shape, 8, 25)
T
tink2123 已提交
230 231 232 233 234 235 236 237 238
                    encoder_word_pos_list = []
                    gsrm_word_pos_list = []
                    gsrm_slf_attn_bias1_list = []
                    gsrm_slf_attn_bias2_list = []
                    encoder_word_pos_list.append(norm_img[1])
                    gsrm_word_pos_list.append(norm_img[2])
                    gsrm_slf_attn_bias1_list.append(norm_img[3])
                    gsrm_slf_attn_bias2_list.append(norm_img[4])
                    norm_img_batch.append(norm_img[0])
L
LDOUBLEV 已提交
239 240
            norm_img_batch = np.concatenate(norm_img_batch)
            norm_img_batch = norm_img_batch.copy()
T
tink2123 已提交
241 242
            if self.benchmark:
                self.autolog.times.stamp()
T
tink2123 已提交
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268

            if self.rec_algorithm == "SRN":
                encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
                gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
                gsrm_slf_attn_bias1_list = np.concatenate(
                    gsrm_slf_attn_bias1_list)
                gsrm_slf_attn_bias2_list = np.concatenate(
                    gsrm_slf_attn_bias2_list)

                inputs = [
                    norm_img_batch,
                    encoder_word_pos_list,
                    gsrm_word_pos_list,
                    gsrm_slf_attn_bias1_list,
                    gsrm_slf_attn_bias2_list,
                ]
                input_names = self.predictor.get_input_names()
                for i in range(len(input_names)):
                    input_tensor = self.predictor.get_input_handle(input_names[
                        i])
                    input_tensor.copy_from_cpu(inputs[i])
                self.predictor.run()
                outputs = []
                for output_tensor in self.output_tensors:
                    output = output_tensor.copy_to_cpu()
                    outputs.append(output)
T
tink2123 已提交
269 270
                if self.benchmark:
                    self.autolog.times.stamp()
T
tink2123 已提交
271 272 273 274 275 276 277 278
                preds = {"predict": outputs[2]}
            else:
                self.input_tensor.copy_from_cpu(norm_img_batch)
                self.predictor.run()
                outputs = []
                for output_tensor in self.output_tensors:
                    output = output_tensor.copy_to_cpu()
                    outputs.append(output)
T
tink2123 已提交
279 280
                if self.benchmark:
                    self.autolog.times.stamp()
T
Topdu 已提交
281 282 283 284
                if len(outputs) != 1:
                    preds = outputs
                else:
                    preds = outputs[0]
W
WenmuZhou 已提交
285 286 287
            rec_result = self.postprocess_op(preds)
            for rno in range(len(rec_result)):
                rec_res[indices[beg_img_no + rno]] = rec_result[rno]
T
tink2123 已提交
288 289
            if self.benchmark:
                self.autolog.times.end(stamp=True)
290
        self.lock.release()
L
LDOUBLEV 已提交
291
        return rec_res, time.time() - st
L
LDOUBLEV 已提交
292 293


294
def main(args):
D
dyning 已提交
295
    image_file_list = get_image_file_list(args.image_dir)
L
LDOUBLEV 已提交
296 297 298
    text_recognizer = TextRecognizer(args)
    valid_image_file_list = []
    img_list = []
L
LDOUBLEV 已提交
299

300
    # warmup 2 times
L
LDOUBLEV 已提交
301 302
    if args.warmup:
        img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8)
303
        for i in range(2):
L
LDOUBLEV 已提交
304
            res = text_recognizer([img] * int(args.rec_batch_num))
L
LDOUBLEV 已提交
305

L
LDOUBLEV 已提交
306
    for image_file in image_file_list:
L
LDOUBLEV 已提交
307 308 309
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
L
LDOUBLEV 已提交
310 311 312 313 314
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
L
LDOUBLEV 已提交
315 316 317 318 319 320 321 322 323 324
    try:
        rec_res, _ = text_recognizer(img_list)

    except Exception as E:
        logger.info(traceback.format_exc())
        logger.info(E)
        exit()
    for ino in range(len(img_list)):
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
                                               rec_res[ino]))
T
tink2123 已提交
325 326
    if args.benchmark:
        text_recognizer.autolog.report()
327 328 329 330


if __name__ == "__main__":
    main(utility.parse_args())