未验证 提交 65f705e1 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support sharding_parallel under eager (#42910)

上级 c0001a24
...@@ -162,29 +162,36 @@ def sharding_reduce_gradients(parameter_list, hcg): ...@@ -162,29 +162,36 @@ def sharding_reduce_gradients(parameter_list, hcg):
sharding_nrank = hcg.get_sharding_parallel_group().nranks sharding_nrank = hcg.get_sharding_parallel_group().nranks
for param in parameter_list: for param in parameter_list:
if param.trainable and (param._grad_ivar() is not None): if param.trainable and (param._grad_ivar() is not None):
if in_dygraph_mode():
g_var = param._grad_ivar() param.grad.scale_(1.0 / sharding_nrank)
paddle.distributed.all_reduce(
# need use trace_op to allreduce param.grad,
# paddle.distributed.all_reduce( group=hcg.get_sharding_parallel_group(),
# g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True) use_calc_stream=True)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="c_allreduce_sum", elif _in_legacy_dygraph():
inputs={'X': g_var}, g_var = param._grad_ivar()
outputs={'Out': g_var}, # need use trace_op to allreduce
attrs={ # paddle.distributed.all_reduce(
'ring_id': hcg.get_sharding_parallel_group().id, # g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True)
'use_calc_stream': True paddle.fluid.framework._dygraph_tracer().trace_op(
}) type="c_allreduce_sum",
inputs={'X': g_var},
# grad / sharding_rank outputs={'Out': g_var},
div_factor = paddle.to_tensor(sharding_nrank, dtype=g_var.dtype) attrs={
paddle.fluid.framework._dygraph_tracer().trace_op( 'ring_id': hcg.get_sharding_parallel_group().id,
type="elementwise_div", 'use_calc_stream': True
inputs={'X': g_var, })
'Y': div_factor},
outputs={'Out': g_var}, # grad / sharding_rank
attrs={'axis': -1}) div_factor = paddle.to_tensor(
sharding_nrank, dtype=g_var.dtype)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="elementwise_div",
inputs={'X': g_var,
'Y': div_factor},
outputs={'Out': g_var},
attrs={'axis': -1})
def broadcast_sharding_parameters(model, hcg): def broadcast_sharding_parameters(model, hcg):
......
...@@ -25,8 +25,7 @@ class TestHybridParallel(TestMultipleGpus): ...@@ -25,8 +25,7 @@ class TestHybridParallel(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode # check sharding logic as well as the accuracy with single mode
def test_hybrid_parallel_sharding_logic(self): def test_hybrid_parallel_sharding_logic(self):
# self.run_mnist_2gpu( self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
# 'hybrid_parallel_sharding_model.py')
self.run_mnist_2gpu( self.run_mnist_2gpu(
'hybrid_parallel_sharding_model.py', eager_mode=False) 'hybrid_parallel_sharding_model.py', eager_mode=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册