未验证 提交 65f705e1 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support sharding_parallel under eager (#42910)

上级 c0001a24
......@@ -162,9 +162,15 @@ def sharding_reduce_gradients(parameter_list, hcg):
sharding_nrank = hcg.get_sharding_parallel_group().nranks
for param in parameter_list:
if param.trainable and (param._grad_ivar() is not None):
if in_dygraph_mode():
param.grad.scale_(1.0 / sharding_nrank)
paddle.distributed.all_reduce(
param.grad,
group=hcg.get_sharding_parallel_group(),
use_calc_stream=True)
elif _in_legacy_dygraph():
g_var = param._grad_ivar()
# need use trace_op to allreduce
# paddle.distributed.all_reduce(
# g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True)
......@@ -178,7 +184,8 @@ def sharding_reduce_gradients(parameter_list, hcg):
})
# grad / sharding_rank
div_factor = paddle.to_tensor(sharding_nrank, dtype=g_var.dtype)
div_factor = paddle.to_tensor(
sharding_nrank, dtype=g_var.dtype)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="elementwise_div",
inputs={'X': g_var,
......
......@@ -25,8 +25,7 @@ class TestHybridParallel(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_hybrid_parallel_sharding_logic(self):
# self.run_mnist_2gpu(
# 'hybrid_parallel_sharding_model.py')
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
self.run_mnist_2gpu(
'hybrid_parallel_sharding_model.py', eager_mode=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册