未验证 提交 75cc7057 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

dp and sharding coexist (#56096)

* dp and sharding coexist

* dp
上级 77da9106
......@@ -100,6 +100,15 @@ class DygraphShardingOptimizer:
elif not hasattr(p, "main_grad"):
p.clear_gradient(set_to_zero)
def filter_parameters(self, parameter_list, hcg):
sharding_parallel_rank = hcg.get_sharding_parallel_rank()
parameter_list = [
param
for param in parameter_list
if self._param2rank[param.name] == sharding_parallel_rank
]
return parameter_list
def _partition_parameters(self):
"""
Partitions parameters among sharding ranks.
......
......@@ -293,19 +293,17 @@ class HybridParallelClipGrad:
params_grads, global_norm_var_dist, global_norm_var_not_dist
)
def _comm_and_clip(
self, params_grads, global_norm_var_dist, global_norm_var_not_dist
):
# sharding first
sharding_flag = (
self._hcg.get_sharding_parallel_world_size() > 1
and self._hcg.get_data_parallel_world_size() == 1
)
def _global_norm(self, global_norm_var_dist, global_norm_var_not_dist):
sharding_flag = self._hcg.get_sharding_parallel_world_size() > 1
dp_flag = self._hcg.get_data_parallel_world_size() > 1
mp_flag = self._hcg.get_model_parallel_world_size() > 1
# add all reduce to get global norm of distributed params_and_grads
pp_flag = self._hcg.get_pipe_parallel_world_size() > 1
# not g_shard_norm_align_dp, grads are sharded among sharding group
if sharding_flag and not g_shard_norm_align_dp:
# norm of mp distributed variable
if mp_flag:
# dist should reduce among sharding group and mp group、pp group latter
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_sharding_parallel_group(),
......@@ -315,21 +313,40 @@ class HybridParallelClipGrad:
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group(),
)
# norm of mp distributed variable
if mp_flag:
# dist should reduce among sharding group、mp group、pp group
# the else branch would suffice, but this branch remains here for number precision backward compatibility
if not (dp_flag and sharding_flag):
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_check_parallel_group(sharding_flag),
)
else:
# global_norm_var_dist should all reduce among model parallel group and pp group
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_model_parallel_group(),
)
if pp_flag:
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_pipe_parallel_group(),
)
# add all reduce to get global norm of non-distributed params_and_grads in groups of pp
if self._hcg.get_pipe_parallel_world_size() > 1:
if pp_flag:
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_pipe_parallel_group(),
)
def _comm_and_clip(
self, params_grads, global_norm_var_dist, global_norm_var_not_dist
):
self._global_norm(global_norm_var_dist, global_norm_var_not_dist)
global_norm_var_fp32 = paddle.sqrt(
global_norm_var_dist + global_norm_var_not_dist
)
......@@ -554,15 +571,21 @@ class HybridParallelOptimizer:
@no_grad()
@framework.dygraph_only
def step(self):
parameters_list = obtain_optimizer_parameters_list(self._inner_opt)
parameter_list = list(obtain_optimizer_parameters_list(self._inner_opt))
dp_parameter_list = parameter_list
if self._sharding_enable:
assert isinstance(self._inner_opt, DygraphShardingOptimizer)
self._inner_opt.reduce_gradients(list(parameters_list), self._hcg)
self._inner_opt.reduce_gradients(parameter_list, self._hcg)
# dp sync later do not need to use global parameter list
if not g_shard_norm_align_dp:
dp_parameter_list = self._inner_opt.filter_parameters(
parameter_list, self._hcg
)
if self._dp_enable:
fused_allreduce_gradients(list(parameters_list), self._hcg)
fused_allreduce_gradients(dp_parameter_list, self._hcg)
self._step(parameters_list)
self._step(parameter_list)
@no_grad()
def minimize(
......@@ -574,14 +597,20 @@ class HybridParallelOptimizer:
parameter_list = (
parameters if parameters else self._inner_opt._parameter_list
)
parameter_list = list(parameter_list)
dp_parameter_list = parameter_list
# Here sharding should use global parameter list
if self._sharding_enable:
assert isinstance(self._inner_opt, DygraphShardingOptimizer)
self._inner_opt.reduce_gradients(list(parameter_list), self._hcg)
self._inner_opt.reduce_gradients(parameter_list, self._hcg)
# dp sync later do not need to use global parameter list
if not g_shard_norm_align_dp:
dp_parameter_list = self._inner_opt.filter_parameters(
parameter_list, self._hcg
)
if self._dp_enable:
fused_allreduce_gradients(list(parameter_list), self._hcg)
fused_allreduce_gradients(dp_parameter_list, self._hcg)
return self._inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册