diff --git a/paddle/operators/momentum_op.cc b/paddle/operators/momentum_op.cc index 9be4d15a43d87ae1a27c81498e8b19b0049a3bfa..2d4d6f13720f0e6888edbddcb3243116506227ba 100644 --- a/paddle/operators/momentum_op.cc +++ b/paddle/operators/momentum_op.cc @@ -75,12 +75,17 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("VelocityOut", "(Tensor) Output updated velocity"); AddAttr("mu", "(float) Momentum coefficient"); + AddAttr("useNesterov", "(bool) Use Nesterov Momentum") + .SetDefault(false); AddComment(R"DOC( -Momentum Algorithm (momentum). +Momentum Algorithm with a flag for Nestrov Moemntum (momentum). velocity = mu * velocity + gradient -param = param - learning_rate * velocity +if (use_nesterov): + param = param - gradient * learning_rate + mu * velocity * learning_rate +else: + param = param - learning_rate * velocity )DOC"); } diff --git a/paddle/operators/momentum_op.h b/paddle/operators/momentum_op.h index f7a724f048782ceee8509ddafcb4834fd8dbba8a..e6d6d1da3df9f7e43a93fcc2e12658a01a491f81 100644 --- a/paddle/operators/momentum_op.h +++ b/paddle/operators/momentum_op.h @@ -34,6 +34,7 @@ class MomentumOpKernel : public framework::OpKernel { velocity_out->mutable_data(ctx.GetPlace()); float mu = ctx.Attr("mu"); + bool use_nesterov = ctx.Attr("useNesterov"); auto p_out = framework::EigenVector::Flatten(*param_out); auto v_out = framework::EigenVector::Flatten(*velocity_out); @@ -46,8 +47,14 @@ class MomentumOpKernel : public framework::OpKernel { auto place = ctx.GetEigenDevice(); Eigen::DSizes grad_dsize(grad->numel()); + v_out.device(place) = v * mu + g; - p_out.device(place) = p - lr.broadcast(grad_dsize) * v_out; + if (use_nesterov) { + p_out.device(place) = p - g * lr.broadcast(grad_dsize) + + v_out * mu * lr.broadcast(grad_dsize); + } else { + p_out.device(place) = p - lr.broadcast(grad_dsize) * v_out; + } } }; diff --git a/python/paddle/v2/framework/tests/test_momentum_op.py b/python/paddle/v2/framework/tests/test_momentum_op.py index d3353ff6e4f4da32eaefdd4e816a621ddac8bece..654d31975aab4578055e7e70ade202bd2c3d93cb 100644 --- a/python/paddle/v2/framework/tests/test_momentum_op.py +++ b/python/paddle/v2/framework/tests/test_momentum_op.py @@ -3,7 +3,7 @@ import numpy as np from op_test import OpTest -class TestMomentumOp(OpTest): +class TestMomentumOp1(OpTest): def setUp(self): self.op_type = "momentum" @@ -12,6 +12,7 @@ class TestMomentumOp(OpTest): velocity = np.zeros((123, 321)).astype("float32") learning_rate = np.array([0.001]).astype("float32") mu = 0.0001 + use_nesterov = False self.inputs = { 'Param': param, @@ -23,7 +24,47 @@ class TestMomentumOp(OpTest): self.attrs = {'mu': mu} velocity_out = mu * velocity + grad - param_out = param - learning_rate * velocity_out + if use_nesterov: + param_out = param - grad * learning_rate + \ + velocity_out * mu * learning_rate + else: + param_out = param - learning_rate * velocity_out + + self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} + + def test_check_output(self): + self.check_output() + + +class TestMomentumOp2(OpTest): + '''Test Momentum with defaukt values for attributes + ''' + + def setUp(self): + self.op_type = "momentum" + + param = np.random.random((123, 321)).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + velocity = np.zeros((123, 321)).astype("float32") + learning_rate = np.array([0.001]).astype("float32") + mu = 0.0001 + use_nesterov = True + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Velocity': velocity, + 'LearningRate': learning_rate + } + + self.attrs = {'mu': mu, 'useNesterov': use_nesterov} + + velocity_out = mu * velocity + grad + if use_nesterov: + param_out = param - grad * learning_rate + \ + velocity_out * mu * learning_rate + else: + param_out = param - learning_rate * velocity_out self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} diff --git a/python/paddle/v2/framework/tests/test_rmsprop_op.py b/python/paddle/v2/framework/tests/test_rmsprop_op.py index 3e5ff733e9b55fe8c9727e9721e25083a494be15..237bcfccceee89f62fc05e4c6c972a76d1875367 100644 --- a/python/paddle/v2/framework/tests/test_rmsprop_op.py +++ b/python/paddle/v2/framework/tests/test_rmsprop_op.py @@ -46,7 +46,7 @@ class TestRmspropOp1(OpTest): class TestRmspropOp2(OpTest): - '''Test RMSProp with defaukt values for attributes + '''Test RMSProp with default values for attributes ''' def setUp(self):