未验证 提交 625dd722 编写于 作者: S ShenLiang 提交者: GitHub

fix recompute (#41396)

上级 a6b6bcbf
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.autograd import PyLayer from paddle.autograd import PyLayer, EagerPyLayer
from paddle.fluid import framework from paddle.fluid import framework
import contextlib import contextlib
from paddle.fluid.framework import in_dygraph_mode
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -32,7 +34,7 @@ __all__ = [] ...@@ -32,7 +34,7 @@ __all__ = []
def detach_variable(inputs): def detach_variable(inputs):
out = [] out = []
for inp in inputs: for inp in inputs:
if not isinstance(inp, core.VarBase): if not isinstance(inp, (core.eager.Tensor, core.VarBase)):
out.append(inp) out.append(inp)
continue continue
...@@ -44,7 +46,7 @@ def detach_variable(inputs): ...@@ -44,7 +46,7 @@ def detach_variable(inputs):
def check_recompute_necessary(inputs): def check_recompute_necessary(inputs):
if not any(input_.stop_gradient == False for input_ in inputs if not any(input_.stop_gradient == False for input_ in inputs
if isinstance(input_, paddle.Tensor)): if isinstance(input_, (core.eager.Tensor, paddle.Tensor))):
logger.warn( logger.warn(
"[Recompute]: None of the inputs to current recompute block need grad, " "[Recompute]: None of the inputs to current recompute block need grad, "
"therefore there is NO need to recompute this block in backward !") "therefore there is NO need to recompute this block in backward !")
...@@ -60,6 +62,140 @@ def swith_rng_state(rng_state): ...@@ -60,6 +62,140 @@ def swith_rng_state(rng_state):
paddle.set_cuda_rng_state(orig_cuda_rng_state) paddle.set_cuda_rng_state(orig_cuda_rng_state)
class EagerRecomputeFunction(EagerPyLayer):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
if framework._dygraph_tracer()._has_grad:
check_recompute_necessary(args)
# store for recomputing
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
# NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
# the order of tensors in backward()'s output should be the same as tensors in forward()'s input
# None tensor inputs will be filtered in backward inputs.
# save input for backward
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for i, arg in enumerate(args):
if paddle.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
# NOTE recompute with restore RNG only support one senario where one process for one cuda gpu.
# one process with multiple gpu and mix-gpu-cpu senarios are not support
if ctx.preserve_rng_state:
cur_device = paddle.get_device()
if 'gpu:' not in cur_device:
raise RuntimeError(
"Recompute with RNG perserve is not support current device: {}.".
format(cur_device))
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
if tracer._amp_dtype == 'float16':
ctx.amp_dtype = 'float16'
elif tracer._amp_dtype in ('bfloat16', 'float32'):
ctx.amp_dtype = 'bfloat16'
else:
raise ValueError("unsupported amp dtype: {}".format(
tracer._amp_dtype))
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *args):
with paddle.fluid.dygraph.guard():
# TODO need to check the recompute calling is vaild or not
# Restore inputs
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensor()
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]
# paddle.enable_grad()
tracer = framework._dygraph_tracer()
tracer._has_grad = True
# NOTE support AMP
# need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state:
with swith_rng_state(ctx.fw_cuda_rng_state):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level,
dtype=ctx.amp_dtype):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
else:
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level,
dtype=ctx.amp_dtype):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, core.eager.Tensor):
outputs = (outputs, )
assert len(outputs) == len(args)
# run backward() with only tensor that requires grad
forward_outputs_with_grad = []
# NOTE In Transformer-like network, if user put the attention mask into the recompute segment output,
# pylayer will force the stop_gradient of attention mask to be False, which will make the number of
# tensor that need grad does not match.
# the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad = []
for i in range(len(outputs)):
if isinstance(
outputs[i],
core.eager.Tensor) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
backward_inputs_with_grad.append(args[i])
if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True, this recompute() is not necessary"
)
# actually backward
with paddle.amp.auto_cast(enable=False):
paddle.autograd.backward(forward_outputs_with_grad,
backward_inputs_with_grad)
grads = tuple(
inp.grad for inp in detached_inputs
if isinstance(inp, core.eager.Tensor))
return grads
class RecomputeFunction(PyLayer): class RecomputeFunction(PyLayer):
@staticmethod @staticmethod
def forward(ctx, run_function, preserve_rng_state, *args): def forward(ctx, run_function, preserve_rng_state, *args):
...@@ -315,4 +451,7 @@ def recompute(function, *args, **kwargs): ...@@ -315,4 +451,7 @@ def recompute(function, *args, **kwargs):
raise ValueError("Unexpected keyword arguments: " + ",".join( raise ValueError("Unexpected keyword arguments: " + ",".join(
arg for arg in kwargs)) arg for arg in kwargs))
return RecomputeFunction.apply(function, preserve, *args) if in_dygraph_mode():
return EagerRecomputeFunction.apply(function, preserve, *args)
else:
return RecomputeFunction.apply(function, preserve, *args)
...@@ -23,6 +23,7 @@ from paddle.distributed.fleet.utils import recompute ...@@ -23,6 +23,7 @@ from paddle.distributed.fleet.utils import recompute
import random import random
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid.framework import _test_eager_guard
def get_fc_block(block_idx, input_size, is_last=False): def get_fc_block(block_idx, input_size, is_last=False):
...@@ -141,96 +142,75 @@ def run_model(recompute_block=[], ...@@ -141,96 +142,75 @@ def run_model(recompute_block=[],
class TestPyLayer(unittest.TestCase): class TestPyLayer(unittest.TestCase):
def test_fc_net_with_dropout(self): def test_base_case(self, enable_autocast=False, pure_fp16=False):
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
self.assertEqual(loss_ref, loss) self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param) self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad) self.assertEqual(grad_ref, grad)
# without recompute # without recompute
loss_ref, param_ref, grad_ref = run_model(recompute_block=[])
# recompute second block
loss, param, grad = run_model(recompute_block=[1])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute fourth block
loss, param, grad = run_model(recompute_block=[3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second to fourth block
loss, param, grad = run_model(recompute_block=[1, 2, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second & fourth block
loss, param, grad = run_model(recompute_block=[1, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_without_restore_rng(self):
loss_ref, param_ref, grad_ref = run_model( loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2], recompute_block=[],
recompute_kwargs={"preserve_rng_state": False}, enable_autocast=enable_autocast,
enable_autocast=True) pure_fp16=pure_fp16)
def test_fc_net_with_amp(self):
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)
# without recompute
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[], enable_autocast=True)
# recompute second block # recompute second block
loss, param, grad = run_model(recompute_block=[1], enable_autocast=True) loss, param, grad = run_model(
recompute_block=[1],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute fourth block # recompute fourth block
loss, param, grad = run_model(recompute_block=[3], enable_autocast=True) loss, param, grad = run_model(
recompute_block=[3],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second to fourth block # recompute second to fourth block
loss, param, grad = run_model( loss, param, grad = run_model(
recompute_block=[1, 2, 3], enable_autocast=True) recompute_block=[1, 2, 3],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second & fourth block # recompute second & fourth block
loss, param, grad = run_model( loss, param, grad = run_model(
recompute_block=[1, 3], enable_autocast=True) recompute_block=[1, 3],
enable_autocast=enable_autocast,
pure_fp16=pure_fp16)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_with_fp16(self): def test_fc_net_with_dropout(self):
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): with _test_eager_guard():
self.assertEqual(loss_ref, loss) self.test_base_case()
self.assertEqual(param_ref, param) self.test_base_case()
self.assertEqual(grad_ref, grad)
# without recompute
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[], enable_autocast=True, pure_fp16=True)
# recompute second block
loss, param, grad = run_model(
recompute_block=[1], enable_autocast=True, pure_fp16=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute fourth block def test_fc_net_without_restore_rng(self):
loss, param, grad = run_model( with _test_eager_guard():
recompute_block=[3], enable_autocast=True, pure_fp16=True) loss_ref, param_ref, grad_ref = run_model(
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) recompute_block=[2],
recompute_kwargs={"preserve_rng_state": False},
enable_autocast=True)
# recompute second to fourth block def test_fc_net_with_amp(self):
loss, param, grad = run_model( with _test_eager_guard():
recompute_block=[1, 2, 3], enable_autocast=True, pure_fp16=True) self.test_base_case(enable_autocast=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) self.test_base_case(enable_autocast=True)
# recompute second & fourth block def test_fc_net_with_fp16(self):
loss, param, grad = run_model( with _test_eager_guard():
recompute_block=[1, 3], enable_autocast=True, pure_fp16=True) self.test_base_case(enable_autocast=True, pure_fp16=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) self.test_base_case(enable_autocast=True, pure_fp16=True)
def test_recompute_kwargs(self): def test_recompute_kwargs(self):
with _test_eager_guard():
paddle.set_device("gpu")
kwargs = {"is_test": False}
with self.assertRaises(ValueError):
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2], recompute_kwargs=kwargs)
paddle.set_device("gpu") paddle.set_device("gpu")
kwargs = {"is_test": False} kwargs = {"is_test": False}
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -238,6 +218,11 @@ class TestPyLayer(unittest.TestCase): ...@@ -238,6 +218,11 @@ class TestPyLayer(unittest.TestCase):
recompute_block=[2], recompute_kwargs=kwargs) recompute_block=[2], recompute_kwargs=kwargs)
def test_recompute_cpu_rng(self): def test_recompute_cpu_rng(self):
with _test_eager_guard():
paddle.set_device("cpu")
with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2])
paddle.set_device("cpu") paddle.set_device("cpu")
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2]) loss_ref, param_ref, grad_ref = run_model(recompute_block=[2])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册