diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 74f377fd875de30df1b4438d60f6a0f79c832f43..ce489352d3bcfbc7d191991bcf7f85d4be6c33ac 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -68,6 +68,8 @@ message PpConfig { message DygraphShardingConfig { optional bool tensor_fusion = 1 [ default = false ]; + optional int32 accumulate_steps = 2 [ default = 1 ]; + optional bool comm_overlap = 3 [ default = false ]; } message HybridConfig { diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 63d261e2e3dfe1667039551a042b2e0ea0e10d92..ccb5bfdcd030ad0a2b92a4cdd27c612fd0ef62d7 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -78,12 +78,23 @@ class DygraphShardingOptimizer: self.tensor_fusion = strategy.hybrid_configs[ 'sharding_configs' ].tensor_fusion + self.accumulate_steps = strategy.hybrid_configs[ + 'sharding_configs' + ].accumulate_steps + self.comm_overlap = strategy.hybrid_configs[ + 'sharding_configs' + ].comm_overlap pp_overlap = strategy.hybrid_configs['pp_configs'].sharding_comm_overlap - if self.tensor_fusion: + if self.tensor_fusion or self.comm_overlap: assert ( not pp_overlap ), "Can not enable pp's sharding_comm_overlap and sharding's tensor_fusion at the same time." + self._use_main_grad = hasattr(self._parameter_list[0], "main_grad") + self._rank2decay = {} + self._rank2fused = {} + self._comm_buffers = [] + self._rank2params = self._partition_parameters() self._param2rank = self._map_param_to_rank() @@ -95,25 +106,22 @@ class DygraphShardingOptimizer: '_param_groups', self._rank2params[self._sharding_rank] ) else: - self._use_main_grad = hasattr(self._parameter_list[0], "main_grad") - self._rank2decay = {} - self._rank2fused = {} self._tensor_fusion() decay_params = [ p.name for p in self._rank2decay[self._sharding_rank] ] - all_params = self._rank2fused[self._sharding_rank] + fused_params = self._rank2fused[self._sharding_rank] apply_decay_param_fun = lambda x: x in decay_params - params = [] + all_fused_params = [] for v in self._rank2fused.values(): - params += v - self._parameter_list = params - self._param_groups = params + all_fused_params += v + self._parameter_list = all_fused_params + self._param_groups = all_fused_params - self._set_inner_opt_attr('_parameter_list', all_params) - self._set_inner_opt_attr('_param_groups', all_params) + self._set_inner_opt_attr('_parameter_list', fused_params) + self._set_inner_opt_attr('_param_groups', fused_params) origin_decay_param_fun = getattr( self._inner_opt, '_apply_decay_param_fun', None ) @@ -145,11 +153,23 @@ class DygraphShardingOptimizer: p.clear_gradient(set_to_zero) def _tensor_fusion(self): + comm_group = self._hcg.get_sharding_parallel_group() for i in range(self._sharding_world_size): params = self._rank2params[i] - decay_fused, all_fused = fused_parameters( - params, self._use_main_grad + dst = comm_group.ranks[i] + # TODO(sharding dev): make scale_after_comm a field to be configured by user + decay_fused, all_fused, all_buffer = fused_parameters( + params, + use_main_grad=self._use_main_grad, + fuse_param=True, + comm_overlap=self.comm_overlap, + comm_group=comm_group, + dst=dst, + acc_step=self.accumulate_steps, + scale_after_comm=False, ) + if self.comm_overlap: + self._comm_buffers += all_buffer self._rank2decay[i] = decay_fused self._rank2fused[i] = all_fused for p in all_fused: @@ -199,6 +219,10 @@ class DygraphShardingOptimizer: def reduce_gradients(self, parameter_list, hcg): # TODO merge grad / nrank with dp logger.debug("sharding start gradients sync") + if self.comm_overlap: + for buffer in self._comm_buffers: + buffer.scale_grads() + return with framework.no_grad(): sharding_nrank = hcg.get_sharding_parallel_group().nranks for param in parameter_list: diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 6644e2a06e5fe91e7402b5fbf3d572808267c90b..2038a4c4e460697fa8eab34c06a8d35305c3da8c 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -37,11 +37,11 @@ else: from .pp_utils import p2p_communication as p2p from paddle.distributed.fleet.utils.tensor_fusion_helper import ( + HOOK_ACTION, + FusedCommBuffer, assign_group_by_size, ) -from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer - __all__ = [] g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0)) @@ -334,9 +334,11 @@ class PipelineParallel(MetaParallelBase): for dst in fused_parameter_group: parameter_list = fused_parameter_group[dst] - if not dp: + if act != HOOK_ACTION.ALL_REDUCE: # parse the relative dst rank to absolute dst rank for sharding dst = comm_group.ranks[dst] + else: + dst = -1 var_groups = assign_group_by_size(parameter_list) for group_idx, parameters in var_groups.items(): buffer = FusedCommBuffer( @@ -515,7 +517,7 @@ class PipelineParallel(MetaParallelBase): if self._comm_overlap: assert len(self._comm_buffers) > 0 for buffer in self._comm_buffers: - buffer.scale_and_split_grads() + buffer.scale_grads() if self._enable_timer: self.timers("allreduce_shared_weight_gradients").start() @@ -1256,7 +1258,7 @@ class PipelineParallelWithInterleave(PipelineParallel): if self._comm_overlap: assert len(self._comm_buffers) > 0 for buffer in self._comm_buffers: - buffer.scale_and_split_grads() + buffer.scale_grads() if static_scheduler: self._reset_counter() diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 6c8e2fd9dc3aa349580ba94463ea836ac0f5b3f8..33b8c3d95d5824f4143e6b28a22396dbf64781bb 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -15,19 +15,10 @@ import paddle from paddle import _legacy_C_ops -from paddle.distributed.fleet.utils.tensor_fusion_helper import ( - flatten_dense_tensors, -) -from paddle.framework import base as imperative_base __all__ = [] -class HOOK_ACTION: - ALL_REDUCE = 0 - REDUCE = 1 - - FLOAT_TYPE_DICT = { paddle.float16: "float16", paddle.float32: "float32", @@ -116,118 +107,3 @@ def _all_gather(tensor, group=None, use_calc_stream=True): 'nranks', nranks, ) - - -class FusedCommBuffer: - def __init__(self, id, params, comm_group, acc_steps=1, act=None, dst=-1): - self._id = id - self._params = params - self._acc_steps = acc_steps - self._comm_group = comm_group - - self.use_main_grad = hasattr(self._params[0], "main_grad") - - self._task = None - self._params_step_dict = {} - self._params_checked_in = 0 - self._params_to_addr = {} - - self._act = act - if self._act == HOOK_ACTION.ALL_REDUCE: - assert dst == -1 - elif self._act == HOOK_ACTION.REDUCE: - assert dst != -1 - else: - raise ValueError( - "The act should be allreudce for dp or reduce for sharding." - ) - self._dst = dst - - self._init_step_dict() - - self.grad_storage = flatten_dense_tensors( - self._params, - use_main_grad=self.use_main_grad, - fuse_param=False, - warp_buffer=False, - ).buffer - - self._record_addr() - - def _record_addr(self): - for param in self._params: - addr = ( - param.main_grad.data_ptr() - if self.use_main_grad - else param.grad.data_ptr() - ) - self._params_to_addr[param.name] = addr - - def _init_step_dict(self): - for p in self._params: - self._params_step_dict[p.name] = 0 - - def _reset_params_checked_in(self): - self._task = None - self._init_step_dict() - self._params_checked_in = 0 - - @property - def _all_params_checked_in(self): - return ( - len(self._params) == self._params_checked_in - and len(self._params_step_dict) == 0 - ) - - def add_grad(self, param): - assert param.name in self._params_step_dict - current_ptr = ( - param.main_grad.data_ptr() - if self.use_main_grad - else param.grad.data_ptr() - ) - if self._params_to_addr[param.name] != current_ptr: - raise ValueError( - "The address of the grad/main_grad of the param has been changed during training, " - "which is not allowed for dp/sharding overlap with pp. " - "This may be caused by some non-inplace operations on the grad/main_grad. " - "Please use the inplace version of the operations or disable the overlapping." - ) - - self._params_step_dict[param.name] += 1 - - if self._params_step_dict[param.name] == self._acc_steps: - self._params_checked_in += 1 - self._params_step_dict.pop(param.name) - - if self._all_params_checked_in: - self._comm_grads() - - @imperative_base.no_grad - def _comm_grads(self): - assert self._all_params_checked_in - - if self._act == HOOK_ACTION.ALL_REDUCE: - task = paddle.distributed.all_reduce( - self.grad_storage, group=self._comm_group, sync_op=False - ) - - elif self._act == HOOK_ACTION.REDUCE: - task = paddle.distributed.reduce( - self.grad_storage, - dst=self._dst, - group=self._comm_group, - sync_op=False, - ) - - self._task = task - - @imperative_base.no_grad - def scale_and_split_grads(self): - assert self._task is not None - self._task.wait() - - scale_factor = 1.0 / self._comm_group.nranks - self.grad_storage.scale_(scale_factor) - - self._reset_params_checked_in() diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index 403f9d5d9a6c151a0a83847b4e927a31bac18b04..f2720b04ea093948f9d67f614144db2e405a8381 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -12,13 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools +import os from collections import OrderedDict import numpy as np import paddle +from paddle.framework import base as imperative_base from paddle.framework import core + +class HOOK_ACTION: + ALL_REDUCE = 0 + REDUCE = 1 + + alignment = { "gpu": 256, } @@ -101,23 +109,204 @@ def flatten_dense_tensors( return grad_storage -def obtain_storage(parameters, use_main_grad, clip, dist): +def bw_hook_func(buffer, param): + @paddle.autograd.no_grad() + def fused_comm(*_): + buffer.add_grad(param) + + return fused_comm + + +class FusedCommBuffer: + def __init__( + self, + id, + params, + comm_group, + acc_steps=1, + act=None, + dst=-1, + use_main_grad=None, + fuse_param=False, + scale_after_comm=True, + ): + self._id = id + self._params = params + self._acc_steps = acc_steps + self._comm_group = comm_group + self._scale_after_comm = scale_after_comm + self._fuse_param = fuse_param + + self.use_main_grad = ( + use_main_grad + if use_main_grad is not None + else hasattr(self._params[0], "main_grad") + ) + + self._task = None + self._params_step_dict = {} + self._params_checked_in = 0 + self._grads_to_addr = {} + + self._act = act + if self._act == HOOK_ACTION.ALL_REDUCE: + assert dst == -1 + elif self._act == HOOK_ACTION.REDUCE: + assert dst != -1 + else: + raise ValueError( + "The act should be allreudce for dp or reduce for sharding." + ) + self._dst = dst + + self._init_step_dict() + + if self._fuse_param: + self.param_storage, self.grad_storage = flatten_dense_tensors( + self._params, + use_main_grad=use_main_grad, + fuse_param=True, + warp_buffer=True, + ) + self.param_storage = self.param_storage.buffer + self.grad_storage = self.grad_storage.buffer + else: + self.param_storage = None + self.grad_storage = flatten_dense_tensors( + self._params, + use_main_grad=self.use_main_grad, + fuse_param=False, + warp_buffer=False, + ).buffer + + self._record_addr() + + def _record_addr(self): + for param in self._params: + addr = ( + param.main_grad.data_ptr() + if self.use_main_grad + else param.grad.data_ptr() + ) + self._grads_to_addr[param.name] = addr + + def _init_step_dict(self): + for p in self._params: + self._params_step_dict[p.name] = 0 + + def _reset_params_checked_in(self): + self._task = None + self._init_step_dict() + self._params_checked_in = 0 + + @property + def _all_params_checked_in(self): + return ( + len(self._params) == self._params_checked_in + and len(self._params_step_dict) == 0 + ) + + def add_grad(self, param): + assert param.name in self._params_step_dict + current_ptr = ( + param.main_grad.data_ptr() + if self.use_main_grad + else param.grad.data_ptr() + ) + if self._grads_to_addr[param.name] != current_ptr: + raise ValueError( + "The address of the grad/main_grad of the param has been changed during training, " + "which is not allowed for dp/sharding overlap with pp. " + "This may be caused by some non-inplace operations on the grad/main_grad. " + "Please use the inplace version of the operations or disable the overlapping." + ) + + self._params_step_dict[param.name] += 1 + + if self._params_step_dict[param.name] == self._acc_steps: + self._params_checked_in += 1 + self._params_step_dict.pop(param.name) + + if self._all_params_checked_in: + self._comm_grads() + + @imperative_base.no_grad + def _comm_grads(self): + assert self._all_params_checked_in + + if not self._scale_after_comm: + scale_factor = 1.0 / self._comm_group.nranks + self.grad_storage.scale_(scale_factor) + + if self._act == HOOK_ACTION.ALL_REDUCE: + task = paddle.distributed.all_reduce( + self.grad_storage, group=self._comm_group, sync_op=False + ) + + elif self._act == HOOK_ACTION.REDUCE: + task = paddle.distributed.reduce( + self.grad_storage, + dst=self._dst, + group=self._comm_group, + sync_op=False, + ) + + self._task = task + + @imperative_base.no_grad + def scale_grads(self): + assert self._task is not None + self._task.wait() + + if self._scale_after_comm: + scale_factor = 1.0 / self._comm_group.nranks + self.grad_storage.scale_(scale_factor) + + self._reset_params_checked_in() + + +def obtain_storage( + parameters, + use_main_grad=False, + clip=True, + dist=False, + fuse_param=True, + comm_overlap=False, + act=None, + comm_group=None, + dst=-1, + acc_steps=1, + scale_after_comm=False, +): if len(parameters) < 1: - return [] + return [], [] var_groups = assign_group_by_size(parameters, group_size=256 * 1024 * 1024) storage = [] + buffers = [] for group_idx, parameters in var_groups.items(): - param_storage, grad_storage = flatten_dense_tensors( + comm_buffer = FusedCommBuffer( + group_idx, parameters, + comm_group=comm_group, + acc_steps=acc_steps, + act=act, + dst=dst, use_main_grad=use_main_grad, - fuse_param=True, - warp_buffer=True, + fuse_param=fuse_param, + scale_after_comm=scale_after_comm, ) - param_storage.buffer.need_clip = clip - param_storage.buffer.is_distributed = dist - storage.append(param_storage.buffer) - return storage + if fuse_param: + param_buffer = comm_buffer.param_storage + param_buffer.need_clip = clip + param_buffer.is_distributed = dist + storage.append(param_buffer) + if comm_overlap: + for param in parameters: + param._register_backward_hook(bw_hook_func(comm_buffer, param)) + buffers.append(comm_buffer) + + return storage, buffers def filter_params(params, is_fp32, is_distributed, need_clip): @@ -155,7 +344,38 @@ def filter_params(params, is_fp32, is_distributed, need_clip): return params, dtype -def fused_parameters(parameters, use_main_grad): +def fused_parameters( + parameters, + use_main_grad=False, + fuse_param=True, + comm_overlap=False, + comm_group=None, + dst=-1, + acc_step=1, + scale_after_comm=False, +): + """ + Fuse gradients. Fuse parameters if be enabled. Prepare for comm overlap if be enabled. + :param parameters: all parameters to be fused. + :param use_main_grad: does the gradient use main grad or not + :param comm_overlap: enable comm overlap or not + :param comm_group: the comm group for comm overlap + :param dst: the dst for comm overlap + :param acc_step: acc steps, using for comm overlap + :param fuse_param: fuse param or not + :param scale_after_comm: if enable comm overlap, specify the location of grad scale + :return: param storage if fused, comm buffers is comm overlap + """ + g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0)) + act = ( + HOOK_ACTION.ALL_REDUCE if not g_shard_use_reduce else HOOK_ACTION.REDUCE + ) + if comm_overlap: + assert comm_group is not None + if act == HOOK_ACTION.REDUCE: + assert dst != -1 + elif act == HOOK_ACTION.ALL_REDUCE: + dst = -1 param_groups = [] attrs = [] @@ -178,6 +398,7 @@ def fused_parameters(parameters, use_main_grad): decay_fused = [] all_fused = [] + all_buffers = [] for params, attr in zip(param_groups, attrs): decay_params = [] other_params = [] @@ -190,14 +411,36 @@ def fused_parameters(parameters, use_main_grad): is_distributed = attr[1] need_clip = attr[2] - decay = obtain_storage( - decay_params, use_main_grad, need_clip, is_distributed + decay, decay_buffers = obtain_storage( + decay_params, + use_main_grad=use_main_grad, + clip=need_clip, + dist=is_distributed, + fuse_param=fuse_param, + comm_overlap=comm_overlap, + act=act, + comm_group=comm_group, + dst=dst, + acc_steps=acc_step, + scale_after_comm=scale_after_comm, ) - other = obtain_storage( - other_params, use_main_grad, need_clip, is_distributed + other, other_buffers = obtain_storage( + other_params, + fuse_param=fuse_param, + comm_overlap=comm_overlap, + use_main_grad=use_main_grad, + clip=need_clip, + dist=is_distributed, + act=act, + comm_group=comm_group, + dst=dst, + acc_steps=acc_step, + scale_after_comm=scale_after_comm, ) decay_fused += decay all_fused += decay all_fused += other + all_buffers += decay_buffers + all_buffers += other_buffers - return decay_fused, all_fused + return decay_fused, all_fused, all_buffers diff --git a/test/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py b/test/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py index 310313119b4c369f706f3921dd974c65537bac7c..e70656a4ce6084232f293c4b40d341874bf26819 100644 --- a/test/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py +++ b/test/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py @@ -99,6 +99,8 @@ class TestDistSharding(unittest.TestCase): "pp_degree": 1, } self.strategy.hybrid_configs["sharding_configs"].tensor_fusion = True + self.strategy.hybrid_configs["sharding_configs"].comm_overlap = True + self.strategy.hybrid_configs["sharding_configs"].accumulate_steps = 1 fleet.init(is_collective=True, strategy=self.strategy) self.data = np.random.randint( 0, diff --git a/test/legacy_test/test_fused_comm_buffer.py b/test/legacy_test/test_fused_comm_buffer.py index ad771b6dfe5a2bbe220cb424a16b9ee77c18ce7b..25d9a2748bd0edf6133a0cd879092b7004cd63e6 100644 --- a/test/legacy_test/test_fused_comm_buffer.py +++ b/test/legacy_test/test_fused_comm_buffer.py @@ -15,7 +15,7 @@ import unittest import paddle -from paddle.distributed.fleet.meta_parallel.pp_utils.utils import ( +from paddle.distributed.fleet.utils.tensor_fusion_helper import ( HOOK_ACTION, FusedCommBuffer, )