diff --git a/slim/quantization/compress.py b/slim/quantization/compress.py index 79e18947a101486277d531e1f88cd99c2f183d44..6ec072b280d358116544639bd99de4966d05734c 100644 --- a/slim/quantization/compress.py +++ b/slim/quantization/compress.py @@ -59,7 +59,7 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) -def eval_run(exe, compile_program, reader, keys, values, cls, test_feed): +def eval_run(exe, compile_program, reader, keys, values, cls, test_feed, cfg): """ Run evaluation program, return program outputs. """ @@ -82,9 +82,12 @@ def eval_run(exe, compile_program, reader, keys, values, cls, test_feed): feed=feed_data, fetch_list=[values[0]], return_numpy=False) - outs.append(data['gt_box']) - outs.append(data['gt_label']) - outs.append(data['is_difficult']) + if cfg.metric == 'VOC': + outs.append(data['gt_box']) + outs.append(data['gt_label']) + outs.append(data['is_difficult']) + elif cfg.metric == 'COCO': + outs.append(data['im_id']) res = { k: (np.array(v), v.recursive_sequence_lengths()) for k, v in zip(keys, outs) @@ -208,7 +211,7 @@ def main(): #place = fluid.CPUPlace() #exe = fluid.Executor(place) results = eval_run(exe, program, eval_reader, eval_keys, eval_values, - eval_cls, test_data_feed) + eval_cls, test_data_feed, cfg) resolution = None if 'mask' in results[0]: