未验证 提交 75cc7057 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

dp and sharding coexist (#56096)

* dp and sharding coexist

* dp
上级 77da9106
...@@ -100,6 +100,15 @@ class DygraphShardingOptimizer: ...@@ -100,6 +100,15 @@ class DygraphShardingOptimizer:
elif not hasattr(p, "main_grad"): elif not hasattr(p, "main_grad"):
p.clear_gradient(set_to_zero) p.clear_gradient(set_to_zero)
def filter_parameters(self, parameter_list, hcg):
sharding_parallel_rank = hcg.get_sharding_parallel_rank()
parameter_list = [
param
for param in parameter_list
if self._param2rank[param.name] == sharding_parallel_rank
]
return parameter_list
def _partition_parameters(self): def _partition_parameters(self):
""" """
Partitions parameters among sharding ranks. Partitions parameters among sharding ranks.
......
...@@ -293,19 +293,17 @@ class HybridParallelClipGrad: ...@@ -293,19 +293,17 @@ class HybridParallelClipGrad:
params_grads, global_norm_var_dist, global_norm_var_not_dist params_grads, global_norm_var_dist, global_norm_var_not_dist
) )
def _comm_and_clip( def _global_norm(self, global_norm_var_dist, global_norm_var_not_dist):
self, params_grads, global_norm_var_dist, global_norm_var_not_dist sharding_flag = self._hcg.get_sharding_parallel_world_size() > 1
): dp_flag = self._hcg.get_data_parallel_world_size() > 1
# sharding first
sharding_flag = (
self._hcg.get_sharding_parallel_world_size() > 1
and self._hcg.get_data_parallel_world_size() == 1
)
mp_flag = self._hcg.get_model_parallel_world_size() > 1 mp_flag = self._hcg.get_model_parallel_world_size() > 1
# add all reduce to get global norm of distributed params_and_grads pp_flag = self._hcg.get_pipe_parallel_world_size() > 1
# not g_shard_norm_align_dp, grads are sharded among sharding group
if sharding_flag and not g_shard_norm_align_dp: if sharding_flag and not g_shard_norm_align_dp:
# norm of mp distributed variable # norm of mp distributed variable
if mp_flag: if mp_flag:
# dist should reduce among sharding group and mp group、pp group latter
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
global_norm_var_dist, global_norm_var_dist,
group=self._hcg.get_sharding_parallel_group(), group=self._hcg.get_sharding_parallel_group(),
...@@ -315,21 +313,40 @@ class HybridParallelClipGrad: ...@@ -315,21 +313,40 @@ class HybridParallelClipGrad:
global_norm_var_not_dist, global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group(), group=self._hcg.get_sharding_parallel_group(),
) )
# norm of mp distributed variable # norm of mp distributed variable
if mp_flag: if mp_flag:
# dist should reduce among sharding group、mp group、pp group # the else branch would suffice, but this branch remains here for number precision backward compatibility
paddle.distributed.all_reduce( if not (dp_flag and sharding_flag):
global_norm_var_dist, paddle.distributed.all_reduce(
group=self._hcg.get_check_parallel_group(sharding_flag), global_norm_var_dist,
) group=self._hcg.get_check_parallel_group(sharding_flag),
)
else:
# global_norm_var_dist should all reduce among model parallel group and pp group
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_model_parallel_group(),
)
if pp_flag:
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_pipe_parallel_group(),
)
# add all reduce to get global norm of non-distributed params_and_grads in groups of pp # add all reduce to get global norm of non-distributed params_and_grads in groups of pp
if self._hcg.get_pipe_parallel_world_size() > 1: if pp_flag:
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
global_norm_var_not_dist, global_norm_var_not_dist,
group=self._hcg.get_pipe_parallel_group(), group=self._hcg.get_pipe_parallel_group(),
) )
def _comm_and_clip(
self, params_grads, global_norm_var_dist, global_norm_var_not_dist
):
self._global_norm(global_norm_var_dist, global_norm_var_not_dist)
global_norm_var_fp32 = paddle.sqrt( global_norm_var_fp32 = paddle.sqrt(
global_norm_var_dist + global_norm_var_not_dist global_norm_var_dist + global_norm_var_not_dist
) )
...@@ -554,15 +571,21 @@ class HybridParallelOptimizer: ...@@ -554,15 +571,21 @@ class HybridParallelOptimizer:
@no_grad() @no_grad()
@framework.dygraph_only @framework.dygraph_only
def step(self): def step(self):
parameters_list = obtain_optimizer_parameters_list(self._inner_opt) parameter_list = list(obtain_optimizer_parameters_list(self._inner_opt))
dp_parameter_list = parameter_list
if self._sharding_enable: if self._sharding_enable:
assert isinstance(self._inner_opt, DygraphShardingOptimizer) assert isinstance(self._inner_opt, DygraphShardingOptimizer)
self._inner_opt.reduce_gradients(list(parameters_list), self._hcg) self._inner_opt.reduce_gradients(parameter_list, self._hcg)
# dp sync later do not need to use global parameter list
if not g_shard_norm_align_dp:
dp_parameter_list = self._inner_opt.filter_parameters(
parameter_list, self._hcg
)
if self._dp_enable: if self._dp_enable:
fused_allreduce_gradients(list(parameters_list), self._hcg) fused_allreduce_gradients(dp_parameter_list, self._hcg)
self._step(parameters_list) self._step(parameter_list)
@no_grad() @no_grad()
def minimize( def minimize(
...@@ -574,14 +597,20 @@ class HybridParallelOptimizer: ...@@ -574,14 +597,20 @@ class HybridParallelOptimizer:
parameter_list = ( parameter_list = (
parameters if parameters else self._inner_opt._parameter_list parameters if parameters else self._inner_opt._parameter_list
) )
parameter_list = list(parameter_list)
dp_parameter_list = parameter_list
# Here sharding should use global parameter list # Here sharding should use global parameter list
if self._sharding_enable: if self._sharding_enable:
assert isinstance(self._inner_opt, DygraphShardingOptimizer) assert isinstance(self._inner_opt, DygraphShardingOptimizer)
self._inner_opt.reduce_gradients(list(parameter_list), self._hcg) self._inner_opt.reduce_gradients(parameter_list, self._hcg)
# dp sync later do not need to use global parameter list
if not g_shard_norm_align_dp:
dp_parameter_list = self._inner_opt.filter_parameters(
parameter_list, self._hcg
)
if self._dp_enable: if self._dp_enable:
fused_allreduce_gradients(list(parameter_list), self._hcg) fused_allreduce_gradients(dp_parameter_list, self._hcg)
return self._inner_opt.minimize( return self._inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set loss, startup_program, parameter_list, no_grad_set
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册