提交 121b5378 编写于 作者: C cuicheng01

fix bugs to adapt to the new framework

上级 a3a4390a
...@@ -81,7 +81,8 @@ def classification_eval(engine, epoch_id=0): ...@@ -81,7 +81,8 @@ def classification_eval(engine, epoch_id=0):
# gather Tensor when distributed # gather Tensor when distributed
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
label_list = [] label_list = []
label = batch[1].cuda() if engine.config["Global"][ device_id = paddle.distributed.ParallelEnv().device_id
label = batch[1].cuda(device_id) if engine.config["Global"][
"device"] == "gpu" else batch[1] "device"] == "gpu" else batch[1]
paddle.distributed.all_gather(label_list, label) paddle.distributed.all_gather(label_list, label)
labels = paddle.concat(label_list, 0) labels = paddle.concat(label_list, 0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册