未验证 提交 2e355f03 编写于 作者: S Siddharth Goyal 提交者: GitHub

Fix attribute naming for momentum_op (#5453)

* Fix attribute naming for momentum_op

* Fix minor typo in comment

* Fix attribute name

* Fix names in test_optimizer

* Fix python wrapper
上级 c88f98cf
......@@ -75,7 +75,7 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("VelocityOut", "(Tensor) Output updated velocity");
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("useNesterov",
AddAttr<bool>("use_nesterov",
"(bool, default false) "
"Use Nesterov Momentum")
.SetDefault(false);
......
......@@ -34,7 +34,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
velocity_out->mutable_data<T>(ctx.GetPlace());
float mu = ctx.Attr<float>("mu");
bool use_nesterov = ctx.Attr<bool>("useNesterov");
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto p_out = framework::EigenVector<T>::Flatten(*param_out);
auto v_out = framework::EigenVector<T>::Flatten(*velocity_out);
......
......@@ -297,7 +297,7 @@ class MomentumOptimizer(Optimizer):
"VelocityOut": velocity_acc
},
attrs={"mu": self._momentum,
"useNesterov": self._use_nesterov})
"use_nesterov": self._use_nesterov})
return momentum_op
......
......@@ -37,7 +37,7 @@ class TestMomentumOp1(OpTest):
class TestMomentumOp2(OpTest):
'''Test Momentum with defaukt values for attributes
'''Test Momentum with default values for attributes
'''
def setUp(self):
......@@ -57,7 +57,7 @@ class TestMomentumOp2(OpTest):
'LearningRate': learning_rate
}
self.attrs = {'mu': mu, 'useNesterov': use_nesterov}
self.attrs = {'mu': mu, 'use_nesterov': use_nesterov}
velocity_out = mu * velocity + grad
if use_nesterov:
......
......@@ -98,7 +98,7 @@ class TestMomentumOptimizer(unittest.TestCase):
self.assertEqual(len(opts), 1)
sgd_op = opts[0]
self.assertEqual(sgd_op.type, "momentum")
self.assertFalse(sgd_op.attr('useNesterov'))
self.assertFalse(sgd_op.attr('use_nesterov'))
# Check accumulators
accumulators = momentum_optimizer.get_accumulators()
......@@ -143,7 +143,7 @@ class TestMomentumOptimizer(unittest.TestCase):
self.assertEqual(len(opts), 1)
sgd_op = opts[0]
self.assertEqual(sgd_op.type, "momentum")
self.assertTrue(sgd_op.attr('useNesterov'))
self.assertTrue(sgd_op.attr('use_nesterov'))
# Check accumulators
accumulators = momentum_optimizer.get_accumulators()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册