diff --git a/paddle/operators/momentum_op.h b/paddle/operators/momentum_op.h index fa3788a8abd147abb4934268be9af3a57fbc956b..f7a724f048782ceee8509ddafcb4834fd8dbba8a 100644 --- a/paddle/operators/momentum_op.h +++ b/paddle/operators/momentum_op.h @@ -19,33 +19,35 @@ limitations under the License. */ namespace paddle { namespace operators { -using Tensor = framework::Tensor; -template -using EigenVector = framework::EigenVector; - template class MomentumOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto param_out = ctx.Output("ParamOut"); - auto velocity_out = ctx.Output("VelocityOut"); + auto param_out = ctx.Output("ParamOut"); + auto velocity_out = ctx.Output("VelocityOut"); + auto param = ctx.Input("Param"); + auto velocity = ctx.Input("Velocity"); + auto grad = ctx.Input("Grad"); + auto learning_rate = ctx.Input("LearningRate"); param_out->mutable_data(ctx.GetPlace()); velocity_out->mutable_data(ctx.GetPlace()); float mu = ctx.Attr("mu"); - auto param = EigenVector::Flatten(*ctx.Input("Param")); - auto grad = EigenVector::Flatten(*ctx.Input("Grad")); - auto velocity = EigenVector::Flatten(*ctx.Input("Velocity")); - float learning_rate = ctx.Input("LearningRate")->data()[0]; - auto p_out = EigenVector::Flatten(*param_out); - auto v_out = EigenVector::Flatten(*velocity_out); + auto p_out = framework::EigenVector::Flatten(*param_out); + auto v_out = framework::EigenVector::Flatten(*velocity_out); + + auto p = framework::EigenVector::Flatten(*param); + auto v = framework::EigenVector::Flatten(*velocity); + auto g = framework::EigenVector::Flatten(*grad); + auto lr = framework::EigenVector::Flatten(*learning_rate); + auto place = ctx.GetEigenDevice(); - v_out.device(place) = velocity * mu + grad; - p_out.device(place) = param - learning_rate * v_out; + Eigen::DSizes grad_dsize(grad->numel()); + v_out.device(place) = v * mu + g; + p_out.device(place) = p - lr.broadcast(grad_dsize) * v_out; } };