From 589d13c5e1f29559bbd7744275e0eec19b6ad5e2 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Thu, 12 Aug 2021 11:45:41 +0800 Subject: [PATCH] [HybridParallel]Add Recompute for PipeLineParallel (#34607) * add recompute for pp * add recompute offload * add recompute partition --- paddle/fluid/pybind/op_function_generator.cc | 1 - python/paddle/distributed/collective.py | 11 +- .../parallel_layers/pp_layers.py | 53 +++- .../meta_parallel/parallel_layers/random.py | 12 + .../fleet/meta_parallel/pipeline_parallel.py | 7 +- .../fleet/meta_parallel/pp_utils/utils.py | 230 +++++++++++++++++- .../unittests/hybrid_parallel_pp_recompute.py | 172 +++++++++++++ ...test_parallel_dygraph_pipeline_parallel.py | 3 + 8 files changed, 476 insertions(+), 13 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/hybrid_parallel_pp_recompute.py diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index d81783c677..07a3fc8a8d 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -136,7 +136,6 @@ std::map> op_passing_outs_map = { {"c_reduce_min", {"Out"}}, {"c_reduce_prod", {"Out"}}, {"c_reduce", {"Out"}}, - {"c_allgather", {"Out"}}, {"c_scatter", {"Out"}}, {"barrier", {"Out"}}, {"fake_quantize_dequantize_moving_average_abs_max", diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index f1dcf55a56..e5dfb34f24 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -632,14 +632,13 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): ring_id = 0 if group is None else group.id nranks = _get_global_group().nranks if group is None else group.nranks - op_type = 'c_allgather' - helper = LayerHelper(op_type, **locals()) - out = helper.create_variable_for_type_inference(dtype=tensor.dtype) - if in_dygraph_mode(): - _C_ops.c_allgather(tensor, out, 'use_calc_stream', use_calc_stream, - 'ring_id', ring_id, 'nranks', nranks) + out = _C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream, + 'ring_id', ring_id, 'nranks', nranks) else: + op_type = 'c_allgather' + helper = LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference(dtype=tensor.dtype) if not isinstance(tensor_list, list): raise ValueError("The type of 'tensor_list' for all_gather " "should be list.") diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index f546adc65e..5ea3659bed 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -23,6 +23,7 @@ from functools import partial import paddle from paddle.fluid.dygraph.layers import Layer from ...utils.log_util import logger, layer_to_str +from ..pp_utils.utils import _hp_recompute, _initialize_recompute_setting __all__ = [] @@ -134,7 +135,10 @@ class PipelineLayer(Layer): num_stages=None, topology=None, loss_fn=None, - seg_method="uniform"): + seg_method="uniform", + recompute_interval=0, + recompute_offload=False, + recompute_partition=False): super(PipelineLayer, self).__init__() if num_stages is None and topology is None: raise ValueError("should provide num_stages or topology") @@ -147,6 +151,16 @@ class PipelineLayer(Layer): self.layers = layers self._loss_fn = loss_fn self._topo = topology + self._recompute_interval = recompute_interval + self._recompute_offload = recompute_offload + self._recompute_partition = recompute_partition + + if recompute_interval > 0: + logger.info( + "Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}". + format(recompute_offload, recompute_partition)) + _initialize_recompute_setting(recompute_offload, recompute_partition) + world_size = dist.get_world_size() self.global_rank = dist.get_rank() @@ -312,11 +326,44 @@ class PipelineLayer(Layer): else: self.run_function.append(layer) + def forward_function(self, start, end): + def execute_func(*x): + if len(x) == 1: + x = x[0] + for idx, layer in enumerate(self.run_function[start:end]): + x = layer(x) + return x + + return execute_func + def forward(self, input): - for layer in self.run_function: - input = layer(input) + if self._recompute_interval == 0: + input = self.forward_function(0, len(self.run_function))(input) + else: + num_layers = len(self.run_function) + for start_idx in range(0, num_layers, self._recompute_interval): + end_idx = min(start_idx + self._recompute_interval, num_layers) + funcs = self.run_function[start_idx:end_idx] + + if not isinstance(input, tuple): + input = (input, ) + + if self._need_recompute(funcs, input): + input = _hp_recompute( + self.forward_function(start_idx, end_idx), *input) + else: + input = self.forward_function(start_idx, end_idx)(*input) + return input + def _need_recompute(self, funcs, inputs): + if not any(input_.stop_gradient == False for input_ in inputs + if isinstance(input_, paddle.Tensor)): + return False + + params = [f.parameters() for f in funcs if isinstance(f, Layer)] + return any(len(list(p)) > 0 for p in params) + def save_state_dict(self, path): if self._topo.get_coord(self.global_rank).data != 0: return diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py index 70daa3b253..ec80ba7103 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py @@ -20,6 +20,9 @@ __all__ = [] MODEL_PARALLEL_RNG = 'model_parallel_rng' +# This file is inspired by Megatron to control random states for MP: +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/random.py + class RNGStatesTracker: """ @@ -46,6 +49,15 @@ class RNGStatesTracker: self.states_[name] = paddle.get_cuda_rng_state() paddle.set_cuda_rng_state(orig_rng_state) + def get_states_tracker(self): + states = {} + for name in self.states_: + states[name] = self.states_[name] + return states + + def set_states_tracker(self, states): + self.states_ = states + @contextlib.contextmanager def rng_state(self, name=MODEL_PARALLEL_RNG): if name not in self.states_: diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 16ea7de294..fc7b39ede2 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -14,7 +14,7 @@ import paddle import paddle.fluid as fluid from .meta_parallel_base import MetaParallelBase -from .pp_utils.utils import is_float_tensor +from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg from .parallel_layers.pp_layers import PipelineLayer from ..utils.hybrid_parallel_util import broadcast_mp_parameters @@ -48,6 +48,8 @@ class PipelineParallel(MetaParallelBase): p2p.initialize_p2p_groups(hcg) + _initialize_recompute_hcg(hcg) + self.is_first_stage = self.stage_id == 0 self.is_last_stage = (self.stage_id == (self.num_stages - 1)) self.global_rank = self._hcg.get_global_rank() @@ -213,6 +215,9 @@ class PipelineParallel(MetaParallelBase): if self.is_first_stage: assert len(inputs) == 2, "length of input should be 2" if isinstance(inputs[0], tuple): + assert len( + inputs[0] + ) > 1, "If you use tuple for input data, it should have at least two inputs." batch_size = inputs[0][0].shape[0] assert self.micro_batch_size * self.accumulate_steps == batch_size, ( "batch_size needs to be divisible by micro_batch_size. Currently, " diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 8c204820b1..728080a7cd 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -12,9 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc +import contextlib + import paddle -from ...utils import log_util as hp_util +from paddle.fluid import core +from paddle import _C_ops +import paddle.distributed as dist +from paddle.autograd import PyLayer +from paddle.fluid import framework +from paddle.distributed.fleet.utils.recompute import check_recompute_necessary, detach_variable +from ..parallel_layers.random import get_rng_state_tracker __all__ = [] @@ -79,3 +86,222 @@ def get_tensor_bytes(tensor): else: raise ValueError("unknown data type: {}".format(tensor.dtype)) return tensor.numel() * elem_size + + +_hcg = None +_recompute_offload = False +_recompute_partition = False + + +def _initialize_recompute_setting(is_offload, is_partition): + global _recompute_offload, _recompute_partition + + _recompute_offload = is_offload + _recompute_partition = is_partition + + +def _initialize_recompute_hcg(hcg): + global _hcg + _hcg = hcg + + +def _all_gather(tensor, group=None, use_calc_stream=True): + """ + The main difference with paddle.distributed.all_gather: + no need to pass in tensor_list, the returned tensor is spliced + """ + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + nranks = paddle.distributed.collective._get_global_group( + ).nranks if group is None else group.nranks + return _C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream, + 'ring_id', ring_id, 'nranks', nranks) + + +def _split_activation(tensor): + global _hcg + + mp_degree = _hcg.get_model_parallel_world_size() + mp_rank = _hcg.get_model_parallel_rank() + if mp_degree < 2: + return tensor + + tensor_numel = paddle.numel(tensor) + assert tensor_numel != 0, "can't recompute zero element" + assert tensor_numel % mp_degree == 0, "The capacity of the activation () cannot be divisible by mp_degree()".format( + tensor_numel, mp_degree) + + # use inplace operation to save memory + data = tensor.flatten_() + part_size = tensor_numel // mp_degree + start = part_size * mp_rank + end = start + part_size + return data[start:end] + + +def _merge_activation(tensor): + global _hcg + mp_degree = _hcg.get_model_parallel_world_size() + mp_rank = _hcg.get_model_parallel_rank() + mp_group = _hcg.get_model_parallel_group() + if mp_degree < 2: + return tensor + return _all_gather(tensor, group=mp_group) + + +@contextlib.contextmanager +def _swith_rng_state_tracker(rng_state, tracker): + orig_cuda_rng_state = paddle.get_cuda_rng_state() + orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker() + + paddle.set_cuda_rng_state(rng_state) + get_rng_state_tracker().set_states_tracker(tracker) + try: + yield + finally: + paddle.set_cuda_rng_state(orig_cuda_rng_state) + get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker) + + +class _HPRecomputeFunction(PyLayer): + """ + Compared with paddle.distributed.fleet.utils.recompute, there are the following differences: + 1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type. + 2. Offload support for activation + 3. Support MP segmentation of activation to further reduce cuda memory + 4. Adapt to the random state of MP + """ + + @staticmethod + def forward(ctx, run_function, all_outputs, *args): + check_recompute_necessary(args) + + # store for recomputing + ctx.run_function = run_function + + # store the rng states + ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( + ).get_states_tracker() + + # save input for backward + ctx.inputs = [] + ctx.tensor_indices = [] + ctx.tensor_shapes = [] + tensor_inputs = [] + + cur_device = paddle.get_device() + assert 'gpu:' in paddle.get_device( + ), "Recompute with RNG is not support current device: {}.".format( + cur_device) + + # TODO support AMP + tracer = framework._dygraph_tracer() + ctx.is_fw_autocast = tracer._enable_autocast + ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() + + with paddle.no_grad(): + outputs = run_function(*args) + + for i, arg in enumerate(args): + if paddle.is_tensor(arg): + state = arg.stop_gradient + if _recompute_partition: + ctx.tensor_shapes.append(arg.shape) + partition = _split_activation(arg.detach()).clone() + # TODO(shenliang03) not use calculate stream to D2H to speed + arg = partition.cpu() if _recompute_offload else partition + else: + arg = arg.cpu() if _recompute_offload else arg + arg.stop_gradient = state + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + ctx.save_for_backward(*tensor_inputs) + + if paddle.is_tensor(outputs): + all_outputs += [outputs] + return outputs + else: + all_outputs += outputs + return tuple(outputs) + + @staticmethod + def backward(ctx, *args): + with paddle.fluid.dygraph.guard(): + # Restore inputs + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensor_shapes = ctx.tensor_shapes + tensors = list(ctx.saved_tensor()) + + device_id = dist.ParallelEnv().device_id + for i, idx in enumerate(tensor_indices): + if _recompute_partition: + state = tensors[i].stop_gradient + tensors[i] = _merge_activation(tensors[i]).detach( + ).reshape_(tensor_shapes[i]) + tensors[i].stop_gradient = state + inputs[idx] = tensors[i].cuda( + device_id) if _recompute_offload else tensors[i] + + tracer = framework._dygraph_tracer() + tracer._has_grad = True + + # need restore auto_cast state as well as w/b list + with _swith_rng_state_tracker(ctx.fwd_cuda_rng_state, + ctx.fwd_cuda_rng_state_tracker): + with paddle.amp.auto_cast( + enable=ctx.is_fw_autocast, + custom_white_list=ctx.amp_white_list, + custom_black_list=ctx.amp_black_list): + detached_inputs = detach_variable(tuple(inputs)) + outputs = ctx.run_function(*detached_inputs) + + if isinstance(outputs, core.VarBase): + outputs = (outputs, ) + assert len(outputs) == len(args) + + forward_outputs_with_grad = [] + backward_inputs = [] + + for i in range(len(outputs)): + if isinstance(outputs[i], + core.VarBase) and not outputs[i].stop_gradient: + forward_outputs_with_grad.append(outputs[i]) + backward_inputs.append(args[i]) + + if len(forward_outputs_with_grad) == 0: + raise RuntimeError( + "none of output has stop_gradient=False, this recompute() is not necessary" + ) + + # actually backward + paddle.autograd.backward(forward_outputs_with_grad, backward_inputs) + grads = list(inp._grad_ivar() for inp in detached_inputs + if isinstance(inp, core.VarBase)) + return grads + + +def _hp_recompute(function, *args): + # NODTE(shenliang03)The current hybrid parallel recompute has limitations. + # It cannot handle the following situations: + # 1. The calculation output of recompute, there are tensors that do not require gradients. + # 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach(). + # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor + + all_outputs = [] + _HPRecomputeFunction.apply(function, all_outputs, *args) + + if len(all_outputs) == 1: + return all_outputs[0] + else: + for output in all_outputs: + if paddle.is_tensor(output) and not is_float_tensor(output): + output.stop_gradient = True + + return tuple(all_outputs) diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_recompute.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_recompute.py new file mode 100644 index 0000000000..ebcac70a3b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_recompute.py @@ -0,0 +1,172 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import random +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from paddle.fluid import layers +import paddle.nn.functional as F +from paddle.distributed.fleet.meta_parallel import PipelineLayer, LayerDesc +from paddle.fluid.dygraph.layers import Layer +import paddle.nn as nn + + +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + dp_id) + + +batch_size = 8 +length = 8 +micro_batch_size = 2 +vocab_size = 128 +hidden_size = 16 +d_model = hidden_size +dim_feedforward = 4 * d_model + + +class EmbeddingNet(Layer): + def __init__(self): + super(EmbeddingNet, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + self.position_embeddings = nn.Embedding(vocab_size, hidden_size) + + def forward(self, x): + w_emb = self.word_embeddings(x) + p_emb = self.position_embeddings(x) + w_emb = w_emb + p_emb + return w_emb + + +class TransformerNet(Layer): + def __init__(self): + super(TransformerNet, self).__init__() + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.q_proj = nn.Linear(d_model, d_model) + self.k_proj = nn.Linear(d_model, d_model) + self.v_proj = nn.Linear(d_model, d_model) + + self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5) + + def forward(self, x): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_model**-0.5) + weights = F.softmax(product) + + weights = F.dropout(weights, 0.2) + tgt = layers.matmul(weights, v) + residual = tgt + tgt = self.norm1(tgt) + tgt = residual + tgt + + out = self.linear2(F.gelu(self.linear1(tgt), approximate=True)) + return out + + +class EmbeddingPipe(EmbeddingNet): + def forward(self, x): + return super().forward(x) + + +class TransformerNetPipe(TransformerNet): + def forward(self, x): + output = super().forward(x) + return output + + +class CriterionPipe(Layer): + def __init__(self): + super(CriterionPipe, self).__init__() + + def forward(self, out, label): + loss = out.mean() + return loss + + +class ModelPipe(PipelineLayer): + def __init__(self, topology): + self.descs = [] + self.descs.append(LayerDesc(EmbeddingPipe)) + + for x in range(2): + self.descs.append(LayerDesc(TransformerNetPipe)) + + super().__init__( + layers=self.descs, + loss_fn=CriterionPipe(), + topology=topology, + seg_method="layer:TransformerNetPipe", + recompute_interval=1, + recompute_partition=False, + recompute_offload=False) + + +class TestDistPPTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size + } + fleet.init(is_collective=True, strategy=strategy) + + def test_pp_model(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + pp_id = hcg.get_stage_id() + rank_id = dist.get_rank() + topology = hcg.topology() + set_random_seed(1024, dp_id, rank_id) + + model = ModelPipe(topology) + scheduler = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True) + optimizer = paddle.optimizer.SGD(learning_rate=scheduler, + parameters=model.parameters()) + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + + for step_id in range(5): + x_data = np.random.randint(0, vocab_size, size=[batch_size, length]) + x = paddle.to_tensor(x_data) + x.stop_gradient = True + loss = model.train_batch([x, x], optimizer, scheduler) + # TODO(shenliang03) add utest for loss + print("loss: ", loss) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py index 003e0c1685..35fd49dfff 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py @@ -39,6 +39,9 @@ class TestHybridPipeParallel(TestMultipleGpus): def test_hybrid_parallel_transformer(self): self.run_mnist_2gpu('hybrid_parallel_pp_save_load.py') + def test_hybrid_parallel_transformer(self): + self.run_mnist_2gpu('hybrid_parallel_pp_recompute.py') + if __name__ == "__main__": unittest.main() -- GitLab