未验证 提交 522c2bc0 编写于 作者: W wanghuancoder 提交者: GitHub

delete old dygraph pylayer recompute (#49338)

上级 2bbdc47a
......@@ -18,7 +18,6 @@ import weakref
import paddle
from paddle import framework
from paddle.autograd import PyLayer
from paddle.autograd.py_layer import LegacyPyLayer
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
......@@ -67,147 +66,6 @@ def swith_rng_state_tracker(rng_state, tracker):
get_rng_state_tracker().set_states_tracker(orig_rng_tracker)
class LegacyRecomputeFunction(LegacyPyLayer):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
# store for recomputing
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
# NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
# the order of tensors in backward()'s output should be the same as tensors in forward()'s input
# None tensor inputs will be filtered in backward inputs.
# save input for backward
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for i, arg in enumerate(args):
if paddle.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
# NOTE recompute with restore RNG only support one senario where one process for one cuda gpu.
# one process with multiple gpu and mix-gpu-cpu senarios are not support
if ctx.preserve_rng_state:
ctx.fw_rng_state = paddle.get_rng_state()
ctx.fwd_rng_state_tracker = (
get_rng_state_tracker().get_states_tracker()
)
# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = (
False if tracer._amp_level == core.AmpLevel.O0 else True
)
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
raise ValueError(
"unsupported amp level: {}".format(tracer._amp_level)
)
if tracer._amp_dtype == 'float16':
ctx.amp_dtype = 'float16'
elif tracer._amp_dtype in ('bfloat16', 'float32'):
ctx.amp_dtype = 'bfloat16'
else:
raise ValueError(
"unsupported amp dtype: {}".format(tracer._amp_dtype)
)
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *args):
with paddle.fluid.dygraph.guard():
# TODO need to check the recompute calling is vaild or not
# Restore inputs
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensor()
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]
# paddle.enable_grad()
tracer = framework._dygraph_tracer()
tracer._has_grad = True
# NOTE support AMP
# need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state:
with swith_rng_state_tracker(
ctx.fw_rng_state, ctx.fwd_rng_state_tracker
):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level,
dtype=ctx.amp_dtype,
):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
else:
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level,
dtype=ctx.amp_dtype,
):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, core.VarBase):
outputs = (outputs,)
assert len(outputs) == len(args)
# run backward() with only tensor that requires grad
forward_outputs_with_grad = []
# NOTE In Transformer-like network, if user put the attention mask into the recompute segment output,
# pylayer will force the stop_gradient of attention mask to be False, which will make the number of
# tensor that need grad does not match.
# the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad = []
for i in range(len(outputs)):
if (
isinstance(outputs[i], core.VarBase)
and not outputs[i].stop_gradient
):
forward_outputs_with_grad.append(outputs[i])
backward_inputs_with_grad.append(args[i])
if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True, this recompute() is not necessary"
)
# actually backward
with paddle.amp.auto_cast(enable=False):
paddle.autograd.backward(
forward_outputs_with_grad, backward_inputs_with_grad
)
grads = list(
inp._grad_ivar()
for inp in detached_inputs
if isinstance(inp, core.VarBase)
)
return grads
class RecomputeFunction(PyLayer):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args, **kwargs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册