提交 a66c504c 编写于 作者: L Liufang Sang 提交者: qingqing01

Fix coco eval error in compress.py (#53)

上级 50381eda
...@@ -59,7 +59,7 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) ...@@ -59,7 +59,7 @@ logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def eval_run(exe, compile_program, reader, keys, values, cls, test_feed): def eval_run(exe, compile_program, reader, keys, values, cls, test_feed, cfg):
""" """
Run evaluation program, return program outputs. Run evaluation program, return program outputs.
""" """
...@@ -82,9 +82,12 @@ def eval_run(exe, compile_program, reader, keys, values, cls, test_feed): ...@@ -82,9 +82,12 @@ def eval_run(exe, compile_program, reader, keys, values, cls, test_feed):
feed=feed_data, feed=feed_data,
fetch_list=[values[0]], fetch_list=[values[0]],
return_numpy=False) return_numpy=False)
if cfg.metric == 'VOC':
outs.append(data['gt_box']) outs.append(data['gt_box'])
outs.append(data['gt_label']) outs.append(data['gt_label'])
outs.append(data['is_difficult']) outs.append(data['is_difficult'])
elif cfg.metric == 'COCO':
outs.append(data['im_id'])
res = { res = {
k: (np.array(v), v.recursive_sequence_lengths()) k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(keys, outs) for k, v in zip(keys, outs)
...@@ -208,7 +211,7 @@ def main(): ...@@ -208,7 +211,7 @@ def main():
#place = fluid.CPUPlace() #place = fluid.CPUPlace()
#exe = fluid.Executor(place) #exe = fluid.Executor(place)
results = eval_run(exe, program, eval_reader, eval_keys, eval_values, results = eval_run(exe, program, eval_reader, eval_keys, eval_values,
eval_cls, test_data_feed) eval_cls, test_data_feed, cfg)
resolution = None resolution = None
if 'mask' in results[0]: if 'mask' in results[0]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册