提交 20e579ef 编写于 作者: X xuezhong

add initial_accumulator_value for adagrad

test=develop
上级 c1092374
...@@ -662,7 +662,8 @@ class AdagradOptimizer(Optimizer): ...@@ -662,7 +662,8 @@ class AdagradOptimizer(Optimizer):
learning_rate, learning_rate,
epsilon=1.0e-6, epsilon=1.0e-6,
regularization=None, regularization=None,
name=None): name=None,
initial_accumulator_value=0.1):
assert learning_rate is not None assert learning_rate is not None
assert epsilon is not None assert epsilon is not None
super(AdagradOptimizer, self).__init__( super(AdagradOptimizer, self).__init__(
...@@ -671,6 +672,7 @@ class AdagradOptimizer(Optimizer): ...@@ -671,6 +672,7 @@ class AdagradOptimizer(Optimizer):
name=name) name=name)
self.type = "adagrad" self.type = "adagrad"
self._epsilon = epsilon self._epsilon = epsilon
self.initial_accumulator_value = initial_accumulator_value
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -683,6 +685,16 @@ class AdagradOptimizer(Optimizer): ...@@ -683,6 +685,16 @@ class AdagradOptimizer(Optimizer):
moment_acc = self._get_accumulator(self._moment_acc_str, moment_acc = self._get_accumulator(self._moment_acc_str,
param_and_grad[0]) param_and_grad[0])
startup_block = framework.default_startup_program().global_block()
startup_block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [moment_acc]},
attrs={
'dtype': moment_acc.dtype,
'value': self.initial_accumulator_value,
'shape': moment_acc.shape,
})
# Create the adagrad optimizer op # Create the adagrad optimizer op
adagrad_op = block.append_op( adagrad_op = block.append_op(
......
...@@ -274,7 +274,7 @@ class TestAdagradOptimizer(unittest.TestCase): ...@@ -274,7 +274,7 @@ class TestAdagradOptimizer(unittest.TestCase):
# Check init_program # Check init_program
init_ops = init_program.global_block().ops init_ops = init_program.global_block().ops
self.assertEqual(len(init_ops), 2) self.assertEqual(len(init_ops), 3)
self.assertEqual(init_ops[0].type, "fill_constant") self.assertEqual(init_ops[0].type, "fill_constant")
self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate) self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate)
self.assertEqual(init_ops[1].type, "fill_constant") self.assertEqual(init_ops[1].type, "fill_constant")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册