predict_system.py 8.9 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15
import os
import sys
L
LDOUBLEV 已提交
16
import subprocess
W
WenmuZhou 已提交
17

18
__dir__ = os.path.dirname(os.path.abspath(__file__))
19
sys.path.append(__dir__)
20
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
L
LDOUBLEV 已提交
21

L
LDOUBLEV 已提交
22 23
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

L
LDOUBLEV 已提交
24 25 26 27
import cv2
import copy
import numpy as np
import time
W
WenmuZhou 已提交
28
import logging
L
LDOUBLEV 已提交
29
from PIL import Image
W
WenmuZhou 已提交
30 31 32
import tools.infer.utility as utility
import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det
W
WenmuZhou 已提交
33
import tools.infer.predict_cls as predict_cls
W
WenmuZhou 已提交
34 35
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
W
WenmuZhou 已提交
36
from tools.infer.utility import draw_ocr_box_txt, get_current_memory_mb, get_rotate_crop_image
L
LDOUBLEV 已提交
37
import tools.infer.benchmark_utils as benchmark_utils
W
WenmuZhou 已提交
38 39
logger = get_logger()

L
LDOUBLEV 已提交
40 41 42

class TextSystem(object):
    def __init__(self, args):
W
WenmuZhou 已提交
43 44 45
        if not args.show_log:
            logger.setLevel(logging.INFO)

L
LDOUBLEV 已提交
46 47
        self.text_detector = predict_det.TextDetector(args)
        self.text_recognizer = predict_rec.TextRecognizer(args)
W
WenmuZhou 已提交
48
        self.use_angle_cls = args.use_angle_cls
W
WenmuZhou 已提交
49
        self.drop_score = args.drop_score
W
WenmuZhou 已提交
50 51
        if self.use_angle_cls:
            self.text_classifier = predict_cls.TextClassifier(args)
L
LDOUBLEV 已提交
52 53 54 55 56

    def print_draw_crop_rec_res(self, img_crop_list, rec_res):
        bbox_num = len(img_crop_list)
        for bno in range(bbox_num):
            cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno])
W
WenmuZhou 已提交
57
            logger.info(bno, rec_res[bno])
L
LDOUBLEV 已提交
58

59
    def __call__(self, img, cls=True):
L
LDOUBLEV 已提交
60 61
        ori_im = img.copy()
        dt_boxes, elapse = self.text_detector(img)
L
LDOUBLEV 已提交
62

W
WenmuZhou 已提交
63
        logger.debug("dt_boxes num : {}, elapse : {}".format(
W
WenmuZhou 已提交
64
            len(dt_boxes), elapse))
L
LDOUBLEV 已提交
65 66 67
        if dt_boxes is None:
            return None, None
        img_crop_list = []
68 69 70

        dt_boxes = sorted_boxes(dt_boxes)

L
LDOUBLEV 已提交
71 72
        for bno in range(len(dt_boxes)):
            tmp_box = copy.deepcopy(dt_boxes[bno])
W
WenmuZhou 已提交
73
            img_crop = get_rotate_crop_image(ori_im, tmp_box)
L
LDOUBLEV 已提交
74
            img_crop_list.append(img_crop)
75
        if self.use_angle_cls and cls:
W
WenmuZhou 已提交
76 77
            img_crop_list, angle_list, elapse = self.text_classifier(
                img_crop_list)
W
WenmuZhou 已提交
78
            logger.debug("cls num  : {}, elapse : {}".format(
W
WenmuZhou 已提交
79 80
                len(img_crop_list), elapse))

L
LDOUBLEV 已提交
81
        rec_res, elapse = self.text_recognizer(img_crop_list)
W
WenmuZhou 已提交
82
        logger.debug("rec_res num  : {}, elapse : {}".format(
W
WenmuZhou 已提交
83
            len(rec_res), elapse))
84
        # self.print_draw_crop_rec_res(img_crop_list, rec_res)
W
WenmuZhou 已提交
85 86 87 88 89 90 91
        filter_boxes, filter_rec_res = [], []
        for box, rec_reuslt in zip(dt_boxes, rec_res):
            text, score = rec_reuslt
            if score >= self.drop_score:
                filter_boxes.append(box)
                filter_rec_res.append(rec_reuslt)
        return filter_boxes, filter_rec_res
L
LDOUBLEV 已提交
92 93


94 95 96 97
def sorted_boxes(dt_boxes):
    """
    Sort text boxes in order from top to bottom, left to right
    args:
T
tink2123 已提交
98
        dt_boxes(array):detected text boxes with shape [4, 2]
99 100 101 102
    return:
        sorted boxes(array) with shape [4, 2]
    """
    num_boxes = dt_boxes.shape[0]
103
    sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
104 105 106
    _boxes = list(sorted_boxes)

    for i in range(num_boxes - 1):
W
WenmuZhou 已提交
107 108
        if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
                (_boxes[i + 1][0][0] < _boxes[i][0][0]):
109 110 111 112 113 114
            tmp = _boxes[i]
            _boxes[i] = _boxes[i + 1]
            _boxes[i + 1] = tmp
    return _boxes


115
def main(args):
L
LDOUBLEV 已提交
116
    image_file_list = get_image_file_list(args.image_dir)
L
LDOUBLEV 已提交
117
    image_file_list = image_file_list[args.process_id::args.total_process_num]
L
LDOUBLEV 已提交
118
    text_sys = TextSystem(args)
L
LDOUBLEV 已提交
119
    is_visualize = True
W
WenmuZhou 已提交
120
    font_path = args.vis_font_path
W
WenmuZhou 已提交
121
    drop_score = args.drop_score
D
Double_V 已提交
122

L
LDOUBLEV 已提交
123 124 125 126 127
    # warm up 10 times
    if args.warmup:
        img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
        for i in range(10):
            res = text_sys(img)
W
WenmuZhou 已提交
128

L
LDOUBLEV 已提交
129 130 131 132 133
    total_time = 0
    cpu_mem, gpu_mem, gpu_util = 0, 0, 0
    _st = time.time()
    count = 0
    for idx, image_file in enumerate(image_file_list):
L
LDOUBLEV 已提交
134

L
LDOUBLEV 已提交
135 136 137
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
L
LDOUBLEV 已提交
138
        if img is None:
139
            logger.info("error in loading image:{}".format(image_file))
L
LDOUBLEV 已提交
140 141 142 143
            continue
        starttime = time.time()
        dt_boxes, rec_res = text_sys(img)
        elapse = time.time() - starttime
L
LDOUBLEV 已提交
144 145 146 147 148 149 150
        total_time += elapse
        if args.benchmark and idx % 20 == 0:
            cm, gm, gu = get_current_memory_mb(0)
            cpu_mem += cm
            gpu_mem += gm
            gpu_util += gu
            count += 1
L
LDOUBLEV 已提交
151

L
LDOUBLEV 已提交
152 153
        logger.info(
            str(idx) + "  Predict time of %s: %.3fs" % (image_file, elapse))
W
WenmuZhou 已提交
154 155
        for text, score in rec_res:
            logger.info("{}, {:.3f}".format(text, score))
L
LDOUBLEV 已提交
156 157 158 159 160 161 162

        if is_visualize:
            image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            boxes = dt_boxes
            txts = [rec_res[i][0] for i in range(len(rec_res))]
            scores = [rec_res[i][1] for i in range(len(rec_res))]

W
WenmuZhou 已提交
163 164 165 166 167 168 169
            draw_img = draw_ocr_box_txt(
                image,
                boxes,
                txts,
                scores,
                drop_score=drop_score,
                font_path=font_path)
170
            draw_img_save = "./inference_results/"
L
LDOUBLEV 已提交
171 172
            if not os.path.exists(draw_img_save):
                os.makedirs(draw_img_save)
L
LDOUBLEV 已提交
173 174
            if flag:
                image_file = image_file[:-3] + "png"
L
LDOUBLEV 已提交
175 176
            cv2.imwrite(
                os.path.join(draw_img_save, os.path.basename(image_file)),
D
dyning 已提交
177
                draw_img[:, :, ::-1])
W
WenmuZhou 已提交
178
            logger.info("The visualized image saved in {}".format(
179
                os.path.join(draw_img_save, os.path.basename(image_file))))
180

L
LDOUBLEV 已提交
181 182
    logger.info("The predict total time is {}".format(time.time() - _st))
    logger.info("\nThe predict total time is {}".format(total_time))
183

L
LDOUBLEV 已提交
184 185 186 187 188 189 190
    img_num = text_sys.text_detector.det_times.img_num
    if args.benchmark:
        mems = {
            'cpu_rss_mb': cpu_mem / count,
            'gpu_rss_mb': gpu_mem / count,
            'gpu_util': gpu_util * 100 / count
        }
littletomatodonkey's avatar
littletomatodonkey 已提交
191
    else:
L
LDOUBLEV 已提交
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
        mems = None
    det_time_dict = text_sys.text_detector.det_times.report(average=True)
    rec_time_dict = text_sys.text_recognizer.rec_times.report(average=True)
    det_model_name = args.det_model_dir
    rec_model_name = args.rec_model_dir

    # construct det log information
    model_info = {
        'model_name': args.det_model_dir.split('/')[-1],
        'precision': args.precision
    }
    data_info = {
        'batch_size': 1,
        'shape': 'dynamic_shape',
        'data_num': det_time_dict['img_num']
    }
    perf_info = {
        'preprocess_time_s': det_time_dict['preprocess_time'],
        'inference_time_s': det_time_dict['inference_time'],
        'postprocess_time_s': det_time_dict['postprocess_time'],
        'total_time_s': det_time_dict['total_time']
    }

    benchmark_log = benchmark_utils.PaddleInferBenchmark(
        text_sys.text_detector.config, model_info, data_info, perf_info, mems,
        args.save_log_path)
    benchmark_log("Det")

    # construct rec log information
    model_info = {
        'model_name': args.rec_model_dir.split('/')[-1],
        'precision': args.precision
    }
    data_info = {
        'batch_size': args.rec_batch_num,
        'shape': 'dynamic_shape',
        'data_num': rec_time_dict['img_num']
    }
    perf_info = {
        'preprocess_time_s': rec_time_dict['preprocess_time'],
        'inference_time_s': rec_time_dict['inference_time'],
        'postprocess_time_s': rec_time_dict['postprocess_time'],
        'total_time_s': rec_time_dict['total_time']
    }
    benchmark_log = benchmark_utils.PaddleInferBenchmark(
        text_sys.text_recognizer.config, model_info, data_info, perf_info, mems,
        args.save_log_path)
    benchmark_log("Rec")


if __name__ == "__main__":
L
LDOUBLEV 已提交
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
    args = utility.parse_args()
    if args.use_mp:
        p_list = []
        total_process_num = args.total_process_num
        for process_id in range(total_process_num):
            cmd = [sys.executable, "-u"] + sys.argv + [
                "--process_id={}".format(process_id),
                "--use_mp={}".format(False)
            ]
            p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
            p_list.append(p)
        for p in p_list:
            p.wait()
    else:
        main(args)