infer_det.py 6.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from copy import deepcopy
import json

L
LDOUBLEV 已提交
23 24 25 26 27
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
28 29 30 31 32 33 34 35 36 37


def set_paddle_flags(**kwargs):
    for key, value in kwargs.items():
        if os.environ.get(key, None) is None:
            os.environ[key] = str(value)


# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
L
LDOUBLEV 已提交
38
# not take any effect.
39 40 41 42 43
set_paddle_flags(
    FLAGS_eager_delete_tensor_gb=0,  # enable GC to save memory
)

from paddle import fluid
L
LDOUBLEV 已提交
44
from ppocr.utils.utility import create_module, get_image_file_list
45 46 47
import program
from ppocr.utils.save_load import init_model
from ppocr.data.reader_main import reader_main
L
LDOUBLEV 已提交
48
import cv2
49 50 51 52 53

from ppocr.utils.utility import initial_logger
logger = initial_logger()


L
LDOUBLEV 已提交
54
def draw_det_res(dt_boxes, config, img, img_name):
55 56
    if len(dt_boxes) > 0:
        import cv2
L
LDOUBLEV 已提交
57
        src_im = img
58 59 60
        for box in dt_boxes:
            box = box.astype(np.int32).reshape((-1, 1, 2))
            cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
L
LDOUBLEV 已提交
61
        save_det_path = os.path.dirname(config['Global'][
62 63 64
            'save_res_path']) + "/det_results/"
        if not os.path.exists(save_det_path):
            os.makedirs(save_det_path)
L
LDOUBLEV 已提交
65
        save_path = os.path.join(save_det_path, os.path.basename(img_name))
66 67 68
        cv2.imwrite(save_path, src_im)
        logger.info("The detected Image saved in {}".format(save_path))

L
licx 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
def gen_im_detection(src_im, detections):
    """
    Generate image with detection results.
    """
    im_detection = src_im.copy()

    h, w, _ = im_detection.shape
    thickness = int(max((h + w) / 2000, 1))

    for poly in detections:
        # Draw the first point
        cv2.putText(im_detection, '0', org=(int(poly[0, 0]), int(poly[0, 1])),
                    fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=thickness, color=(255, 0, 0),
                    thickness=thickness)

        cv2.polylines(im_detection, np.array(poly).reshape((1, -1, 2)).astype(np.int32), isClosed=True,
                  color=(0, 0, 255), thickness=thickness)

    return im_detection
88 89 90 91

def main():
    config = program.load_config(FLAGS.config)
    program.merge_config(FLAGS.opt)
littletomatodonkey's avatar
littletomatodonkey 已提交
92
    logger.info(config)
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
    program.check_gpu(use_gpu)

    place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

    det_model = create_module(config['Architecture']['function'])(params=config)

    startup_prog = fluid.Program()
    eval_prog = fluid.Program()
    with fluid.program_guard(eval_prog, startup_prog):
        with fluid.unique_name.guard():
            _, eval_outputs = det_model(mode="test")
            fetch_name_list = list(eval_outputs.keys())
            eval_fetch_list = [eval_outputs[v].name for v in fetch_name_list]

    eval_prog = eval_prog.clone(for_test=True)
    exe.run(startup_prog)

    # load checkpoints
    checkpoints = config['Global'].get('checkpoints')
    if checkpoints:
        path = checkpoints
        fluid.load(eval_prog, path, exe)
        logger.info("Finish initing model from {}".format(path))
    else:
        raise Exception("{} not exists!".format(checkpoints))

    save_res_path = config['Global']['save_res_path']
L
LDOUBLEV 已提交
124 125
    if not os.path.exists(os.path.dirname(save_res_path)):
        os.makedirs(os.path.dirname(save_res_path))
126
    with open(save_res_path, "wb") as fout:
L
LDOUBLEV 已提交
127

L
LDOUBLEV 已提交
128
        test_reader = reader_main(config=config, mode='test')
129 130 131 132 133 134 135 136 137 138 139 140
        tackling_num = 0
        for data in test_reader():
            img_num = len(data)
            tackling_num = tackling_num + img_num
            logger.info("tackling_num:%d", tackling_num)
            img_list = []
            ratio_list = []
            img_name_list = []
            for ino in range(img_num):
                img_list.append(data[ino][0])
                ratio_list.append(data[ino][1])
                img_name_list.append(data[ino][2])
L
LDOUBLEV 已提交
141

142 143 144 145 146 147 148 149 150 151
            img_list = np.concatenate(img_list, axis=0)
            outs = exe.run(eval_prog,\
                feed={'image': img_list},\
                fetch_list=eval_fetch_list)

            global_params = config['Global']
            postprocess_params = deepcopy(config["PostProcess"])
            postprocess_params.update(global_params)
            postprocess = create_module(postprocess_params['function'])\
                (params=postprocess_params)
L
LDOUBLEV 已提交
152 153 154 155
            if config['Global']['algorithm'] == 'EAST':
                dic = {'f_score': outs[0], 'f_geo': outs[1]}
            elif config['Global']['algorithm'] == 'DB':
                dic = {'maps': outs[0]}
L
licx 已提交
156 157
            elif config['Global']['algorithm'] == 'SAST':
                dic = {'f_score': outs[0], 'f_border': outs[1], 'f_tvo': outs[2], 'f_tco': outs[3]}
L
LDOUBLEV 已提交
158
            else:
L
licx 已提交
159
                raise Exception("only support algorithm: ['EAST', 'DB', 'SAST']")
L
LDOUBLEV 已提交
160
            dt_boxes_list = postprocess(dic, ratio_list)
161 162 163 164 165 166 167 168 169 170
            for ino in range(img_num):
                dt_boxes = dt_boxes_list[ino]
                img_name = img_name_list[ino]
                dt_boxes_json = []
                for box in dt_boxes:
                    tmp_json = {"transcription": ""}
                    tmp_json['points'] = box.tolist()
                    dt_boxes_json.append(tmp_json)
                otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n"
                fout.write(otstr.encode())
L
LDOUBLEV 已提交
171 172
                src_img = cv2.imread(img_name)
                draw_det_res(dt_boxes, config, src_img, img_name)
L
licx 已提交
173
                
174 175 176 177 178 179 180
    logger.info("success!")


if __name__ == '__main__':
    parser = program.ArgsParser()
    FLAGS = parser.parse_args()
    main()