diff --git a/slim/prune/prune.py b/slim/prune/prune.py index 4ccdb274af92c508ad7ea70269f748ec56be2934..d5fd5f8b35309405f00bfa428650417bd9d20489 100644 --- a/slim/prune/prune.py +++ b/slim/prune/prune.py @@ -34,7 +34,6 @@ from ppdet.utils.stats import TrainingStats from ppdet.utils.cli import ArgsParser from ppdet.utils.check import check_gpu, check_version import ppdet.utils.checkpoint as checkpoint -from ppdet.modeling.model_input import create_feed import logging FORMAT = '%(asctime)s-%(levelname)s: %(message)s' @@ -142,9 +141,9 @@ def main(): if cfg.metric == 'COCO': extra_keys = ['im_info', 'im_id', 'im_shape'] if cfg.metric == 'VOC': - extra_keys = ['gt_box', 'gt_label', 'is_difficult'] + extra_keys = ['gt_bbox', 'gt_class', 'is_difficult'] if cfg.metric == 'WIDERFACE': - extra_keys = ['im_id', 'im_shape', 'gt_box'] + extra_keys = ['im_id', 'im_shape', 'gt_bbox'] eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog, extra_keys) @@ -306,8 +305,14 @@ def main(): if 'mask' in results[0]: resolution = model.mask_head.resolution box_ap_stats = eval_results( - results, eval_feed, cfg.metric, cfg.num_classes, resolution, - is_bbox_normalized, FLAGS.output_eval, map_type) + results, + cfg.metric, + cfg.num_classes, + resolution, + is_bbox_normalized, + FLAGS.output_eval, + map_type, + dataset=dataset) # use tb_paddle to log mAP if FLAGS.use_tb: