From a66c504ca6ed67478e4f1797b11de012eb82d3a1 Mon Sep 17 00:00:00 2001 From: Liufang Sang Date: Fri, 29 Nov 2019 11:25:50 +0800 Subject: [PATCH] Fix coco eval error in compress.py (#53) --- slim/quantization/compress.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/slim/quantization/compress.py b/slim/quantization/compress.py index 79e18947a..6ec072b28 100644 --- a/slim/quantization/compress.py +++ b/slim/quantization/compress.py @@ -59,7 +59,7 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) -def eval_run(exe, compile_program, reader, keys, values, cls, test_feed): +def eval_run(exe, compile_program, reader, keys, values, cls, test_feed, cfg): """ Run evaluation program, return program outputs. """ @@ -82,9 +82,12 @@ def eval_run(exe, compile_program, reader, keys, values, cls, test_feed): feed=feed_data, fetch_list=[values[0]], return_numpy=False) - outs.append(data['gt_box']) - outs.append(data['gt_label']) - outs.append(data['is_difficult']) + if cfg.metric == 'VOC': + outs.append(data['gt_box']) + outs.append(data['gt_label']) + outs.append(data['is_difficult']) + elif cfg.metric == 'COCO': + outs.append(data['im_id']) res = { k: (np.array(v), v.recursive_sequence_lengths()) for k, v in zip(keys, outs) @@ -208,7 +211,7 @@ def main(): #place = fluid.CPUPlace() #exe = fluid.Executor(place) results = eval_run(exe, program, eval_reader, eval_keys, eval_values, - eval_cls, test_data_feed) + eval_cls, test_data_feed, cfg) resolution = None if 'mask' in results[0]: -- GitLab