From 9daba606d9e6c3c0b8f6acdb5fba505240e5bc2b Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Sat, 22 Jul 2023 10:20:15 +0800 Subject: [PATCH] make sharding reduce mode by default (#55529) * make sharding reduce mode by default * Update dygraph_sharding_optimizer.py * Update hybrid_parallel_optimizer.py * Update pipeline_parallel.py --- .../dygraph_optimizer/dygraph_sharding_optimizer.py | 6 ++---- .../dygraph_optimizer/hybrid_parallel_optimizer.py | 3 +-- .../distributed/fleet/meta_parallel/pipeline_parallel.py | 3 +-- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index e20e988f916..a2f6ba3d932 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -24,10 +24,8 @@ from paddle.fluid.dygraph import base as imperative_base from ...utils.log_util import logger -g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0)) -logger.info(f"g_shard_use_reduce {g_shard_use_reduce}") -g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1)) -logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}") +g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1)) +g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 0)) if g_shard_norm_align_dp: assert ( diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index 5fc568d4f8e..a903d8bdaa5 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -41,8 +41,7 @@ from ...utils.mix_precision_utils import MixPrecisionOptimizer __all__ = [] -g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1)) -logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}") +g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 0)) class HybridParallelClipGrad: diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 12c8055e7bd..3710a9014c4 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -31,8 +31,7 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size __all__ = [] -g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0)) -logger.info(f"g_shard_use_reduce {g_shard_use_reduce}") +g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1)) # assume only the first stage and last stage need data, and data consumption are ordred; -- GitLab