From 2e355f032e6b457b1e6f8ddc75ac1b518e0ee831 Mon Sep 17 00:00:00 2001 From: Siddharth Goyal Date: Thu, 9 Nov 2017 12:55:10 -0800 Subject: [PATCH] Fix attribute naming for momentum_op (#5453) * Fix attribute naming for momentum_op * Fix minor typo in comment * Fix attribute name * Fix names in test_optimizer * Fix python wrapper --- paddle/operators/momentum_op.cc | 2 +- paddle/operators/momentum_op.h | 2 +- python/paddle/v2/framework/optimizer.py | 2 +- python/paddle/v2/framework/tests/test_momentum_op.py | 4 ++-- python/paddle/v2/framework/tests/test_optimizer.py | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/operators/momentum_op.cc b/paddle/operators/momentum_op.cc index e8ce16f4cfc..19954006195 100644 --- a/paddle/operators/momentum_op.cc +++ b/paddle/operators/momentum_op.cc @@ -75,7 +75,7 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("VelocityOut", "(Tensor) Output updated velocity"); AddAttr("mu", "(float) Momentum coefficient"); - AddAttr("useNesterov", + AddAttr("use_nesterov", "(bool, default false) " "Use Nesterov Momentum") .SetDefault(false); diff --git a/paddle/operators/momentum_op.h b/paddle/operators/momentum_op.h index e6d6d1da3df..8f7f5eb5c21 100644 --- a/paddle/operators/momentum_op.h +++ b/paddle/operators/momentum_op.h @@ -34,7 +34,7 @@ class MomentumOpKernel : public framework::OpKernel { velocity_out->mutable_data(ctx.GetPlace()); float mu = ctx.Attr("mu"); - bool use_nesterov = ctx.Attr("useNesterov"); + bool use_nesterov = ctx.Attr("use_nesterov"); auto p_out = framework::EigenVector::Flatten(*param_out); auto v_out = framework::EigenVector::Flatten(*velocity_out); diff --git a/python/paddle/v2/framework/optimizer.py b/python/paddle/v2/framework/optimizer.py index f20865d604f..5b4cdecf2c4 100644 --- a/python/paddle/v2/framework/optimizer.py +++ b/python/paddle/v2/framework/optimizer.py @@ -297,7 +297,7 @@ class MomentumOptimizer(Optimizer): "VelocityOut": velocity_acc }, attrs={"mu": self._momentum, - "useNesterov": self._use_nesterov}) + "use_nesterov": self._use_nesterov}) return momentum_op diff --git a/python/paddle/v2/framework/tests/test_momentum_op.py b/python/paddle/v2/framework/tests/test_momentum_op.py index 654d31975aa..638095f7564 100644 --- a/python/paddle/v2/framework/tests/test_momentum_op.py +++ b/python/paddle/v2/framework/tests/test_momentum_op.py @@ -37,7 +37,7 @@ class TestMomentumOp1(OpTest): class TestMomentumOp2(OpTest): - '''Test Momentum with defaukt values for attributes + '''Test Momentum with default values for attributes ''' def setUp(self): @@ -57,7 +57,7 @@ class TestMomentumOp2(OpTest): 'LearningRate': learning_rate } - self.attrs = {'mu': mu, 'useNesterov': use_nesterov} + self.attrs = {'mu': mu, 'use_nesterov': use_nesterov} velocity_out = mu * velocity + grad if use_nesterov: diff --git a/python/paddle/v2/framework/tests/test_optimizer.py b/python/paddle/v2/framework/tests/test_optimizer.py index 9333df8f7f3..a39e7402600 100644 --- a/python/paddle/v2/framework/tests/test_optimizer.py +++ b/python/paddle/v2/framework/tests/test_optimizer.py @@ -98,7 +98,7 @@ class TestMomentumOptimizer(unittest.TestCase): self.assertEqual(len(opts), 1) sgd_op = opts[0] self.assertEqual(sgd_op.type, "momentum") - self.assertFalse(sgd_op.attr('useNesterov')) + self.assertFalse(sgd_op.attr('use_nesterov')) # Check accumulators accumulators = momentum_optimizer.get_accumulators() @@ -143,7 +143,7 @@ class TestMomentumOptimizer(unittest.TestCase): self.assertEqual(len(opts), 1) sgd_op = opts[0] self.assertEqual(sgd_op.type, "momentum") - self.assertTrue(sgd_op.attr('useNesterov')) + self.assertTrue(sgd_op.attr('use_nesterov')) # Check accumulators accumulators = momentum_optimizer.get_accumulators() -- GitLab