infer_kie_token_ser.py 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
littletomatodonkey's avatar
littletomatodonkey 已提交
26
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46

os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import json
import paddle

from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.visual import draw_ser_results
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps
import tools.program as program


def to_tensor(data):
    import numbers
    from collections import defaultdict
    data_dict = defaultdict(list)
    to_tensor_idxs = []
文幕地方's avatar
文幕地方 已提交
47

48 49 50 51 52 53 54 55 56 57 58 59 60
    for idx, v in enumerate(data):
        if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
            if idx not in to_tensor_idxs:
                to_tensor_idxs.append(idx)
        data_dict[idx].append(v)
    for idx in to_tensor_idxs:
        data_dict[idx] = paddle.to_tensor(data_dict[idx])
    return list(data_dict.values())


class SerPredictor(object):
    def __init__(self, config):
        global_config = config['Global']
61
        self.algorithm = config['Architecture']["algorithm"]
62 63 64 65 66 67 68 69 70 71 72 73 74

        # build post process
        self.post_process_class = build_post_process(config['PostProcess'],
                                                     global_config)

        # build model
        self.model = build_model(config['Architecture'])

        load_model(
            config, self.model, model_type=config['Architecture']["model_type"])

        from paddleocr import PaddleOCR

文幕地方's avatar
文幕地方 已提交
75 76 77
        self.ocr_engine = PaddleOCR(
            use_angle_cls=False,
            show_log=False,
78 79
            rec_model_dir=global_config.get("kie_rec_model_dir", None),
            det_model_dir=global_config.get("kie_det_model_dir", None),
文幕地方's avatar
文幕地方 已提交
80
            use_gpu=global_config['use_gpu'])
81 82 83 84 85 86 87 88 89

        # create data ops
        transforms = []
        for op in config['Eval']['dataset']['transforms']:
            op_name = list(op)[0]
            if 'Label' in op_name:
                op[op_name]['ocr_engine'] = self.ocr_engine
            elif op_name == 'KeepKeys':
                op[op_name]['keep_keys'] = [
文幕地方's avatar
文幕地方 已提交
90 91
                    'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
                    'image', 'labels', 'segment_offset_id', 'ocr_info',
92 93 94 95
                    'entities'
                ]

            transforms.append(op)
96 97
        if config["Global"].get("infer_mode", None) is None:
            global_config['infer_mode'] = True
98 99 100 101
        self.ops = create_operators(config['Eval']['dataset']['transforms'],
                                    global_config)
        self.model.eval()

102 103
    def __call__(self, data):
        with open(data["img_path"], 'rb') as f:
104
            img = f.read()
105
        data["image"] = img
106 107 108
        batch = transform(data, self.ops)
        batch = to_tensor(batch)
        preds = self.model(batch)
文幕地方's avatar
文幕地方 已提交
109

110
        post_result = self.post_process_class(
文幕地方's avatar
文幕地方 已提交
111
            preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
112 113 114 115 116 117 118 119 120
        return post_result, batch


if __name__ == '__main__':
    config, device, logger, vdl_writer = program.preprocess()
    os.makedirs(config['Global']['save_res_path'], exist_ok=True)

    ser_engine = SerPredictor(config)

121 122 123 124 125 126 127
    if config["Global"].get("infer_mode", None) is False:
        data_dir = config['Eval']['dataset']['data_dir']
        with open(config['Global']['infer_img'], "rb") as f:
            infer_imgs = f.readlines()
    else:
        infer_imgs = get_image_file_list(config['Global']['infer_img'])

128 129 130 131 132
    with open(
            os.path.join(config['Global']['save_res_path'],
                         "infer_results.txt"),
            "w",
            encoding='utf-8') as fout:
133 134 135 136 137 138 139 140 141 142
        for idx, info in enumerate(infer_imgs):
            if config["Global"].get("infer_mode", None) is False:
                data_line = info.decode('utf-8')
                substr = data_line.strip("\n").split("\t")
                img_path = os.path.join(data_dir, substr[0])
                data = {'img_path': img_path, 'label': substr[1]}
            else:
                img_path = info
                data = {'img_path': img_path}

143 144 145 146
            save_img_path = os.path.join(
                config['Global']['save_res_path'],
                os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")

147
            result, _ = ser_engine(data)
148 149 150 151 152 153 154
            result = result[0]
            fout.write(img_path + "\t" + json.dumps(
                {
                    "ocr_info": result,
                }, ensure_ascii=False) + "\n")
            img_res = draw_ser_results(img_path, result)
            cv2.imwrite(save_img_path, img_res)
155 156 157

            logger.info("process: [{}/{}], save result to {}".format(
                idx, len(infer_imgs), save_img_path))