diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index f6878ec1d8627c616ecee7771580c31293cce5d9..e3c92ee1db7c7c54324146e6f8b21d6ac36253b9 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -49,8 +49,9 @@ from functools import partial import paddle from paddle.fluid.dygraph.layers import Layer from ...utils.log_util import logger, layer_to_str -from ..pp_utils.utils import _hp_recompute, _initialize_recompute_setting +from paddle.distributed import fleet from paddle.fluid.framework import in_dygraph_mode +from paddle.incubate.distributed.fleet import recompute_hybrid __all__ = [] @@ -309,13 +310,13 @@ class PipelineLayer(Layer): self._loss_fn = loss_fn self._topo = topology self._recompute_interval = recompute_interval + self.recompute_ctx = recompute_ctx if recompute_interval > 0: assert recompute_ctx is not None, "recompute_ctx must be not None for recompute." offload = recompute_ctx.get('offload', False) partition = recompute_ctx.get('partition', False) - _initialize_recompute_setting(offload, partition) logger.info( "Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}" .format(offload, partition)) @@ -638,7 +639,8 @@ class PipelineLayer(Layer): input = (input, ) if self._need_recompute(funcs, input): - input = _hp_recompute( + input = recompute_hybrid( + self.recompute_ctx, self.forward_function(start_idx, end_idx), *input) else: input = self.forward_function(start_idx, end_idx)(*input) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 876f9ffaed32bbf55e118ff0289c919f74d62450..537885bfad349779fe1b791c5693f2c8c5b114bb 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -14,7 +14,6 @@ import paddle import paddle.fluid as fluid from .meta_parallel_base import MetaParallelBase -from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg from .parallel_layers.pp_layers import PipelineLayer from ..utils.hybrid_parallel_util import broadcast_mp_parameters @@ -61,8 +60,6 @@ class PipelineParallel(MetaParallelBase): p2p.initialize_p2p_groups(hcg, self._using_cache) - _initialize_recompute_hcg(hcg) - self.global_rank = self._hcg.get_global_rank() self.micro_batch_id = 0 diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py index 786eb20487a52e884db35795e006681d513d0b1c..04575bfb231946e87150ec121ad0be8cd3ea599f 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py @@ -12,6 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .utils import get_tensor_bytes - __all__ = [] diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index ce5c1cfe9eb8537e6140c250fdfdbfe409f4d72b..3b4094f047552dd5d822f4f4c576779f9a4e16f3 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -13,12 +13,12 @@ # limitations under the License. import paddle -from .utils import paddle_2_number, number_2_dtype from ...utils.log_util import logger import numpy as np from paddle import _C_ops, _legacy_C_ops import paddle.fluid.core as core from paddle.fluid.framework import _in_legacy_dygraph, _non_static_mode, in_dygraph_mode +from .utils import paddle_2_number, paddle_2_number, number_2_dtype _hcg = None _use_cache = False diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 8ec7f0f037b0676e6a3da47f2610c9e5991262ce..683cc51d279079a2e09941f9d9ebe4313e126b2c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -12,16 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib - import paddle from paddle.fluid import core from paddle import _C_ops, _legacy_C_ops -from paddle.autograd import PyLayer -from paddle.fluid import framework -from ...utils.recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker -from ..parallel_layers.random import get_rng_state_tracker -from paddle.fluid.framework import in_dygraph_mode __all__ = [] @@ -88,23 +81,6 @@ def get_tensor_bytes(tensor): return tensor.numel() * elem_size -_hcg = None -_recompute_offload = False -_recompute_partition = False - - -def _initialize_recompute_setting(is_offload, is_partition): - global _recompute_offload, _recompute_partition - - _recompute_offload = is_offload - _recompute_partition = is_partition - - -def _initialize_recompute_hcg(hcg): - global _hcg - _hcg = hcg - - def _all_gather(tensor, group=None, use_calc_stream=True): """ The main difference with paddle.distributed.all_gather: @@ -117,187 +93,3 @@ def _all_gather(tensor, group=None, use_calc_stream=True): ).nranks if group is None else group.nranks return _legacy_C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'nranks', nranks) - - -def _split_activation(tensor): - global _hcg - - mp_degree = _hcg.get_model_parallel_world_size() - mp_rank = _hcg.get_model_parallel_rank() - if mp_degree < 2: - return tensor - - tensor_numel = paddle.numel(tensor) - assert tensor_numel != 0, "can't recompute zero element" - assert tensor_numel % mp_degree == 0, "The capacity of the activation () cannot be divisible by mp_degree()".format( - tensor_numel, mp_degree) - - # use inplace operation to save memory - data = tensor.flatten_() - - part_size = tensor_numel // mp_degree - start = part_size * mp_rank - end = start + part_size - return data[start:end] - - -def _merge_activation(tensor): - global _hcg - mp_degree = _hcg.get_model_parallel_world_size() - mp_rank = _hcg.get_model_parallel_rank() - mp_group = _hcg.get_model_parallel_group() - if mp_degree < 2: - return tensor - return _all_gather(tensor, group=mp_group) - - -class _HPRecomputeFunction(PyLayer): - """ - Compared with paddle.distributed.fleet.utils.recompute, there are the following differences: - 1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type. - 2. Offload support for activation - 3. Support MP segmentation of activation to further reduce cuda memory - 4. Adapt to the random state of MP - """ - - @staticmethod - def forward(ctx, run_function, all_outputs, *args): - check_recompute_necessary(args) - - # store for recomputing - ctx.run_function = run_function - - # store the rng states - ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state() - ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( - ).get_states_tracker() - - # save input for backward - ctx.inputs = [] - ctx.tensor_indices = [] - ctx.tensor_shapes = [] - tensor_inputs = [] - - cur_device = paddle.get_device() - assert 'gpu:' in paddle.get_device( - ), "Recompute with RNG is not support current device: {}.".format( - cur_device) - - # TODO support AMP - tracer = framework._dygraph_tracer() - ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True - if tracer._amp_level == core.AmpLevel.O2: - ctx.amp_level = 'O2' - elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): - ctx.amp_level = 'O1' - else: - raise ValueError("unsupported amp level: {}".format( - tracer._amp_level)) - ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() - - with paddle.no_grad(): - outputs = run_function(*args) - - for i, arg in enumerate(args): - if paddle.is_tensor(arg): - state = arg.stop_gradient - if _recompute_partition: - ctx.tensor_shapes.append(arg.shape) - partition = _split_activation(arg.detach()).clone() - # TODO(shenliang03) not use calculate stream to D2H to speed - arg = partition.cpu() if _recompute_offload else partition - else: - arg = arg.cpu() if _recompute_offload else arg - arg.stop_gradient = state - tensor_inputs.append(arg) - ctx.tensor_indices.append(i) - ctx.inputs.append(None) - else: - ctx.inputs.append(arg) - - ctx.save_for_backward(*tensor_inputs) - - if paddle.is_tensor(outputs): - all_outputs += [outputs] - return outputs - else: - all_outputs += outputs - return tuple(outputs) - - @staticmethod - def backward(ctx, *args): - with paddle.fluid.dygraph.guard(): - # Restore inputs - inputs = list(ctx.inputs) - tensor_indices = ctx.tensor_indices - tensor_shapes = ctx.tensor_shapes - tensors = list(ctx.saved_tensor()) - - device_id = paddle.distributed.ParallelEnv().device_id - for i, idx in enumerate(tensor_indices): - if _recompute_partition: - state = tensors[i].stop_gradient - tensors[i] = _merge_activation( - tensors[i]).detach().reshape_(tensor_shapes[i]) - tensors[i].stop_gradient = state - inputs[idx] = tensors[i].cuda( - device_id) if _recompute_offload else tensors[i] - - tracer = framework._dygraph_tracer() - tracer._has_grad = True - - # need restore auto_cast state as well as w/b list - with swith_rng_state_tracker(ctx.fwd_cuda_rng_state, - ctx.fwd_cuda_rng_state_tracker): - with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, - custom_white_list=ctx.amp_white_list, - custom_black_list=ctx.amp_black_list, - level=ctx.amp_level): - detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) - - if isinstance(outputs, (core.VarBase, core.eager.Tensor)): - outputs = (outputs, ) - assert len(outputs) == len(args) - - forward_outputs_with_grad = [] - backward_inputs = [] - - for i in range(len(outputs)): - if isinstance( - outputs[i], - (core.VarBase, - core.eager.Tensor)) and not outputs[i].stop_gradient: - forward_outputs_with_grad.append(outputs[i]) - backward_inputs.append(args[i]) - - if len(forward_outputs_with_grad) == 0: - raise RuntimeError( - "none of output has stop_gradient=False, this recompute() is not necessary" - ) - - # actually backward - paddle.autograd.backward(forward_outputs_with_grad, backward_inputs) - grads = tuple(inp._grad_ivar() for inp in detached_inputs - if isinstance(inp, (core.VarBase, core.eager.Tensor))) - return grads - - -def _hp_recompute(function, *args): - # NODTE(shenliang03)The current hybrid parallel recompute has limitations. - # It cannot handle the following situations: - # 1. The calculation output of recompute, there are tensors that do not require gradients. - # 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach(). - # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor - - all_outputs = [] - _HPRecomputeFunction.apply(function, all_outputs, *args) - - if len(all_outputs) == 1: - return all_outputs[0] - else: - for output in all_outputs: - if paddle.is_tensor(output) and not is_float_tensor(output): - output.stop_gradient = True - - return tuple(all_outputs) diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index fea2614fe84c3f9b0efe68222873a5da1e1f4175..40633788f12d45be4bc30f25044e700656034d93 100644 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -20,46 +20,9 @@ from .base.topology import ParallelMode from .meta_parallel import TensorParallel, model_parallel_random_seed from .meta_parallel import PipelineParallel, ShardingParallel, PipelineParallelWithInterleave, PipelineLayer from paddle.fluid import core -from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar from paddle.distributed import fleet - -class _RecomputeModelWrapper(paddle.nn.Layer): - - def __init__(self, model, segments=2, preserve_rng_state=True): - super(_RecomputeModelWrapper, self).__init__() - assert isinstance(model, paddle.nn.Sequential), ( - "The model passed to RecomputeModelWrapper must be of type " - "paddle.nn.Sequential.") - self._model = model - self._segments = segments - self._preserve_rng_state = preserve_rng_state - self._layers = list(model.children()) - self._segment_size = len(self._layers) // segments - - def _run_func(self, begin, end): - - def do_run(input): - for i in range(begin, end): - input = self._layers[i](input) - return input - - return do_run - - def _checkpoint(self, func, *args, **kwargs): - return LegacyRecomputeFunction.apply(func, self._preserve_rng_state, - *args) - - def forward(self, input): - end = 0 - for begin in range(0, self._segment_size * (self._segments - 1), - self._segment_size): - end = begin + self._segment_size - input = self._checkpoint(self._run_func(begin, end), input) - return self._run_func(end, len(self._layers))(input) - - _grad_scalar = None @@ -125,7 +88,6 @@ def distributed_model(model): return model amp_enable = False - recompute_enable = False strategy = fleet_env._user_defined_strategy if strategy.amp == True: amp_enable = True @@ -154,10 +116,6 @@ def distributed_model(model): decr_every_n_nan_or_inf=decr_every_n_nan_or_inf, use_dynamic_loss_scaling=use_dynamic_loss_scaling) - if strategy.recompute == True: - recompute_enable = True - model = _RecomputeModelWrapper(model) - if strategy.heter_ccl_mode == True: distributed_model = paddle.DataParallel( model, diff --git a/python/paddle/distributed/fleet/recompute/__init__.py b/python/paddle/distributed/fleet/recompute/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e5bcdb1db277661551bdde6e2c1a37e8d2bf906 --- /dev/null +++ b/python/paddle/distributed/fleet/recompute/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .recompute import recompute, recompute_sequential +from .recompute_hybrid import recompute_hybrid + +__all__ = [] diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py similarity index 88% rename from python/paddle/distributed/fleet/utils/recompute.py rename to python/paddle/distributed/fleet/recompute/recompute.py index 2dddb1d9fb492969b4bda0b0dd19d94f75199219..28ded25a0e6e0f96d97f396923c3160fc1b007d1 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -207,12 +207,13 @@ class LegacyRecomputeFunction(LegacyPyLayer): class RecomputeFunction(PyLayer): @staticmethod - def forward(ctx, run_function, preserve_rng_state, *args): + def forward(ctx, run_function, preserve_rng_state, *args, **kwargs): from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker # store for recomputing ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state + ctx.kwargs = kwargs # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input # the order of tensors in backward()'s output should be the same as tensors in forward()'s input @@ -265,7 +266,7 @@ class RecomputeFunction(PyLayer): ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() with paddle.no_grad(): - outputs = run_function(*args) + outputs = run_function(*args, **kwargs) return outputs @staticmethod @@ -297,7 +298,8 @@ class RecomputeFunction(PyLayer): level=ctx.amp_level, dtype=ctx.amp_dtype): detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) + outputs = ctx.run_function(*detached_inputs, + **ctx.kwargs) else: with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, @@ -305,7 +307,7 @@ class RecomputeFunction(PyLayer): level=ctx.amp_level, dtype=ctx.amp_dtype): detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) + outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) if isinstance(outputs, (core.VarBase, core.eager.Tensor)): outputs = (outputs, ) @@ -352,7 +354,7 @@ def recompute(function, *args, **kwargs): recompute intermediate activations to save then memory. Parameters: - function(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model + function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. *args(Tensor): inputs to the function. @@ -466,11 +468,59 @@ def recompute(function, *args, **kwargs): """ # Hack to mix *args with **kwargs in a python 2.7-compliant way preserve = kwargs.pop('preserve_rng_state', True) - if kwargs: - raise ValueError("Unexpected keyword arguments: " + - ",".join(arg for arg in kwargs)) if framework._dygraph_tracer()._has_grad: check_recompute_necessary(args) - return RecomputeFunction.apply(function, preserve, *args) + return RecomputeFunction.apply(function, preserve, *args, **kwargs) + + +def recompute_sequential(ctx, functions, *args, **kwargs): + """ + recompute intermediate activations to save then memory for 'Sequential' models. + + Parameters: + ctx(dict): include 'segments' and 'preserve_rng_state' keys, the key 'segments' (int, default 1), represents the number of chunks to create in the model, + the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be + restored when the forward recalculation of backpropagation is performed. and some keys such as 'mp_group', 'offload' and 'partition' are invalid here, + they are useful in 'recompute_hybrid' API. + functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model + whose intermediate activations will be released to save memory in forward stage and will be recomputed + in backward stage for gradient calculation. + *args(Tensor): inputs(tuple) to the function. + **kwargs(Dict): inputs(dict) to the function. + + Returns: + Output of function on args and kwargs. + + Examples: + .. code-block:: python + + model = paddle.nn.Sequential(...) + input = recompute_sequential({'segments' : 1}, model, input) + """ + segments = ctx.get('segments', 1) + preserve_rng_state = ctx.get('preserve_rng_state', True) + + def _run_func(begin, end, funcs): + + def do_run(input): + for i in range(begin, end + 1): + input = funcs[i](input) + return input + + return do_run + + if isinstance(functions, paddle.nn.Sequential): + functions = list(functions.children()) + + segment_size = len(functions) // segments + + end = -1 + for begin in range(0, segment_size * (segments - 1), segment_size): + end = begin + segment_size - 1 + args = recompute(_run_func(begin, end, functions), + *args, + preserve_rng_state=preserve_rng_state, + **kwargs) + return _run_func(end + 1, len(functions) - 1, functions)(args) diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py new file mode 100644 index 0000000000000000000000000000000000000000..4883cad2511bb83a77ab81635e7bb7096c0e6298 --- /dev/null +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -0,0 +1,250 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib + +import paddle +from paddle import _C_ops, _legacy_C_ops +from paddle.fluid import core +from paddle.autograd import PyLayer +from paddle.fluid import framework +from ..meta_parallel.parallel_layers.random import get_rng_state_tracker +from paddle.fluid.framework import in_dygraph_mode +from paddle.distributed import fleet +from .recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker +from ..meta_parallel.pp_utils import utils + +__all__ = [] + + +def _split_activation(tensor, mp_group): + + mp_degree = mp_group.nranks + mp_rank = mp_group.rank + if mp_degree < 2: + return tensor + + tensor_numel = paddle.numel(tensor) + assert tensor_numel != 0, "can't recompute zero element" + assert tensor_numel % mp_degree == 0, "The capacity of the activation ({}) cannot be divisible by mp_degree({})".format( + tensor_numel, mp_degree) + + # use inplace operation to save memory + data = tensor.flatten_() + + part_size = tensor_numel // mp_degree + start = part_size * mp_rank + end = start + part_size + return data[start:end] + + +def _merge_activation(tensor, mp_group): + mp_degree = mp_group.nranks + mp_rank = mp_group.rank + if mp_degree < 2: + return tensor + + # adapt to new dygraph + tensor_shape = list(tensor.shape) + tensor_shape[0] *= mp_group.nranks + out = paddle.empty(tensor_shape, tensor.dtype) + task = mp_group.process_group.all_gather(tensor.cuda(), out) + task.wait() + return out + + +class _HPRecomputeFunction(PyLayer): + """ + Compared with paddle.distributed.fleet.utils.recompute, there are the following differences: + 1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type. + 2. Offload support for activation + 3. Support MP segmentation of activation to further reduce cuda memory + 4. Adapt to the random state of MP + """ + + @staticmethod + def forward(ctx, run_function, all_outputs, mp_group, offload, partition, + *args, **kwargs): + check_recompute_necessary(args) + + # store for recomputing + ctx.run_function = run_function + + ctx.kwargs = kwargs + + # store the rng states + ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( + ).get_states_tracker() + + # save config info + ctx.mp_group = mp_group + ctx.offload = offload + ctx.partition = partition + + # save input for backward + ctx.inputs = [] + ctx.tensor_indices = [] + ctx.tensor_shapes = [] + tensor_inputs = [] + + cur_device = paddle.get_device() + assert 'gpu:' in paddle.get_device( + ), "Recompute with RNG is not support current device: {}.".format( + cur_device) + + # TODO support AMP + tracer = framework._dygraph_tracer() + ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True + if tracer._amp_level == core.AmpLevel.O2: + ctx.amp_level = 'O2' + elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): + ctx.amp_level = 'O1' + else: + raise ValueError("unsupported amp level: {}".format( + tracer._amp_level)) + ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() + + with paddle.no_grad(): + outputs = run_function(*args, **kwargs) + + for i, arg in enumerate(args): + if paddle.is_tensor(arg): + state = arg.stop_gradient + if partition: + ctx.tensor_shapes.append(arg.shape) + partition = _split_activation(arg.detach(), + mp_group).clone() + # TODO(shenliang03) not use calculate stream to D2H to speed + arg = partition.cpu() if offload else partition + else: + arg = arg.cpu() if offload else arg + arg.stop_gradient = state + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + ctx.save_for_backward(*tensor_inputs) + + if paddle.is_tensor(outputs): + all_outputs += [outputs] + return outputs + else: + all_outputs += outputs + return tuple(outputs) + + @staticmethod + def backward(ctx, *args): + with paddle.fluid.dygraph.guard(): + # Restore inputs + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensor_shapes = ctx.tensor_shapes + tensors = list(ctx.saved_tensor()) + + device_id = paddle.distributed.ParallelEnv().device_id + for i, idx in enumerate(tensor_indices): + if ctx.partition: + state = tensors[i].stop_gradient + tensors[i] = _merge_activation( + tensors[i], + ctx.mp_group).detach().reshape_(tensor_shapes[i]) + tensors[i].stop_gradient = state + inputs[idx] = tensors[i].cuda( + device_id) if ctx.offload else tensors[i] + + tracer = framework._dygraph_tracer() + tracer._has_grad = True + + # need restore auto_cast state as well as w/b list + with swith_rng_state_tracker(ctx.fwd_cuda_rng_state, + ctx.fwd_cuda_rng_state_tracker): + with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, + custom_white_list=ctx.amp_white_list, + custom_black_list=ctx.amp_black_list, + level=ctx.amp_level): + detached_inputs = detach_variable(tuple(inputs)) + outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) + + if isinstance(outputs, (core.VarBase, core.eager.Tensor)): + outputs = (outputs, ) + assert len(outputs) == len(args) + + forward_outputs_with_grad = [] + backward_inputs = [] + + for i in range(len(outputs)): + if isinstance( + outputs[i], + (core.VarBase, + core.eager.Tensor)) and not outputs[i].stop_gradient: + forward_outputs_with_grad.append(outputs[i]) + backward_inputs.append(args[i]) + + if len(forward_outputs_with_grad) == 0: + raise RuntimeError( + "none of output has stop_gradient=False, this recompute() is not necessary" + ) + + # actually backward + paddle.autograd.backward(forward_outputs_with_grad, backward_inputs) + grads = tuple(inp._grad_ivar() for inp in detached_inputs + if isinstance(inp, (core.VarBase, core.eager.Tensor))) + return grads + + +def recompute_hybrid(ctx, function, *args, **kwargs): + """ + # NODTE(shenliang03)The current hybrid parallel recompute has limitations. + # It cannot handle the following situations: + # 1. The calculation output of recompute, there are tensors that do not require gradients. + # 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach(). + # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor + + Parameters: + ctx(dict): include 'mp_group', 'offload', and 'partition' keys. the key 'mp_group' (Group), represents the avtivations are splitted + in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False), + represents whether to split activations in the mp_group. and some keys such as 'segments' and 'preserve_rng_state' are invalid here, they are useful in + 'recompute_sequential' API. + function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model + whose intermediate activations will be released to save memory in forward stage and will be recomputed + in backward stage for gradient calculation. + *args(Tensor): inputs(tuple) to the function. + + **kwargs(Dict): inputs(dict) to the function. + + Returns: + Output of function on args and kwargs. + + """ + mp_group = ctx.get('mp_group', None) + assert mp_group is not None, "ctx must contains mp_group and mp_group can not be None." + + offload = ctx.get('offload', False) + partition = ctx.get('partition', False) + + all_outputs = [] + _HPRecomputeFunction.apply(function, all_outputs, mp_group, offload, + partition, *args, **kwargs) + + if len(all_outputs) == 1: + return all_outputs[0] + else: + for output in all_outputs: + if paddle.is_tensor(output) and not utils.is_float_tensor(output): + output.stop_gradient = True + + return tuple(all_outputs) diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 1bf90a22e375c7068653d78891237a710bd8d666..93fc890d05af5b6a7e3bcfa51be0e115271e929c 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -15,11 +15,21 @@ from .fs import LocalFS # noqa: F401 from .fs import HDFSClient # noqa: F401 from .ps_util import DistributedInfer # noqa: F401 -from .recompute import recompute # noqa: F401 +import paddle.utils.deprecated as deprecated +from paddle.distributed import fleet +import paddle from . import log_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401 __all__ = [ #noqa "LocalFS", "recompute", "DistributedInfer", "HDFSClient" ] + + +@deprecated(since="2.4.0", + update_to="paddle.distributed.fleet.recompute", + level=1, + reason="Please use new recompute API(fleet.recompute) ") +def recompute(function, *args, **kwargs): + return fleet.recompute.recompute(function, *args, **kwargs) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py index 11ca15fd33104b64cc9fb2ca6b6aee14e2f6d2cb..f5f59cf10279ae93b13a6937439e1b2e3d336056 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py @@ -21,6 +21,7 @@ import paddle from paddle.autograd import PyLayer from paddle.distributed.fleet.utils import recompute import random +from paddle.incubate.distributed.fleet import recompute_sequential import paddle.fluid.layers as layers @@ -53,48 +54,66 @@ class Naive_fc_net(paddle.nn.Layer): def __init__(self, input_size=10, recompute_blocks=[1, 3], + use_fleet_sq=False, + segments=1, + use_raw_recompute=False, recompute_kwargs={}): super(Naive_fc_net, self).__init__() self.recompute_blocks = recompute_blocks self.recompute_kwargs = recompute_kwargs + self.use_fleet_sq = use_fleet_sq + self.use_raw_recompute = use_raw_recompute + self.segments = segments + self.runfunc0 = get_fc_block(0, input_size, is_last=False) self.runfunc1 = get_fc_block(1, input_size, is_last=False) self.runfunc2 = get_fc_block(2, input_size, is_last=False) self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc4 = get_fc_block(4, input_size, is_last=True) - def forward(self, inputs): + if self.use_fleet_sq and not use_raw_recompute: + self.runfuncs = paddle.nn.Sequential(self.runfunc0, self.runfunc1, + self.runfunc2, self.runfunc3, + self.runfunc4) - if 0 in self.recompute_blocks: - inputs = recompute(self.runfunc0, inputs) - else: - inputs = self.runfunc0(inputs) + self.layers = [ + self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, + self.runfunc4 + ] - if 1 in self.recompute_blocks: - inputs = recompute(self.runfunc1, inputs) - else: - inputs = self.runfunc1(inputs) + # default segments = 2 + if use_raw_recompute: + self.layers = [ + paddle.nn.Sequential(self.runfunc0, self.runfunc1), + paddle.nn.Sequential(self.runfunc2, self.runfunc3, + self.runfunc4) + ] - if 2 in self.recompute_blocks: - inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs) - else: - inputs = self.runfunc2(inputs) + def forward(self, inputs): - if 3 in self.recompute_blocks: - inputs = recompute(self.runfunc3, inputs) - else: - inputs = self.runfunc3(inputs) + if self.use_fleet_sq and not self.use_raw_recompute: + return recompute_sequential({"segments": self.segments}, + self.runfuncs, inputs) - if 4 in self.recompute_blocks: - inputs = recompute(self.runfunc4, inputs) - else: - inputs = self.runfunc4(inputs) + if self.use_raw_recompute: + inputs = recompute(self.layers[0], inputs) + return self.layers[1](inputs) + + for i in range(len(self.layers)): + if i in self.recompute_blocks: + inputs = recompute(self.layers[i], inputs, + **self.recompute_kwargs) + else: + inputs = self.layers[i](inputs) return inputs def run_model(recompute_block=[], recompute_kwargs={}, + use_fleet_sq=False, + use_raw_recompute=False, + segments=1, enable_autocast=False, pure_fp16=False): gen = paddle.seed(10) @@ -105,6 +124,9 @@ def run_model(recompute_block=[], batch_size, input_size = 1, 10 model = Naive_fc_net(input_size, recompute_blocks=recompute_block, + use_fleet_sq=use_fleet_sq, + use_raw_recompute=use_raw_recompute, + segments=segments, recompute_kwargs=recompute_kwargs) loss_fn = paddle.nn.MSELoss(reduction='mean') optimizer = paddle.optimizer.SGD(learning_rate=0.01, @@ -179,6 +201,34 @@ class TestPyLayer(unittest.TestCase): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + # recompute second & fourth block using fleet + loss, param, grad = run_model(recompute_block=[1, 3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute using recompute_sequential, segments=1 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with base recompute, and segments=2 + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], + enable_autocast=enable_autocast, + use_raw_recompute=True, + pure_fp16=pure_fp16) + + # recompute using recompute_sequential, segments=2 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + segments=2, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_dropout(self): self.test_base_case() @@ -191,7 +241,7 @@ class TestPyLayer(unittest.TestCase): def test_recompute_kwargs(self): paddle.set_device("gpu") kwargs = {"is_test": False} - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], recompute_kwargs=kwargs) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py index bc97d53485be99597bf0902870137fc1c6c0361d..4b0c73370d36125b2da9a46f99401bb34911237e 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py @@ -25,6 +25,7 @@ import paddle from paddle.autograd import PyLayer from paddle.distributed.fleet.utils import recompute import random +from paddle.incubate.distributed.fleet import recompute_sequential import paddle.fluid.layers as layers @@ -57,48 +58,66 @@ class Naive_fc_net(paddle.nn.Layer): def __init__(self, input_size=10, recompute_blocks=[1, 3], + use_fleet_sq=False, + segments=1, + use_raw_recompute=False, recompute_kwargs={}): super(Naive_fc_net, self).__init__() self.recompute_blocks = recompute_blocks self.recompute_kwargs = recompute_kwargs + self.use_fleet_sq = use_fleet_sq + self.use_raw_recompute = use_raw_recompute + self.segments = segments + self.runfunc0 = get_fc_block(0, input_size, is_last=False) self.runfunc1 = get_fc_block(1, input_size, is_last=False) self.runfunc2 = get_fc_block(2, input_size, is_last=False) self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc4 = get_fc_block(4, input_size, is_last=True) - def forward(self, inputs): + if self.use_fleet_sq and not use_raw_recompute: + self.runfuncs = paddle.nn.Sequential(self.runfunc0, self.runfunc1, + self.runfunc2, self.runfunc3, + self.runfunc4) - if 0 in self.recompute_blocks: - inputs = recompute(self.runfunc0, inputs) - else: - inputs = self.runfunc0(inputs) + self.layers = [ + self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, + self.runfunc4 + ] - if 1 in self.recompute_blocks: - inputs = recompute(self.runfunc1, inputs) - else: - inputs = self.runfunc1(inputs) + # default segments = 2 + if use_raw_recompute: + self.layers = [ + paddle.nn.Sequential(self.runfunc0, self.runfunc1), + paddle.nn.Sequential(self.runfunc2, self.runfunc3, + self.runfunc4) + ] - if 2 in self.recompute_blocks: - inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs) - else: - inputs = self.runfunc2(inputs) + def forward(self, inputs): - if 3 in self.recompute_blocks: - inputs = recompute(self.runfunc3, inputs) - else: - inputs = self.runfunc3(inputs) + if self.use_fleet_sq and not self.use_raw_recompute: + return paddle.incubate.distributed.fleet.recompute_sequential( + {"segments": self.segments}, self.runfuncs, inputs) - if 4 in self.recompute_blocks: - inputs = recompute(self.runfunc4, inputs) - else: - inputs = self.runfunc4(inputs) + if self.use_raw_recompute: + inputs = recompute(self.layers[0], inputs) + return self.layers[1](inputs) + + for i in range(len(self.layers)): + if i in self.recompute_blocks: + inputs = recompute(self.layers[i], inputs, + **self.recompute_kwargs) + else: + inputs = self.layers[i](inputs) return inputs def run_model(recompute_block=[], recompute_kwargs={}, + use_fleet_sq=False, + use_raw_recompute=False, + segments=1, enable_autocast=False, pure_fp16=False): gen = paddle.seed(10) @@ -109,6 +128,9 @@ def run_model(recompute_block=[], batch_size, input_size = 1, 10 model = Naive_fc_net(input_size, recompute_blocks=recompute_block, + use_fleet_sq=use_fleet_sq, + use_raw_recompute=use_raw_recompute, + segments=segments, recompute_kwargs=recompute_kwargs) loss_fn = paddle.nn.MSELoss(reduction='mean') optimizer = paddle.optimizer.SGD(learning_rate=0.01, @@ -183,6 +205,28 @@ class TestPyLayer(unittest.TestCase): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + # recompute_sequential with segments=1 using fleet + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with base recompute, and segments=2 + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], + enable_autocast=enable_autocast, + use_raw_recompute=True, + pure_fp16=pure_fp16) + + # recompute using paddle.incubate.distributed.fleet.recompute_sequential, segments=2 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + segments=2, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_dropout(self): self.test_base_case() @@ -201,7 +245,7 @@ class TestPyLayer(unittest.TestCase): def test_recompute_kwargs(self): paddle.set_device("gpu") kwargs = {"is_test": False} - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], recompute_kwargs=kwargs) diff --git a/python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py b/python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py new file mode 100755 index 0000000000000000000000000000000000000000..cc90f0433b92172eb5bf7c74c8e18b885582a88e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py @@ -0,0 +1,221 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +from paddle.autograd import PyLayer +from paddle.distributed.fleet.utils import recompute +from paddle.incubate.distributed.fleet import recompute_hybrid +import random +from paddle.distributed import fleet + +import paddle.fluid.layers as layers + + +def get_fc_block(block_idx, input_size, is_last=False): + block_name = "block_" + str(block_idx) + block = paddle.nn.Sequential( + (block_name + "_fc_0", + paddle.nn.Linear(input_size, input_size, bias_attr=False)), + (block_name + "_dropout", paddle.nn.Dropout(p=0.5)), + (block_name + "_relu_1", paddle.nn.ReLU()), + (block_name + "_fc_1", + paddle.nn.Linear(input_size, input_size, bias_attr=False)), + (block_name + "_relu_2", paddle.nn.ReLU()), + ) + if is_last: + block.add_sublayer(block_name + "_fc_2", + paddle.nn.Linear(input_size, 1, + bias_attr=False)) # add sublayer + else: + block.add_sublayer(block_name + "_fc_2", + paddle.nn.Linear(input_size, + input_size, + bias_attr=False)) # add sublayer + return block + + +class Naive_fc_net(paddle.nn.Layer): + + def __init__(self, + input_size=10, + recompute_blocks=[1, 3], + offload=False, + partition=False, + recompute_kwargs={}): + super(Naive_fc_net, self).__init__() + self.recompute_blocks = recompute_blocks + self.recompute_kwargs = recompute_kwargs + self.offload = offload + self.partition = partition + + self.runfunc0 = get_fc_block(0, input_size, is_last=False) + self.runfunc1 = get_fc_block(1, input_size, is_last=False) + self.runfunc2 = get_fc_block(2, input_size, is_last=False) + self.runfunc3 = get_fc_block(3, input_size, is_last=False) + self.runfunc4 = get_fc_block(4, input_size, is_last=True) + + self.layers = [ + self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, + self.runfunc4 + ] + + def forward(self, inputs): + for i in range(len(self.layers)): + if i in self.recompute_blocks: + inputs = recompute_hybrid( + { + "mp_group": fleet.fleet._hcg.get_model_parallel_group(), + "offload": self.offload, + "partition": self.partition + }, self.layers[i], inputs, **self.recompute_kwargs) + else: + inputs = self.layers[i](inputs) + + return inputs + + +def run_model(recompute_block=[], + recompute_kwargs={}, + offload=False, + partition=False, + enable_autocast=False, + pure_fp16=False): + gen = paddle.seed(10) + gen.manual_seed(10) + np.random.seed(10) + random.seed(10) + + batch_size, input_size = 1, 10 + model = Naive_fc_net(input_size, + recompute_blocks=recompute_block, + offload=offload, + partition=partition, + recompute_kwargs=recompute_kwargs) + loss_fn = paddle.nn.MSELoss(reduction='mean') + optimizer = paddle.optimizer.SGD(learning_rate=0.01, + parameters=model.parameters()) + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + + if enable_autocast: + scaler = paddle.amp.GradScaler() + scaler = fleet.distributed_scaler(scaler) + + loss_ = [] + param_ = [] + grad_ = [] + for step in range(10): + + x_data = np.random.randn(batch_size, input_size).astype(np.float32) + x = paddle.to_tensor(x_data) + # x.stop_gradient = False + level = 'O2' if pure_fp16 else 'O1' + with paddle.amp.auto_cast(True, level=level): + y_pred = model(x) + loss = y_pred.mean() + if enable_autocast: + scaler.scale(loss).backward() + scaler.minimize(optimizer, loss) + else: + loss_.append(np.asarray(loss).tolist()) + loss.backward() + optimizer.step() + + param_.append(np.asarray(model.parameters()[9]).tolist()) + grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist()) + + optimizer.clear_grad() + return loss_, param_, grad_ + + +class TestPyLayer(unittest.TestCase): + + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + fleet.init(is_collective=True, strategy=strategy) + + def test_base_case(self, enable_autocast=False, pure_fp16=False): + + def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): + self.assertEqual(loss_ref, loss) + self.assertEqual(param_ref, param) + self.assertEqual(grad_ref, grad) + + # without recompute + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + + # with recompute, offload=False, partition=False + loss, param, grad = run_model(recompute_block=[1, 3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with recompute, offload=True, partition=False + loss, param, grad = run_model(recompute_block=[1, 2, 3], + offload=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with recompute, offload=False, partition=True + loss, param, grad = run_model(recompute_block=[1], + partition=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with recompute, offload=True, partition=True + loss, param, grad = run_model(recompute_block=[1, 3, 4], + offload=True, + partition=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + def test_fc_net_with_dropout(self): + self.test_base_case() + + def test_fc_net_with_amp(self): + self.test_base_case(enable_autocast=True) + + def test_fc_net_with_fp16(self): + self.test_base_case(enable_autocast=True, pure_fp16=True) + + def test_recompute_kwargs(self): + paddle.set_device("gpu") + kwargs = {"is_test": False} + with self.assertRaises(TypeError): + loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], + recompute_kwargs=kwargs) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py index 8773e8d47ed3c0e1fa7b4e4442af0c07c0e6111f..10243a0faa9445ea6d14086d2e7d243be67be70f 100644 --- a/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py @@ -26,5 +26,11 @@ class TestPipelineParallel(TestMultipleGpus): self.run_mnist_2gpu('hybrid_parallel_pp_alexnet.py') +class TestModelParallelWithRecompute(TestMultipleGpus): + + def test_model_parallel_with_recompute(self): + self.run_mnist_2gpu("dygraph_recompute_hybrid.py") + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/incubate/distributed/fleet/__init__.py b/python/paddle/incubate/distributed/fleet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94e1a7c8bbe77bbf763f462d48fc5812243817f9 --- /dev/null +++ b/python/paddle/incubate/distributed/fleet/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.distributed.fleet.recompute import recompute_sequential, recompute_hybrid + +__all__ = ["recompute_sequential", "recompute_hybrid"] diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index f25b00cb4beefcff029e3833ac985908989724d0..65a7eec9973001d4ade517dddbf1a3bfd32e3b8d 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -34,9 +34,9 @@ from paddle.distributed import fleet from paddle.autograd import PyLayer from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate from .utils import count_by_gate -from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute from paddle import fluid from paddle.fluid.framework import in_dygraph_mode +from paddle.incubate.distributed.fleet import recompute_hybrid def _local_scatter(inp, pos): @@ -424,8 +424,8 @@ class MoELayer(nn.Layer): if self.recompute_interval <= 0 or x.shape[0] == 0: x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) else: - x = _hp_recompute(experts_fwd, x, fwd_expert_count.numpy(), - self.experts) + x = recompute_hybrid(self.recompute_ctx, experts_fwd, x, + fwd_expert_count.numpy(), self.experts) out_batch_size = inp.shape[0] if len(gate.shape) == 2: diff --git a/python/paddle/incubate/optimizer/distributed_fused_lamb.py b/python/paddle/incubate/optimizer/distributed_fused_lamb.py index d230b6afca2995f1dad428e95f87335e71cdb49d..f8e3b55aba6213658bfb730202087711e2de6ab8 100644 --- a/python/paddle/incubate/optimizer/distributed_fused_lamb.py +++ b/python/paddle/incubate/optimizer/distributed_fused_lamb.py @@ -19,7 +19,7 @@ from paddle.fluid.clip import ClipGradByGlobalNorm from paddle.fluid.initializer import Constant from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.optimizer import Optimizer -from paddle.distributed import get_rank, get_world_size +import paddle.distributed as dist from paddle.distributed.collective import new_group from paddle.fluid.executor import global_scope from paddle.fluid.framework import name_scope @@ -288,8 +288,8 @@ class DistributedFusedLamb(Optimizer): step = self._get_or_create_step() - rank = get_rank() - nranks = get_world_size() + rank = dist.get_rank() + nranks = dist.get_world_size() if self._nproc_per_node is None: nproc_per_node = nranks else: diff --git a/python/setup.py.in b/python/setup.py.in index c51a3765da12f6b7fddefea4342b5580417e19e2..084fd5ac042c60f801c07abb1daea1ed95a50a69 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -296,6 +296,7 @@ packages=['paddle', 'paddle.distributed.launch.plugins', 'paddle.distributed.launch.utils', 'paddle.distributed.fleet.base', + 'paddle.distributed.fleet.recompute', 'paddle.distributed.fleet.elastic', 'paddle.distributed.fleet.meta_optimizers', 'paddle.distributed.fleet.meta_optimizers.sharding', @@ -380,6 +381,7 @@ packages=['paddle', 'paddle.incubate.optimizer.functional', 'paddle.incubate.autograd', 'paddle.incubate.distributed', + 'paddle.incubate.distributed.fleet', 'paddle.incubate.distributed.models', 'paddle.incubate.distributed.models.moe', 'paddle.incubate.distributed.models.moe.gate',