From a5d13d593c1e180a89c023075c4f96c38d65fe1c Mon Sep 17 00:00:00 2001 From: Jiawei Wang Date: Tue, 1 Dec 2020 10:57:05 +0800 Subject: [PATCH] Momentum Velocity init in Momentum.__init__() (#29223) * add lamb optimizer and unittest * fix momentum resume training * fix momentum acc --- .../fluid/tests/unittests/test_momentum_op.py | 1 - python/paddle/optimizer/momentum.py | 16 +++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index 40a1c8def5d..1bb57409b78 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -294,7 +294,6 @@ class TestMomentumV2(unittest.TestCase): def test_momentum(self): paddle.enable_static() - place = fluid.CPUPlace() main = fluid.Program() with fluid.program_guard(main): diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index 2cfd8deaef7..601fdce7a34 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -16,6 +16,8 @@ from .optimizer import Optimizer from ..fluid import core from ..fluid import framework from ..fluid.framework import Variable, name_scope +from ..fluid.layer_helper import LayerHelper +import paddle.fluid as fluid __all__ = ["Momentum"] @@ -105,12 +107,20 @@ class Momentum(Optimizer): self.type = "momentum" self._momentum = momentum self._use_nesterov = bool(use_nesterov) + if framework.in_dygraph_mode(): + self.helper = LayerHelper(self.__class__.__name__) + for p in parameters: + self._add_accumulator(self._velocity_acc_str, p) + else: + all_parameters = fluid.default_main_program().global_block( + ).all_parameters() + self.helper = LayerHelper(self.__class__.__name__) + for p in all_parameters: + self._add_accumulator(self._velocity_acc_str, p) def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) - - for p in parameters: - self._add_accumulator(self._velocity_acc_str, p) + # create accumulator in init func, so no implementation here def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) -- GitLab