diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 19d3245982a96aedaf37e1d757ba690cbdf4178f..d366291b6bebff1e70286f6d0712d40fe81ea066 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -24,6 +24,8 @@ import copy import logging +import warnings + import numpy as np from collections import OrderedDict @@ -86,6 +88,11 @@ class GroupShardedOptimizerStage2(Optimizer): # Default information self._optim = optim + # sharing stage 2 comm overlap flag + self._reduce_overlap = False + # record the last task used for comm overlap for sharding stage 2 + self._comm_task = None + assert hasattr(self._optim, "_master_weights" ), "Must use optimizer with _master_weights attribute" @@ -103,6 +110,17 @@ class GroupShardedOptimizerStage2(Optimizer): filter(lambda x: x.trainable and x.dtype == Type.fp16.value, self._local_params))) > 0 + self._broadcast_overlap = False + self._forward_pre_hook_remove_helper = [] + try: + # The fp32 params such as layer_norm_0.w_0 will be at the end of param_list. + # Have to sort the params to make sure all params are in the forward using order. + self._broadcast_order_params = sorted( + self.local_params, + key=lambda x: int(x.name.split('.')[0].split('_')[-1])) + except ValueError: + self._broadcast_order_params = None + self._group = new_group( _get_global_group().ranks) if group is None else group @@ -157,6 +175,60 @@ class GroupShardedOptimizerStage2(Optimizer): group=self._group, sync_op=True) + def _update_task(self, task): + if self._reduce_overlap: + assert task is not None + # Only track of the last reduce task. + # Since all tasks are on the same stream, only need to wait the last one. + # After waiting for the last reduce task, all reduce tasks before have already finished. + self._comm_task = task + + def _set_reduce_overlap(self, reduce_overlap): + # Enable gradients' reduces overlap with backward calculation. + self._reduce_overlap = reduce_overlap + + def _set_broadcast_overlap(self, + broadcast_overlap, + layers=None, + num_groups=None): + # Enable post optimizer broadcasts overlap with the forward calculation of next batch. + self._broadcast_overlap = broadcast_overlap + if self._broadcast_overlap: + assert layers is not None, \ + "To enable broadcast overlap forward, please pass the module to the function." + self._layers = layers + warnings.warn( + "Setting overlap broadcast means the `paddle.device.cuda.synchronize()` " + "must be called manually before calling `paddle.save()` and before and inference." + ) + if self._broadcast_order_params is None: + # Params' names should be like column_linear_32.w_0 patter to get the best performance. + warnings.warn( + "The param name passed to the optimizer doesn't follow .+_[0-9]+\..+ patter, " + "overlap broadcast may harm the performance.") + self._broadcast_order_params = self._local_params + + if num_groups is None or num_groups > len(self._broadcast_order_params): + warnings.warn( + "The num_groups for broadcast is larger than the number of params to be broadcast. " + "It will set to default value: 1 (use the default sharding group)." + ) + num_groups = 1 + + assert isinstance( + num_groups, + int) and num_groups > 0, "num_groups should be a positive integer" + + self._number_of_broadcast_groups = num_groups + self._broadcast_groups = [ + None for _ in range(self._number_of_broadcast_groups) + ] + self._broadcast_groups[0] = self._group + + ranks = self._group.ranks + for i in range(1, self._number_of_broadcast_groups): + self._broadcast_groups[i] = new_group(ranks) + def _generate_master_params(self, trainable_params): if self.offload: for param in trainable_params: @@ -364,6 +436,13 @@ class GroupShardedOptimizerStage2(Optimizer): """ A wrapper for Optimizer's step function to finish the update operation of the optimizer. """ + # This method won't be called directly by opt.step()! + # The _redefine_opt_step() in class GroupShardedStage2 will wrap this function. + if self._broadcast_overlap: + # Clear the pre forward hook in the optimizer step. + for hook_remove in self._forward_pre_hook_remove_helper: + hook_remove.remove() + self._forward_pre_hook_remove_helper = [] if self.offload: params_list = [self.offload_params.buffer] @@ -408,9 +487,52 @@ class GroupShardedOptimizerStage2(Optimizer): """Broadcast the parameters of the current rank to each rank""" # Exchange all the shards with the other ranks - for dtype_per_rank in self.param_storages.values(): - for dst_rank, internal_storage in dtype_per_rank.items(): - broadcast(tensor=internal_storage.buffer, - src=self._group.ranks[dst_rank], - group=self._group, - sync_op=True) + if self._broadcast_overlap: + self._broadcast_params_overlap_forward() + else: + for dtype_per_rank in self.param_storages.values(): + for dst_rank, internal_storage in dtype_per_rank.items(): + broadcast(tensor=internal_storage.buffer, + src=self._group.ranks[dst_rank], + group=self._group, + sync_op=True) + + def _forward_pre_hook_function(self, tasks): + # Since the layers will call pre hook by `forward_pre_hook(self, inputs)`, + # the helper functions needs the x and y to take those params. + def __impl__(x, y): + for task in tasks: + # Wait for broadcast task before using the result of the broadcast. + task.wait() + + return __impl__ + + @paddle.autograd.no_grad() + def _broadcast_params_overlap_forward(self): + # Exchange all the shards with the other ranks, + # but overlap the broadcast with next batch's calculation. + group_idx = 0 + + param2task = {} + for x in self._broadcast_order_params: + if x.trainable: + group = self._broadcast_groups[group_idx] + group_idx = (group_idx + 1) % self._number_of_broadcast_groups + task = broadcast(tensor=x, + src=group.ranks[self._param2rank[x.name]], + group=group, + sync_op=False) + assert x.name not in param2task + param2task[x.name] = task + + for layer in self._layers.sublayers(): + if len(layer.sublayers()) == 0: + # Register forward pre hood for leaf layers. This will get the best performance. + tasks = [] + for param in layer.parameters(): + if param.trainable: + if param.name in param2task: + tasks.append(param2task[param.name]) + self._forward_pre_hook_remove_helper.append( + layer.register_forward_pre_hook( + self._forward_pre_hook_function(tasks))) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index d47f50b292db26e0a0930a58623cb1d05b3ecaeb..babf9391b928dad0842b8272d6313ceaa5cb67a5 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -100,6 +100,9 @@ class GroupShardedStage2(nn.Layer): for optim in self._sharding_optimizers: self._all_params.extend(list(optim.local_params)) + # sharing stage 2 comm overlap flag + self._reduce_overlap = False + self._trainable_params = [] self._grad_reduced = [] self._trainable_param2rank = {} @@ -306,6 +309,18 @@ class GroupShardedStage2(nn.Layer): for grad_storage in self._grad_storage_list: grad_storage.reset_checked_in() + def _set_reduce_overlap(self, reduce_overlap): + # Hacky way to not add an extra parameter to the `group_sharded_parallel` funct. + # User should use this like: + # model, optimizer, scaler = group_sharded_parallel(...) + # model._set_reduce_overlap(True) + self._reduce_overlap = reduce_overlap + if self._reduce_overlap: + assert len( + self._sharding_optimizers + ) == 1, "Only support comm overlap strategy for single optimizer" + self._sharding_optimizers[0]._set_reduce_overlap(reduce_overlap) + def _get_reduce_fn(self, index, param, dst_rank): """ There are two ways to reduce gradient. @@ -337,11 +352,12 @@ class GroupShardedStage2(nn.Layer): del tmp_grad param.clear_gradient(False) - # Synchronize the reduce parameter gradient - collective.reduce(tensor=param.grad, - dst=self._group.ranks[dst_rank], - group=self._group) - # TODO (Baibaifan) Asynchronous the reduce parameter gradient + # Synchronize the reduce parameter gradient asynchronize + self._sharding_optimizers[0]._update_task( + collective.reduce(tensor=param.grad, + dst=self._group.ranks[dst_rank], + group=self._group, + sync_op=not self._reduce_overlap)) # Clear the task flow and trigger callback to clear the redundant gradient # self._clear_task_flow() @@ -385,12 +401,13 @@ class GroupShardedStage2(nn.Layer): # Reduce the bucket grad_storage.sent = True - # Synchronize the reduce parameter gradient - collective.reduce( - tensor=grad_storage.buffer, - dst=self._group.ranks[grad_storage.destination], - group=self._group) - # TODO (Baibaifan) Asynchronous the reduce parameter gradient + # Synchronize the reduce parameter gradient asynchronize + self._sharding_optimizers[0]._update_task( + collective.reduce( + tensor=grad_storage.buffer, + dst=self._group.ranks[grad_storage.destination], + group=self._group, + sync_op=not self._reduce_overlap)) cleanup() @@ -528,6 +545,10 @@ class GroupShardedStage2(nn.Layer): opt_step = opt.step def _opt_step(self): + if self._reduce_overlap: + # Wait for the last reduce task. This wait must before grad scale function. + assert self._comm_task is not None + self._comm_task.wait() grad_func() opt_step() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py new file mode 100644 index 0000000000000000000000000000000000000000..81ab434e5e88df14a43ab98b60b1e45481127be0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py @@ -0,0 +1,245 @@ +# -*- coding: UTF-8 -*- + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import numpy as np +import argparse +import tempfile +import ast +import time +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Linear +from paddle.distributed import fleet +from paddle.fluid.dygraph import nn +from paddle.fluid.framework import _test_eager_guard + +from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import GroupShardedOptimizerStage2 +from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import GroupShardedStage2 + +seed = 2022 +epoch = 2 +linear_size = 1000 + +np.random.seed(seed) +paddle.seed(seed) + + +class MLP(fluid.Layer): + + def __init__(self, linear_size=1000, param_attr=None, bias_attr=None): + super(MLP, self).__init__() + + self._linear1 = Linear(linear_size, linear_size) + self._linear2 = Linear(linear_size, linear_size) + self._linear3 = Linear(linear_size, 10) + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + return y + + +def reader_decorator(linear_size=1000): + + def __reader__(): + for _ in range(100): + img = np.random.rand(linear_size).astype('float32') + label = np.ones(1).astype('int64') + yield img, label + + return __reader__ + + +def optimizer_setting(model, use_pure_fp16, opt_group=False): + clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) + optimizer = paddle.optimizer.AdamW(parameters=[{ + "params": model.parameters(), + }] if opt_group else model.parameters(), + learning_rate=0.001, + weight_decay=0.00001, + grad_clip=clip, + multi_precision=use_pure_fp16) + + return optimizer + + +def train_mlp(model, + sharding_stage, + batch_size=100, + use_pure_fp16=False, + accumulate_grad=False, + opt_group=False, + save_model=False, + test_minimize=False): + if sharding_stage != "dp": + group = paddle.distributed.new_group([0, 1], backend="nccl") + if opt_group: + optimizer = optimizer_setting(model=model, + use_pure_fp16=use_pure_fp16, + opt_group=opt_group) + else: + optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) + + if sharding_stage == 2: + origin_model = model + optimizer = GroupShardedOptimizerStage2( + params=optimizer._parameter_list, optim=optimizer, group=group) + model = GroupShardedStage2(model, + optimizer, + group=group, + buffer_max_size=2**21) + model._set_reduce_overlap(True) + optimizer._set_broadcast_overlap(True, model) + else: + model = paddle.DataParallel(model) + + # check optimizer.minimize() error + if test_minimize: + try: + optimizer.minimize() + except: + print( + "====== Find sharding_stage2_optimizer.minimize() error ======") + return + + train_reader = paddle.batch(reader_decorator(), + batch_size=batch_size, + drop_last=True) + + train_loader = paddle.io.DataLoader.from_generator(capacity=32, + use_double_buffer=True, + iterable=True, + return_list=True, + use_multiprocess=True) + train_loader.set_sample_list_generator(train_reader) + + if sharding_stage == 2: + model.to(device="gpu") + + for eop in range(epoch): + model.train() + + for batch_id, data in enumerate(train_loader()): + img, label = data + label.stop_gradient = True + img.stop_gradient = True + + out = model(img) + loss = paddle.nn.functional.cross_entropy(input=out, label=label) + + avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) + if batch_size == 20: + avg_loss = avg_loss / 5 + avg_loss.backward() + + if not accumulate_grad: + optimizer.step() + optimizer.clear_grad() + + if accumulate_grad: + optimizer.step() + optimizer.clear_grad() + + paddle.device.cuda.synchronize() + + if save_model: + return model, optimizer + return model.parameters() + + +def test_dp_stage2(): + paddle.distributed.init_parallel_env() + mlp = MLP() + state_dict = mlp.state_dict() + mlp1 = MLP() + mlp2 = MLP() + mlp3 = MLP() + mlp4 = MLP() + mlp5 = MLP() + mlp6 = MLP() + mlp7 = MLP() + mlp1.set_state_dict(state_dict) + mlp2.set_state_dict(state_dict) + mlp3.set_state_dict(state_dict) + mlp4.set_state_dict(state_dict) + mlp5.set_state_dict(state_dict) + mlp6.set_state_dict(state_dict) + mlp7.set_state_dict(state_dict) + + # DP VS stage2 + dp_params = train_mlp(mlp1, + sharding_stage="dp", + use_pure_fp16=False, + opt_group=False) + stage2_params = train_mlp(mlp2, + sharding_stage=2, + use_pure_fp16=False, + opt_group=False) + for i in range(len(dp_params)): + np.testing.assert_allclose(dp_params[i].numpy(), + stage2_params[i].numpy(), + rtol=1e-6) + + # stage2 accumulate grad + stage2_params = train_mlp(mlp3, sharding_stage=2, accumulate_grad=True) + stage2_accumulate_grad = train_mlp(mlp4, + sharding_stage=2, + batch_size=20, + accumulate_grad=True) + for i in range(len(stage2_params)): + np.testing.assert_allclose(stage2_params[i].numpy(), + stage2_accumulate_grad[i].numpy(), + rtol=1e-5, + atol=1e-5) + + # stage2 param list VS param group + stage2_params = train_mlp(mlp5, + sharding_stage=2, + use_pure_fp16=False, + opt_group=True) + for i in range(len(dp_params)): + np.testing.assert_allclose(dp_params[i].numpy(), + stage2_params[i].numpy(), + rtol=1e-6) + + # save/load model + output_dir = tempfile.mkdtemp() + model_file = os.path.join(output_dir, "model.pdmodel") + optimizer_file = os.path.join(output_dir, "model.pdopt") + model_stage2, optimizer_stage2 = train_mlp(mlp6, + sharding_stage=2, + use_pure_fp16=False, + opt_group=False, + save_model=True) + paddle.save(model_stage2.state_dict(), model_file) + paddle.save(optimizer_stage2.state_dict(), optimizer_file) + m_state_dict = paddle.load(model_file) + opt_state_dict = paddle.load(optimizer_file) + model_stage2.set_state_dict(m_state_dict) + optimizer_stage2.set_state_dict(opt_state_dict) + shutil.rmtree(output_dir) + + # check optimizer.minimize() error + train_mlp(mlp7, sharding_stage=2, test_minimize=True) + return + + +if __name__ == '__main__': + with _test_eager_guard(): + test_dp_stage2() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_sharding_stage2.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_sharding_stage2.py index 9d842d8719fe3a095daa0f9a193abfd2c38236fd..aeeae15fe06534b1ff3f56eace7f6bcc4b20c281 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_sharding_stage2.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_sharding_stage2.py @@ -33,6 +33,9 @@ class TestDygraphShardingStage2(TestMultipleGpus): self.run_mnist_2gpu('dygraph_sharding_stage2_offload.py', eager_mode=False) + def test_dygraph_sharding_stage2_with_comm_overlap(self): + self.run_mnist_2gpu('dygraph_group_sharded_stage2_comm_overlap.py') + if __name__ == "__main__": os.environ["FLAGS_enable_eager_mode"] = "1"