// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include #ifndef _USE_MATH_DEFINES #define _USE_MATH_DEFINES #endif #include #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { namespace funcs { enum ActBwdOpFwdDeps { kNoDeps = 0x00, // Do not need any forward input/output kDepX = 0x01, // Only need forward input X kDepOut = 0x02, // Only need forward output Out }; template struct BaseActivationFunctor { using ELEMENT_TYPE = T; using AttrPair = std::vector>; AttrPair GetAttrs() { return AttrPair(); } }; template struct Sine { HOSTDEVICE T operator()(const T& val) const { return sin(val); } }; template <> struct Sine { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(sin(static_cast(val))); } }; template struct Cosine { HOSTDEVICE T operator()(const T& val) const { return cos(val); } }; template <> struct Cosine { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(cos(static_cast(val))); } }; // sine'(x) = cos(x) template struct SinGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.unaryExpr(Cosine()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // sine(x) = sin(x) template struct SinFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Sine()); } }; // cosine'(x) = -sin(x) template struct CosGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = -dout * x.unaryExpr(Sine()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // cosine(x) = cos(x) template struct CosFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Cosine()); } }; template struct Tangent { HOSTDEVICE T operator()(const T& val) const { return tan(val); } }; template <> struct Tangent { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(tan(static_cast(val))); } }; // Tangent'(x) = -Tangent(x) template struct TanGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout / x.unaryExpr(Cosine()).square(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // Tangent(x) = tan(x) template struct TanFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Tangent()); } }; template struct Sinh { HOSTDEVICE T operator()(const T& val) const { return sinh(val); } }; template <> struct Sinh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(sinhf(static_cast(val))); } }; template struct Cosh { HOSTDEVICE T operator()(const T& val) const { return cosh(val); } }; template <> struct Cosh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(coshf(static_cast(val))); } }; // sinh(x) = sinh(x) template struct SinhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Sinh()); } }; // cosh(x) = cosh(x) template struct CoshFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Cosh()); } }; // sinh'(x) = cosh(x) template struct SinhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.unaryExpr(Cosh()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // cosh'(x) = sinh(x) template struct CoshGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.unaryExpr(Sinh()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Acos { HOSTDEVICE T operator()(const T& val) const { return acos(val); } }; template <> struct Acos { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(acos(static_cast(val))); } }; // Acos(x) = acos(x) template struct AcosFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Acos()); } }; // acos'(x) = -1/sqrt(1-x^2) template struct AcosGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = -dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Asin { HOSTDEVICE T operator()(const T& val) const { return asin(val); } }; template <> struct Asin { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(asin(static_cast(val))); } }; // Asin(x) = asin(x) template struct AsinFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Asin()); } }; // asin'(x) = 1/sqrt(1-x^2) template struct AsinGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Atan { HOSTDEVICE T operator()(const T& val) const { return atan(val); } }; template <> struct Atan { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(atan(static_cast(val))); } }; // Atan(x) = atan(x) template struct AtanFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Atan()); } }; // atan'(x) = 1 / (1 + x^2) template struct AtanGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (static_cast(1) + x.square()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Acosh { HOSTDEVICE T operator()(const T& val) const { return acosh(val); } }; template <> struct Acosh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(acosh(static_cast(val))); } }; // Acosh(x) = acosh(x) template struct AcoshFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Acosh()); } }; // acosh'(x) = 1/sqrt(x^2 - 1) template struct AcoshGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (x * x - static_cast(1)).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Asinh { HOSTDEVICE T operator()(const T& val) const { return asinh(val); } }; template <> struct Asinh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(asinh(static_cast(val))); } }; // Asinh(x) = asinh(x) template struct AsinhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Asinh()); } }; // asinh'(x) = 1/sqrt(x^2 + 1) template struct AsinhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (x.square() + static_cast(1)).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Atanh { HOSTDEVICE T operator()(const T& val) const { return atanh(val); } }; template <> struct Atanh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(atanh(static_cast(val))); } }; // Atanh(x) = atanh(x) template struct AtanhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Atanh()); } }; // atanh'(x) = 1/(1 - x^2) template struct AtanhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (static_cast(1) - x.square()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // relu(x) = max(x, 0) template struct ReluCPUFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr([] HOSTDEVICE(T v) { return v > static_cast(0) ? v : static_cast(0); }); } }; template struct ReluCUDAFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(0)); } }; template struct ReluGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (out > static_cast(0)).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct ReluGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* X, const DenseTensor* Out, const DenseTensor* ddX, DenseTensor* ddOut, DenseTensor* dOut, DenseTensor* dX) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "ReluGradGrad")); auto out = EigenVector::Flatten( GET_DATA_SAFELY(Out, "Output", "Out", "ReluGradGrad")); if (ddOut) { auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ReluGradGrad")); ddout.device(*d) = ddx * (out > static_cast(0)).template cast(); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; #if defined(__NVCC__) || defined(__HIPCC__) template struct CudaReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); // relu(x) = max(x, 0) __device__ __forceinline__ T operator()(const T x) const { return x > zero ? x : zero; } }; template struct CudaReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); // dx = dout * (out > 0) __device__ __forceinline__ T operator()(const T dout, const T out) const { return out > zero ? dout : zero; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaCosFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // cos(x) = cos(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(cos(x)); } }; template struct CudaCosGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * (-sin(x)) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(-dout * sin(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSinFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // sin(x) = sin(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(sin(x)); } }; template struct CudaSinGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * cos(x) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * cos(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaTanFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // tan(x) = tan(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(tan(x)); } }; template struct CudaTanGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout / cos(x)^2 __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout / (cos(x) * cos(x))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAsinFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // asin(x) = asin(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(asin(x)); } }; template struct CudaAsinGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout / sqrt(1 - x^2) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout / sqrt(one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAcosFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // acos(x) = acos(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(acos(x)); } }; template struct CudaAcosGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = -dout / sqrt(1 - x^2) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(-dout / sqrt(one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaCoshFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // cosh(x) = cosh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(cosh(x)); } }; template struct CudaCoshGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * sinh(x) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * sinh(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSinhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // sinh(x) = sinh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(sinh(x)); } }; template struct CudaSinhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * cosh(x) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * cosh(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAcoshFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // Acosh(x) = acosh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(acosh(x)); } }; template struct CudaAcoshGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * 1 / sqrt(x^2 - 1) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * one / sqrt(x * x - one)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAsinhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // Asinh(x) = asinh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(asinh(x)); } }; template struct CudaAsinhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * 1/sqrt(x^2 + 1) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * one / sqrt(x * x + one)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAtanhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // Atanh(x) = atanh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(atanh(x)); } }; template struct CudaAtanhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * 1/(1- x^2) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * one / (one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAtanFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // atan(x) = atan(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(atan(x)); } }; template struct CudaAtanGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // dx = dout / (1 + x^2) __device__ __forceinline__ T operator()(const T dout, const T x) const { return dout / (one + x * x); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; #endif } // namespace funcs } // namespace phi