diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 329ab95327b76ad997ef27257ca6a5f2a12f1379..460e458ca4f7f40746f0dbf7e258a165faa88e1a 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -66,11 +66,11 @@ class MulOpGrad : public framework::OperatorWithKernel { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); - auto x_dims = ctx.Output("X")->dims(); - auto y_dims = ctx.Output("Y")->dims(); + auto x_dims = ctx.Input("X")->dims(); + auto y_dims = ctx.Input("Y")->dims(); + auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); auto *x_grad = ctx.Output(framework::GradVarName("X")); auto *y_grad = ctx.Output(framework::GradVarName("Y")); - auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); PADDLE_ENFORCE(x_dims[0] == out_dims[0], "Out@GRAD M X N must equal to X dims 0, M "); PADDLE_ENFORCE(y_dims[1] == out_dims[1],