未验证 提交 593bbfe3 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #11765 from jacquesqiao/fix-adam-op-for-selectedrows

fix adam op for selected rows
...@@ -56,9 +56,12 @@ class AdamOp : public framework::OperatorWithKernel { ...@@ -56,9 +56,12 @@ class AdamOp : public framework::OperatorWithKernel {
"Beta2 power accumulator should have 1 dimension"); "Beta2 power accumulator should have 1 dimension");
auto param_dims = ctx->GetInputDim("Param"); auto param_dims = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ( if (ctx->GetInputsVarType("Grad")[0] ==
param_dims, ctx->GetInputDim("Grad"), framework::proto::VarType::LOD_TENSOR) {
"Param and Grad input of AdamOp should have same dimension"); PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
"Param and Grad input of AdamOp should have same dimension");
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment1"), param_dims, ctx->GetInputDim("Moment1"),
"Param and Moment1 input of AdamOp should have same dimension"); "Param and Moment1 input of AdamOp should have same dimension");
......
...@@ -282,6 +282,10 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -282,6 +282,10 @@ class AdamOpKernel : public framework::OpKernel<T> {
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
auto& grad = auto& grad =
Ref(ctx.Input<framework::SelectedRows>("Grad"), "Must set Grad"); Ref(ctx.Input<framework::SelectedRows>("Grad"), "Must set Grad");
if (grad.rows().size() == 0) {
VLOG(3) << "grad row size is 0!!";
return;
}
// merge duplicated rows if any. // merge duplicated rows if any.
scatter::MergeAdd<DeviceContext, T> merge_func; scatter::MergeAdd<DeviceContext, T> merge_func;
auto grad_merge = auto grad_merge =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册