提交 c10da26c 编写于 作者: S sidgoyal78

Modify implementation

上级 d28b3094
......@@ -57,25 +57,30 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
MomentumOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "Input parameter");
AddInput("Grad", "Input gradient");
AddInput("Velocity", "Input velocity");
AddInput("LearningRate", "Input learning rate");
AddOutput("ParamOut", "Output parameter");
AddOutput("VelocityOut", "Output velocity");
AddAttr<float>("mu", "Momentum coefficient");
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter that has to be updated");
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter");
AddInput("Velocity",
"(Tensor, default Tensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated");
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"Input learning rate");
AddOutput("ParamOut", "(Tensor) Output updated parameter");
AddOutput("VelocityOut", "(Tensor) Output updated velocity");
AddAttr<float>("mu", "(float) Momentum coefficient");
AddComment(R"DOC(
Momentum Algorithm (momentum).
velocity_out = mu * velocity - learning_rate * grad
param_out = param + velocity_out
Ref: Sutskever, Ilya, et al. "On the importance of initialization
and momentum in deep learning." ICML 2013;
http://jmlr.org/proceedings/papers/v28/sutskever13.pdf
velocity = mu * velocity + gradient
param = param - learning_rate * velocity
)DOC");
}
......
......@@ -36,16 +36,16 @@ class MomentumOpKernel : public framework::OpKernel<T> {
float mu = ctx.Attr<float>("mu");
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto g = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Grad"));
auto v = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Velocity"));
float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0];
auto param = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto grad = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Grad"));
auto velocity = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Velocity"));
float learning_rate = ctx.Input<Tensor>("LearningRate")->data<float>()[0];
auto p_out = EigenVector<T>::Flatten(*param_out);
auto v_out = EigenVector<T>::Flatten(*velocity_out);
auto place = ctx.GetEigenDevice<Place>();
v_out.device(place) = mu * v - lr * g;
p_out.device(place) = p + v_out;
v_out.device(place) = velocity * mu + grad;
p_out.device(place) = param - learning_rate * v_out;
}
};
......
......@@ -22,8 +22,8 @@ class TestMomentumOp(OpTest):
self.attrs = {'mu': mu}
velocity_out = mu * velocity - learning_rate * grad
param_out = param + velocity_out
velocity_out = mu * velocity + grad
param_out = param - learning_rate * velocity_out
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册