From 65f705e1011f63c349813d7368d55b35df03ad82 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 23 May 2022 12:06:39 +0800 Subject: [PATCH] [Eager] Support sharding_parallel under eager (#42910) --- .../fleet/utils/hybrid_parallel_util.py | 53 +++++++++++-------- ...test_parallel_dygraph_sharding_parallel.py | 3 +- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index 1285e1f332..d0b5c915e1 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -162,29 +162,36 @@ def sharding_reduce_gradients(parameter_list, hcg): sharding_nrank = hcg.get_sharding_parallel_group().nranks for param in parameter_list: if param.trainable and (param._grad_ivar() is not None): - - g_var = param._grad_ivar() - - # need use trace_op to allreduce - # paddle.distributed.all_reduce( - # g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True) - paddle.fluid.framework._dygraph_tracer().trace_op( - type="c_allreduce_sum", - inputs={'X': g_var}, - outputs={'Out': g_var}, - attrs={ - 'ring_id': hcg.get_sharding_parallel_group().id, - 'use_calc_stream': True - }) - - # grad / sharding_rank - div_factor = paddle.to_tensor(sharding_nrank, dtype=g_var.dtype) - paddle.fluid.framework._dygraph_tracer().trace_op( - type="elementwise_div", - inputs={'X': g_var, - 'Y': div_factor}, - outputs={'Out': g_var}, - attrs={'axis': -1}) + if in_dygraph_mode(): + param.grad.scale_(1.0 / sharding_nrank) + paddle.distributed.all_reduce( + param.grad, + group=hcg.get_sharding_parallel_group(), + use_calc_stream=True) + + elif _in_legacy_dygraph(): + g_var = param._grad_ivar() + # need use trace_op to allreduce + # paddle.distributed.all_reduce( + # g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True) + paddle.fluid.framework._dygraph_tracer().trace_op( + type="c_allreduce_sum", + inputs={'X': g_var}, + outputs={'Out': g_var}, + attrs={ + 'ring_id': hcg.get_sharding_parallel_group().id, + 'use_calc_stream': True + }) + + # grad / sharding_rank + div_factor = paddle.to_tensor( + sharding_nrank, dtype=g_var.dtype) + paddle.fluid.framework._dygraph_tracer().trace_op( + type="elementwise_div", + inputs={'X': g_var, + 'Y': div_factor}, + outputs={'Out': g_var}, + attrs={'axis': -1}) def broadcast_sharding_parameters(model, hcg): diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sharding_parallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sharding_parallel.py index e12d1826f2..503bd9d0f9 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sharding_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sharding_parallel.py @@ -25,8 +25,7 @@ class TestHybridParallel(TestMultipleGpus): # check sharding logic as well as the accuracy with single mode def test_hybrid_parallel_sharding_logic(self): - # self.run_mnist_2gpu( - # 'hybrid_parallel_sharding_model.py') + self.run_mnist_2gpu('hybrid_parallel_sharding_model.py') self.run_mnist_2gpu( 'hybrid_parallel_sharding_model.py', eager_mode=False) -- GitLab