From 8ebc31d9358c919fdd6f50d502f4ee071a91d38e Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 4 Oct 2017 17:13:02 -0700 Subject: [PATCH] optimize the dsize --- paddle/operators/sgd_op.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index b501d244d7f..26f4012f258 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -33,10 +33,10 @@ class SGDOpKernel : public framework::OpKernel { auto p = framework::EigenVector::Flatten(*param); auto g = framework::EigenVector::Flatten(*grad); auto o = framework::EigenVector::Flatten(*param_out); - auto lr = framework::EigenVector::From(*learning_rate); + auto lr = framework::EigenVector::Flatten(*learning_rate); auto place = ctx.GetEigenDevice(); - Eigen::DSizes grad_dsize(grad->dims()[0], grad->dims()[1]); + Eigen::DSizes grad_dsize(grad->numel()); o.device(place) = p - lr.broadcast(grad_dsize) * g; } }; -- GitLab