predict_det.py 13.8 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
LDOUBLEV 已提交
14 15
import os
import sys
W
WenmuZhou 已提交
16

17
__dir__ = os.path.dirname(os.path.abspath(__file__))
L
LDOUBLEV 已提交
18
sys.path.append(__dir__)
littletomatodonkey's avatar
littletomatodonkey 已提交
19
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
L
LDOUBLEV 已提交
20

L
LDOUBLEV 已提交
21 22
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

23 24 25 26 27
import cv2
import numpy as np
import time
import sys

L
LDOUBLEV 已提交
28
import tools.infer.utility as utility
W
WenmuZhou 已提交
29
from ppocr.utils.logging import get_logger
30
from ppocr.utils.utility import get_image_file_list, check_and_read
W
WenmuZhou 已提交
31 32
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
L
LDOUBLEV 已提交
33
import json
W
WenmuZhou 已提交
34 35
logger = get_logger()

L
LDOUBLEV 已提交
36 37 38

class TextDetector(object):
    def __init__(self, args):
L
LDOUBLEV 已提交
39
        self.args = args
L
LDOUBLEV 已提交
40
        self.det_algorithm = args.det_algorithm
T
tink2123 已提交
41
        self.use_onnx = args.use_onnx
M
MissPenguin 已提交
42
        pre_process_list = [{
43 44
            'DetResizeForTest': {
                'limit_side_len': args.det_limit_side_len,
W
WenmuZhou 已提交
45
                'limit_type': args.det_limit_type,
46
            }
M
MissPenguin 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60
        }, {
            'NormalizeImage': {
                'std': [0.229, 0.224, 0.225],
                'mean': [0.485, 0.456, 0.406],
                'scale': '1./255.',
                'order': 'hwc'
            }
        }, {
            'ToCHWImage': None
        }, {
            'KeepKeys': {
                'keep_keys': ['image', 'shape']
            }
        }]
L
LDOUBLEV 已提交
61 62
        postprocess_params = {}
        if self.det_algorithm == "DB":
W
WenmuZhou 已提交
63
            postprocess_params['name'] = 'DBPostProcess'
L
LDOUBLEV 已提交
64 65 66
            postprocess_params["thresh"] = args.det_db_thresh
            postprocess_params["box_thresh"] = args.det_db_box_thresh
            postprocess_params["max_candidates"] = 1000
67
            postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
L
LDOUBLEV 已提交
68
            postprocess_params["use_dilation"] = args.use_dilation
littletomatodonkey's avatar
littletomatodonkey 已提交
69
            postprocess_params["score_mode"] = args.det_db_score_mode
W
wangjingyeye 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
        elif self.det_algorithm == "DB++":
            postprocess_params['name'] = 'DBPostProcess'
            postprocess_params["thresh"] = args.det_db_thresh
            postprocess_params["box_thresh"] = args.det_db_box_thresh
            postprocess_params["max_candidates"] = 1000
            postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
            postprocess_params["use_dilation"] = args.use_dilation
            postprocess_params["score_mode"] = args.det_db_score_mode
            pre_process_list[1] = {
                'NormalizeImage': {
                    'std': [1.0, 1.0, 1.0],
                    'mean':
                    [0.48109378172549, 0.45752457890196, 0.40787054090196],
                    'scale': '1./255.',
                    'order': 'hwc'
                }
            }
M
MissPenguin 已提交
87
        elif self.det_algorithm == "EAST":
W
WenmuZhou 已提交
88
            postprocess_params['name'] = 'EASTPostProcess'
M
MissPenguin 已提交
89 90 91 92
            postprocess_params["score_thresh"] = args.det_east_score_thresh
            postprocess_params["cover_thresh"] = args.det_east_cover_thresh
            postprocess_params["nms_thresh"] = args.det_east_nms_thresh
        elif self.det_algorithm == "SAST":
M
MissPenguin 已提交
93
            pre_process_list[0] = {
W
WenmuZhou 已提交
94 95 96
                'DetResizeForTest': {
                    'resize_long': args.det_limit_side_len
                }
M
MissPenguin 已提交
97
            }
W
WenmuZhou 已提交
98
            postprocess_params['name'] = 'SASTPostProcess'
M
MissPenguin 已提交
99 100 101 102 103 104 105 106 107 108 109
            postprocess_params["score_thresh"] = args.det_sast_score_thresh
            postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
            self.det_sast_polygon = args.det_sast_polygon
            if self.det_sast_polygon:
                postprocess_params["sample_pts_num"] = 6
                postprocess_params["expand_scale"] = 1.2
                postprocess_params["shrink_ratio_of_width"] = 0.2
            else:
                postprocess_params["sample_pts_num"] = 2
                postprocess_params["expand_scale"] = 1.0
                postprocess_params["shrink_ratio_of_width"] = 0.3
W
WenmuZhou 已提交
110 111 112 113 114 115 116 117
        elif self.det_algorithm == "PSE":
            postprocess_params['name'] = 'PSEPostProcess'
            postprocess_params["thresh"] = args.det_pse_thresh
            postprocess_params["box_thresh"] = args.det_pse_box_thresh
            postprocess_params["min_area"] = args.det_pse_min_area
            postprocess_params["box_type"] = args.det_pse_box_type
            postprocess_params["scale"] = args.det_pse_scale
            self.det_pse_box_type = args.det_pse_box_type
文幕地方's avatar
文幕地方 已提交
118 119 120 121 122 123 124 125 126 127 128 129
        elif self.det_algorithm == "FCE":
            pre_process_list[0] = {
                'DetResizeForTest': {
                    'rescale_img': [1080, 736]
                }
            }
            postprocess_params['name'] = 'FCEPostProcess'
            postprocess_params["scales"] = args.scales
            postprocess_params["alpha"] = args.alpha
            postprocess_params["beta"] = args.beta
            postprocess_params["fourier_degree"] = args.fourier_degree
            postprocess_params["box_type"] = args.det_fce_box_type
H
huangjun12 已提交
130 131 132
        elif self.det_algorithm == "CT":
            pre_process_list[0] = {'ScaleAlignedShort': {'short_size': 640}}
            postprocess_params['name'] = 'CTPostProcess'
L
LDOUBLEV 已提交
133 134 135
        else:
            logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
            sys.exit(0)
136

W
WenmuZhou 已提交
137 138
        self.preprocess_op = create_operators(pre_process_list)
        self.postprocess_op = build_post_process(postprocess_params)
L
LDOUBLEV 已提交
139 140 141
        self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
            args, 'det', logger)

142 143 144 145 146 147 148 149 150 151
        if self.use_onnx:
            img_h, img_w = self.input_tensor.shape[2:]
            if img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
                pre_process_list[0] = {
                    'DetResizeForTest': {
                        'image_shape': [img_h, img_w]
                    }
                }
        self.preprocess_op = create_operators(pre_process_list)

D
Double_V 已提交
152
        if args.benchmark:
D
Double_V 已提交
153
            import auto_log
D
Double_V 已提交
154
            pid = os.getpid()
L
LDOUBLEV 已提交
155
            gpu_id = utility.get_infer_gpuid()
D
Double_V 已提交
156 157 158 159 160
            self.autolog = auto_log.AutoLogger(
                model_name="det",
                model_precision=args.precision,
                batch_size=1,
                data_shape="dynamic",
L
LDOUBLEV 已提交
161
                save_path=None,
D
Double_V 已提交
162 163 164
                inference_config=self.config,
                pids=pid,
                process_name=None,
165
                gpu_ids=gpu_id if args.use_gpu else None,
D
Double_V 已提交
166 167 168
                time_keys=[
                    'preprocess_time', 'inference_time', 'postprocess_time'
                ],
169
                warmup=2,
L
LDOUBLEV 已提交
170
                logger=logger)
L
LDOUBLEV 已提交
171

L
LDOUBLEV 已提交
172
    def order_points_clockwise(self, pts):
L
fix  
LDOUBLEV 已提交
173
        rect = np.zeros((4, 2), dtype="float32")
L
LDOUBLEV 已提交
174 175 176 177 178 179 180 181
        s = pts.sum(axis=1)
        rect[0] = pts[np.argmin(s)]
        rect[2] = pts[np.argmax(s)]
        tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
        diff = np.diff(np.array(tmp), axis=1)
        rect[1] = tmp[np.argmin(diff)]
        rect[3] = tmp[np.argmax(diff)]
        return rect
文幕地方's avatar
文幕地方 已提交
182

D
dyning 已提交
183
    def clip_det_res(self, points, img_height, img_width):
184
        for pno in range(points.shape[0]):
D
dyning 已提交
185 186
            points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
            points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
L
LDOUBLEV 已提交
187 188 189 190 191 192 193
        return points

    def filter_tag_det_res(self, dt_boxes, image_shape):
        img_height, img_width = image_shape[0:2]
        dt_boxes_new = []
        for box in dt_boxes:
            box = self.order_points_clockwise(box)
D
dyning 已提交
194
            box = self.clip_det_res(box, img_height, img_width)
L
LDOUBLEV 已提交
195 196
            rect_width = int(np.linalg.norm(box[0] - box[1]))
            rect_height = int(np.linalg.norm(box[0] - box[3]))
M
MissPenguin 已提交
197
            if rect_width <= 3 or rect_height <= 3:
L
LDOUBLEV 已提交
198 199 200 201 202
                continue
            dt_boxes_new.append(box)
        dt_boxes = np.array(dt_boxes_new)
        return dt_boxes

203 204 205 206 207 208 209 210
    def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
        img_height, img_width = image_shape[0:2]
        dt_boxes_new = []
        for box in dt_boxes:
            box = self.clip_det_res(box, img_height, img_width)
            dt_boxes_new.append(box)
        dt_boxes = np.array(dt_boxes_new)
        return dt_boxes
211

L
LDOUBLEV 已提交
212 213
    def __call__(self, img):
        ori_im = img.copy()
W
WenmuZhou 已提交
214
        data = {'image': img}
L
LDOUBLEV 已提交
215 216

        st = time.time()
L
LDOUBLEV 已提交
217

littletomatodonkey's avatar
littletomatodonkey 已提交
218
        if self.args.benchmark:
D
Double_V 已提交
219
            self.autolog.times.start()
L
LDOUBLEV 已提交
220

W
WenmuZhou 已提交
221 222 223
        data = transform(data, self.preprocess_op)
        img, shape_list = data
        if img is None:
L
LDOUBLEV 已提交
224
            return None, 0
W
WenmuZhou 已提交
225 226
        img = np.expand_dims(img, axis=0)
        shape_list = np.expand_dims(shape_list, axis=0)
227
        img = img.copy()
L
LDOUBLEV 已提交
228

littletomatodonkey's avatar
littletomatodonkey 已提交
229
        if self.args.benchmark:
D
Double_V 已提交
230
            self.autolog.times.stamp()
T
tink2123 已提交
231 232 233 234 235 236 237 238 239 240 241 242 243
        if self.use_onnx:
            input_dict = {}
            input_dict[self.input_tensor.name] = img
            outputs = self.predictor.run(self.output_tensors, input_dict)
        else:
            self.input_tensor.copy_from_cpu(img)
            self.predictor.run()
            outputs = []
            for output_tensor in self.output_tensors:
                output = output_tensor.copy_to_cpu()
                outputs.append(output)
            if self.args.benchmark:
                self.autolog.times.stamp()
L
LDOUBLEV 已提交
244

M
MissPenguin 已提交
245 246 247 248 249 250 251 252 253
        preds = {}
        if self.det_algorithm == "EAST":
            preds['f_geo'] = outputs[0]
            preds['f_score'] = outputs[1]
        elif self.det_algorithm == 'SAST':
            preds['f_border'] = outputs[0]
            preds['f_score'] = outputs[1]
            preds['f_tco'] = outputs[2]
            preds['f_tvo'] = outputs[3]
W
wangjingyeye 已提交
254
        elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
W
WenmuZhou 已提交
255
            preds['maps'] = outputs[0]
文幕地方's avatar
文幕地方 已提交
256 257 258
        elif self.det_algorithm == 'FCE':
            for i, output in enumerate(outputs):
                preds['level_{}'.format(i)] = output
H
huangjun12 已提交
259 260 261
        elif self.det_algorithm == "CT":
            preds['maps'] = outputs[0]
            preds['score'] = outputs[1]
W
WenmuZhou 已提交
262 263
        else:
            raise NotImplementedError
L
LDOUBLEV 已提交
264

L
LDOUBLEV 已提交
265
        #self.predictor.try_shrink_memory()
W
WenmuZhou 已提交
266 267
        post_result = self.postprocess_op(preds, shape_list)
        dt_boxes = post_result[0]['points']
文幕地方's avatar
文幕地方 已提交
268
        if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
H
huangjun12 已提交
269
                self.det_algorithm in ["PSE", "FCE", "CT"] and
文幕地方's avatar
文幕地方 已提交
270
                self.postprocess_op.box_type == 'poly'):
W
WenmuZhou 已提交
271
            dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
M
MissPenguin 已提交
272 273
        else:
            dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
L
LDOUBLEV 已提交
274

littletomatodonkey's avatar
littletomatodonkey 已提交
275
        if self.args.benchmark:
D
Double_V 已提交
276
            self.autolog.times.end(stamp=True)
L
LDOUBLEV 已提交
277 278
        et = time.time()
        return dt_boxes, et - st
L
LDOUBLEV 已提交
279 280 281 282


if __name__ == "__main__":
    args = utility.parse_args()
L
LDOUBLEV 已提交
283
    image_file_list = get_image_file_list(args.image_dir)
L
LDOUBLEV 已提交
284 285
    text_detector = TextDetector(args)
    total_time = 0
A
andyjpaddle 已提交
286 287
    draw_img_save_dir = args.draw_img_save_dir
    os.makedirs(draw_img_save_dir, exist_ok=True)
L
LDOUBLEV 已提交
288

L
LDOUBLEV 已提交
289 290
    if args.warmup:
        img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
291
        for i in range(2):
L
LDOUBLEV 已提交
292 293
            res = text_detector(img)

L
LDOUBLEV 已提交
294
    save_results = []
A
andyjpaddle 已提交
295 296 297
    for idx, image_file in enumerate(image_file_list):
        img, flag_gif, flag_pdf = check_and_read(image_file)
        if not flag_gif and not flag_pdf:
L
LDOUBLEV 已提交
298
            img = cv2.imread(image_file)
A
andyjpaddle 已提交
299 300 301 302 303 304 305 306 307 308 309 310 311 312
        if not flag_pdf:
            if img is None:
                logger.debug("error in loading image:{}".format(image_file))
                continue
            imgs = [img]
        else:
            page_num = args.page_num
            if page_num > len(img) or page_num == 0:
                page_num = len(img)
            imgs = img[:page_num]
        for index, img in enumerate(imgs):
            st = time.time()
            dt_boxes, _ = text_detector(img)
            elapse = time.time() - st
L
LDOUBLEV 已提交
313
            total_time += elapse
A
andyjpaddle 已提交
314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
            if len(imgs) > 1:
                save_pred = os.path.basename(image_file) + '_' + str(
                    index) + "\t" + str(
                        json.dumps([x.tolist() for x in dt_boxes])) + "\n"
            else:
                save_pred = os.path.basename(image_file) + "\t" + str(
                    json.dumps([x.tolist() for x in dt_boxes])) + "\n"
            save_results.append(save_pred)
            logger.info(save_pred)
            if len(imgs) > 1:
                logger.info("{}_{} The predict time of {}: {}".format(
                    idx, index, image_file, elapse))
            else:
                logger.info("{} The predict time of {}: {}".format(
                    idx, image_file, elapse))
            if flag_pdf:
                src_im = utility.draw_text_det_res(dt_boxes, img, flag_pdf)
            else:
                src_im = utility.draw_text_det_res(dt_boxes, image_file,
                                                   flag_pdf)
            if flag_gif:
                save_file = image_file[:-3] + "png"
            elif flag_pdf:
                save_file = image_file.replace('.pdf',
                                               '_' + str(index) + '.png')
            else:
                save_file = image_file
            img_path = os.path.join(
                draw_img_save_dir,
                "det_res_{}".format(os.path.basename(save_file)))
            cv2.imwrite(img_path, src_im)
            logger.info("The visualized image saved in {}".format(img_path))
L
LDOUBLEV 已提交
346

A
andyjpaddle 已提交
347
    with open(os.path.join(draw_img_save_dir, "det_results.txt"), 'w') as f:
L
LDOUBLEV 已提交
348 349
        f.writelines(save_results)
        f.close()
D
Double_V 已提交
350 351
    if args.benchmark:
        text_detector.autolog.report()