// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/activation_kernel.h" #include "paddle/phi/kernels/elementwise_add_kernel.h" #include "paddle/phi/kernels/elementwise_multiply_kernel.h" #include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/activation_functor.h" #include "paddle/phi/kernels/scale_kernel.h" namespace phi { template void ActivationGradImpl(const Context& dev_ctx, const DenseTensor* X, const DenseTensor* Out, const DenseTensor* dOut, DenseTensor* dX, const Functor& functor) { if (static_cast(Functor::FwdDeps()) & static_cast(funcs::ActBwdOpFwdDeps::kDepOut)) { PADDLE_ENFORCE_NOT_NULL( Out, errors::NotFound("The input DenseTensor Out can not be nullptr")); } PADDLE_ENFORCE_NOT_NULL( dOut, errors::NotFound("The input DenseTensor dOut can not be nullptr")); PADDLE_ENFORCE_NOT_NULL( dX, errors::NotFound("The output DenseTensor dX can not be nullptr")); if (!Out) { Out = dOut; // fake out } if (static_cast(Functor::FwdDeps()) & static_cast(funcs::ActBwdOpFwdDeps::kDepX)) { PADDLE_ENFORCE_NOT_NULL( X, errors::NotFound("The input DenseTensor X can not be nullptr")); } else { VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); X = dX; } dev_ctx.template Alloc(dX); auto dout = phi::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad")); auto out = phi::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad")); auto dx = phi::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad")); auto x = phi::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad")); auto* place = dev_ctx.eigen_device(); // use 32bit index to speed up computation bool use_32bit_index = out.size() < Eigen::NumTraits::highest(); bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace()); if (use_32bit_index && is_gpu_place) { functor(*place, To32BitIndex(x), To32BitIndex(out), To32BitIndex(dout), To32BitIndex(dx)); } else { functor(*place, x, out, dout, dx); } } template void ActivationDoubleGradImpl(const Context& dev_ctx, const DenseTensor* X, const DenseTensor* Out, const DenseTensor* ddX, DenseTensor* dX, DenseTensor* dOut, DenseTensor* ddOut, const Functor& functor) { if (static_cast(Functor::FwdDeps()) & static_cast(funcs::ActBwdOpFwdDeps::kDepX)) { PADDLE_ENFORCE_NOT_NULL( X, errors::NotFound("The input DenseTensor X can not be nullptr")); } else { VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); X = ddX; } if (static_cast(Functor::FwdDeps()) & static_cast(funcs::ActBwdOpFwdDeps::kDepOut)) { PADDLE_ENFORCE_NOT_NULL( Out, errors::NotFound("The input DenseTensor Out can not be nullptr")); } else { VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); Out = ddX; } if (ddOut) { dev_ctx.template Alloc(ddOut); } if (dOut) { dev_ctx.template Alloc(dOut); } if (dX) { dX->Resize(Out->dims()); dev_ctx.template Alloc(dX); } functor(dev_ctx, X, Out, ddX, ddOut, dOut, dX); } template void ReluDoubleGradKernel(const Context& dev_ctx, const DenseTensor& out, const DenseTensor& ddx, DenseTensor* ddout) { funcs::ReluGradGradFunctor relu_double_grad_functor; ActivationDoubleGradImpl>( dev_ctx, nullptr, &out, &ddx, nullptr, nullptr, ddout, relu_double_grad_functor); } template void LeakyReluDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& ddx, float alpha, DenseTensor* ddout) { funcs::LeakyReluGradGradFunctor leaky_relu_double_grad_functor; leaky_relu_double_grad_functor.alpha = alpha; ActivationDoubleGradImpl>( dev_ctx, &x, nullptr, &ddx, nullptr, nullptr, ddout, leaky_relu_double_grad_functor); } template void TanhDoubleGradKernel(const Context& dev_ctx, const DenseTensor& out, const DenseTensor& dout, const DenseTensor& ddx, DenseTensor* dout_new, DenseTensor* ddout) { if (dout_new) { dout_new->Resize(out.dims()); dev_ctx.template Alloc(dout_new); } if (ddout) { ddout->Resize(out.dims()); dev_ctx.template Alloc(ddout); } funcs::TanhGradGradFunctor functor; functor(dev_ctx, &out, &ddx, &dout, dout_new, ddout); } template void TanhTripleGradKernel(const Context& dev_ctx, const DenseTensor& out, const DenseTensor& dout, const DenseTensor& ddx, const DenseTensor& d_dout_new, const DenseTensor& d_ddout, DenseTensor* d_out_new, DenseTensor* d_dout, DenseTensor* d_ddx) { if (d_dout) { d_dout->Resize(out.dims()); dev_ctx.template Alloc(d_dout); } if (d_out_new) { d_dout->Resize(out.dims()); dev_ctx.template Alloc(d_out_new); } if (d_ddx) { d_dout->Resize(ddx.dims()); dev_ctx.template Alloc(d_ddx); } funcs::TanhTripleGradFunctor functor; functor(dev_ctx, &out, &ddx, &dout, &d_ddout, &d_dout_new, // input d_dout, d_out_new, d_ddx); // output } template void EluDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, const DenseTensor& ddx, float alpha, DenseTensor* dx, DenseTensor* ddout) { if (dx) { dx->Resize(x.dims()); dev_ctx.template Alloc(dx); } if (ddout) { dev_ctx.template Alloc(ddout); } funcs::ELUGradGradFunctor functor; functor.alpha = alpha; functor(dev_ctx, &x, &ddx, ddout, &dout, dx); } template void LogitGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out_grad, float eps, DenseTensor* x_grad) { dev_ctx.template Alloc(x_grad); auto eigen_x = EigenVector::Flatten(x); auto eigen_dout = EigenVector::Flatten(out_grad); auto eigen_dx = EigenVector::Flatten(*x_grad); auto& place = *dev_ctx.eigen_device(); auto eigen_p = EigenVector::Flatten(x); funcs::LogitGradFunctor functor; functor(place, eigen_x, eigen_dout, eigen_dx, eigen_p, eps); } template void SigmoidDoubleGradKernel(const Context& dev_ctx, const DenseTensor& out, const DenseTensor& dout, const DenseTensor& ddx, DenseTensor* dout_new, DenseTensor* ddout) { if (dout_new) { dout_new->Resize(out.dims()); dev_ctx.template Alloc(dout_new); } if (ddout) { ddout->Resize(out.dims()); dev_ctx.template Alloc(ddout); } funcs::SigmoidGradGradFunctor functor; functor(dev_ctx, &out, &ddx, &dout, dout_new, ddout); } template void SigmoidTripleGradKernel(const Context& dev_ctx, const DenseTensor& out, const DenseTensor& dout, const DenseTensor& ddx, const DenseTensor& d_dout_new, const paddle::optional& d_ddout, DenseTensor* d_out_new, DenseTensor* d_dout, DenseTensor* d_ddx) { if (d_dout) { d_dout->Resize(out.dims()); dev_ctx.template Alloc(d_dout); } if (d_out_new) { d_out_new->Resize(out.dims()); dev_ctx.template Alloc(d_out_new); } if (d_ddx) { d_ddx->Resize(ddx.dims()); dev_ctx.template Alloc(d_ddx); } funcs::SigmoidTripleGradFunctor functor; functor(dev_ctx, &out, &ddx, &dout, d_ddout.get_ptr(), &d_dout_new, d_dout, d_out_new, d_ddx); } template void LogDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, const DenseTensor& ddx, DenseTensor* dx, DenseTensor* ddout) { if (dx) { dx->Resize(x.dims()); dev_ctx.template Alloc(dx); } if (ddout) { dev_ctx.template Alloc(ddout); } funcs::LogGradGradFunctor functor; functor(dev_ctx, &x, &ddx, ddout, &dout, dx); } template void PowGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, const Scalar& factor, DenseTensor* dx) { PADDLE_ENFORCE_NOT_NULL( dx, errors::NotFound("The output DenseTensor dX can not be nullptr")); if (dx) { dev_ctx.template Alloc(dx); } auto dout_flatten = EigenVector::Flatten( GET_DATA_SAFELY(&dout, "Input", "Out@GRAD", "PowGrad")); auto dx_flatten = EigenVector::Flatten( GET_DATA_SAFELY(dx, "Output", "X@GRAD", "PowGrad")); auto x_flatten = EigenVector::Flatten(GET_DATA_SAFELY(&x, "Input", "X", "PowGrad")); auto* place = dev_ctx.eigen_device(); phi::funcs::PowGradFunctor functor; auto attrs = functor.GetAttrs(); *(attrs[0].second) = factor.to(); functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten); } template void PowDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, const DenseTensor& ddx, const Scalar& factor, DenseTensor* dx, DenseTensor* ddout) { PADDLE_ENFORCE_NOT_NULL( dx, errors::NotFound("The output DenseTensor DX can not be nullptr")); PADDLE_ENFORCE_NOT_NULL( ddout, errors::NotFound("The output DenseTensor DDOut can not be nullptr")); float exponent = factor.to(); if (exponent == 1) { *dx = phi::FullLike(dev_ctx, x, static_cast(0)); } else { DenseTensor dx_tmp1 = phi::Multiply(dev_ctx, dout, ddx); DenseTensor dx_tmp2 = phi::Multiply( dev_ctx, dx_tmp1, phi::Pow(dev_ctx, x, exponent - 2)); *dx = phi::Scale( dev_ctx, dx_tmp2, exponent * (exponent - 1), 0.0, true); } DenseTensor ddout_tmp = phi::Multiply( dev_ctx, ddx, phi::Pow(dev_ctx, x, exponent - 1)); *ddout = phi::Scale(dev_ctx, ddout_tmp, exponent, 0.0, true); } template void PowTripleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, const DenseTensor& ddx, const DenseTensor& d_dx, const DenseTensor& d_ddout, const Scalar& factor, DenseTensor* out_d_x, DenseTensor* out_d_dout, DenseTensor* out_d_ddx) { PADDLE_ENFORCE_NOT_NULL( out_d_x, errors::NotFound("The output DenseTensor D_X can not be nullptr")); PADDLE_ENFORCE_NOT_NULL( out_d_dout, errors::NotFound("The output DenseTensor D_DOut can not be nullptr")); PADDLE_ENFORCE_NOT_NULL( out_d_ddx, errors::NotFound("The output DenseTensor D_DDX can not be nullptr")); float exponent = factor.to(); if (exponent != 2 && exponent != 1) { // case1: b != 2 and b != 1 // D_X = D_DX * DDX * DOut * b * (b-1) * (b-2) * X^(b-3) // + D_DDOut * DDX * b * (b-1) * X^(b-2) DenseTensor out_d_x_tmp1 = phi::Multiply(dev_ctx, d_dx, ddx); DenseTensor out_d_x_tmp2 = phi::Scale(dev_ctx, phi::Pow(dev_ctx, x, exponent - 3), exponent * (exponent - 1) * (exponent - 2), 0.0, true); DenseTensor out_d_x_part1 = phi::Multiply( dev_ctx, phi::Multiply(dev_ctx, out_d_x_tmp1, dout), out_d_x_tmp2); DenseTensor out_d_x_tmp3 = phi::Multiply(dev_ctx, d_ddout, ddx); DenseTensor out_d_x_tmp4 = phi::Scale(dev_ctx, phi::Pow(dev_ctx, x, exponent - 2), exponent * (exponent - 1), 0.0, true); DenseTensor out_d_x_part2 = phi::Multiply(dev_ctx, out_d_x_tmp3, out_d_x_tmp4); *out_d_x = phi::Add(dev_ctx, out_d_x_part1, out_d_x_part2); // D_DOut = D_DX * DDX * b * (b-1) * X^(b-2) DenseTensor out_d_dout_tmp = phi::Scale(dev_ctx, phi::Pow(dev_ctx, x, exponent - 2), exponent * (exponent - 1), 0.0, true); *out_d_dout = phi::Multiply(dev_ctx, out_d_x_tmp1, out_d_dout_tmp); // D_DDX = D_DX * DOut * b * (b-1) * X^(b-2) + D_DDOut * b * X^(b-1) DenseTensor out_d_ddx_tmp1 = phi::Multiply(dev_ctx, d_dx, dout); DenseTensor out_d_ddx_part1 = phi::Multiply(dev_ctx, out_d_ddx_tmp1, out_d_dout_tmp); DenseTensor out_d_ddx_tmp2 = phi::Scale(dev_ctx, phi::Pow(dev_ctx, x, exponent - 1), exponent, 0.0, true); DenseTensor out_d_ddx_part2 = phi::Multiply(dev_ctx, d_ddout, out_d_ddx_tmp2); *out_d_ddx = phi::Add(dev_ctx, out_d_ddx_part1, out_d_ddx_part2); } else if (exponent == 2) { // case2: b = 2 // D_X = D_DDOut * DDX * b * (b-1) * X^(b-2) DenseTensor out_d_x_tmp1 = phi::Multiply(dev_ctx, d_ddout, ddx); DenseTensor out_d_x_tmp2 = phi::Scale(dev_ctx, phi::Pow(dev_ctx, x, exponent - 2), exponent * (exponent - 1), 0.0, true); *out_d_x = phi::Multiply(dev_ctx, out_d_x_tmp1, out_d_x_tmp2); // D_DOut = D_DX * DDX * b * (b-1) * X^(b-2) DenseTensor out_d_dout_tmp1 = phi::Multiply(dev_ctx, d_dx, ddx); DenseTensor out_d_dout_tmp2 = phi::Scale(dev_ctx, phi::Pow(dev_ctx, x, exponent - 2), exponent * (exponent - 1), 0.0, true); *out_d_dout = phi::Multiply(dev_ctx, out_d_dout_tmp1, out_d_dout_tmp2); // D_DDX = D_DX * DOut * b * (b-1) * X^(b-2) + D_DDOut * b * X^(b-1) DenseTensor out_d_ddx_tmp1 = phi::Multiply(dev_ctx, d_dx, dout); DenseTensor out_d_ddx_part1 = phi::Multiply(dev_ctx, out_d_ddx_tmp1, out_d_dout_tmp2); DenseTensor out_d_ddx_tmp2 = phi::Scale(dev_ctx, phi::Pow(dev_ctx, x, exponent - 1), exponent, 0.0, true); DenseTensor out_d_ddx_part2 = phi::Multiply(dev_ctx, d_ddout, out_d_ddx_tmp2); *out_d_ddx = phi::Add(dev_ctx, out_d_ddx_part1, out_d_ddx_part2); } else { // case3: b = 1 // D_X = D_DX * DDX * DOut * b * (b-1) * (b-2) * X^(b-3) DenseTensor out_d_x_tmp1 = phi::Multiply(dev_ctx, d_dx, ddx); DenseTensor out_d_x_tmp2 = phi::Scale(dev_ctx, phi::Pow(dev_ctx, x, exponent - 3), exponent * (exponent - 1) * (exponent - 2), 0.0, true); *out_d_x = phi::Multiply( dev_ctx, phi::Multiply(dev_ctx, out_d_x_tmp1, dout), out_d_x_tmp2); // D_DOut = 0 *out_d_dout = phi::FullLike(dev_ctx, dout, static_cast(0)); // D_DDX = D_DDOut * b * X^(b-1) DenseTensor out_d_ddx_tmp = phi::Scale(dev_ctx, phi::Pow(dev_ctx, x, exponent - 1), exponent, 0.0, true); *out_d_ddx = phi::Multiply(dev_ctx, d_ddout, out_d_ddx_tmp); } } template void SqrtDoubleGradKernel(const Context& dev_ctx, const DenseTensor& out, const DenseTensor& dx, const DenseTensor& ddx, DenseTensor* dout, DenseTensor* ddout) { if (dout) { dout->Resize(out.dims()); dev_ctx.template Alloc(dout); } if (ddout) { ddout->Resize(out.dims()); dev_ctx.template Alloc(ddout); } phi::funcs::SqrtGradGradFunctor functor; functor(dev_ctx, &out, &dx, &ddx, dout, ddout); } // rsqrt Grad: dx = -0.5 * dy * y * y * y // rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3 / y) * dx * ddx template void RsqrtDoubleGradKernel(const Context& dev_ctx, const DenseTensor& out, const DenseTensor& dx, const DenseTensor& ddx, DenseTensor* dout, DenseTensor* ddout) { if (dout) { dout->Resize(out.dims()); dev_ctx.template Alloc(dout); } if (ddout) { ddout->Resize(out.dims()); dev_ctx.template Alloc(ddout); } phi::funcs::RsqrtGradGradFunctor functor; functor(dev_ctx, &out, &dx, &ddx, dout, ddout); } template void CeluDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, const DenseTensor& ddx, float alpha, DenseTensor* dx, DenseTensor* ddout) { if (dx) { dx->Resize(x.dims()); dev_ctx.template Alloc(dx); } if (ddout) { dev_ctx.template Alloc(ddout); } phi::funcs::CELUGradGradFunctor functor; auto attrs = functor.GetAttrs(); *(attrs[0].second) = alpha; functor(dev_ctx, &x, &dout, &ddx, dx, ddout); } template void SquareDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, const DenseTensor& ddx, DenseTensor* dx, DenseTensor* ddout) { if (dx) { dx->Resize(x.dims()); dev_ctx.template Alloc(dx); } if (ddout) { dev_ctx.template Alloc(ddout); } phi::funcs::SquareGradGradFunctor functor; functor(dev_ctx, &x, &dout, &ddx, dx, ddout); } template void SinDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, const DenseTensor& ddx, DenseTensor* dx, DenseTensor* ddout) { if (dx) { dx->Resize(x.dims()); dev_ctx.template Alloc(dx); } if (ddout) { dev_ctx.template Alloc(ddout); } phi::funcs::SinDoubleGradFunctor functor; functor(dev_ctx, &x, &dout, &ddx, dx, ddout); } template void SinTripleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, const DenseTensor& ddx, const DenseTensor& d_dx_new, const DenseTensor& d_ddout, DenseTensor* d_x_new, DenseTensor* d_dout, DenseTensor* d_ddx) { if (d_dout) { d_dout->Resize(x.dims()); dev_ctx.template Alloc(d_dout); } if (d_x_new) { d_dout->Resize(x.dims()); dev_ctx.template Alloc(d_x_new); } if (d_ddx) { d_dout->Resize(ddx.dims()); dev_ctx.template Alloc(d_ddx); } funcs::SinTripleGradFunctor functor; functor(dev_ctx, &x, &ddx, &dout, &d_ddout, &d_dx_new, // input d_dout, d_x_new, d_ddx); // output } template void CosDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, const DenseTensor& ddx, DenseTensor* dx, DenseTensor* ddout) { if (dx) { dx->Resize(x.dims()); dev_ctx.template Alloc(dx); } if (ddout) { dev_ctx.template Alloc(ddout); } phi::funcs::CosDoubleGradFunctor functor; functor(dev_ctx, &x, &dout, &ddx, dx, ddout); } template void CosTripleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, const DenseTensor& ddx, const DenseTensor& d_dx_new, const DenseTensor& d_ddout, DenseTensor* d_x_new, DenseTensor* d_dout, DenseTensor* d_ddx) { if (d_dout) { d_dout->Resize(x.dims()); dev_ctx.template Alloc(d_dout); } if (d_x_new) { d_dout->Resize(x.dims()); dev_ctx.template Alloc(d_x_new); } if (d_ddx) { d_dout->Resize(ddx.dims()); dev_ctx.template Alloc(d_ddx); } funcs::CosTripleGradFunctor functor; functor(dev_ctx, &x, &ddx, &dout, &d_ddout, &d_dx_new, // input d_dout, d_x_new, d_ddx); // output } } // namespace phi