提交 5380a547 编写于 作者: K kavyasrinet 提交者: GitHub

Adding Nesterov Momentum (#4948)

上级 23785584
......@@ -75,12 +75,17 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("VelocityOut", "(Tensor) Output updated velocity");
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("useNesterov", "(bool) Use Nesterov Momentum")
.SetDefault(false);
AddComment(R"DOC(
Momentum Algorithm (momentum).
Momentum Algorithm with a flag for Nestrov Moemntum (momentum).
velocity = mu * velocity + gradient
param = param - learning_rate * velocity
if (use_nesterov):
param = param - gradient * learning_rate + mu * velocity * learning_rate
else:
param = param - learning_rate * velocity
)DOC");
}
......
......@@ -34,6 +34,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
velocity_out->mutable_data<T>(ctx.GetPlace());
float mu = ctx.Attr<float>("mu");
bool use_nesterov = ctx.Attr<bool>("useNesterov");
auto p_out = framework::EigenVector<T>::Flatten(*param_out);
auto v_out = framework::EigenVector<T>::Flatten(*velocity_out);
......@@ -46,8 +47,14 @@ class MomentumOpKernel : public framework::OpKernel<T> {
auto place = ctx.GetEigenDevice<Place>();
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
v_out.device(place) = v * mu + g;
p_out.device(place) = p - lr.broadcast(grad_dsize) * v_out;
if (use_nesterov) {
p_out.device(place) = p - g * lr.broadcast(grad_dsize) +
v_out * mu * lr.broadcast(grad_dsize);
} else {
p_out.device(place) = p - lr.broadcast(grad_dsize) * v_out;
}
}
};
......
......@@ -3,7 +3,7 @@ import numpy as np
from op_test import OpTest
class TestMomentumOp(OpTest):
class TestMomentumOp1(OpTest):
def setUp(self):
self.op_type = "momentum"
......@@ -12,6 +12,7 @@ class TestMomentumOp(OpTest):
velocity = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.001]).astype("float32")
mu = 0.0001
use_nesterov = False
self.inputs = {
'Param': param,
......@@ -23,7 +24,47 @@ class TestMomentumOp(OpTest):
self.attrs = {'mu': mu}
velocity_out = mu * velocity + grad
param_out = param - learning_rate * velocity_out
if use_nesterov:
param_out = param - grad * learning_rate + \
velocity_out * mu * learning_rate
else:
param_out = param - learning_rate * velocity_out
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
def test_check_output(self):
self.check_output()
class TestMomentumOp2(OpTest):
'''Test Momentum with defaukt values for attributes
'''
def setUp(self):
self.op_type = "momentum"
param = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
velocity = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.001]).astype("float32")
mu = 0.0001
use_nesterov = True
self.inputs = {
'Param': param,
'Grad': grad,
'Velocity': velocity,
'LearningRate': learning_rate
}
self.attrs = {'mu': mu, 'useNesterov': use_nesterov}
velocity_out = mu * velocity + grad
if use_nesterov:
param_out = param - grad * learning_rate + \
velocity_out * mu * learning_rate
else:
param_out = param - learning_rate * velocity_out
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
......
......@@ -46,7 +46,7 @@ class TestRmspropOp1(OpTest):
class TestRmspropOp2(OpTest):
'''Test RMSProp with defaukt values for attributes
'''Test RMSProp with default values for attributes
'''
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册