diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index 98849ca60d4d74dcd78538495bd19b358f81ecd1..6aee9383da3c1d6bd6bd06afc4e9ee037179ef3d 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -529,6 +529,44 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx, funcs::InverseMultiplyFunctor>( dev_ctx, dout, ddy_safe, dx, axis); + } else if ((!dx) && dy) { + DenseTensor tmp_a(ddout->dtype()); + tmp_a.Resize(ddout->dims()); + + dev_ctx.template Alloc(&tmp_a); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, x, ddy_safe, &tmp_a, axis); + + auto ddout_t1 = phi::EigenVector::Flatten(tmp_a); + + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, ddx_safe, y, ddout, axis); + + auto ddout_t2 = phi::EigenVector::Flatten(*ddout); + ddout_t2.device(place) = ddout_t2 + ddout_t1; + + // NOTE: in the following ElemwiseGradCompute, for the + // first output tensor is nullptr, the branch to calculate first + // output tensor will not be activated, DivGradDx function will not + // be called and can be ignored, the first branch has little effect + // on running speed. + phi::funcs::ElemwiseGradCompute, MulGradDY>( + dev_ctx, + ddx_safe, + ddy_safe, + dout, + dout, + axis, + nullptr, + dy, + MulGradDX(), + MulGradDY()); } else { DenseTensor tmp_a(ddout->dtype()); tmp_a.Resize(ddout->dims()); @@ -554,19 +592,18 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx, } } } else { - if (dx && dy) { - phi::funcs::ElemwiseGradCompute, MulGradDY>( - dev_ctx, - ddx_safe, - ddy_safe, - dout, - dout, - axis, - dx, - dy, - MulGradDX(), - MulGradDY()); - } + VLOG(3) << "Calculating here with dx: " << dx << ", dy: " << dy; + phi::funcs::ElemwiseGradCompute, MulGradDY>( + dev_ctx, + ddx_safe, + ddy_safe, + dout, + dout, + axis, + dx, + dy, + MulGradDX(), + MulGradDY()); } }