predict_det.py 10.5 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
LDOUBLEV 已提交
14 15
import os
import sys
W
WenmuZhou 已提交
16

17
__dir__ = os.path.dirname(os.path.abspath(__file__))
L
LDOUBLEV 已提交
18
sys.path.append(__dir__)
19
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
L
LDOUBLEV 已提交
20

L
LDOUBLEV 已提交
21 22
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

23 24 25 26 27
import cv2
import numpy as np
import time
import sys

L
LDOUBLEV 已提交
28
import tools.infer.utility as utility
W
WenmuZhou 已提交
29
from ppocr.utils.logging import get_logger
L
LDOUBLEV 已提交
30
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
W
WenmuZhou 已提交
31 32
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
L
LDOUBLEV 已提交
33

L
LDOUBLEV 已提交
34 35
import tools.infer.benchmark_utils as benchmark_utils

W
WenmuZhou 已提交
36 37
logger = get_logger()

L
LDOUBLEV 已提交
38 39 40

class TextDetector(object):
    def __init__(self, args):
L
LDOUBLEV 已提交
41
        self.args = args
L
LDOUBLEV 已提交
42
        self.det_algorithm = args.det_algorithm
M
MissPenguin 已提交
43
        pre_process_list = [{
44 45
            'DetResizeForTest': {
                'limit_side_len': args.det_limit_side_len,
W
WenmuZhou 已提交
46
                'limit_type': args.det_limit_type,
47
            }
M
MissPenguin 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61
        }, {
            'NormalizeImage': {
                'std': [0.229, 0.224, 0.225],
                'mean': [0.485, 0.456, 0.406],
                'scale': '1./255.',
                'order': 'hwc'
            }
        }, {
            'ToCHWImage': None
        }, {
            'KeepKeys': {
                'keep_keys': ['image', 'shape']
            }
        }]
L
LDOUBLEV 已提交
62 63
        postprocess_params = {}
        if self.det_algorithm == "DB":
W
WenmuZhou 已提交
64
            postprocess_params['name'] = 'DBPostProcess'
L
LDOUBLEV 已提交
65 66 67
            postprocess_params["thresh"] = args.det_db_thresh
            postprocess_params["box_thresh"] = args.det_db_box_thresh
            postprocess_params["max_candidates"] = 1000
68
            postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
L
LDOUBLEV 已提交
69
            postprocess_params["use_dilation"] = args.use_dilation
littletomatodonkey's avatar
littletomatodonkey 已提交
70
            postprocess_params["score_mode"] = args.det_db_score_mode
M
MissPenguin 已提交
71
        elif self.det_algorithm == "EAST":
W
WenmuZhou 已提交
72
            postprocess_params['name'] = 'EASTPostProcess'
M
MissPenguin 已提交
73 74 75 76
            postprocess_params["score_thresh"] = args.det_east_score_thresh
            postprocess_params["cover_thresh"] = args.det_east_cover_thresh
            postprocess_params["nms_thresh"] = args.det_east_nms_thresh
        elif self.det_algorithm == "SAST":
M
MissPenguin 已提交
77
            pre_process_list[0] = {
W
WenmuZhou 已提交
78 79 80
                'DetResizeForTest': {
                    'resize_long': args.det_limit_side_len
                }
M
MissPenguin 已提交
81
            }
W
WenmuZhou 已提交
82
            postprocess_params['name'] = 'SASTPostProcess'
M
MissPenguin 已提交
83 84 85 86 87 88 89 90 91 92 93
            postprocess_params["score_thresh"] = args.det_sast_score_thresh
            postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
            self.det_sast_polygon = args.det_sast_polygon
            if self.det_sast_polygon:
                postprocess_params["sample_pts_num"] = 6
                postprocess_params["expand_scale"] = 1.2
                postprocess_params["shrink_ratio_of_width"] = 0.2
            else:
                postprocess_params["sample_pts_num"] = 2
                postprocess_params["expand_scale"] = 1.0
                postprocess_params["shrink_ratio_of_width"] = 0.3
L
LDOUBLEV 已提交
94 95 96 97
        else:
            logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
            sys.exit(0)

W
WenmuZhou 已提交
98 99
        self.preprocess_op = create_operators(pre_process_list)
        self.postprocess_op = build_post_process(postprocess_params)
L
LDOUBLEV 已提交
100 101 102 103
        self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
            args, 'det', logger)

        self.det_times = utility.Timer()
L
LDOUBLEV 已提交
104 105

    def order_points_clockwise(self, pts):
106 107
        """
        reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
L
LDOUBLEV 已提交
108
        # sort the points based on their x-coordinates
109
        """
L
LDOUBLEV 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
        xSorted = pts[np.argsort(pts[:, 0]), :]

        # grab the left-most and right-most points from the sorted
        # x-roodinate points
        leftMost = xSorted[:2, :]
        rightMost = xSorted[2:, :]

        # now, sort the left-most coordinates according to their
        # y-coordinates so we can grab the top-left and bottom-left
        # points, respectively
        leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
        (tl, bl) = leftMost

        rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
        (tr, br) = rightMost

        rect = np.array([tl, tr, br, bl], dtype="float32")
        return rect

D
dyning 已提交
129
    def clip_det_res(self, points, img_height, img_width):
130
        for pno in range(points.shape[0]):
D
dyning 已提交
131 132
            points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
            points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
L
LDOUBLEV 已提交
133 134 135 136 137 138 139
        return points

    def filter_tag_det_res(self, dt_boxes, image_shape):
        img_height, img_width = image_shape[0:2]
        dt_boxes_new = []
        for box in dt_boxes:
            box = self.order_points_clockwise(box)
D
dyning 已提交
140
            box = self.clip_det_res(box, img_height, img_width)
L
LDOUBLEV 已提交
141 142
            rect_width = int(np.linalg.norm(box[0] - box[1]))
            rect_height = int(np.linalg.norm(box[0] - box[3]))
M
MissPenguin 已提交
143
            if rect_width <= 3 or rect_height <= 3:
L
LDOUBLEV 已提交
144 145 146 147 148
                continue
            dt_boxes_new.append(box)
        dt_boxes = np.array(dt_boxes_new)
        return dt_boxes

149 150 151 152 153 154 155 156
    def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
        img_height, img_width = image_shape[0:2]
        dt_boxes_new = []
        for box in dt_boxes:
            box = self.clip_det_res(box, img_height, img_width)
            dt_boxes_new.append(box)
        dt_boxes = np.array(dt_boxes_new)
        return dt_boxes
157

L
LDOUBLEV 已提交
158 159
    def __call__(self, img):
        ori_im = img.copy()
W
WenmuZhou 已提交
160
        data = {'image': img}
L
LDOUBLEV 已提交
161 162
        self.det_times.total_time.start()
        self.det_times.preprocess_time.start()
W
WenmuZhou 已提交
163 164 165
        data = transform(data, self.preprocess_op)
        img, shape_list = data
        if img is None:
L
LDOUBLEV 已提交
166
            return None, 0
W
WenmuZhou 已提交
167 168
        img = np.expand_dims(img, axis=0)
        shape_list = np.expand_dims(shape_list, axis=0)
169
        img = img.copy()
L
LDOUBLEV 已提交
170 171 172

        self.det_times.preprocess_time.end()
        self.det_times.inference_time.start()
173

W
WenmuZhou 已提交
174 175
        self.input_tensor.copy_from_cpu(img)
        self.predictor.run()
176 177 178 179
        outputs = []
        for output_tensor in self.output_tensors:
            output = output_tensor.copy_to_cpu()
            outputs.append(output)
L
LDOUBLEV 已提交
180
        self.det_times.inference_time.end()
181

M
MissPenguin 已提交
182 183 184 185 186 187 188 189 190
        preds = {}
        if self.det_algorithm == "EAST":
            preds['f_geo'] = outputs[0]
            preds['f_score'] = outputs[1]
        elif self.det_algorithm == 'SAST':
            preds['f_border'] = outputs[0]
            preds['f_score'] = outputs[1]
            preds['f_tco'] = outputs[2]
            preds['f_tvo'] = outputs[3]
W
WenmuZhou 已提交
191
        elif self.det_algorithm == 'DB':
W
WenmuZhou 已提交
192
            preds['maps'] = outputs[0]
W
WenmuZhou 已提交
193 194
        else:
            raise NotImplementedError
L
LDOUBLEV 已提交
195 196 197

        self.det_times.postprocess_time.start()

W
fix mem  
WenmuZhou 已提交
198
        self.predictor.try_shrink_memory()
W
WenmuZhou 已提交
199 200
        post_result = self.postprocess_op(preds, shape_list)
        dt_boxes = post_result[0]['points']
M
MissPenguin 已提交
201 202 203 204
        if self.det_algorithm == "SAST" and self.det_sast_polygon:
            dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
        else:
            dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
L
LDOUBLEV 已提交
205 206 207 208 209

        self.det_times.postprocess_time.end()
        self.det_times.total_time.end()
        self.det_times.img_num += 1
        return dt_boxes, self.det_times.total_time.value()
L
LDOUBLEV 已提交
210 211 212 213


if __name__ == "__main__":
    args = utility.parse_args()
L
LDOUBLEV 已提交
214
    image_file_list = get_image_file_list(args.image_dir)
L
LDOUBLEV 已提交
215 216 217
    text_detector = TextDetector(args)
    count = 0
    total_time = 0
littletomatodonkey's avatar
littletomatodonkey 已提交
218
    draw_img_save = "./inference_results"
L
LDOUBLEV 已提交
219 220 221 222 223 224 225
    cpu_mem, gpu_mem, gpu_util = 0, 0, 0

    # warmup 10 times
    fake_img = np.random.uniform(-1, 1, [640, 640, 3]).astype(np.float32)
    for i in range(10):
        dt_boxes, _ = text_detector(fake_img)

littletomatodonkey's avatar
littletomatodonkey 已提交
226 227
    if not os.path.exists(draw_img_save):
        os.makedirs(draw_img_save)
L
LDOUBLEV 已提交
228
    for image_file in image_file_list:
L
LDOUBLEV 已提交
229 230 231
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
L
LDOUBLEV 已提交
232 233 234
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
L
LDOUBLEV 已提交
235 236 237
        st = time.time()
        dt_boxes, _ = text_detector(img)
        elapse = time.time() - st
L
LDOUBLEV 已提交
238 239 240
        if count > 0:
            total_time += elapse
        count += 1
L
LDOUBLEV 已提交
241 242 243 244 245 246 247

        if args.benchmark:
            cm, gm, gu = utility.get_current_memory_mb(0)
            cpu_mem += cm
            gpu_mem += gm
            gpu_util += gu

W
WenmuZhou 已提交
248
        logger.info("Predict time of {}: {}".format(image_file, elapse))
D
dyning 已提交
249
        src_im = utility.draw_text_det_res(dt_boxes, image_file)
W
WenmuZhou 已提交
250
        img_name_pure = os.path.split(image_file)[-1]
W
WenmuZhou 已提交
251 252
        img_path = os.path.join(draw_img_save,
                                "det_res_{}".format(img_name_pure))
L
LDOUBLEV 已提交
253

W
WenmuZhou 已提交
254
        logger.info("The visualized image saved in {}".format(img_path))
L
LDOUBLEV 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
    # print the information about memory and time-spent
    if args.benchmark:
        mems = {
            'cpu_rss_mb': cpu_mem / count,
            'gpu_rss_mb': gpu_mem / count,
            'gpu_util': gpu_util * 100 / count
        }
    else:
        mems = None
    logger.info("The predict time about detection module is as follows: ")
    det_time_dict = text_detector.det_times.report(average=True)
    det_model_name = args.det_model_dir

    if args.benchmark:
        # construct log information
        model_info = {
            'model_name': args.det_model_dir.split('/')[-1],
            'precision': args.precision
        }
        data_info = {
            'batch_size': 1,
            'shape': 'dynamic_shape',
            'data_num': det_time_dict['img_num']
        }
        perf_info = {
            'preprocess_time_s': det_time_dict['preprocess_time'],
            'inference_time_s': det_time_dict['inference_time'],
            'postprocess_time_s': det_time_dict['postprocess_time'],
            'total_time_s': det_time_dict['total_time']
        }
        benchmark_log = benchmark_utils.PaddleInferBenchmark(
            text_detector.config, model_info, data_info, perf_info, mems)
        benchmark_log("Det")