From 452c75b8034e485a2626e22cac39c95c07b883b4 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Wed, 9 Mar 2022 21:37:32 +0800 Subject: [PATCH] move elementwise mul grad (#40252) --- .../new_executor/standalone_executor_test.cc | 2 +- .../elementwise/elementwise_functor.h | 41 --- .../elementwise/elementwise_mul_op.cc | 49 ---- .../elementwise/elementwise_mul_op.cu | 68 ----- .../elementwise/elementwise_mul_op.h | 238 --------------- .../kernels/cpu/elementwise_grad_kernel.cc | 61 +++- paddle/phi/kernels/elementwise_grad_kernel.h | 39 +++ .../phi/kernels/funcs/elementwise_functor.h | 44 +++ paddle/phi/kernels/gpu/elementwise_grad.h | 37 +++ .../kernels/gpu/elementwise_grad_kernel.cu | 54 ++++ .../impl/elementwise_grad_kernel_impl.h | 273 ++++++++++++++++++ paddle/phi/ops/compat/elementwise_sig.cc | 34 +++ 12 files changed, 539 insertions(+), 401 deletions(-) diff --git a/paddle/fluid/framework/new_executor/standalone_executor_test.cc b/paddle/fluid/framework/new_executor/standalone_executor_test.cc index 62d87b6917e..a69cc0d6b86 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor_test.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor_test.cc @@ -46,7 +46,7 @@ USE_OP(matmul_grad); USE_OP(square); USE_OP(transpose2_grad); USE_OP(concat_grad); -USE_OP(elementwise_mul_grad); +USE_OP_ITSELF(elementwise_mul_grad); USE_OP(sigmoid_grad); USE_OP(tanh_grad); USE_OP(sum); diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index 8e0bf78e9b7..14baeaa74d2 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -196,47 +196,6 @@ struct MinGradXYFunctor { } }; -template -struct MulGradFunctor { - inline HOSTDEVICE T operator()(const T a, const T b) const { return a * b; } -}; -template -struct MulGradFunctor> { - inline HOSTDEVICE Complex operator()(const Complex a, - const Complex b) const { - Complex b_conj(b.real, -b.imag); - return a * b_conj; - } -}; - -template -struct MulGradXYFunctor { - inline HOSTDEVICE phi::Array operator()(const InT a, const InT b, - const InT c) { - phi::Array outs; - // dx = dout * y - outs[0] = a * b; - // dy = dout * x - outs[1] = a * c; - return outs; - } -}; - -template -struct MulGradXYFunctor, Complex> { - inline HOSTDEVICE phi::Array, 2> operator()( - const Complex a, const Complex b, const Complex c) { - phi::Array, 2> outs; - // dx = dout * y - Complex b_conj(b.real, -b.imag); - outs[0] = a * b_conj; - // dy = dout * x - Complex c_conj(c.real, -c.imag); - outs[1] = a * c_conj; - return outs; - } -}; - // Ternary compare template struct MaxGradXFunctor { diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index e172279145e..830e09eeae4 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -173,55 +173,6 @@ REGISTER_OP_CPU_KERNEL( paddle::platform::complex>, ops::ElementwiseMulKernel>); -REGISTER_OP_CPU_KERNEL( - elementwise_mul_grad, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel>, - ops::ElementwiseMulGradKernel>); -REGISTER_OP_CPU_KERNEL( - elementwise_mul_grad_grad, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel>, - ops::ElementwiseMulDoubleGradKernel>); -REGISTER_OP_CPU_KERNEL( - elementwise_mul_triple_grad, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel>, - ops::ElementwiseMulTripleGradKernel>); REGISTER_OP_VERSION(elementwise_mul) .AddCheckpoint( diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 45c87a27a18..f7b9fd1e265 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -63,33 +63,6 @@ class ElementwiseMulKernel } }; -template -typename std::enable_if< - std::is_same::value>::type -ElementwiseMulGrad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy) { - int axis = ctx.Attr("axis"); - const auto& dev_ctx = - ctx.template device_context(); - const auto place = ctx.GetPlace(); - - if (dx != nullptr && dy != nullptr) { - std::vector ins = {dout, y, x}; - GetGradXAndYOut( - dev_ctx, place, axis, ins, dout, dx, dy, MulGradXYFunctor()); - } else if (dx != nullptr && dy == nullptr) { - std::vector ins = {dout, y}; - GetGradXOrYOut(dev_ctx, place, axis, ins, dout, - dx, MulGradFunctor()); - } else if (dx == nullptr && dy != nullptr) { - std::vector ins = {dout, x}; - GetGradXOrYOut(dev_ctx, place, axis, ins, dout, - dy, MulGradFunctor()); - } -} - } // namespace operators } // namespace paddle @@ -103,44 +76,3 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseMulKernel, ops::ElementwiseMulKernel>, ops::ElementwiseMulKernel>); -REGISTER_OP_CUDA_KERNEL( - elementwise_mul_grad, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel>, - ops::ElementwiseMulGradKernel>); -REGISTER_OP_CUDA_KERNEL( - elementwise_mul_grad_grad, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel>, - ops::ElementwiseMulDoubleGradKernel>); -REGISTER_OP_CUDA_KERNEL( - elementwise_mul_triple_grad, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel>, - ops::ElementwiseMulTripleGradKernel>); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index c81266d5844..58a3123c7e3 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -137,244 +137,6 @@ class ElementwiseMulKernel : public framework::OpKernel { } } }; -template -struct MulGradDX { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; } -}; - -template -struct MulGradDX> { - HOSTDEVICE paddle::platform::complex operator()( - paddle::platform::complex x, paddle::platform::complex y, - paddle::platform::complex out, - paddle::platform::complex dout) const { - paddle::platform::complex y_conj(y.real, -y.imag); - return dout * y_conj; - } -}; - -template -struct MulGradDY { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; } -}; - -template -struct MulGradDY> { - HOSTDEVICE paddle::platform::complex operator()( - paddle::platform::complex x, paddle::platform::complex y, - paddle::platform::complex out, - paddle::platform::complex dout) const { - paddle::platform::complex x_conj(x.real, -x.imag); - return dout * x_conj; - } -}; -template -typename std::enable_if< - std::is_same::value>::type -ElementwiseMulGrad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy) { - int axis = ctx.Attr("axis"); - ElemwiseGradCompute, MulGradDY>( - ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX(), MulGradDY()); -} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -template -typename std::enable_if< - std::is_same::value>::type -ElementwiseMulGrad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy); -#endif - -template -class ElementwiseMulGradKernel : public ElemwiseGradKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - ElemwiseGradKernel::Compute(ctx); - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* out = dout; // out is not necessary - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - - ElementwiseMulGrad(ctx, x, y, out, dout, dx, dy); - } -}; - -template -class ElementwiseMulDoubleGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input("DOut"); - auto* ddx = ctx.Input("DDX"); - auto* ddy = ctx.Input("DDY"); - - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - auto* ddout = ctx.Output("DDOut"); - - if (ddout) ddout->mutable_data(ctx.GetPlace()); - - Tensor ddx_safe, ddy_safe; - GetDoubleGradSafeTensor(ctx, x, ddx, &ddx_safe); - GetDoubleGradSafeTensor(ctx, y, ddy, &ddy_safe); - - // dx = dout * ddy - // dy = dout * ddx - // ddout = ddx * y + x * ddy - // change computation sequence to save memory, so ddout can inplace ddx and - // dx can be used as 'tmp' tensor - // (1) dx = x * ddy - // (2) dy = dout * ddx - // (3) ddout = ddx * y - // (4) ddout = ddout + dx - // (5) dx = dout * ddy - if (ddout) { - int axis = ctx.Attr("axis"); - auto& place = - *ctx.template device_context().eigen_device(); - // size(ddout) > size(ddx), ddout can't use memory of ddx using inplace - if (ddout->numel() > ddx->numel()) { - ElemwiseGradCompute, MulGradDY>( - ctx, ddx_safe, ddy_safe, *dout, *dout, axis, dx, dy, MulGradDX(), - MulGradDY()); - - Tensor ddout_tmp; - ddout_tmp.mutable_data(ddout->dims(), ctx.GetPlace()); - - default_elementwise_mul(ctx, y, &ddx_safe, ddout); - default_elementwise_mul(ctx, &ddy_safe, x, - &ddout_tmp); - - auto ddout_t = framework::EigenVector::Flatten(*ddout); - auto ddout_tmp_t = framework::EigenVector::Flatten(ddout_tmp); - ddout_t.device(place) = ddout_t + ddout_tmp_t; - } else { - // use dx to save memory, other than alloc tmp tensor - Tensor* ddout_tmp = dx; - - default_elementwise_mul(ctx, x, &ddy_safe, ddout_tmp); - // NOTE: in the following ElemwiseGradCompute, for the - // first output tensor is nullptr, the branch to calculate first - // output tensor will not be activated, DivGradDx function will not - // be called and can be ignored, the first branch has little effect - // on running speed. - ElemwiseGradCompute, MulGradDY>( - ctx, ddx_safe, ddy_safe, *dout, *dout, axis, nullptr, dy, - MulGradDX(), MulGradDY()); - default_elementwise_mul(ctx, &ddx_safe, y, ddout); - - auto ddout_t = framework::EigenVector::Flatten(*ddout); - auto ddout_tmp_t = framework::EigenVector::Flatten(*ddout_tmp); - ddout_t.device(place) = ddout_t + ddout_tmp_t; - default_elementwise_mul(ctx, dout, &ddy_safe, dx); - } - } - } -}; - -template -class ElementwiseMulTripleGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; - // get input - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input("DOut"); - auto* ddx = ctx.Input("DDX"); - auto* ddy = ctx.Input("DDY"); - - auto* d_dx = ctx.Input("D_DX"); - auto* d_dy = ctx.Input("D_DY"); - auto* d_ddout = ctx.Input("D_DDOut"); - - // get output - auto* out_d_x = ctx.Output("D_X"); - auto* out_d_y = ctx.Output("D_Y"); - auto* out_d_dout = ctx.Output("D_DOut"); - - auto* out_d_ddx = ctx.Output("D_DDX"); - auto* out_d_ddy = ctx.Output("D_DDY"); - - if (out_d_x) out_d_x->mutable_data(x->dims(), ctx.GetPlace()); - if (out_d_y) out_d_y->mutable_data(y->dims(), ctx.GetPlace()); - if (out_d_dout) out_d_dout->mutable_data(dout->dims(), ctx.GetPlace()); - if (out_d_ddx) out_d_ddx->mutable_data(x->dims(), ctx.GetPlace()); - if (out_d_ddy) out_d_ddy->mutable_data(y->dims(), ctx.GetPlace()); - - auto& place = *ctx.template device_context().eigen_device(); - - Tensor ddx_safe, ddy_safe; - GetDoubleGradSafeTensor(ctx, x, ddx, &ddx_safe); - GetDoubleGradSafeTensor(ctx, y, ddy, &ddy_safe); - - if (d_ddout) { - if (out_d_x) { - // out_d_x = ddy * d_ddout - default_elementwise_mul(ctx, &ddy_safe, d_ddout, - out_d_x); - } - if (out_d_y) { - // out_d_y = ddx * d_ddout - default_elementwise_mul(ctx, &ddx_safe, d_ddout, - out_d_y); - } - } - - if (out_d_dout) { - // get out_d_dout - // out_d_dout = ddy * d_dx + d_dy * ddx - Tensor out_d_dout_tmp; - out_d_dout_tmp.mutable_data(dout->dims(), ctx.GetPlace()); - default_elementwise_mul(ctx, d_dy, &ddx_safe, - out_d_dout); - default_elementwise_mul(ctx, &ddy_safe, d_dx, - &out_d_dout_tmp); - auto out_d_dout_t = framework::EigenVector::Flatten(*out_d_dout); - auto out_d_dout_tmp_t = - framework::EigenVector::Flatten(out_d_dout_tmp); - out_d_dout_t.device(place) = out_d_dout_t + out_d_dout_tmp_t; - } - - if (out_d_ddx) { - // get out_d_ddx - // out_d_ddx = dout * d_dy + y * d_ddout - Tensor out_d_ddx_tmp; - out_d_ddx_tmp.mutable_data(ddx->dims(), ctx.GetPlace()); - default_elementwise_mul(ctx, dout, d_dy, out_d_ddx); - default_elementwise_mul(ctx, y, d_ddout, - &out_d_ddx_tmp); - auto out_d_ddx_t = framework::EigenVector::Flatten(*out_d_ddx); - auto out_d_ddx_tmp_t = framework::EigenVector::Flatten(out_d_ddx_tmp); - out_d_ddx_t.device(place) = out_d_ddx_t + out_d_ddx_tmp_t; - } - - if (out_d_ddy) { - // get out_d_ddy - // out_d_ddy = dout * d_dx + x * d_ddout - Tensor out_d_ddy_tmp; - out_d_ddy_tmp.mutable_data(ddy->dims(), ctx.GetPlace()); - default_elementwise_mul(ctx, dout, d_dx, out_d_ddy); - default_elementwise_mul(ctx, x, d_ddout, - &out_d_ddy_tmp); - auto out_d_ddy_t = framework::EigenVector::Flatten(*out_d_ddy); - auto out_d_ddy_tmp_t = framework::EigenVector::Flatten(out_d_ddy_tmp); - out_d_ddy_t.device(place) = out_d_ddy_t + out_d_ddy_tmp_t; - } - } -}; } // namespace operators } // namespace paddle diff --git a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc index c9177f1c46e..cd513e809fd 100644 --- a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc @@ -121,6 +121,20 @@ void DivideGradKernel(const Context& dev_ctx, dev_ctx, x, y, out, dout, axis, dx, dy, DivGradDX(), DivGradDY()); } +template +void MultiplyGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy) { + funcs::ElementwiseGradPreProcess(dout, dx); + auto* out = &dout; // out is not necessary + phi::funcs::ElemwiseGradCompute, MulGradDY>( + dev_ctx, x, y, *out, dout, axis, dx, dy, MulGradDX(), MulGradDY()); +} + } // namespace phi PD_REGISTER_KERNEL(add_grad, @@ -193,8 +207,8 @@ PD_REGISTER_KERNEL(divide_grad, double, int, int64_t, - paddle::platform::complex, - paddle::platform::complex) {} + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(divide_double_grad, CPU, @@ -204,5 +218,44 @@ PD_REGISTER_KERNEL(divide_double_grad, double, int, int64_t, - paddle::platform::complex, - paddle::platform::complex) {} + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(multiply_grad, + CPU, + ALL_LAYOUT, + phi::MultiplyGradKernel, + float, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(multiply_double_grad, + CPU, + ALL_LAYOUT, + phi::MultiplyDoubleGradKernel, + float, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(multiply_triple_grad, + CPU, + ALL_LAYOUT, + phi::MultiplyTripleGradKernel, + float, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/elementwise_grad_kernel.h b/paddle/phi/kernels/elementwise_grad_kernel.h index bcd5a98f07e..58ae11a9c42 100644 --- a/paddle/phi/kernels/elementwise_grad_kernel.h +++ b/paddle/phi/kernels/elementwise_grad_kernel.h @@ -85,4 +85,43 @@ void DivideDoubleGradKernel(const Context& dev_ctx, DenseTensor* dy, DenseTensor* dout, DenseTensor* ddout); + +template +void MultiplyGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy); + +template +void MultiplyDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + paddle::optional ddx, + paddle::optional ddy, + int axis, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout); + +template +void MultiplyTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + paddle::optional ddx, + paddle::optional ddy, + const DenseTensor& d_dx, + const DenseTensor& d_dy, + paddle::optional d_ddout, + int axis, + DenseTensor* d_x, + DenseTensor* d_y, + DenseTensor* d_dout, + DenseTensor* d_ddx, + DenseTensor* d_ddy); + } // namespace phi diff --git a/paddle/phi/kernels/funcs/elementwise_functor.h b/paddle/phi/kernels/funcs/elementwise_functor.h index 5615a450b5c..b01d50015f0 100644 --- a/paddle/phi/kernels/funcs/elementwise_functor.h +++ b/paddle/phi/kernels/funcs/elementwise_functor.h @@ -160,5 +160,49 @@ struct DivGradYFunctor> { } }; +template +struct MultiplyGradFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { return a * b; } +}; +template +struct MultiplyGradFunctor> { + inline HOSTDEVICE ComplexType operator()(const ComplexType a, + const ComplexType b) const { + ComplexType b_conj(b.real, -b.imag); + return a * b_conj; + } +}; + +template +struct MultiplyGradXYFunctor { + inline HOSTDEVICE phi::Array operator()(const InT a, + const InT b, + const InT c) { + phi::Array outs; + // dx = dout * y + outs[0] = a * b; + // dy = dout * x + outs[1] = a * c; + return outs; + } +}; + +template +struct MultiplyGradXYFunctor, ComplexType> { + inline HOSTDEVICE phi::Array, 2> operator()( + const ComplexType a, + const ComplexType b, + const ComplexType c) { + phi::Array, 2> outs; + // dx = dout * y + ComplexType b_conj(b.real, -b.imag); + outs[0] = a * b_conj; + // dy = dout * x + ComplexType c_conj(c.real, -c.imag); + outs[1] = a * c_conj; + return outs; + } +}; + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/elementwise_grad.h b/paddle/phi/kernels/gpu/elementwise_grad.h index 98df65c92f3..e5432b5f918 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad.h +++ b/paddle/phi/kernels/gpu/elementwise_grad.h @@ -360,4 +360,41 @@ void ElementwiseDivGrad(const GPUContext &dev_ctx, } } +/* +****************************** + Mul Grad +****************************** +*/ + +template +void ElementwiseMulGrad(const GPUContext &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &dout, + DenseTensor *dx, + DenseTensor *dy, + int axis) { + const auto place = dev_ctx.GetPlace(); + + if (dx != nullptr && dy != nullptr) { + std::vector ins = {&dout, &y, &x}; + GetGradXAndYOut( + dev_ctx, + place, + axis, + ins, + dout, + dx, + dy, + funcs::MultiplyGradXYFunctor()); + } else if (dx != nullptr && dy == nullptr) { + std::vector ins = {&dout, &y}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dx, funcs::MultiplyGradFunctor()); + } else if (dx == nullptr && dy != nullptr) { + std::vector ins = {&dout, &x}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dy, funcs::MultiplyGradFunctor()); + } +} } // namespace phi diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index 45c8b9a2163..81f7fac1088 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -136,6 +136,18 @@ void DivideGradKernel(const Context& dev_ctx, } } +template +void MultiplyGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy) { + funcs::ElementwiseGradPreProcess(dout, dx); + ElementwiseMulGrad(dev_ctx, x, y, dout, dx, dy, axis); +} + } // namespace phi PD_REGISTER_KERNEL(add_grad, @@ -228,3 +240,45 @@ PD_REGISTER_KERNEL(divide_double_grad, int64_t, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(multiply_grad, + GPU, + ALL_LAYOUT, + phi::MultiplyGradKernel, + float, + phi::dtype::float16, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(multiply_double_grad, + GPU, + ALL_LAYOUT, + phi::MultiplyDoubleGradKernel, + float, + phi::dtype::float16, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(multiply_triple_grad, + GPU, + ALL_LAYOUT, + phi::MultiplyTripleGradKernel, + float, + phi::dtype::float16, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index e8831f90213..65427e87506 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -259,4 +259,277 @@ void DivideDoubleGradKernel(const Context& dev_ctx, } } +template +struct MulGradDX { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; } +}; + +template +struct MulGradDX> { + HOSTDEVICE phi::dtype::complex operator()( + phi::dtype::complex x, + phi::dtype::complex y, + phi::dtype::complex out, + phi::dtype::complex dout) const { + phi::dtype::complex y_conj(y.real, -y.imag); + return dout * y_conj; + } +}; + +/* +****************************** + Multiply Grad +****************************** +*/ + +template +struct MulGradDY { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; } +}; + +template +struct MulGradDY> { + HOSTDEVICE phi::dtype::complex operator()( + phi::dtype::complex x, + phi::dtype::complex y, + phi::dtype::complex out, + phi::dtype::complex dout) const { + phi::dtype::complex x_conj(x.real, -x.imag); + return dout * x_conj; + } +}; + +template +void MultiplyDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + paddle::optional ddx, + paddle::optional ddy, + int axis, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout) { + if (ddout) dev_ctx.template Alloc(ddout); + + DenseTensor ddx_safe, ddy_safe; + funcs::GetDoubleGradSafeTensor( + dev_ctx, x, ddx.get_ptr(), &ddx_safe); + funcs::GetDoubleGradSafeTensor( + dev_ctx, y, ddy.get_ptr(), &ddy_safe); + + // dx = dout * ddy + // dy = dout * ddx + // ddout = ddx * y + x * ddy + // change computation sequence to save memory, so ddout can inplace ddx and + // dx can be used as 'tmp' tensor + // (1) dx = x * ddy + // (2) dy = dout * ddx + // (3) ddout = ddx * y + // (4) ddout = ddout + dx + // (5) dx = dout * ddy + if (ddout) { + auto& place = *dev_ctx.eigen_device(); + // size(ddout) > size(ddx), ddout can't use memory of ddx using inplace + if (ddout->numel() > ddx.get_ptr()->numel()) { + phi::funcs::ElemwiseGradCompute, MulGradDY>( + dev_ctx, + ddx_safe, + ddy_safe, + dout, + dout, + axis, + dx, + dy, + MulGradDX(), + MulGradDY()); + + DenseTensor ddout_tmp; + ddout_tmp.Resize(ddout->dims()); + dev_ctx.template Alloc(&ddout_tmp); + + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, y, ddx_safe, ddout, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, ddy_safe, x, &ddout_tmp, axis); + + auto ddout_t = phi::EigenVector::Flatten(*ddout); + auto ddout_tmp_t = phi::EigenVector::Flatten(ddout_tmp); + ddout_t.device(place) = ddout_t + ddout_tmp_t; + } else { + // use dx to save memory, other than alloc tmp tensor + DenseTensor* ddout_tmp = dx; + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, x, ddy_safe, ddout_tmp, axis); + // NOTE: in the following ElemwiseGradCompute, for the + // first output tensor is nullptr, the branch to calculate first + // output tensor will not be activated, DivGradDx function will not + // be called and can be ignored, the first branch has little effect + // on running speed. + phi::funcs::ElemwiseGradCompute, MulGradDY>( + dev_ctx, + ddx_safe, + ddy_safe, + dout, + dout, + axis, + nullptr, + dy, + MulGradDX(), + MulGradDY()); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, ddx_safe, y, ddout, axis); + + auto ddout_t = phi::EigenVector::Flatten(*ddout); + auto ddout_tmp_t = phi::EigenVector::Flatten(*ddout_tmp); + ddout_t.device(place) = ddout_t + ddout_tmp_t; + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, dout, ddy_safe, dx, axis); + } + } +} + +template +void MultiplyTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + paddle::optional ddx, + paddle::optional ddy, + const DenseTensor& d_dx, + const DenseTensor& d_dy, + paddle::optional d_ddout, + int axis, + DenseTensor* d_x, + DenseTensor* d_y, + DenseTensor* d_dout, + DenseTensor* d_ddx, + DenseTensor* d_ddy) { + if (d_x) { + d_x->Resize(x.dims()); + dev_ctx.template Alloc(d_x); + } + if (d_y) { + d_y->Resize(y.dims()); + dev_ctx.template Alloc(d_y); + } + if (d_dout) { + d_dout->Resize(dout.dims()); + dev_ctx.template Alloc(d_dout); + } + if (d_ddx) { + d_ddx->Resize(x.dims()); + dev_ctx.template Alloc(d_ddx); + } + if (d_ddy) { + d_ddy->Resize(y.dims()); + dev_ctx.template Alloc(d_ddy); + } + + auto& place = *dev_ctx.eigen_device(); + + DenseTensor ddx_safe, ddy_safe; + funcs::GetDoubleGradSafeTensor( + dev_ctx, x, ddx.get_ptr(), &ddx_safe); + funcs::GetDoubleGradSafeTensor( + dev_ctx, y, ddy.get_ptr(), &ddy_safe); + + if (d_ddout.get_ptr()) { + if (d_x) { + // d_x = ddy * d_ddout + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, ddy_safe, *(d_ddout.get_ptr()), d_x, axis); + } + if (d_y) { + // d_y = ddx * d_ddout + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, ddx_safe, *(d_ddout.get_ptr()), d_y, axis); + } + } + + if (d_dout) { + // get d_dout + // d_dout = ddy * d_dx + d_dy * ddx + DenseTensor d_dout_tmp; + d_dout_tmp.Resize(dout.dims()); + dev_ctx.template Alloc(&d_dout_tmp); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, d_dy, ddx_safe, d_dout, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, ddy_safe, d_dx, &d_dout_tmp, axis); + auto d_dout_t = phi::EigenVector::Flatten(*d_dout); + auto d_dout_tmp_t = phi::EigenVector::Flatten(d_dout_tmp); + d_dout_t.device(place) = d_dout_t + d_dout_tmp_t; + } + + if (d_ddx) { + // get d_ddx + // d_ddx = dout * d_dy + y * d_ddout + DenseTensor d_ddx_tmp; + d_ddx_tmp.Resize(ddx->dims()); + dev_ctx.template Alloc(&d_ddx_tmp); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, dout, d_dy, d_ddx, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, y, *(d_ddout.get_ptr()), &d_ddx_tmp, axis); + auto d_ddx_t = phi::EigenVector::Flatten(*d_ddx); + auto d_ddx_tmp_t = phi::EigenVector::Flatten(d_ddx_tmp); + d_ddx_t.device(place) = d_ddx_t + d_ddx_tmp_t; + } + + if (d_ddy) { + // get d_ddy + // d_ddy = dout * d_dx + x * d_ddout + DenseTensor d_ddy_tmp; + d_ddy_tmp.Resize(ddy->dims()); + dev_ctx.template Alloc(&d_ddy_tmp); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, dout, d_dx, d_ddy, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, x, *(d_ddout.get_ptr()), &d_ddy_tmp, axis); + auto d_ddy_t = phi::EigenVector::Flatten(*d_ddy); + auto d_ddy_tmp_t = phi::EigenVector::Flatten(d_ddy_tmp); + d_ddy_t.device(place) = d_ddy_t + d_ddy_tmp_t; + } +} + } // namespace phi diff --git a/paddle/phi/ops/compat/elementwise_sig.cc b/paddle/phi/ops/compat/elementwise_sig.cc index d4a25866907..fc890fa3a49 100644 --- a/paddle/phi/ops/compat/elementwise_sig.cc +++ b/paddle/phi/ops/compat/elementwise_sig.cc @@ -122,6 +122,31 @@ KernelSignature ElementwiseDivDoubleGradOpArgumentMapping( {GradVarName("Y"), "DOut", "DDOut"}); } +KernelSignature ElementwiseMulGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("multiply_grad", + {"X", "Y", GradVarName("Out")}, + {"axis"}, + {GradVarName("X"), GradVarName("Y")}); +} + +KernelSignature ElementwiseMulDoubleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("multiply_double_grad", + {"X", "Y", "DOut", "DDX", "DDY"}, + {"axis"}, + {GradVarName("X"), GradVarName("Y"), "DDOut"}); +} + +KernelSignature ElementwiseMulTripleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "multiply_triple_grad", + {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"}, + {"axis"}, + {"D_X", "D_Y", "D_DOut", "D_DDX", "D_DDY"}); +} + } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(elementwise_add, add); @@ -135,6 +160,9 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad_grad, subtract_double_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad, divide_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad); +PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad, multiply_grad); +PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad); +PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad); PD_REGISTER_ARG_MAPPING_FN(elementwise_add, phi::ElementwiseAddOpArgumentMapping); @@ -158,3 +186,9 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad, phi::ElementwiseDivGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad_grad, phi::ElementwiseDivDoubleGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad, + phi::ElementwiseMulGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad, + phi::ElementwiseMulDoubleGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad, + phi::ElementwiseMulTripleGradOpArgumentMapping); -- GitLab