未验证 提交 4bced24a 编写于 作者: W wuhuachaocoding 提交者: GitHub

Recompute unify incubate (#46073) (#46210)

上级 be84cac7
...@@ -49,8 +49,9 @@ from functools import partial ...@@ -49,8 +49,9 @@ from functools import partial
import paddle import paddle
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
from ...utils.log_util import logger, layer_to_str from ...utils.log_util import logger, layer_to_str
from ..pp_utils.utils import _hp_recompute, _initialize_recompute_setting from paddle.distributed import fleet
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.incubate.distributed.fleet import recompute_hybrid
__all__ = [] __all__ = []
...@@ -309,13 +310,13 @@ class PipelineLayer(Layer): ...@@ -309,13 +310,13 @@ class PipelineLayer(Layer):
self._loss_fn = loss_fn self._loss_fn = loss_fn
self._topo = topology self._topo = topology
self._recompute_interval = recompute_interval self._recompute_interval = recompute_interval
self.recompute_ctx = recompute_ctx
if recompute_interval > 0: if recompute_interval > 0:
assert recompute_ctx is not None, "recompute_ctx must be not None for recompute." assert recompute_ctx is not None, "recompute_ctx must be not None for recompute."
offload = recompute_ctx.get('offload', False) offload = recompute_ctx.get('offload', False)
partition = recompute_ctx.get('partition', False) partition = recompute_ctx.get('partition', False)
_initialize_recompute_setting(offload, partition)
logger.info( logger.info(
"Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}" "Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}"
.format(offload, partition)) .format(offload, partition))
...@@ -638,7 +639,8 @@ class PipelineLayer(Layer): ...@@ -638,7 +639,8 @@ class PipelineLayer(Layer):
input = (input, ) input = (input, )
if self._need_recompute(funcs, input): if self._need_recompute(funcs, input):
input = _hp_recompute( input = recompute_hybrid(
self.recompute_ctx,
self.forward_function(start_idx, end_idx), *input) self.forward_function(start_idx, end_idx), *input)
else: else:
input = self.forward_function(start_idx, end_idx)(*input) input = self.forward_function(start_idx, end_idx)(*input)
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase from .meta_parallel_base import MetaParallelBase
from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg
from .parallel_layers.pp_layers import PipelineLayer from .parallel_layers.pp_layers import PipelineLayer
from ..utils.hybrid_parallel_util import broadcast_mp_parameters from ..utils.hybrid_parallel_util import broadcast_mp_parameters
...@@ -61,8 +60,6 @@ class PipelineParallel(MetaParallelBase): ...@@ -61,8 +60,6 @@ class PipelineParallel(MetaParallelBase):
p2p.initialize_p2p_groups(hcg, self._using_cache) p2p.initialize_p2p_groups(hcg, self._using_cache)
_initialize_recompute_hcg(hcg)
self.global_rank = self._hcg.get_global_rank() self.global_rank = self._hcg.get_global_rank()
self.micro_batch_id = 0 self.micro_batch_id = 0
......
...@@ -12,6 +12,4 @@ ...@@ -12,6 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .utils import get_tensor_bytes
__all__ = [] __all__ = []
...@@ -13,12 +13,12 @@ ...@@ -13,12 +13,12 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from .utils import paddle_2_number, number_2_dtype
from ...utils.log_util import logger from ...utils.log_util import logger
import numpy as np import numpy as np
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.framework import _in_legacy_dygraph, _non_static_mode, in_dygraph_mode from paddle.fluid.framework import _in_legacy_dygraph, _non_static_mode, in_dygraph_mode
from .utils import paddle_2_number, paddle_2_number, number_2_dtype
_hcg = None _hcg = None
_use_cache = False _use_cache = False
......
...@@ -12,16 +12,9 @@ ...@@ -12,16 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.autograd import PyLayer
from paddle.fluid import framework
from ...utils.recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker
from ..parallel_layers.random import get_rng_state_tracker
from paddle.fluid.framework import in_dygraph_mode
__all__ = [] __all__ = []
...@@ -88,23 +81,6 @@ def get_tensor_bytes(tensor): ...@@ -88,23 +81,6 @@ def get_tensor_bytes(tensor):
return tensor.numel() * elem_size return tensor.numel() * elem_size
_hcg = None
_recompute_offload = False
_recompute_partition = False
def _initialize_recompute_setting(is_offload, is_partition):
global _recompute_offload, _recompute_partition
_recompute_offload = is_offload
_recompute_partition = is_partition
def _initialize_recompute_hcg(hcg):
global _hcg
_hcg = hcg
def _all_gather(tensor, group=None, use_calc_stream=True): def _all_gather(tensor, group=None, use_calc_stream=True):
""" """
The main difference with paddle.distributed.all_gather: The main difference with paddle.distributed.all_gather:
...@@ -117,187 +93,3 @@ def _all_gather(tensor, group=None, use_calc_stream=True): ...@@ -117,187 +93,3 @@ def _all_gather(tensor, group=None, use_calc_stream=True):
).nranks if group is None else group.nranks ).nranks if group is None else group.nranks
return _legacy_C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream, return _legacy_C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'nranks', nranks) 'ring_id', ring_id, 'nranks', nranks)
def _split_activation(tensor):
global _hcg
mp_degree = _hcg.get_model_parallel_world_size()
mp_rank = _hcg.get_model_parallel_rank()
if mp_degree < 2:
return tensor
tensor_numel = paddle.numel(tensor)
assert tensor_numel != 0, "can't recompute zero element"
assert tensor_numel % mp_degree == 0, "The capacity of the activation () cannot be divisible by mp_degree()".format(
tensor_numel, mp_degree)
# use inplace operation to save memory
data = tensor.flatten_()
part_size = tensor_numel // mp_degree
start = part_size * mp_rank
end = start + part_size
return data[start:end]
def _merge_activation(tensor):
global _hcg
mp_degree = _hcg.get_model_parallel_world_size()
mp_rank = _hcg.get_model_parallel_rank()
mp_group = _hcg.get_model_parallel_group()
if mp_degree < 2:
return tensor
return _all_gather(tensor, group=mp_group)
class _HPRecomputeFunction(PyLayer):
"""
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type.
2. Offload support for activation
3. Support MP segmentation of activation to further reduce cuda memory
4. Adapt to the random state of MP
"""
@staticmethod
def forward(ctx, run_function, all_outputs, *args):
check_recompute_necessary(args)
# store for recomputing
ctx.run_function = run_function
# store the rng states
ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker(
).get_states_tracker()
# save input for backward
ctx.inputs = []
ctx.tensor_indices = []
ctx.tensor_shapes = []
tensor_inputs = []
cur_device = paddle.get_device()
assert 'gpu:' in paddle.get_device(
), "Recompute with RNG is not support current device: {}.".format(
cur_device)
# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad():
outputs = run_function(*args)
for i, arg in enumerate(args):
if paddle.is_tensor(arg):
state = arg.stop_gradient
if _recompute_partition:
ctx.tensor_shapes.append(arg.shape)
partition = _split_activation(arg.detach()).clone()
# TODO(shenliang03) not use calculate stream to D2H to speed
arg = partition.cpu() if _recompute_offload else partition
else:
arg = arg.cpu() if _recompute_offload else arg
arg.stop_gradient = state
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
if paddle.is_tensor(outputs):
all_outputs += [outputs]
return outputs
else:
all_outputs += outputs
return tuple(outputs)
@staticmethod
def backward(ctx, *args):
with paddle.fluid.dygraph.guard():
# Restore inputs
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensor_shapes = ctx.tensor_shapes
tensors = list(ctx.saved_tensor())
device_id = paddle.distributed.ParallelEnv().device_id
for i, idx in enumerate(tensor_indices):
if _recompute_partition:
state = tensors[i].stop_gradient
tensors[i] = _merge_activation(
tensors[i]).detach().reshape_(tensor_shapes[i])
tensors[i].stop_gradient = state
inputs[idx] = tensors[i].cuda(
device_id) if _recompute_offload else tensors[i]
tracer = framework._dygraph_tracer()
tracer._has_grad = True
# need restore auto_cast state as well as w/b list
with swith_rng_state_tracker(ctx.fwd_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker):
with paddle.amp.auto_cast(enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, (core.VarBase, core.eager.Tensor)):
outputs = (outputs, )
assert len(outputs) == len(args)
forward_outputs_with_grad = []
backward_inputs = []
for i in range(len(outputs)):
if isinstance(
outputs[i],
(core.VarBase,
core.eager.Tensor)) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
backward_inputs.append(args[i])
if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has stop_gradient=False, this recompute() is not necessary"
)
# actually backward
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)
grads = tuple(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, (core.VarBase, core.eager.Tensor)))
return grads
def _hp_recompute(function, *args):
# NODTE(shenliang03)The current hybrid parallel recompute has limitations.
# It cannot handle the following situations:
# 1. The calculation output of recompute, there are tensors that do not require gradients.
# 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach().
# 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor
all_outputs = []
_HPRecomputeFunction.apply(function, all_outputs, *args)
if len(all_outputs) == 1:
return all_outputs[0]
else:
for output in all_outputs:
if paddle.is_tensor(output) and not is_float_tensor(output):
output.stop_gradient = True
return tuple(all_outputs)
...@@ -20,46 +20,9 @@ from .base.topology import ParallelMode ...@@ -20,46 +20,9 @@ from .base.topology import ParallelMode
from .meta_parallel import TensorParallel, model_parallel_random_seed from .meta_parallel import TensorParallel, model_parallel_random_seed
from .meta_parallel import PipelineParallel, ShardingParallel, PipelineParallelWithInterleave, PipelineLayer from .meta_parallel import PipelineParallel, ShardingParallel, PipelineParallelWithInterleave, PipelineLayer
from paddle.fluid import core from paddle.fluid import core
from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction
from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar
from paddle.distributed import fleet from paddle.distributed import fleet
class _RecomputeModelWrapper(paddle.nn.Layer):
def __init__(self, model, segments=2, preserve_rng_state=True):
super(_RecomputeModelWrapper, self).__init__()
assert isinstance(model, paddle.nn.Sequential), (
"The model passed to RecomputeModelWrapper must be of type "
"paddle.nn.Sequential.")
self._model = model
self._segments = segments
self._preserve_rng_state = preserve_rng_state
self._layers = list(model.children())
self._segment_size = len(self._layers) // segments
def _run_func(self, begin, end):
def do_run(input):
for i in range(begin, end):
input = self._layers[i](input)
return input
return do_run
def _checkpoint(self, func, *args, **kwargs):
return LegacyRecomputeFunction.apply(func, self._preserve_rng_state,
*args)
def forward(self, input):
end = 0
for begin in range(0, self._segment_size * (self._segments - 1),
self._segment_size):
end = begin + self._segment_size
input = self._checkpoint(self._run_func(begin, end), input)
return self._run_func(end, len(self._layers))(input)
_grad_scalar = None _grad_scalar = None
...@@ -125,7 +88,6 @@ def distributed_model(model): ...@@ -125,7 +88,6 @@ def distributed_model(model):
return model return model
amp_enable = False amp_enable = False
recompute_enable = False
strategy = fleet_env._user_defined_strategy strategy = fleet_env._user_defined_strategy
if strategy.amp == True: if strategy.amp == True:
amp_enable = True amp_enable = True
...@@ -154,10 +116,6 @@ def distributed_model(model): ...@@ -154,10 +116,6 @@ def distributed_model(model):
decr_every_n_nan_or_inf=decr_every_n_nan_or_inf, decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
use_dynamic_loss_scaling=use_dynamic_loss_scaling) use_dynamic_loss_scaling=use_dynamic_loss_scaling)
if strategy.recompute == True:
recompute_enable = True
model = _RecomputeModelWrapper(model)
if strategy.heter_ccl_mode == True: if strategy.heter_ccl_mode == True:
distributed_model = paddle.DataParallel( distributed_model = paddle.DataParallel(
model, model,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .recompute import recompute, recompute_sequential
from .recompute_hybrid import recompute_hybrid
__all__ = []
...@@ -207,12 +207,13 @@ class LegacyRecomputeFunction(LegacyPyLayer): ...@@ -207,12 +207,13 @@ class LegacyRecomputeFunction(LegacyPyLayer):
class RecomputeFunction(PyLayer): class RecomputeFunction(PyLayer):
@staticmethod @staticmethod
def forward(ctx, run_function, preserve_rng_state, *args): def forward(ctx, run_function, preserve_rng_state, *args, **kwargs):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
# store for recomputing # store for recomputing
ctx.run_function = run_function ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state ctx.preserve_rng_state = preserve_rng_state
ctx.kwargs = kwargs
# NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
# the order of tensors in backward()'s output should be the same as tensors in forward()'s input # the order of tensors in backward()'s output should be the same as tensors in forward()'s input
...@@ -265,7 +266,7 @@ class RecomputeFunction(PyLayer): ...@@ -265,7 +266,7 @@ class RecomputeFunction(PyLayer):
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad(): with paddle.no_grad():
outputs = run_function(*args) outputs = run_function(*args, **kwargs)
return outputs return outputs
@staticmethod @staticmethod
...@@ -297,7 +298,8 @@ class RecomputeFunction(PyLayer): ...@@ -297,7 +298,8 @@ class RecomputeFunction(PyLayer):
level=ctx.amp_level, level=ctx.amp_level,
dtype=ctx.amp_dtype): dtype=ctx.amp_dtype):
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs) outputs = ctx.run_function(*detached_inputs,
**ctx.kwargs)
else: else:
with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, with paddle.amp.auto_cast(enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list, custom_white_list=ctx.amp_white_list,
...@@ -305,7 +307,7 @@ class RecomputeFunction(PyLayer): ...@@ -305,7 +307,7 @@ class RecomputeFunction(PyLayer):
level=ctx.amp_level, level=ctx.amp_level,
dtype=ctx.amp_dtype): dtype=ctx.amp_dtype):
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs) outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
if isinstance(outputs, (core.VarBase, core.eager.Tensor)): if isinstance(outputs, (core.VarBase, core.eager.Tensor)):
outputs = (outputs, ) outputs = (outputs, )
...@@ -352,13 +354,13 @@ def recompute(function, *args, **kwargs): ...@@ -352,13 +354,13 @@ def recompute(function, *args, **kwargs):
recompute intermediate activations to save then memory. recompute intermediate activations to save then memory.
Parameters: Parameters:
function(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation. in backward stage for gradient calculation.
*args(Tensor): inputs to the function. *args(Tensor): inputs to the function.
**kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to **kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to
indicate whether to save the forward rng. If it is True, then the last forward rng value will be indicate whether to save the forward rng. If it is True, then the last forward rng value will be
restored when the forward recalculation of backpropagation is performed. The default restored when the forward recalculation of backpropagation is performed. The default
preserve_rng_state is True. preserve_rng_state is True.
Returns: Returns:
...@@ -466,11 +468,59 @@ def recompute(function, *args, **kwargs): ...@@ -466,11 +468,59 @@ def recompute(function, *args, **kwargs):
""" """
# Hack to mix *args with **kwargs in a python 2.7-compliant way # Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve = kwargs.pop('preserve_rng_state', True) preserve = kwargs.pop('preserve_rng_state', True)
if kwargs:
raise ValueError("Unexpected keyword arguments: " +
",".join(arg for arg in kwargs))
if framework._dygraph_tracer()._has_grad: if framework._dygraph_tracer()._has_grad:
check_recompute_necessary(args) check_recompute_necessary(args)
return RecomputeFunction.apply(function, preserve, *args) return RecomputeFunction.apply(function, preserve, *args, **kwargs)
def recompute_sequential(ctx, functions, *args, **kwargs):
"""
recompute intermediate activations to save then memory for 'Sequential' models.
Parameters:
ctx(dict): include 'segments' and 'preserve_rng_state' keys, the key 'segments' (int, default 1), represents the number of chunks to create in the model,
the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be
restored when the forward recalculation of backpropagation is performed. and some keys such as 'mp_group', 'offload' and 'partition' are invalid here,
they are useful in 'recompute_hybrid' API.
functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation.
*args(Tensor): inputs(tuple) to the function.
**kwargs(Dict): inputs(dict) to the function.
Returns:
Output of function on args and kwargs.
Examples:
.. code-block:: python
model = paddle.nn.Sequential(...)
input = recompute_sequential({'segments' : 1}, model, input)
"""
segments = ctx.get('segments', 1)
preserve_rng_state = ctx.get('preserve_rng_state', True)
def _run_func(begin, end, funcs):
def do_run(input):
for i in range(begin, end + 1):
input = funcs[i](input)
return input
return do_run
if isinstance(functions, paddle.nn.Sequential):
functions = list(functions.children())
segment_size = len(functions) // segments
end = -1
for begin in range(0, segment_size * (segments - 1), segment_size):
end = begin + segment_size - 1
args = recompute(_run_func(begin, end, functions),
*args,
preserve_rng_state=preserve_rng_state,
**kwargs)
return _run_func(end + 1, len(functions) - 1, functions)(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import paddle
from paddle import _C_ops, _legacy_C_ops
from paddle.fluid import core
from paddle.autograd import PyLayer
from paddle.fluid import framework
from ..meta_parallel.parallel_layers.random import get_rng_state_tracker
from paddle.fluid.framework import in_dygraph_mode
from paddle.distributed import fleet
from .recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker
from ..meta_parallel.pp_utils import utils
__all__ = []
def _split_activation(tensor, mp_group):
mp_degree = mp_group.nranks
mp_rank = mp_group.rank
if mp_degree < 2:
return tensor
tensor_numel = paddle.numel(tensor)
assert tensor_numel != 0, "can't recompute zero element"
assert tensor_numel % mp_degree == 0, "The capacity of the activation ({}) cannot be divisible by mp_degree({})".format(
tensor_numel, mp_degree)
# use inplace operation to save memory
data = tensor.flatten_()
part_size = tensor_numel // mp_degree
start = part_size * mp_rank
end = start + part_size
return data[start:end]
def _merge_activation(tensor, mp_group):
mp_degree = mp_group.nranks
mp_rank = mp_group.rank
if mp_degree < 2:
return tensor
# adapt to new dygraph
tensor_shape = list(tensor.shape)
tensor_shape[0] *= mp_group.nranks
out = paddle.empty(tensor_shape, tensor.dtype)
task = mp_group.process_group.all_gather(tensor.cuda(), out)
task.wait()
return out
class _HPRecomputeFunction(PyLayer):
"""
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type.
2. Offload support for activation
3. Support MP segmentation of activation to further reduce cuda memory
4. Adapt to the random state of MP
"""
@staticmethod
def forward(ctx, run_function, all_outputs, mp_group, offload, partition,
*args, **kwargs):
check_recompute_necessary(args)
# store for recomputing
ctx.run_function = run_function
ctx.kwargs = kwargs
# store the rng states
ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker(
).get_states_tracker()
# save config info
ctx.mp_group = mp_group
ctx.offload = offload
ctx.partition = partition
# save input for backward
ctx.inputs = []
ctx.tensor_indices = []
ctx.tensor_shapes = []
tensor_inputs = []
cur_device = paddle.get_device()
assert 'gpu:' in paddle.get_device(
), "Recompute with RNG is not support current device: {}.".format(
cur_device)
# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad():
outputs = run_function(*args, **kwargs)
for i, arg in enumerate(args):
if paddle.is_tensor(arg):
state = arg.stop_gradient
if partition:
ctx.tensor_shapes.append(arg.shape)
partition = _split_activation(arg.detach(),
mp_group).clone()
# TODO(shenliang03) not use calculate stream to D2H to speed
arg = partition.cpu() if offload else partition
else:
arg = arg.cpu() if offload else arg
arg.stop_gradient = state
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
if paddle.is_tensor(outputs):
all_outputs += [outputs]
return outputs
else:
all_outputs += outputs
return tuple(outputs)
@staticmethod
def backward(ctx, *args):
with paddle.fluid.dygraph.guard():
# Restore inputs
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensor_shapes = ctx.tensor_shapes
tensors = list(ctx.saved_tensor())
device_id = paddle.distributed.ParallelEnv().device_id
for i, idx in enumerate(tensor_indices):
if ctx.partition:
state = tensors[i].stop_gradient
tensors[i] = _merge_activation(
tensors[i],
ctx.mp_group).detach().reshape_(tensor_shapes[i])
tensors[i].stop_gradient = state
inputs[idx] = tensors[i].cuda(
device_id) if ctx.offload else tensors[i]
tracer = framework._dygraph_tracer()
tracer._has_grad = True
# need restore auto_cast state as well as w/b list
with swith_rng_state_tracker(ctx.fwd_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker):
with paddle.amp.auto_cast(enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
if isinstance(outputs, (core.VarBase, core.eager.Tensor)):
outputs = (outputs, )
assert len(outputs) == len(args)
forward_outputs_with_grad = []
backward_inputs = []
for i in range(len(outputs)):
if isinstance(
outputs[i],
(core.VarBase,
core.eager.Tensor)) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
backward_inputs.append(args[i])
if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has stop_gradient=False, this recompute() is not necessary"
)
# actually backward
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)
grads = tuple(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, (core.VarBase, core.eager.Tensor)))
return grads
def recompute_hybrid(ctx, function, *args, **kwargs):
"""
# NODTE(shenliang03)The current hybrid parallel recompute has limitations.
# It cannot handle the following situations:
# 1. The calculation output of recompute, there are tensors that do not require gradients.
# 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach().
# 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor
Parameters:
ctx(dict): include 'mp_group', 'offload', and 'partition' keys. the key 'mp_group' (Group), represents the avtivations are splitted
in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False),
represents whether to split activations in the mp_group. and some keys such as 'segments' and 'preserve_rng_state' are invalid here, they are useful in
'recompute_sequential' API.
function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation.
*args(Tensor): inputs(tuple) to the function.
**kwargs(Dict): inputs(dict) to the function.
Returns:
Output of function on args and kwargs.
"""
mp_group = ctx.get('mp_group', None)
assert mp_group is not None, "ctx must contains mp_group and mp_group can not be None."
offload = ctx.get('offload', False)
partition = ctx.get('partition', False)
all_outputs = []
_HPRecomputeFunction.apply(function, all_outputs, mp_group, offload,
partition, *args, **kwargs)
if len(all_outputs) == 1:
return all_outputs[0]
else:
for output in all_outputs:
if paddle.is_tensor(output) and not utils.is_float_tensor(output):
output.stop_gradient = True
return tuple(all_outputs)
...@@ -15,11 +15,21 @@ ...@@ -15,11 +15,21 @@
from .fs import LocalFS # noqa: F401 from .fs import LocalFS # noqa: F401
from .fs import HDFSClient # noqa: F401 from .fs import HDFSClient # noqa: F401
from .ps_util import DistributedInfer # noqa: F401 from .ps_util import DistributedInfer # noqa: F401
from .recompute import recompute # noqa: F401 import paddle.utils.deprecated as deprecated
from paddle.distributed import fleet
import paddle
from . import log_util # noqa: F401 from . import log_util # noqa: F401
from . import hybrid_parallel_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401
__all__ = [ #noqa __all__ = [ #noqa
"LocalFS", "recompute", "DistributedInfer", "HDFSClient" "LocalFS", "recompute", "DistributedInfer", "HDFSClient"
] ]
@deprecated(since="2.4.0",
update_to="paddle.distributed.fleet.recompute",
level=1,
reason="Please use new recompute API(fleet.recompute) ")
def recompute(function, *args, **kwargs):
return fleet.recompute.recompute(function, *args, **kwargs)
...@@ -21,6 +21,7 @@ import paddle ...@@ -21,6 +21,7 @@ import paddle
from paddle.autograd import PyLayer from paddle.autograd import PyLayer
from paddle.distributed.fleet.utils import recompute from paddle.distributed.fleet.utils import recompute
import random import random
from paddle.incubate.distributed.fleet import recompute_sequential
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
...@@ -53,48 +54,66 @@ class Naive_fc_net(paddle.nn.Layer): ...@@ -53,48 +54,66 @@ class Naive_fc_net(paddle.nn.Layer):
def __init__(self, def __init__(self,
input_size=10, input_size=10,
recompute_blocks=[1, 3], recompute_blocks=[1, 3],
use_fleet_sq=False,
segments=1,
use_raw_recompute=False,
recompute_kwargs={}): recompute_kwargs={}):
super(Naive_fc_net, self).__init__() super(Naive_fc_net, self).__init__()
self.recompute_blocks = recompute_blocks self.recompute_blocks = recompute_blocks
self.recompute_kwargs = recompute_kwargs self.recompute_kwargs = recompute_kwargs
self.use_fleet_sq = use_fleet_sq
self.use_raw_recompute = use_raw_recompute
self.segments = segments
self.runfunc0 = get_fc_block(0, input_size, is_last=False) self.runfunc0 = get_fc_block(0, input_size, is_last=False)
self.runfunc1 = get_fc_block(1, input_size, is_last=False) self.runfunc1 = get_fc_block(1, input_size, is_last=False)
self.runfunc2 = get_fc_block(2, input_size, is_last=False) self.runfunc2 = get_fc_block(2, input_size, is_last=False)
self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc3 = get_fc_block(3, input_size, is_last=False)
self.runfunc4 = get_fc_block(4, input_size, is_last=True) self.runfunc4 = get_fc_block(4, input_size, is_last=True)
def forward(self, inputs): if self.use_fleet_sq and not use_raw_recompute:
self.runfuncs = paddle.nn.Sequential(self.runfunc0, self.runfunc1,
self.runfunc2, self.runfunc3,
self.runfunc4)
if 0 in self.recompute_blocks: self.layers = [
inputs = recompute(self.runfunc0, inputs) self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3,
else: self.runfunc4
inputs = self.runfunc0(inputs) ]
if 1 in self.recompute_blocks: # default segments = 2
inputs = recompute(self.runfunc1, inputs) if use_raw_recompute:
else: self.layers = [
inputs = self.runfunc1(inputs) paddle.nn.Sequential(self.runfunc0, self.runfunc1),
paddle.nn.Sequential(self.runfunc2, self.runfunc3,
self.runfunc4)
]
if 2 in self.recompute_blocks: def forward(self, inputs):
inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs)
else:
inputs = self.runfunc2(inputs)
if 3 in self.recompute_blocks: if self.use_fleet_sq and not self.use_raw_recompute:
inputs = recompute(self.runfunc3, inputs) return recompute_sequential({"segments": self.segments},
else: self.runfuncs, inputs)
inputs = self.runfunc3(inputs)
if 4 in self.recompute_blocks: if self.use_raw_recompute:
inputs = recompute(self.runfunc4, inputs) inputs = recompute(self.layers[0], inputs)
else: return self.layers[1](inputs)
inputs = self.runfunc4(inputs)
for i in range(len(self.layers)):
if i in self.recompute_blocks:
inputs = recompute(self.layers[i], inputs,
**self.recompute_kwargs)
else:
inputs = self.layers[i](inputs)
return inputs return inputs
def run_model(recompute_block=[], def run_model(recompute_block=[],
recompute_kwargs={}, recompute_kwargs={},
use_fleet_sq=False,
use_raw_recompute=False,
segments=1,
enable_autocast=False, enable_autocast=False,
pure_fp16=False): pure_fp16=False):
gen = paddle.seed(10) gen = paddle.seed(10)
...@@ -105,6 +124,9 @@ def run_model(recompute_block=[], ...@@ -105,6 +124,9 @@ def run_model(recompute_block=[],
batch_size, input_size = 1, 10 batch_size, input_size = 1, 10
model = Naive_fc_net(input_size, model = Naive_fc_net(input_size,
recompute_blocks=recompute_block, recompute_blocks=recompute_block,
use_fleet_sq=use_fleet_sq,
use_raw_recompute=use_raw_recompute,
segments=segments,
recompute_kwargs=recompute_kwargs) recompute_kwargs=recompute_kwargs)
loss_fn = paddle.nn.MSELoss(reduction='mean') loss_fn = paddle.nn.MSELoss(reduction='mean')
optimizer = paddle.optimizer.SGD(learning_rate=0.01, optimizer = paddle.optimizer.SGD(learning_rate=0.01,
...@@ -179,6 +201,34 @@ class TestPyLayer(unittest.TestCase): ...@@ -179,6 +201,34 @@ class TestPyLayer(unittest.TestCase):
pure_fp16=pure_fp16) pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second & fourth block using fleet
loss, param, grad = run_model(recompute_block=[1, 3],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute using recompute_sequential, segments=1
loss, param, grad = run_model(recompute_block=[],
use_fleet_sq=True,
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# with base recompute, and segments=2
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[],
enable_autocast=enable_autocast,
use_raw_recompute=True,
pure_fp16=pure_fp16)
# recompute using recompute_sequential, segments=2
loss, param, grad = run_model(recompute_block=[],
use_fleet_sq=True,
segments=2,
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_with_dropout(self): def test_fc_net_with_dropout(self):
self.test_base_case() self.test_base_case()
...@@ -191,7 +241,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -191,7 +241,7 @@ class TestPyLayer(unittest.TestCase):
def test_recompute_kwargs(self): def test_recompute_kwargs(self):
paddle.set_device("gpu") paddle.set_device("gpu")
kwargs = {"is_test": False} kwargs = {"is_test": False}
with self.assertRaises(ValueError): with self.assertRaises(TypeError):
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], loss_ref, param_ref, grad_ref = run_model(recompute_block=[2],
recompute_kwargs=kwargs) recompute_kwargs=kwargs)
......
...@@ -25,6 +25,7 @@ import paddle ...@@ -25,6 +25,7 @@ import paddle
from paddle.autograd import PyLayer from paddle.autograd import PyLayer
from paddle.distributed.fleet.utils import recompute from paddle.distributed.fleet.utils import recompute
import random import random
from paddle.incubate.distributed.fleet import recompute_sequential
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
...@@ -57,48 +58,66 @@ class Naive_fc_net(paddle.nn.Layer): ...@@ -57,48 +58,66 @@ class Naive_fc_net(paddle.nn.Layer):
def __init__(self, def __init__(self,
input_size=10, input_size=10,
recompute_blocks=[1, 3], recompute_blocks=[1, 3],
use_fleet_sq=False,
segments=1,
use_raw_recompute=False,
recompute_kwargs={}): recompute_kwargs={}):
super(Naive_fc_net, self).__init__() super(Naive_fc_net, self).__init__()
self.recompute_blocks = recompute_blocks self.recompute_blocks = recompute_blocks
self.recompute_kwargs = recompute_kwargs self.recompute_kwargs = recompute_kwargs
self.use_fleet_sq = use_fleet_sq
self.use_raw_recompute = use_raw_recompute
self.segments = segments
self.runfunc0 = get_fc_block(0, input_size, is_last=False) self.runfunc0 = get_fc_block(0, input_size, is_last=False)
self.runfunc1 = get_fc_block(1, input_size, is_last=False) self.runfunc1 = get_fc_block(1, input_size, is_last=False)
self.runfunc2 = get_fc_block(2, input_size, is_last=False) self.runfunc2 = get_fc_block(2, input_size, is_last=False)
self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc3 = get_fc_block(3, input_size, is_last=False)
self.runfunc4 = get_fc_block(4, input_size, is_last=True) self.runfunc4 = get_fc_block(4, input_size, is_last=True)
def forward(self, inputs): if self.use_fleet_sq and not use_raw_recompute:
self.runfuncs = paddle.nn.Sequential(self.runfunc0, self.runfunc1,
self.runfunc2, self.runfunc3,
self.runfunc4)
if 0 in self.recompute_blocks: self.layers = [
inputs = recompute(self.runfunc0, inputs) self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3,
else: self.runfunc4
inputs = self.runfunc0(inputs) ]
if 1 in self.recompute_blocks: # default segments = 2
inputs = recompute(self.runfunc1, inputs) if use_raw_recompute:
else: self.layers = [
inputs = self.runfunc1(inputs) paddle.nn.Sequential(self.runfunc0, self.runfunc1),
paddle.nn.Sequential(self.runfunc2, self.runfunc3,
self.runfunc4)
]
if 2 in self.recompute_blocks: def forward(self, inputs):
inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs)
else:
inputs = self.runfunc2(inputs)
if 3 in self.recompute_blocks: if self.use_fleet_sq and not self.use_raw_recompute:
inputs = recompute(self.runfunc3, inputs) return paddle.incubate.distributed.fleet.recompute_sequential(
else: {"segments": self.segments}, self.runfuncs, inputs)
inputs = self.runfunc3(inputs)
if 4 in self.recompute_blocks: if self.use_raw_recompute:
inputs = recompute(self.runfunc4, inputs) inputs = recompute(self.layers[0], inputs)
else: return self.layers[1](inputs)
inputs = self.runfunc4(inputs)
for i in range(len(self.layers)):
if i in self.recompute_blocks:
inputs = recompute(self.layers[i], inputs,
**self.recompute_kwargs)
else:
inputs = self.layers[i](inputs)
return inputs return inputs
def run_model(recompute_block=[], def run_model(recompute_block=[],
recompute_kwargs={}, recompute_kwargs={},
use_fleet_sq=False,
use_raw_recompute=False,
segments=1,
enable_autocast=False, enable_autocast=False,
pure_fp16=False): pure_fp16=False):
gen = paddle.seed(10) gen = paddle.seed(10)
...@@ -109,6 +128,9 @@ def run_model(recompute_block=[], ...@@ -109,6 +128,9 @@ def run_model(recompute_block=[],
batch_size, input_size = 1, 10 batch_size, input_size = 1, 10
model = Naive_fc_net(input_size, model = Naive_fc_net(input_size,
recompute_blocks=recompute_block, recompute_blocks=recompute_block,
use_fleet_sq=use_fleet_sq,
use_raw_recompute=use_raw_recompute,
segments=segments,
recompute_kwargs=recompute_kwargs) recompute_kwargs=recompute_kwargs)
loss_fn = paddle.nn.MSELoss(reduction='mean') loss_fn = paddle.nn.MSELoss(reduction='mean')
optimizer = paddle.optimizer.SGD(learning_rate=0.01, optimizer = paddle.optimizer.SGD(learning_rate=0.01,
...@@ -183,6 +205,28 @@ class TestPyLayer(unittest.TestCase): ...@@ -183,6 +205,28 @@ class TestPyLayer(unittest.TestCase):
pure_fp16=pure_fp16) pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute_sequential with segments=1 using fleet
loss, param, grad = run_model(recompute_block=[],
use_fleet_sq=True,
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# with base recompute, and segments=2
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[],
enable_autocast=enable_autocast,
use_raw_recompute=True,
pure_fp16=pure_fp16)
# recompute using paddle.incubate.distributed.fleet.recompute_sequential, segments=2
loss, param, grad = run_model(recompute_block=[],
use_fleet_sq=True,
segments=2,
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_with_dropout(self): def test_fc_net_with_dropout(self):
self.test_base_case() self.test_base_case()
...@@ -201,7 +245,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -201,7 +245,7 @@ class TestPyLayer(unittest.TestCase):
def test_recompute_kwargs(self): def test_recompute_kwargs(self):
paddle.set_device("gpu") paddle.set_device("gpu")
kwargs = {"is_test": False} kwargs = {"is_test": False}
with self.assertRaises(ValueError): with self.assertRaises(TypeError):
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], loss_ref, param_ref, grad_ref = run_model(recompute_block=[2],
recompute_kwargs=kwargs) recompute_kwargs=kwargs)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle.autograd import PyLayer
from paddle.distributed.fleet.utils import recompute
from paddle.incubate.distributed.fleet import recompute_hybrid
import random
from paddle.distributed import fleet
import paddle.fluid.layers as layers
def get_fc_block(block_idx, input_size, is_last=False):
block_name = "block_" + str(block_idx)
block = paddle.nn.Sequential(
(block_name + "_fc_0",
paddle.nn.Linear(input_size, input_size, bias_attr=False)),
(block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
(block_name + "_relu_1", paddle.nn.ReLU()),
(block_name + "_fc_1",
paddle.nn.Linear(input_size, input_size, bias_attr=False)),
(block_name + "_relu_2", paddle.nn.ReLU()),
)
if is_last:
block.add_sublayer(block_name + "_fc_2",
paddle.nn.Linear(input_size, 1,
bias_attr=False)) # add sublayer
else:
block.add_sublayer(block_name + "_fc_2",
paddle.nn.Linear(input_size,
input_size,
bias_attr=False)) # add sublayer
return block
class Naive_fc_net(paddle.nn.Layer):
def __init__(self,
input_size=10,
recompute_blocks=[1, 3],
offload=False,
partition=False,
recompute_kwargs={}):
super(Naive_fc_net, self).__init__()
self.recompute_blocks = recompute_blocks
self.recompute_kwargs = recompute_kwargs
self.offload = offload
self.partition = partition
self.runfunc0 = get_fc_block(0, input_size, is_last=False)
self.runfunc1 = get_fc_block(1, input_size, is_last=False)
self.runfunc2 = get_fc_block(2, input_size, is_last=False)
self.runfunc3 = get_fc_block(3, input_size, is_last=False)
self.runfunc4 = get_fc_block(4, input_size, is_last=True)
self.layers = [
self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3,
self.runfunc4
]
def forward(self, inputs):
for i in range(len(self.layers)):
if i in self.recompute_blocks:
inputs = recompute_hybrid(
{
"mp_group": fleet.fleet._hcg.get_model_parallel_group(),
"offload": self.offload,
"partition": self.partition
}, self.layers[i], inputs, **self.recompute_kwargs)
else:
inputs = self.layers[i](inputs)
return inputs
def run_model(recompute_block=[],
recompute_kwargs={},
offload=False,
partition=False,
enable_autocast=False,
pure_fp16=False):
gen = paddle.seed(10)
gen.manual_seed(10)
np.random.seed(10)
random.seed(10)
batch_size, input_size = 1, 10
model = Naive_fc_net(input_size,
recompute_blocks=recompute_block,
offload=offload,
partition=partition,
recompute_kwargs=recompute_kwargs)
loss_fn = paddle.nn.MSELoss(reduction='mean')
optimizer = paddle.optimizer.SGD(learning_rate=0.01,
parameters=model.parameters())
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
if enable_autocast:
scaler = paddle.amp.GradScaler()
scaler = fleet.distributed_scaler(scaler)
loss_ = []
param_ = []
grad_ = []
for step in range(10):
x_data = np.random.randn(batch_size, input_size).astype(np.float32)
x = paddle.to_tensor(x_data)
# x.stop_gradient = False
level = 'O2' if pure_fp16 else 'O1'
with paddle.amp.auto_cast(True, level=level):
y_pred = model(x)
loss = y_pred.mean()
if enable_autocast:
scaler.scale(loss).backward()
scaler.minimize(optimizer, loss)
else:
loss_.append(np.asarray(loss).tolist())
loss.backward()
optimizer.step()
param_.append(np.asarray(model.parameters()[9]).tolist())
grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist())
optimizer.clear_grad()
return loss_, param_, grad_
class TestPyLayer(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
self.data_parallel_size = 1
self.pipeline_parallel_size = 1
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
fleet.init(is_collective=True, strategy=strategy)
def test_base_case(self, enable_autocast=False, pure_fp16=False):
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)
# without recompute
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
# with recompute, offload=False, partition=False
loss, param, grad = run_model(recompute_block=[1, 3],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# with recompute, offload=True, partition=False
loss, param, grad = run_model(recompute_block=[1, 2, 3],
offload=True,
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# with recompute, offload=False, partition=True
loss, param, grad = run_model(recompute_block=[1],
partition=True,
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# with recompute, offload=True, partition=True
loss, param, grad = run_model(recompute_block=[1, 3, 4],
offload=True,
partition=True,
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_with_dropout(self):
self.test_base_case()
def test_fc_net_with_amp(self):
self.test_base_case(enable_autocast=True)
def test_fc_net_with_fp16(self):
self.test_base_case(enable_autocast=True, pure_fp16=True)
def test_recompute_kwargs(self):
paddle.set_device("gpu")
kwargs = {"is_test": False}
with self.assertRaises(TypeError):
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2],
recompute_kwargs=kwargs)
if __name__ == '__main__':
unittest.main()
...@@ -26,5 +26,11 @@ class TestPipelineParallel(TestMultipleGpus): ...@@ -26,5 +26,11 @@ class TestPipelineParallel(TestMultipleGpus):
self.run_mnist_2gpu('hybrid_parallel_pp_alexnet.py') self.run_mnist_2gpu('hybrid_parallel_pp_alexnet.py')
class TestModelParallelWithRecompute(TestMultipleGpus):
def test_model_parallel_with_recompute(self):
self.run_mnist_2gpu("dygraph_recompute_hybrid.py")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.fleet.recompute import recompute_sequential, recompute_hybrid
__all__ = ["recompute_sequential", "recompute_hybrid"]
...@@ -34,9 +34,9 @@ from paddle.distributed import fleet ...@@ -34,9 +34,9 @@ from paddle.distributed import fleet
from paddle.autograd import PyLayer from paddle.autograd import PyLayer
from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate
from .utils import count_by_gate from .utils import count_by_gate
from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute
from paddle import fluid from paddle import fluid
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.incubate.distributed.fleet import recompute_hybrid
def _local_scatter(inp, pos): def _local_scatter(inp, pos):
...@@ -424,8 +424,8 @@ class MoELayer(nn.Layer): ...@@ -424,8 +424,8 @@ class MoELayer(nn.Layer):
if self.recompute_interval <= 0 or x.shape[0] == 0: if self.recompute_interval <= 0 or x.shape[0] == 0:
x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) x = experts_fwd(x, fwd_expert_count.numpy(), self.experts)
else: else:
x = _hp_recompute(experts_fwd, x, fwd_expert_count.numpy(), x = recompute_hybrid(self.recompute_ctx, experts_fwd, x,
self.experts) fwd_expert_count.numpy(), self.experts)
out_batch_size = inp.shape[0] out_batch_size = inp.shape[0]
if len(gate.shape) == 2: if len(gate.shape) == 2:
......
...@@ -19,7 +19,7 @@ from paddle.fluid.clip import ClipGradByGlobalNorm ...@@ -19,7 +19,7 @@ from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.fluid.initializer import Constant from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.optimizer import Optimizer from paddle.fluid.optimizer import Optimizer
from paddle.distributed import get_rank, get_world_size import paddle.distributed as dist
from paddle.distributed.collective import new_group from paddle.distributed.collective import new_group
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope
from paddle.fluid.framework import name_scope from paddle.fluid.framework import name_scope
...@@ -288,8 +288,8 @@ class DistributedFusedLamb(Optimizer): ...@@ -288,8 +288,8 @@ class DistributedFusedLamb(Optimizer):
step = self._get_or_create_step() step = self._get_or_create_step()
rank = get_rank() rank = dist.get_rank()
nranks = get_world_size() nranks = dist.get_world_size()
if self._nproc_per_node is None: if self._nproc_per_node is None:
nproc_per_node = nranks nproc_per_node = nranks
else: else:
......
...@@ -296,6 +296,7 @@ packages=['paddle', ...@@ -296,6 +296,7 @@ packages=['paddle',
'paddle.distributed.launch.plugins', 'paddle.distributed.launch.plugins',
'paddle.distributed.launch.utils', 'paddle.distributed.launch.utils',
'paddle.distributed.fleet.base', 'paddle.distributed.fleet.base',
'paddle.distributed.fleet.recompute',
'paddle.distributed.fleet.elastic', 'paddle.distributed.fleet.elastic',
'paddle.distributed.fleet.meta_optimizers', 'paddle.distributed.fleet.meta_optimizers',
'paddle.distributed.fleet.meta_optimizers.sharding', 'paddle.distributed.fleet.meta_optimizers.sharding',
...@@ -380,6 +381,7 @@ packages=['paddle', ...@@ -380,6 +381,7 @@ packages=['paddle',
'paddle.incubate.optimizer.functional', 'paddle.incubate.optimizer.functional',
'paddle.incubate.autograd', 'paddle.incubate.autograd',
'paddle.incubate.distributed', 'paddle.incubate.distributed',
'paddle.incubate.distributed.fleet',
'paddle.incubate.distributed.models', 'paddle.incubate.distributed.models',
'paddle.incubate.distributed.models.moe', 'paddle.incubate.distributed.models.moe',
'paddle.incubate.distributed.models.moe.gate', 'paddle.incubate.distributed.models.moe.gate',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册