未验证 提交 64f25adf 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

new_frl_shard_reduce (#55353)

* new_frl_shard_redece

* add mp guard

* add test
上级 e1545af4
...@@ -24,6 +24,16 @@ from paddle.fluid.dygraph import base as imperative_base ...@@ -24,6 +24,16 @@ from paddle.fluid.dygraph import base as imperative_base
from ...utils.log_util import logger from ...utils.log_util import logger
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
logger.info(f"g_shard_use_reduce {g_shard_use_reduce}")
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1))
logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}")
if g_shard_norm_align_dp:
assert (
not g_shard_use_reduce
), "g_shard_norm_align_dp is not support if g_shard_use_reduce is true"
def _is_trainable(param): def _is_trainable(param):
return not param.stop_gradient return not param.stop_gradient
...@@ -160,18 +170,22 @@ class DygraphShardingOptimizer: ...@@ -160,18 +170,22 @@ class DygraphShardingOptimizer:
if g_var is not None: if g_var is not None:
g_var.scale_(1.0 / sharding_nrank) g_var.scale_(1.0 / sharding_nrank)
param_rank = self._param2rank[param.name] param_rank = self._param2rank[param.name]
paddle.distributed.all_reduce( if not g_shard_use_reduce:
g_var, paddle.distributed.all_reduce(
group=hcg.get_sharding_parallel_group(), g_var,
sync_op=True, group=hcg.get_sharding_parallel_group(),
) sync_op=True,
# TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp. )
# paddle.distributed.reduce( else:
# g_var, # TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
# dst=hcg.get_sharding_parallel_group().ranks[param_rank], paddle.distributed.reduce(
# group=hcg.get_sharding_parallel_group(), g_var,
# sync_op=True, dst=hcg.get_sharding_parallel_group().ranks[
# ) param_rank
],
group=hcg.get_sharding_parallel_group(),
sync_op=True,
)
def _sharding_sync_parameters(self): def _sharding_sync_parameters(self):
""" """
...@@ -247,11 +261,10 @@ class DygraphShardingOptimizer: ...@@ -247,11 +261,10 @@ class DygraphShardingOptimizer:
if hasattr(param, "main_grad") and param.main_grad is not None: if hasattr(param, "main_grad") and param.main_grad is not None:
grad_var = param.main_grad grad_var = param.main_grad
params_grads.append((param, grad_var)) params_grads.append((param, grad_var))
if hasattr(self._inner_opt._grad_clip, 'not_sharding_stage1'): if g_shard_norm_align_dp:
self._inner_opt._grad_clip.not_sharding_stage1 = False params_grads = self._inner_opt._grad_clip(params_grads)
params_grads = self._inner_opt._grad_clip(params_grads) # set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize
# set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize self._set_inner_opt_attr('_grad_clip', None)
self._set_inner_opt_attr('_grad_clip', None)
update_param_names = [ update_param_names = [
p.name for p in self._rank2params[self._sharding_rank] p.name for p in self._rank2params[self._sharding_rank]
] ]
...@@ -263,8 +276,9 @@ class DygraphShardingOptimizer: ...@@ -263,8 +276,9 @@ class DygraphShardingOptimizer:
startup_program=None, startup_program=None,
params_grads=update_params_grads, params_grads=update_params_grads,
) )
# restore the grad clip if g_shard_norm_align_dp:
self._set_inner_opt_attr('_grad_clip', origin_clip) # restore the grad clip
self._set_inner_opt_attr('_grad_clip', origin_clip)
# sync parameters across sharding ranks # sync parameters across sharding ranks
self._sharding_sync_parameters() self._sharding_sync_parameters()
......
...@@ -41,12 +41,14 @@ from ...utils.mix_precision_utils import MixPrecisionOptimizer ...@@ -41,12 +41,14 @@ from ...utils.mix_precision_utils import MixPrecisionOptimizer
__all__ = [] __all__ = []
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1))
logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}")
class HybridParallelClipGrad: class HybridParallelClipGrad:
def __init__(self, clip, hcg): def __init__(self, clip, hcg):
self._clip = clip self._clip = clip
self._hcg = hcg self._hcg = hcg
self.not_sharding_stage1 = True
self._vpp_chunk_num = None self._vpp_chunk_num = None
self._force_align_vpp_grad_sum_order = distutils.util.strtobool( self._force_align_vpp_grad_sum_order = distutils.util.strtobool(
os.getenv('FLAGS_force_align_vpp_grad_sum_order', '0') os.getenv('FLAGS_force_align_vpp_grad_sum_order', '0')
...@@ -85,6 +87,7 @@ class HybridParallelClipGrad: ...@@ -85,6 +87,7 @@ class HybridParallelClipGrad:
for p, g in params_grads: for p, g in params_grads:
if g is None: if g is None:
continue continue
not_shared_enable = (not hasattr(p, 'is_firstly_shared')) or ( not_shared_enable = (not hasattr(p, 'is_firstly_shared')) or (
hasattr(p, 'is_firstly_shared') hasattr(p, 'is_firstly_shared')
and getattr(p, 'is_firstly_shared', True) and getattr(p, 'is_firstly_shared', True)
...@@ -294,14 +297,28 @@ class HybridParallelClipGrad: ...@@ -294,14 +297,28 @@ class HybridParallelClipGrad:
def _comm_and_clip( def _comm_and_clip(
self, params_grads, global_norm_var_dist, global_norm_var_not_dist self, params_grads, global_norm_var_dist, global_norm_var_not_dist
): ):
# sharding first
sharding_flag = (
self._hcg.get_sharding_parallel_world_size() > 1
and self._hcg.get_data_parallel_world_size() == 1
)
mp_flag = self._hcg.get_model_parallel_world_size() > 1
# add all reduce to get global norm of distributed params_and_grads # add all reduce to get global norm of distributed params_and_grads
if self._hcg.get_model_parallel_world_size() > 1: if sharding_flag and not g_shard_norm_align_dp:
sharding_flag = False # norm of mp distributed variable
if ( if mp_flag:
self._hcg.get_sharding_parallel_world_size() > 1 paddle.distributed.all_reduce(
and self._hcg.get_data_parallel_world_size() == 1 global_norm_var_dist,
): group=self._hcg.get_sharding_parallel_group(),
sharding_flag = True )
# not dist only reduce among sharding group and pp group later
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group(),
)
# norm of mp distributed variable
if mp_flag:
# dist should reduce among sharding group、mp group、pp group
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
global_norm_var_dist, global_norm_var_dist,
group=self._hcg.get_check_parallel_group(sharding_flag), group=self._hcg.get_check_parallel_group(sharding_flag),
...@@ -314,18 +331,6 @@ class HybridParallelClipGrad: ...@@ -314,18 +331,6 @@ class HybridParallelClipGrad:
group=self._hcg.get_pipe_parallel_group(), group=self._hcg.get_pipe_parallel_group(),
) )
# In Sharding mode, param and grad is mapping different rank in optimizer.
# ClipGradByGlobalNorm need allreduce to get globol norm
# TODO(pangengzheng): remove the self.not_sharding_stage1 flag when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
if (
self._hcg.get_sharding_parallel_world_size() > 1
and self.not_sharding_stage1
):
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group(),
)
global_norm_var_fp32 = paddle.sqrt( global_norm_var_fp32 = paddle.sqrt(
global_norm_var_dist + global_norm_var_not_dist global_norm_var_dist + global_norm_var_not_dist
) )
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import os
import paddle import paddle
from paddle import framework from paddle import framework
...@@ -29,6 +31,9 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size ...@@ -29,6 +31,9 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
__all__ = [] __all__ = []
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
logger.info(f"g_shard_use_reduce {g_shard_use_reduce}")
# assume only the first stage and last stage need data, and data consumption are ordred; # assume only the first stage and last stage need data, and data consumption are ordred;
# to be replaced by real micro dataset from reader # to be replaced by real micro dataset from reader
...@@ -261,8 +266,12 @@ class PipelineParallel(MetaParallelBase): ...@@ -261,8 +266,12 @@ class PipelineParallel(MetaParallelBase):
assert hasattr(self, "optimizer") assert hasattr(self, "optimizer")
assert hasattr(self.optimizer, "_param2rank") assert hasattr(self.optimizer, "_param2rank")
_param2rank = self.optimizer._param2rank _param2rank = self.optimizer._param2rank
# Note: after sharding change to reduce operation, here need to be cleared
act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE act = (
HOOK_ACTION.ALL_REDUCE
if (dp or not g_shard_use_reduce)
else HOOK_ACTION.REDUCE
)
for model in models: for model in models:
# For virtual pipeline. Will separate parameters in different chunk into # For virtual pipeline. Will separate parameters in different chunk into
......
...@@ -246,21 +246,19 @@ class FusedCommBuffer: ...@@ -246,21 +246,19 @@ class FusedCommBuffer:
def _comm_grads(self): def _comm_grads(self):
assert self._all_params_checked_in assert self._all_params_checked_in
# Note: after sharding change to reduce operation here also need to be updated if self._act == HOOK_ACTION.ALL_REDUCE:
# if self._act == HOOK_ACTION.ALL_REDUCE: task = paddle.distributed.all_reduce(
# task = paddle.distributed.all_reduce( self.grad_storage, group=self._comm_group, sync_op=False
# self.grad_storage, group=self._comm_group, sync_op=False )
# )
# elif self._act == HOOK_ACTION.REDUCE: elif self._act == HOOK_ACTION.REDUCE:
# task = paddle.distributed.reduce( task = paddle.distributed.reduce(
# self.grad_storage, self.grad_storage,
# dst=self._dst, dst=self._dst,
# group=self._comm_group, group=self._comm_group,
# sync_op=False, sync_op=False,
# ) )
task = paddle.distributed.all_reduce(
self.grad_storage, group=self._comm_group, sync_op=False
)
self._task = task self._task = task
@imperative_base.no_grad @imperative_base.no_grad
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import unittest import unittest
from test_parallel_dygraph_dataparallel import TestMultipleGpus from test_parallel_dygraph_dataparallel import TestMultipleGpus
...@@ -21,6 +22,13 @@ class TestHybridParallel(TestMultipleGpus): ...@@ -21,6 +22,13 @@ class TestHybridParallel(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode # check sharding logic as well as the accuracy with single mode
def test_hybrid_parallel_sharding_logic(self): def test_hybrid_parallel_sharding_logic(self):
# test shard grad reduce
os.environ["FLAGS_shard_use_reduce"] = "1"
os.environ["FLAGS_shard_norm_align_dp"] = "0"
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
# test shard grad allreduce
os.environ["FLAGS_shard_use_reduce"] = "0"
os.environ["FLAGS_shard_norm_align_dp"] = "1"
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py') self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
def test_hybrid_parallel_sharding_state_dict(self): def test_hybrid_parallel_sharding_state_dict(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册