提交 2dd6872e 编写于 作者: C chenguowei01

update train.py

上级 a491c19c
...@@ -87,6 +87,7 @@ def train(model, ...@@ -87,6 +87,7 @@ def train(model,
labels = data[1].astype('int64') labels = data[1].astype('int64')
if nranks > 1: if nranks > 1:
loss = ddp_model(images, labels) loss = ddp_model(images, labels)
# apply_collective_grads sum grads over multiple gpus.
loss = ddp_model.scale_loss(loss) loss = ddp_model.scale_loss(loss)
loss.backward() loss.backward()
ddp_model.apply_collective_grads() ddp_model.apply_collective_grads()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册