diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index d72d333a9a03df17b34356090ca420cbe75fa58f..954fd4827252d7ea6cfc0a90a010acd2378cc328 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -23,6 +23,9 @@ using Tensor = framework::Tensor; template using EigenVector = framework::EigenVector; +template +using EigenScalar = framework::EigenScalar; template class SGDOpKernel : public framework::OpKernel { @@ -31,13 +34,14 @@ class SGDOpKernel : public framework::OpKernel { auto param = ctx.Input("Param"); auto grad = ctx.Input("Grad"); auto param_out = ctx.Output("ParamOut"); - float lr = ctx.Input("LearningRate")->data()[0]; + auto learning_rate = ctx.Input("LearningRate"); param_out->mutable_data(ctx.GetPlace()); auto p = EigenVector::Flatten(*param); auto g = EigenVector::Flatten(*grad); auto o = EigenVector::Flatten(*param_out); + auto lr = EigenScalar::From(*learning_rate); auto place = ctx.GetEigenDevice(); o.device(place) = p - lr * g;