diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 903ca7b18478899ae97c3d6f01dfab2aa0a8a317..9a57e6b68f5e82ac3d4d9656cc6f7478ed75075a 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -68,8 +68,8 @@ class MulOpGrad : public framework::OperatorWithKernel { "Input(Out@GRAD) should not be null"); auto *x_grad = ctx.Output(framework::GradVarName("X")); auto *y_grad = ctx.Output(framework::GradVarName("Y")); - auto dim0 = ctx.Input(framework::GradVarName("X"))->dims(); - auto dim1 = ctx.Input(framework::GradVarName("Y"))->dims(); + auto dim0 = ctx.Output(framework::GradVarName("X"))->dims(); + auto dim1 = ctx.Output(framework::GradVarName("Y"))->dims(); auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); PADDLE_ENFORCE(dim0[0] * dim1[0] == out_dims[0], "Out@GRAD[0] must equal to X[0] * Y[0]");