From 775c60246b66469e06f01a50c89b7b39594a3b63 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 4 Oct 2017 16:53:21 -0700 Subject: [PATCH] remove using in sgd header file --- paddle/operators/sgd_op.h | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index 954fd4827..b501d244d 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -19,32 +19,25 @@ limitations under the License. */ namespace paddle { namespace operators { -using Tensor = framework::Tensor; -template -using EigenVector = framework::EigenVector; -template -using EigenScalar = framework::EigenScalar; - template class SGDOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto param = ctx.Input("Param"); - auto grad = ctx.Input("Grad"); - auto param_out = ctx.Output("ParamOut"); - auto learning_rate = ctx.Input("LearningRate"); + auto param = ctx.Input("Param"); + auto grad = ctx.Input("Grad"); + auto param_out = ctx.Output("ParamOut"); + auto learning_rate = ctx.Input("LearningRate"); param_out->mutable_data(ctx.GetPlace()); - auto p = EigenVector::Flatten(*param); - auto g = EigenVector::Flatten(*grad); - auto o = EigenVector::Flatten(*param_out); - auto lr = EigenScalar::From(*learning_rate); + auto p = framework::EigenVector::Flatten(*param); + auto g = framework::EigenVector::Flatten(*grad); + auto o = framework::EigenVector::Flatten(*param_out); + auto lr = framework::EigenVector::From(*learning_rate); auto place = ctx.GetEigenDevice(); - o.device(place) = p - lr * g; + Eigen::DSizes grad_dsize(grad->dims()[0], grad->dims()[1]); + o.device(place) = p - lr.broadcast(grad_dsize) * g; } }; -- GitLab