未验证 提交 2e355f03 编写于 作者: S Siddharth Goyal 提交者: GitHub

Fix attribute naming for momentum_op (#5453)

* Fix attribute naming for momentum_op

* Fix minor typo in comment

* Fix attribute name

* Fix names in test_optimizer

* Fix python wrapper
上级 c88f98cf
...@@ -75,7 +75,7 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -75,7 +75,7 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("VelocityOut", "(Tensor) Output updated velocity"); AddOutput("VelocityOut", "(Tensor) Output updated velocity");
AddAttr<float>("mu", "(float) Momentum coefficient"); AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("useNesterov", AddAttr<bool>("use_nesterov",
"(bool, default false) " "(bool, default false) "
"Use Nesterov Momentum") "Use Nesterov Momentum")
.SetDefault(false); .SetDefault(false);
......
...@@ -34,7 +34,7 @@ class MomentumOpKernel : public framework::OpKernel<T> { ...@@ -34,7 +34,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
velocity_out->mutable_data<T>(ctx.GetPlace()); velocity_out->mutable_data<T>(ctx.GetPlace());
float mu = ctx.Attr<float>("mu"); float mu = ctx.Attr<float>("mu");
bool use_nesterov = ctx.Attr<bool>("useNesterov"); bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto p_out = framework::EigenVector<T>::Flatten(*param_out); auto p_out = framework::EigenVector<T>::Flatten(*param_out);
auto v_out = framework::EigenVector<T>::Flatten(*velocity_out); auto v_out = framework::EigenVector<T>::Flatten(*velocity_out);
......
...@@ -297,7 +297,7 @@ class MomentumOptimizer(Optimizer): ...@@ -297,7 +297,7 @@ class MomentumOptimizer(Optimizer):
"VelocityOut": velocity_acc "VelocityOut": velocity_acc
}, },
attrs={"mu": self._momentum, attrs={"mu": self._momentum,
"useNesterov": self._use_nesterov}) "use_nesterov": self._use_nesterov})
return momentum_op return momentum_op
......
...@@ -37,7 +37,7 @@ class TestMomentumOp1(OpTest): ...@@ -37,7 +37,7 @@ class TestMomentumOp1(OpTest):
class TestMomentumOp2(OpTest): class TestMomentumOp2(OpTest):
'''Test Momentum with defaukt values for attributes '''Test Momentum with default values for attributes
''' '''
def setUp(self): def setUp(self):
...@@ -57,7 +57,7 @@ class TestMomentumOp2(OpTest): ...@@ -57,7 +57,7 @@ class TestMomentumOp2(OpTest):
'LearningRate': learning_rate 'LearningRate': learning_rate
} }
self.attrs = {'mu': mu, 'useNesterov': use_nesterov} self.attrs = {'mu': mu, 'use_nesterov': use_nesterov}
velocity_out = mu * velocity + grad velocity_out = mu * velocity + grad
if use_nesterov: if use_nesterov:
......
...@@ -98,7 +98,7 @@ class TestMomentumOptimizer(unittest.TestCase): ...@@ -98,7 +98,7 @@ class TestMomentumOptimizer(unittest.TestCase):
self.assertEqual(len(opts), 1) self.assertEqual(len(opts), 1)
sgd_op = opts[0] sgd_op = opts[0]
self.assertEqual(sgd_op.type, "momentum") self.assertEqual(sgd_op.type, "momentum")
self.assertFalse(sgd_op.attr('useNesterov')) self.assertFalse(sgd_op.attr('use_nesterov'))
# Check accumulators # Check accumulators
accumulators = momentum_optimizer.get_accumulators() accumulators = momentum_optimizer.get_accumulators()
...@@ -143,7 +143,7 @@ class TestMomentumOptimizer(unittest.TestCase): ...@@ -143,7 +143,7 @@ class TestMomentumOptimizer(unittest.TestCase):
self.assertEqual(len(opts), 1) self.assertEqual(len(opts), 1)
sgd_op = opts[0] sgd_op = opts[0]
self.assertEqual(sgd_op.type, "momentum") self.assertEqual(sgd_op.type, "momentum")
self.assertTrue(sgd_op.attr('useNesterov')) self.assertTrue(sgd_op.attr('use_nesterov'))
# Check accumulators # Check accumulators
accumulators = momentum_optimizer.get_accumulators() accumulators = momentum_optimizer.get_accumulators()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册