From ee7b3ed09e699da191fe238ab409f32318637380 Mon Sep 17 00:00:00 2001 From: qiaolongfei <qiaolongfei@baidu.com> Date: Wed, 4 Oct 2017 15:33:44 -0700 Subject: [PATCH] use EigenScalar to get learning_rate from GPU device --- paddle/operators/sgd_op.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index d72d333a9a..954fd48272 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -23,6 +23,9 @@ using Tensor = framework::Tensor; template <typename T, int MajorType = Eigen::RowMajor, typename IndexType = Eigen::DenseIndex> using EigenVector = framework::EigenVector<T, MajorType, IndexType>; +template <typename T, int MajorType = Eigen::RowMajor, + typename IndexType = Eigen::DenseIndex> +using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>; template <typename Place, typename T> class SGDOpKernel : public framework::OpKernel<T> { @@ -31,13 +34,14 @@ class SGDOpKernel : public framework::OpKernel<T> { auto param = ctx.Input<Tensor>("Param"); auto grad = ctx.Input<Tensor>("Grad"); auto param_out = ctx.Output<Tensor>("ParamOut"); - float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0]; + auto learning_rate = ctx.Input<Tensor>("LearningRate"); param_out->mutable_data<T>(ctx.GetPlace()); auto p = EigenVector<T>::Flatten(*param); auto g = EigenVector<T>::Flatten(*grad); auto o = EigenVector<T>::Flatten(*param_out); + auto lr = EigenScalar<T>::From(*learning_rate); auto place = ctx.GetEigenDevice<Place>(); o.device(place) = p - lr * g; -- GitLab