From 3e2a3488865b5af7ab1d3d90efa7bd1331aca0bf Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Tue, 28 Jul 2020 15:31:42 +0800 Subject: [PATCH] add string variable support for RecomputeOptimizer (#25728) --- python/paddle/fluid/optimizer.py | 26 ++++++++++++++++--- .../tests/unittests/test_dgc_optimizer.py | 1 + .../fluid/tests/unittests/test_optimizer.py | 17 ++++++++++++ 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 6e7a90e44e5..85d07f687e3 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -15,6 +15,7 @@ from __future__ import print_function import numpy as np +import six import logging from collections import defaultdict @@ -4554,6 +4555,17 @@ class RecomputeOptimizer(Optimizer): self._learning_rate_map = self._optimizer._learning_rate_map def _set_checkpoints(self, checkpoints): + """ + Args: + checkpoints (list): List of Variable or string + """ + assert isinstance( + checkpoints, list + ), "_checkpoints should be a list of Variable or a list of String" + for ckpt in checkpoints: + assert ( + isinstance(ckpt, six.string_types) or isinstance(ckpt, Variable) + ), "_checkpoints should be a list of Variable or a list of String" self._checkpoints = checkpoints def load(self, stat_dict): @@ -4690,6 +4702,8 @@ class RecomputeOptimizer(Optimizer): no_grad_set=None) print("Finished backward") """ + assert (self._checkpoints is not None + ), "You should call _set_checkpoints first" if framework.in_dygraph_mode(): raise NotImplementedError( @@ -4698,11 +4712,15 @@ class RecomputeOptimizer(Optimizer): self._dtype = loss.dtype program = loss.block.program with program_guard(program, startup_program): + checkpoint_vars = [] + for ckpt in self._checkpoints: + if isinstance(ckpt, Variable): + checkpoint_vars.append(ckpt) + else: + checkpoint_vars.append(loss.block.var(ckpt)) + params_grads = append_backward( - loss, - parameter_list, - no_grad_set, - checkpoints=self._checkpoints) + loss, parameter_list, no_grad_set, checkpoints=checkpoint_vars) # Note: since we can't use all_reduce_op now, # dgc_op should be the last op of one grad. if hasattr(self._optimizer, "_append_dgc_ops"): diff --git a/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py b/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py index 70fd4653b48..49b93e0dfaa 100644 --- a/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py @@ -77,6 +77,7 @@ class TestDGCMomentumOptimizer(unittest.TestCase): if use_recompute: dgc_momentum_optimizer = optimizer.RecomputeOptimizer( dgc_momentum_optimizer) + dgc_momentum_optimizer._set_checkpoints([]) dgc_momentum_optimizer.get_accumulators = dgc_momentum_optimizer._optimizer.get_accumulators dgc_momentum_optimizer.get_velocity_str = dgc_momentum_optimizer._optimizer.get_velocity_str diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index 3384c80499b..2e6e516aa2e 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -714,6 +714,23 @@ class TestRecomputeOptimizer(unittest.TestCase): "elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd" ]) + def test_str_checkpoints(self): + mul_out, b1_out, b2_out, mean_out = self.net() + self.assertEqual(len(mean_out.block.ops), 4) + self.assertEqual([op.type for op in mean_out.block.ops], + ["mul", "elementwise_add", "elementwise_add", "mean"]) + sgd_optimizer = optimizer.SGD(learning_rate=1.0) + recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer) + recompute_optimizer._set_checkpoints([b1_out.name]) + opts, params_grads = recompute_optimizer.minimize(mean_out) + + self.assertEqual(len(mean_out.block.ops), 13) + self.assertEqual([op.type for op in mean_out.block.ops], [ + "mul", "elementwise_add", "elementwise_add", "mean", + "fill_constant", "mean_grad", "elementwise_add_grad", "mul", + "elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd" + ]) + def test_multi_checkpoint(self): mul_out, b1_out, b2_out, mean_out = self.net() self.assertEqual(len(mean_out.block.ops), 4) -- GitLab