未验证 提交 284e0d12 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

shard grad reduce (#55495)

上级 5f376f00
......@@ -13,7 +13,7 @@
# limitations under the License.
######
import os
from functools import reduce
import paddle
......@@ -23,6 +23,16 @@ from paddle.distributed import fleet
from ...utils.log_util import logger
from ...utils.tensor_fusion_helper import fused_parameters
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
logger.info(f"g_shard_use_reduce {g_shard_use_reduce}")
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1))
logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}")
if g_shard_norm_align_dp:
assert (
not g_shard_use_reduce
), "g_shard_norm_align_dp is not support if g_shard_use_reduce is true"
def _is_trainable(param):
return not param.stop_gradient
......@@ -203,18 +213,22 @@ class DygraphShardingOptimizer:
if g_var is not None:
g_var.scale_(1.0 / sharding_nrank)
param_rank = self._param2rank[param.name]
paddle.distributed.all_reduce(
g_var,
group=hcg.get_sharding_parallel_group(),
sync_op=True,
)
# TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
# paddle.distributed.reduce(
# g_var,
# dst=hcg.get_sharding_parallel_group().ranks[param_rank],
# group=hcg.get_sharding_parallel_group(),
# sync_op=True,
# )
if not g_shard_use_reduce:
paddle.distributed.all_reduce(
g_var,
group=hcg.get_sharding_parallel_group(),
sync_op=True,
)
else:
# TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
paddle.distributed.reduce(
g_var,
dst=hcg.get_sharding_parallel_group().ranks[
param_rank
],
group=hcg.get_sharding_parallel_group(),
sync_op=True,
)
def _sharding_sync_parameters(self):
"""
......@@ -294,11 +308,11 @@ class DygraphShardingOptimizer:
if hasattr(param, "main_grad") and param.main_grad is not None:
grad_var = param.main_grad
params_grads.append((param, grad_var))
if hasattr(self._inner_opt._grad_clip, 'not_sharding_stage1'):
self._inner_opt._grad_clip.not_sharding_stage1 = False
params_grads = self._inner_opt._grad_clip(params_grads)
# set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize
self._set_inner_opt_attr('_grad_clip', None)
if g_shard_norm_align_dp:
params_grads = self._inner_opt._grad_clip(params_grads)
# set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize
self._set_inner_opt_attr('_grad_clip', None)
rank_params = (
self._rank2params[self._sharding_rank]
if not self.tensor_fusion
......@@ -313,8 +327,9 @@ class DygraphShardingOptimizer:
startup_program=None,
params_grads=update_params_grads,
)
# restore the grad clip
self._set_inner_opt_attr('_grad_clip', origin_clip)
if g_shard_norm_align_dp:
# restore the grad clip
self._set_inner_opt_attr('_grad_clip', origin_clip)
# sync parameters across sharding ranks
self._sharding_sync_parameters()
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
......@@ -37,6 +39,9 @@ from ...utils.mix_precision_utils import MixPrecisionOptimizer
__all__ = []
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1))
logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}")
class HybridParallelClipGrad:
def __init__(self, clip, hcg):
......@@ -44,6 +49,42 @@ class HybridParallelClipGrad:
self._hcg = hcg
self.not_sharding_stage1 = True
def _global_norm(self, global_norm_var_dist, global_norm_var_not_dist):
# sharding first
sharding_flag = (
self._hcg.get_sharding_parallel_world_size() > 1
and self._hcg.get_data_parallel_world_size() == 1
)
mp_flag = self._hcg.get_model_parallel_world_size() > 1
# add all reduce to get global norm of distributed params_and_grads
if sharding_flag and not g_shard_norm_align_dp:
# norm of mp distributed variable
if mp_flag:
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_sharding_parallel_group(),
)
# not dist only reduce among sharding group and pp group later
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group(),
)
# norm of mp distributed variable
if mp_flag:
# dist should reduce among sharding group、mp group、pp group
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_check_parallel_group(sharding_flag),
)
# add all reduce to get global norm of non-distributed params_and_grads in groups of pp
if self._hcg.get_pipe_parallel_world_size() > 1:
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_pipe_parallel_group(),
)
@no_grad()
def _dygraph_clip(self, params_grads):
sum_square_dist_fp16 = []
......@@ -157,37 +198,7 @@ class HybridParallelClipGrad:
+ global_norm_not_dist_fp32
)
# add all reduce to get global norm of distributed params_and_grads
if self._hcg.get_model_parallel_world_size() > 1:
sharding_flag = False
if (
self._hcg.get_sharding_parallel_world_size() > 1
and self._hcg.get_data_parallel_world_size() == 1
):
sharding_flag = True
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_check_parallel_group(sharding_flag),
)
# add all reduce to get global norm of non-distributed params_and_grads in groups of pp
if self._hcg.get_pipe_parallel_world_size() > 1:
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_pipe_parallel_group(),
)
# In Sharding mode, param and grad is mapping different rank in optimizer.
# ClipGradByGlobalNorm need allreduce to get globol norm
# TODO(pangengzheng): remove the self.not_sharding_stage1 flag when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
if (
self._hcg.get_sharding_parallel_world_size() > 1
and self.not_sharding_stage1
):
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group(),
)
self._global_norm(global_norm_var_dist, global_norm_var_not_dist)
global_norm_var_fp32 = paddle.sqrt(
global_norm_var_dist + global_norm_var_not_dist
......
......@@ -44,6 +44,9 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer
__all__ = []
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
logger.info(f"g_shard_use_reduce {g_shard_use_reduce}")
# assume only the first stage and last stage need data, and data consumption is ordred
# to be replaced by real micro dataset from reader
......@@ -299,8 +302,12 @@ class PipelineParallel(MetaParallelBase):
assert hasattr(self, "optimizer")
assert hasattr(self.optimizer, "_param2rank")
_param2rank = self.optimizer._param2rank
act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE
# Note: after sharding change to reduce operation, here need to be cleared
act = (
HOOK_ACTION.ALL_REDUCE
if (dp or not g_shard_use_reduce)
else HOOK_ACTION.REDUCE
)
for model in models:
# For virtual pipeline. Will separate parameters in different chunk into
......
......@@ -207,21 +207,19 @@ class FusedCommBuffer:
def _comm_grads(self):
assert self._all_params_checked_in
# Note: after sharding change to reduce operation here also need to be updated
# if self._act == HOOK_ACTION.ALL_REDUCE:
# task = paddle.distributed.all_reduce(
# self.grad_storage, group=self._comm_group, sync_op=False
# )
# elif self._act == HOOK_ACTION.REDUCE:
# task = paddle.distributed.reduce(
# self.grad_storage,
# dst=self._dst,
# group=self._comm_group,
# sync_op=False,
# )
task = paddle.distributed.all_reduce(
self.grad_storage, group=self._comm_group, sync_op=False
)
if self._act == HOOK_ACTION.ALL_REDUCE:
task = paddle.distributed.all_reduce(
self.grad_storage, group=self._comm_group, sync_op=False
)
elif self._act == HOOK_ACTION.REDUCE:
task = paddle.distributed.reduce(
self.grad_storage,
dst=self._dst,
group=self._comm_group,
sync_op=False,
)
self._task = task
@imperative_base.no_grad
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus
......@@ -20,6 +21,13 @@ from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestHybridParallel(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_hybrid_parallel_sharding_logic(self):
# test shard grad reduce
os.environ["FLAGS_shard_use_reduce"] = "1"
os.environ["FLAGS_shard_norm_align_dp"] = "0"
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
# test shard grad allreduce
os.environ["FLAGS_shard_use_reduce"] = "0"
os.environ["FLAGS_shard_norm_align_dp"] = "1"
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
def test_hybrid_parallel_sharding_tensor_fusion(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册