diff --git a/dygraph/core/train.py b/dygraph/core/train.py index 0bbcabf15b6f72e1a3dbece7cd53ee90ca863450..9563f0c3a2e712746ee85253b86a999deeb89b41 100644 --- a/dygraph/core/train.py +++ b/dygraph/core/train.py @@ -87,6 +87,7 @@ def train(model, labels = data[1].astype('int64') if nranks > 1: loss = ddp_model(images, labels) + # apply_collective_grads sum grads over multiple gpus. loss = ddp_model.scale_loss(loss) loss.backward() ddp_model.apply_collective_grads()