未验证 提交 a5d13d59 编写于 作者: J Jiawei Wang 提交者: GitHub

Momentum Velocity init in Momentum.__init__() (#29223)

* add lamb optimizer and unittest

* fix momentum resume training

* fix momentum acc
上级 4556ad76
...@@ -294,7 +294,6 @@ class TestMomentumV2(unittest.TestCase): ...@@ -294,7 +294,6 @@ class TestMomentumV2(unittest.TestCase):
def test_momentum(self): def test_momentum(self):
paddle.enable_static() paddle.enable_static()
place = fluid.CPUPlace() place = fluid.CPUPlace()
main = fluid.Program() main = fluid.Program()
with fluid.program_guard(main): with fluid.program_guard(main):
......
...@@ -16,6 +16,8 @@ from .optimizer import Optimizer ...@@ -16,6 +16,8 @@ from .optimizer import Optimizer
from ..fluid import core from ..fluid import core
from ..fluid import framework from ..fluid import framework
from ..fluid.framework import Variable, name_scope from ..fluid.framework import Variable, name_scope
from ..fluid.layer_helper import LayerHelper
import paddle.fluid as fluid
__all__ = ["Momentum"] __all__ = ["Momentum"]
...@@ -105,12 +107,20 @@ class Momentum(Optimizer): ...@@ -105,12 +107,20 @@ class Momentum(Optimizer):
self.type = "momentum" self.type = "momentum"
self._momentum = momentum self._momentum = momentum
self._use_nesterov = bool(use_nesterov) self._use_nesterov = bool(use_nesterov)
if framework.in_dygraph_mode():
self.helper = LayerHelper(self.__class__.__name__)
for p in parameters:
self._add_accumulator(self._velocity_acc_str, p)
else:
all_parameters = fluid.default_main_program().global_block(
).all_parameters()
self.helper = LayerHelper(self.__class__.__name__)
for p in all_parameters:
self._add_accumulator(self._velocity_acc_str, p)
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
# create accumulator in init func, so no implementation here
for p in parameters:
self._add_accumulator(self._velocity_acc_str, p)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册