From 8a878423716e3a746a4c2d8ff55fc9d3e9c4e235 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Tue, 3 Nov 2020 17:07:03 +0800 Subject: [PATCH] add infer (#1640) --- ppdet/modeling/bbox.py | 5 +- ppdet/py_op/post_process.py | 2 - ppdet/utils/check.py | 2 +- ppdet/utils/eval_utils.py | 92 ++++++++------- ppdet/utils/visualizer.py | 14 +-- tools/eval.py | 21 +++- tools/infer.py | 215 ++++++++++++++++++++++++++++++++++++ 7 files changed, 295 insertions(+), 56 deletions(-) create mode 100755 tools/infer.py diff --git a/ppdet/modeling/bbox.py b/ppdet/modeling/bbox.py index cfae09e61..68fb55246 100644 --- a/ppdet/modeling/bbox.py +++ b/ppdet/modeling/bbox.py @@ -97,10 +97,7 @@ class BBoxPostProcessYOLO(object): scores_list.append(paddle.transpose(scores, perm=[0, 2, 1])) yolo_boxes = paddle.concat(boxes_list, axis=1) yolo_scores = paddle.concat(scores_list, axis=2) - bbox = self.nms(bboxes=yolo_boxes, scores=yolo_scores) - # TODO: parse the lod of nmsed_bbox - # default batch size is 1 - bbox_num = np.array([int(bbox.shape[0])], dtype=np.int32) + bbox, bbox_num = self.nms(bboxes=yolo_boxes, scores=yolo_scores) return bbox, bbox_num diff --git a/ppdet/py_op/post_process.py b/ppdet/py_op/post_process.py index 2049bffd0..7ae75bcd6 100755 --- a/ppdet/py_op/post_process.py +++ b/ppdet/py_op/post_process.py @@ -136,8 +136,6 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map): k = 0 for i in range(len(bbox_nums)): image_id = int(image_id[i][0]) - if bboxes.shape == (1, 1): - continue det_nums = bbox_nums[i] for j in range(det_nums): dt = bboxes[k] diff --git a/ppdet/utils/check.py b/ppdet/utils/check.py index 324734ce0..3ebbe889d 100644 --- a/ppdet/utils/check.py +++ b/ppdet/utils/check.py @@ -47,7 +47,7 @@ def check_gpu(use_gpu): pass -def check_version(version='1.7.0'): +def check_version(version='2.0'): """ Log error and exit when the installed version of paddlepaddle is not satisfied. diff --git a/ppdet/utils/eval_utils.py b/ppdet/utils/eval_utils.py index b0bb21a15..91a725615 100644 --- a/ppdet/utils/eval_utils.py +++ b/ppdet/utils/eval_utils.py @@ -3,6 +3,11 @@ from __future__ import division from __future__ import print_function import os +import sys +import json +from ppdet.py_op.post_process import get_det_res, get_seg_res +import logging +logger = logging.getLogger(__name__) def json_eval_results(metric, json_directory=None, dataset=None): @@ -28,49 +33,62 @@ def json_eval_results(metric, json_directory=None, dataset=None): logger.info("{} not exists!".format(v_json)) -def coco_eval_results(outs_res=None, include_mask=False, dataset=None): - print("start evaluate bbox using coco api") - import io - import six - import json - from pycocotools.coco import COCO - from pycocotools.cocoeval import COCOeval - from ppdet.py_op.post_process import get_det_res, get_seg_res - anno_file = os.path.join(dataset.dataset_dir, dataset.anno_path) - cocoGt = COCO(anno_file) - catid = { - i + dataset.with_background: v - for i, v in enumerate(cocoGt.getCatIds()) - } - - if outs_res is not None and len(outs_res) > 0: - det_res = [] - for outs in outs_res: - det_res += get_det_res(outs['bbox'], outs['bbox_num'], - outs['im_id'], catid) +def get_infer_results(outs_res, eval_type, catid): + """ + Get result at the stage of inference. + The output format is dictionary containing bbox or mask result. - with io.open("bbox.json", 'w') as outfile: - encode_func = unicode if six.PY2 else str - outfile.write(encode_func(json.dumps(det_res))) + For example, bbox result is a list and each element contains + image_id, category_id, bbox and score. + """ + if outs_res is None or len(outs_res) == 0: + raise ValueError( + 'The number of valid detection result if zero. Please use reasonable model and check input data.' + ) + infer_res = {} - cocoDt = cocoGt.loadRes("bbox.json") - cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') - cocoEval.evaluate() - cocoEval.accumulate() - cocoEval.summarize() + if 'bbox' in eval_type: + box_res = [] + for outs in outs_res: + box_res += get_det_res(outs['bbox'], outs['bbox_num'], + outs['im_id'], catid) + infer_res['bbox'] = box_res - if outs_res is not None and len(outs_res) > 0 and include_mask: + if 'mask' in eval_type: seg_res = [] for outs in outs_res: seg_res += get_seg_res(outs['mask'], outs['bbox_num'], outs['im_id'], catid) + infer_res['mask'] = seg_res + + return infer_res + + +def eval_results(res, metric, anno_file): + """ + Evalute the inference result + """ + eval_res = [] + if metric == 'COCO': + from ppdet.utils.coco_eval import cocoapi_eval + + if 'bbox' in res: + with open("bbox.json", 'w') as f: + json.dump(res['bbox'], f) + logger.info('The bbox result is saved to bbox.json.') + + bbox_stats = cocoapi_eval('bbox.json', 'bbox', anno_file=anno_file) + eval_res.append(bbox_stats) + sys.stdout.flush() + if 'mask' in res: + with open("mask.json", 'w') as f: + json.dump(res['mask'], f) + logger.info('The mask result is saved to mask.json.') - with io.open("mask.json", 'w') as outfile: - encode_func = unicode if six.PY2 else str - outfile.write(encode_func(json.dumps(seg_res))) + seg_stats = cocoapi_eval('mask.json', 'mask', anno_file=anno_file) + eval_res.append(seg_stats) + sys.stdout.flush() + else: + raise NotImplemented("Only COCO metric is supported now.") - cocoSg = cocoGt.loadRes("mask.json") - cocoEval = COCOeval(cocoGt, cocoSg, 'segm') - cocoEval.evaluate() - cocoEval.accumulate() - cocoEval.summarize() + return eval_res diff --git a/ppdet/utils/visualizer.py b/ppdet/utils/visualizer.py index 0658c8c35..e41db7b36 100644 --- a/ppdet/utils/visualizer.py +++ b/ppdet/utils/visualizer.py @@ -26,18 +26,18 @@ __all__ = ['visualize_results'] def visualize_results(image, + bbox_res, + mask_res, im_id, catid2name, - threshold=0.5, - bbox_results=None, - mask_results=None): + threshold=0.5): """ Visualize bbox and mask results """ - if mask_results: - image = draw_mask(image, im_id, mask_results, threshold) - if bbox_results: - image = draw_bbox(image, im_id, catid2name, bbox_results, threshold) + if bbox_res is not None: + image = draw_bbox(image, im_id, catid2name, bbox_res, threshold) + if mask_res is not None: + image = draw_mask(image, im_id, mask_res, threshold) return image diff --git a/tools/eval.py b/tools/eval.py index f5f10c679..9a5534bfa 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -18,7 +18,7 @@ from paddle.distributed import ParallelEnv from ppdet.core.workspace import load_config, merge_config, create from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.cli import ArgsParser -from ppdet.utils.eval_utils import coco_eval_results +from ppdet.utils.eval_utils import get_infer_results, eval_results from ppdet.data.reader import create_reader from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt import logging @@ -79,11 +79,22 @@ def run(FLAGS, cfg): cost_time = time.time() - start_time logger.info('Total sample number: {}, averge FPS: {}'.format( sample_num, sample_num / cost_time)) + + eval_type = ['bbox'] + if getattr(cfg, 'MaskHead', None): + eval_type.append('mask') # Metric - coco_eval_results( - outs_res, - include_mask=True if getattr(cfg, 'MaskHead', None) else False, - dataset=cfg['EvalReader']['dataset']) + # TODO: support other metric + dataset = cfg.EvalReader['dataset'] + from ppdet.utils.coco_eval import get_category_info + anno_file = dataset.get_anno() + with_background = dataset.with_background + use_default_label = dataset.use_default_label + clsid2catid, catid2name = get_category_info(anno_file, with_background, + use_default_label) + + infer_res = get_infer_results(outs_res, eval_type, clsid2catid) + eval_results(infer_res, cfg.metric, anno_file) def main(): diff --git a/tools/infer.py b/tools/infer.py new file mode 100755 index 000000000..c16f0d5c7 --- /dev/null +++ b/tools/infer.py @@ -0,0 +1,215 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os, sys +# add python path of PadleDetection to sys.path +parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +# ignore numba warning +import warnings +warnings.filterwarnings('ignore') +import glob +import numpy as np +from PIL import Image +import paddle +from paddle.distributed import ParallelEnv +from ppdet.core.workspace import load_config, merge_config, create +from ppdet.utils.check import check_gpu, check_version, check_config +from ppdet.utils.visualizer import visualize_results +from ppdet.utils.cli import ArgsParser +from ppdet.data.reader import create_reader +from ppdet.utils.checkpoint import load_dygraph_ckpt +from ppdet.utils.eval_utils import get_infer_results +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = ArgsParser() + parser.add_argument( + "--infer_dir", + type=str, + default=None, + help="Directory for images to perform inference on.") + parser.add_argument( + "--infer_img", + type=str, + default=None, + help="Image path, has higher priority over --infer_dir") + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Directory for storing the output visualization files.") + parser.add_argument( + "--draw_threshold", + type=float, + default=0.5, + help="Threshold to reserve the result for visualization.") + parser.add_argument( + "--use_vdl", + type=bool, + default=False, + help="whether to record the data to VisualDL.") + parser.add_argument( + '--vdl_log_dir', + type=str, + default="vdl_log_dir/image", + help='VisualDL logging directory for image.') + args = parser.parse_args() + return args + + +def get_save_image_name(output_dir, image_path): + """ + Get save image name from source image path. + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + image_name = os.path.split(image_path)[-1] + name, ext = os.path.splitext(image_name) + return os.path.join(output_dir, "{}".format(name)) + ext + + +def get_test_images(infer_dir, infer_img): + """ + Get image path list in TEST mode + """ + assert infer_img is not None or infer_dir is not None, \ + "--infer_img or --infer_dir should be set" + assert infer_img is None or os.path.isfile(infer_img), \ + "{} is not a file".format(infer_img) + assert infer_dir is None or os.path.isdir(infer_dir), \ + "{} is not a directory".format(infer_dir) + + # infer_img has a higher priority + if infer_img and os.path.isfile(infer_img): + return [infer_img] + + images = set() + infer_dir = os.path.abspath(infer_dir) + assert os.path.isdir(infer_dir), \ + "infer_dir {} is not a directory".format(infer_dir) + exts = ['jpg', 'jpeg', 'png', 'bmp'] + exts += [ext.upper() for ext in exts] + for ext in exts: + images.update(glob.glob('{}/*.{}'.format(infer_dir, ext))) + images = list(images) + + assert len(images) > 0, "no image found in {}".format(infer_dir) + logger.info("Found {} inference images in total.".format(len(images))) + + return images + + +def run(FLAGS, cfg): + + # Model + main_arch = cfg.architecture + model = create(cfg.architecture) + + dataset = cfg.TestReader['dataset'] + test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img) + dataset.set_images(test_images) + + # TODO: support other metrics + imid2path = dataset.get_imid2path() + + from ppdet.utils.coco_eval import get_category_info + anno_file = dataset.get_anno() + with_background = dataset.with_background + use_default_label = dataset.use_default_label + clsid2catid, catid2name = get_category_info(anno_file, with_background, + use_default_label) + + # Init Model + model = load_dygraph_ckpt(model, ckpt=cfg.weights) + + # Data Reader + test_reader = create_reader(cfg.TestReader) + + # Run Infer + for iter_id, data in enumerate(test_reader()): + # forward + model.eval() + outs = model(data, cfg.TestReader['inputs_def']['fields'], 'infer') + + batch_res = get_infer_results([outs], outs.keys(), clsid2catid) + logger.info('Infer iter {}'.format(iter_id)) + bbox_res = None + mask_res = None + + im_ids = outs['im_id'] + bbox_num = outs['bbox_num'] + start = 0 + for i, im_id in enumerate(im_ids): + im_id = im_ids[i] + image_path = imid2path[int(im_id)] + image = Image.open(image_path).convert('RGB') + end = start + bbox_num[i] + + # use VisualDL to log original image + if FLAGS.use_vdl: + original_image_np = np.array(image) + vdl_writer.add_image( + "original/frame_{}".format(vdl_image_frame), + original_image_np, vdl_image_step) + + if 'bbox' in batch_res: + bbox_res = batch_res['bbox'][start:end] + if 'mask' in batch_res: + mask_res = batch_res['mask'][start:end] + + image = visualize_results(image, bbox_res, mask_res, + int(im_id), catid2name, + FLAGS.draw_threshold) + + # use VisualDL to log image with bbox + if FLAGS.use_vdl: + infer_image_np = np.array(image) + vdl_writer.add_image("bbox/frame_{}".format(vdl_image_frame), + infer_image_np, vdl_image_step) + vdl_image_step += 1 + if vdl_image_step % 10 == 0: + vdl_image_step = 0 + vdl_image_frame += 1 + + # save image with detection + save_name = get_save_image_name(FLAGS.output_dir, image_path) + logger.info("Detection bbox results save in {}".format(save_name)) + image.save(save_name, quality=95) + start = end + + +def main(): + FLAGS = parse_args() + + cfg = load_config(FLAGS.config) + merge_config(FLAGS.opt) + check_config(cfg) + check_gpu(cfg.use_gpu) + check_version() + + run(FLAGS, cfg) + + +if __name__ == '__main__': + main() -- GitLab