// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/activation_functor.h" #include "paddle/fluid/platform/device_context.h" namespace phi { template void ActivationGradImpl(const Context& dev_ctx, const DenseTensor* X, const DenseTensor* Out, const DenseTensor* dOut, DenseTensor* dX, const Functor& functor) { if (static_cast(Functor::FwdDeps()) & static_cast(funcs::ActBwdOpFwdDeps::kDepOut)) { PADDLE_ENFORCE_NOT_NULL( Out, errors::NotFound("The input DenseTensor Out can not be nullptr")); } PADDLE_ENFORCE_NOT_NULL( dOut, errors::NotFound("The input DenseTensor dOut can not be nullptr")); PADDLE_ENFORCE_NOT_NULL( dX, errors::NotFound("The output DenseTensor dX can not be nullptr")); if (!Out) { Out = dOut; // fake out } if (static_cast(Functor::FwdDeps()) & static_cast(funcs::ActBwdOpFwdDeps::kDepX)) { PADDLE_ENFORCE_NOT_NULL( X, errors::NotFound("The input DenseTensor X can not be nullptr")); } else { VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); X = dX; } dev_ctx.template Alloc(dX); auto dout = phi::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad")); auto out = phi::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad")); auto dx = phi::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad")); auto x = phi::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad")); auto* place = dev_ctx.eigen_device(); // use 32bit index to speed up computation bool use_32bit_index = out.size() < Eigen::NumTraits::highest(); bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace()); if (use_32bit_index && is_gpu_place) { functor(*place, To32BitIndex(x), To32BitIndex(out), To32BitIndex(dout), To32BitIndex(dx)); } else { functor(*place, x, out, dout, dx); } } template void ActivationDoubleGradImpl(const Context& dev_ctx, const DenseTensor* X, const DenseTensor* Out, const DenseTensor* ddX, DenseTensor* dX, DenseTensor* dOut, DenseTensor* ddOut, const Functor& functor) { if (static_cast(Functor::FwdDeps()) & static_cast(funcs::ActBwdOpFwdDeps::kDepX)) { PADDLE_ENFORCE_NOT_NULL( X, errors::NotFound("The input DenseTensor X can not be nullptr")); } else { VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); X = ddX; } if (static_cast(Functor::FwdDeps()) & static_cast(funcs::ActBwdOpFwdDeps::kDepOut)) { PADDLE_ENFORCE_NOT_NULL( Out, errors::NotFound("The input DenseTensor Out can not be nullptr")); } else { VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name(); Out = ddX; } if (ddOut) { dev_ctx.template Alloc(ddOut); } if (dOut) { dev_ctx.template Alloc(dOut); } if (dX) { dX->Resize(Out->dims()); dev_ctx.template Alloc(dX); } functor(dev_ctx, X, Out, ddX, ddOut, dOut, dX); } template void ReluDoubleGradKernel(const Context& dev_ctx, const DenseTensor& out, const DenseTensor& ddx, DenseTensor* ddout) { funcs::ReluGradGradFunctor relu_double_grad_functor; ActivationDoubleGradImpl>( dev_ctx, nullptr, &out, &ddx, nullptr, nullptr, ddout, relu_double_grad_functor); } template void LeakyReluDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& ddx, float alpha, DenseTensor* ddout) { funcs::LeakyReluGradGradFunctor leaky_relu_double_grad_functor; leaky_relu_double_grad_functor.alpha = alpha; ActivationDoubleGradImpl>( dev_ctx, &x, nullptr, &ddx, nullptr, nullptr, ddout, leaky_relu_double_grad_functor); } template void TanhDoubleGradKernel(const Context& dev_ctx, const DenseTensor& out, const DenseTensor& dout, const DenseTensor& ddx, DenseTensor* dout_new, DenseTensor* ddout) { if (dout_new) { dout_new->Resize(out.dims()); dev_ctx.template Alloc(dout_new); } if (ddout) { ddout->Resize(out.dims()); dev_ctx.template Alloc(ddout); } funcs::TanhGradGradFunctor functor; functor(dev_ctx, &out, &ddx, &dout, dout_new, ddout); } template void TanhTripleGradKernel(const Context& dev_ctx, const DenseTensor& out, const DenseTensor& dout, const DenseTensor& ddx, const DenseTensor& d_dout_new, const DenseTensor& d_ddout, DenseTensor* d_out_new, DenseTensor* d_dout, DenseTensor* d_ddx) { if (d_dout) { d_dout->Resize(out.dims()); dev_ctx.template Alloc(d_dout); } if (d_out_new) { d_dout->Resize(out.dims()); dev_ctx.template Alloc(d_out_new); } if (d_ddx) { d_dout->Resize(ddx.dims()); dev_ctx.template Alloc(d_ddx); } funcs::TanhTripleGradFunctor functor; functor(dev_ctx, &out, &ddx, &dout, &d_ddout, &d_dout_new, // input d_dout, d_out_new, d_ddx); // output } template void EluDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, const DenseTensor& ddx, float alpha, DenseTensor* dx, DenseTensor* ddout) { if (dx) { dx->Resize(x.dims()); dev_ctx.template Alloc(dx); } if (ddout) { dev_ctx.template Alloc(ddout); } funcs::ELUGradGradFunctor functor; functor.alpha = alpha; functor(dev_ctx, &x, &ddx, ddout, &dout, dx); } template void LogitGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out_grad, float eps, DenseTensor* x_grad) { dev_ctx.template Alloc(x_grad); auto eigen_x = EigenVector::Flatten(x); auto eigen_dout = EigenVector::Flatten(out_grad); auto eigen_dx = EigenVector::Flatten(*x_grad); auto& place = *dev_ctx.eigen_device(); auto eigen_p = EigenVector::Flatten(x); funcs::LogitGradFunctor functor; functor(place, eigen_x, eigen_dout, eigen_dx, eigen_p, eps); } template void SigmoidDoubleGradKernel(const Context& dev_ctx, const DenseTensor& out, const DenseTensor& dout, const DenseTensor& ddx, DenseTensor* dout_new, DenseTensor* ddout) { if (dout_new) { dout_new->Resize(out.dims()); dev_ctx.template Alloc(dout_new); } if (ddout) { ddout->Resize(out.dims()); dev_ctx.template Alloc(ddout); } funcs::SigmoidGradGradFunctor functor; functor(dev_ctx, &out, &ddx, &dout, dout_new, ddout); } template void SigmoidTripleGradKernel(const Context& dev_ctx, const DenseTensor& out, const DenseTensor& dout, const DenseTensor& ddx, const DenseTensor& d_dout_new, paddle::optional d_ddout, DenseTensor* d_out_new, DenseTensor* d_dout, DenseTensor* d_ddx) { if (d_dout) { d_dout->Resize(out.dims()); dev_ctx.template Alloc(d_dout); } if (d_out_new) { d_out_new->Resize(out.dims()); dev_ctx.template Alloc(d_out_new); } if (d_ddx) { d_ddx->Resize(ddx.dims()); dev_ctx.template Alloc(d_ddx); } funcs::SigmoidTripleGradFunctor functor; functor(dev_ctx, &out, &ddx, &dout, d_ddout.get_ptr(), &d_dout_new, d_dout, d_out_new, d_ddx); } template void LogDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, const DenseTensor& ddx, DenseTensor* dx, DenseTensor* ddout) { if (dx) { dx->Resize(x.dims()); dev_ctx.template Alloc(dx); } if (ddout) { dev_ctx.template Alloc(ddout); } funcs::LogGradGradFunctor functor; functor(dev_ctx, &x, &ddx, ddout, &dout, dx); } template void PowGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& dout, const Scalar& factor, DenseTensor* dx) { PADDLE_ENFORCE_NOT_NULL( dx, errors::NotFound("The output DenseTensor dX can not be nullptr")); if (dx) { dev_ctx.template Alloc(dx); } auto dout_flatten = EigenVector::Flatten( GET_DATA_SAFELY(&dout, "Input", "Out@GRAD", "PowGrad")); auto dx_flatten = EigenVector::Flatten( GET_DATA_SAFELY(dx, "Output", "X@GRAD", "PowGrad")); auto x_flatten = EigenVector::Flatten(GET_DATA_SAFELY(&x, "Input", "X", "PowGrad")); auto* place = dev_ctx.eigen_device(); phi::funcs::PowGradFunctor functor; auto attrs = functor.GetAttrs(); *(attrs[0].second) = factor.to(); functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten); } } // namespace phi