未验证 提交 2e9e65d8 编写于 作者: W wuhuachaocoding 提交者: GitHub

【cherry-pick】update Recompute doc (#47784)

* cherry-pick recompute doc update.

* update.
上级 ff642c68
...@@ -41,16 +41,23 @@ def detach_variable(inputs): ...@@ -41,16 +41,23 @@ def detach_variable(inputs):
def check_recompute_necessary(inputs): def check_recompute_necessary(inputs):
if not any(input_.stop_gradient == False for input_ in inputs if not any(
if isinstance(input_, (core.eager.Tensor, paddle.Tensor))): input_.stop_gradient == False
for input_ in inputs
if isinstance(input_, (core.eager.Tensor, paddle.Tensor))
):
logger.warning( logger.warning(
"[Recompute]: None of the inputs to current recompute block need grad, " "[Recompute]: None of the inputs to current recompute block need grad, "
"therefore there is NO need to recompute this block in backward !") "therefore there is NO need to recompute this block in backward !"
)
@contextlib.contextmanager @contextlib.contextmanager
def swith_rng_state_tracker(rng_state, tracker): def swith_rng_state_tracker(rng_state, tracker):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
orig_cuda_rng_state = paddle.get_cuda_rng_state() orig_cuda_rng_state = paddle.get_cuda_rng_state()
orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker() orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker()
...@@ -64,10 +71,11 @@ def swith_rng_state_tracker(rng_state, tracker): ...@@ -64,10 +71,11 @@ def swith_rng_state_tracker(rng_state, tracker):
class LegacyRecomputeFunction(LegacyPyLayer): class LegacyRecomputeFunction(LegacyPyLayer):
@staticmethod @staticmethod
def forward(ctx, run_function, preserve_rng_state, *args): def forward(ctx, run_function, preserve_rng_state, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
# store for recomputing # store for recomputing
ctx.run_function = run_function ctx.run_function = run_function
...@@ -96,30 +104,37 @@ class LegacyRecomputeFunction(LegacyPyLayer): ...@@ -96,30 +104,37 @@ class LegacyRecomputeFunction(LegacyPyLayer):
cur_device = paddle.get_device() cur_device = paddle.get_device()
if 'gpu:' not in cur_device: if 'gpu:' not in cur_device:
raise RuntimeError( raise RuntimeError(
"Recompute with RNG perserve is not support current device: {}." "Recompute with RNG perserve is not support current device: {}.".format(
.format(cur_device)) cur_device
)
)
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( ctx.fwd_cuda_rng_state_tracker = (
).get_states_tracker() get_rng_state_tracker().get_states_tracker()
)
# TODO support AMP # TODO support AMP
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True ctx.is_fw_autocast = (
False if tracer._amp_level == core.AmpLevel.O0 else True
)
if tracer._amp_level == core.AmpLevel.O2: if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2' ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1' ctx.amp_level = 'O1'
else: else:
raise ValueError("unsupported amp level: {}".format( raise ValueError(
tracer._amp_level)) "unsupported amp level: {}".format(tracer._amp_level)
)
if tracer._amp_dtype == 'float16': if tracer._amp_dtype == 'float16':
ctx.amp_dtype = 'float16' ctx.amp_dtype = 'float16'
elif tracer._amp_dtype in ('bfloat16', 'float32'): elif tracer._amp_dtype in ('bfloat16', 'float32'):
ctx.amp_dtype = 'bfloat16' ctx.amp_dtype = 'bfloat16'
else: else:
raise ValueError("unsupported amp dtype: {}".format( raise ValueError(
tracer._amp_dtype)) "unsupported amp dtype: {}".format(tracer._amp_dtype)
)
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
...@@ -129,7 +144,10 @@ class LegacyRecomputeFunction(LegacyPyLayer): ...@@ -129,7 +144,10 @@ class LegacyRecomputeFunction(LegacyPyLayer):
@staticmethod @staticmethod
def backward(ctx, *args): def backward(ctx, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
with paddle.fluid.dygraph.guard(): with paddle.fluid.dygraph.guard():
# TODO need to check the recompute calling is vaild or not # TODO need to check the recompute calling is vaild or not
...@@ -147,27 +165,31 @@ class LegacyRecomputeFunction(LegacyPyLayer): ...@@ -147,27 +165,31 @@ class LegacyRecomputeFunction(LegacyPyLayer):
# NOTE support AMP # NOTE support AMP
# need restore auto_cast state as well as w/b list # need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state: if ctx.preserve_rng_state:
with swith_rng_state_tracker(ctx.fw_cuda_rng_state, with swith_rng_state_tracker(
ctx.fwd_cuda_rng_state_tracker): ctx.fw_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker
):
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast, enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list, custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list, custom_black_list=ctx.amp_black_list,
level=ctx.amp_level, level=ctx.amp_level,
dtype=ctx.amp_dtype): dtype=ctx.amp_dtype,
):
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs) outputs = ctx.run_function(*detached_inputs)
else: else:
with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, with paddle.amp.auto_cast(
custom_white_list=ctx.amp_white_list, enable=ctx.is_fw_autocast,
custom_black_list=ctx.amp_black_list, custom_white_list=ctx.amp_white_list,
level=ctx.amp_level, custom_black_list=ctx.amp_black_list,
dtype=ctx.amp_dtype): level=ctx.amp_level,
dtype=ctx.amp_dtype,
):
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs) outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, core.VarBase): if isinstance(outputs, core.VarBase):
outputs = (outputs, ) outputs = (outputs,)
assert len(outputs) == len(args) assert len(outputs) == len(args)
# run backward() with only tensor that requires grad # run backward() with only tensor that requires grad
...@@ -178,8 +200,10 @@ class LegacyRecomputeFunction(LegacyPyLayer): ...@@ -178,8 +200,10 @@ class LegacyRecomputeFunction(LegacyPyLayer):
# the following backward_inputs_with_grad is used to avoid this case. # the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad = [] backward_inputs_with_grad = []
for i in range(len(outputs)): for i in range(len(outputs)):
if isinstance(outputs[i], if (
core.VarBase) and not outputs[i].stop_gradient: isinstance(outputs[i], core.VarBase)
and not outputs[i].stop_gradient
):
forward_outputs_with_grad.append(outputs[i]) forward_outputs_with_grad.append(outputs[i])
backward_inputs_with_grad.append(args[i]) backward_inputs_with_grad.append(args[i])
...@@ -190,19 +214,24 @@ class LegacyRecomputeFunction(LegacyPyLayer): ...@@ -190,19 +214,24 @@ class LegacyRecomputeFunction(LegacyPyLayer):
# actually backward # actually backward
with paddle.amp.auto_cast(enable=False): with paddle.amp.auto_cast(enable=False):
paddle.autograd.backward(forward_outputs_with_grad, paddle.autograd.backward(
backward_inputs_with_grad) forward_outputs_with_grad, backward_inputs_with_grad
)
grads = list(inp._grad_ivar() for inp in detached_inputs grads = list(
if isinstance(inp, core.VarBase)) inp._grad_ivar()
for inp in detached_inputs
if isinstance(inp, core.VarBase)
)
return grads return grads
class RecomputeFunction(PyLayer): class RecomputeFunction(PyLayer):
@staticmethod @staticmethod
def forward(ctx, run_function, preserve_rng_state, *args, **kwargs): def forward(ctx, run_function, preserve_rng_state, *args, **kwargs):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
# store for recomputing # store for recomputing
ctx.run_function = run_function ctx.run_function = run_function
...@@ -232,30 +261,37 @@ class RecomputeFunction(PyLayer): ...@@ -232,30 +261,37 @@ class RecomputeFunction(PyLayer):
cur_device = paddle.get_device() cur_device = paddle.get_device()
if 'gpu:' not in cur_device: if 'gpu:' not in cur_device:
raise RuntimeError( raise RuntimeError(
"Recompute with RNG perserve is not support current device: {}." "Recompute with RNG perserve is not support current device: {}.".format(
.format(cur_device)) cur_device
)
)
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( ctx.fwd_cuda_rng_state_tracker = (
).get_states_tracker() get_rng_state_tracker().get_states_tracker()
)
# TODO support AMP # TODO support AMP
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True ctx.is_fw_autocast = (
False if tracer._amp_level == core.AmpLevel.O0 else True
)
if tracer._amp_level == core.AmpLevel.O2: if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2' ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1' ctx.amp_level = 'O1'
else: else:
raise ValueError("unsupported amp level: {}".format( raise ValueError(
tracer._amp_level)) "unsupported amp level: {}".format(tracer._amp_level)
)
if tracer._amp_dtype == 'float16': if tracer._amp_dtype == 'float16':
ctx.amp_dtype = 'float16' ctx.amp_dtype = 'float16'
elif tracer._amp_dtype in ('bfloat16', 'float32'): elif tracer._amp_dtype in ('bfloat16', 'float32'):
ctx.amp_dtype = 'bfloat16' ctx.amp_dtype = 'bfloat16'
else: else:
raise ValueError("unsupported amp dtype: {}".format( raise ValueError(
tracer._amp_dtype)) "unsupported amp dtype: {}".format(tracer._amp_dtype)
)
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
...@@ -265,7 +301,10 @@ class RecomputeFunction(PyLayer): ...@@ -265,7 +301,10 @@ class RecomputeFunction(PyLayer):
@staticmethod @staticmethod
def backward(ctx, *args): def backward(ctx, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
with paddle.fluid.dygraph.guard(): with paddle.fluid.dygraph.guard():
# TODO need to check the recompute calling is vaild or not # TODO need to check the recompute calling is vaild or not
...@@ -283,28 +322,33 @@ class RecomputeFunction(PyLayer): ...@@ -283,28 +322,33 @@ class RecomputeFunction(PyLayer):
# NOTE support AMP # NOTE support AMP
# need restore auto_cast state as well as w/b list # need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state: if ctx.preserve_rng_state:
with swith_rng_state_tracker(ctx.fw_cuda_rng_state, with swith_rng_state_tracker(
ctx.fwd_cuda_rng_state_tracker): ctx.fw_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker
):
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast, enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list, custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list, custom_black_list=ctx.amp_black_list,
level=ctx.amp_level, level=ctx.amp_level,
dtype=ctx.amp_dtype): dtype=ctx.amp_dtype,
):
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs, outputs = ctx.run_function(
**ctx.kwargs) *detached_inputs, **ctx.kwargs
)
else: else:
with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, with paddle.amp.auto_cast(
custom_white_list=ctx.amp_white_list, enable=ctx.is_fw_autocast,
custom_black_list=ctx.amp_black_list, custom_white_list=ctx.amp_white_list,
level=ctx.amp_level, custom_black_list=ctx.amp_black_list,
dtype=ctx.amp_dtype): level=ctx.amp_level,
dtype=ctx.amp_dtype,
):
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
if isinstance(outputs, (core.VarBase, core.eager.Tensor)): if isinstance(outputs, (core.VarBase, core.eager.Tensor)):
outputs = (outputs, ) outputs = (outputs,)
assert len(outputs) == len(args) assert len(outputs) == len(args)
# run backward() with only tensor that requires grad # run backward() with only tensor that requires grad
...@@ -315,10 +359,10 @@ class RecomputeFunction(PyLayer): ...@@ -315,10 +359,10 @@ class RecomputeFunction(PyLayer):
# the following backward_inputs_with_grad is used to avoid this case. # the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad = [] backward_inputs_with_grad = []
for i in range(len(outputs)): for i in range(len(outputs)):
if isinstance( if (
outputs[i], isinstance(outputs[i], (core.VarBase, core.eager.Tensor))
(core.VarBase, and not outputs[i].stop_gradient
core.eager.Tensor)) and not outputs[i].stop_gradient: ):
forward_outputs_with_grad.append(outputs[i]) forward_outputs_with_grad.append(outputs[i])
backward_inputs_with_grad.append(args[i]) backward_inputs_with_grad.append(args[i])
...@@ -329,17 +373,22 @@ class RecomputeFunction(PyLayer): ...@@ -329,17 +373,22 @@ class RecomputeFunction(PyLayer):
# actually backward # actually backward
with paddle.amp.auto_cast(enable=False): with paddle.amp.auto_cast(enable=False):
paddle.autograd.backward(forward_outputs_with_grad, paddle.autograd.backward(
backward_inputs_with_grad) forward_outputs_with_grad, backward_inputs_with_grad
)
if in_dygraph_mode(): if in_dygraph_mode():
grads = tuple( grads = tuple(
inp._grad_ivar() for inp in detached_inputs inp._grad_ivar()
if isinstance(inp, (core.VarBase, core.eager.Tensor))) for inp in detached_inputs
if isinstance(inp, (core.VarBase, core.eager.Tensor))
)
else: else:
grads = list( grads = list(
inp._grad_ivar() for inp in detached_inputs inp._grad_ivar()
if isinstance(inp, (core.VarBase, core.eager.Tensor))) for inp in detached_inputs
if isinstance(inp, (core.VarBase, core.eager.Tensor))
)
return grads return grads
...@@ -363,13 +412,10 @@ def recompute(function, *args, **kwargs): ...@@ -363,13 +412,10 @@ def recompute(function, *args, **kwargs):
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np
import paddle import paddle
from paddle.distributed.fleet.utils import recompute from paddle.distributed.fleet.utils import recompute
import random import random
# required: gpu # required: gpu
def get_fc_block(block_idx, input_size, is_last=False): def get_fc_block(block_idx, input_size, is_last=False):
block_name = "block_" + str(block_idx) block_name = "block_" + str(block_idx)
block = paddle.nn.Sequential( block = paddle.nn.Sequential(
...@@ -391,10 +437,7 @@ def recompute(function, *args, **kwargs): ...@@ -391,10 +437,7 @@ def recompute(function, *args, **kwargs):
block_name + "_fc_2", block_name + "_fc_2",
paddle.nn.Linear(input_size, input_size, bias_attr=False) paddle.nn.Linear(input_size, input_size, bias_attr=False)
) )
return block return block
class Naive_fc_net(paddle.nn.Layer): class Naive_fc_net(paddle.nn.Layer):
def __init__(self, input_size=10, def __init__(self, input_size=10,
recompute_blocks=[1, 3], recompute_blocks=[1, 3],
...@@ -408,7 +451,6 @@ def recompute(function, *args, **kwargs): ...@@ -408,7 +451,6 @@ def recompute(function, *args, **kwargs):
self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc3 = get_fc_block(3, input_size, is_last=False)
self.runfunc4 = get_fc_block(4, input_size, is_last=True) self.runfunc4 = get_fc_block(4, input_size, is_last=True)
self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4] self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4]
def forward(self, inputs): def forward(self, inputs):
nums = len(self.total_func) nums = len(self.total_func)
for i in range(nums): for i in range(nums):
...@@ -417,15 +459,12 @@ def recompute(function, *args, **kwargs): ...@@ -417,15 +459,12 @@ def recompute(function, *args, **kwargs):
else: else:
inputs = self.total_func[i](inputs) inputs = self.total_func[i](inputs)
return inputs return inputs
def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
gen = paddle.seed(10) gen = paddle.seed(10)
gen.manual_seed(10) gen.manual_seed(10)
np.random.seed(10)
random.seed(10) random.seed(10)
if cuda_state: if cuda_state:
paddle.set_cuda_rng_state(cuda_state) paddle.set_cuda_rng_state(cuda_state)
batch_size, input_size = 1, 10 batch_size, input_size = 1, 10
model = Naive_fc_net( model = Naive_fc_net(
input_size, input_size,
...@@ -436,29 +475,24 @@ def recompute(function, *args, **kwargs): ...@@ -436,29 +475,24 @@ def recompute(function, *args, **kwargs):
param_ = [] param_ = []
grad_ = [] grad_ = []
for _ in range(5): for _ in range(5):
x_data = np.random.randn(batch_size, input_size).astype(np.float32) x = paddle.rand(shape=[batch_size, input_size], dtype="float32")
x = paddle.to_tensor(x_data)
y_pred = model(x) y_pred = model(x)
loss = y_pred.mean() loss = y_pred.mean()
loss_.append(np.asarray(loss).tolist()) loss_.append(loss.item())
loss.backward() loss.backward()
optimizer.step() optimizer.step()
param_.append(np.asarray(model.parameters()[9]).tolist()) param_.append(model.parameters()[9])
grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist()) grad_.append(model.parameters()[3]._grad_ivar())
optimizer.clear_grad() optimizer.clear_grad()
return loss_, param_, grad_ return loss_, param_, grad_
cuda_state = paddle.get_cuda_rng_state() cuda_state = paddle.get_cuda_rng_state()
# without recompute # without recompute
loss_ref, param_ref, grad_ref = run_model( loss_ref, param_ref, grad_ref = run_model(
cuda_state, recompute_block=[] cuda_state, recompute_block=[]
) )
loss, param, grad = run_model(cuda_state, recompute_block=[1, 2]) loss, param, grad = run_model(cuda_state, recompute_block=[1, 2])
print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss)) print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss))
# The result of the recompute_loss should be the same as the normal_loss. # The result of the recompute_loss should be the same as the normal_loss.
""" """
# Hack to mix *args with **kwargs in a python 2.7-compliant way # Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve = kwargs.pop('preserve_rng_state', True) preserve = kwargs.pop('preserve_rng_state', True)
...@@ -497,7 +531,6 @@ def recompute_sequential(ctx, functions, *args, **kwargs): ...@@ -497,7 +531,6 @@ def recompute_sequential(ctx, functions, *args, **kwargs):
preserve_rng_state = ctx.get('preserve_rng_state', True) preserve_rng_state = ctx.get('preserve_rng_state', True)
def _run_func(begin, end, funcs): def _run_func(begin, end, funcs):
def do_run(input): def do_run(input):
for i in range(begin, end + 1): for i in range(begin, end + 1):
input = funcs[i](input) input = funcs[i](input)
...@@ -513,8 +546,10 @@ def recompute_sequential(ctx, functions, *args, **kwargs): ...@@ -513,8 +546,10 @@ def recompute_sequential(ctx, functions, *args, **kwargs):
end = -1 end = -1
for begin in range(0, segment_size * (segments - 1), segment_size): for begin in range(0, segment_size * (segments - 1), segment_size):
end = begin + segment_size - 1 end = begin + segment_size - 1
args = recompute(_run_func(begin, end, functions), args = recompute(
*args, _run_func(begin, end, functions),
preserve_rng_state=preserve_rng_state, *args,
**kwargs) preserve_rng_state=preserve_rng_state,
**kwargs
)
return _run_func(end + 1, len(functions) - 1, functions)(args) return _run_func(end + 1, len(functions) - 1, functions)(args)
...@@ -22,14 +22,108 @@ import paddle ...@@ -22,14 +22,108 @@ import paddle
from . import log_util # noqa: F401 from . import log_util # noqa: F401
from . import hybrid_parallel_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401
__all__ = [ #noqa __all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa
"LocalFS", "recompute", "DistributedInfer", "HDFSClient"
]
@deprecated(since="2.4.0",
update_to="paddle.distributed.fleet.recompute",
level=1,
reason="Please use new recompute API(fleet.recompute) ")
def recompute(function, *args, **kwargs): def recompute(function, *args, **kwargs):
"""
recompute intermediate activations to save then memory.
Parameters:
function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation.
*args(Tensor): inputs to the function.
**kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to
indicate whether to save the forward rng. If it is True, then the last forward rng value will be
restored when the forward recalculation of backpropagation is performed. The default
preserve_rng_state is True.
Returns:
Output of function on args.
Examples:
.. code-block:: python
import paddle
from paddle.distributed.fleet.utils import recompute
import random
# required: gpu
def get_fc_block(block_idx, input_size, is_last=False):
block_name = "block_" + str(block_idx)
block = paddle.nn.Sequential(
(block_name + "_fc_0", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
(block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
(block_name + "_relu_1", paddle.nn.ReLU()),
(block_name + "_fc_1", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
(block_name + "_relu_2", paddle.nn.ReLU()),
)
if is_last:
block.add_sublayer(
block_name + "_fc_2",
paddle.nn.Linear(
input_size, 1, bias_attr=False
)
)
else:
block.add_sublayer(
block_name + "_fc_2",
paddle.nn.Linear(input_size, input_size, bias_attr=False)
)
return block
class Naive_fc_net(paddle.nn.Layer):
def __init__(self, input_size=10,
recompute_blocks=[1, 3],
recompute_kwargs={}):
super(Naive_fc_net, self).__init__()
self.recompute_blocks = recompute_blocks
self.recompute_kwargs = recompute_kwargs
self.runfunc0 = get_fc_block(0, input_size, is_last=False)
self.runfunc1 = get_fc_block(1, input_size, is_last=False)
self.runfunc2 = get_fc_block(2, input_size, is_last=False)
self.runfunc3 = get_fc_block(3, input_size, is_last=False)
self.runfunc4 = get_fc_block(4, input_size, is_last=True)
self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4]
def forward(self, inputs):
nums = len(self.total_func)
for i in range(nums):
if i in self.recompute_blocks:
inputs = recompute(self.total_func[i], inputs, **{"preserve_rng_state": True})
else:
inputs = self.total_func[i](inputs)
return inputs
def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
gen = paddle.seed(10)
gen.manual_seed(10)
random.seed(10)
if cuda_state:
paddle.set_cuda_rng_state(cuda_state)
batch_size, input_size = 1, 10
model = Naive_fc_net(
input_size,
recompute_blocks=recompute_block,
recompute_kwargs=recompute_kwargs)
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
loss_ = []
param_ = []
grad_ = []
for _ in range(5):
x = paddle.rand(shape=[batch_size, input_size], dtype="float32")
y_pred = model(x)
loss = y_pred.mean()
loss_.append(loss.item())
loss.backward()
optimizer.step()
param_.append(model.parameters()[9])
grad_.append(model.parameters()[3]._grad_ivar())
optimizer.clear_grad()
return loss_, param_, grad_
cuda_state = paddle.get_cuda_rng_state()
# without recompute
loss_ref, param_ref, grad_ref = run_model(
cuda_state, recompute_block=[]
)
loss, param, grad = run_model(cuda_state, recompute_block=[1, 2])
print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss))
# The result of the recompute_loss should be the same as the normal_loss.
"""
return fleet.recompute.recompute(function, *args, **kwargs) return fleet.recompute.recompute(function, *args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册