diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index e7bd434b94fd32c19daa99defefe979058e99355..7e95bfe7f331aadd3b5aeb6adc99a1a5706d7c6b 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -18,7 +18,11 @@ import numpy as np from paddle import framework import paddle from paddle.fluid import core -from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, build_groups +from paddle.fluid.dygraph.parallel import ( + _split_tensors, + sync_params_buffers, + build_groups, +) from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph from collections import OrderedDict from .log_util import logger @@ -26,7 +30,7 @@ from .log_util import logger __all__ = [] -def _apply_collective_grads(parameters, comm_group): +def _apply_collective_grads(parameters, comm_group, bucket_size, scale=None): grad_var_set = set() grad_vars = [] sparse_grad_vars = [] @@ -34,52 +38,70 @@ def _apply_collective_grads(parameters, comm_group): for param in parameters: if param.trainable and (param._grad_ivar() is not None): g_var = param._grad_ivar() - assert not g_var._is_sparse( + assert ( + not g_var._is_sparse() ), "Now, it doesn't support sparse parameters" grad_vars.append(g_var) assert g_var not in grad_var_set grad_var_set.add(g_var) - coalesced_grads_and_vars = build_groups(grad_vars, 128 * 1024 * 1024) + coalesced_grads_and_vars = build_groups(grad_vars, bucket_size) + + nranks = ( + paddle.distributed.get_world_size() + if comm_group is None + else comm_group.nranks + ) + + scale = nranks if scale is None else 1.0 / scale + scale = None if scale == 1.0 else scale - nranks = paddle.distributed.get_world_size( - ) if comm_group is None else comm_group.nranks for coalesced_grad, _, _ in coalesced_grads_and_vars: # need to div nranks - div_factor = paddle.to_tensor(nranks, dtype=coalesced_grad.dtype) - paddle.fluid.framework._dygraph_tracer().trace_op( - type="elementwise_div", - inputs={ - 'X': coalesced_grad, - 'Y': div_factor - }, - outputs={'Out': coalesced_grad}, - attrs={'axis': -1}) + if scale is not None: + div_factor = paddle.to_tensor(scale, dtype=coalesced_grad.dtype) + paddle.fluid.framework._dygraph_tracer().trace_op( + type="elementwise_div", + inputs={'X': coalesced_grad, 'Y': div_factor}, + outputs={'Out': coalesced_grad}, + attrs={'axis': -1}, + ) paddle.distributed.all_reduce(coalesced_grad, group=comm_group) _split_tensors(coalesced_grads_and_vars) -def _apply_collective_grads_eager(parameters, comm_group): +def _apply_collective_grads_eager( + parameters, comm_group, bucket_size, scale=None +): grad_var_set = set() grad_vars = [] for param in parameters: if param.trainable and (param._grad_ivar() is not None): g_var = param._grad_ivar() - assert not g_var.is_sparse( + assert ( + not g_var.is_sparse() ), "Now, it doesn't support sparse parameters" grad_vars.append(g_var) assert g_var not in grad_var_set grad_var_set.add(g_var) - coalesced_grads_and_vars = build_groups(grad_vars, 128 * 1024 * 1024) + coalesced_grads_and_vars = build_groups(grad_vars, bucket_size) + + nranks = ( + paddle.distributed.get_world_size() + if comm_group is None + else comm_group.nranks + ) + + scale = 1.0 / nranks if scale is None else scale + scale = None if scale == 1.0 else scale - nranks = paddle.distributed.get_world_size( - ) if comm_group is None else comm_group.nranks for coalesced_grad, _, _ in coalesced_grads_and_vars: # need to div nranks - coalesced_grad.scale_(1.0 / nranks) + if scale is not None: + coalesced_grad.scale_(scale) paddle.distributed.all_reduce(coalesced_grad, group=comm_group) _split_tensors(coalesced_grads_and_vars) @@ -91,20 +113,18 @@ def _broadcast_data_help(data, shape, dtype, hcg): mp_rank = hcg.get_model_parallel_rank() shape_gpu = paddle.to_tensor(shape, dtype="int32") - paddle.distributed.broadcast(shape_gpu, - src=src_rank, - group=model_parallel_group, - sync_op=True) + paddle.distributed.broadcast( + shape_gpu, src=src_rank, group=model_parallel_group, sync_op=True + ) if mp_rank != 0: input_data = paddle.zeros(shape_gpu, dtype=dtype) else: input_data = data - paddle.distributed.broadcast(input_data, - src=src_rank, - group=model_parallel_group, - sync_op=True) + paddle.distributed.broadcast( + input_data, src=src_rank, group=model_parallel_group, sync_op=True + ) if mp_rank != 0: if in_dygraph_mode(): @@ -113,7 +133,8 @@ def _broadcast_data_help(data, shape, dtype, hcg): else: data.value().get_tensor()._clear() data.value().get_tensor()._share_data_with( - input_data.value().get_tensor()) + input_data.value().get_tensor() + ) def broadcast_input_data(hcg, *inputs, **kwargs): @@ -121,8 +142,11 @@ def broadcast_input_data(hcg, *inputs, **kwargs): for v in inputs: if isinstance(v, (core.VarBase, core.eager.Tensor)): with framework.no_grad(): - if "gpu" in cur_device and in_dygraph_mode() \ - and not v.place.is_gpu_place(): + if ( + "gpu" in cur_device + and in_dygraph_mode() + and not v.place.is_gpu_place() + ): v_gpu = v.cuda(int(cur_device.split(":")[1])) v._clear_data() v_gpu._share_buffer_to(v) @@ -133,8 +157,11 @@ def broadcast_input_data(hcg, *inputs, **kwargs): for k, v in kwargs.items(): if isinstance(v, (core.VarBase, core.eager.Tensor)): with framework.no_grad(): - if "gpu" in cur_device and in_dygraph_mode() \ - and not v.place.is_gpu_place(): + if ( + "gpu" in cur_device + and in_dygraph_mode() + and not v.place.is_gpu_place() + ): v_gpu = v.cuda(int(cur_device.split(":")[1])) v._clear_data() v_gpu._share_buffer_to(v) @@ -148,28 +175,35 @@ def broadcast_input_data(hcg, *inputs, **kwargs): def broadcast_mp_parameters(model, hcg): model_parallel_group = hcg.get_model_parallel_group() src_rank = hcg.get_model_parallel_group_src_rank() - sync_params_buffers(model, - model_parallel_group, - src_rank, - is_model_parallel=True) + sync_params_buffers( + model, model_parallel_group, src_rank, is_model_parallel=True + ) def broadcast_dp_parameters(model, hcg): data_parallel_group = hcg.get_data_parallel_group() src_rank = hcg.get_data_parallel_group_src_rank() - sync_params_buffers(model, - data_parallel_group, - src_rank, - is_model_parallel=False) + sync_params_buffers( + model, data_parallel_group, src_rank, is_model_parallel=False + ) + + +def fused_allreduce_gradients_with_group( + parameter_list, group, bucket_size=128 * 1024 * 1024, scale=None +): + apply_func = ( + _apply_collective_grads_eager + if in_dygraph_mode() + else _apply_collective_grads + ) + with framework.no_grad(): + apply_func(parameter_list, group, bucket_size) def fused_allreduce_gradients(parameter_list, hcg): data_parallel_group = None if hcg is None else hcg.get_data_parallel_group() logger.debug("dp start fuse allreduce gradients") - apply_func = _apply_collective_grads_eager if in_dygraph_mode( - ) else _apply_collective_grads - with framework.no_grad(): - apply_func(parameter_list, data_parallel_group) + fused_allreduce_gradients_with_group(parameter_list, data_parallel_group) def sharding_reduce_gradients(parameter_list, hcg): @@ -186,7 +220,8 @@ def sharding_reduce_gradients(parameter_list, hcg): paddle.distributed.all_reduce( param.grad, group=hcg.get_sharding_parallel_group(), - sync_op=True) + sync_op=True, + ) elif _in_legacy_dygraph(): g_var = param._grad_ivar() @@ -199,20 +234,20 @@ def sharding_reduce_gradients(parameter_list, hcg): outputs={'Out': g_var}, attrs={ 'ring_id': hcg.get_sharding_parallel_group().id, - 'use_calc_stream': True - }) + 'use_calc_stream': True, + }, + ) # grad / sharding_rank - div_factor = paddle.to_tensor(sharding_nrank, - dtype=g_var.dtype) + div_factor = paddle.to_tensor( + sharding_nrank, dtype=g_var.dtype + ) paddle.fluid.framework._dygraph_tracer().trace_op( type="elementwise_div", - inputs={ - 'X': g_var, - 'Y': div_factor - }, + inputs={'X': g_var, 'Y': div_factor}, outputs={'Out': g_var}, - attrs={'axis': -1}) + attrs={'axis': -1}, + ) def broadcast_sharding_parameters(model, hcg): @@ -220,7 +255,6 @@ def broadcast_sharding_parameters(model, hcg): logger.debug("sharding start init parameters sync") sharding_parallel_group = hcg.get_sharding_parallel_group() src_rank = hcg.get_sharding_parallel_group_src_rank() - sync_params_buffers(model, - sharding_parallel_group, - src_rank, - is_model_parallel=False) + sync_params_buffers( + model, sharding_parallel_group, src_rank, is_model_parallel=False + )