未验证 提交 8752c912 编写于 作者: J JZ-LIANG 提交者: GitHub

Dygraph Recompute: support amp (#33251)

* Dygraph Recompute support AMP

* dygraph recompute: update unitest
上级 c70f1cad
...@@ -97,10 +97,12 @@ class RecomputeFunction(PyLayer): ...@@ -97,10 +97,12 @@ class RecomputeFunction(PyLayer):
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
# TODO support AMP # TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = tracer._enable_autocast
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad(): with paddle.no_grad():
outputs = run_function(*args) outputs = run_function(*args)
return outputs return outputs
@staticmethod @staticmethod
...@@ -119,13 +121,21 @@ class RecomputeFunction(PyLayer): ...@@ -119,13 +121,21 @@ class RecomputeFunction(PyLayer):
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
tracer._has_grad = True tracer._has_grad = True
# TODO support AMP # NOTE support AMP
# need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state: if ctx.preserve_rng_state:
with swith_rng_state(ctx.fw_cuda_rng_state): with swith_rng_state(ctx.fw_cuda_rng_state):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list):
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs) outputs = ctx.run_function(*detached_inputs)
else: else:
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list):
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs) outputs = ctx.run_function(*detached_inputs)
...@@ -155,7 +165,6 @@ class RecomputeFunction(PyLayer): ...@@ -155,7 +165,6 @@ class RecomputeFunction(PyLayer):
grads = list(inp._grad_ivar() for inp in detached_inputs grads = list(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase)) if isinstance(inp, core.VarBase))
return grads return grads
......
...@@ -92,15 +92,12 @@ class Naive_fc_net(paddle.nn.Layer): ...@@ -92,15 +92,12 @@ class Naive_fc_net(paddle.nn.Layer):
return inputs return inputs
def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): def run_model(recompute_block=[], recompute_kwargs={}, enable_autocast=False):
gen = paddle.seed(10) gen = paddle.seed(10)
gen.manual_seed(10) gen.manual_seed(10)
np.random.seed(10) np.random.seed(10)
random.seed(10) random.seed(10)
if cuda_state:
paddle.set_cuda_rng_state(cuda_state)
batch_size, input_size = 1, 10 batch_size, input_size = 1, 10
model = Naive_fc_net( model = Naive_fc_net(
input_size, input_size,
...@@ -110,16 +107,24 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): ...@@ -110,16 +107,24 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
optimizer = paddle.optimizer.SGD(learning_rate=0.01, optimizer = paddle.optimizer.SGD(learning_rate=0.01,
parameters=model.parameters()) parameters=model.parameters())
if enable_autocast:
scaler = paddle.amp.GradScaler()
loss_ = [] loss_ = []
param_ = [] param_ = []
grad_ = [] grad_ = []
for step in range(10): for step in range(10):
x_data = np.random.randn(batch_size, input_size).astype(np.float32) x_data = np.random.randn(batch_size, input_size).astype(np.float32)
x = paddle.to_tensor(x_data) x = paddle.to_tensor(x_data)
# x.stop_gradient = False # x.stop_gradient = False
with paddle.amp.auto_cast(True):
y_pred = model(x) y_pred = model(x)
loss = y_pred.mean() loss = y_pred.mean()
if enable_autocast:
scaler.scale(loss).backward()
scaler.minimize(optimizer, loss)
else:
loss_.append(np.asarray(loss).tolist()) loss_.append(np.asarray(loss).tolist())
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -138,25 +143,57 @@ class TestPyLayer(unittest.TestCase): ...@@ -138,25 +143,57 @@ class TestPyLayer(unittest.TestCase):
self.assertEqual(param_ref, param) self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad) self.assertEqual(grad_ref, grad)
cuda_state = paddle.get_cuda_rng_state() # without recompute
loss_ref, param_ref, grad_ref = run_model(recompute_block=[])
# recompute second block
loss, param, grad = run_model(recompute_block=[1])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute fourth block
loss, param, grad = run_model(recompute_block=[3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second to fourth block
loss, param, grad = run_model(recompute_block=[1, 2, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second & fourth block
loss, param, grad = run_model(recompute_block=[1, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_without_restore_rng(self):
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2],
recompute_kwargs={"preserve_rng_state": False},
enable_autocast=True)
def test_fc_net_with_amp(self):
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)
# without recompute # without recompute
loss_ref, param_ref, grad_ref = run_model( loss_ref, param_ref, grad_ref = run_model(
cuda_state, recompute_block=[]) recompute_block=[], enable_autocast=True)
# recompute second block # recompute second block
loss, param, grad = run_model(cuda_state, recompute_block=[1, 3]) loss, param, grad = run_model(recompute_block=[1], enable_autocast=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute fourth block # recompute fourth block
loss, param, grad = run_model(cuda_state, recompute_block=[3]) loss, param, grad = run_model(recompute_block=[3], enable_autocast=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second to fourth block # recompute second to fourth block
loss, param, grad = run_model(cuda_state, recompute_block=[1, 2, 3]) loss, param, grad = run_model(
recompute_block=[1, 2, 3], enable_autocast=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second & fourth block # recompute second & fourth block
loss, param, grad = run_model(cuda_state, recompute_block=[1, 3]) loss, param, grad = run_model(
recompute_block=[1, 3], enable_autocast=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_recompute_kwargs(self): def test_recompute_kwargs(self):
...@@ -164,12 +201,12 @@ class TestPyLayer(unittest.TestCase): ...@@ -164,12 +201,12 @@ class TestPyLayer(unittest.TestCase):
kwargs = {"is_test": False} kwargs = {"is_test": False}
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
loss_ref, param_ref, grad_ref = run_model( loss_ref, param_ref, grad_ref = run_model(
None, recompute_block=[2], recompute_kwargs=kwargs) recompute_block=[2], recompute_kwargs=kwargs)
def test_recompute_cpu_rng(self): def test_recompute_cpu_rng(self):
paddle.set_device("cpu") paddle.set_device("cpu")
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(None, recompute_block=[2]) loss_ref, param_ref, grad_ref = run_model(recompute_block=[2])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册