提交 df7a266a 编写于 作者: Q qiaolongfei

fix adam op for selected rows

上级 b756063c
......@@ -56,9 +56,12 @@ class AdamOp : public framework::OperatorWithKernel {
"Beta2 power accumulator should have 1 dimension");
auto param_dims = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
"Param and Grad input of AdamOp should have same dimension");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
"Param and Grad input of AdamOp should have same dimension");
}
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment1"),
"Param and Moment1 input of AdamOp should have same dimension");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册