diff --git a/slim/prune/compress.py b/slim/prune/compress.py index 004496cc0beff7f06efc9d4ac0551d5f77928efc..ed2426df4fa7aa7b44e58f627a472a955851d303 100644 --- a/slim/prune/compress.py +++ b/slim/prune/compress.py @@ -52,7 +52,7 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) -def eval_run(exe, compile_program, reader, keys, values, cls, test_feed): +def eval_run(exe, compile_program, reader, keys, values, cls, test_feed, cfg): """ Run evaluation program, return program outputs. """ @@ -75,9 +75,16 @@ def eval_run(exe, compile_program, reader, keys, values, cls, test_feed): feed=feed_data, fetch_list=[values[0]], return_numpy=False) - outs.append(data['gt_box']) - outs.append(data['gt_label']) - outs.append(data['is_difficult']) + + if cfg.metric == 'VOC': + outs.append(data['gt_box']) + outs.append(data['gt_label']) + outs.append(data['is_difficult']) + elif cfg.metric == 'COCO': + outs.append(data['im_info']) + outs.append(data['im_id']) + outs.append(data['im_shape']) + res = { k: (np.array(v), v.recursive_sequence_lengths()) for k, v in zip(keys, outs) @@ -195,7 +202,7 @@ def main(): #place = fluid.CPUPlace() #exe = fluid.Executor(place) results = eval_run(exe, program, eval_reader, eval_keys, eval_values, - eval_cls, test_data_feed) + eval_cls, test_data_feed, cfg) resolution = None if 'mask' in results[0]: