infer_det.py 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from copy import deepcopy
import json

L
LDOUBLEV 已提交
23 24 25 26 27
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
28 29 30 31 32 33 34 35 36 37


def set_paddle_flags(**kwargs):
    for key, value in kwargs.items():
        if os.environ.get(key, None) is None:
            os.environ[key] = str(value)


# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
L
LDOUBLEV 已提交
38
# not take any effect.
39 40 41 42 43
set_paddle_flags(
    FLAGS_eager_delete_tensor_gb=0,  # enable GC to save memory
)

from paddle import fluid
L
LDOUBLEV 已提交
44
from ppocr.utils.utility import create_module, get_image_file_list
45 46 47
import program
from ppocr.utils.save_load import init_model
from ppocr.data.reader_main import reader_main
L
LDOUBLEV 已提交
48
import cv2
49 50 51 52 53

from ppocr.utils.utility import initial_logger
logger = initial_logger()


L
LDOUBLEV 已提交
54
def draw_det_res(dt_boxes, config, img, img_name):
55 56
    if len(dt_boxes) > 0:
        import cv2
L
LDOUBLEV 已提交
57
        src_im = img
58 59 60
        for box in dt_boxes:
            box = box.astype(np.int32).reshape((-1, 1, 2))
            cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
L
LDOUBLEV 已提交
61
        save_det_path = os.path.dirname(config['Global'][
62 63 64
            'save_res_path']) + "/det_results/"
        if not os.path.exists(save_det_path):
            os.makedirs(save_det_path)
L
LDOUBLEV 已提交
65
        save_path = os.path.join(save_det_path, os.path.basename(img_name))
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
        cv2.imwrite(save_path, src_im)
        logger.info("The detected Image saved in {}".format(save_path))


def main():
    config = program.load_config(FLAGS.config)
    program.merge_config(FLAGS.opt)
    print(config)

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
    program.check_gpu(use_gpu)

    place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

    det_model = create_module(config['Architecture']['function'])(params=config)

    startup_prog = fluid.Program()
    eval_prog = fluid.Program()
    with fluid.program_guard(eval_prog, startup_prog):
        with fluid.unique_name.guard():
            _, eval_outputs = det_model(mode="test")
            fetch_name_list = list(eval_outputs.keys())
            eval_fetch_list = [eval_outputs[v].name for v in fetch_name_list]

    eval_prog = eval_prog.clone(for_test=True)
    exe.run(startup_prog)

    # load checkpoints
    checkpoints = config['Global'].get('checkpoints')
    if checkpoints:
        path = checkpoints
        fluid.load(eval_prog, path, exe)
        logger.info("Finish initing model from {}".format(path))
    else:
        raise Exception("{} not exists!".format(checkpoints))

    save_res_path = config['Global']['save_res_path']
L
LDOUBLEV 已提交
105 106
    if not os.path.exists(os.path.dirname(save_res_path)):
        os.makedirs(os.path.dirname(save_res_path))
107
    with open(save_res_path, "wb") as fout:
L
LDOUBLEV 已提交
108

L
LDOUBLEV 已提交
109
        test_reader = reader_main(config=config, mode='test')
110 111 112 113 114 115 116 117 118 119 120 121
        tackling_num = 0
        for data in test_reader():
            img_num = len(data)
            tackling_num = tackling_num + img_num
            logger.info("tackling_num:%d", tackling_num)
            img_list = []
            ratio_list = []
            img_name_list = []
            for ino in range(img_num):
                img_list.append(data[ino][0])
                ratio_list.append(data[ino][1])
                img_name_list.append(data[ino][2])
L
LDOUBLEV 已提交
122

123 124 125 126 127 128 129 130 131 132
            img_list = np.concatenate(img_list, axis=0)
            outs = exe.run(eval_prog,\
                feed={'image': img_list},\
                fetch_list=eval_fetch_list)

            global_params = config['Global']
            postprocess_params = deepcopy(config["PostProcess"])
            postprocess_params.update(global_params)
            postprocess = create_module(postprocess_params['function'])\
                (params=postprocess_params)
L
LDOUBLEV 已提交
133 134 135 136 137
            if config['Global']['algorithm'] == 'EAST':
                dic = {'f_score': outs[0], 'f_geo': outs[1]}
            elif config['Global']['algorithm'] == 'DB':
                dic = {'maps': outs[0]}
            else:
138
                raise Exception("only support algorithm: ['EAST', 'DB']")
L
LDOUBLEV 已提交
139
            dt_boxes_list = postprocess(dic, ratio_list)
140 141 142 143 144 145 146 147 148 149
            for ino in range(img_num):
                dt_boxes = dt_boxes_list[ino]
                img_name = img_name_list[ino]
                dt_boxes_json = []
                for box in dt_boxes:
                    tmp_json = {"transcription": ""}
                    tmp_json['points'] = box.tolist()
                    dt_boxes_json.append(tmp_json)
                otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n"
                fout.write(otstr.encode())
L
LDOUBLEV 已提交
150 151
                src_img = cv2.imread(img_name)
                draw_det_res(dt_boxes, config, src_img, img_name)
152 153 154 155 156 157 158 159

    logger.info("success!")


if __name__ == '__main__':
    parser = program.ArgsParser()
    FLAGS = parser.parse_args()
    main()