From 677dfad0572b4f7f398e2f78c3aa188c05f29c39 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Thu, 6 Feb 2020 16:11:28 +0800 Subject: [PATCH] fix prune.py (#212) * fix prune.py --- slim/prune/prune.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/slim/prune/prune.py b/slim/prune/prune.py index 4ccdb274a..d5fd5f8b3 100644 --- a/slim/prune/prune.py +++ b/slim/prune/prune.py @@ -34,7 +34,6 @@ from ppdet.utils.stats import TrainingStats from ppdet.utils.cli import ArgsParser from ppdet.utils.check import check_gpu, check_version import ppdet.utils.checkpoint as checkpoint -from ppdet.modeling.model_input import create_feed import logging FORMAT = '%(asctime)s-%(levelname)s: %(message)s' @@ -142,9 +141,9 @@ def main(): if cfg.metric == 'COCO': extra_keys = ['im_info', 'im_id', 'im_shape'] if cfg.metric == 'VOC': - extra_keys = ['gt_box', 'gt_label', 'is_difficult'] + extra_keys = ['gt_bbox', 'gt_class', 'is_difficult'] if cfg.metric == 'WIDERFACE': - extra_keys = ['im_id', 'im_shape', 'gt_box'] + extra_keys = ['im_id', 'im_shape', 'gt_bbox'] eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog, extra_keys) @@ -306,8 +305,14 @@ def main(): if 'mask' in results[0]: resolution = model.mask_head.resolution box_ap_stats = eval_results( - results, eval_feed, cfg.metric, cfg.num_classes, resolution, - is_bbox_normalized, FLAGS.output_eval, map_type) + results, + cfg.metric, + cfg.num_classes, + resolution, + is_bbox_normalized, + FLAGS.output_eval, + map_type, + dataset=dataset) # use tb_paddle to log mAP if FLAGS.use_tb: -- GitLab