infer_kie_token_ser_re.py 7.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
littletomatodonkey's avatar
littletomatodonkey 已提交
26
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
27 28 29 30 31 32 33 34 35 36 37 38 39 40

os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import json
import paddle
import paddle.distributed as dist

from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.visual import draw_re_results
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
41
from tools.program import ArgsParser, load_config, merge_config
littletomatodonkey's avatar
littletomatodonkey 已提交
42
from tools.infer_kie_token_ser import SerPredictor
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65


class ReArgsParser(ArgsParser):
    def __init__(self):
        super(ReArgsParser, self).__init__()
        self.add_argument(
            "-c_ser", "--config_ser", help="ser configuration file to use")
        self.add_argument(
            "-o_ser",
            "--opt_ser",
            nargs='+',
            help="set ser configuration options ")

    def parse_args(self, argv=None):
        args = super(ReArgsParser, self).parse_args(argv)
        assert args.config_ser is not None, \
            "Please specify --config_ser=ser_configure_file_path."
        args.opt_ser = self._parse_opt(args.opt_ser)
        return args


def make_input(ser_inputs, ser_results):
    entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
文幕地方's avatar
文幕地方 已提交
66
    batch_size, max_seq_len = ser_inputs[0].shape[:2]
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
    entities = ser_inputs[8][0]
    ser_results = ser_results[0]
    assert len(entities) == len(ser_results)

    # entities
    start = []
    end = []
    label = []
    entity_idx_dict = {}
    for i, (res, entity) in enumerate(zip(ser_results, entities)):
        if res['pred'] == 'O':
            continue
        entity_idx_dict[len(start)] = i
        start.append(entity['start'])
        end.append(entity['end'])
        label.append(entities_labels[res['pred']])
文幕地方's avatar
文幕地方 已提交
83

Z
zhoujun 已提交
84
    entities = np.full([max_seq_len + 1, 3], fill_value=-1, dtype=np.int64)
文幕地方's avatar
文幕地方 已提交
85 86 87 88 89 90
    entities[0, 0] = len(start)
    entities[1:len(start) + 1, 0] = start
    entities[0, 1] = len(end)
    entities[1:len(end) + 1, 1] = end
    entities[0, 2] = len(label)
    entities[1:len(label) + 1, 2] = label
91 92 93 94

    # relations
    head = []
    tail = []
文幕地方's avatar
文幕地方 已提交
95 96 97
    for i in range(len(label)):
        for j in range(len(label)):
            if label[i] == 1 and label[j] == 2:
98 99 100
                head.append(i)
                tail.append(j)

Z
zhoujun 已提交
101
    relations = np.full([len(head) + 1, 2], fill_value=-1, dtype=np.int64)
文幕地方's avatar
文幕地方 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    relations[0, 0] = len(head)
    relations[1:len(head) + 1, 0] = head
    relations[0, 1] = len(tail)
    relations[1:len(tail) + 1, 1] = tail

    entities = np.expand_dims(entities, axis=0)
    entities = np.repeat(entities, batch_size, axis=0)
    relations = np.expand_dims(relations, axis=0)
    relations = np.repeat(relations, batch_size, axis=0)

    # remove ocr_info segment_offset_id and label in ser input
    if isinstance(ser_inputs[0], paddle.Tensor):
        entities = paddle.to_tensor(entities)
        relations = paddle.to_tensor(relations)
    ser_inputs = ser_inputs[:5] + [entities, relations]
117 118 119 120 121 122 123 124 125

    entity_idx_dict_batch = []
    for b in range(batch_size):
        entity_idx_dict_batch.append(entity_idx_dict)
    return ser_inputs, entity_idx_dict_batch


class SerRePredictor(object):
    def __init__(self, config, ser_config):
littletomatodonkey's avatar
littletomatodonkey 已提交
126 127 128 129
        global_config = config['Global']
        if "infer_mode" in global_config:
            ser_config["Global"]["infer_mode"] = global_config["infer_mode"]

130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
        self.ser_engine = SerPredictor(ser_config)

        #  init re model 

        # build post process
        self.post_process_class = build_post_process(config['PostProcess'],
                                                     global_config)

        # build model
        self.model = build_model(config['Architecture'])

        load_model(
            config, self.model, model_type=config['Architecture']["model_type"])

        self.model.eval()

littletomatodonkey's avatar
littletomatodonkey 已提交
146 147
    def __call__(self, data):
        ser_results, ser_inputs = self.ser_engine(data)
148
        re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
文幕地方's avatar
文幕地方 已提交
149 150
        if self.model.backbone.use_visual_backbone is False:
            re_input.pop(4)
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
        preds = self.model(re_input)
        post_result = self.post_process_class(
            preds,
            ser_results=ser_results,
            entity_idx_dict_batch=entity_idx_dict_batch)
        return post_result


def preprocess():
    FLAGS = ReArgsParser().parse_args()
    config = load_config(FLAGS.config)
    config = merge_config(config, FLAGS.opt)

    ser_config = load_config(FLAGS.config_ser)
    ser_config = merge_config(ser_config, FLAGS.opt_ser)

Z
zhoujun 已提交
167
    logger = get_logger()
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']

    device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
    device = paddle.set_device(device)

    logger.info('{} re config {}'.format('*' * 10, '*' * 10))
    print_dict(config, logger)
    logger.info('\n')
    logger.info('{} ser config {}'.format('*' * 10, '*' * 10))
    print_dict(ser_config, logger)
    logger.info('train with paddle {} and device {}'.format(paddle.__version__,
                                                            device))
    return config, ser_config, device, logger


if __name__ == '__main__':
    config, ser_config, device, logger = preprocess()
    os.makedirs(config['Global']['save_res_path'], exist_ok=True)

    ser_re_engine = SerRePredictor(config, ser_config)

littletomatodonkey's avatar
littletomatodonkey 已提交
191 192 193 194 195 196 197
    if config["Global"].get("infer_mode", None) is False:
        data_dir = config['Eval']['dataset']['data_dir']
        with open(config['Global']['infer_img'], "rb") as f:
            infer_imgs = f.readlines()
    else:
        infer_imgs = get_image_file_list(config['Global']['infer_img'])

198 199 200 201 202
    with open(
            os.path.join(config['Global']['save_res_path'],
                         "infer_results.txt"),
            "w",
            encoding='utf-8') as fout:
littletomatodonkey's avatar
littletomatodonkey 已提交
203 204 205 206 207 208 209 210 211 212
        for idx, info in enumerate(infer_imgs):
            if config["Global"].get("infer_mode", None) is False:
                data_line = info.decode('utf-8')
                substr = data_line.strip("\n").split("\t")
                img_path = os.path.join(data_dir, substr[0])
                data = {'img_path': img_path, 'label': substr[1]}
            else:
                img_path = info
                data = {'img_path': img_path}

213 214
            save_img_path = os.path.join(
                config['Global']['save_res_path'],
215
                os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
216

littletomatodonkey's avatar
littletomatodonkey 已提交
217
            result = ser_re_engine(data)
218 219
            result = result[0]
            fout.write(img_path + "\t" + json.dumps(
220
                result, ensure_ascii=False) + "\n")
221 222
            img_res = draw_re_results(img_path, result)
            cv2.imwrite(save_img_path, img_res)
文幕地方's avatar
文幕地方 已提交
223

224
            logger.info("process: [{}/{}], save result to {}".format(
文幕地方's avatar
文幕地方 已提交
225
                idx, len(infer_imgs), save_img_path))