diff --git a/paddle/operators/adagrad_op.h b/paddle/operators/adagrad_op.h index 0d77dbcbacd4efb6c1900e57b5c4ea9e9b136771..667c1939a2a9cdb6d0a669dc3388c8ccd8107c57 100644 --- a/paddle/operators/adagrad_op.h +++ b/paddle/operators/adagrad_op.h @@ -47,8 +47,8 @@ class AdagradOpKernel : public framework::OpKernel { *ctx.Input("Grad")); auto moment = framework::EigenVector::Flatten( *ctx.Input("Moment")); - auto lr = framework::EigenVector::Flatten( - *ctx.Input("LearningRate")); + auto* learning_rate = ctx.Input("LearningRate"); + auto* lr = learning_rate->data(); auto param_out = framework::EigenVector::Flatten(*param_out_tensor); auto moment_out = framework::EigenVector::Flatten(*moment_out_tensor); @@ -57,7 +57,7 @@ class AdagradOpKernel : public framework::OpKernel { moment_out.device(*place) = moment + grad * grad; Eigen::DSizes m_dsize(moment_out_tensor->numel()); param_out.device(*place) = - param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); + param - lr[0] * grad / (moment_out.sqrt() + epsilon); } else if (grad_var->IsType()) { auto* param_tensor = ctx.Input("Param"); PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor);