未验证 提交 65f705e1 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support sharding_parallel under eager (#42910)

上级 c0001a24
......@@ -162,29 +162,36 @@ def sharding_reduce_gradients(parameter_list, hcg):
sharding_nrank = hcg.get_sharding_parallel_group().nranks
for param in parameter_list:
if param.trainable and (param._grad_ivar() is not None):
g_var = param._grad_ivar()
# need use trace_op to allreduce
# paddle.distributed.all_reduce(
# g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="c_allreduce_sum",
inputs={'X': g_var},
outputs={'Out': g_var},
attrs={
'ring_id': hcg.get_sharding_parallel_group().id,
'use_calc_stream': True
})
# grad / sharding_rank
div_factor = paddle.to_tensor(sharding_nrank, dtype=g_var.dtype)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="elementwise_div",
inputs={'X': g_var,
'Y': div_factor},
outputs={'Out': g_var},
attrs={'axis': -1})
if in_dygraph_mode():
param.grad.scale_(1.0 / sharding_nrank)
paddle.distributed.all_reduce(
param.grad,
group=hcg.get_sharding_parallel_group(),
use_calc_stream=True)
elif _in_legacy_dygraph():
g_var = param._grad_ivar()
# need use trace_op to allreduce
# paddle.distributed.all_reduce(
# g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="c_allreduce_sum",
inputs={'X': g_var},
outputs={'Out': g_var},
attrs={
'ring_id': hcg.get_sharding_parallel_group().id,
'use_calc_stream': True
})
# grad / sharding_rank
div_factor = paddle.to_tensor(
sharding_nrank, dtype=g_var.dtype)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="elementwise_div",
inputs={'X': g_var,
'Y': div_factor},
outputs={'Out': g_var},
attrs={'axis': -1})
def broadcast_sharding_parameters(model, hcg):
......
......@@ -25,8 +25,7 @@ class TestHybridParallel(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_hybrid_parallel_sharding_logic(self):
# self.run_mnist_2gpu(
# 'hybrid_parallel_sharding_model.py')
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
self.run_mnist_2gpu(
'hybrid_parallel_sharding_model.py', eager_mode=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册