diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py old mode 100644 new mode 100755 index e58c8aa1625ddecc6f80810d0266958a75ea4956..78503baf2fd5d2833e557a8d4e2f7271545aeca7 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -97,10 +97,12 @@ class RecomputeFunction(PyLayer): ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() # TODO support AMP + tracer = framework._dygraph_tracer() + ctx.is_fw_autocast = tracer._enable_autocast + ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() with paddle.no_grad(): outputs = run_function(*args) - return outputs @staticmethod @@ -119,15 +121,23 @@ class RecomputeFunction(PyLayer): tracer = framework._dygraph_tracer() tracer._has_grad = True - # TODO support AMP - + # NOTE support AMP + # need restore auto_cast state as well as w/b list if ctx.preserve_rng_state: with swith_rng_state(ctx.fw_cuda_rng_state): + with paddle.amp.auto_cast( + enable=ctx.is_fw_autocast, + custom_white_list=ctx.amp_white_list, + custom_black_list=ctx.amp_black_list): + detached_inputs = detach_variable(tuple(inputs)) + outputs = ctx.run_function(*detached_inputs) + else: + with paddle.amp.auto_cast( + enable=ctx.is_fw_autocast, + custom_white_list=ctx.amp_white_list, + custom_black_list=ctx.amp_black_list): detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs) - else: - detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, core.VarBase): outputs = (outputs, ) @@ -155,7 +165,6 @@ class RecomputeFunction(PyLayer): grads = list(inp._grad_ivar() for inp in detached_inputs if isinstance(inp, core.VarBase)) - return grads diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py index 6de04c14bfa7080bcbf5e3b4c55f98da0f09a863..332603b812955000b4a58d31fd14b21225a9a0c8 100755 --- a/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py @@ -92,15 +92,12 @@ class Naive_fc_net(paddle.nn.Layer): return inputs -def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): +def run_model(recompute_block=[], recompute_kwargs={}, enable_autocast=False): gen = paddle.seed(10) gen.manual_seed(10) np.random.seed(10) random.seed(10) - if cuda_state: - paddle.set_cuda_rng_state(cuda_state) - batch_size, input_size = 1, 10 model = Naive_fc_net( input_size, @@ -110,19 +107,27 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters()) + if enable_autocast: + scaler = paddle.amp.GradScaler() + loss_ = [] param_ = [] grad_ = [] for step in range(10): + x_data = np.random.randn(batch_size, input_size).astype(np.float32) x = paddle.to_tensor(x_data) # x.stop_gradient = False - y_pred = model(x) - loss = y_pred.mean() - - loss_.append(np.asarray(loss).tolist()) - loss.backward() - optimizer.step() + with paddle.amp.auto_cast(True): + y_pred = model(x) + loss = y_pred.mean() + if enable_autocast: + scaler.scale(loss).backward() + scaler.minimize(optimizer, loss) + else: + loss_.append(np.asarray(loss).tolist()) + loss.backward() + optimizer.step() param_.append(np.asarray(model.parameters()[9]).tolist()) grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist()) @@ -138,25 +143,57 @@ class TestPyLayer(unittest.TestCase): self.assertEqual(param_ref, param) self.assertEqual(grad_ref, grad) - cuda_state = paddle.get_cuda_rng_state() + # without recompute + loss_ref, param_ref, grad_ref = run_model(recompute_block=[]) + + # recompute second block + loss, param, grad = run_model(recompute_block=[1]) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute fourth block + loss, param, grad = run_model(recompute_block=[3]) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute second to fourth block + loss, param, grad = run_model(recompute_block=[1, 2, 3]) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute second & fourth block + loss, param, grad = run_model(recompute_block=[1, 3]) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + def test_fc_net_without_restore_rng(self): + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[2], + recompute_kwargs={"preserve_rng_state": False}, + enable_autocast=True) + + def test_fc_net_with_amp(self): + def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): + self.assertEqual(loss_ref, loss) + self.assertEqual(param_ref, param) + self.assertEqual(grad_ref, grad) + # without recompute loss_ref, param_ref, grad_ref = run_model( - cuda_state, recompute_block=[]) + recompute_block=[], enable_autocast=True) # recompute second block - loss, param, grad = run_model(cuda_state, recompute_block=[1, 3]) + loss, param, grad = run_model(recompute_block=[1], enable_autocast=True) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) # recompute fourth block - loss, param, grad = run_model(cuda_state, recompute_block=[3]) + loss, param, grad = run_model(recompute_block=[3], enable_autocast=True) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) # recompute second to fourth block - loss, param, grad = run_model(cuda_state, recompute_block=[1, 2, 3]) + loss, param, grad = run_model( + recompute_block=[1, 2, 3], enable_autocast=True) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) # recompute second & fourth block - loss, param, grad = run_model(cuda_state, recompute_block=[1, 3]) + loss, param, grad = run_model( + recompute_block=[1, 3], enable_autocast=True) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) def test_recompute_kwargs(self): @@ -164,12 +201,12 @@ class TestPyLayer(unittest.TestCase): kwargs = {"is_test": False} with self.assertRaises(ValueError): loss_ref, param_ref, grad_ref = run_model( - None, recompute_block=[2], recompute_kwargs=kwargs) + recompute_block=[2], recompute_kwargs=kwargs) def test_recompute_cpu_rng(self): paddle.set_device("cpu") with self.assertRaises(RuntimeError): - loss_ref, param_ref, grad_ref = run_model(None, recompute_block=[2]) + loss_ref, param_ref, grad_ref = run_model(recompute_block=[2]) if __name__ == '__main__':