From ee7b3ed09e699da191fe238ab409f32318637380 Mon Sep 17 00:00:00 2001
From: qiaolongfei <qiaolongfei@baidu.com>
Date: Wed, 4 Oct 2017 15:33:44 -0700
Subject: [PATCH] use EigenScalar to get learning_rate from GPU device

---
 paddle/operators/sgd_op.h | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h
index d72d333a9a..954fd48272 100644
--- a/paddle/operators/sgd_op.h
+++ b/paddle/operators/sgd_op.h
@@ -23,6 +23,9 @@ using Tensor = framework::Tensor;
 template <typename T, int MajorType = Eigen::RowMajor,
           typename IndexType = Eigen::DenseIndex>
 using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
+template <typename T, int MajorType = Eigen::RowMajor,
+          typename IndexType = Eigen::DenseIndex>
+using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
 
 template <typename Place, typename T>
 class SGDOpKernel : public framework::OpKernel<T> {
@@ -31,13 +34,14 @@ class SGDOpKernel : public framework::OpKernel<T> {
     auto param = ctx.Input<Tensor>("Param");
     auto grad = ctx.Input<Tensor>("Grad");
     auto param_out = ctx.Output<Tensor>("ParamOut");
-    float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0];
+    auto learning_rate = ctx.Input<Tensor>("LearningRate");
 
     param_out->mutable_data<T>(ctx.GetPlace());
 
     auto p = EigenVector<T>::Flatten(*param);
     auto g = EigenVector<T>::Flatten(*grad);
     auto o = EigenVector<T>::Flatten(*param_out);
+    auto lr = EigenScalar<T>::From(*learning_rate);
     auto place = ctx.GetEigenDevice<Place>();
 
     o.device(place) = p - lr * g;
-- 
GitLab