From f17b971d4157583b86ea8f519903fdc30cace4d3 Mon Sep 17 00:00:00 2001 From: whs Date: Thu, 21 Nov 2019 13:13:33 +0800 Subject: [PATCH] Make demo of slim support COCO dataset. (#33) --- slim/prune/compress.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/slim/prune/compress.py b/slim/prune/compress.py index 004496cc0..ed2426df4 100644 --- a/slim/prune/compress.py +++ b/slim/prune/compress.py @@ -52,7 +52,7 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) -def eval_run(exe, compile_program, reader, keys, values, cls, test_feed): +def eval_run(exe, compile_program, reader, keys, values, cls, test_feed, cfg): """ Run evaluation program, return program outputs. """ @@ -75,9 +75,16 @@ def eval_run(exe, compile_program, reader, keys, values, cls, test_feed): feed=feed_data, fetch_list=[values[0]], return_numpy=False) - outs.append(data['gt_box']) - outs.append(data['gt_label']) - outs.append(data['is_difficult']) + + if cfg.metric == 'VOC': + outs.append(data['gt_box']) + outs.append(data['gt_label']) + outs.append(data['is_difficult']) + elif cfg.metric == 'COCO': + outs.append(data['im_info']) + outs.append(data['im_id']) + outs.append(data['im_shape']) + res = { k: (np.array(v), v.recursive_sequence_lengths()) for k, v in zip(keys, outs) @@ -195,7 +202,7 @@ def main(): #place = fluid.CPUPlace() #exe = fluid.Executor(place) results = eval_run(exe, program, eval_reader, eval_keys, eval_values, - eval_cls, test_data_feed) + eval_cls, test_data_feed, cfg) resolution = None if 'mask' in results[0]: -- GitLab