提交 82b820e9 编写于 作者: Q qiaolongfei

fix rowwise_add_grad_op

上级 cef27dab
......@@ -63,7 +63,7 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
"Input(Out@GRAD) should not be null");
auto dims0 = ctx.Input<Tensor>("X")->dims();
auto dims1 = ctx.Input<Tensor>("b")->dims();
PADDLE_ENFORCE_EQ(1, framework::product(dims1), "b dims should be 1")
PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
ctx.Output<Tensor>(framework::GradVarName("X"))->Resize(dims0);
ctx.Output<Tensor>(framework::GradVarName("b"))->Resize(dims1);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册