From dd09d8b2bfe68ff3a0df6a95fc6fd751434fa9cf Mon Sep 17 00:00:00 2001 From: peterzhang2029 Date: Fri, 5 Jan 2018 15:08:48 +0800 Subject: [PATCH] fix gpu error --- paddle/operators/adagrad_op.h | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/paddle/operators/adagrad_op.h b/paddle/operators/adagrad_op.h index be234cf4c4..66f5b0f449 100644 --- a/paddle/operators/adagrad_op.h +++ b/paddle/operators/adagrad_op.h @@ -47,9 +47,7 @@ class AdagradOpKernel : public framework::OpKernel { *ctx.Input("Grad")); auto moment = framework::EigenVector::Flatten( *ctx.Input("Moment")); - auto* learning_rate = ctx.Input("LearningRate"); - auto* lr = learning_rate->data(); auto param_out = framework::EigenVector::Flatten(*param_out_tensor); auto moment_out = framework::EigenVector::Flatten(*moment_out_tensor); @@ -57,8 +55,16 @@ class AdagradOpKernel : public framework::OpKernel { moment_out.device(*place) = moment + grad * grad; Eigen::DSizes m_dsize(moment_out_tensor->numel()); - param_out.device(*place) = - param - lr[0] * grad / (moment_out.sqrt() + epsilon); + if (platform::is_cpu_place(ctx.GetPlace())) { + auto* lr = learning_rate->data(); + param_out.device(*place) = + param - lr[0] * grad / (moment_out.sqrt() + epsilon); + } else { + auto lr = framework::EigenVector::Flatten(*learning_rate); + param_out.device(*place) = + param - + lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); + } } else if (grad_var->IsType()) { auto* param_tensor = ctx.Input("Param"); PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor); -- GitLab