未验证 提交 625dd722 编写于 作者: S ShenLiang 提交者: GitHub

fix recompute (#41396)

上级 a6b6bcbf
......@@ -14,9 +14,11 @@
import paddle
from paddle.fluid import core
from paddle.autograd import PyLayer
from paddle.autograd import PyLayer, EagerPyLayer
from paddle.fluid import framework
import contextlib
from paddle.fluid.framework import in_dygraph_mode
import logging
logger = logging.getLogger(__name__)
......@@ -32,7 +34,7 @@ __all__ = []
def detach_variable(inputs):
out = []
for inp in inputs:
if not isinstance(inp, core.VarBase):
if not isinstance(inp, (core.eager.Tensor, core.VarBase)):
out.append(inp)
continue
......@@ -44,7 +46,7 @@ def detach_variable(inputs):
def check_recompute_necessary(inputs):
if not any(input_.stop_gradient == False for input_ in inputs
if isinstance(input_, paddle.Tensor)):
if isinstance(input_, (core.eager.Tensor, paddle.Tensor))):
logger.warn(
"[Recompute]: None of the inputs to current recompute block need grad, "
"therefore there is NO need to recompute this block in backward !")
......@@ -60,6 +62,140 @@ def swith_rng_state(rng_state):
paddle.set_cuda_rng_state(orig_cuda_rng_state)
class EagerRecomputeFunction(EagerPyLayer):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
if framework._dygraph_tracer()._has_grad:
check_recompute_necessary(args)
# store for recomputing
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
# NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
# the order of tensors in backward()'s output should be the same as tensors in forward()'s input
# None tensor inputs will be filtered in backward inputs.
# save input for backward
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for i, arg in enumerate(args):
if paddle.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
# NOTE recompute with restore RNG only support one senario where one process for one cuda gpu.
# one process with multiple gpu and mix-gpu-cpu senarios are not support
if ctx.preserve_rng_state:
cur_device = paddle.get_device()
if 'gpu:' not in cur_device:
raise RuntimeError(
"Recompute with RNG perserve is not support current device: {}.".
format(cur_device))
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
if tracer._amp_dtype == 'float16':
ctx.amp_dtype = 'float16'
elif tracer._amp_dtype in ('bfloat16', 'float32'):
ctx.amp_dtype = 'bfloat16'
else:
raise ValueError("unsupported amp dtype: {}".format(
tracer._amp_dtype))
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *args):
with paddle.fluid.dygraph.guard():
# TODO need to check the recompute calling is vaild or not
# Restore inputs
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensor()
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]
# paddle.enable_grad()
tracer = framework._dygraph_tracer()
tracer._has_grad = True
# NOTE support AMP
# need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state:
with swith_rng_state(ctx.fw_cuda_rng_state):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level,
dtype=ctx.amp_dtype):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
else:
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level,
dtype=ctx.amp_dtype):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, core.eager.Tensor):
outputs = (outputs, )
assert len(outputs) == len(args)
# run backward() with only tensor that requires grad
forward_outputs_with_grad = []
# NOTE In Transformer-like network, if user put the attention mask into the recompute segment output,
# pylayer will force the stop_gradient of attention mask to be False, which will make the number of
# tensor that need grad does not match.
# the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad = []
for i in range(len(outputs)):
if isinstance(
outputs[i],
core.eager.Tensor) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
backward_inputs_with_grad.append(args[i])
if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True, this recompute() is not necessary"
)
# actually backward
with paddle.amp.auto_cast(enable=False):
paddle.autograd.backward(forward_outputs_with_grad,
backward_inputs_with_grad)
grads = tuple(
inp.grad for inp in detached_inputs
if isinstance(inp, core.eager.Tensor))
return grads
class RecomputeFunction(PyLayer):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
......@@ -315,4 +451,7 @@ def recompute(function, *args, **kwargs):
raise ValueError("Unexpected keyword arguments: " + ",".join(
arg for arg in kwargs))
return RecomputeFunction.apply(function, preserve, *args)
if in_dygraph_mode():
return EagerRecomputeFunction.apply(function, preserve, *args)
else:
return RecomputeFunction.apply(function, preserve, *args)
......@@ -23,6 +23,7 @@ from paddle.distributed.fleet.utils import recompute
import random
import paddle.fluid.layers as layers
from paddle.fluid.framework import _test_eager_guard
def get_fc_block(block_idx, input_size, is_last=False):
......@@ -141,96 +142,75 @@ def run_model(recompute_block=[],
class TestPyLayer(unittest.TestCase):
def test_fc_net_with_dropout(self):
def test_base_case(self, enable_autocast=False, pure_fp16=False):
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)
# without recompute
loss_ref, param_ref, grad_ref = run_model(recompute_block=[])
# recompute second block
loss, param, grad = run_model(recompute_block=[1])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute fourth block
loss, param, grad = run_model(recompute_block=[3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second to fourth block
loss, param, grad = run_model(recompute_block=[1, 2, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second & fourth block
loss, param, grad = run_model(recompute_block=[1, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_without_restore_rng(self):
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2],
recompute_kwargs={"preserve_rng_state": False},
enable_autocast=True)
def test_fc_net_with_amp(self):
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)
# without recompute
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[], enable_autocast=True)
recompute_block=[],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
# recompute second block
loss, param, grad = run_model(recompute_block=[1], enable_autocast=True)
loss, param, grad = run_model(
recompute_block=[1],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute fourth block
loss, param, grad = run_model(recompute_block=[3], enable_autocast=True)
loss, param, grad = run_model(
recompute_block=[3],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second to fourth block
loss, param, grad = run_model(
recompute_block=[1, 2, 3], enable_autocast=True)
recompute_block=[1, 2, 3],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second & fourth block
loss, param, grad = run_model(
recompute_block=[1, 3], enable_autocast=True)
recompute_block=[1, 3],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_with_fp16(self):
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)
# without recompute
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[], enable_autocast=True, pure_fp16=True)
# recompute second block
loss, param, grad = run_model(
recompute_block=[1], enable_autocast=True, pure_fp16=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_with_dropout(self):
with _test_eager_guard():
self.test_base_case()
self.test_base_case()
# recompute fourth block
loss, param, grad = run_model(
recompute_block=[3], enable_autocast=True, pure_fp16=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_without_restore_rng(self):
with _test_eager_guard():
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2],
recompute_kwargs={"preserve_rng_state": False},
enable_autocast=True)
# recompute second to fourth block
loss, param, grad = run_model(
recompute_block=[1, 2, 3], enable_autocast=True, pure_fp16=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_with_amp(self):
with _test_eager_guard():
self.test_base_case(enable_autocast=True)
self.test_base_case(enable_autocast=True)
# recompute second & fourth block
loss, param, grad = run_model(
recompute_block=[1, 3], enable_autocast=True, pure_fp16=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_with_fp16(self):
with _test_eager_guard():
self.test_base_case(enable_autocast=True, pure_fp16=True)
self.test_base_case(enable_autocast=True, pure_fp16=True)
def test_recompute_kwargs(self):
with _test_eager_guard():
paddle.set_device("gpu")
kwargs = {"is_test": False}
with self.assertRaises(ValueError):
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2], recompute_kwargs=kwargs)
paddle.set_device("gpu")
kwargs = {"is_test": False}
with self.assertRaises(ValueError):
......@@ -238,6 +218,11 @@ class TestPyLayer(unittest.TestCase):
recompute_block=[2], recompute_kwargs=kwargs)
def test_recompute_cpu_rng(self):
with _test_eager_guard():
paddle.set_device("cpu")
with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2])
paddle.set_device("cpu")
with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册