未验证 提交 ed31dac6 编写于 作者: C Chen Weihang 提交者: GitHub

remove scale loss and coll grads, test=document_fix (#27874)

上级 6898746f
......@@ -630,9 +630,7 @@ class Fleet(object):
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
......@@ -842,9 +840,7 @@ class Fleet(object):
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
......@@ -903,9 +899,7 @@ class Fleet(object):
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
......
......@@ -92,9 +92,7 @@ def init_parallel_env():
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
......
......@@ -314,9 +314,7 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
if print_result is True:
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
......
......@@ -397,9 +397,7 @@ class DataParallel(layers.Layer):
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册