提交 2dd6872e 编写于 作者: C chenguowei01

update train.py

上级 a491c19c
......@@ -87,6 +87,7 @@ def train(model,
labels = data[1].astype('int64')
if nranks > 1:
loss = ddp_model(images, labels)
# apply_collective_grads sum grads over multiple gpus.
loss = ddp_model.scale_loss(loss)
loss.backward()
ddp_model.apply_collective_grads()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册