提交 278f6d80 编写于 作者: D dongshuilong

fig goooglenet distributed eval bug

上级 69d9a477
......@@ -73,14 +73,24 @@ def classification_eval(engine, epoch_id=0):
# calc metric
if engine.eval_metric_func is not None:
if paddle.distributed.get_world_size() > 1:
pred_list = []
label_list = []
paddle.distributed.all_gather(label_list, batch[1])
labels = paddle.concat(label_list, 0)
if isinstance(out, dict):
out = out["logits"]
if isinstance(out, list):
pred = []
for x in out:
pred_list = []
paddle.distributed.all_gather(pred_list, x)
pred_x = paddle.concat(pred_list, 0)
pred.append(pred_x)
else:
pred_list = []
paddle.distributed.all_gather(pred_list, out)
paddle.distributed.all_gather(label_list, batch[1])
pred = paddle.concat(pred_list, 0)
labels = paddle.concat(label_list, 0)
if accum_samples > total_samples:
pred = pred[:total_samples + current_samples -
accum_samples]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册