From 491e4df317a766ecdad1c63e66653fa0d2b4cf31 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com> Date: Mon, 19 Sep 2022 14:08:49 +0800 Subject: [PATCH] Recompute unify incubate (#46073) --- .../parallel_layers/pp_layers.py | 8 +- .../fleet/meta_parallel/pipeline_parallel.py | 3 - .../fleet/meta_parallel/pp_utils/__init__.py | 2 - .../pp_utils/p2p_communication.py | 2 +- .../fleet/meta_parallel/pp_utils/utils.py | 208 --------------- python/paddle/distributed/fleet/model.py | 42 --- .../distributed/fleet/recompute/__init__.py | 18 ++ .../fleet/{utils => recompute}/recompute.py | 68 ++++- .../fleet/recompute/recompute_hybrid.py | 250 ++++++++++++++++++ .../distributed/fleet/utils/__init__.py | 12 +- .../fleet/test_dygraph_recompute.py | 94 +++++-- .../fleet/test_dygraph_recompute_for_eager.py | 88 ++++-- .../unittests/dygraph_recompute_hybrid.py | 221 ++++++++++++++++ .../tests/unittests/test_pipeline_parallel.py | 6 + .../incubate/distributed/fleet/__init__.py | 17 ++ .../distributed/models/moe/moe_layer.py | 6 +- .../optimizer/distributed_fused_lamb.py | 6 +- python/setup.py.in | 2 + 18 files changed, 734 insertions(+), 319 deletions(-) create mode 100644 python/paddle/distributed/fleet/recompute/__init__.py rename python/paddle/distributed/fleet/{utils => recompute}/recompute.py (88%) create mode 100644 python/paddle/distributed/fleet/recompute/recompute_hybrid.py create mode 100755 python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py create mode 100644 python/paddle/incubate/distributed/fleet/__init__.py diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index f6878ec1d86..e3c92ee1db7 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -49,8 +49,9 @@ from functools import partial import paddle from paddle.fluid.dygraph.layers import Layer from ...utils.log_util import logger, layer_to_str -from ..pp_utils.utils import _hp_recompute, _initialize_recompute_setting +from paddle.distributed import fleet from paddle.fluid.framework import in_dygraph_mode +from paddle.incubate.distributed.fleet import recompute_hybrid __all__ = [] @@ -309,13 +310,13 @@ class PipelineLayer(Layer): self._loss_fn = loss_fn self._topo = topology self._recompute_interval = recompute_interval + self.recompute_ctx = recompute_ctx if recompute_interval > 0: assert recompute_ctx is not None, "recompute_ctx must be not None for recompute." offload = recompute_ctx.get('offload', False) partition = recompute_ctx.get('partition', False) - _initialize_recompute_setting(offload, partition) logger.info( "Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}" .format(offload, partition)) @@ -638,7 +639,8 @@ class PipelineLayer(Layer): input = (input, ) if self._need_recompute(funcs, input): - input = _hp_recompute( + input = recompute_hybrid( + self.recompute_ctx, self.forward_function(start_idx, end_idx), *input) else: input = self.forward_function(start_idx, end_idx)(*input) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 876f9ffaed3..537885bfad3 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -14,7 +14,6 @@ import paddle import paddle.fluid as fluid from .meta_parallel_base import MetaParallelBase -from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg from .parallel_layers.pp_layers import PipelineLayer from ..utils.hybrid_parallel_util import broadcast_mp_parameters @@ -61,8 +60,6 @@ class PipelineParallel(MetaParallelBase): p2p.initialize_p2p_groups(hcg, self._using_cache) - _initialize_recompute_hcg(hcg) - self.global_rank = self._hcg.get_global_rank() self.micro_batch_id = 0 diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py index 786eb20487a..04575bfb231 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py @@ -12,6 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .utils import get_tensor_bytes - __all__ = [] diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index ce5c1cfe9eb..3b4094f0475 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -13,12 +13,12 @@ # limitations under the License. import paddle -from .utils import paddle_2_number, number_2_dtype from ...utils.log_util import logger import numpy as np from paddle import _C_ops, _legacy_C_ops import paddle.fluid.core as core from paddle.fluid.framework import _in_legacy_dygraph, _non_static_mode, in_dygraph_mode +from .utils import paddle_2_number, paddle_2_number, number_2_dtype _hcg = None _use_cache = False diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 8ec7f0f037b..683cc51d279 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -12,16 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib - import paddle from paddle.fluid import core from paddle import _C_ops, _legacy_C_ops -from paddle.autograd import PyLayer -from paddle.fluid import framework -from ...utils.recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker -from ..parallel_layers.random import get_rng_state_tracker -from paddle.fluid.framework import in_dygraph_mode __all__ = [] @@ -88,23 +81,6 @@ def get_tensor_bytes(tensor): return tensor.numel() * elem_size -_hcg = None -_recompute_offload = False -_recompute_partition = False - - -def _initialize_recompute_setting(is_offload, is_partition): - global _recompute_offload, _recompute_partition - - _recompute_offload = is_offload - _recompute_partition = is_partition - - -def _initialize_recompute_hcg(hcg): - global _hcg - _hcg = hcg - - def _all_gather(tensor, group=None, use_calc_stream=True): """ The main difference with paddle.distributed.all_gather: @@ -117,187 +93,3 @@ def _all_gather(tensor, group=None, use_calc_stream=True): ).nranks if group is None else group.nranks return _legacy_C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'nranks', nranks) - - -def _split_activation(tensor): - global _hcg - - mp_degree = _hcg.get_model_parallel_world_size() - mp_rank = _hcg.get_model_parallel_rank() - if mp_degree < 2: - return tensor - - tensor_numel = paddle.numel(tensor) - assert tensor_numel != 0, "can't recompute zero element" - assert tensor_numel % mp_degree == 0, "The capacity of the activation () cannot be divisible by mp_degree()".format( - tensor_numel, mp_degree) - - # use inplace operation to save memory - data = tensor.flatten_() - - part_size = tensor_numel // mp_degree - start = part_size * mp_rank - end = start + part_size - return data[start:end] - - -def _merge_activation(tensor): - global _hcg - mp_degree = _hcg.get_model_parallel_world_size() - mp_rank = _hcg.get_model_parallel_rank() - mp_group = _hcg.get_model_parallel_group() - if mp_degree < 2: - return tensor - return _all_gather(tensor, group=mp_group) - - -class _HPRecomputeFunction(PyLayer): - """ - Compared with paddle.distributed.fleet.utils.recompute, there are the following differences: - 1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type. - 2. Offload support for activation - 3. Support MP segmentation of activation to further reduce cuda memory - 4. Adapt to the random state of MP - """ - - @staticmethod - def forward(ctx, run_function, all_outputs, *args): - check_recompute_necessary(args) - - # store for recomputing - ctx.run_function = run_function - - # store the rng states - ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state() - ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( - ).get_states_tracker() - - # save input for backward - ctx.inputs = [] - ctx.tensor_indices = [] - ctx.tensor_shapes = [] - tensor_inputs = [] - - cur_device = paddle.get_device() - assert 'gpu:' in paddle.get_device( - ), "Recompute with RNG is not support current device: {}.".format( - cur_device) - - # TODO support AMP - tracer = framework._dygraph_tracer() - ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True - if tracer._amp_level == core.AmpLevel.O2: - ctx.amp_level = 'O2' - elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): - ctx.amp_level = 'O1' - else: - raise ValueError("unsupported amp level: {}".format( - tracer._amp_level)) - ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() - - with paddle.no_grad(): - outputs = run_function(*args) - - for i, arg in enumerate(args): - if paddle.is_tensor(arg): - state = arg.stop_gradient - if _recompute_partition: - ctx.tensor_shapes.append(arg.shape) - partition = _split_activation(arg.detach()).clone() - # TODO(shenliang03) not use calculate stream to D2H to speed - arg = partition.cpu() if _recompute_offload else partition - else: - arg = arg.cpu() if _recompute_offload else arg - arg.stop_gradient = state - tensor_inputs.append(arg) - ctx.tensor_indices.append(i) - ctx.inputs.append(None) - else: - ctx.inputs.append(arg) - - ctx.save_for_backward(*tensor_inputs) - - if paddle.is_tensor(outputs): - all_outputs += [outputs] - return outputs - else: - all_outputs += outputs - return tuple(outputs) - - @staticmethod - def backward(ctx, *args): - with paddle.fluid.dygraph.guard(): - # Restore inputs - inputs = list(ctx.inputs) - tensor_indices = ctx.tensor_indices - tensor_shapes = ctx.tensor_shapes - tensors = list(ctx.saved_tensor()) - - device_id = paddle.distributed.ParallelEnv().device_id - for i, idx in enumerate(tensor_indices): - if _recompute_partition: - state = tensors[i].stop_gradient - tensors[i] = _merge_activation( - tensors[i]).detach().reshape_(tensor_shapes[i]) - tensors[i].stop_gradient = state - inputs[idx] = tensors[i].cuda( - device_id) if _recompute_offload else tensors[i] - - tracer = framework._dygraph_tracer() - tracer._has_grad = True - - # need restore auto_cast state as well as w/b list - with swith_rng_state_tracker(ctx.fwd_cuda_rng_state, - ctx.fwd_cuda_rng_state_tracker): - with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, - custom_white_list=ctx.amp_white_list, - custom_black_list=ctx.amp_black_list, - level=ctx.amp_level): - detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) - - if isinstance(outputs, (core.VarBase, core.eager.Tensor)): - outputs = (outputs, ) - assert len(outputs) == len(args) - - forward_outputs_with_grad = [] - backward_inputs = [] - - for i in range(len(outputs)): - if isinstance( - outputs[i], - (core.VarBase, - core.eager.Tensor)) and not outputs[i].stop_gradient: - forward_outputs_with_grad.append(outputs[i]) - backward_inputs.append(args[i]) - - if len(forward_outputs_with_grad) == 0: - raise RuntimeError( - "none of output has stop_gradient=False, this recompute() is not necessary" - ) - - # actually backward - paddle.autograd.backward(forward_outputs_with_grad, backward_inputs) - grads = tuple(inp._grad_ivar() for inp in detached_inputs - if isinstance(inp, (core.VarBase, core.eager.Tensor))) - return grads - - -def _hp_recompute(function, *args): - # NODTE(shenliang03)The current hybrid parallel recompute has limitations. - # It cannot handle the following situations: - # 1. The calculation output of recompute, there are tensors that do not require gradients. - # 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach(). - # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor - - all_outputs = [] - _HPRecomputeFunction.apply(function, all_outputs, *args) - - if len(all_outputs) == 1: - return all_outputs[0] - else: - for output in all_outputs: - if paddle.is_tensor(output) and not is_float_tensor(output): - output.stop_gradient = True - - return tuple(all_outputs) diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index fea2614fe84..40633788f12 100644 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -20,46 +20,9 @@ from .base.topology import ParallelMode from .meta_parallel import TensorParallel, model_parallel_random_seed from .meta_parallel import PipelineParallel, ShardingParallel, PipelineParallelWithInterleave, PipelineLayer from paddle.fluid import core -from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar from paddle.distributed import fleet - -class _RecomputeModelWrapper(paddle.nn.Layer): - - def __init__(self, model, segments=2, preserve_rng_state=True): - super(_RecomputeModelWrapper, self).__init__() - assert isinstance(model, paddle.nn.Sequential), ( - "The model passed to RecomputeModelWrapper must be of type " - "paddle.nn.Sequential.") - self._model = model - self._segments = segments - self._preserve_rng_state = preserve_rng_state - self._layers = list(model.children()) - self._segment_size = len(self._layers) // segments - - def _run_func(self, begin, end): - - def do_run(input): - for i in range(begin, end): - input = self._layers[i](input) - return input - - return do_run - - def _checkpoint(self, func, *args, **kwargs): - return LegacyRecomputeFunction.apply(func, self._preserve_rng_state, - *args) - - def forward(self, input): - end = 0 - for begin in range(0, self._segment_size * (self._segments - 1), - self._segment_size): - end = begin + self._segment_size - input = self._checkpoint(self._run_func(begin, end), input) - return self._run_func(end, len(self._layers))(input) - - _grad_scalar = None @@ -125,7 +88,6 @@ def distributed_model(model): return model amp_enable = False - recompute_enable = False strategy = fleet_env._user_defined_strategy if strategy.amp == True: amp_enable = True @@ -154,10 +116,6 @@ def distributed_model(model): decr_every_n_nan_or_inf=decr_every_n_nan_or_inf, use_dynamic_loss_scaling=use_dynamic_loss_scaling) - if strategy.recompute == True: - recompute_enable = True - model = _RecomputeModelWrapper(model) - if strategy.heter_ccl_mode == True: distributed_model = paddle.DataParallel( model, diff --git a/python/paddle/distributed/fleet/recompute/__init__.py b/python/paddle/distributed/fleet/recompute/__init__.py new file mode 100644 index 00000000000..7e5bcdb1db2 --- /dev/null +++ b/python/paddle/distributed/fleet/recompute/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .recompute import recompute, recompute_sequential +from .recompute_hybrid import recompute_hybrid + +__all__ = [] diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py similarity index 88% rename from python/paddle/distributed/fleet/utils/recompute.py rename to python/paddle/distributed/fleet/recompute/recompute.py index 2dddb1d9fb4..28ded25a0e6 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -207,12 +207,13 @@ class LegacyRecomputeFunction(LegacyPyLayer): class RecomputeFunction(PyLayer): @staticmethod - def forward(ctx, run_function, preserve_rng_state, *args): + def forward(ctx, run_function, preserve_rng_state, *args, **kwargs): from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker # store for recomputing ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state + ctx.kwargs = kwargs # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input # the order of tensors in backward()'s output should be the same as tensors in forward()'s input @@ -265,7 +266,7 @@ class RecomputeFunction(PyLayer): ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() with paddle.no_grad(): - outputs = run_function(*args) + outputs = run_function(*args, **kwargs) return outputs @staticmethod @@ -297,7 +298,8 @@ class RecomputeFunction(PyLayer): level=ctx.amp_level, dtype=ctx.amp_dtype): detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) + outputs = ctx.run_function(*detached_inputs, + **ctx.kwargs) else: with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, @@ -305,7 +307,7 @@ class RecomputeFunction(PyLayer): level=ctx.amp_level, dtype=ctx.amp_dtype): detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) + outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) if isinstance(outputs, (core.VarBase, core.eager.Tensor)): outputs = (outputs, ) @@ -352,7 +354,7 @@ def recompute(function, *args, **kwargs): recompute intermediate activations to save then memory. Parameters: - function(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model + function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. *args(Tensor): inputs to the function. @@ -466,11 +468,59 @@ def recompute(function, *args, **kwargs): """ # Hack to mix *args with **kwargs in a python 2.7-compliant way preserve = kwargs.pop('preserve_rng_state', True) - if kwargs: - raise ValueError("Unexpected keyword arguments: " + - ",".join(arg for arg in kwargs)) if framework._dygraph_tracer()._has_grad: check_recompute_necessary(args) - return RecomputeFunction.apply(function, preserve, *args) + return RecomputeFunction.apply(function, preserve, *args, **kwargs) + + +def recompute_sequential(ctx, functions, *args, **kwargs): + """ + recompute intermediate activations to save then memory for 'Sequential' models. + + Parameters: + ctx(dict): include 'segments' and 'preserve_rng_state' keys, the key 'segments' (int, default 1), represents the number of chunks to create in the model, + the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be + restored when the forward recalculation of backpropagation is performed. and some keys such as 'mp_group', 'offload' and 'partition' are invalid here, + they are useful in 'recompute_hybrid' API. + functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model + whose intermediate activations will be released to save memory in forward stage and will be recomputed + in backward stage for gradient calculation. + *args(Tensor): inputs(tuple) to the function. + **kwargs(Dict): inputs(dict) to the function. + + Returns: + Output of function on args and kwargs. + + Examples: + .. code-block:: python + + model = paddle.nn.Sequential(...) + input = recompute_sequential({'segments' : 1}, model, input) + """ + segments = ctx.get('segments', 1) + preserve_rng_state = ctx.get('preserve_rng_state', True) + + def _run_func(begin, end, funcs): + + def do_run(input): + for i in range(begin, end + 1): + input = funcs[i](input) + return input + + return do_run + + if isinstance(functions, paddle.nn.Sequential): + functions = list(functions.children()) + + segment_size = len(functions) // segments + + end = -1 + for begin in range(0, segment_size * (segments - 1), segment_size): + end = begin + segment_size - 1 + args = recompute(_run_func(begin, end, functions), + *args, + preserve_rng_state=preserve_rng_state, + **kwargs) + return _run_func(end + 1, len(functions) - 1, functions)(args) diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py new file mode 100644 index 00000000000..4883cad2511 --- /dev/null +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -0,0 +1,250 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib + +import paddle +from paddle import _C_ops, _legacy_C_ops +from paddle.fluid import core +from paddle.autograd import PyLayer +from paddle.fluid import framework +from ..meta_parallel.parallel_layers.random import get_rng_state_tracker +from paddle.fluid.framework import in_dygraph_mode +from paddle.distributed import fleet +from .recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker +from ..meta_parallel.pp_utils import utils + +__all__ = [] + + +def _split_activation(tensor, mp_group): + + mp_degree = mp_group.nranks + mp_rank = mp_group.rank + if mp_degree < 2: + return tensor + + tensor_numel = paddle.numel(tensor) + assert tensor_numel != 0, "can't recompute zero element" + assert tensor_numel % mp_degree == 0, "The capacity of the activation ({}) cannot be divisible by mp_degree({})".format( + tensor_numel, mp_degree) + + # use inplace operation to save memory + data = tensor.flatten_() + + part_size = tensor_numel // mp_degree + start = part_size * mp_rank + end = start + part_size + return data[start:end] + + +def _merge_activation(tensor, mp_group): + mp_degree = mp_group.nranks + mp_rank = mp_group.rank + if mp_degree < 2: + return tensor + + # adapt to new dygraph + tensor_shape = list(tensor.shape) + tensor_shape[0] *= mp_group.nranks + out = paddle.empty(tensor_shape, tensor.dtype) + task = mp_group.process_group.all_gather(tensor.cuda(), out) + task.wait() + return out + + +class _HPRecomputeFunction(PyLayer): + """ + Compared with paddle.distributed.fleet.utils.recompute, there are the following differences: + 1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type. + 2. Offload support for activation + 3. Support MP segmentation of activation to further reduce cuda memory + 4. Adapt to the random state of MP + """ + + @staticmethod + def forward(ctx, run_function, all_outputs, mp_group, offload, partition, + *args, **kwargs): + check_recompute_necessary(args) + + # store for recomputing + ctx.run_function = run_function + + ctx.kwargs = kwargs + + # store the rng states + ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( + ).get_states_tracker() + + # save config info + ctx.mp_group = mp_group + ctx.offload = offload + ctx.partition = partition + + # save input for backward + ctx.inputs = [] + ctx.tensor_indices = [] + ctx.tensor_shapes = [] + tensor_inputs = [] + + cur_device = paddle.get_device() + assert 'gpu:' in paddle.get_device( + ), "Recompute with RNG is not support current device: {}.".format( + cur_device) + + # TODO support AMP + tracer = framework._dygraph_tracer() + ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True + if tracer._amp_level == core.AmpLevel.O2: + ctx.amp_level = 'O2' + elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): + ctx.amp_level = 'O1' + else: + raise ValueError("unsupported amp level: {}".format( + tracer._amp_level)) + ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() + + with paddle.no_grad(): + outputs = run_function(*args, **kwargs) + + for i, arg in enumerate(args): + if paddle.is_tensor(arg): + state = arg.stop_gradient + if partition: + ctx.tensor_shapes.append(arg.shape) + partition = _split_activation(arg.detach(), + mp_group).clone() + # TODO(shenliang03) not use calculate stream to D2H to speed + arg = partition.cpu() if offload else partition + else: + arg = arg.cpu() if offload else arg + arg.stop_gradient = state + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + ctx.save_for_backward(*tensor_inputs) + + if paddle.is_tensor(outputs): + all_outputs += [outputs] + return outputs + else: + all_outputs += outputs + return tuple(outputs) + + @staticmethod + def backward(ctx, *args): + with paddle.fluid.dygraph.guard(): + # Restore inputs + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensor_shapes = ctx.tensor_shapes + tensors = list(ctx.saved_tensor()) + + device_id = paddle.distributed.ParallelEnv().device_id + for i, idx in enumerate(tensor_indices): + if ctx.partition: + state = tensors[i].stop_gradient + tensors[i] = _merge_activation( + tensors[i], + ctx.mp_group).detach().reshape_(tensor_shapes[i]) + tensors[i].stop_gradient = state + inputs[idx] = tensors[i].cuda( + device_id) if ctx.offload else tensors[i] + + tracer = framework._dygraph_tracer() + tracer._has_grad = True + + # need restore auto_cast state as well as w/b list + with swith_rng_state_tracker(ctx.fwd_cuda_rng_state, + ctx.fwd_cuda_rng_state_tracker): + with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, + custom_white_list=ctx.amp_white_list, + custom_black_list=ctx.amp_black_list, + level=ctx.amp_level): + detached_inputs = detach_variable(tuple(inputs)) + outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) + + if isinstance(outputs, (core.VarBase, core.eager.Tensor)): + outputs = (outputs, ) + assert len(outputs) == len(args) + + forward_outputs_with_grad = [] + backward_inputs = [] + + for i in range(len(outputs)): + if isinstance( + outputs[i], + (core.VarBase, + core.eager.Tensor)) and not outputs[i].stop_gradient: + forward_outputs_with_grad.append(outputs[i]) + backward_inputs.append(args[i]) + + if len(forward_outputs_with_grad) == 0: + raise RuntimeError( + "none of output has stop_gradient=False, this recompute() is not necessary" + ) + + # actually backward + paddle.autograd.backward(forward_outputs_with_grad, backward_inputs) + grads = tuple(inp._grad_ivar() for inp in detached_inputs + if isinstance(inp, (core.VarBase, core.eager.Tensor))) + return grads + + +def recompute_hybrid(ctx, function, *args, **kwargs): + """ + # NODTE(shenliang03)The current hybrid parallel recompute has limitations. + # It cannot handle the following situations: + # 1. The calculation output of recompute, there are tensors that do not require gradients. + # 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach(). + # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor + + Parameters: + ctx(dict): include 'mp_group', 'offload', and 'partition' keys. the key 'mp_group' (Group), represents the avtivations are splitted + in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False), + represents whether to split activations in the mp_group. and some keys such as 'segments' and 'preserve_rng_state' are invalid here, they are useful in + 'recompute_sequential' API. + function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model + whose intermediate activations will be released to save memory in forward stage and will be recomputed + in backward stage for gradient calculation. + *args(Tensor): inputs(tuple) to the function. + + **kwargs(Dict): inputs(dict) to the function. + + Returns: + Output of function on args and kwargs. + + """ + mp_group = ctx.get('mp_group', None) + assert mp_group is not None, "ctx must contains mp_group and mp_group can not be None." + + offload = ctx.get('offload', False) + partition = ctx.get('partition', False) + + all_outputs = [] + _HPRecomputeFunction.apply(function, all_outputs, mp_group, offload, + partition, *args, **kwargs) + + if len(all_outputs) == 1: + return all_outputs[0] + else: + for output in all_outputs: + if paddle.is_tensor(output) and not utils.is_float_tensor(output): + output.stop_gradient = True + + return tuple(all_outputs) diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 1bf90a22e37..93fc890d05a 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -15,11 +15,21 @@ from .fs import LocalFS # noqa: F401 from .fs import HDFSClient # noqa: F401 from .ps_util import DistributedInfer # noqa: F401 -from .recompute import recompute # noqa: F401 +import paddle.utils.deprecated as deprecated +from paddle.distributed import fleet +import paddle from . import log_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401 __all__ = [ #noqa "LocalFS", "recompute", "DistributedInfer", "HDFSClient" ] + + +@deprecated(since="2.4.0", + update_to="paddle.distributed.fleet.recompute", + level=1, + reason="Please use new recompute API(fleet.recompute) ") +def recompute(function, *args, **kwargs): + return fleet.recompute.recompute(function, *args, **kwargs) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py index 11ca15fd331..f5f59cf1027 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py @@ -21,6 +21,7 @@ import paddle from paddle.autograd import PyLayer from paddle.distributed.fleet.utils import recompute import random +from paddle.incubate.distributed.fleet import recompute_sequential import paddle.fluid.layers as layers @@ -53,48 +54,66 @@ class Naive_fc_net(paddle.nn.Layer): def __init__(self, input_size=10, recompute_blocks=[1, 3], + use_fleet_sq=False, + segments=1, + use_raw_recompute=False, recompute_kwargs={}): super(Naive_fc_net, self).__init__() self.recompute_blocks = recompute_blocks self.recompute_kwargs = recompute_kwargs + self.use_fleet_sq = use_fleet_sq + self.use_raw_recompute = use_raw_recompute + self.segments = segments + self.runfunc0 = get_fc_block(0, input_size, is_last=False) self.runfunc1 = get_fc_block(1, input_size, is_last=False) self.runfunc2 = get_fc_block(2, input_size, is_last=False) self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc4 = get_fc_block(4, input_size, is_last=True) - def forward(self, inputs): + if self.use_fleet_sq and not use_raw_recompute: + self.runfuncs = paddle.nn.Sequential(self.runfunc0, self.runfunc1, + self.runfunc2, self.runfunc3, + self.runfunc4) - if 0 in self.recompute_blocks: - inputs = recompute(self.runfunc0, inputs) - else: - inputs = self.runfunc0(inputs) + self.layers = [ + self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, + self.runfunc4 + ] - if 1 in self.recompute_blocks: - inputs = recompute(self.runfunc1, inputs) - else: - inputs = self.runfunc1(inputs) + # default segments = 2 + if use_raw_recompute: + self.layers = [ + paddle.nn.Sequential(self.runfunc0, self.runfunc1), + paddle.nn.Sequential(self.runfunc2, self.runfunc3, + self.runfunc4) + ] - if 2 in self.recompute_blocks: - inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs) - else: - inputs = self.runfunc2(inputs) + def forward(self, inputs): - if 3 in self.recompute_blocks: - inputs = recompute(self.runfunc3, inputs) - else: - inputs = self.runfunc3(inputs) + if self.use_fleet_sq and not self.use_raw_recompute: + return recompute_sequential({"segments": self.segments}, + self.runfuncs, inputs) - if 4 in self.recompute_blocks: - inputs = recompute(self.runfunc4, inputs) - else: - inputs = self.runfunc4(inputs) + if self.use_raw_recompute: + inputs = recompute(self.layers[0], inputs) + return self.layers[1](inputs) + + for i in range(len(self.layers)): + if i in self.recompute_blocks: + inputs = recompute(self.layers[i], inputs, + **self.recompute_kwargs) + else: + inputs = self.layers[i](inputs) return inputs def run_model(recompute_block=[], recompute_kwargs={}, + use_fleet_sq=False, + use_raw_recompute=False, + segments=1, enable_autocast=False, pure_fp16=False): gen = paddle.seed(10) @@ -105,6 +124,9 @@ def run_model(recompute_block=[], batch_size, input_size = 1, 10 model = Naive_fc_net(input_size, recompute_blocks=recompute_block, + use_fleet_sq=use_fleet_sq, + use_raw_recompute=use_raw_recompute, + segments=segments, recompute_kwargs=recompute_kwargs) loss_fn = paddle.nn.MSELoss(reduction='mean') optimizer = paddle.optimizer.SGD(learning_rate=0.01, @@ -179,6 +201,34 @@ class TestPyLayer(unittest.TestCase): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + # recompute second & fourth block using fleet + loss, param, grad = run_model(recompute_block=[1, 3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute using recompute_sequential, segments=1 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with base recompute, and segments=2 + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], + enable_autocast=enable_autocast, + use_raw_recompute=True, + pure_fp16=pure_fp16) + + # recompute using recompute_sequential, segments=2 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + segments=2, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_dropout(self): self.test_base_case() @@ -191,7 +241,7 @@ class TestPyLayer(unittest.TestCase): def test_recompute_kwargs(self): paddle.set_device("gpu") kwargs = {"is_test": False} - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], recompute_kwargs=kwargs) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py index bc97d53485b..4b0c73370d3 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py @@ -25,6 +25,7 @@ import paddle from paddle.autograd import PyLayer from paddle.distributed.fleet.utils import recompute import random +from paddle.incubate.distributed.fleet import recompute_sequential import paddle.fluid.layers as layers @@ -57,48 +58,66 @@ class Naive_fc_net(paddle.nn.Layer): def __init__(self, input_size=10, recompute_blocks=[1, 3], + use_fleet_sq=False, + segments=1, + use_raw_recompute=False, recompute_kwargs={}): super(Naive_fc_net, self).__init__() self.recompute_blocks = recompute_blocks self.recompute_kwargs = recompute_kwargs + self.use_fleet_sq = use_fleet_sq + self.use_raw_recompute = use_raw_recompute + self.segments = segments + self.runfunc0 = get_fc_block(0, input_size, is_last=False) self.runfunc1 = get_fc_block(1, input_size, is_last=False) self.runfunc2 = get_fc_block(2, input_size, is_last=False) self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc4 = get_fc_block(4, input_size, is_last=True) - def forward(self, inputs): + if self.use_fleet_sq and not use_raw_recompute: + self.runfuncs = paddle.nn.Sequential(self.runfunc0, self.runfunc1, + self.runfunc2, self.runfunc3, + self.runfunc4) - if 0 in self.recompute_blocks: - inputs = recompute(self.runfunc0, inputs) - else: - inputs = self.runfunc0(inputs) + self.layers = [ + self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, + self.runfunc4 + ] - if 1 in self.recompute_blocks: - inputs = recompute(self.runfunc1, inputs) - else: - inputs = self.runfunc1(inputs) + # default segments = 2 + if use_raw_recompute: + self.layers = [ + paddle.nn.Sequential(self.runfunc0, self.runfunc1), + paddle.nn.Sequential(self.runfunc2, self.runfunc3, + self.runfunc4) + ] - if 2 in self.recompute_blocks: - inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs) - else: - inputs = self.runfunc2(inputs) + def forward(self, inputs): - if 3 in self.recompute_blocks: - inputs = recompute(self.runfunc3, inputs) - else: - inputs = self.runfunc3(inputs) + if self.use_fleet_sq and not self.use_raw_recompute: + return paddle.incubate.distributed.fleet.recompute_sequential( + {"segments": self.segments}, self.runfuncs, inputs) - if 4 in self.recompute_blocks: - inputs = recompute(self.runfunc4, inputs) - else: - inputs = self.runfunc4(inputs) + if self.use_raw_recompute: + inputs = recompute(self.layers[0], inputs) + return self.layers[1](inputs) + + for i in range(len(self.layers)): + if i in self.recompute_blocks: + inputs = recompute(self.layers[i], inputs, + **self.recompute_kwargs) + else: + inputs = self.layers[i](inputs) return inputs def run_model(recompute_block=[], recompute_kwargs={}, + use_fleet_sq=False, + use_raw_recompute=False, + segments=1, enable_autocast=False, pure_fp16=False): gen = paddle.seed(10) @@ -109,6 +128,9 @@ def run_model(recompute_block=[], batch_size, input_size = 1, 10 model = Naive_fc_net(input_size, recompute_blocks=recompute_block, + use_fleet_sq=use_fleet_sq, + use_raw_recompute=use_raw_recompute, + segments=segments, recompute_kwargs=recompute_kwargs) loss_fn = paddle.nn.MSELoss(reduction='mean') optimizer = paddle.optimizer.SGD(learning_rate=0.01, @@ -183,6 +205,28 @@ class TestPyLayer(unittest.TestCase): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + # recompute_sequential with segments=1 using fleet + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with base recompute, and segments=2 + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], + enable_autocast=enable_autocast, + use_raw_recompute=True, + pure_fp16=pure_fp16) + + # recompute using paddle.incubate.distributed.fleet.recompute_sequential, segments=2 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + segments=2, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_dropout(self): self.test_base_case() @@ -201,7 +245,7 @@ class TestPyLayer(unittest.TestCase): def test_recompute_kwargs(self): paddle.set_device("gpu") kwargs = {"is_test": False} - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], recompute_kwargs=kwargs) diff --git a/python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py b/python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py new file mode 100755 index 00000000000..cc90f0433b9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py @@ -0,0 +1,221 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +from paddle.autograd import PyLayer +from paddle.distributed.fleet.utils import recompute +from paddle.incubate.distributed.fleet import recompute_hybrid +import random +from paddle.distributed import fleet + +import paddle.fluid.layers as layers + + +def get_fc_block(block_idx, input_size, is_last=False): + block_name = "block_" + str(block_idx) + block = paddle.nn.Sequential( + (block_name + "_fc_0", + paddle.nn.Linear(input_size, input_size, bias_attr=False)), + (block_name + "_dropout", paddle.nn.Dropout(p=0.5)), + (block_name + "_relu_1", paddle.nn.ReLU()), + (block_name + "_fc_1", + paddle.nn.Linear(input_size, input_size, bias_attr=False)), + (block_name + "_relu_2", paddle.nn.ReLU()), + ) + if is_last: + block.add_sublayer(block_name + "_fc_2", + paddle.nn.Linear(input_size, 1, + bias_attr=False)) # add sublayer + else: + block.add_sublayer(block_name + "_fc_2", + paddle.nn.Linear(input_size, + input_size, + bias_attr=False)) # add sublayer + return block + + +class Naive_fc_net(paddle.nn.Layer): + + def __init__(self, + input_size=10, + recompute_blocks=[1, 3], + offload=False, + partition=False, + recompute_kwargs={}): + super(Naive_fc_net, self).__init__() + self.recompute_blocks = recompute_blocks + self.recompute_kwargs = recompute_kwargs + self.offload = offload + self.partition = partition + + self.runfunc0 = get_fc_block(0, input_size, is_last=False) + self.runfunc1 = get_fc_block(1, input_size, is_last=False) + self.runfunc2 = get_fc_block(2, input_size, is_last=False) + self.runfunc3 = get_fc_block(3, input_size, is_last=False) + self.runfunc4 = get_fc_block(4, input_size, is_last=True) + + self.layers = [ + self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, + self.runfunc4 + ] + + def forward(self, inputs): + for i in range(len(self.layers)): + if i in self.recompute_blocks: + inputs = recompute_hybrid( + { + "mp_group": fleet.fleet._hcg.get_model_parallel_group(), + "offload": self.offload, + "partition": self.partition + }, self.layers[i], inputs, **self.recompute_kwargs) + else: + inputs = self.layers[i](inputs) + + return inputs + + +def run_model(recompute_block=[], + recompute_kwargs={}, + offload=False, + partition=False, + enable_autocast=False, + pure_fp16=False): + gen = paddle.seed(10) + gen.manual_seed(10) + np.random.seed(10) + random.seed(10) + + batch_size, input_size = 1, 10 + model = Naive_fc_net(input_size, + recompute_blocks=recompute_block, + offload=offload, + partition=partition, + recompute_kwargs=recompute_kwargs) + loss_fn = paddle.nn.MSELoss(reduction='mean') + optimizer = paddle.optimizer.SGD(learning_rate=0.01, + parameters=model.parameters()) + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + + if enable_autocast: + scaler = paddle.amp.GradScaler() + scaler = fleet.distributed_scaler(scaler) + + loss_ = [] + param_ = [] + grad_ = [] + for step in range(10): + + x_data = np.random.randn(batch_size, input_size).astype(np.float32) + x = paddle.to_tensor(x_data) + # x.stop_gradient = False + level = 'O2' if pure_fp16 else 'O1' + with paddle.amp.auto_cast(True, level=level): + y_pred = model(x) + loss = y_pred.mean() + if enable_autocast: + scaler.scale(loss).backward() + scaler.minimize(optimizer, loss) + else: + loss_.append(np.asarray(loss).tolist()) + loss.backward() + optimizer.step() + + param_.append(np.asarray(model.parameters()[9]).tolist()) + grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist()) + + optimizer.clear_grad() + return loss_, param_, grad_ + + +class TestPyLayer(unittest.TestCase): + + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + fleet.init(is_collective=True, strategy=strategy) + + def test_base_case(self, enable_autocast=False, pure_fp16=False): + + def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): + self.assertEqual(loss_ref, loss) + self.assertEqual(param_ref, param) + self.assertEqual(grad_ref, grad) + + # without recompute + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + + # with recompute, offload=False, partition=False + loss, param, grad = run_model(recompute_block=[1, 3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with recompute, offload=True, partition=False + loss, param, grad = run_model(recompute_block=[1, 2, 3], + offload=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with recompute, offload=False, partition=True + loss, param, grad = run_model(recompute_block=[1], + partition=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with recompute, offload=True, partition=True + loss, param, grad = run_model(recompute_block=[1, 3, 4], + offload=True, + partition=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + def test_fc_net_with_dropout(self): + self.test_base_case() + + def test_fc_net_with_amp(self): + self.test_base_case(enable_autocast=True) + + def test_fc_net_with_fp16(self): + self.test_base_case(enable_autocast=True, pure_fp16=True) + + def test_recompute_kwargs(self): + paddle.set_device("gpu") + kwargs = {"is_test": False} + with self.assertRaises(TypeError): + loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], + recompute_kwargs=kwargs) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py index 8773e8d47ed..10243a0faa9 100644 --- a/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py @@ -26,5 +26,11 @@ class TestPipelineParallel(TestMultipleGpus): self.run_mnist_2gpu('hybrid_parallel_pp_alexnet.py') +class TestModelParallelWithRecompute(TestMultipleGpus): + + def test_model_parallel_with_recompute(self): + self.run_mnist_2gpu("dygraph_recompute_hybrid.py") + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/incubate/distributed/fleet/__init__.py b/python/paddle/incubate/distributed/fleet/__init__.py new file mode 100644 index 00000000000..94e1a7c8bbe --- /dev/null +++ b/python/paddle/incubate/distributed/fleet/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.distributed.fleet.recompute import recompute_sequential, recompute_hybrid + +__all__ = ["recompute_sequential", "recompute_hybrid"] diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index f25b00cb4be..65a7eec9973 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -34,9 +34,9 @@ from paddle.distributed import fleet from paddle.autograd import PyLayer from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate from .utils import count_by_gate -from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute from paddle import fluid from paddle.fluid.framework import in_dygraph_mode +from paddle.incubate.distributed.fleet import recompute_hybrid def _local_scatter(inp, pos): @@ -424,8 +424,8 @@ class MoELayer(nn.Layer): if self.recompute_interval <= 0 or x.shape[0] == 0: x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) else: - x = _hp_recompute(experts_fwd, x, fwd_expert_count.numpy(), - self.experts) + x = recompute_hybrid(self.recompute_ctx, experts_fwd, x, + fwd_expert_count.numpy(), self.experts) out_batch_size = inp.shape[0] if len(gate.shape) == 2: diff --git a/python/paddle/incubate/optimizer/distributed_fused_lamb.py b/python/paddle/incubate/optimizer/distributed_fused_lamb.py index d230b6afca2..f8e3b55aba6 100644 --- a/python/paddle/incubate/optimizer/distributed_fused_lamb.py +++ b/python/paddle/incubate/optimizer/distributed_fused_lamb.py @@ -19,7 +19,7 @@ from paddle.fluid.clip import ClipGradByGlobalNorm from paddle.fluid.initializer import Constant from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.optimizer import Optimizer -from paddle.distributed import get_rank, get_world_size +import paddle.distributed as dist from paddle.distributed.collective import new_group from paddle.fluid.executor import global_scope from paddle.fluid.framework import name_scope @@ -288,8 +288,8 @@ class DistributedFusedLamb(Optimizer): step = self._get_or_create_step() - rank = get_rank() - nranks = get_world_size() + rank = dist.get_rank() + nranks = dist.get_world_size() if self._nproc_per_node is None: nproc_per_node = nranks else: diff --git a/python/setup.py.in b/python/setup.py.in index c51a3765da1..084fd5ac042 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -296,6 +296,7 @@ packages=['paddle', 'paddle.distributed.launch.plugins', 'paddle.distributed.launch.utils', 'paddle.distributed.fleet.base', + 'paddle.distributed.fleet.recompute', 'paddle.distributed.fleet.elastic', 'paddle.distributed.fleet.meta_optimizers', 'paddle.distributed.fleet.meta_optimizers.sharding', @@ -380,6 +381,7 @@ packages=['paddle', 'paddle.incubate.optimizer.functional', 'paddle.incubate.autograd', 'paddle.incubate.distributed', + 'paddle.incubate.distributed.fleet', 'paddle.incubate.distributed.models', 'paddle.incubate.distributed.models.moe', 'paddle.incubate.distributed.models.moe.gate', -- GitLab