提交 4ab36a71 编写于 作者: D dongzhihong

"fix error"

上级 4c9699c5
......@@ -68,8 +68,8 @@ class MulOpGrad : public framework::OperatorWithKernel {
"Input(Out@GRAD) should not be null");
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto dim0 = ctx.Input<Tensor>(framework::GradVarName("X"))->dims();
auto dim1 = ctx.Input<Tensor>(framework::GradVarName("Y"))->dims();
auto dim0 = ctx.Output<Tensor>(framework::GradVarName("X"))->dims();
auto dim1 = ctx.Output<Tensor>(framework::GradVarName("Y"))->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
PADDLE_ENFORCE(dim0[0] * dim1[0] == out_dims[0],
"Out@GRAD[0] must equal to X[0] * Y[0]");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册