diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index a6e8661f7a6eae312220a886fa6e0990acb164e8..926e6aab81e56dfead1d0569f52500a87d719a00 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -49,8 +49,9 @@ from functools import partial import paddle from paddle.fluid.dygraph.layers import Layer from ...utils.log_util import logger, layer_to_str -from ..pp_utils.utils import _hp_recompute, _initialize_recompute_setting +from paddle.distributed import fleet from paddle.fluid.framework import in_dygraph_mode +from paddle.incubate.distributed.fleet import recompute_hybrid __all__ = [] @@ -309,13 +310,13 @@ class PipelineLayer(Layer): self._loss_fn = loss_fn self._topo = topology self._recompute_interval = recompute_interval + self.recompute_ctx = recompute_ctx if recompute_interval > 0: assert recompute_ctx is not None, "recompute_ctx must be not None for recompute." offload = recompute_ctx.get('offload', False) partition = recompute_ctx.get('partition', False) - _initialize_recompute_setting(offload, partition) logger.info( "Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}" .format(offload, partition)) @@ -638,7 +639,8 @@ class PipelineLayer(Layer): input = (input, ) if self._need_recompute(funcs, input): - input = _hp_recompute( + input = recompute_hybrid( + self.recompute_ctx, self.forward_function(start_idx, end_idx), *input) else: input = self.forward_function(start_idx, end_idx)(*input) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 876f9ffaed32bbf55e118ff0289c919f74d62450..537885bfad349779fe1b791c5693f2c8c5b114bb 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -14,7 +14,6 @@ import paddle import paddle.fluid as fluid from .meta_parallel_base import MetaParallelBase -from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg from .parallel_layers.pp_layers import PipelineLayer from ..utils.hybrid_parallel_util import broadcast_mp_parameters @@ -61,8 +60,6 @@ class PipelineParallel(MetaParallelBase): p2p.initialize_p2p_groups(hcg, self._using_cache) - _initialize_recompute_hcg(hcg) - self.global_rank = self._hcg.get_global_rank() self.micro_batch_id = 0 diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py index 786eb20487a52e884db35795e006681d513d0b1c..04575bfb231946e87150ec121ad0be8cd3ea599f 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py @@ -12,6 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .utils import get_tensor_bytes - __all__ = [] diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index ce5c1cfe9eb8537e6140c250fdfdbfe409f4d72b..3b4094f047552dd5d822f4f4c576779f9a4e16f3 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -13,12 +13,12 @@ # limitations under the License. import paddle -from .utils import paddle_2_number, number_2_dtype from ...utils.log_util import logger import numpy as np from paddle import _C_ops, _legacy_C_ops import paddle.fluid.core as core from paddle.fluid.framework import _in_legacy_dygraph, _non_static_mode, in_dygraph_mode +from .utils import paddle_2_number, paddle_2_number, number_2_dtype _hcg = None _use_cache = False diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index bb774b8a0e5f896ef784e05b2d7426d2d6a966fa..c2008abb71c5371acfa55479e7cfb033380656a7 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -12,16 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib - import paddle from paddle.fluid import core from paddle import _C_ops, _legacy_C_ops -from paddle.autograd import PyLayer -from paddle.fluid import framework -from ...utils.recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker -from ..parallel_layers.random import get_rng_state_tracker -from paddle.fluid.framework import in_dygraph_mode __all__ = [] @@ -88,23 +81,6 @@ def get_tensor_bytes(tensor): return tensor.numel() * elem_size -_hcg = None -_recompute_offload = False -_recompute_partition = False - - -def _initialize_recompute_setting(is_offload, is_partition): - global _recompute_offload, _recompute_partition - - _recompute_offload = is_offload - _recompute_partition = is_partition - - -def _initialize_recompute_hcg(hcg): - global _hcg - _hcg = hcg - - def _all_gather(tensor, group=None, use_calc_stream=True): """ The main difference with paddle.distributed.all_gather: @@ -117,187 +93,3 @@ def _all_gather(tensor, group=None, use_calc_stream=True): ).nranks if group is None else group.nranks return _legacy_C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'nranks', nranks) - - -def _split_activation(tensor): - global _hcg - - mp_degree = _hcg.get_model_parallel_world_size() - mp_rank = _hcg.get_model_parallel_rank() - if mp_degree < 2: - return tensor - - tensor_numel = paddle.numel(tensor) - assert tensor_numel != 0, "can't recompute zero element" - assert tensor_numel % mp_degree == 0, "The capacity of the activation () cannot be divisible by mp_degree()".format( - tensor_numel, mp_degree) - - # use inplace operation to save memory - data = tensor.flatten_() - - part_size = tensor_numel // mp_degree - start = part_size * mp_rank - end = start + part_size - return data[start:end] - - -def _merge_activation(tensor): - global _hcg - mp_degree = _hcg.get_model_parallel_world_size() - mp_rank = _hcg.get_model_parallel_rank() - mp_group = _hcg.get_model_parallel_group() - if mp_degree < 2: - return tensor - return _all_gather(tensor, group=mp_group) - - -class _HPRecomputeFunction(PyLayer): - """ - Compared with paddle.distributed.fleet.utils.recompute, there are the following differences: - 1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type. - 2. Offload support for activation - 3. Support MP segmentation of activation to further reduce cuda memory - 4. Adapt to the random state of MP - """ - - @staticmethod - def forward(ctx, run_function, all_outputs, *args): - check_recompute_necessary(args) - - # store for recomputing - ctx.run_function = run_function - - # store the rng states - ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state() - ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( - ).get_states_tracker() - - # save input for backward - ctx.inputs = [] - ctx.tensor_indices = [] - ctx.tensor_shapes = [] - tensor_inputs = [] - - cur_device = paddle.get_device() - assert 'gpu:' in paddle.get_device( - ), "Recompute with RNG is not support current device: {}.".format( - cur_device) - - # TODO support AMP - tracer = framework._dygraph_tracer() - ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True - if tracer._amp_level == core.AmpLevel.O2: - ctx.amp_level = 'O2' - elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): - ctx.amp_level = 'O1' - else: - raise ValueError("unsupported amp level: {}".format( - tracer._amp_level)) - ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() - - with paddle.no_grad(): - outputs = run_function(*args) - - for i, arg in enumerate(args): - if paddle.is_tensor(arg): - state = arg.stop_gradient - if _recompute_partition: - ctx.tensor_shapes.append(arg.shape) - partition = _split_activation(arg.detach()).clone() - # TODO(shenliang03) not use calculate stream to D2H to speed - arg = partition.cpu() if _recompute_offload else partition - else: - arg = arg.cpu() if _recompute_offload else arg - arg.stop_gradient = state - tensor_inputs.append(arg) - ctx.tensor_indices.append(i) - ctx.inputs.append(None) - else: - ctx.inputs.append(arg) - - ctx.save_for_backward(*tensor_inputs) - - if paddle.is_tensor(outputs): - all_outputs += [outputs] - return outputs - else: - all_outputs += outputs - return tuple(outputs) - - @staticmethod - def backward(ctx, *args): - with paddle.fluid.dygraph.guard(): - # Restore inputs - inputs = list(ctx.inputs) - tensor_indices = ctx.tensor_indices - tensor_shapes = ctx.tensor_shapes - tensors = list(ctx.saved_tensor()) - - device_id = paddle.distributed.ParallelEnv().device_id - for i, idx in enumerate(tensor_indices): - if _recompute_partition: - state = tensors[i].stop_gradient - tensors[i] = _merge_activation( - tensors[i]).detach().reshape_(tensor_shapes[i]) - tensors[i].stop_gradient = state - inputs[idx] = tensors[i].cuda( - device_id) if _recompute_offload else tensors[i] - - tracer = framework._dygraph_tracer() - tracer._has_grad = True - - # need restore auto_cast state as well as w/b list - with swith_rng_state_tracker(ctx.fwd_cuda_rng_state, - ctx.fwd_cuda_rng_state_tracker): - with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, - custom_white_list=ctx.amp_white_list, - custom_black_list=ctx.amp_black_list, - level=ctx.amp_level): - detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) - - if isinstance(outputs, (core.VarBase, core.eager.Tensor)): - outputs = (outputs, ) - assert len(outputs) == len(args) - - forward_outputs_with_grad = [] - backward_inputs = [] - - for i in range(len(outputs)): - if isinstance( - outputs[i], - (core.VarBase, - core.eager.Tensor)) and not outputs[i].stop_gradient: - forward_outputs_with_grad.append(outputs[i]) - backward_inputs.append(args[i]) - - if len(forward_outputs_with_grad) == 0: - raise RuntimeError( - "none of output has stop_gradient=False, this recompute() is not necessary" - ) - - # actually backward - paddle.autograd.backward(forward_outputs_with_grad, backward_inputs) - grads = tuple(inp._grad_ivar() for inp in detached_inputs - if isinstance(inp, (core.VarBase, core.eager.Tensor))) - return grads - - -def _hp_recompute(function, *args): - # NODTE(shenliang03)The current hybrid parallel recompute has limitations. - # It cannot handle the following situations: - # 1. The calculation output of recompute, there are tensors that do not require gradients. - # 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach(). - # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor - - all_outputs = [] - _HPRecomputeFunction.apply(function, all_outputs, *args) - - if len(all_outputs) == 1: - return all_outputs[0] - else: - for output in all_outputs: - if paddle.is_tensor(output) and not is_float_tensor(output): - output.stop_gradient = True - - return tuple(all_outputs) diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index fea2614fe84c3f9b0efe68222873a5da1e1f4175..40633788f12d45be4bc30f25044e700656034d93 100644 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -20,46 +20,9 @@ from .base.topology import ParallelMode from .meta_parallel import TensorParallel, model_parallel_random_seed from .meta_parallel import PipelineParallel, ShardingParallel, PipelineParallelWithInterleave, PipelineLayer from paddle.fluid import core -from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar from paddle.distributed import fleet - -class _RecomputeModelWrapper(paddle.nn.Layer): - - def __init__(self, model, segments=2, preserve_rng_state=True): - super(_RecomputeModelWrapper, self).__init__() - assert isinstance(model, paddle.nn.Sequential), ( - "The model passed to RecomputeModelWrapper must be of type " - "paddle.nn.Sequential.") - self._model = model - self._segments = segments - self._preserve_rng_state = preserve_rng_state - self._layers = list(model.children()) - self._segment_size = len(self._layers) // segments - - def _run_func(self, begin, end): - - def do_run(input): - for i in range(begin, end): - input = self._layers[i](input) - return input - - return do_run - - def _checkpoint(self, func, *args, **kwargs): - return LegacyRecomputeFunction.apply(func, self._preserve_rng_state, - *args) - - def forward(self, input): - end = 0 - for begin in range(0, self._segment_size * (self._segments - 1), - self._segment_size): - end = begin + self._segment_size - input = self._checkpoint(self._run_func(begin, end), input) - return self._run_func(end, len(self._layers))(input) - - _grad_scalar = None @@ -125,7 +88,6 @@ def distributed_model(model): return model amp_enable = False - recompute_enable = False strategy = fleet_env._user_defined_strategy if strategy.amp == True: amp_enable = True @@ -154,10 +116,6 @@ def distributed_model(model): decr_every_n_nan_or_inf=decr_every_n_nan_or_inf, use_dynamic_loss_scaling=use_dynamic_loss_scaling) - if strategy.recompute == True: - recompute_enable = True - model = _RecomputeModelWrapper(model) - if strategy.heter_ccl_mode == True: distributed_model = paddle.DataParallel( model, diff --git a/python/paddle/distributed/fleet/recompute/__init__.py b/python/paddle/distributed/fleet/recompute/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e5bcdb1db277661551bdde6e2c1a37e8d2bf906 --- /dev/null +++ b/python/paddle/distributed/fleet/recompute/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .recompute import recompute, recompute_sequential +from .recompute_hybrid import recompute_hybrid + +__all__ = [] diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py similarity index 86% rename from python/paddle/distributed/fleet/utils/recompute.py rename to python/paddle/distributed/fleet/recompute/recompute.py index f0c74159488a782a556c8fab5f303ed0cead31b4..28ded25a0e6e0f96d97f396923c3160fc1b007d1 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -207,12 +207,13 @@ class LegacyRecomputeFunction(LegacyPyLayer): class RecomputeFunction(PyLayer): @staticmethod - def forward(ctx, run_function, preserve_rng_state, *args): + def forward(ctx, run_function, preserve_rng_state, *args, **kwargs): from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker # store for recomputing ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state + ctx.kwargs = kwargs # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input # the order of tensors in backward()'s output should be the same as tensors in forward()'s input @@ -265,7 +266,7 @@ class RecomputeFunction(PyLayer): ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() with paddle.no_grad(): - outputs = run_function(*args) + outputs = run_function(*args, **kwargs) return outputs @staticmethod @@ -297,7 +298,8 @@ class RecomputeFunction(PyLayer): level=ctx.amp_level, dtype=ctx.amp_dtype): detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) + outputs = ctx.run_function(*detached_inputs, + **ctx.kwargs) else: with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, @@ -305,7 +307,7 @@ class RecomputeFunction(PyLayer): level=ctx.amp_level, dtype=ctx.amp_dtype): detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) + outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) if isinstance(outputs, (core.VarBase, core.eager.Tensor)): outputs = (outputs, ) @@ -352,13 +354,13 @@ def recompute(function, *args, **kwargs): recompute intermediate activations to save then memory. Parameters: - function(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model - whose intermediate activations will be released to save memory in forward stage and will be recomputed - in backward stage for gradient calculation. - *args(Tensor): inputs to the function. - **kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to - indicate whether to save the forward rng. If it is True, then the last forward rng value will be - restored when the forward recalculation of backpropagation is performed. The default + function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model + whose intermediate activations will be released to save memory in forward stage and will be recomputed + in backward stage for gradient calculation. + *args(Tensor): inputs to the function. + **kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to + indicate whether to save the forward rng. If it is True, then the last forward rng value will be + restored when the forward recalculation of backpropagation is performed. The default preserve_rng_state is True. Returns: @@ -466,11 +468,59 @@ def recompute(function, *args, **kwargs): """ # Hack to mix *args with **kwargs in a python 2.7-compliant way preserve = kwargs.pop('preserve_rng_state', True) - if kwargs: - raise ValueError("Unexpected keyword arguments: " + - ",".join(arg for arg in kwargs)) if framework._dygraph_tracer()._has_grad: check_recompute_necessary(args) - return RecomputeFunction.apply(function, preserve, *args) + return RecomputeFunction.apply(function, preserve, *args, **kwargs) + + +def recompute_sequential(ctx, functions, *args, **kwargs): + """ + recompute intermediate activations to save then memory for 'Sequential' models. + + Parameters: + ctx(dict): include 'segments' and 'preserve_rng_state' keys, the key 'segments' (int, default 1), represents the number of chunks to create in the model, + the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be + restored when the forward recalculation of backpropagation is performed. and some keys such as 'mp_group', 'offload' and 'partition' are invalid here, + they are useful in 'recompute_hybrid' API. + functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model + whose intermediate activations will be released to save memory in forward stage and will be recomputed + in backward stage for gradient calculation. + *args(Tensor): inputs(tuple) to the function. + **kwargs(Dict): inputs(dict) to the function. + + Returns: + Output of function on args and kwargs. + + Examples: + .. code-block:: python + + model = paddle.nn.Sequential(...) + input = recompute_sequential({'segments' : 1}, model, input) + """ + segments = ctx.get('segments', 1) + preserve_rng_state = ctx.get('preserve_rng_state', True) + + def _run_func(begin, end, funcs): + + def do_run(input): + for i in range(begin, end + 1): + input = funcs[i](input) + return input + + return do_run + + if isinstance(functions, paddle.nn.Sequential): + functions = list(functions.children()) + + segment_size = len(functions) // segments + + end = -1 + for begin in range(0, segment_size * (segments - 1), segment_size): + end = begin + segment_size - 1 + args = recompute(_run_func(begin, end, functions), + *args, + preserve_rng_state=preserve_rng_state, + **kwargs) + return _run_func(end + 1, len(functions) - 1, functions)(args) diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py new file mode 100644 index 0000000000000000000000000000000000000000..4883cad2511bb83a77ab81635e7bb7096c0e6298 --- /dev/null +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -0,0 +1,250 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib + +import paddle +from paddle import _C_ops, _legacy_C_ops +from paddle.fluid import core +from paddle.autograd import PyLayer +from paddle.fluid import framework +from ..meta_parallel.parallel_layers.random import get_rng_state_tracker +from paddle.fluid.framework import in_dygraph_mode +from paddle.distributed import fleet +from .recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker +from ..meta_parallel.pp_utils import utils + +__all__ = [] + + +def _split_activation(tensor, mp_group): + + mp_degree = mp_group.nranks + mp_rank = mp_group.rank + if mp_degree < 2: + return tensor + + tensor_numel = paddle.numel(tensor) + assert tensor_numel != 0, "can't recompute zero element" + assert tensor_numel % mp_degree == 0, "The capacity of the activation ({}) cannot be divisible by mp_degree({})".format( + tensor_numel, mp_degree) + + # use inplace operation to save memory + data = tensor.flatten_() + + part_size = tensor_numel // mp_degree + start = part_size * mp_rank + end = start + part_size + return data[start:end] + + +def _merge_activation(tensor, mp_group): + mp_degree = mp_group.nranks + mp_rank = mp_group.rank + if mp_degree < 2: + return tensor + + # adapt to new dygraph + tensor_shape = list(tensor.shape) + tensor_shape[0] *= mp_group.nranks + out = paddle.empty(tensor_shape, tensor.dtype) + task = mp_group.process_group.all_gather(tensor.cuda(), out) + task.wait() + return out + + +class _HPRecomputeFunction(PyLayer): + """ + Compared with paddle.distributed.fleet.utils.recompute, there are the following differences: + 1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type. + 2. Offload support for activation + 3. Support MP segmentation of activation to further reduce cuda memory + 4. Adapt to the random state of MP + """ + + @staticmethod + def forward(ctx, run_function, all_outputs, mp_group, offload, partition, + *args, **kwargs): + check_recompute_necessary(args) + + # store for recomputing + ctx.run_function = run_function + + ctx.kwargs = kwargs + + # store the rng states + ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( + ).get_states_tracker() + + # save config info + ctx.mp_group = mp_group + ctx.offload = offload + ctx.partition = partition + + # save input for backward + ctx.inputs = [] + ctx.tensor_indices = [] + ctx.tensor_shapes = [] + tensor_inputs = [] + + cur_device = paddle.get_device() + assert 'gpu:' in paddle.get_device( + ), "Recompute with RNG is not support current device: {}.".format( + cur_device) + + # TODO support AMP + tracer = framework._dygraph_tracer() + ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True + if tracer._amp_level == core.AmpLevel.O2: + ctx.amp_level = 'O2' + elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): + ctx.amp_level = 'O1' + else: + raise ValueError("unsupported amp level: {}".format( + tracer._amp_level)) + ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() + + with paddle.no_grad(): + outputs = run_function(*args, **kwargs) + + for i, arg in enumerate(args): + if paddle.is_tensor(arg): + state = arg.stop_gradient + if partition: + ctx.tensor_shapes.append(arg.shape) + partition = _split_activation(arg.detach(), + mp_group).clone() + # TODO(shenliang03) not use calculate stream to D2H to speed + arg = partition.cpu() if offload else partition + else: + arg = arg.cpu() if offload else arg + arg.stop_gradient = state + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + ctx.save_for_backward(*tensor_inputs) + + if paddle.is_tensor(outputs): + all_outputs += [outputs] + return outputs + else: + all_outputs += outputs + return tuple(outputs) + + @staticmethod + def backward(ctx, *args): + with paddle.fluid.dygraph.guard(): + # Restore inputs + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensor_shapes = ctx.tensor_shapes + tensors = list(ctx.saved_tensor()) + + device_id = paddle.distributed.ParallelEnv().device_id + for i, idx in enumerate(tensor_indices): + if ctx.partition: + state = tensors[i].stop_gradient + tensors[i] = _merge_activation( + tensors[i], + ctx.mp_group).detach().reshape_(tensor_shapes[i]) + tensors[i].stop_gradient = state + inputs[idx] = tensors[i].cuda( + device_id) if ctx.offload else tensors[i] + + tracer = framework._dygraph_tracer() + tracer._has_grad = True + + # need restore auto_cast state as well as w/b list + with swith_rng_state_tracker(ctx.fwd_cuda_rng_state, + ctx.fwd_cuda_rng_state_tracker): + with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, + custom_white_list=ctx.amp_white_list, + custom_black_list=ctx.amp_black_list, + level=ctx.amp_level): + detached_inputs = detach_variable(tuple(inputs)) + outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) + + if isinstance(outputs, (core.VarBase, core.eager.Tensor)): + outputs = (outputs, ) + assert len(outputs) == len(args) + + forward_outputs_with_grad = [] + backward_inputs = [] + + for i in range(len(outputs)): + if isinstance( + outputs[i], + (core.VarBase, + core.eager.Tensor)) and not outputs[i].stop_gradient: + forward_outputs_with_grad.append(outputs[i]) + backward_inputs.append(args[i]) + + if len(forward_outputs_with_grad) == 0: + raise RuntimeError( + "none of output has stop_gradient=False, this recompute() is not necessary" + ) + + # actually backward + paddle.autograd.backward(forward_outputs_with_grad, backward_inputs) + grads = tuple(inp._grad_ivar() for inp in detached_inputs + if isinstance(inp, (core.VarBase, core.eager.Tensor))) + return grads + + +def recompute_hybrid(ctx, function, *args, **kwargs): + """ + # NODTE(shenliang03)The current hybrid parallel recompute has limitations. + # It cannot handle the following situations: + # 1. The calculation output of recompute, there are tensors that do not require gradients. + # 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach(). + # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor + + Parameters: + ctx(dict): include 'mp_group', 'offload', and 'partition' keys. the key 'mp_group' (Group), represents the avtivations are splitted + in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False), + represents whether to split activations in the mp_group. and some keys such as 'segments' and 'preserve_rng_state' are invalid here, they are useful in + 'recompute_sequential' API. + function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model + whose intermediate activations will be released to save memory in forward stage and will be recomputed + in backward stage for gradient calculation. + *args(Tensor): inputs(tuple) to the function. + + **kwargs(Dict): inputs(dict) to the function. + + Returns: + Output of function on args and kwargs. + + """ + mp_group = ctx.get('mp_group', None) + assert mp_group is not None, "ctx must contains mp_group and mp_group can not be None." + + offload = ctx.get('offload', False) + partition = ctx.get('partition', False) + + all_outputs = [] + _HPRecomputeFunction.apply(function, all_outputs, mp_group, offload, + partition, *args, **kwargs) + + if len(all_outputs) == 1: + return all_outputs[0] + else: + for output in all_outputs: + if paddle.is_tensor(output) and not utils.is_float_tensor(output): + output.stop_gradient = True + + return tuple(all_outputs) diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 1bf90a22e375c7068653d78891237a710bd8d666..93fc890d05af5b6a7e3bcfa51be0e115271e929c 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -15,11 +15,21 @@ from .fs import LocalFS # noqa: F401 from .fs import HDFSClient # noqa: F401 from .ps_util import DistributedInfer # noqa: F401 -from .recompute import recompute # noqa: F401 +import paddle.utils.deprecated as deprecated +from paddle.distributed import fleet +import paddle from . import log_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401 __all__ = [ #noqa "LocalFS", "recompute", "DistributedInfer", "HDFSClient" ] + + +@deprecated(since="2.4.0", + update_to="paddle.distributed.fleet.recompute", + level=1, + reason="Please use new recompute API(fleet.recompute) ") +def recompute(function, *args, **kwargs): + return fleet.recompute.recompute(function, *args, **kwargs) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py index 11ca15fd33104b64cc9fb2ca6b6aee14e2f6d2cb..f5f59cf10279ae93b13a6937439e1b2e3d336056 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py @@ -21,6 +21,7 @@ import paddle from paddle.autograd import PyLayer from paddle.distributed.fleet.utils import recompute import random +from paddle.incubate.distributed.fleet import recompute_sequential import paddle.fluid.layers as layers @@ -53,48 +54,66 @@ class Naive_fc_net(paddle.nn.Layer): def __init__(self, input_size=10, recompute_blocks=[1, 3], + use_fleet_sq=False, + segments=1, + use_raw_recompute=False, recompute_kwargs={}): super(Naive_fc_net, self).__init__() self.recompute_blocks = recompute_blocks self.recompute_kwargs = recompute_kwargs + self.use_fleet_sq = use_fleet_sq + self.use_raw_recompute = use_raw_recompute + self.segments = segments + self.runfunc0 = get_fc_block(0, input_size, is_last=False) self.runfunc1 = get_fc_block(1, input_size, is_last=False) self.runfunc2 = get_fc_block(2, input_size, is_last=False) self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc4 = get_fc_block(4, input_size, is_last=True) - def forward(self, inputs): + if self.use_fleet_sq and not use_raw_recompute: + self.runfuncs = paddle.nn.Sequential(self.runfunc0, self.runfunc1, + self.runfunc2, self.runfunc3, + self.runfunc4) - if 0 in self.recompute_blocks: - inputs = recompute(self.runfunc0, inputs) - else: - inputs = self.runfunc0(inputs) + self.layers = [ + self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, + self.runfunc4 + ] - if 1 in self.recompute_blocks: - inputs = recompute(self.runfunc1, inputs) - else: - inputs = self.runfunc1(inputs) + # default segments = 2 + if use_raw_recompute: + self.layers = [ + paddle.nn.Sequential(self.runfunc0, self.runfunc1), + paddle.nn.Sequential(self.runfunc2, self.runfunc3, + self.runfunc4) + ] - if 2 in self.recompute_blocks: - inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs) - else: - inputs = self.runfunc2(inputs) + def forward(self, inputs): - if 3 in self.recompute_blocks: - inputs = recompute(self.runfunc3, inputs) - else: - inputs = self.runfunc3(inputs) + if self.use_fleet_sq and not self.use_raw_recompute: + return recompute_sequential({"segments": self.segments}, + self.runfuncs, inputs) - if 4 in self.recompute_blocks: - inputs = recompute(self.runfunc4, inputs) - else: - inputs = self.runfunc4(inputs) + if self.use_raw_recompute: + inputs = recompute(self.layers[0], inputs) + return self.layers[1](inputs) + + for i in range(len(self.layers)): + if i in self.recompute_blocks: + inputs = recompute(self.layers[i], inputs, + **self.recompute_kwargs) + else: + inputs = self.layers[i](inputs) return inputs def run_model(recompute_block=[], recompute_kwargs={}, + use_fleet_sq=False, + use_raw_recompute=False, + segments=1, enable_autocast=False, pure_fp16=False): gen = paddle.seed(10) @@ -105,6 +124,9 @@ def run_model(recompute_block=[], batch_size, input_size = 1, 10 model = Naive_fc_net(input_size, recompute_blocks=recompute_block, + use_fleet_sq=use_fleet_sq, + use_raw_recompute=use_raw_recompute, + segments=segments, recompute_kwargs=recompute_kwargs) loss_fn = paddle.nn.MSELoss(reduction='mean') optimizer = paddle.optimizer.SGD(learning_rate=0.01, @@ -179,6 +201,34 @@ class TestPyLayer(unittest.TestCase): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + # recompute second & fourth block using fleet + loss, param, grad = run_model(recompute_block=[1, 3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute using recompute_sequential, segments=1 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with base recompute, and segments=2 + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], + enable_autocast=enable_autocast, + use_raw_recompute=True, + pure_fp16=pure_fp16) + + # recompute using recompute_sequential, segments=2 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + segments=2, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_dropout(self): self.test_base_case() @@ -191,7 +241,7 @@ class TestPyLayer(unittest.TestCase): def test_recompute_kwargs(self): paddle.set_device("gpu") kwargs = {"is_test": False} - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], recompute_kwargs=kwargs) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py index bc97d53485be99597bf0902870137fc1c6c0361d..4b0c73370d36125b2da9a46f99401bb34911237e 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py @@ -25,6 +25,7 @@ import paddle from paddle.autograd import PyLayer from paddle.distributed.fleet.utils import recompute import random +from paddle.incubate.distributed.fleet import recompute_sequential import paddle.fluid.layers as layers @@ -57,48 +58,66 @@ class Naive_fc_net(paddle.nn.Layer): def __init__(self, input_size=10, recompute_blocks=[1, 3], + use_fleet_sq=False, + segments=1, + use_raw_recompute=False, recompute_kwargs={}): super(Naive_fc_net, self).__init__() self.recompute_blocks = recompute_blocks self.recompute_kwargs = recompute_kwargs + self.use_fleet_sq = use_fleet_sq + self.use_raw_recompute = use_raw_recompute + self.segments = segments + self.runfunc0 = get_fc_block(0, input_size, is_last=False) self.runfunc1 = get_fc_block(1, input_size, is_last=False) self.runfunc2 = get_fc_block(2, input_size, is_last=False) self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc4 = get_fc_block(4, input_size, is_last=True) - def forward(self, inputs): + if self.use_fleet_sq and not use_raw_recompute: + self.runfuncs = paddle.nn.Sequential(self.runfunc0, self.runfunc1, + self.runfunc2, self.runfunc3, + self.runfunc4) - if 0 in self.recompute_blocks: - inputs = recompute(self.runfunc0, inputs) - else: - inputs = self.runfunc0(inputs) + self.layers = [ + self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, + self.runfunc4 + ] - if 1 in self.recompute_blocks: - inputs = recompute(self.runfunc1, inputs) - else: - inputs = self.runfunc1(inputs) + # default segments = 2 + if use_raw_recompute: + self.layers = [ + paddle.nn.Sequential(self.runfunc0, self.runfunc1), + paddle.nn.Sequential(self.runfunc2, self.runfunc3, + self.runfunc4) + ] - if 2 in self.recompute_blocks: - inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs) - else: - inputs = self.runfunc2(inputs) + def forward(self, inputs): - if 3 in self.recompute_blocks: - inputs = recompute(self.runfunc3, inputs) - else: - inputs = self.runfunc3(inputs) + if self.use_fleet_sq and not self.use_raw_recompute: + return paddle.incubate.distributed.fleet.recompute_sequential( + {"segments": self.segments}, self.runfuncs, inputs) - if 4 in self.recompute_blocks: - inputs = recompute(self.runfunc4, inputs) - else: - inputs = self.runfunc4(inputs) + if self.use_raw_recompute: + inputs = recompute(self.layers[0], inputs) + return self.layers[1](inputs) + + for i in range(len(self.layers)): + if i in self.recompute_blocks: + inputs = recompute(self.layers[i], inputs, + **self.recompute_kwargs) + else: + inputs = self.layers[i](inputs) return inputs def run_model(recompute_block=[], recompute_kwargs={}, + use_fleet_sq=False, + use_raw_recompute=False, + segments=1, enable_autocast=False, pure_fp16=False): gen = paddle.seed(10) @@ -109,6 +128,9 @@ def run_model(recompute_block=[], batch_size, input_size = 1, 10 model = Naive_fc_net(input_size, recompute_blocks=recompute_block, + use_fleet_sq=use_fleet_sq, + use_raw_recompute=use_raw_recompute, + segments=segments, recompute_kwargs=recompute_kwargs) loss_fn = paddle.nn.MSELoss(reduction='mean') optimizer = paddle.optimizer.SGD(learning_rate=0.01, @@ -183,6 +205,28 @@ class TestPyLayer(unittest.TestCase): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + # recompute_sequential with segments=1 using fleet + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with base recompute, and segments=2 + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], + enable_autocast=enable_autocast, + use_raw_recompute=True, + pure_fp16=pure_fp16) + + # recompute using paddle.incubate.distributed.fleet.recompute_sequential, segments=2 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + segments=2, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_dropout(self): self.test_base_case() @@ -201,7 +245,7 @@ class TestPyLayer(unittest.TestCase): def test_recompute_kwargs(self): paddle.set_device("gpu") kwargs = {"is_test": False} - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], recompute_kwargs=kwargs) diff --git a/python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py b/python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py new file mode 100755 index 0000000000000000000000000000000000000000..cc90f0433b92172eb5bf7c74c8e18b885582a88e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py @@ -0,0 +1,221 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +from paddle.autograd import PyLayer +from paddle.distributed.fleet.utils import recompute +from paddle.incubate.distributed.fleet import recompute_hybrid +import random +from paddle.distributed import fleet + +import paddle.fluid.layers as layers + + +def get_fc_block(block_idx, input_size, is_last=False): + block_name = "block_" + str(block_idx) + block = paddle.nn.Sequential( + (block_name + "_fc_0", + paddle.nn.Linear(input_size, input_size, bias_attr=False)), + (block_name + "_dropout", paddle.nn.Dropout(p=0.5)), + (block_name + "_relu_1", paddle.nn.ReLU()), + (block_name + "_fc_1", + paddle.nn.Linear(input_size, input_size, bias_attr=False)), + (block_name + "_relu_2", paddle.nn.ReLU()), + ) + if is_last: + block.add_sublayer(block_name + "_fc_2", + paddle.nn.Linear(input_size, 1, + bias_attr=False)) # add sublayer + else: + block.add_sublayer(block_name + "_fc_2", + paddle.nn.Linear(input_size, + input_size, + bias_attr=False)) # add sublayer + return block + + +class Naive_fc_net(paddle.nn.Layer): + + def __init__(self, + input_size=10, + recompute_blocks=[1, 3], + offload=False, + partition=False, + recompute_kwargs={}): + super(Naive_fc_net, self).__init__() + self.recompute_blocks = recompute_blocks + self.recompute_kwargs = recompute_kwargs + self.offload = offload + self.partition = partition + + self.runfunc0 = get_fc_block(0, input_size, is_last=False) + self.runfunc1 = get_fc_block(1, input_size, is_last=False) + self.runfunc2 = get_fc_block(2, input_size, is_last=False) + self.runfunc3 = get_fc_block(3, input_size, is_last=False) + self.runfunc4 = get_fc_block(4, input_size, is_last=True) + + self.layers = [ + self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, + self.runfunc4 + ] + + def forward(self, inputs): + for i in range(len(self.layers)): + if i in self.recompute_blocks: + inputs = recompute_hybrid( + { + "mp_group": fleet.fleet._hcg.get_model_parallel_group(), + "offload": self.offload, + "partition": self.partition + }, self.layers[i], inputs, **self.recompute_kwargs) + else: + inputs = self.layers[i](inputs) + + return inputs + + +def run_model(recompute_block=[], + recompute_kwargs={}, + offload=False, + partition=False, + enable_autocast=False, + pure_fp16=False): + gen = paddle.seed(10) + gen.manual_seed(10) + np.random.seed(10) + random.seed(10) + + batch_size, input_size = 1, 10 + model = Naive_fc_net(input_size, + recompute_blocks=recompute_block, + offload=offload, + partition=partition, + recompute_kwargs=recompute_kwargs) + loss_fn = paddle.nn.MSELoss(reduction='mean') + optimizer = paddle.optimizer.SGD(learning_rate=0.01, + parameters=model.parameters()) + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + + if enable_autocast: + scaler = paddle.amp.GradScaler() + scaler = fleet.distributed_scaler(scaler) + + loss_ = [] + param_ = [] + grad_ = [] + for step in range(10): + + x_data = np.random.randn(batch_size, input_size).astype(np.float32) + x = paddle.to_tensor(x_data) + # x.stop_gradient = False + level = 'O2' if pure_fp16 else 'O1' + with paddle.amp.auto_cast(True, level=level): + y_pred = model(x) + loss = y_pred.mean() + if enable_autocast: + scaler.scale(loss).backward() + scaler.minimize(optimizer, loss) + else: + loss_.append(np.asarray(loss).tolist()) + loss.backward() + optimizer.step() + + param_.append(np.asarray(model.parameters()[9]).tolist()) + grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist()) + + optimizer.clear_grad() + return loss_, param_, grad_ + + +class TestPyLayer(unittest.TestCase): + + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + fleet.init(is_collective=True, strategy=strategy) + + def test_base_case(self, enable_autocast=False, pure_fp16=False): + + def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): + self.assertEqual(loss_ref, loss) + self.assertEqual(param_ref, param) + self.assertEqual(grad_ref, grad) + + # without recompute + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + + # with recompute, offload=False, partition=False + loss, param, grad = run_model(recompute_block=[1, 3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with recompute, offload=True, partition=False + loss, param, grad = run_model(recompute_block=[1, 2, 3], + offload=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with recompute, offload=False, partition=True + loss, param, grad = run_model(recompute_block=[1], + partition=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with recompute, offload=True, partition=True + loss, param, grad = run_model(recompute_block=[1, 3, 4], + offload=True, + partition=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + def test_fc_net_with_dropout(self): + self.test_base_case() + + def test_fc_net_with_amp(self): + self.test_base_case(enable_autocast=True) + + def test_fc_net_with_fp16(self): + self.test_base_case(enable_autocast=True, pure_fp16=True) + + def test_recompute_kwargs(self): + paddle.set_device("gpu") + kwargs = {"is_test": False} + with self.assertRaises(TypeError): + loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], + recompute_kwargs=kwargs) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py index 8773e8d47ed3c0e1fa7b4e4442af0c07c0e6111f..10243a0faa9445ea6d14086d2e7d243be67be70f 100644 --- a/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py @@ -26,5 +26,11 @@ class TestPipelineParallel(TestMultipleGpus): self.run_mnist_2gpu('hybrid_parallel_pp_alexnet.py') +class TestModelParallelWithRecompute(TestMultipleGpus): + + def test_model_parallel_with_recompute(self): + self.run_mnist_2gpu("dygraph_recompute_hybrid.py") + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/incubate/distributed/fleet/__init__.py b/python/paddle/incubate/distributed/fleet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94e1a7c8bbe77bbf763f462d48fc5812243817f9 --- /dev/null +++ b/python/paddle/incubate/distributed/fleet/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.distributed.fleet.recompute import recompute_sequential, recompute_hybrid + +__all__ = ["recompute_sequential", "recompute_hybrid"] diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index ebf300abf95457d69c8437c4c3abd93835b88969..086ae2d57d2d233f03273d19ca9015dccff6bad0 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -34,9 +34,9 @@ from paddle.distributed import fleet from paddle.autograd import PyLayer from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate from .utils import count_by_gate -from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute from paddle import fluid from paddle.fluid.framework import in_dygraph_mode +from paddle.incubate.distributed.fleet import recompute_hybrid def _local_scatter(inp, pos): @@ -424,8 +424,8 @@ class MoELayer(nn.Layer): if self.recompute_interval <= 0 or x.shape[0] == 0: x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) else: - x = _hp_recompute(experts_fwd, x, fwd_expert_count.numpy(), - self.experts) + x = recompute_hybrid(self.recompute_ctx, experts_fwd, x, + fwd_expert_count.numpy(), self.experts) out_batch_size = inp.shape[0] if len(gate.shape) == 2: diff --git a/python/paddle/incubate/optimizer/distributed_fused_lamb.py b/python/paddle/incubate/optimizer/distributed_fused_lamb.py index d230b6afca2995f1dad428e95f87335e71cdb49d..f8e3b55aba6213658bfb730202087711e2de6ab8 100644 --- a/python/paddle/incubate/optimizer/distributed_fused_lamb.py +++ b/python/paddle/incubate/optimizer/distributed_fused_lamb.py @@ -19,7 +19,7 @@ from paddle.fluid.clip import ClipGradByGlobalNorm from paddle.fluid.initializer import Constant from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.optimizer import Optimizer -from paddle.distributed import get_rank, get_world_size +import paddle.distributed as dist from paddle.distributed.collective import new_group from paddle.fluid.executor import global_scope from paddle.fluid.framework import name_scope @@ -288,8 +288,8 @@ class DistributedFusedLamb(Optimizer): step = self._get_or_create_step() - rank = get_rank() - nranks = get_world_size() + rank = dist.get_rank() + nranks = dist.get_world_size() if self._nproc_per_node is None: nproc_per_node = nranks else: diff --git a/python/setup.py.in b/python/setup.py.in index c51a3765da12f6b7fddefea4342b5580417e19e2..084fd5ac042c60f801c07abb1daea1ed95a50a69 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -296,6 +296,7 @@ packages=['paddle', 'paddle.distributed.launch.plugins', 'paddle.distributed.launch.utils', 'paddle.distributed.fleet.base', + 'paddle.distributed.fleet.recompute', 'paddle.distributed.fleet.elastic', 'paddle.distributed.fleet.meta_optimizers', 'paddle.distributed.fleet.meta_optimizers.sharding', @@ -380,6 +381,7 @@ packages=['paddle', 'paddle.incubate.optimizer.functional', 'paddle.incubate.autograd', 'paddle.incubate.distributed', + 'paddle.incubate.distributed.fleet', 'paddle.incubate.distributed.models', 'paddle.incubate.distributed.models.moe', 'paddle.incubate.distributed.models.moe.gate',