未验证 提交 284e0d12 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

shard grad reduce (#55495)

上级 5f376f00
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
###### ######
import os
from functools import reduce from functools import reduce
import paddle import paddle
...@@ -23,6 +23,16 @@ from paddle.distributed import fleet ...@@ -23,6 +23,16 @@ from paddle.distributed import fleet
from ...utils.log_util import logger from ...utils.log_util import logger
from ...utils.tensor_fusion_helper import fused_parameters from ...utils.tensor_fusion_helper import fused_parameters
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
logger.info(f"g_shard_use_reduce {g_shard_use_reduce}")
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1))
logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}")
if g_shard_norm_align_dp:
assert (
not g_shard_use_reduce
), "g_shard_norm_align_dp is not support if g_shard_use_reduce is true"
def _is_trainable(param): def _is_trainable(param):
return not param.stop_gradient return not param.stop_gradient
...@@ -203,18 +213,22 @@ class DygraphShardingOptimizer: ...@@ -203,18 +213,22 @@ class DygraphShardingOptimizer:
if g_var is not None: if g_var is not None:
g_var.scale_(1.0 / sharding_nrank) g_var.scale_(1.0 / sharding_nrank)
param_rank = self._param2rank[param.name] param_rank = self._param2rank[param.name]
if not g_shard_use_reduce:
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
g_var, g_var,
group=hcg.get_sharding_parallel_group(), group=hcg.get_sharding_parallel_group(),
sync_op=True, sync_op=True,
) )
else:
# TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp. # TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
# paddle.distributed.reduce( paddle.distributed.reduce(
# g_var, g_var,
# dst=hcg.get_sharding_parallel_group().ranks[param_rank], dst=hcg.get_sharding_parallel_group().ranks[
# group=hcg.get_sharding_parallel_group(), param_rank
# sync_op=True, ],
# ) group=hcg.get_sharding_parallel_group(),
sync_op=True,
)
def _sharding_sync_parameters(self): def _sharding_sync_parameters(self):
""" """
...@@ -294,8 +308,8 @@ class DygraphShardingOptimizer: ...@@ -294,8 +308,8 @@ class DygraphShardingOptimizer:
if hasattr(param, "main_grad") and param.main_grad is not None: if hasattr(param, "main_grad") and param.main_grad is not None:
grad_var = param.main_grad grad_var = param.main_grad
params_grads.append((param, grad_var)) params_grads.append((param, grad_var))
if hasattr(self._inner_opt._grad_clip, 'not_sharding_stage1'):
self._inner_opt._grad_clip.not_sharding_stage1 = False if g_shard_norm_align_dp:
params_grads = self._inner_opt._grad_clip(params_grads) params_grads = self._inner_opt._grad_clip(params_grads)
# set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize # set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize
self._set_inner_opt_attr('_grad_clip', None) self._set_inner_opt_attr('_grad_clip', None)
...@@ -313,6 +327,7 @@ class DygraphShardingOptimizer: ...@@ -313,6 +327,7 @@ class DygraphShardingOptimizer:
startup_program=None, startup_program=None,
params_grads=update_params_grads, params_grads=update_params_grads,
) )
if g_shard_norm_align_dp:
# restore the grad clip # restore the grad clip
self._set_inner_opt_attr('_grad_clip', origin_clip) self._set_inner_opt_attr('_grad_clip', origin_clip)
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import numpy as np import numpy as np
import paddle import paddle
...@@ -37,6 +39,9 @@ from ...utils.mix_precision_utils import MixPrecisionOptimizer ...@@ -37,6 +39,9 @@ from ...utils.mix_precision_utils import MixPrecisionOptimizer
__all__ = [] __all__ = []
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1))
logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}")
class HybridParallelClipGrad: class HybridParallelClipGrad:
def __init__(self, clip, hcg): def __init__(self, clip, hcg):
...@@ -44,6 +49,42 @@ class HybridParallelClipGrad: ...@@ -44,6 +49,42 @@ class HybridParallelClipGrad:
self._hcg = hcg self._hcg = hcg
self.not_sharding_stage1 = True self.not_sharding_stage1 = True
def _global_norm(self, global_norm_var_dist, global_norm_var_not_dist):
# sharding first
sharding_flag = (
self._hcg.get_sharding_parallel_world_size() > 1
and self._hcg.get_data_parallel_world_size() == 1
)
mp_flag = self._hcg.get_model_parallel_world_size() > 1
# add all reduce to get global norm of distributed params_and_grads
if sharding_flag and not g_shard_norm_align_dp:
# norm of mp distributed variable
if mp_flag:
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_sharding_parallel_group(),
)
# not dist only reduce among sharding group and pp group later
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group(),
)
# norm of mp distributed variable
if mp_flag:
# dist should reduce among sharding group、mp group、pp group
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_check_parallel_group(sharding_flag),
)
# add all reduce to get global norm of non-distributed params_and_grads in groups of pp
if self._hcg.get_pipe_parallel_world_size() > 1:
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_pipe_parallel_group(),
)
@no_grad() @no_grad()
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
sum_square_dist_fp16 = [] sum_square_dist_fp16 = []
...@@ -157,37 +198,7 @@ class HybridParallelClipGrad: ...@@ -157,37 +198,7 @@ class HybridParallelClipGrad:
+ global_norm_not_dist_fp32 + global_norm_not_dist_fp32
) )
# add all reduce to get global norm of distributed params_and_grads self._global_norm(global_norm_var_dist, global_norm_var_not_dist)
if self._hcg.get_model_parallel_world_size() > 1:
sharding_flag = False
if (
self._hcg.get_sharding_parallel_world_size() > 1
and self._hcg.get_data_parallel_world_size() == 1
):
sharding_flag = True
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_check_parallel_group(sharding_flag),
)
# add all reduce to get global norm of non-distributed params_and_grads in groups of pp
if self._hcg.get_pipe_parallel_world_size() > 1:
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_pipe_parallel_group(),
)
# In Sharding mode, param and grad is mapping different rank in optimizer.
# ClipGradByGlobalNorm need allreduce to get globol norm
# TODO(pangengzheng): remove the self.not_sharding_stage1 flag when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
if (
self._hcg.get_sharding_parallel_world_size() > 1
and self.not_sharding_stage1
):
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group(),
)
global_norm_var_fp32 = paddle.sqrt( global_norm_var_fp32 = paddle.sqrt(
global_norm_var_dist + global_norm_var_not_dist global_norm_var_dist + global_norm_var_not_dist
......
...@@ -44,6 +44,9 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer ...@@ -44,6 +44,9 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer
__all__ = [] __all__ = []
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
logger.info(f"g_shard_use_reduce {g_shard_use_reduce}")
# assume only the first stage and last stage need data, and data consumption is ordred # assume only the first stage and last stage need data, and data consumption is ordred
# to be replaced by real micro dataset from reader # to be replaced by real micro dataset from reader
...@@ -299,8 +302,12 @@ class PipelineParallel(MetaParallelBase): ...@@ -299,8 +302,12 @@ class PipelineParallel(MetaParallelBase):
assert hasattr(self, "optimizer") assert hasattr(self, "optimizer")
assert hasattr(self.optimizer, "_param2rank") assert hasattr(self.optimizer, "_param2rank")
_param2rank = self.optimizer._param2rank _param2rank = self.optimizer._param2rank
# Note: after sharding change to reduce operation, here need to be cleared
act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE act = (
HOOK_ACTION.ALL_REDUCE
if (dp or not g_shard_use_reduce)
else HOOK_ACTION.REDUCE
)
for model in models: for model in models:
# For virtual pipeline. Will separate parameters in different chunk into # For virtual pipeline. Will separate parameters in different chunk into
......
...@@ -207,21 +207,19 @@ class FusedCommBuffer: ...@@ -207,21 +207,19 @@ class FusedCommBuffer:
def _comm_grads(self): def _comm_grads(self):
assert self._all_params_checked_in assert self._all_params_checked_in
# Note: after sharding change to reduce operation here also need to be updated if self._act == HOOK_ACTION.ALL_REDUCE:
# if self._act == HOOK_ACTION.ALL_REDUCE:
# task = paddle.distributed.all_reduce(
# self.grad_storage, group=self._comm_group, sync_op=False
# )
# elif self._act == HOOK_ACTION.REDUCE:
# task = paddle.distributed.reduce(
# self.grad_storage,
# dst=self._dst,
# group=self._comm_group,
# sync_op=False,
# )
task = paddle.distributed.all_reduce( task = paddle.distributed.all_reduce(
self.grad_storage, group=self._comm_group, sync_op=False self.grad_storage, group=self._comm_group, sync_op=False
) )
elif self._act == HOOK_ACTION.REDUCE:
task = paddle.distributed.reduce(
self.grad_storage,
dst=self._dst,
group=self._comm_group,
sync_op=False,
)
self._task = task self._task = task
@imperative_base.no_grad @imperative_base.no_grad
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import unittest import unittest
from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus
...@@ -20,6 +21,13 @@ from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus ...@@ -20,6 +21,13 @@ from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestHybridParallel(TestMultipleGpus): class TestHybridParallel(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode # check sharding logic as well as the accuracy with single mode
def test_hybrid_parallel_sharding_logic(self): def test_hybrid_parallel_sharding_logic(self):
# test shard grad reduce
os.environ["FLAGS_shard_use_reduce"] = "1"
os.environ["FLAGS_shard_norm_align_dp"] = "0"
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
# test shard grad allreduce
os.environ["FLAGS_shard_use_reduce"] = "0"
os.environ["FLAGS_shard_norm_align_dp"] = "1"
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py') self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
def test_hybrid_parallel_sharding_tensor_fusion(self): def test_hybrid_parallel_sharding_tensor_fusion(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册