未验证 提交 9daba606 编写于 作者: S sneaxiy 提交者: GitHub

make sharding reduce mode by default (#55529)

* make sharding reduce mode by default

* Update dygraph_sharding_optimizer.py

* Update hybrid_parallel_optimizer.py

* Update pipeline_parallel.py
上级 8d42540f
......@@ -24,10 +24,8 @@ from paddle.fluid.dygraph import base as imperative_base
from ...utils.log_util import logger
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
logger.info(f"g_shard_use_reduce {g_shard_use_reduce}")
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1))
logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}")
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1))
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 0))
if g_shard_norm_align_dp:
assert (
......
......@@ -41,8 +41,7 @@ from ...utils.mix_precision_utils import MixPrecisionOptimizer
__all__ = []
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1))
logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}")
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 0))
class HybridParallelClipGrad:
......
......@@ -31,8 +31,7 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
__all__ = []
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
logger.info(f"g_shard_use_reduce {g_shard_use_reduce}")
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1))
# assume only the first stage and last stage need data, and data consumption are ordred;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册