From df7a266ae238afbbde27a38b279d5145f188f12b Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 27 Jun 2018 15:56:45 +0800 Subject: [PATCH] fix adam op for selected rows --- paddle/fluid/operators/adam_op.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/adam_op.cc b/paddle/fluid/operators/adam_op.cc index 6ee73c3000f..5d670fe3b9d 100644 --- a/paddle/fluid/operators/adam_op.cc +++ b/paddle/fluid/operators/adam_op.cc @@ -56,9 +56,12 @@ class AdamOp : public framework::OperatorWithKernel { "Beta2 power accumulator should have 1 dimension"); auto param_dims = ctx->GetInputDim("Param"); - PADDLE_ENFORCE_EQ( - param_dims, ctx->GetInputDim("Grad"), - "Param and Grad input of AdamOp should have same dimension"); + if (ctx->GetInputsVarType("Grad")[0] == + framework::proto::VarType::LOD_TENSOR) { + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("Grad"), + "Param and Grad input of AdamOp should have same dimension"); + } PADDLE_ENFORCE_EQ( param_dims, ctx->GetInputDim("Moment1"), "Param and Moment1 input of AdamOp should have same dimension"); -- GitLab