From 20e579ef2ad9e3afe184ae05ea31ca4b575f810f Mon Sep 17 00:00:00 2001 From: xuezhong Date: Fri, 1 Feb 2019 03:50:46 +0000 Subject: [PATCH] add initial_accumulator_value for adagrad test=develop --- python/paddle/fluid/optimizer.py | 14 +++++++++++++- .../paddle/fluid/tests/unittests/test_optimizer.py | 2 +- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index e0e781a322..ce5e5c4f37 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -662,7 +662,8 @@ class AdagradOptimizer(Optimizer): learning_rate, epsilon=1.0e-6, regularization=None, - name=None): + name=None, + initial_accumulator_value=0.1): assert learning_rate is not None assert epsilon is not None super(AdagradOptimizer, self).__init__( @@ -671,6 +672,7 @@ class AdagradOptimizer(Optimizer): name=name) self.type = "adagrad" self._epsilon = epsilon + self.initial_accumulator_value = initial_accumulator_value def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) @@ -683,6 +685,16 @@ class AdagradOptimizer(Optimizer): moment_acc = self._get_accumulator(self._moment_acc_str, param_and_grad[0]) + startup_block = framework.default_startup_program().global_block() + startup_block.append_op( + type='fill_constant', + inputs={}, + outputs={'Out': [moment_acc]}, + attrs={ + 'dtype': moment_acc.dtype, + 'value': self.initial_accumulator_value, + 'shape': moment_acc.shape, + }) # Create the adagrad optimizer op adagrad_op = block.append_op( diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index 34c9b7e006..95ddc135b3 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -274,7 +274,7 @@ class TestAdagradOptimizer(unittest.TestCase): # Check init_program init_ops = init_program.global_block().ops - self.assertEqual(len(init_ops), 2) + self.assertEqual(len(init_ops), 3) self.assertEqual(init_ops[0].type, "fill_constant") self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate) self.assertEqual(init_ops[1].type, "fill_constant") -- GitLab