未验证 提交 2e9e65d8 编写于 作者: W wuhuachaocoding 提交者: GitHub

【cherry-pick】update Recompute doc (#47784)

* cherry-pick recompute doc update.

* update.
上级 ff642c68
......@@ -41,16 +41,23 @@ def detach_variable(inputs):
def check_recompute_necessary(inputs):
if not any(input_.stop_gradient == False for input_ in inputs
if isinstance(input_, (core.eager.Tensor, paddle.Tensor))):
if not any(
input_.stop_gradient == False
for input_ in inputs
if isinstance(input_, (core.eager.Tensor, paddle.Tensor))
):
logger.warning(
"[Recompute]: None of the inputs to current recompute block need grad, "
"therefore there is NO need to recompute this block in backward !")
"therefore there is NO need to recompute this block in backward !"
)
@contextlib.contextmanager
def swith_rng_state_tracker(rng_state, tracker):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
orig_cuda_rng_state = paddle.get_cuda_rng_state()
orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker()
......@@ -64,10 +71,11 @@ def swith_rng_state_tracker(rng_state, tracker):
class LegacyRecomputeFunction(LegacyPyLayer):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
# store for recomputing
ctx.run_function = run_function
......@@ -96,30 +104,37 @@ class LegacyRecomputeFunction(LegacyPyLayer):
cur_device = paddle.get_device()
if 'gpu:' not in cur_device:
raise RuntimeError(
"Recompute with RNG perserve is not support current device: {}."
.format(cur_device))
"Recompute with RNG perserve is not support current device: {}.".format(
cur_device
)
)
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker(
).get_states_tracker()
ctx.fwd_cuda_rng_state_tracker = (
get_rng_state_tracker().get_states_tracker()
)
# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
ctx.is_fw_autocast = (
False if tracer._amp_level == core.AmpLevel.O0 else True
)
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
raise ValueError(
"unsupported amp level: {}".format(tracer._amp_level)
)
if tracer._amp_dtype == 'float16':
ctx.amp_dtype = 'float16'
elif tracer._amp_dtype in ('bfloat16', 'float32'):
ctx.amp_dtype = 'bfloat16'
else:
raise ValueError("unsupported amp dtype: {}".format(
tracer._amp_dtype))
raise ValueError(
"unsupported amp dtype: {}".format(tracer._amp_dtype)
)
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
......@@ -129,7 +144,10 @@ class LegacyRecomputeFunction(LegacyPyLayer):
@staticmethod
def backward(ctx, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
with paddle.fluid.dygraph.guard():
# TODO need to check the recompute calling is vaild or not
......@@ -147,27 +165,31 @@ class LegacyRecomputeFunction(LegacyPyLayer):
# NOTE support AMP
# need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state:
with swith_rng_state_tracker(ctx.fw_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker):
with swith_rng_state_tracker(
ctx.fw_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker
):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level,
dtype=ctx.amp_dtype):
dtype=ctx.amp_dtype,
):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
else:
with paddle.amp.auto_cast(enable=ctx.is_fw_autocast,
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level,
dtype=ctx.amp_dtype):
dtype=ctx.amp_dtype,
):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, core.VarBase):
outputs = (outputs, )
outputs = (outputs,)
assert len(outputs) == len(args)
# run backward() with only tensor that requires grad
......@@ -178,8 +200,10 @@ class LegacyRecomputeFunction(LegacyPyLayer):
# the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad = []
for i in range(len(outputs)):
if isinstance(outputs[i],
core.VarBase) and not outputs[i].stop_gradient:
if (
isinstance(outputs[i], core.VarBase)
and not outputs[i].stop_gradient
):
forward_outputs_with_grad.append(outputs[i])
backward_inputs_with_grad.append(args[i])
......@@ -190,19 +214,24 @@ class LegacyRecomputeFunction(LegacyPyLayer):
# actually backward
with paddle.amp.auto_cast(enable=False):
paddle.autograd.backward(forward_outputs_with_grad,
backward_inputs_with_grad)
paddle.autograd.backward(
forward_outputs_with_grad, backward_inputs_with_grad
)
grads = list(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase))
grads = list(
inp._grad_ivar()
for inp in detached_inputs
if isinstance(inp, core.VarBase)
)
return grads
class RecomputeFunction(PyLayer):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args, **kwargs):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
# store for recomputing
ctx.run_function = run_function
......@@ -232,30 +261,37 @@ class RecomputeFunction(PyLayer):
cur_device = paddle.get_device()
if 'gpu:' not in cur_device:
raise RuntimeError(
"Recompute with RNG perserve is not support current device: {}."
.format(cur_device))
"Recompute with RNG perserve is not support current device: {}.".format(
cur_device
)
)
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker(
).get_states_tracker()
ctx.fwd_cuda_rng_state_tracker = (
get_rng_state_tracker().get_states_tracker()
)
# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
ctx.is_fw_autocast = (
False if tracer._amp_level == core.AmpLevel.O0 else True
)
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
raise ValueError(
"unsupported amp level: {}".format(tracer._amp_level)
)
if tracer._amp_dtype == 'float16':
ctx.amp_dtype = 'float16'
elif tracer._amp_dtype in ('bfloat16', 'float32'):
ctx.amp_dtype = 'bfloat16'
else:
raise ValueError("unsupported amp dtype: {}".format(
tracer._amp_dtype))
raise ValueError(
"unsupported amp dtype: {}".format(tracer._amp_dtype)
)
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
......@@ -265,7 +301,10 @@ class RecomputeFunction(PyLayer):
@staticmethod
def backward(ctx, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
with paddle.fluid.dygraph.guard():
# TODO need to check the recompute calling is vaild or not
......@@ -283,28 +322,33 @@ class RecomputeFunction(PyLayer):
# NOTE support AMP
# need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state:
with swith_rng_state_tracker(ctx.fw_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker):
with swith_rng_state_tracker(
ctx.fw_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker
):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level,
dtype=ctx.amp_dtype):
dtype=ctx.amp_dtype,
):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs,
**ctx.kwargs)
outputs = ctx.run_function(
*detached_inputs, **ctx.kwargs
)
else:
with paddle.amp.auto_cast(enable=ctx.is_fw_autocast,
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level,
dtype=ctx.amp_dtype):
dtype=ctx.amp_dtype,
):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
if isinstance(outputs, (core.VarBase, core.eager.Tensor)):
outputs = (outputs, )
outputs = (outputs,)
assert len(outputs) == len(args)
# run backward() with only tensor that requires grad
......@@ -315,10 +359,10 @@ class RecomputeFunction(PyLayer):
# the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad = []
for i in range(len(outputs)):
if isinstance(
outputs[i],
(core.VarBase,
core.eager.Tensor)) and not outputs[i].stop_gradient:
if (
isinstance(outputs[i], (core.VarBase, core.eager.Tensor))
and not outputs[i].stop_gradient
):
forward_outputs_with_grad.append(outputs[i])
backward_inputs_with_grad.append(args[i])
......@@ -329,17 +373,22 @@ class RecomputeFunction(PyLayer):
# actually backward
with paddle.amp.auto_cast(enable=False):
paddle.autograd.backward(forward_outputs_with_grad,
backward_inputs_with_grad)
paddle.autograd.backward(
forward_outputs_with_grad, backward_inputs_with_grad
)
if in_dygraph_mode():
grads = tuple(
inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, (core.VarBase, core.eager.Tensor)))
inp._grad_ivar()
for inp in detached_inputs
if isinstance(inp, (core.VarBase, core.eager.Tensor))
)
else:
grads = list(
inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, (core.VarBase, core.eager.Tensor)))
inp._grad_ivar()
for inp in detached_inputs
if isinstance(inp, (core.VarBase, core.eager.Tensor))
)
return grads
......@@ -363,13 +412,10 @@ def recompute(function, *args, **kwargs):
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.distributed.fleet.utils import recompute
import random
# required: gpu
def get_fc_block(block_idx, input_size, is_last=False):
block_name = "block_" + str(block_idx)
block = paddle.nn.Sequential(
......@@ -391,10 +437,7 @@ def recompute(function, *args, **kwargs):
block_name + "_fc_2",
paddle.nn.Linear(input_size, input_size, bias_attr=False)
)
return block
class Naive_fc_net(paddle.nn.Layer):
def __init__(self, input_size=10,
recompute_blocks=[1, 3],
......@@ -408,7 +451,6 @@ def recompute(function, *args, **kwargs):
self.runfunc3 = get_fc_block(3, input_size, is_last=False)
self.runfunc4 = get_fc_block(4, input_size, is_last=True)
self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4]
def forward(self, inputs):
nums = len(self.total_func)
for i in range(nums):
......@@ -417,15 +459,12 @@ def recompute(function, *args, **kwargs):
else:
inputs = self.total_func[i](inputs)
return inputs
def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
gen = paddle.seed(10)
gen.manual_seed(10)
np.random.seed(10)
random.seed(10)
if cuda_state:
paddle.set_cuda_rng_state(cuda_state)
batch_size, input_size = 1, 10
model = Naive_fc_net(
input_size,
......@@ -436,29 +475,24 @@ def recompute(function, *args, **kwargs):
param_ = []
grad_ = []
for _ in range(5):
x_data = np.random.randn(batch_size, input_size).astype(np.float32)
x = paddle.to_tensor(x_data)
x = paddle.rand(shape=[batch_size, input_size], dtype="float32")
y_pred = model(x)
loss = y_pred.mean()
loss_.append(np.asarray(loss).tolist())
loss_.append(loss.item())
loss.backward()
optimizer.step()
param_.append(np.asarray(model.parameters()[9]).tolist())
grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist())
param_.append(model.parameters()[9])
grad_.append(model.parameters()[3]._grad_ivar())
optimizer.clear_grad()
return loss_, param_, grad_
cuda_state = paddle.get_cuda_rng_state()
# without recompute
loss_ref, param_ref, grad_ref = run_model(
cuda_state, recompute_block=[]
)
loss, param, grad = run_model(cuda_state, recompute_block=[1, 2])
print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss))
# The result of the recompute_loss should be the same as the normal_loss.
"""
# Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve = kwargs.pop('preserve_rng_state', True)
......@@ -497,7 +531,6 @@ def recompute_sequential(ctx, functions, *args, **kwargs):
preserve_rng_state = ctx.get('preserve_rng_state', True)
def _run_func(begin, end, funcs):
def do_run(input):
for i in range(begin, end + 1):
input = funcs[i](input)
......@@ -513,8 +546,10 @@ def recompute_sequential(ctx, functions, *args, **kwargs):
end = -1
for begin in range(0, segment_size * (segments - 1), segment_size):
end = begin + segment_size - 1
args = recompute(_run_func(begin, end, functions),
args = recompute(
_run_func(begin, end, functions),
*args,
preserve_rng_state=preserve_rng_state,
**kwargs)
**kwargs
)
return _run_func(end + 1, len(functions) - 1, functions)(args)
......@@ -22,14 +22,108 @@ import paddle
from . import log_util # noqa: F401
from . import hybrid_parallel_util # noqa: F401
__all__ = [ #noqa
"LocalFS", "recompute", "DistributedInfer", "HDFSClient"
]
__all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa
@deprecated(since="2.4.0",
update_to="paddle.distributed.fleet.recompute",
level=1,
reason="Please use new recompute API(fleet.recompute) ")
def recompute(function, *args, **kwargs):
"""
recompute intermediate activations to save then memory.
Parameters:
function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation.
*args(Tensor): inputs to the function.
**kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to
indicate whether to save the forward rng. If it is True, then the last forward rng value will be
restored when the forward recalculation of backpropagation is performed. The default
preserve_rng_state is True.
Returns:
Output of function on args.
Examples:
.. code-block:: python
import paddle
from paddle.distributed.fleet.utils import recompute
import random
# required: gpu
def get_fc_block(block_idx, input_size, is_last=False):
block_name = "block_" + str(block_idx)
block = paddle.nn.Sequential(
(block_name + "_fc_0", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
(block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
(block_name + "_relu_1", paddle.nn.ReLU()),
(block_name + "_fc_1", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
(block_name + "_relu_2", paddle.nn.ReLU()),
)
if is_last:
block.add_sublayer(
block_name + "_fc_2",
paddle.nn.Linear(
input_size, 1, bias_attr=False
)
)
else:
block.add_sublayer(
block_name + "_fc_2",
paddle.nn.Linear(input_size, input_size, bias_attr=False)
)
return block
class Naive_fc_net(paddle.nn.Layer):
def __init__(self, input_size=10,
recompute_blocks=[1, 3],
recompute_kwargs={}):
super(Naive_fc_net, self).__init__()
self.recompute_blocks = recompute_blocks
self.recompute_kwargs = recompute_kwargs
self.runfunc0 = get_fc_block(0, input_size, is_last=False)
self.runfunc1 = get_fc_block(1, input_size, is_last=False)
self.runfunc2 = get_fc_block(2, input_size, is_last=False)
self.runfunc3 = get_fc_block(3, input_size, is_last=False)
self.runfunc4 = get_fc_block(4, input_size, is_last=True)
self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4]
def forward(self, inputs):
nums = len(self.total_func)
for i in range(nums):
if i in self.recompute_blocks:
inputs = recompute(self.total_func[i], inputs, **{"preserve_rng_state": True})
else:
inputs = self.total_func[i](inputs)
return inputs
def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
gen = paddle.seed(10)
gen.manual_seed(10)
random.seed(10)
if cuda_state:
paddle.set_cuda_rng_state(cuda_state)
batch_size, input_size = 1, 10
model = Naive_fc_net(
input_size,
recompute_blocks=recompute_block,
recompute_kwargs=recompute_kwargs)
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
loss_ = []
param_ = []
grad_ = []
for _ in range(5):
x = paddle.rand(shape=[batch_size, input_size], dtype="float32")
y_pred = model(x)
loss = y_pred.mean()
loss_.append(loss.item())
loss.backward()
optimizer.step()
param_.append(model.parameters()[9])
grad_.append(model.parameters()[3]._grad_ivar())
optimizer.clear_grad()
return loss_, param_, grad_
cuda_state = paddle.get_cuda_rng_state()
# without recompute
loss_ref, param_ref, grad_ref = run_model(
cuda_state, recompute_block=[]
)
loss, param, grad = run_model(cuda_state, recompute_block=[1, 2])
print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss))
# The result of the recompute_loss should be the same as the normal_loss.
"""
return fleet.recompute.recompute(function, *args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册