提交 82b820e9 编写于 作者: Q qiaolongfei

fix rowwise_add_grad_op

上级 cef27dab
...@@ -63,7 +63,7 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { ...@@ -63,7 +63,7 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto dims0 = ctx.Input<Tensor>("X")->dims(); auto dims0 = ctx.Input<Tensor>("X")->dims();
auto dims1 = ctx.Input<Tensor>("b")->dims(); auto dims1 = ctx.Input<Tensor>("b")->dims();
PADDLE_ENFORCE_EQ(1, framework::product(dims1), "b dims should be 1") PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
ctx.Output<Tensor>(framework::GradVarName("X"))->Resize(dims0); ctx.Output<Tensor>(framework::GradVarName("X"))->Resize(dims0);
ctx.Output<Tensor>(framework::GradVarName("b"))->Resize(dims1); ctx.Output<Tensor>(framework::GradVarName("b"))->Resize(dims1);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册