提交 a3842494 编写于 作者: A Abhinav Arora 提交者: GitHub

Adding nesterov momentum to python momentum wrapper (#5055)

* Adding nesterov momentum to python momentum wrapper
* Fixing optimizer test after merge
上级 0760043d
...@@ -211,13 +211,14 @@ class MomentumOptimizer(Optimizer): ...@@ -211,13 +211,14 @@ class MomentumOptimizer(Optimizer):
""" """
_velocity_acc_str = "velocity" _velocity_acc_str = "velocity"
def __init__(self, learning_rate, momentum): def __init__(self, learning_rate, momentum, use_nesterov=False):
assert learning_rate is not None assert learning_rate is not None
assert momentum is not None assert momentum is not None
super(MomentumOptimizer, self).__init__() super(MomentumOptimizer, self).__init__()
self.type = "momentum" self.type = "momentum"
self._learning_rate = learning_rate self._learning_rate = learning_rate
self._momentum = momentum self._momentum = momentum
self._use_nesterov = bool(use_nesterov)
def _initialize_tensors(self, block): def _initialize_tensors(self, block):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -259,7 +260,8 @@ class MomentumOptimizer(Optimizer): ...@@ -259,7 +260,8 @@ class MomentumOptimizer(Optimizer):
"ParamOut": param_and_grad[0], "ParamOut": param_and_grad[0],
"VelocityOut": velocity_acc "VelocityOut": velocity_acc
}, },
attrs={"mu": self._momentum}) attrs={"mu": self._momentum,
"useNesterov": self._use_nesterov})
return momentum_op return momentum_op
......
...@@ -36,7 +36,7 @@ class TestMomentumOptimizer(unittest.TestCase): ...@@ -36,7 +36,7 @@ class TestMomentumOptimizer(unittest.TestCase):
def get_velocity_str(self): def get_velocity_str(self):
return self._velocity_acc_str return self._velocity_acc_str
def test_momentum_optimizer(self): def test_vanilla_momentum_optimizer(self):
program = framework.Program() program = framework.Program()
block = program.global_block() block = program.global_block()
mul_x = block.create_parameter( mul_x = block.create_parameter(
...@@ -60,6 +60,42 @@ class TestMomentumOptimizer(unittest.TestCase): ...@@ -60,6 +60,42 @@ class TestMomentumOptimizer(unittest.TestCase):
self.assertEqual(len(opts), 1) self.assertEqual(len(opts), 1)
sgd_op = opts[0] sgd_op = opts[0]
self.assertEqual(sgd_op.type, "momentum") self.assertEqual(sgd_op.type, "momentum")
self.assertFalse(sgd_op.attr('useNesterov'))
# Check accumulators
accumulators = momentum_optimizer.get_accumulators()
self.assertEqual(len(accumulators), 1)
self.assertTrue(momentum_optimizer.get_velocity_str() in accumulators)
velocity_acc = accumulators[momentum_optimizer.get_velocity_str()]
self.assertEqual(len(velocity_acc), 1)
self.assertTrue(mul_x.name in velocity_acc)
def test_nesterov_momentum_optimizer(self):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
block.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
momentum_optimizer = self.MockMomentum(
learning_rate=0.01, momentum=0.2, use_nesterov=True)
params_grads = append_backward_ops(mul_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
opts = momentum_optimizer.create_optimization_pass(params_grads,
mul_out)
self.assertEqual(len(opts), 1)
sgd_op = opts[0]
self.assertEqual(sgd_op.type, "momentum")
self.assertTrue(sgd_op.attr('useNesterov'))
# Check accumulators # Check accumulators
accumulators = momentum_optimizer.get_accumulators() accumulators = momentum_optimizer.get_accumulators()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册