predict_vqa_token_ser.py 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
文幕地方's avatar
fix  
文幕地方 已提交
19
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42

os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

import cv2
import json
import numpy as np
import time

import tools.infer.utility as utility
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.visual import draw_ser_results
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppstructure.utility import parse_args

from paddleocr import PaddleOCR

logger = get_logger()


class SerPredictor(object):
    def __init__(self, args):
littletomatodonkey's avatar
littletomatodonkey 已提交
43 44
        self.ocr_engine = PaddleOCR(
            use_angle_cls=False, show_log=False, use_gpu=args.use_gpu)
45 46 47 48 49 50

        pre_process_list = [{
            'VQATokenLabelEncode': {
                'algorithm': args.vqa_algorithm,
                'class_path': args.ser_dict_path,
                'contains_re': False,
littletomatodonkey's avatar
littletomatodonkey 已提交
51 52
                'ocr_engine': self.ocr_engine,
                'order_method': args.ocr_order_method,
53 54
            }
        }, {
文幕地方's avatar
fix  
文幕地方 已提交
55 56
            'VQATokenPad': {
                'max_seq_len': 512,
57 58 59
                'return_attention_mask': True
            }
        }, {
文幕地方's avatar
fix  
文幕地方 已提交
60 61
            'VQASerTokenChunk': {
                'max_seq_len': 512,
62 63 64
                'return_attention_mask': True
            }
        }, {
文幕地方's avatar
fix  
文幕地方 已提交
65 66
            'Resize': {
                'size': [224, 224]
67 68 69 70 71 72 73 74 75 76 77 78 79
            }
        }, {
            'NormalizeImage': {
                'std': [58.395, 57.12, 57.375],
                'mean': [123.675, 116.28, 103.53],
                'scale': '1',
                'order': 'hwc'
            }
        }, {
            'ToCHWImage': None
        }, {
            'KeepKeys': {
                'keep_keys': [
文幕地方's avatar
fix  
文幕地方 已提交
80 81
                    'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
                    'image', 'labels', 'segment_offset_id', 'ocr_info',
82 83 84 85 86 87 88 89 90
                    'entities'
                ]
            }
        }]
        postprocess_params = {
            'name': 'VQASerTokenLayoutLMPostProcess',
            "class_path": args.ser_dict_path,
        }

文幕地方's avatar
fix  
文幕地方 已提交
91 92
        self.preprocess_op = create_operators(pre_process_list,
                                              {'infer_mode': True})
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        self.postprocess_op = build_post_process(postprocess_params)
        self.predictor, self.input_tensor, self.output_tensors, self.config = \
            utility.create_predictor(args, 'ser', logger)

    def __call__(self, img):
        ori_im = img.copy()
        data = {'image': img}
        data = transform(data, self.preprocess_op)
        img = data[0]
        if img is None:
            return None, 0
        img = np.expand_dims(img, axis=0)
        img = img.copy()
        starttime = time.time()

        for idx in range(len(self.input_tensor)):
            expand_input = np.expand_dims(data[idx], axis=0)
            self.input_tensor[idx].copy_from_cpu(expand_input)

        self.predictor.run()

        outputs = []
        for output_tensor in self.output_tensors:
            output = output_tensor.copy_to_cpu()
            outputs.append(output)
        preds = outputs[0]
文幕地方's avatar
fix  
文幕地方 已提交
119

120
        post_result = self.postprocess_op(
文幕地方's avatar
fix  
文幕地方 已提交
121
            preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]])
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
        elapse = time.time() - starttime
        return post_result, elapse


def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    ser_predictor = SerPredictor(args)
    count = 0
    total_time = 0

    os.makedirs(args.output, exist_ok=True)
    with open(
            os.path.join(args.output, 'infer.txt'), mode='w',
            encoding='utf-8') as f_w:
        for image_file in image_file_list:
            img, flag = check_and_read_gif(image_file)
            if not flag:
                img = cv2.imread(image_file)
文幕地方's avatar
fix  
文幕地方 已提交
140
                img = img[:, :, ::-1]
141 142 143 144 145 146
            if img is None:
                logger.info("error in loading image:{}".format(image_file))
                continue
            ser_res, elapse = ser_predictor(img)
            ser_res = ser_res[0]

文幕地方's avatar
fix  
文幕地方 已提交
147 148 149 150 151 152
            res_str = '{}\t{}\n'.format(
                image_file,
                json.dumps(
                    {
                        "ocr_info": ser_res,
                    }, ensure_ascii=False))
153 154
            f_w.write(res_str)

文幕地方's avatar
fix  
文幕地方 已提交
155 156 157
            img_res = draw_ser_results(
                image_file,
                ser_res,
文幕地方's avatar
文幕地方 已提交
158
                font_path=args.vis_font_path, )
159 160 161 162 163 164 165 166 167 168

            img_save_path = os.path.join(args.output,
                                         os.path.basename(image_file))
            cv2.imwrite(img_save_path, img_res)
            logger.info("save vis result to {}".format(img_save_path))
            if count > 0:
                total_time += elapse
            count += 1
            logger.info("Predict time of {}: {}".format(image_file, elapse))

文幕地方's avatar
fix  
文幕地方 已提交
169

170 171
if __name__ == "__main__":
    main(parse_args())