infer.py 40.8 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import yaml
G
Guanghua Yu 已提交
17
import glob
18 19
import json
from pathlib import Path
Q
qingqing01 已提交
20 21 22 23
from functools import reduce

import cv2
import numpy as np
C
cnn 已提交
24
import math
Q
qingqing01 已提交
25 26 27 28
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor

W
wangguanzhong 已提交
29 30 31 32 33
import sys
# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path)

34
from benchmark_utils import PaddleInferBenchmark
35
from picodet_postprocess import PicoDetPostProcess
F
Feng Ni 已提交
36
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image
W
wangguanzhong 已提交
37
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
G
Guanghua Yu 已提交
38
from visualize import visualize_box_mask
39
from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco_clsid2catid
G
Guanghua Yu 已提交
40

Q
qingqing01 已提交
41 42
# Global dictionary
SUPPORT_MODELS = {
J
JYChen 已提交
43 44
    'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE',
    'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet',
J
JYChen 已提交
45
    'StrongBaseline', 'STGCN', 'YOLOX', 'PPHGNet', 'PPLCNet'
Q
qingqing01 已提交
46 47 48
}


W
wangguanzhong 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
def bench_log(detector, img_list, model_info, batch_size=1, name=None):
    mems = {
        'cpu_rss_mb': detector.cpu_mem / len(img_list),
        'gpu_rss_mb': detector.gpu_mem / len(img_list),
        'gpu_util': detector.gpu_util * 100 / len(img_list)
    }
    perf_info = detector.det_times.report(average=True)
    data_info = {
        'batch_size': batch_size,
        'shape': "dynamic_shape",
        'data_num': perf_info['img_num']
    }
    log = PaddleInferBenchmark(detector.config, model_info, data_info,
                               perf_info, mems)
    log(name)


Q
qingqing01 已提交
66 67 68
class Detector(object):
    """
    Args:
69
        pred_config (object): config of model, defined by `Config(model_dir)`
Q
qingqing01 已提交
70
        model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
G
Guanghua Yu 已提交
71
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
72
        run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
73
        batch_size (int): size of pre batch in inference
74 75 76
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
77 78 79 80
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
        cpu_threads (int): cpu threads
        enable_mkldnn (bool): whether to open MKLDNN
81
        enable_mkldnn_bfloat16 (bool): whether to turn on mkldnn bfloat16
W
wangguanzhong 已提交
82 83
        output_dir (str): The path of output
        threshold (float): The threshold of score for visualization
J
JYChen 已提交
84 85
        delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT. 
                                    Used by action model.
Q
qingqing01 已提交
86 87
    """

J
JYChen 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
    def __init__(self,
                 model_dir,
                 device='CPU',
                 run_mode='paddle',
                 batch_size=1,
                 trt_min_shape=1,
                 trt_max_shape=1280,
                 trt_opt_shape=640,
                 trt_calib_mode=False,
                 cpu_threads=1,
                 enable_mkldnn=False,
                 enable_mkldnn_bfloat16=False,
                 output_dir='output',
                 threshold=0.5,
                 delete_shuffle_pass=False):
W
wangguanzhong 已提交
103
        self.pred_config = self.set_config(model_dir)
104
        self.predictor, self.config = load_predictor(
Q
qingqing01 已提交
105 106
            model_dir,
            run_mode=run_mode,
107
            batch_size=batch_size,
Q
qingqing01 已提交
108
            min_subgraph_size=self.pred_config.min_subgraph_size,
G
Guanghua Yu 已提交
109
            device=device,
110
            use_dynamic_shape=self.pred_config.use_dynamic_shape,
111 112
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
G
Guanghua Yu 已提交
113
            trt_opt_shape=trt_opt_shape,
114 115
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
116
            enable_mkldnn=enable_mkldnn,
J
JYChen 已提交
117 118
            enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
            delete_shuffle_pass=delete_shuffle_pass)
G
Guanghua Yu 已提交
119 120
        self.det_times = Timer()
        self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
W
wangguanzhong 已提交
121 122 123 124 125 126
        self.batch_size = batch_size
        self.output_dir = output_dir
        self.threshold = threshold

    def set_config(self, model_dir):
        return PredictConfig(model_dir)
Q
qingqing01 已提交
127

C
cnn 已提交
128
    def preprocess(self, image_list):
Q
qingqing01 已提交
129 130 131 132 133
        preprocess_ops = []
        for op_info in self.pred_config.preprocess_infos:
            new_op_info = op_info.copy()
            op_type = new_op_info.pop('type')
            preprocess_ops.append(eval(op_type)(**new_op_info))
C
cnn 已提交
134 135 136 137

        input_im_lst = []
        input_im_info_lst = []
        for im_path in image_list:
138
            im, im_info = preprocess(im_path, preprocess_ops)
C
cnn 已提交
139 140 141
            input_im_lst.append(im)
            input_im_info_lst.append(im_info)
        inputs = create_inputs(input_im_lst, input_im_info_lst)
W
wangguanzhong 已提交
142 143 144
        input_names = self.predictor.get_input_names()
        for i in range(len(input_names)):
            input_tensor = self.predictor.get_input_handle(input_names[i])
145 146 147 148
            if input_names[i] == 'x':
                input_tensor.copy_from_cpu(inputs['image'])
            else:
                input_tensor.copy_from_cpu(inputs[input_names[i]])
W
wangguanzhong 已提交
149

Q
qingqing01 已提交
150 151
        return inputs

W
wangguanzhong 已提交
152
    def postprocess(self, inputs, result):
Q
qingqing01 已提交
153
        # postprocess output of predictor
W
wangguanzhong 已提交
154
        np_boxes_num = result['boxes_num']
155 156 157 158
        assert isinstance(np_boxes_num, np.ndarray), \
            '`np_boxes_num` should be a `numpy.ndarray`'

        if np_boxes_num.sum() <= 0:
159
            print('[WARNNING] No object detected.')
W
wangguanzhong 已提交
160
            result = {'boxes': np.zeros([0, 6]), 'boxes_num': np_boxes_num}
161 162
        result = {k: v for k, v in result.items() if v is not None}
        return result
Q
qingqing01 已提交
163

164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
    def filter_box(self, result, threshold):
        np_boxes_num = result['boxes_num']
        boxes = result['boxes']
        start_idx = 0
        filter_boxes = []
        filter_num = []
        for i in range(len(np_boxes_num)):
            boxes_num = np_boxes_num[i]
            boxes_i = boxes[start_idx:start_idx + boxes_num, :]
            idx = boxes_i[:, 1] > threshold
            filter_boxes_i = boxes_i[idx, :]
            filter_boxes.append(filter_boxes_i)
            filter_num.append(filter_boxes_i.shape[0])
            start_idx += boxes_num
        boxes = np.concatenate(filter_boxes)
        filter_num = np.array(filter_num)
        filter_res = {'boxes': boxes, 'boxes_num': filter_num}
        return filter_res

W
wangguanzhong 已提交
183
    def predict(self, repeats=1):
Q
qingqing01 已提交
184 185
        '''
        Args:
W
wangguanzhong 已提交
186
            repeats (int): repeats number for prediction
Q
qingqing01 已提交
187
        Returns:
W
wangguanzhong 已提交
188
            result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
Q
qingqing01 已提交
189
                            matix element:[class, score, x_min, y_min, x_max, y_max]
W
wangguanzhong 已提交
190
                            MaskRCNN's result include 'masks': np.ndarray:
G
Guanghua Yu 已提交
191
                            shape: [N, im_h, im_w]
Q
qingqing01 已提交
192
        '''
W
wangguanzhong 已提交
193
        # model prediction
194
        np_boxes_num, np_boxes, np_masks = np.array([0]), None, None
Q
qingqing01 已提交
195 196 197 198 199
        for i in range(repeats):
            self.predictor.run()
            output_names = self.predictor.get_output_names()
            boxes_tensor = self.predictor.get_output_handle(output_names[0])
            np_boxes = boxes_tensor.copy_to_cpu()
C
cnn 已提交
200 201
            boxes_num = self.predictor.get_output_handle(output_names[1])
            np_boxes_num = boxes_num.copy_to_cpu()
G
Guanghua Yu 已提交
202
            if self.pred_config.mask:
Q
qingqing01 已提交
203 204
                masks_tensor = self.predictor.get_output_handle(output_names[2])
                np_masks = masks_tensor.copy_to_cpu()
W
wangguanzhong 已提交
205 206 207 208 209 210 211 212 213 214 215 216
        result = dict(boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num)
        return result

    def merge_batch_result(self, batch_result):
        if len(batch_result) == 1:
            return batch_result[0]
        res_key = batch_result[0].keys()
        results = {k: [] for k in res_key}
        for res in batch_result:
            for k, v in res.items():
                results[k].append(v)
        for k, v in results.items():
217
            if k not in ['masks', 'segm']:
W
wangguanzhong 已提交
218
                results[k] = np.concatenate(v)
W
wangguanzhong 已提交
219
        return results
Q
qingqing01 已提交
220

W
wangguanzhong 已提交
221 222
    def get_timer(self):
        return self.det_times
W
wangguanzhong 已提交
223

224 225 226 227 228 229
    def predict_image_slice(self,
                            img_list,
                            slice_size=[640, 640],
                            overlap_ratio=[0.25, 0.25],
                            combine_method='nms',
                            match_threshold=0.6,
F
Feng Ni 已提交
230 231 232
                            match_metric='ios',
                            run_benchmark=False,
                            repeats=1,
233
                            visual=True,
234
                            save_results=False):
235 236 237 238 239 240
        # slice infer only support bs=1
        results = []
        try:
            import sahi
            from sahi.slicing import slice_image
        except Exception as e:
F
Feng Ni 已提交
241
            print(
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
                'sahi not found, plaese install sahi. '
                'for example: `pip install sahi`, see https://github.com/obss/sahi.'
            )
            raise e
        num_classes = len(self.pred_config.labels)
        for i in range(len(img_list)):
            ori_image = img_list[i]
            slice_image_result = sahi.slicing.slice_image(
                image=ori_image,
                slice_height=slice_size[0],
                slice_width=slice_size[1],
                overlap_height_ratio=overlap_ratio[0],
                overlap_width_ratio=overlap_ratio[1])
            sub_img_num = len(slice_image_result)
            merged_bboxs = []
F
Feng Ni 已提交
257
            print('sub_img_num', sub_img_num)
F
Feng Ni 已提交
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287

            batch_image_list = [
                slice_image_result.images[_ind] for _ind in range(sub_img_num)
            ]
            if run_benchmark:
                # preprocess
                inputs = self.preprocess(batch_image_list)  # warmup
                self.det_times.preprocess_time_s.start()
                inputs = self.preprocess(batch_image_list)
                self.det_times.preprocess_time_s.end()

                # model prediction
                result = self.predict(repeats=50)  # warmup
                self.det_times.inference_time_s.start()
                result = self.predict(repeats=repeats)
                self.det_times.inference_time_s.end(repeats=repeats)

                # postprocess
                result_warmup = self.postprocess(inputs, result)  # warmup
                self.det_times.postprocess_time_s.start()
                result = self.postprocess(inputs, result)
                self.det_times.postprocess_time_s.end()
                self.det_times.img_num += 1

                cm, gm, gu = get_current_memory_mb()
                self.cpu_mem += cm
                self.gpu_mem += gm
                self.gpu_util += gu
            else:
                # preprocess
288
                self.det_times.preprocess_time_s.start()
F
Feng Ni 已提交
289
                inputs = self.preprocess(batch_image_list)
290 291 292 293 294 295 296 297 298 299 300 301 302
                self.det_times.preprocess_time_s.end()

                # model prediction
                self.det_times.inference_time_s.start()
                result = self.predict()
                self.det_times.inference_time_s.end()

                # postprocess
                self.det_times.postprocess_time_s.start()
                result = self.postprocess(inputs, result)
                self.det_times.postprocess_time_s.end()
                self.det_times.img_num += 1

F
Feng Ni 已提交
303 304 305 306
            st, ed = 0, result['boxes_num'][0]  # start_index, end_index
            for _ind in range(sub_img_num):
                boxes_num = result['boxes_num'][_ind]
                ed = boxes_num
307
                shift_amount = slice_image_result.starting_pixels[_ind]
F
Feng Ni 已提交
308 309 310 311 312 313
                result['boxes'][st:ed][:, 2:4] = result['boxes'][
                    st:ed][:, 2:4] + shift_amount
                result['boxes'][st:ed][:, 4:6] = result['boxes'][
                    st:ed][:, 4:6] + shift_amount
                merged_bboxs.append(result['boxes'][st:ed])
                st = ed
314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338

            merged_results = {'boxes': []}
            if combine_method == 'nms':
                final_boxes = multiclass_nms(
                    np.concatenate(merged_bboxs), num_classes, match_threshold,
                    match_metric)
                merged_results['boxes'] = np.concatenate(final_boxes)
            elif combine_method == 'concat':
                merged_results['boxes'] = np.concatenate(merged_bboxs)
            else:
                raise ValueError(
                    "Now only support 'nms' or 'concat' to fuse detection results."
                )
            merged_results['boxes_num'] = np.array(
                [len(merged_results['boxes'])], dtype=np.int32)

            if visual:
                visualize(
                    [ori_image],  # should be list
                    merged_results,
                    self.pred_config.labels,
                    output_dir=self.output_dir,
                    threshold=self.threshold)

            results.append(merged_results)
339
            print('Test iter {}'.format(i))
340 341

        results = self.merge_batch_result(results)
342 343 344 345
        if save_results:
            Path(self.output_dir).mkdir(exist_ok=True)
            self.save_coco_results(
                img_list, results, use_coco_category=FLAGS.use_coco_category)
346 347
        return results

W
wangguanzhong 已提交
348 349 350 351
    def predict_image(self,
                      image_list,
                      run_benchmark=False,
                      repeats=1,
352
                      visual=True,
353
                      save_results=False):
W
wangguanzhong 已提交
354
        batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
Q
qingqing01 已提交
355
        results = []
W
wangguanzhong 已提交
356 357 358 359 360 361 362 363 364 365 366 367
        for i in range(batch_loop_cnt):
            start_index = i * self.batch_size
            end_index = min((i + 1) * self.batch_size, len(image_list))
            batch_image_list = image_list[start_index:end_index]
            if run_benchmark:
                # preprocess
                inputs = self.preprocess(batch_image_list)  # warmup
                self.det_times.preprocess_time_s.start()
                inputs = self.preprocess(batch_image_list)
                self.det_times.preprocess_time_s.end()

                # model prediction
368
                result = self.predict(repeats=50)  # warmup
W
wangguanzhong 已提交
369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
                self.det_times.inference_time_s.start()
                result = self.predict(repeats=repeats)
                self.det_times.inference_time_s.end(repeats=repeats)

                # postprocess
                result_warmup = self.postprocess(inputs, result)  # warmup
                self.det_times.postprocess_time_s.start()
                result = self.postprocess(inputs, result)
                self.det_times.postprocess_time_s.end()
                self.det_times.img_num += len(batch_image_list)

                cm, gm, gu = get_current_memory_mb()
                self.cpu_mem += cm
                self.gpu_mem += gm
                self.gpu_util += gu
            else:
                # preprocess
                self.det_times.preprocess_time_s.start()
                inputs = self.preprocess(batch_image_list)
                self.det_times.preprocess_time_s.end()

                # model prediction
                self.det_times.inference_time_s.start()
                result = self.predict()
                self.det_times.inference_time_s.end()

                # postprocess
                self.det_times.postprocess_time_s.start()
                result = self.postprocess(inputs, result)
                self.det_times.postprocess_time_s.end()
                self.det_times.img_num += len(batch_image_list)

                if visual:
                    visualize(
                        batch_image_list,
                        result,
                        self.pred_config.labels,
                        output_dir=self.output_dir,
                        threshold=self.threshold)
            results.append(result)
409
            print('Test iter {}'.format(i))
W
wangguanzhong 已提交
410
        results = self.merge_batch_result(results)
411 412 413 414
        if save_results:
            Path(self.output_dir).mkdir(exist_ok=True)
            self.save_coco_results(
                image_list, results, use_coco_category=FLAGS.use_coco_category)
Q
qingqing01 已提交
415 416
        return results

W
wangguanzhong 已提交
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
    def predict_video(self, video_file, camera_id):
        video_out_name = 'output.mp4'
        if camera_id != -1:
            capture = cv2.VideoCapture(camera_id)
        else:
            capture = cv2.VideoCapture(video_file)
            video_out_name = os.path.split(video_file)[-1]
        # Get Video info : resolution, fps, frame count
        width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(capture.get(cv2.CAP_PROP_FPS))
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        print("fps: %d, frame_count: %d" % (fps, frame_count))

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        out_path = os.path.join(self.output_dir, video_out_name)
434
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
W
wangguanzhong 已提交
435 436 437 438 439 440 441 442
        writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
        index = 1
        while (1):
            ret, frame = capture.read()
            if not ret:
                break
            print('detect frame: %d' % (index))
            index += 1
L
lazyn1997 已提交
443
            results = self.predict_image([frame[:, :, ::-1]], visual=False)
W
wangguanzhong 已提交
444 445 446 447 448 449 450 451 452 453 454 455 456

            im = visualize_box_mask(
                frame,
                results,
                self.pred_config.labels,
                threshold=self.threshold)
            im = np.array(im)
            writer.write(im)
            if camera_id != -1:
                cv2.imshow('Mask Detection', im)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
        writer.release()
W
wangguanzhong 已提交
457

458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497
    def save_coco_results(self, image_list, results, use_coco_category=False):
        bbox_results = []
        mask_results = []
        idx = 0
        print("Start saving coco json files...")
        for i, box_num in enumerate(results['boxes_num']):
            file_name = os.path.split(image_list[i])[-1]
            if use_coco_category:
                img_id = int(os.path.splitext(file_name)[0])
            else:
                img_id = i

            if 'boxes' in results:
                boxes = results['boxes'][idx:idx + box_num].tolist()
                bbox_results.extend([{
                    'image_id': img_id,
                    'category_id': coco_clsid2catid[int(box[0])] \
                        if use_coco_category else int(box[0]),
                    'file_name': file_name,
                    'bbox': [box[2], box[3], box[4] - box[2],
                         box[5] - box[3]],  # xyxy -> xywh
                    'score': box[1]} for box in boxes])

            if 'masks' in results:
                import pycocotools.mask as mask_util

                boxes = results['boxes'][idx:idx + box_num].tolist()
                masks = results['masks'][i][:box_num].astype(np.uint8)
                seg_res = []
                for box, mask in zip(boxes, masks):
                    rle = mask_util.encode(
                        np.array(
                            mask[:, :, None], dtype=np.uint8, order="F"))[0]
                    if 'counts' in rle:
                        rle['counts'] = rle['counts'].decode("utf8")
                    seg_res.append({
                        'image_id': img_id,
                        'category_id': coco_clsid2catid[int(box[0])] \
                        if use_coco_category else int(box[0]),
                        'file_name': file_name,
498
                        'segmentation': rle,
499 500
                        'score': box[1]})
                mask_results.extend(seg_res)
501

502
            idx += box_num
503

504 505 506 507 508 509 510 511 512 513
        if bbox_results:
            bbox_file = os.path.join(self.output_dir, "bbox.json")
            with open(bbox_file, 'w') as f:
                json.dump(bbox_results, f)
            print(f"The bbox result is saved to {bbox_file}")
        if mask_results:
            mask_file = os.path.join(self.output_dir, "mask.json")
            with open(mask_file, 'w') as f:
                json.dump(mask_results, f)
            print(f"The mask result is saved to {mask_file}")
514

Q
qingqing01 已提交
515

G
Guanghua Yu 已提交
516 517 518 519
class DetectorSOLOv2(Detector):
    """
    Args:
        model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
G
Guanghua Yu 已提交
520
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
521
        run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
522
        batch_size (int): size of pre batch in inference
523 524 525
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
526 527 528 529
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
        cpu_threads (int): cpu threads
        enable_mkldnn (bool): whether to open MKLDNN 
530
        enable_mkldnn_bfloat16 (bool): Whether to turn on mkldnn bfloat16
W
wangguanzhong 已提交
531 532 533
        output_dir (str): The path of output
        threshold (float): The threshold of score for visualization
       
G
Guanghua Yu 已提交
534 535
    """

W
wangguanzhong 已提交
536 537
    def __init__(
            self,
G
Guanghua Yu 已提交
538
            model_dir,
W
wangguanzhong 已提交
539 540 541 542 543 544 545 546 547
            device='CPU',
            run_mode='paddle',
            batch_size=1,
            trt_min_shape=1,
            trt_max_shape=1280,
            trt_opt_shape=640,
            trt_calib_mode=False,
            cpu_threads=1,
            enable_mkldnn=False,
548
            enable_mkldnn_bfloat16=False,
W
wangguanzhong 已提交
549 550 551 552 553
            output_dir='./',
            threshold=0.5, ):
        super(DetectorSOLOv2, self).__init__(
            model_dir=model_dir,
            device=device,
G
Guanghua Yu 已提交
554
            run_mode=run_mode,
555
            batch_size=batch_size,
556 557
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
G
Guanghua Yu 已提交
558
            trt_opt_shape=trt_opt_shape,
559 560
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
W
wangguanzhong 已提交
561
            enable_mkldnn=enable_mkldnn,
562
            enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
W
wangguanzhong 已提交
563 564
            output_dir=output_dir,
            threshold=threshold, )
G
Guanghua Yu 已提交
565

W
wangguanzhong 已提交
566
    def predict(self, repeats=1):
G
Guanghua Yu 已提交
567 568
        '''
        Args:
W
wangguanzhong 已提交
569
            repeats (int): repeat number for prediction
G
Guanghua Yu 已提交
570
        Returns:
W
wangguanzhong 已提交
571
            result (dict): 'segm': np.ndarray,shape:[N, im_h, im_w]
G
Guanghua Yu 已提交
572 573
                            'cate_label': label of segm, shape:[N]
                            'cate_score': confidence score of segm, shape:[N]
G
Guanghua Yu 已提交
574 575 576 577 578
        '''
        np_label, np_score, np_segms = None, None, None
        for i in range(repeats):
            self.predictor.run()
            output_names = self.predictor.get_output_names()
W
wangguanzhong 已提交
579 580
            np_boxes_num = self.predictor.get_output_handle(output_names[
                0]).copy_to_cpu()
G
Guanghua Yu 已提交
581 582
            np_label = self.predictor.get_output_handle(output_names[
                1]).copy_to_cpu()
G
Guanghua Yu 已提交
583
            np_score = self.predictor.get_output_handle(output_names[
G
Guanghua Yu 已提交
584
                2]).copy_to_cpu()
G
Guanghua Yu 已提交
585 586
            np_segms = self.predictor.get_output_handle(output_names[
                3]).copy_to_cpu()
G
Guanghua Yu 已提交
587

W
wangguanzhong 已提交
588
        result = dict(
W
wangguanzhong 已提交
589 590 591 592
            segm=np_segms,
            label=np_label,
            score=np_score,
            boxes_num=np_boxes_num)
W
wangguanzhong 已提交
593
        return result
G
Guanghua Yu 已提交
594 595


596 597 598 599 600
class DetectorPicoDet(Detector):
    """
    Args:
        model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
601
        run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
602 603 604 605 606 607 608
        batch_size (int): size of pre batch in inference
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
        cpu_threads (int): cpu threads
609 610
        enable_mkldnn (bool): whether to turn on MKLDNN
        enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16
611 612
    """

W
wangguanzhong 已提交
613 614
    def __init__(
            self,
615
            model_dir,
W
wangguanzhong 已提交
616 617 618 619 620 621 622 623 624
            device='CPU',
            run_mode='paddle',
            batch_size=1,
            trt_min_shape=1,
            trt_max_shape=1280,
            trt_opt_shape=640,
            trt_calib_mode=False,
            cpu_threads=1,
            enable_mkldnn=False,
625
            enable_mkldnn_bfloat16=False,
W
wangguanzhong 已提交
626 627 628 629 630
            output_dir='./',
            threshold=0.5, ):
        super(DetectorPicoDet, self).__init__(
            model_dir=model_dir,
            device=device,
631 632 633 634 635 636 637
            run_mode=run_mode,
            batch_size=batch_size,
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
            trt_opt_shape=trt_opt_shape,
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
W
wangguanzhong 已提交
638
            enable_mkldnn=enable_mkldnn,
639
            enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
W
wangguanzhong 已提交
640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655
            output_dir=output_dir,
            threshold=threshold, )

    def postprocess(self, inputs, result):
        # postprocess output of predictor
        np_score_list = result['boxes']
        np_boxes_list = result['boxes_num']
        postprocessor = PicoDetPostProcess(
            inputs['image'].shape[2:],
            inputs['im_shape'],
            inputs['scale_factor'],
            strides=self.pred_config.fpn_stride,
            nms_threshold=self.pred_config.nms['nms_threshold'])
        np_boxes, np_boxes_num = postprocessor(np_score_list, np_boxes_list)
        result = dict(boxes=np_boxes, boxes_num=np_boxes_num)
        return result
656

W
wangguanzhong 已提交
657
    def predict(self, repeats=1):
658 659
        '''
        Args:
W
wangguanzhong 已提交
660
            repeats (int): repeat number for prediction
661
        Returns:
W
wangguanzhong 已提交
662
            result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678
                            matix element:[class, score, x_min, y_min, x_max, y_max]
        '''
        np_score_list, np_boxes_list = [], []
        for i in range(repeats):
            self.predictor.run()
            np_score_list.clear()
            np_boxes_list.clear()
            output_names = self.predictor.get_output_names()
            num_outs = int(len(output_names) / 2)
            for out_idx in range(num_outs):
                np_score_list.append(
                    self.predictor.get_output_handle(output_names[out_idx])
                    .copy_to_cpu())
                np_boxes_list.append(
                    self.predictor.get_output_handle(output_names[
                        out_idx + num_outs]).copy_to_cpu())
W
wangguanzhong 已提交
679 680
        result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
        return result
681 682


C
cnn 已提交
683
def create_inputs(imgs, im_info):
Q
qingqing01 已提交
684 685
    """generate input for different model type
    Args:
W
wangguanzhong 已提交
686 687
        imgs (list(numpy)): list of images (np.ndarray)
        im_info (list(dict)): list of image info
Q
qingqing01 已提交
688 689 690 691 692
    Returns:
        inputs (dict): input of model
    """
    inputs = {}

C
cnn 已提交
693 694
    im_shape = []
    scale_factor = []
695 696 697 698 699 700 701 702
    if len(imgs) == 1:
        inputs['image'] = np.array((imgs[0], )).astype('float32')
        inputs['im_shape'] = np.array(
            (im_info[0]['im_shape'], )).astype('float32')
        inputs['scale_factor'] = np.array(
            (im_info[0]['scale_factor'], )).astype('float32')
        return inputs

C
cnn 已提交
703 704 705 706
    for e in im_info:
        im_shape.append(np.array((e['im_shape'], )).astype('float32'))
        scale_factor.append(np.array((e['scale_factor'], )).astype('float32'))

C
cnn 已提交
707 708
    inputs['im_shape'] = np.concatenate(im_shape, axis=0)
    inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
C
cnn 已提交
709 710 711 712 713 714 715 716 717 718 719 720

    imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
    max_shape_h = max([e[0] for e in imgs_shape])
    max_shape_w = max([e[1] for e in imgs_shape])
    padding_imgs = []
    for img in imgs:
        im_c, im_h, im_w = img.shape[:]
        padding_im = np.zeros(
            (im_c, max_shape_h, max_shape_w), dtype=np.float32)
        padding_im[:, :im_h, :im_w] = img
        padding_imgs.append(padding_im)
    inputs['image'] = np.stack(padding_imgs, axis=0)
Q
qingqing01 已提交
721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739
    return inputs


class PredictConfig():
    """set config of preprocess, postprocess and visualize
    Args:
        model_dir (str): root path of model.yml
    """

    def __init__(self, model_dir):
        # parsing Yaml config for Preprocess
        deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
        with open(deploy_file) as f:
            yml_conf = yaml.safe_load(f)
        self.check_model(yml_conf)
        self.arch = yml_conf['arch']
        self.preprocess_infos = yml_conf['Preprocess']
        self.min_subgraph_size = yml_conf['min_subgraph_size']
        self.labels = yml_conf['label_list']
G
Guanghua Yu 已提交
740
        self.mask = False
741
        self.use_dynamic_shape = yml_conf['use_dynamic_shape']
G
Guanghua Yu 已提交
742 743
        if 'mask' in yml_conf:
            self.mask = yml_conf['mask']
744 745 746
        self.tracker = None
        if 'tracker' in yml_conf:
            self.tracker = yml_conf['tracker']
747 748 749 750
        if 'NMS' in yml_conf:
            self.nms = yml_conf['NMS']
        if 'fpn_stride' in yml_conf:
            self.fpn_stride = yml_conf['fpn_stride']
751 752 753 754
        if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
            print(
                'The RCNN export model is used for ONNX and it only supports batch_size = 1'
            )
Q
qingqing01 已提交
755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777
        self.print_config()

    def check_model(self, yml_conf):
        """
        Raises:
            ValueError: loaded model not in supported model type 
        """
        for support_model in SUPPORT_MODELS:
            if support_model in yml_conf['arch']:
                return True
        raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
            'arch'], SUPPORT_MODELS))

    def print_config(self):
        print('-----------  Model Configuration -----------')
        print('%s: %s' % ('Model Arch', self.arch))
        print('%s: ' % ('Transform Order'))
        for op_info in self.preprocess_infos:
            print('--%s: %s' % ('transform op', op_info['type']))
        print('--------------------------------------------')


def load_predictor(model_dir,
778
                   run_mode='paddle',
Q
qingqing01 已提交
779
                   batch_size=1,
G
Guanghua Yu 已提交
780
                   device='CPU',
781 782 783 784
                   min_subgraph_size=3,
                   use_dynamic_shape=False,
                   trt_min_shape=1,
                   trt_max_shape=1280,
G
Guanghua Yu 已提交
785
                   trt_opt_shape=640,
786 787
                   trt_calib_mode=False,
                   cpu_threads=1,
788
                   enable_mkldnn=False,
J
JYChen 已提交
789 790
                   enable_mkldnn_bfloat16=False,
                   delete_shuffle_pass=False):
Q
qingqing01 已提交
791 792 793
    """set AnalysisConfig, generate AnalysisPredictor
    Args:
        model_dir (str): root path of __model__ and __params__
G
Guanghua Yu 已提交
794
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
795
        run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
796 797 798 799
        use_dynamic_shape (bool): use dynamic shape or not
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
G
Guanghua Yu 已提交
800 801
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
J
JYChen 已提交
802 803
        delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT. 
                                    Used by action model.
Q
qingqing01 已提交
804 805 806
    Returns:
        predictor (PaddlePredictor): AnalysisPredictor
    Raises:
G
Guanghua Yu 已提交
807
        ValueError: predict by TensorRT need device == 'GPU'.
Q
qingqing01 已提交
808
    """
809
    if device != 'GPU' and run_mode != 'paddle':
Q
qingqing01 已提交
810
        raise ValueError(
G
Guanghua Yu 已提交
811 812
            "Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
            .format(run_mode, device))
813 814 815 816 817 818 819 820 821
    infer_model = os.path.join(model_dir, 'model.pdmodel')
    infer_params = os.path.join(model_dir, 'model.pdiparams')
    if not os.path.exists(infer_model):
        infer_model = os.path.join(model_dir, 'inference.pdmodel')
        infer_params = os.path.join(model_dir, 'inference.pdiparams')
        if not os.path.exists(infer_model):
            raise ValueError(
                "Cannot find any inference model in dir: {},".format(model_dir))
    config = Config(infer_model, infer_params)
G
Guanghua Yu 已提交
822
    if device == 'GPU':
Q
qingqing01 已提交
823 824 825
        # initial GPU memory(M), device ID
        config.enable_use_gpu(200, 0)
        # optimize graph and fuse op
826
        config.switch_ir_optim(True)
G
Guanghua Yu 已提交
827
    elif device == 'XPU':
828 829
        if config.lite_engine_enabled():
            config.enable_lite_engine()
G
Guanghua Yu 已提交
830
        config.enable_xpu(10 * 1024 * 1024)
831 832 833 834
    elif device == 'NPU':
        if config.lite_engine_enabled():
            config.enable_lite_engine()
        config.enable_npu()
Q
qingqing01 已提交
835 836
    else:
        config.disable_gpu()
837 838
        config.set_cpu_math_library_num_threads(cpu_threads)
        if enable_mkldnn:
G
Guanghua Yu 已提交
839 840 841 842
            try:
                # cache 10 different shapes for mkldnn to avoid memory leak
                config.set_mkldnn_cache_capacity(10)
                config.enable_mkldnn()
843 844
                if enable_mkldnn_bfloat16:
                    config.enable_mkldnn_bfloat16()
G
Guanghua Yu 已提交
845 846 847 848 849
            except Exception as e:
                print(
                    "The current environment does not support `mkldnn`, so disable mkldnn."
                )
                pass
Q
qingqing01 已提交
850

G
Guanghua Yu 已提交
851 852 853 854 855
    precision_map = {
        'trt_int8': Config.Precision.Int8,
        'trt_fp32': Config.Precision.Float32,
        'trt_fp16': Config.Precision.Half
    }
Q
qingqing01 已提交
856 857
    if run_mode in precision_map.keys():
        config.enable_tensorrt_engine(
W
wangxinxin08 已提交
858
            workspace_size=(1 << 25) * batch_size,
Q
qingqing01 已提交
859 860 861 862
            max_batch_size=batch_size,
            min_subgraph_size=min_subgraph_size,
            precision_mode=precision_map[run_mode],
            use_static=False,
G
Guanghua Yu 已提交
863
            use_calib_mode=trt_calib_mode)
864 865

        if use_dynamic_shape:
866
            min_input_shape = {
W
wangxinxin08 已提交
867 868
                'image': [batch_size, 3, trt_min_shape, trt_min_shape],
                'scale_factor': [batch_size, 2]
869 870
            }
            max_input_shape = {
W
wangxinxin08 已提交
871 872
                'image': [batch_size, 3, trt_max_shape, trt_max_shape],
                'scale_factor': [batch_size, 2]
873 874
            }
            opt_input_shape = {
W
wangxinxin08 已提交
875 876
                'image': [batch_size, 3, trt_opt_shape, trt_opt_shape],
                'scale_factor': [batch_size, 2]
877
            }
878 879 880
            config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
                                              opt_input_shape)
            print('trt set dynamic shape done!')
Q
qingqing01 已提交
881 882 883 884 885 886 887

    # disable print log when predict
    config.disable_glog_info()
    # enable shared memory
    config.enable_memory_optim()
    # disable feed, fetch OP, needed by zero_copy_run
    config.switch_use_feed_fetch_ops(False)
J
JYChen 已提交
888 889
    if delete_shuffle_pass:
        config.delete_pass("shuffle_channel_detect_pass")
Q
qingqing01 已提交
890
    predictor = create_predictor(config)
891
    return predictor, config
Q
qingqing01 已提交
892 893


G
Guanghua Yu 已提交
894 895 896 897 898
def get_test_images(infer_dir, infer_img):
    """
    Get image path list in TEST mode
    """
    assert infer_img is not None or infer_dir is not None, \
J
JYChen 已提交
899
        "--image_file or --image_dir should be set"
G
Guanghua Yu 已提交
900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924
    assert infer_img is None or os.path.isfile(infer_img), \
            "{} is not a file".format(infer_img)
    assert infer_dir is None or os.path.isdir(infer_dir), \
            "{} is not a directory".format(infer_dir)

    # infer_img has a higher priority
    if infer_img and os.path.isfile(infer_img):
        return [infer_img]

    images = set()
    infer_dir = os.path.abspath(infer_dir)
    assert os.path.isdir(infer_dir), \
        "infer_dir {} is not a directory".format(infer_dir)
    exts = ['jpg', 'jpeg', 'png', 'bmp']
    exts += [ext.upper() for ext in exts]
    for ext in exts:
        images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
    images = list(images)

    assert len(images) > 0, "no image found in {}".format(infer_dir)
    print("Found {} inference images in total.".format(len(images)))

    return images


W
wangguanzhong 已提交
925
def visualize(image_list, result, labels, output_dir='output/', threshold=0.5):
Q
qingqing01 已提交
926
    # visualize the predict result
C
cnn 已提交
927 928
    start_idx = 0
    for idx, image_file in enumerate(image_list):
W
wangguanzhong 已提交
929
        im_bboxes_num = result['boxes_num'][idx]
C
cnn 已提交
930
        im_results = {}
W
wangguanzhong 已提交
931 932 933 934 935 936 937 938 939 940 941 942 943 944 945
        if 'boxes' in result:
            im_results['boxes'] = result['boxes'][start_idx:start_idx +
                                                  im_bboxes_num, :]
        if 'masks' in result:
            im_results['masks'] = result['masks'][start_idx:start_idx +
                                                  im_bboxes_num, :]
        if 'segm' in result:
            im_results['segm'] = result['segm'][start_idx:start_idx +
                                                im_bboxes_num, :]
        if 'label' in result:
            im_results['label'] = result['label'][start_idx:start_idx +
                                                  im_bboxes_num]
        if 'score' in result:
            im_results['score'] = result['score'][start_idx:start_idx +
                                                  im_bboxes_num]
W
wangguanzhong 已提交
946

C
cnn 已提交
947 948 949 950 951 952 953 954 955
        start_idx += im_bboxes_num
        im = visualize_box_mask(
            image_file, im_results, labels, threshold=threshold)
        img_name = os.path.split(image_file)[-1]
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        out_path = os.path.join(output_dir, img_name)
        im.save(out_path, quality=95)
        print("save result to: " + out_path)
Q
qingqing01 已提交
956 957 958 959 960 961 962 963 964 965


def print_arguments(args):
    print('-----------  Running Arguments -----------')
    for arg, value in sorted(vars(args).items()):
        print('%s: %s' % (arg, value))
    print('------------------------------------------')


def main():
W
wangguanzhong 已提交
966 967 968 969
    deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
    with open(deploy_file) as f:
        yml_conf = yaml.safe_load(f)
    arch = yml_conf['arch']
970
    detector_func = 'Detector'
W
wangguanzhong 已提交
971
    if arch == 'SOLOv2':
972
        detector_func = 'DetectorSOLOv2'
W
wangguanzhong 已提交
973
    elif arch == 'PicoDet':
974 975
        detector_func = 'DetectorPicoDet'

976 977 978 979 980 981 982 983 984 985 986 987 988 989
    detector = eval(detector_func)(
        FLAGS.model_dir,
        device=FLAGS.device,
        run_mode=FLAGS.run_mode,
        batch_size=FLAGS.batch_size,
        trt_min_shape=FLAGS.trt_min_shape,
        trt_max_shape=FLAGS.trt_max_shape,
        trt_opt_shape=FLAGS.trt_opt_shape,
        trt_calib_mode=FLAGS.trt_calib_mode,
        cpu_threads=FLAGS.cpu_threads,
        enable_mkldnn=FLAGS.enable_mkldnn,
        enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16,
        threshold=FLAGS.threshold,
        output_dir=FLAGS.output_dir)
G
Guanghua Yu 已提交
990

Q
qingqing01 已提交
991
    # predict from video file or camera video stream
G
Guanghua Yu 已提交
992
    if FLAGS.video_file is not None or FLAGS.camera_id != -1:
W
wangguanzhong 已提交
993
        detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
G
Guanghua Yu 已提交
994 995
    else:
        # predict from image
C
cnn 已提交
996 997
        if FLAGS.image_dir is None and FLAGS.image_file is not None:
            assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
G
Guanghua Yu 已提交
998
        img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
999 1000 1001 1002 1003 1004 1005 1006
        if FLAGS.slice_infer:
            detector.predict_image_slice(
                img_list,
                FLAGS.slice_size,
                FLAGS.overlap_ratio,
                FLAGS.combine_method,
                FLAGS.match_threshold,
                FLAGS.match_metric,
1007 1008
                visual=FLAGS.save_images,
                save_results=FLAGS.save_results)
1009 1010
        else:
            detector.predict_image(
1011 1012 1013 1014 1015
                img_list,
                FLAGS.run_benchmark,
                repeats=100,
                visual=FLAGS.save_images,
                save_results=FLAGS.save_results)
G
Guanghua Yu 已提交
1016 1017 1018
        if not FLAGS.run_benchmark:
            detector.det_times.info(average=True)
        else:
1019
            mode = FLAGS.run_mode
W
wangguanzhong 已提交
1020
            model_dir = FLAGS.model_dir
1021
            model_info = {
1022 1023
                'model_name': model_dir.strip('/').split('/')[-1],
                'precision': mode.split('_')[-1]
1024
            }
W
wangguanzhong 已提交
1025
            bench_log(detector, img_list, model_info, name='DET')
Q
qingqing01 已提交
1026 1027 1028 1029


if __name__ == '__main__':
    paddle.enable_static()
G
Guanghua Yu 已提交
1030
    parser = argsparser()
Q
qingqing01 已提交
1031 1032
    FLAGS = parser.parse_args()
    print_arguments(FLAGS)
G
Guanghua Yu 已提交
1033
    FLAGS.device = FLAGS.device.upper()
1034 1035
    assert FLAGS.device in ['CPU', 'GPU', 'XPU', 'NPU'
                            ], "device should be CPU, GPU, XPU or NPU"
G
Guanghua Yu 已提交
1036
    assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
Q
qingqing01 已提交
1037

1038 1039 1040
    assert not (
        FLAGS.enable_mkldnn == False and FLAGS.enable_mkldnn_bfloat16 == True
    ), 'To enable mkldnn bfloat, please turn on both enable_mkldnn and enable_mkldnn_bfloat16'
1041

Q
qingqing01 已提交
1042
    main()