未验证 提交 64f25adf 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

new_frl_shard_reduce (#55353)

* new_frl_shard_redece

* add mp guard

* add test
上级 e1545af4
...@@ -24,6 +24,16 @@ from paddle.fluid.dygraph import base as imperative_base ...@@ -24,6 +24,16 @@ from paddle.fluid.dygraph import base as imperative_base
from ...utils.log_util import logger from ...utils.log_util import logger
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
logger.info(f"g_shard_use_reduce {g_shard_use_reduce}")
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1))
logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}")
if g_shard_norm_align_dp:
assert (
not g_shard_use_reduce
), "g_shard_norm_align_dp is not support if g_shard_use_reduce is true"
def _is_trainable(param): def _is_trainable(param):
return not param.stop_gradient return not param.stop_gradient
...@@ -160,18 +170,22 @@ class DygraphShardingOptimizer: ...@@ -160,18 +170,22 @@ class DygraphShardingOptimizer:
if g_var is not None: if g_var is not None:
g_var.scale_(1.0 / sharding_nrank) g_var.scale_(1.0 / sharding_nrank)
param_rank = self._param2rank[param.name] param_rank = self._param2rank[param.name]
if not g_shard_use_reduce:
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
g_var, g_var,
group=hcg.get_sharding_parallel_group(), group=hcg.get_sharding_parallel_group(),
sync_op=True, sync_op=True,
) )
else:
# TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp. # TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
# paddle.distributed.reduce( paddle.distributed.reduce(
# g_var, g_var,
# dst=hcg.get_sharding_parallel_group().ranks[param_rank], dst=hcg.get_sharding_parallel_group().ranks[
# group=hcg.get_sharding_parallel_group(), param_rank
# sync_op=True, ],
# ) group=hcg.get_sharding_parallel_group(),
sync_op=True,
)
def _sharding_sync_parameters(self): def _sharding_sync_parameters(self):
""" """
...@@ -247,8 +261,7 @@ class DygraphShardingOptimizer: ...@@ -247,8 +261,7 @@ class DygraphShardingOptimizer:
if hasattr(param, "main_grad") and param.main_grad is not None: if hasattr(param, "main_grad") and param.main_grad is not None:
grad_var = param.main_grad grad_var = param.main_grad
params_grads.append((param, grad_var)) params_grads.append((param, grad_var))
if hasattr(self._inner_opt._grad_clip, 'not_sharding_stage1'): if g_shard_norm_align_dp:
self._inner_opt._grad_clip.not_sharding_stage1 = False
params_grads = self._inner_opt._grad_clip(params_grads) params_grads = self._inner_opt._grad_clip(params_grads)
# set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize # set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize
self._set_inner_opt_attr('_grad_clip', None) self._set_inner_opt_attr('_grad_clip', None)
...@@ -263,6 +276,7 @@ class DygraphShardingOptimizer: ...@@ -263,6 +276,7 @@ class DygraphShardingOptimizer:
startup_program=None, startup_program=None,
params_grads=update_params_grads, params_grads=update_params_grads,
) )
if g_shard_norm_align_dp:
# restore the grad clip # restore the grad clip
self._set_inner_opt_attr('_grad_clip', origin_clip) self._set_inner_opt_attr('_grad_clip', origin_clip)
......
...@@ -41,12 +41,14 @@ from ...utils.mix_precision_utils import MixPrecisionOptimizer ...@@ -41,12 +41,14 @@ from ...utils.mix_precision_utils import MixPrecisionOptimizer
__all__ = [] __all__ = []
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1))
logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}")
class HybridParallelClipGrad: class HybridParallelClipGrad:
def __init__(self, clip, hcg): def __init__(self, clip, hcg):
self._clip = clip self._clip = clip
self._hcg = hcg self._hcg = hcg
self.not_sharding_stage1 = True
self._vpp_chunk_num = None self._vpp_chunk_num = None
self._force_align_vpp_grad_sum_order = distutils.util.strtobool( self._force_align_vpp_grad_sum_order = distutils.util.strtobool(
os.getenv('FLAGS_force_align_vpp_grad_sum_order', '0') os.getenv('FLAGS_force_align_vpp_grad_sum_order', '0')
...@@ -85,6 +87,7 @@ class HybridParallelClipGrad: ...@@ -85,6 +87,7 @@ class HybridParallelClipGrad:
for p, g in params_grads: for p, g in params_grads:
if g is None: if g is None:
continue continue
not_shared_enable = (not hasattr(p, 'is_firstly_shared')) or ( not_shared_enable = (not hasattr(p, 'is_firstly_shared')) or (
hasattr(p, 'is_firstly_shared') hasattr(p, 'is_firstly_shared')
and getattr(p, 'is_firstly_shared', True) and getattr(p, 'is_firstly_shared', True)
...@@ -294,14 +297,28 @@ class HybridParallelClipGrad: ...@@ -294,14 +297,28 @@ class HybridParallelClipGrad:
def _comm_and_clip( def _comm_and_clip(
self, params_grads, global_norm_var_dist, global_norm_var_not_dist self, params_grads, global_norm_var_dist, global_norm_var_not_dist
): ):
# add all reduce to get global norm of distributed params_and_grads # sharding first
if self._hcg.get_model_parallel_world_size() > 1: sharding_flag = (
sharding_flag = False
if (
self._hcg.get_sharding_parallel_world_size() > 1 self._hcg.get_sharding_parallel_world_size() > 1
and self._hcg.get_data_parallel_world_size() == 1 and self._hcg.get_data_parallel_world_size() == 1
): )
sharding_flag = True mp_flag = self._hcg.get_model_parallel_world_size() > 1
# add all reduce to get global norm of distributed params_and_grads
if sharding_flag and not g_shard_norm_align_dp:
# norm of mp distributed variable
if mp_flag:
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_sharding_parallel_group(),
)
# not dist only reduce among sharding group and pp group later
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group(),
)
# norm of mp distributed variable
if mp_flag:
# dist should reduce among sharding group、mp group、pp group
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
global_norm_var_dist, global_norm_var_dist,
group=self._hcg.get_check_parallel_group(sharding_flag), group=self._hcg.get_check_parallel_group(sharding_flag),
...@@ -314,18 +331,6 @@ class HybridParallelClipGrad: ...@@ -314,18 +331,6 @@ class HybridParallelClipGrad:
group=self._hcg.get_pipe_parallel_group(), group=self._hcg.get_pipe_parallel_group(),
) )
# In Sharding mode, param and grad is mapping different rank in optimizer.
# ClipGradByGlobalNorm need allreduce to get globol norm
# TODO(pangengzheng): remove the self.not_sharding_stage1 flag when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
if (
self._hcg.get_sharding_parallel_world_size() > 1
and self.not_sharding_stage1
):
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group(),
)
global_norm_var_fp32 = paddle.sqrt( global_norm_var_fp32 = paddle.sqrt(
global_norm_var_dist + global_norm_var_not_dist global_norm_var_dist + global_norm_var_not_dist
) )
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import os
import paddle import paddle
from paddle import framework from paddle import framework
...@@ -29,6 +31,9 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size ...@@ -29,6 +31,9 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
__all__ = [] __all__ = []
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
logger.info(f"g_shard_use_reduce {g_shard_use_reduce}")
# assume only the first stage and last stage need data, and data consumption are ordred; # assume only the first stage and last stage need data, and data consumption are ordred;
# to be replaced by real micro dataset from reader # to be replaced by real micro dataset from reader
...@@ -261,8 +266,12 @@ class PipelineParallel(MetaParallelBase): ...@@ -261,8 +266,12 @@ class PipelineParallel(MetaParallelBase):
assert hasattr(self, "optimizer") assert hasattr(self, "optimizer")
assert hasattr(self.optimizer, "_param2rank") assert hasattr(self.optimizer, "_param2rank")
_param2rank = self.optimizer._param2rank _param2rank = self.optimizer._param2rank
# Note: after sharding change to reduce operation, here need to be cleared
act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE act = (
HOOK_ACTION.ALL_REDUCE
if (dp or not g_shard_use_reduce)
else HOOK_ACTION.REDUCE
)
for model in models: for model in models:
# For virtual pipeline. Will separate parameters in different chunk into # For virtual pipeline. Will separate parameters in different chunk into
......
...@@ -246,21 +246,19 @@ class FusedCommBuffer: ...@@ -246,21 +246,19 @@ class FusedCommBuffer:
def _comm_grads(self): def _comm_grads(self):
assert self._all_params_checked_in assert self._all_params_checked_in
# Note: after sharding change to reduce operation here also need to be updated if self._act == HOOK_ACTION.ALL_REDUCE:
# if self._act == HOOK_ACTION.ALL_REDUCE:
# task = paddle.distributed.all_reduce(
# self.grad_storage, group=self._comm_group, sync_op=False
# )
# elif self._act == HOOK_ACTION.REDUCE:
# task = paddle.distributed.reduce(
# self.grad_storage,
# dst=self._dst,
# group=self._comm_group,
# sync_op=False,
# )
task = paddle.distributed.all_reduce( task = paddle.distributed.all_reduce(
self.grad_storage, group=self._comm_group, sync_op=False self.grad_storage, group=self._comm_group, sync_op=False
) )
elif self._act == HOOK_ACTION.REDUCE:
task = paddle.distributed.reduce(
self.grad_storage,
dst=self._dst,
group=self._comm_group,
sync_op=False,
)
self._task = task self._task = task
@imperative_base.no_grad @imperative_base.no_grad
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import unittest import unittest
from test_parallel_dygraph_dataparallel import TestMultipleGpus from test_parallel_dygraph_dataparallel import TestMultipleGpus
...@@ -21,6 +22,13 @@ class TestHybridParallel(TestMultipleGpus): ...@@ -21,6 +22,13 @@ class TestHybridParallel(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode # check sharding logic as well as the accuracy with single mode
def test_hybrid_parallel_sharding_logic(self): def test_hybrid_parallel_sharding_logic(self):
# test shard grad reduce
os.environ["FLAGS_shard_use_reduce"] = "1"
os.environ["FLAGS_shard_norm_align_dp"] = "0"
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
# test shard grad allreduce
os.environ["FLAGS_shard_use_reduce"] = "0"
os.environ["FLAGS_shard_norm_align_dp"] = "1"
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py') self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
def test_hybrid_parallel_sharding_state_dict(self): def test_hybrid_parallel_sharding_state_dict(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册