From c036c5c0b9f28bcd7a48592b9f5dc78046837924 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Fri, 28 Oct 2022 22:45:35 +0800 Subject: [PATCH] Add fused_allreduce_gradients_with_group for PPFleetX (#47447) * add fused_allreduce_gradients_with_group * add scale * fix ci --- .../fleet/utils/hybrid_parallel_util.py | 50 +++++++++++++------ 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index fec3e455f8..c88a967035 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -26,7 +26,7 @@ from .log_util import logger __all__ = [] -def _apply_collective_grads(parameters, comm_group): +def _apply_collective_grads(parameters, comm_group, bucket_size, scale=None): grad_var_set = set() grad_vars = [] sparse_grad_vars = [] @@ -41,28 +41,35 @@ def _apply_collective_grads(parameters, comm_group): assert g_var not in grad_var_set grad_var_set.add(g_var) - coalesced_grads_and_vars = build_groups(grad_vars, 128 * 1024 * 1024) + coalesced_grads_and_vars = build_groups(grad_vars, bucket_size) nranks = ( paddle.distributed.get_world_size() if comm_group is None else comm_group.nranks ) + + scale = nranks if scale is None else 1.0 / scale + scale = None if scale == 1.0 else scale + for coalesced_grad, _, _ in coalesced_grads_and_vars: # need to div nranks - div_factor = paddle.to_tensor(nranks, dtype=coalesced_grad.dtype) - paddle.fluid.framework._dygraph_tracer().trace_op( - type="elementwise_div", - inputs={'X': coalesced_grad, 'Y': div_factor}, - outputs={'Out': coalesced_grad}, - attrs={'axis': -1}, - ) + if scale is not None: + div_factor = paddle.to_tensor(scale, dtype=coalesced_grad.dtype) + paddle.fluid.framework._dygraph_tracer().trace_op( + type="elementwise_div", + inputs={'X': coalesced_grad, 'Y': div_factor}, + outputs={'Out': coalesced_grad}, + attrs={'axis': -1}, + ) paddle.distributed.all_reduce(coalesced_grad, group=comm_group) _split_tensors(coalesced_grads_and_vars) -def _apply_collective_grads_eager(parameters, comm_group): +def _apply_collective_grads_eager( + parameters, comm_group, bucket_size, scale=None +): grad_var_set = set() grad_vars = [] @@ -76,16 +83,21 @@ def _apply_collective_grads_eager(parameters, comm_group): assert g_var not in grad_var_set grad_var_set.add(g_var) - coalesced_grads_and_vars = build_groups(grad_vars, 128 * 1024 * 1024) + coalesced_grads_and_vars = build_groups(grad_vars, bucket_size) nranks = ( paddle.distributed.get_world_size() if comm_group is None else comm_group.nranks ) + + scale = 1.0 / nranks if scale is None else scale + scale = None if scale == 1.0 else scale + for coalesced_grad, _, _ in coalesced_grads_and_vars: # need to div nranks - coalesced_grad.scale_(1.0 / nranks) + if scale is not None: + coalesced_grad.scale_(scale) paddle.distributed.all_reduce(coalesced_grad, group=comm_group) _split_tensors(coalesced_grads_and_vars) @@ -172,16 +184,22 @@ def broadcast_dp_parameters(model, hcg): ) -def fused_allreduce_gradients(parameter_list, hcg): - data_parallel_group = None if hcg is None else hcg.get_data_parallel_group() - logger.debug("dp start fuse allreduce gradients") +def fused_allreduce_gradients_with_group( + parameter_list, group, bucket_size=128 * 1024 * 1024, scale=None +): apply_func = ( _apply_collective_grads_eager if in_dygraph_mode() else _apply_collective_grads ) with framework.no_grad(): - apply_func(parameter_list, data_parallel_group) + apply_func(parameter_list, group, bucket_size) + + +def fused_allreduce_gradients(parameter_list, hcg): + data_parallel_group = None if hcg is None else hcg.get_data_parallel_group() + logger.debug("dp start fuse allreduce gradients") + fused_allreduce_gradients_with_group(parameter_list, data_parallel_group) def sharding_reduce_gradients(parameter_list, hcg): -- GitLab