predict_kie_token_ser.py 5.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
文幕地方's avatar
fix  
文幕地方 已提交
19
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
20 21 22 23 24 25 26 27 28 29 30 31 32

os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

import cv2
import json
import numpy as np
import time

import tools.infer.utility as utility
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.visual import draw_ser_results
33
from ppocr.utils.utility import get_image_file_list, check_and_read
34 35 36 37 38 39 40 41 42
from ppstructure.utility import parse_args

from paddleocr import PaddleOCR

logger = get_logger()


class SerPredictor(object):
    def __init__(self, args):
littletomatodonkey's avatar
littletomatodonkey 已提交
43
        self.ocr_engine = PaddleOCR(
littletomatodonkey's avatar
littletomatodonkey 已提交
44 45 46 47 48
            use_angle_cls=args.use_angle_cls,
            det_model_dir=args.det_model_dir,
            rec_model_dir=args.rec_model_dir,
            show_log=False,
            use_gpu=args.use_gpu)
49 50 51

        pre_process_list = [{
            'VQATokenLabelEncode': {
52
                'algorithm': args.kie_algorithm,
53 54
                'class_path': args.ser_dict_path,
                'contains_re': False,
littletomatodonkey's avatar
littletomatodonkey 已提交
55 56
                'ocr_engine': self.ocr_engine,
                'order_method': args.ocr_order_method,
57 58
            }
        }, {
文幕地方's avatar
fix  
文幕地方 已提交
59 60
            'VQATokenPad': {
                'max_seq_len': 512,
61 62 63
                'return_attention_mask': True
            }
        }, {
文幕地方's avatar
fix  
文幕地方 已提交
64 65
            'VQASerTokenChunk': {
                'max_seq_len': 512,
66 67 68
                'return_attention_mask': True
            }
        }, {
文幕地方's avatar
fix  
文幕地方 已提交
69 70
            'Resize': {
                'size': [224, 224]
71 72 73 74 75 76 77 78 79 80 81 82 83
            }
        }, {
            'NormalizeImage': {
                'std': [58.395, 57.12, 57.375],
                'mean': [123.675, 116.28, 103.53],
                'scale': '1',
                'order': 'hwc'
            }
        }, {
            'ToCHWImage': None
        }, {
            'KeepKeys': {
                'keep_keys': [
文幕地方's avatar
fix  
文幕地方 已提交
84 85
                    'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
                    'image', 'labels', 'segment_offset_id', 'ocr_info',
86 87 88 89 90 91 92 93 94
                    'entities'
                ]
            }
        }]
        postprocess_params = {
            'name': 'VQASerTokenLayoutLMPostProcess',
            "class_path": args.ser_dict_path,
        }

文幕地方's avatar
fix  
文幕地方 已提交
95 96
        self.preprocess_op = create_operators(pre_process_list,
                                              {'infer_mode': True})
97 98 99 100 101 102 103 104
        self.postprocess_op = build_post_process(postprocess_params)
        self.predictor, self.input_tensor, self.output_tensors, self.config = \
            utility.create_predictor(args, 'ser', logger)

    def __call__(self, img):
        ori_im = img.copy()
        data = {'image': img}
        data = transform(data, self.preprocess_op)
文幕地方's avatar
文幕地方 已提交
105
        if data[0] is None:
106 107 108
            return None, 0
        starttime = time.time()

文幕地方's avatar
文幕地方 已提交
109 110 111 112 113 114
        for idx in range(len(data)):
            if isinstance(data[idx], np.ndarray):
                data[idx] = np.expand_dims(data[idx], axis=0)
            else:
                data[idx] = [data[idx]]

115
        for idx in range(len(self.input_tensor)):
文幕地方's avatar
文幕地方 已提交
116
            self.input_tensor[idx].copy_from_cpu(data[idx])
117 118 119 120 121 122 123 124

        self.predictor.run()

        outputs = []
        for output_tensor in self.output_tensors:
            output = output_tensor.copy_to_cpu()
            outputs.append(output)
        preds = outputs[0]
文幕地方's avatar
fix  
文幕地方 已提交
125

126
        post_result = self.postprocess_op(
文幕地方's avatar
文幕地方 已提交
127
            preds, segment_offset_ids=data[6], ocr_infos=data[7])
128
        elapse = time.time() - starttime
文幕地方's avatar
文幕地方 已提交
129
        return post_result, data, elapse
130 131 132 133 134 135 136 137 138 139 140 141 142


def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    ser_predictor = SerPredictor(args)
    count = 0
    total_time = 0

    os.makedirs(args.output, exist_ok=True)
    with open(
            os.path.join(args.output, 'infer.txt'), mode='w',
            encoding='utf-8') as f_w:
        for image_file in image_file_list:
143
            img, flag, _ = check_and_read(image_file)
144 145
            if not flag:
                img = cv2.imread(image_file)
文幕地方's avatar
fix  
文幕地方 已提交
146
                img = img[:, :, ::-1]
147 148 149
            if img is None:
                logger.info("error in loading image:{}".format(image_file))
                continue
文幕地方's avatar
文幕地方 已提交
150
            ser_res, _, elapse = ser_predictor(img)
151 152
            ser_res = ser_res[0]

文幕地方's avatar
fix  
文幕地方 已提交
153 154 155 156 157 158
            res_str = '{}\t{}\n'.format(
                image_file,
                json.dumps(
                    {
                        "ocr_info": ser_res,
                    }, ensure_ascii=False))
159 160
            f_w.write(res_str)

文幕地方's avatar
fix  
文幕地方 已提交
161 162 163
            img_res = draw_ser_results(
                image_file,
                ser_res,
文幕地方's avatar
文幕地方 已提交
164
                font_path=args.vis_font_path, )
165 166 167 168 169 170 171 172 173 174

            img_save_path = os.path.join(args.output,
                                         os.path.basename(image_file))
            cv2.imwrite(img_save_path, img_res)
            logger.info("save vis result to {}".format(img_save_path))
            if count > 0:
                total_time += elapse
            count += 1
            logger.info("Predict time of {}: {}".format(image_file, elapse))

文幕地方's avatar
fix  
文幕地方 已提交
175

176 177
if __name__ == "__main__":
    main(parse_args())