未验证 提交 8752c912 编写于 作者: J JZ-LIANG 提交者: GitHub

Dygraph Recompute: support amp (#33251)

* Dygraph Recompute support AMP

* dygraph recompute: update unitest
上级 c70f1cad
......@@ -97,10 +97,12 @@ class RecomputeFunction(PyLayer):
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = tracer._enable_autocast
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
......@@ -119,15 +121,23 @@ class RecomputeFunction(PyLayer):
tracer = framework._dygraph_tracer()
tracer._has_grad = True
# TODO support AMP
# NOTE support AMP
# need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state:
with swith_rng_state(ctx.fw_cuda_rng_state):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
else:
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
else:
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, core.VarBase):
outputs = (outputs, )
......@@ -155,7 +165,6 @@ class RecomputeFunction(PyLayer):
grads = list(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase))
return grads
......
......@@ -92,15 +92,12 @@ class Naive_fc_net(paddle.nn.Layer):
return inputs
def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
def run_model(recompute_block=[], recompute_kwargs={}, enable_autocast=False):
gen = paddle.seed(10)
gen.manual_seed(10)
np.random.seed(10)
random.seed(10)
if cuda_state:
paddle.set_cuda_rng_state(cuda_state)
batch_size, input_size = 1, 10
model = Naive_fc_net(
input_size,
......@@ -110,19 +107,27 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
optimizer = paddle.optimizer.SGD(learning_rate=0.01,
parameters=model.parameters())
if enable_autocast:
scaler = paddle.amp.GradScaler()
loss_ = []
param_ = []
grad_ = []
for step in range(10):
x_data = np.random.randn(batch_size, input_size).astype(np.float32)
x = paddle.to_tensor(x_data)
# x.stop_gradient = False
y_pred = model(x)
loss = y_pred.mean()
loss_.append(np.asarray(loss).tolist())
loss.backward()
optimizer.step()
with paddle.amp.auto_cast(True):
y_pred = model(x)
loss = y_pred.mean()
if enable_autocast:
scaler.scale(loss).backward()
scaler.minimize(optimizer, loss)
else:
loss_.append(np.asarray(loss).tolist())
loss.backward()
optimizer.step()
param_.append(np.asarray(model.parameters()[9]).tolist())
grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist())
......@@ -138,25 +143,57 @@ class TestPyLayer(unittest.TestCase):
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)
cuda_state = paddle.get_cuda_rng_state()
# without recompute
loss_ref, param_ref, grad_ref = run_model(recompute_block=[])
# recompute second block
loss, param, grad = run_model(recompute_block=[1])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute fourth block
loss, param, grad = run_model(recompute_block=[3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second to fourth block
loss, param, grad = run_model(recompute_block=[1, 2, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second & fourth block
loss, param, grad = run_model(recompute_block=[1, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_fc_net_without_restore_rng(self):
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2],
recompute_kwargs={"preserve_rng_state": False},
enable_autocast=True)
def test_fc_net_with_amp(self):
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)
# without recompute
loss_ref, param_ref, grad_ref = run_model(
cuda_state, recompute_block=[])
recompute_block=[], enable_autocast=True)
# recompute second block
loss, param, grad = run_model(cuda_state, recompute_block=[1, 3])
loss, param, grad = run_model(recompute_block=[1], enable_autocast=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute fourth block
loss, param, grad = run_model(cuda_state, recompute_block=[3])
loss, param, grad = run_model(recompute_block=[3], enable_autocast=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second to fourth block
loss, param, grad = run_model(cuda_state, recompute_block=[1, 2, 3])
loss, param, grad = run_model(
recompute_block=[1, 2, 3], enable_autocast=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
# recompute second & fourth block
loss, param, grad = run_model(cuda_state, recompute_block=[1, 3])
loss, param, grad = run_model(
recompute_block=[1, 3], enable_autocast=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
def test_recompute_kwargs(self):
......@@ -164,12 +201,12 @@ class TestPyLayer(unittest.TestCase):
kwargs = {"is_test": False}
with self.assertRaises(ValueError):
loss_ref, param_ref, grad_ref = run_model(
None, recompute_block=[2], recompute_kwargs=kwargs)
recompute_block=[2], recompute_kwargs=kwargs)
def test_recompute_cpu_rng(self):
paddle.set_device("cpu")
with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(None, recompute_block=[2])
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2])
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册