提交 c10da26c 编写于 作者: S sidgoyal78

Modify implementation

上级 d28b3094
...@@ -57,25 +57,30 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -57,25 +57,30 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
MomentumOpMaker(framework::OpProto *proto, MomentumOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "Input parameter"); AddInput("Param",
AddInput("Grad", "Input gradient"); "(Tensor, default Tensor<float>) "
AddInput("Velocity", "Input velocity"); "Input parameter that has to be updated");
AddInput("LearningRate", "Input learning rate"); AddInput("Grad",
"(Tensor, default Tensor<float>) "
AddOutput("ParamOut", "Output parameter"); "Input gradient of the parameter");
AddOutput("VelocityOut", "Output velocity"); AddInput("Velocity",
"(Tensor, default Tensor<float>) "
AddAttr<float>("mu", "Momentum coefficient"); "Input velocity (corresponding to the parameter) "
"that has to be updated");
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"Input learning rate");
AddOutput("ParamOut", "(Tensor) Output updated parameter");
AddOutput("VelocityOut", "(Tensor) Output updated velocity");
AddAttr<float>("mu", "(float) Momentum coefficient");
AddComment(R"DOC( AddComment(R"DOC(
Momentum Algorithm (momentum). Momentum Algorithm (momentum).
velocity_out = mu * velocity - learning_rate * grad velocity = mu * velocity + gradient
param_out = param + velocity_out param = param - learning_rate * velocity
Ref: Sutskever, Ilya, et al. "On the importance of initialization
and momentum in deep learning." ICML 2013;
http://jmlr.org/proceedings/papers/v28/sutskever13.pdf
)DOC"); )DOC");
} }
......
...@@ -36,16 +36,16 @@ class MomentumOpKernel : public framework::OpKernel<T> { ...@@ -36,16 +36,16 @@ class MomentumOpKernel : public framework::OpKernel<T> {
float mu = ctx.Attr<float>("mu"); float mu = ctx.Attr<float>("mu");
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param")); auto param = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto g = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Grad")); auto grad = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Grad"));
auto v = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Velocity")); auto velocity = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Velocity"));
float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0]; float learning_rate = ctx.Input<Tensor>("LearningRate")->data<float>()[0];
auto p_out = EigenVector<T>::Flatten(*param_out); auto p_out = EigenVector<T>::Flatten(*param_out);
auto v_out = EigenVector<T>::Flatten(*velocity_out); auto v_out = EigenVector<T>::Flatten(*velocity_out);
auto place = ctx.GetEigenDevice<Place>(); auto place = ctx.GetEigenDevice<Place>();
v_out.device(place) = mu * v - lr * g; v_out.device(place) = velocity * mu + grad;
p_out.device(place) = p + v_out; p_out.device(place) = param - learning_rate * v_out;
} }
}; };
......
...@@ -22,8 +22,8 @@ class TestMomentumOp(OpTest): ...@@ -22,8 +22,8 @@ class TestMomentumOp(OpTest):
self.attrs = {'mu': mu} self.attrs = {'mu': mu}
velocity_out = mu * velocity - learning_rate * grad velocity_out = mu * velocity + grad
param_out = param + velocity_out param_out = param - learning_rate * velocity_out
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册