From d8b4ca920d888595d91a4ac4079c8eb63a6867d7 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Sun, 9 Oct 2022 15:30:07 +0800 Subject: [PATCH] [dygraph sharding stage 2] sharding broadcast overlap (#46656) --- .../group_sharded_optimizer_stage2.py | 98 +++++++++++++++++-- .../sharding/group_sharded_stage2.py | 18 ++-- ...graph_group_sharded_stage2_comm_overlap.py | 6 +- 3 files changed, 102 insertions(+), 20 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index caeb54bc402..073937eafdf 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -24,6 +24,8 @@ import copy import logging +import warnings + import numpy as np from collections import OrderedDict @@ -87,7 +89,7 @@ class GroupShardedOptimizerStage2(Optimizer): self._optim = optim # sharing stage 2 comm overlap flag - self._comm_overlap = False + self._reduce_overlap = False # record the last task used for comm overlap for sharding stage 2 self._comm_task = None @@ -108,6 +110,17 @@ class GroupShardedOptimizerStage2(Optimizer): filter(lambda x: x.trainable and x.dtype == Type.fp16.value, self._local_params))) > 0 + self._broadcast_overlap = False + self._forward_pre_hook_remove_helper = [] + try: + # The fp32 params such as layer_norm_0.w_0 will be at the end of param_list. + # Have to sort the params to make sure all params are in the forward using order. + self._broadcast_order_params = sorted( + self.local_params, + key=lambda x: int(x.name.split('.')[0].split('_')[-1])) + except ValueError: + self._broadcast_order_params = None + self._group = new_group( _get_global_group().ranks) if group is None else group @@ -163,15 +176,34 @@ class GroupShardedOptimizerStage2(Optimizer): sync_op=True) def _update_task(self, task): - if self._comm_overlap: + if self._reduce_overlap: assert task is not None # Only track of the last reduce task. # Since all tasks are on the same stream, only need to wait the last one. # After waiting for the last reduce task, all reduce tasks before have already finished. self._comm_task = task - def _set_comm_overlap(self, comm_overlap): - self._comm_overlap = comm_overlap + def _set_reduce_overlap(self, reduce_overlap): + # Enable gradients' reduces overlap with backward calculation. + self._reduce_overlap = reduce_overlap + + def _set_broadcast_overlap(self, broadcast_overlap, layers=None): + # Enable post optimizer broadcasts overlap with the forward calculation of next batch. + self._broadcast_overlap = broadcast_overlap + if self._broadcast_overlap: + assert layers is not None, \ + "To enable broadcast overlap forward, please pass the module to the function." + self._layers = layers + warnings.warn( + "Setting overlap broadcast means the `paddle.device.cuda.synchronize()` " + "must be called manually before calling `paddle.save()` and before and inference." + ) + if self._broadcast_order_params is None: + # Params' names should be like column_linear_32.w_0 patter to get the best performance. + warnings.warn( + "The param name passed to the optimizer doesn't follow .+_[0-9]+\..+ patter, " + "overlap broadcast may harm the performance.") + self._broadcast_order_params = self._local_params def _generate_master_params(self, trainable_params): if self.offload: @@ -382,6 +414,12 @@ class GroupShardedOptimizerStage2(Optimizer): """ # This method won't be called directly by opt.step()! # The _redefine_opt_step() in class GroupShardedStage2 will wrap this function. + if self._broadcast_overlap: + # Clear the pre forward hook in the optimizer step. + for hook_remove in self._forward_pre_hook_remove_helper: + hook_remove.remove() + self._forward_pre_hook_remove_helper = [] + if self.offload: params_list = [self.offload_params.buffer] @@ -425,9 +463,49 @@ class GroupShardedOptimizerStage2(Optimizer): """Broadcast the parameters of the current rank to each rank""" # Exchange all the shards with the other ranks - for dtype_per_rank in self.param_storages.values(): - for dst_rank, internal_storage in dtype_per_rank.items(): - broadcast(tensor=internal_storage.buffer, - src=self._group.ranks[dst_rank], - group=self._group, - sync_op=True) + if self._broadcast_overlap: + self._broadcast_params_overlap_forward() + else: + for dtype_per_rank in self.param_storages.values(): + for dst_rank, internal_storage in dtype_per_rank.items(): + broadcast(tensor=internal_storage.buffer, + src=self._group.ranks[dst_rank], + group=self._group, + sync_op=True) + + def _forward_pre_hook_function(self, tasks): + # Since the layers will call pre hook by `forward_pre_hook(self, inputs)`, + # the helper functions needs the x and y to take those params. + def __impl__(x, y): + for task in tasks: + # Wait for broadcast task before using the result of the broadcast. + task.wait() + + return __impl__ + + @paddle.autograd.no_grad() + def _broadcast_params_overlap_forward(self): + # Exchange all the shards with the other ranks, + # but overlap the broadcast with next batch's calculation. + param2task = {} + for x in self._broadcast_order_params: + if x.trainable: + task = broadcast( + tensor=x, + src=self._group.ranks[self._param2rank[x.name]], + group=self._group, + sync_op=False) + assert x.name not in param2task + param2task[x.name] = task + + for layer in self._layers.sublayers(): + if len(layer.sublayers()) == 0: + # Register forward pre hood for leaf layers. This will get the best performance. + tasks = [] + for param in layer.parameters(): + if param.trainable: + if param.name in param2task: + tasks.append(param2task[param.name]) + self._forward_pre_hook_remove_helper.append( + layer.register_forward_pre_hook( + self._forward_pre_hook_function(tasks))) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index 7af9333aa5e..709cdadb2c2 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -101,7 +101,7 @@ class GroupShardedStage2(nn.Layer): self._all_params.extend(list(optim.local_params)) # sharing stage 2 comm overlap flag - self._comm_overlap = False + self._reduce_overlap = False self._trainable_params = [] self._grad_reduced = [] @@ -309,17 +309,17 @@ class GroupShardedStage2(nn.Layer): for grad_storage in self._grad_storage_list: grad_storage.reset_checked_in() - def _set_comm_overlap(self, comm_overlap): + def _set_reduce_overlap(self, reduce_overlap): # Hacky way to not add an extra parameter to the `group_sharded_parallel` funct. # User should use this like: # model, optimizer, scaler = group_sharded_parallel(...) - # model._set_comm_overlap(True) - self._comm_overlap = comm_overlap - if self._comm_overlap: + # model._set_reduce_overlap(True) + self._reduce_overlap = reduce_overlap + if self._reduce_overlap: assert len( self._sharding_optimizers ) == 1, "Only support comm overlap strategy for single optimizer" - self._sharding_optimizers[0]._set_comm_overlap(comm_overlap) + self._sharding_optimizers[0]._set_reduce_overlap(reduce_overlap) def _get_reduce_fn(self, index, param, dst_rank): """ @@ -357,7 +357,7 @@ class GroupShardedStage2(nn.Layer): collective.reduce(tensor=param.grad, dst=self._group.ranks[dst_rank], group=self._group, - sync_op=not self._comm_overlap)) + sync_op=not self._reduce_overlap)) # Clear the task flow and trigger callback to clear the redundant gradient # self._clear_task_flow() @@ -407,7 +407,7 @@ class GroupShardedStage2(nn.Layer): tensor=grad_storage.buffer, dst=self._group.ranks[grad_storage.destination], group=self._group, - sync_op=not self._comm_overlap)) + sync_op=not self._reduce_overlap)) cleanup() @@ -545,7 +545,7 @@ class GroupShardedStage2(nn.Layer): opt_step = opt.step def _opt_step(self): - if self._comm_overlap: + if self._reduce_overlap: # Wait for the last reduce task. This wait must before grad scale function. assert self._comm_task is not None self._comm_task.wait() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py index 883874c9cd2..7f16d926f53 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py @@ -92,13 +92,15 @@ def train_mlp(model, optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) if sharding_stage == 2: + origin_model = model optimizer = GroupShardedOptimizerStage2( params=optimizer._parameter_list, optim=optimizer, group=group) model = GroupShardedStage2(model, optimizer, group=group, buffer_max_size=2**21) - model._set_comm_overlap(True) + model._set_reduce_overlap(True) + optimizer._set_broadcast_overlap(True, model) else: model = paddle.DataParallel(model) @@ -149,6 +151,8 @@ def train_mlp(model, optimizer.step() optimizer.clear_grad() + paddle.device.cuda.synchronize() + if save_model: return model, optimizer return model.parameters() -- GitLab