未验证 提交 65f705e1 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support sharding_parallel under eager (#42910)

上级 c0001a24
...@@ -162,9 +162,15 @@ def sharding_reduce_gradients(parameter_list, hcg): ...@@ -162,9 +162,15 @@ def sharding_reduce_gradients(parameter_list, hcg):
sharding_nrank = hcg.get_sharding_parallel_group().nranks sharding_nrank = hcg.get_sharding_parallel_group().nranks
for param in parameter_list: for param in parameter_list:
if param.trainable and (param._grad_ivar() is not None): if param.trainable and (param._grad_ivar() is not None):
if in_dygraph_mode():
param.grad.scale_(1.0 / sharding_nrank)
paddle.distributed.all_reduce(
param.grad,
group=hcg.get_sharding_parallel_group(),
use_calc_stream=True)
elif _in_legacy_dygraph():
g_var = param._grad_ivar() g_var = param._grad_ivar()
# need use trace_op to allreduce # need use trace_op to allreduce
# paddle.distributed.all_reduce( # paddle.distributed.all_reduce(
# g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True) # g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True)
...@@ -178,7 +184,8 @@ def sharding_reduce_gradients(parameter_list, hcg): ...@@ -178,7 +184,8 @@ def sharding_reduce_gradients(parameter_list, hcg):
}) })
# grad / sharding_rank # grad / sharding_rank
div_factor = paddle.to_tensor(sharding_nrank, dtype=g_var.dtype) div_factor = paddle.to_tensor(
sharding_nrank, dtype=g_var.dtype)
paddle.fluid.framework._dygraph_tracer().trace_op( paddle.fluid.framework._dygraph_tracer().trace_op(
type="elementwise_div", type="elementwise_div",
inputs={'X': g_var, inputs={'X': g_var,
......
...@@ -25,8 +25,7 @@ class TestHybridParallel(TestMultipleGpus): ...@@ -25,8 +25,7 @@ class TestHybridParallel(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode # check sharding logic as well as the accuracy with single mode
def test_hybrid_parallel_sharding_logic(self): def test_hybrid_parallel_sharding_logic(self):
# self.run_mnist_2gpu( self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
# 'hybrid_parallel_sharding_model.py')
self.run_mnist_2gpu( self.run_mnist_2gpu(
'hybrid_parallel_sharding_model.py', eager_mode=False) 'hybrid_parallel_sharding_model.py', eager_mode=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册