diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 02c8ca6092aa840032585abe99ae1d05a2a832ac..63d261e2e3dfe1667039551a042b2e0ea0e10d92 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -13,7 +13,7 @@ # limitations under the License. ###### - +import os from functools import reduce import paddle @@ -23,6 +23,16 @@ from paddle.distributed import fleet from ...utils.log_util import logger from ...utils.tensor_fusion_helper import fused_parameters +g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0)) +logger.info(f"g_shard_use_reduce {g_shard_use_reduce}") +g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1)) +logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}") + +if g_shard_norm_align_dp: + assert ( + not g_shard_use_reduce + ), "g_shard_norm_align_dp is not support if g_shard_use_reduce is true" + def _is_trainable(param): return not param.stop_gradient @@ -203,18 +213,22 @@ class DygraphShardingOptimizer: if g_var is not None: g_var.scale_(1.0 / sharding_nrank) param_rank = self._param2rank[param.name] - paddle.distributed.all_reduce( - g_var, - group=hcg.get_sharding_parallel_group(), - sync_op=True, - ) - # TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp. - # paddle.distributed.reduce( - # g_var, - # dst=hcg.get_sharding_parallel_group().ranks[param_rank], - # group=hcg.get_sharding_parallel_group(), - # sync_op=True, - # ) + if not g_shard_use_reduce: + paddle.distributed.all_reduce( + g_var, + group=hcg.get_sharding_parallel_group(), + sync_op=True, + ) + else: + # TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp. + paddle.distributed.reduce( + g_var, + dst=hcg.get_sharding_parallel_group().ranks[ + param_rank + ], + group=hcg.get_sharding_parallel_group(), + sync_op=True, + ) def _sharding_sync_parameters(self): """ @@ -294,11 +308,11 @@ class DygraphShardingOptimizer: if hasattr(param, "main_grad") and param.main_grad is not None: grad_var = param.main_grad params_grads.append((param, grad_var)) - if hasattr(self._inner_opt._grad_clip, 'not_sharding_stage1'): - self._inner_opt._grad_clip.not_sharding_stage1 = False - params_grads = self._inner_opt._grad_clip(params_grads) - # set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize - self._set_inner_opt_attr('_grad_clip', None) + + if g_shard_norm_align_dp: + params_grads = self._inner_opt._grad_clip(params_grads) + # set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize + self._set_inner_opt_attr('_grad_clip', None) rank_params = ( self._rank2params[self._sharding_rank] if not self.tensor_fusion @@ -313,8 +327,9 @@ class DygraphShardingOptimizer: startup_program=None, params_grads=update_params_grads, ) - # restore the grad clip - self._set_inner_opt_attr('_grad_clip', origin_clip) + if g_shard_norm_align_dp: + # restore the grad clip + self._set_inner_opt_attr('_grad_clip', origin_clip) # sync parameters across sharding ranks self._sharding_sync_parameters() diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index d8bf0510712debb82039b9003e958921dfb9c3f8..b24247b580766dc3968c9e4db5f8a257065710b8 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import numpy as np import paddle @@ -37,6 +39,9 @@ from ...utils.mix_precision_utils import MixPrecisionOptimizer __all__ = [] +g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1)) +logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}") + class HybridParallelClipGrad: def __init__(self, clip, hcg): @@ -44,6 +49,42 @@ class HybridParallelClipGrad: self._hcg = hcg self.not_sharding_stage1 = True + def _global_norm(self, global_norm_var_dist, global_norm_var_not_dist): + # sharding first + sharding_flag = ( + self._hcg.get_sharding_parallel_world_size() > 1 + and self._hcg.get_data_parallel_world_size() == 1 + ) + mp_flag = self._hcg.get_model_parallel_world_size() > 1 + + # add all reduce to get global norm of distributed params_and_grads + if sharding_flag and not g_shard_norm_align_dp: + # norm of mp distributed variable + if mp_flag: + paddle.distributed.all_reduce( + global_norm_var_dist, + group=self._hcg.get_sharding_parallel_group(), + ) + # not dist only reduce among sharding group and pp group later + paddle.distributed.all_reduce( + global_norm_var_not_dist, + group=self._hcg.get_sharding_parallel_group(), + ) + # norm of mp distributed variable + if mp_flag: + # dist should reduce among sharding group、mp group、pp group + paddle.distributed.all_reduce( + global_norm_var_dist, + group=self._hcg.get_check_parallel_group(sharding_flag), + ) + + # add all reduce to get global norm of non-distributed params_and_grads in groups of pp + if self._hcg.get_pipe_parallel_world_size() > 1: + paddle.distributed.all_reduce( + global_norm_var_not_dist, + group=self._hcg.get_pipe_parallel_group(), + ) + @no_grad() def _dygraph_clip(self, params_grads): sum_square_dist_fp16 = [] @@ -157,37 +198,7 @@ class HybridParallelClipGrad: + global_norm_not_dist_fp32 ) - # add all reduce to get global norm of distributed params_and_grads - if self._hcg.get_model_parallel_world_size() > 1: - sharding_flag = False - if ( - self._hcg.get_sharding_parallel_world_size() > 1 - and self._hcg.get_data_parallel_world_size() == 1 - ): - sharding_flag = True - paddle.distributed.all_reduce( - global_norm_var_dist, - group=self._hcg.get_check_parallel_group(sharding_flag), - ) - - # add all reduce to get global norm of non-distributed params_and_grads in groups of pp - if self._hcg.get_pipe_parallel_world_size() > 1: - paddle.distributed.all_reduce( - global_norm_var_not_dist, - group=self._hcg.get_pipe_parallel_group(), - ) - - # In Sharding mode, param and grad is mapping different rank in optimizer. - # ClipGradByGlobalNorm need allreduce to get globol norm - # TODO(pangengzheng): remove the self.not_sharding_stage1 flag when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp. - if ( - self._hcg.get_sharding_parallel_world_size() > 1 - and self.not_sharding_stage1 - ): - paddle.distributed.all_reduce( - global_norm_var_not_dist, - group=self._hcg.get_sharding_parallel_group(), - ) + self._global_norm(global_norm_var_dist, global_norm_var_not_dist) global_norm_var_fp32 = paddle.sqrt( global_norm_var_dist + global_norm_var_not_dist diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 1dcb84c66ac8191bfcb59834b5843eca462d3454..6644e2a06e5fe91e7402b5fbf3d572808267c90b 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -44,6 +44,9 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer __all__ = [] +g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0)) +logger.info(f"g_shard_use_reduce {g_shard_use_reduce}") + # assume only the first stage and last stage need data, and data consumption is ordred # to be replaced by real micro dataset from reader @@ -299,8 +302,12 @@ class PipelineParallel(MetaParallelBase): assert hasattr(self, "optimizer") assert hasattr(self.optimizer, "_param2rank") _param2rank = self.optimizer._param2rank - - act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE + # Note: after sharding change to reduce operation, here need to be cleared + act = ( + HOOK_ACTION.ALL_REDUCE + if (dp or not g_shard_use_reduce) + else HOOK_ACTION.REDUCE + ) for model in models: # For virtual pipeline. Will separate parameters in different chunk into diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index e7d31b4aebd07a6f619feefe7c0fa3cb86f5b73b..6c8e2fd9dc3aa349580ba94463ea836ac0f5b3f8 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -207,21 +207,19 @@ class FusedCommBuffer: def _comm_grads(self): assert self._all_params_checked_in - # Note: after sharding change to reduce operation here also need to be updated - # if self._act == HOOK_ACTION.ALL_REDUCE: - # task = paddle.distributed.all_reduce( - # self.grad_storage, group=self._comm_group, sync_op=False - # ) - # elif self._act == HOOK_ACTION.REDUCE: - # task = paddle.distributed.reduce( - # self.grad_storage, - # dst=self._dst, - # group=self._comm_group, - # sync_op=False, - # ) - task = paddle.distributed.all_reduce( - self.grad_storage, group=self._comm_group, sync_op=False - ) + if self._act == HOOK_ACTION.ALL_REDUCE: + task = paddle.distributed.all_reduce( + self.grad_storage, group=self._comm_group, sync_op=False + ) + + elif self._act == HOOK_ACTION.REDUCE: + task = paddle.distributed.reduce( + self.grad_storage, + dst=self._dst, + group=self._comm_group, + sync_op=False, + ) + self._task = task @imperative_base.no_grad diff --git a/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py b/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py index f3dbfb63bce36b65c7d51d022d94beef6bf898d1..cd2398f59af59ef10acaabfccbe7811069d6b960 100644 --- a/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py +++ b/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus @@ -20,6 +21,13 @@ from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus class TestHybridParallel(TestMultipleGpus): # check sharding logic as well as the accuracy with single mode def test_hybrid_parallel_sharding_logic(self): + # test shard grad reduce + os.environ["FLAGS_shard_use_reduce"] = "1" + os.environ["FLAGS_shard_norm_align_dp"] = "0" + self.run_mnist_2gpu('hybrid_parallel_sharding_model.py') + # test shard grad allreduce + os.environ["FLAGS_shard_use_reduce"] = "0" + os.environ["FLAGS_shard_norm_align_dp"] = "1" self.run_mnist_2gpu('hybrid_parallel_sharding_model.py') def test_hybrid_parallel_sharding_tensor_fusion(self):