/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/phi/common/complex.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" namespace phi { template void AddGradImpl(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const DenseTensor& out_grad, int axis, DenseTensor* x_grad, DenseTensor* y_grad, GradFunc grad_func) { phi::funcs::ElementwiseGradPreProcess(out_grad, x_grad); auto* out = &out_grad; // Special case when y_grad is not needed and x_grad doesn't reduce if (x_grad != nullptr && y_grad == nullptr && x_grad->dims() == out_grad.dims()) { VLOG(4) << "Special case when y_grad is not needed and x_grad doesn't " "reduce"; phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); } else if (x_grad == nullptr && y_grad != nullptr && y_grad->dims() == out_grad.dims()) { VLOG(4) << "Special case when x_grad is not needed and y_grad doesn't " "reduce"; phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, y_grad); } else { grad_func(dev_ctx, x, y, *out, out_grad, x_grad, y_grad, axis); } } template void AddDoubleGradImpl(const Context& dev_ctx, const DenseTensor& y, const paddle::optional& ddx, const paddle::optional& ddy, const DenseTensor& dout, int axis, DenseTensor* ddout) { // ddOut = ddx + ddy if (ddout) { DenseTensor ddx_safe, ddy_safe; funcs::GetDoubleGradSafeTensor( dev_ctx, dout, ddx.get_ptr(), &ddx_safe); funcs::GetDoubleGradSafeTensor( dev_ctx, y, ddy.get_ptr(), &ddy_safe); ddout->mutable_data(dev_ctx.GetPlace()); auto ddx_dims = ddx_safe.dims(); auto ddy_dims = ddy_safe.dims(); if (ddx_dims.size() >= ddy_dims.size()) { funcs::ElementwiseCompute, T>( dev_ctx, ddx_safe, ddy_safe, axis, funcs::AddFunctor(), ddout); } else { funcs::ElementwiseCompute, T>( dev_ctx, ddx_safe, ddy_safe, axis, funcs::InverseAddFunctor(), ddout); } } } template void SubtractDoubleGradImpl(const Context& dev_ctx, const DenseTensor& y, const paddle::optional& ddx, const paddle::optional& ddy, const DenseTensor& dout, int axis, DenseTensor* ddout) { // DDOut = ddx - ddy if (ddout) { DenseTensor ddx_safe, ddy_safe; funcs::GetDoubleGradSafeTensor( dev_ctx, dout, ddx.get_ptr(), &ddx_safe); funcs::GetDoubleGradSafeTensor( dev_ctx, y, ddy.get_ptr(), &ddy_safe); ddout->mutable_data(dev_ctx.GetPlace()); funcs::ElementwiseCompute, T>( dev_ctx, ddx_safe, ddy_safe, axis, funcs::SubtractFunctor(), ddout); } } /* ****************************** Divide Grad ****************************** */ template struct DivGradDX { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; } }; template struct DivGradDX> { HOSTDEVICE phi::dtype::complex operator()( phi::dtype::complex x, phi::dtype::complex y, phi::dtype::complex out, phi::dtype::complex dout) const { phi::dtype::complex y_conj(y.real, -y.imag); return dout / y_conj; } }; template struct DivGradDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout * out / y; } }; template struct DivGradDY> { HOSTDEVICE phi::dtype::complex operator()( phi::dtype::complex x, phi::dtype::complex y, phi::dtype::complex out, phi::dtype::complex dout) const { phi::dtype::complex out_div_y_conj((out / y).real, -(out / y).imag); return -dout * out_div_y_conj; } }; template struct DivDoubleDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return y * out * dout - x * dout; } }; template void DivideDoubleGradKernel(const Context& dev_ctx, const DenseTensor& y, const DenseTensor& out, const DenseTensor& dx, paddle::optional ddx, paddle::optional ddy, int axis, DenseTensor* dy, DenseTensor* dout, DenseTensor* ddout) { if (dy) { dy->Resize(y.dims()); dev_ctx.template Alloc(dy); } if (dout) { dout->Resize(out.dims()); dev_ctx.template Alloc(dout); } if (ddout) { ddout->Resize(out.dims()); dev_ctx.template Alloc(ddout); } // ddX_safe == null ? 0 : ddX // ddY_safe == null ? 0 : ddY DenseTensor ddX_safe, ddY_safe; phi::funcs::GetDoubleGradSafeTensor( dev_ctx, dx, ddx.get_ptr(), &ddX_safe); phi::funcs::GetDoubleGradSafeTensor( dev_ctx, y, ddy.get_ptr(), &ddY_safe); // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y // dY = Out * dX * ddY / Y - dX * ddX / Y // dOut = - dX * ddY // To save memory, (1) dout can be used as 'tmp' tensor, (2) ddout can // inplace ddx DenseTensor tmp; if (dout) { tmp = *dout; } else { tmp.Resize(out.dims()); dev_ctx.template Alloc(&tmp); } if (dy) { // dX_div_Y = dX / Y; DenseTensor dX_div_Y = tmp; funcs::DefaultElementwiseOperator, funcs::InverseDivideFunctor>( dev_ctx, dx, y, &dX_div_Y, axis); // NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the // first output tensor is nullptr, the branch to calculate first // output tensor will not be activated, DivGradDx function will not // be called and can be ignored, the first branch has little effect // on running speed. // dY = Out * dX * ddY / Y - dX * ddX / Y phi::funcs::ElemwiseGradCompute, DivDoubleDY>( dev_ctx, ddX_safe, ddY_safe, out, dX_div_Y, axis, nullptr, dy, DivGradDX(), DivDoubleDY()); } if (ddout) { // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y funcs::DefaultElementwiseOperator, funcs::InverseMultiplyFunctor>( dev_ctx, out, ddY_safe, &tmp, axis); funcs::DefaultElementwiseOperator, funcs::InverseSubtractFunctor>( dev_ctx, ddX_safe, tmp, &tmp, axis); funcs::DefaultElementwiseOperator, funcs::InverseDivideFunctor>( dev_ctx, tmp, y, ddout, axis); } if (dout) { // dOut = - dX * ddY funcs::DefaultElementwiseOperator, funcs::InverseMultiplyFunctor>( dev_ctx, dx, ddY_safe, dout, axis); auto& place = *dev_ctx.eigen_device(); auto dout_result = phi::EigenVector::Flatten(*dout); dout_result.device(place) = static_cast(-1) * dout_result; } } template void ElementwiseFMaxGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const DenseTensor& out_grad, int axis, DenseTensor* x_grad, DenseTensor* y_grad) { funcs::ElementwiseGradPreProcess(out_grad, x_grad); auto out = out_grad; // Fake out, not used auto x_dim = x.dims(); auto y_dim = y.dims(); if (x.dims() == y.dims()) { funcs::ElemwiseGradComputeNoBroadcast, funcs::FMaxGradDy>( dev_ctx, x_dim, y_dim, x, y, out, out_grad, axis, x_grad, y_grad, funcs::FMaxGradDx(), funcs::FMaxGradDy()); } else { funcs::ElemwiseGradComputeWithBroadcast, funcs::FMaxGradDy>( dev_ctx, x_dim, y_dim, x, y, out, out_grad, axis, x_grad, y_grad, funcs::FMaxGradDx(), funcs::FMaxGradDy()); } } template void ElementwiseFMinGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const DenseTensor& out_grad, int axis, DenseTensor* x_grad, DenseTensor* y_grad) { funcs::ElementwiseGradPreProcess(out_grad, x_grad); auto out = out_grad; // Fake out, not used auto x_dim = x.dims(); auto y_dim = y.dims(); if (x.dims() == y.dims()) { funcs::ElemwiseGradComputeNoBroadcast, funcs::FMinGradDy>( dev_ctx, x_dim, y_dim, x, y, out, out_grad, axis, x_grad, y_grad, funcs::FMinGradDx(), funcs::FMinGradDy()); } else { funcs::ElemwiseGradComputeWithBroadcast, funcs::FMinGradDy>( dev_ctx, x_dim, y_dim, x, y, out, out_grad, axis, x_grad, y_grad, funcs::FMinGradDx(), funcs::FMinGradDy()); } } template struct MulGradDX { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; } }; // avoid [-Wint-in-bool-context] warning template <> struct MulGradDX { HOSTDEVICE bool operator()(bool x, bool y, bool out, bool dout) const { return dout && y; } }; template struct MulGradDX> { HOSTDEVICE phi::dtype::complex operator()( phi::dtype::complex x, phi::dtype::complex y, phi::dtype::complex out, phi::dtype::complex dout) const { phi::dtype::complex y_conj(y.real, -y.imag); return dout * y_conj; } }; /* ****************************** Multiply Grad ****************************** */ template struct MulGradDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; } }; // avoid [-Wint-in-bool-context] warning template <> struct MulGradDY { HOSTDEVICE bool operator()(bool x, bool y, bool out, bool dout) const { return dout && x; } }; template struct MulGradDY> { HOSTDEVICE phi::dtype::complex operator()( phi::dtype::complex x, phi::dtype::complex y, phi::dtype::complex out, phi::dtype::complex dout) const { phi::dtype::complex x_conj(x.real, -x.imag); return dout * x_conj; } }; template void MultiplyDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const DenseTensor& dout, paddle::optional ddx, paddle::optional ddy, int axis, DenseTensor* dx, DenseTensor* dy, DenseTensor* ddout) { if (ddout) dev_ctx.template Alloc(ddout); DenseTensor ddx_safe, ddy_safe; funcs::GetDoubleGradSafeTensor( dev_ctx, x, ddx.get_ptr(), &ddx_safe); funcs::GetDoubleGradSafeTensor( dev_ctx, y, ddy.get_ptr(), &ddy_safe); // dx = dout * ddy // dy = dout * ddx // ddout = ddx * y + x * ddy // change computation sequence to save memory, so ddout can inplace ddx and // dx can be used as 'tmp' tensor // (1) dx = x * ddy // (2) dy = dout * ddx // (3) ddout = ddx * y // (4) ddout = ddout + dx // (5) dx = dout * ddy if (ddout) { auto& place = *dev_ctx.eigen_device(); // size(ddout) > size(ddx), ddout can't use memory of ddx using inplace if (ddout->numel() > ddx.get_ptr()->numel()) { phi::funcs::ElemwiseGradCompute, MulGradDY>( dev_ctx, ddx_safe, ddy_safe, dout, dout, axis, dx, dy, MulGradDX(), MulGradDY()); DenseTensor ddout_tmp; ddout_tmp.Resize(ddout->dims()); dev_ctx.template Alloc(&ddout_tmp); funcs::DefaultElementwiseOperator, funcs::InverseMultiplyFunctor>( dev_ctx, y, ddx_safe, ddout, axis); funcs::DefaultElementwiseOperator, funcs::InverseMultiplyFunctor>( dev_ctx, ddy_safe, x, &ddout_tmp, axis); auto ddout_t = phi::EigenVector::Flatten(*ddout); auto ddout_tmp_t = phi::EigenVector::Flatten(ddout_tmp); ddout_t.device(place) = ddout_t + ddout_tmp_t; } else { // use dx to save memory, other than alloc tmp tensor DenseTensor* ddout_tmp = dx; funcs::DefaultElementwiseOperator, funcs::InverseMultiplyFunctor>( dev_ctx, x, ddy_safe, ddout_tmp, axis); // NOTE: in the following ElemwiseGradCompute, for the // first output tensor is nullptr, the branch to calculate first // output tensor will not be activated, DivGradDx function will not // be called and can be ignored, the first branch has little effect // on running speed. phi::funcs::ElemwiseGradCompute, MulGradDY>( dev_ctx, ddx_safe, ddy_safe, dout, dout, axis, nullptr, dy, MulGradDX(), MulGradDY()); funcs::DefaultElementwiseOperator, funcs::InverseMultiplyFunctor>( dev_ctx, ddx_safe, y, ddout, axis); auto ddout_t = phi::EigenVector::Flatten(*ddout); auto ddout_tmp_t = phi::EigenVector::Flatten(*ddout_tmp); ddout_t.device(place) = ddout_t + ddout_tmp_t; funcs::DefaultElementwiseOperator, funcs::InverseMultiplyFunctor>( dev_ctx, dout, ddy_safe, dx, axis); } } } template void MultiplyTripleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const DenseTensor& dout, paddle::optional ddx, paddle::optional ddy, const DenseTensor& d_dx, const DenseTensor& d_dy, paddle::optional d_ddout, int axis, DenseTensor* d_x, DenseTensor* d_y, DenseTensor* d_dout, DenseTensor* d_ddx, DenseTensor* d_ddy) { if (d_x) { d_x->Resize(x.dims()); dev_ctx.template Alloc(d_x); } if (d_y) { d_y->Resize(y.dims()); dev_ctx.template Alloc(d_y); } if (d_dout) { d_dout->Resize(dout.dims()); dev_ctx.template Alloc(d_dout); } if (d_ddx) { d_ddx->Resize(x.dims()); dev_ctx.template Alloc(d_ddx); } if (d_ddy) { d_ddy->Resize(y.dims()); dev_ctx.template Alloc(d_ddy); } auto& place = *dev_ctx.eigen_device(); DenseTensor ddx_safe, ddy_safe; funcs::GetDoubleGradSafeTensor( dev_ctx, x, ddx.get_ptr(), &ddx_safe); funcs::GetDoubleGradSafeTensor( dev_ctx, y, ddy.get_ptr(), &ddy_safe); if (d_ddout.get_ptr()) { if (d_x) { // d_x = ddy * d_ddout funcs::DefaultElementwiseOperator, funcs::InverseMultiplyFunctor>( dev_ctx, ddy_safe, *(d_ddout.get_ptr()), d_x, axis); } if (d_y) { // d_y = ddx * d_ddout funcs::DefaultElementwiseOperator, funcs::InverseMultiplyFunctor>( dev_ctx, ddx_safe, *(d_ddout.get_ptr()), d_y, axis); } } if (d_dout) { // get d_dout // d_dout = ddy * d_dx + d_dy * ddx DenseTensor d_dout_tmp; d_dout_tmp.Resize(dout.dims()); dev_ctx.template Alloc(&d_dout_tmp); funcs::DefaultElementwiseOperator, funcs::InverseMultiplyFunctor>( dev_ctx, d_dy, ddx_safe, d_dout, axis); funcs::DefaultElementwiseOperator, funcs::InverseMultiplyFunctor>( dev_ctx, ddy_safe, d_dx, &d_dout_tmp, axis); auto d_dout_t = phi::EigenVector::Flatten(*d_dout); auto d_dout_tmp_t = phi::EigenVector::Flatten(d_dout_tmp); d_dout_t.device(place) = d_dout_t + d_dout_tmp_t; } if (d_ddx) { // get d_ddx // d_ddx = dout * d_dy + y * d_ddout DenseTensor d_ddx_tmp; d_ddx_tmp.Resize(ddx->dims()); dev_ctx.template Alloc(&d_ddx_tmp); funcs::DefaultElementwiseOperator, funcs::InverseMultiplyFunctor>( dev_ctx, dout, d_dy, d_ddx, axis); funcs::DefaultElementwiseOperator, funcs::InverseMultiplyFunctor>( dev_ctx, y, *(d_ddout.get_ptr()), &d_ddx_tmp, axis); auto d_ddx_t = phi::EigenVector::Flatten(*d_ddx); auto d_ddx_tmp_t = phi::EigenVector::Flatten(d_ddx_tmp); d_ddx_t.device(place) = d_ddx_t + d_ddx_tmp_t; } if (d_ddy) { // get d_ddy // d_ddy = dout * d_dx + x * d_ddout DenseTensor d_ddy_tmp; d_ddy_tmp.Resize(ddy->dims()); dev_ctx.template Alloc(&d_ddy_tmp); funcs::DefaultElementwiseOperator, funcs::InverseMultiplyFunctor>( dev_ctx, dout, d_dx, d_ddy, axis); funcs::DefaultElementwiseOperator, funcs::InverseMultiplyFunctor>( dev_ctx, x, *(d_ddout.get_ptr()), &d_ddy_tmp, axis); auto d_ddy_t = phi::EigenVector::Flatten(*d_ddy); auto d_ddy_tmp_t = phi::EigenVector::Flatten(d_ddy_tmp); d_ddy_t.device(place) = d_ddy_t + d_ddy_tmp_t; } } /* ****************************** Maximum Grad ****************************** */ template struct MaxGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * static_cast(x > y); } }; template struct MaxGradDy { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * static_cast(x <= y); } }; /* ****************************** Minimum Grad ****************************** */ template struct MinGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * static_cast(x < y); } }; template struct MinGradDy { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * static_cast(x >= y); } }; template struct HeavisideGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * static_cast(0); } }; template struct HeavisideGradDy { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * static_cast(x == static_cast(0)); } }; template void ElementwiseHeavisideGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const DenseTensor& dout, int axis, DenseTensor* dx, DenseTensor* dy) { funcs::ElementwiseGradPreProcess(dout, dx); phi::funcs:: ElemwiseGradCompute, HeavisideGradDy>( dev_ctx, x, y, dout, dout, axis, dx, dy, HeavisideGradDx(), HeavisideGradDy()); } template struct PowGradDX { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { #if defined(__CUDA_ARCH__) || defined(__HIPCC__) if (std::is_integral::value) { return dout * y * std::pow(static_cast(x), static_cast(y - 1)); } #endif return dout * y * std::pow(x, y - 1); } }; template struct PowGradDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { #if defined(__CUDA_ARCH__) || defined(__HIPCC__) if (std::is_integral::value) { return dout * std::log(static_cast(x)) * std::pow(static_cast(x), static_cast(y)); } #endif return dout * std::log(x) * std::pow(x, y); } }; template void ElementwisePowGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const DenseTensor& dout, int axis, DenseTensor* dx, DenseTensor* dy) { funcs::ElementwiseGradPreProcess(dout, dx); phi::funcs::ElemwiseGradCompute, PowGradDY>( dev_ctx, x, y, dout, dout, axis, dx, dy, PowGradDX(), PowGradDY()); } } // namespace phi