diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 5645df6677ee2a5b52ec00f6f215d4af61b9bab8..329ab95327b76ad997ef27257ca6a5f2a12f1379 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -66,10 +66,10 @@ class MulOpGrad : public framework::OperatorWithKernel { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Output("X")->dims(); + auto y_dims = ctx.Output("Y")->dims(); auto *x_grad = ctx.Output(framework::GradVarName("X")); auto *y_grad = ctx.Output(framework::GradVarName("Y")); - auto x_dims = ctx.Output(framework::GradVarName("X"))->dims(); - auto y_dims = ctx.Output(framework::GradVarName("Y"))->dims(); auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); PADDLE_ENFORCE(x_dims[0] == out_dims[0], "Out@GRAD M X N must equal to X dims 0, M "); diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 2afed81842924796b6f54589fb61627917b4f2f4..9bbd0275261e68a86eb033398b50e73143c99e3b 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -53,7 +53,9 @@ class MulGradKernel : public framework::OpKernel { auto* dY = ctx.Output(framework::GradVarName("Y")); auto* device_context = const_cast(ctx.device_context_); + // dX = dOut' * Y. dX: M x K, dOut : M x N, Y : K x N math::matmul(*dOut, false, *Y, true, 1, dX, 0, device_context); + // dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K math::matmul(*X, true, *dOut, false, 1, dY, 0, device_context); } };