infer_det.py 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from copy import deepcopy
import json

L
LDOUBLEV 已提交
23 24
import os
import sys
25
__dir__ = os.path.dirname(os.path.abspath(__file__))
L
LDOUBLEV 已提交
26
sys.path.append(__dir__)
27
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
28 29 30 31 32 33 34 35 36 37


def set_paddle_flags(**kwargs):
    for key, value in kwargs.items():
        if os.environ.get(key, None) is None:
            os.environ[key] = str(value)


# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
L
LDOUBLEV 已提交
38
# not take any effect.
39 40 41 42 43
set_paddle_flags(
    FLAGS_eager_delete_tensor_gb=0,  # enable GC to save memory
)

from paddle import fluid
L
LDOUBLEV 已提交
44
from ppocr.utils.utility import create_module, get_image_file_list
45 46 47
import program
from ppocr.utils.save_load import init_model
from ppocr.data.reader_main import reader_main
L
LDOUBLEV 已提交
48
import cv2
49 50 51 52 53

from ppocr.utils.utility import initial_logger
logger = initial_logger()


L
LDOUBLEV 已提交
54
def draw_det_res(dt_boxes, config, img, img_name):
55 56
    if len(dt_boxes) > 0:
        import cv2
L
LDOUBLEV 已提交
57
        src_im = img
58 59 60
        for box in dt_boxes:
            box = box.astype(np.int32).reshape((-1, 1, 2))
            cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
L
LDOUBLEV 已提交
61
        save_det_path = os.path.dirname(config['Global'][
62 63 64
            'save_res_path']) + "/det_results/"
        if not os.path.exists(save_det_path):
            os.makedirs(save_det_path)
L
LDOUBLEV 已提交
65
        save_path = os.path.join(save_det_path, os.path.basename(img_name))
66 67 68 69 70 71 72
        cv2.imwrite(save_path, src_im)
        logger.info("The detected Image saved in {}".format(save_path))


def main():
    config = program.load_config(FLAGS.config)
    program.merge_config(FLAGS.opt)
littletomatodonkey's avatar
littletomatodonkey 已提交
73
    logger.info(config)
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
    program.check_gpu(use_gpu)

    place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

    det_model = create_module(config['Architecture']['function'])(params=config)

    startup_prog = fluid.Program()
    eval_prog = fluid.Program()
    with fluid.program_guard(eval_prog, startup_prog):
        with fluid.unique_name.guard():
            _, eval_outputs = det_model(mode="test")
            fetch_name_list = list(eval_outputs.keys())
            eval_fetch_list = [eval_outputs[v].name for v in fetch_name_list]

    eval_prog = eval_prog.clone(for_test=True)
    exe.run(startup_prog)

    # load checkpoints
    checkpoints = config['Global'].get('checkpoints')
    if checkpoints:
        path = checkpoints
        fluid.load(eval_prog, path, exe)
        logger.info("Finish initing model from {}".format(path))
    else:
        raise Exception("{} not exists!".format(checkpoints))

    save_res_path = config['Global']['save_res_path']
L
LDOUBLEV 已提交
105 106
    if not os.path.exists(os.path.dirname(save_res_path)):
        os.makedirs(os.path.dirname(save_res_path))
107
    with open(save_res_path, "wb") as fout:
L
LDOUBLEV 已提交
108

L
LDOUBLEV 已提交
109
        test_reader = reader_main(config=config, mode='test')
110 111 112 113 114 115 116 117 118 119 120 121
        tackling_num = 0
        for data in test_reader():
            img_num = len(data)
            tackling_num = tackling_num + img_num
            logger.info("tackling_num:%d", tackling_num)
            img_list = []
            ratio_list = []
            img_name_list = []
            for ino in range(img_num):
                img_list.append(data[ino][0])
                ratio_list.append(data[ino][1])
                img_name_list.append(data[ino][2])
L
LDOUBLEV 已提交
122

123 124 125 126 127 128 129 130 131 132
            img_list = np.concatenate(img_list, axis=0)
            outs = exe.run(eval_prog,\
                feed={'image': img_list},\
                fetch_list=eval_fetch_list)

            global_params = config['Global']
            postprocess_params = deepcopy(config["PostProcess"])
            postprocess_params.update(global_params)
            postprocess = create_module(postprocess_params['function'])\
                (params=postprocess_params)
L
LDOUBLEV 已提交
133 134 135 136
            if config['Global']['algorithm'] == 'EAST':
                dic = {'f_score': outs[0], 'f_geo': outs[1]}
            elif config['Global']['algorithm'] == 'DB':
                dic = {'maps': outs[0]}
L
licx 已提交
137 138
            elif config['Global']['algorithm'] == 'SAST':
                dic = {'f_score': outs[0], 'f_border': outs[1], 'f_tvo': outs[2], 'f_tco': outs[3]}
L
LDOUBLEV 已提交
139
            else:
L
licx 已提交
140
                raise Exception("only support algorithm: ['EAST', 'DB', 'SAST']")
L
LDOUBLEV 已提交
141
            dt_boxes_list = postprocess(dic, ratio_list)
142 143 144 145 146 147 148 149 150 151
            for ino in range(img_num):
                dt_boxes = dt_boxes_list[ino]
                img_name = img_name_list[ino]
                dt_boxes_json = []
                for box in dt_boxes:
                    tmp_json = {"transcription": ""}
                    tmp_json['points'] = box.tolist()
                    dt_boxes_json.append(tmp_json)
                otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n"
                fout.write(otstr.encode())
L
LDOUBLEV 已提交
152 153
                src_img = cv2.imread(img_name)
                draw_det_res(dt_boxes, config, src_img, img_name)
L
licx 已提交
154
                
155 156 157 158 159 160 161
    logger.info("success!")


if __name__ == '__main__':
    parser = program.ArgsParser()
    FLAGS = parser.parse_args()
    main()