From d0f0a2520c0d865d1dc432c4c9621b9eea8626eb Mon Sep 17 00:00:00 2001 From: zhongpu <2013000149@qq.com> Date: Thu, 9 Jan 2020 22:06:36 +0800 Subject: [PATCH] test Optimizer in dygraph (#21949) * test Optimizer in dygraph, test=develop * add optest for Optimizer in dygraph, test=develop * fix adagrad optimizer, test=develop * fix dpsgd optimizer, test=develop * fix test_optimizer.py, test=develop * fix dpsgd optimizer, this op only support cpu, test=develop * add optest for optimizer, test=develop * add description for dpsgd, test=develop * add rmsprop to white_list in unused_var_check.cc, test=develop * polish code style, test=develop * polish code style, test=develop * delete seed attribute for DpsgdOptimizer, test=develop * change testing to debugging, test=develop --- paddle/fluid/framework/unused_var_check.cc | 3 +- paddle/fluid/operators/optimizers/dpsgd_op.cc | 10 + paddle/fluid/operators/optimizers/dpsgd_op.h | 10 +- python/paddle/fluid/optimizer.py | 42 ++- .../unittests/test_imperative_optimizer.py | 268 +++++++++++++++++- .../fluid/tests/unittests/test_optimizer.py | 2 +- 6 files changed, 309 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/framework/unused_var_check.cc b/paddle/fluid/framework/unused_var_check.cc index 1d433bfbd00..d33e9d8a76a 100644 --- a/paddle/fluid/framework/unused_var_check.cc +++ b/paddle/fluid/framework/unused_var_check.cc @@ -57,7 +57,8 @@ const std::unordered_set op_has_unsed_vars_white_list = { "warpctc_grad", "sync_batch_norm", "match_matrix_tensor_grad", - "ngraph_engine"}; + "ngraph_engine", + "rmsprop"}; namespace paddle { namespace framework { diff --git a/paddle/fluid/operators/optimizers/dpsgd_op.cc b/paddle/fluid/operators/optimizers/dpsgd_op.cc index 9a7b2112d4e..3bcf17fc7b3 100644 --- a/paddle/fluid/operators/optimizers/dpsgd_op.cc +++ b/paddle/fluid/operators/optimizers/dpsgd_op.cc @@ -83,6 +83,16 @@ class DpsgdOpMaker : public framework::OpProtoAndCheckerMaker { "(float, default 1.0e-8) " "Constant for numerical stability") .SetDefault(1.0f); + AddAttr( + "seed", + "(int, default 0) " + "This property is only used for debugging, users do not need to set it." + "Random seed for generating samples. If seed is set to 0, this " + "operator will use the" + "system's random number seed, otherwise, this operator will always " + "generate the same random" + "number every time.") + .SetDefault(0); AddComment(R"DOC( Dpsgd Optimizer. diff --git a/paddle/fluid/operators/optimizers/dpsgd_op.h b/paddle/fluid/operators/optimizers/dpsgd_op.h index 171691613bb..4eb52feb851 100644 --- a/paddle/fluid/operators/optimizers/dpsgd_op.h +++ b/paddle/fluid/operators/optimizers/dpsgd_op.h @@ -79,16 +79,14 @@ class DpsgdOpKernel : public framework::OpKernel { float X; float mu = 0.0; float U1, U2; - unsigned seed = (unsigned int)(time(NULL)); + unsigned seed = static_cast(ctx.Attr("seed")); + if (seed == 0) { + seed = (unsigned)(time(NULL)); + } std::minstd_rand engine; engine.seed(seed); std::uniform_real_distribution dist(0.0, 1.0); do { - // srand((unsigned int)(time(NULL))); - // U1 = (rand() * 1.0) / RAND_MAX; - // U2 = (rand() * 1.0) / RAND_MAX; - // U1 = rand_rr(&seed) * (1.0 / RAND_MAX); - // U2 = rand_rr(&seed) * (1.0 / RAND_MAX); U1 = dist(engine); U2 = dist(engine); V1 = 2 * U1 - 1; diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 376985d257b..29a6644f744 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -1030,6 +1030,8 @@ class DGCMomentumOptimizer(Optimizer): num_trainers=None, regularization=None, name=None): + if framework.in_dygraph_mode(): + raise Exception("In dygraph, don't support DGCMomentumOptimizer.") assert learning_rate is not None assert momentum is not None super(DGCMomentumOptimizer, self).__init__( @@ -1526,24 +1528,16 @@ class AdagradOptimizer(Optimizer): assert isinstance(block, framework.Block) for p in parameters: - self._add_accumulator(self._moment_acc_str, p) + self._add_accumulator( + self._moment_acc_str, + p, + fill_value=self.initial_accumulator_value) def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) moment_acc = self._get_accumulator(self._moment_acc_str, param_and_grad[0]) - startup_block = framework.default_startup_program().global_block() - startup_block.append_op( - type='fill_constant', - inputs={}, - outputs={'Out': [moment_acc]}, - attrs={ - 'dtype': moment_acc.dtype, - 'value': self.initial_accumulator_value, - 'shape': moment_acc.shape, - }) - # Create the adagrad optimizer op adagrad_op = block.append_op( type=self.type, @@ -2031,11 +2025,21 @@ class DpsgdOptimizer(Optimizer): self._clip = clip self._batch_size = batch_size self._sigma = sigma + ''' + Note(wangzhongpu): + This property is only used for debugging, do not need to set it! + Dpsgd operator use time(NULL) as random seed to generate random number. + However, during debugging, we need determinated result, so we will set self._seed to a fixed number. + ''' + self._seed = None def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) # create the dpsgd optimize op + if self._seed == None: + self._seed = 0 + dpsgd_op = block.append_op( type=self.type, inputs={ @@ -2047,7 +2051,8 @@ class DpsgdOptimizer(Optimizer): attrs={ "clip": self._clip, "batch_size": self._batch_size, - "sigma": self._sigma + "sigma": self._sigma, + "seed": self._seed }, stop_gradient=True) @@ -2846,6 +2851,8 @@ class ModelAverage(Optimizer): max_average_window=10000, regularization=None, name=None): + if framework.in_dygraph_mode(): + raise Exception("In dygraph, don't support ModelAverage.") super(ModelAverage, self).__init__( 0.0, regularization=regularization, name=name) self.average_window = average_window_rate @@ -3159,6 +3166,9 @@ class ExponentialMovingAverage(object): """ def __init__(self, decay=0.999, thres_steps=None, name=None): + if framework.in_dygraph_mode(): + raise Exception( + "In dygraph, don't support ExponentialMovingAverage.") self._decay = decay self._thres_steps = thres_steps self._name = name if name is not None else '' @@ -3380,6 +3390,8 @@ class PipelineOptimizer(object): queue_size=30, sync_steps=1, start_cpu_core_id=0): + if framework.in_dygraph_mode(): + raise Exception("In dygraph, don't support PipelineOptimizer.") # TODO: check properties self._optimizer = optimizer self._cut_list = cut_list @@ -3665,6 +3677,8 @@ class RecomputeOptimizer(Optimizer): """ def __init__(self, optimizer): + if framework.in_dygraph_mode(): + raise Exception("In dygraph, don't support RecomputeOptimizer.") self._optimizer = optimizer self._checkpoints = None @@ -3951,6 +3965,8 @@ class LookaheadOptimizer(object): def __init__(self, inner_optimizer, alpha=0.5, k=5): + if framework.in_dygraph_mode(): + raise Exception("In dygraph, don't support LookaheadOptimizer.") assert (inner_optimizer is not None), "inner optimizer can not be None" assert ( 0.0 <= alpha <= 1.0 diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py index 217f57fdc82..ac12e79156d 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py @@ -22,11 +22,15 @@ import six import paddle import paddle.fluid as fluid from paddle.fluid import core -from paddle.fluid.optimizer import SGDOptimizer, Adam +from paddle.fluid.optimizer import SGDOptimizer, Adam, MomentumOptimizer, LarsMomentumOptimizer, AdagradOptimizer, AdamaxOptimizer, DpsgdOptimizer, DecayedAdagradOptimizer, AdadeltaOptimizer, RMSPropOptimizer, FtrlOptimizer, LambOptimizer +from paddle.fluid.optimizer import ModelAverage, DGCMomentumOptimizer, ExponentialMovingAverage, PipelineOptimizer, LookaheadOptimizer, RecomputeOptimizer from paddle.fluid.dygraph import Linear from paddle.fluid.dygraph.base import to_variable from test_imperative_base import new_program_scope +# Note(wangzhongpu) +# In dygraph, don't support ModelAverage, DGCMomentumOptimizer, ExponentialMovingAverage, PipelineOptimizer, LookaheadOptimizer, RecomputeOptimizer. + class MLP(fluid.Layer): def __init__(self, param_attr=None, bias_attr=None): @@ -60,11 +64,32 @@ class TestImperativeOptimizerBase(unittest.TestCase): return _reader_imple - def _check_mlp(self): + def _check_exception(self, exception_message, place=None): + seed = 90 + batch_size = 128 + if place == None: + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + + with fluid.dygraph.guard(place): + try: + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + mlp = MLP() + optimizer = self.get_optimizer_dygraph( + parameter_list=mlp.parameters()) + except Exception as e: + assert str(e) == exception_message + + def _check_mlp(self, place=None): seed = 90 batch_size = 128 - with fluid.dygraph.guard(): + if place == None: + place = fluid.CPUPlace() if not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + + with fluid.dygraph.guard(place): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed @@ -109,8 +134,11 @@ class TestImperativeOptimizerBase(unittest.TestCase): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - exe = fluid.Executor(fluid.CPUPlace( - ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) + if place == None: + place = fluid.CPUPlace() if not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + + exe = fluid.Executor(place) mlp = MLP() optimizer = self.get_optimizer() @@ -312,5 +340,235 @@ class TestImperativeOptimizerNoamDecay(TestImperativeOptimizerBase): self._check_mlp() +class TestImperativeMomentumOptimizer(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = MomentumOptimizer( + learning_rate=0.001, momentum=0.9, parameter_list=parameter_list) + return optimizer + + def get_optimizer(self): + optimizer = MomentumOptimizer(learning_rate=0.001, momentum=0.9) + return optimizer + + def test_momentum(self): + self._check_mlp() + + +class TestImperativeLarsMomentumOptimizer(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = LarsMomentumOptimizer( + learning_rate=0.001, momentum=0.9, parameter_list=parameter_list) + return optimizer + + def get_optimizer(self): + optimizer = LarsMomentumOptimizer(learning_rate=0.001, momentum=0.9) + return optimizer + + def test_larsmomentum(self): + self._check_mlp() + + +class TestImperativeAdagradOptimizer(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = AdagradOptimizer( + learning_rate=0.2, parameter_list=parameter_list) + return optimizer + + def get_optimizer(self): + optimizer = AdagradOptimizer(learning_rate=0.2) + return optimizer + + def test_adagrad(self): + self._check_mlp() + + +class TestImperativeAdamaxOptimizer(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = AdamaxOptimizer( + learning_rate=0.2, parameter_list=parameter_list) + return optimizer + + def get_optimizer(self): + optimizer = AdamaxOptimizer(learning_rate=0.2) + return optimizer + + def test_adamax(self): + self._check_mlp() + + +class TestImperativeDpsgdOptimizer(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = DpsgdOptimizer( + learning_rate=0.01, + clip=10.0, + batch_size=16.0, + sigma=1.0, + parameter_list=parameter_list) + optimizer._seed = 100 + return optimizer + + def get_optimizer(self): + optimizer = DpsgdOptimizer( + learning_rate=0.01, clip=10.0, batch_size=16.0, sigma=1.0) + optimizer._seed = 100 + return optimizer + + def test_dpsgd(self): + self._check_mlp(place=fluid.CPUPlace()) + + +class TestImperativeDecayedAdagradOptimizer(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = DecayedAdagradOptimizer( + learning_rate=0.2, parameter_list=parameter_list) + return optimizer + + def get_optimizer(self): + optimizer = DecayedAdagradOptimizer(learning_rate=0.2) + return optimizer + + def test_decayadagrad(self): + self._check_mlp() + + +class TestImperativeAdadeltaOptimizer(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = AdadeltaOptimizer( + learning_rate=0.0003, + epsilon=1.0e-6, + rho=0.95, + parameter_list=parameter_list) + return optimizer + + def get_optimizer(self): + optimizer = AdadeltaOptimizer( + learning_rate=0.0003, epsilon=1.0e-6, rho=0.95) + return optimizer + + def test_adadelta(self): + self._check_mlp() + + +class TestImperativeRMSPropOptimizer(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = RMSPropOptimizer( + learning_rate=0.1, parameter_list=parameter_list) + return optimizer + + def get_optimizer(self): + optimizer = RMSPropOptimizer(learning_rate=0.1) + return optimizer + + def test_rmsprop(self): + self._check_mlp() + + +class TestImperativeFtrlOptimizer(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = FtrlOptimizer( + learning_rate=0.1, parameter_list=parameter_list) + return optimizer + + def get_optimizer(self): + optimizer = FtrlOptimizer(learning_rate=0.1) + return optimizer + + def test_ftrl(self): + self._check_mlp() + + +def exclude_fn(param): + return param.name.endswith('.b_0') + + +class TestImperativeLambOptimizer(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = LambOptimizer( + learning_rate=0.002, + exclude_from_weight_decay_fn=exclude_fn, + parameter_list=parameter_list) + return optimizer + + def get_optimizer(self): + optimizer = LambOptimizer( + learning_rate=0.002, exclude_from_weight_decay_fn=exclude_fn) + return optimizer + + def test_lamb(self): + self._check_mlp() + + +class TestImperativeModelAverage(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = ModelAverage( + 0.15, min_average_window=10000, max_average_window=12500) + return optimizer + + def test_modelaverage(self): + exception_message = "In dygraph, don't support ModelAverage." + self._check_exception(exception_message) + + +class TestImperativeDGCMomentumOptimizer(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = DGCMomentumOptimizer( + learning_rate=0.0001, + momentum=0.9, + rampup_step=1000, + rampup_begin_step=1252, + sparsity=[0.999, 0.999]) + return optimizer + + def test_dgcmomentum(self): + exception_message = "In dygraph, don't support DGCMomentumOptimizer." + self._check_exception(exception_message) + + +class TestImperativeExponentialMovingAverage(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = ExponentialMovingAverage(0.999) + return optimizer + + def test_exponentialmoving(self): + exception_message = "In dygraph, don't support ExponentialMovingAverage." + self._check_exception(exception_message) + + +class TestImperativePipelineOptimizer(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = fluid.optimizer.SGD(learning_rate=0.5, + parameter_list=parameter_list) + optimizer = PipelineOptimizer(optimizer) + return optimizer + + def test_pipline(self): + exception_message = "In dygraph, don't support PipelineOptimizer." + self._check_exception(exception_message) + + +class TestImperativeLookaheadOptimizer(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = fluid.optimizer.SGD(learning_rate=0.5, + parameter_list=parameter_list) + optimizer = LookaheadOptimizer(optimizer, alpha=0.5, k=5) + return optimizer + + def test_lookahead(self): + exception_message = "In dygraph, don't support LookaheadOptimizer." + self._check_exception(exception_message) + + +class TestImperativeRecomputeOptimizer(TestImperativeOptimizerBase): + def get_optimizer_dygraph(self, parameter_list): + optimizer = fluid.optimizer.SGD(learning_rate=0.5, + parameter_list=parameter_list) + optimizer = RecomputeOptimizer(optimizer) + return optimizer + + def test_recompute(self): + exception_message = "In dygraph, don't support RecomputeOptimizer." + self._check_exception(exception_message) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index e74786e1d49..f97c40b6d99 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -270,7 +270,7 @@ class TestAdagradOptimizer(unittest.TestCase): # Check init_program init_ops = init_program.global_block().ops - self.assertEqual(len(init_ops), 3) + self.assertEqual(len(init_ops), 2) self.assertEqual(init_ops[0].type, "fill_constant") self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate) self.assertEqual(init_ops[1].type, "fill_constant") -- GitLab