diff --git a/paddle/fluid/operators/kernel_primitives/functor_primitives.h b/paddle/fluid/operators/kernel_primitives/functor_primitives.h new file mode 100644 index 0000000000000000000000000000000000000000..fcfcdc28b1f00964cda5d97ae403fe2d3df7948c --- /dev/null +++ b/paddle/fluid/operators/kernel_primitives/functor_primitives.h @@ -0,0 +1,230 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace operators { +namespace kernel_primitives { +namespace details { + +static __device__ __forceinline__ platform::float16 Exp(platform::float16 x) { + return ::Eigen::numext::exp(x); +} + +static __device__ __forceinline__ float Exp(float x) { return expf(x); } + +static __device__ __forceinline__ double Exp(double x) { return exp(x); } + +static __device__ __forceinline__ platform::float16 Log(platform::float16 x) { + return ::Eigen::numext::log(x); +} + +static __device__ __forceinline__ float Log(float x) { return logf(x); } + +static __device__ __forceinline__ double Log(double x) { return log(x); } + +} // namespace details + +/******************************** Unary Functor *******************************/ + +/** + * @brief Default unary exp functor + */ +template +struct ExpFunctor { + HOSTDEVICE inline ExpFunctor() {} + + HOSTDEVICE explicit inline ExpFunctor(int n) {} + + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(details::Exp(x)); + } +}; + +/** + * @brief Default unary identity functor + */ +template +struct IdentityFunctor { + HOSTDEVICE inline IdentityFunctor() {} + + HOSTDEVICE explicit inline IdentityFunctor(int n) {} + + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(x); + } +}; + +/** + * @brief Default unary div functor. Divide by a constant + */ +template +struct DivideFunctor { + HOSTDEVICE inline DivideFunctor() { n_inv = static_cast(1.0f); } + + HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((Tx)(1.0 / n)) {} + + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(x * n_inv); + } + + private: + Tx n_inv; +}; + +/** + * @brief Default unary square functor + */ +template +struct SquareFunctor { + HOSTDEVICE inline SquareFunctor() {} + + HOSTDEVICE explicit inline SquareFunctor(int n) {} + + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(x) * static_cast(x); + } +}; + +/****************************** Binary Functor ********************************/ + +/** + * @brief Default binary min functor + */ +template +struct MinFunctor { + inline T initial() { return static_cast(std::numeric_limits::max()); } + + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return (b < a) ? b : a; + } +}; + +/** + * @brief Default binary max functor + */ +template +struct MaxFunctor { + inline T initial() { + return static_cast(std::numeric_limits::lowest()); + } + + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return (b > a) ? b : a; + } +}; + +/** + * @brief Default binary add functor + */ +template +struct AddFunctor { + inline T initial() { return static_cast(0.0f); } + + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return b + a; + } +}; + +/** + * @brief Default binary add functor + */ +template +struct MulFunctor { + inline T initial() { return static_cast(1.0f); } + + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return b * a; + } +}; + +/** + * @brief Default binary logic or functor + */ +template +struct LogicalOrFunctor { + inline T initial() { return static_cast(false); } + + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return b || a; + } +}; + +/** + * @brief Default binary logic and functor + */ +template +struct LogicalAndFunctor { + inline T initial() { return static_cast(true); } + + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return b && a; + } +}; + +/** + * @brief Default binary sub functor + */ +template +struct SubFunctor { + inline T initial() { return static_cast(0.0f); } + + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a - b; } +}; + +/** + * @brief Default binary div functor + */ +template +struct DivFunctor { + inline T initial() { return static_cast(1.0f); } + + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a / b; } +}; + +template +struct DivFunctor::value>::type> { + inline T initial() { return static_cast(1.0f); } + + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + // For int32/int64, need to check whether the divison is zero. + PADDLE_ENFORCE_NE(b, 0, + platform::errors::InvalidArgument( + "Integer division by zero encountered " + "in (floor) divide. Please check the input value.")); + return a / b; + } +}; + +/** + * @brief Default binary floor divide functor + */ +template +struct FloorDivFunctor { + inline T initial() { return static_cast(1.0f); } + + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + PADDLE_ENFORCE_NE(b, 0, + platform::errors::InvalidArgument( + "Integer division by zero encountered " + "in (floor) divide. Please check the input value.")); + return static_cast(std::trunc(a / b)); + } +}; + +} // namespace kernel_primitives +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/kernel_primitives/kernel_primitives.h b/paddle/fluid/operators/kernel_primitives/kernel_primitives.h index 45ee4fd738174b86e899ebab65b0cac41deed9de..9a4f8bb026b9da6d6d3d04d8ed020d0a98c01548 100644 --- a/paddle/fluid/operators/kernel_primitives/kernel_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/kernel_primitives.h @@ -16,6 +16,7 @@ #include "paddle/fluid/operators/kernel_primitives/compute_primitives.h" #include "paddle/fluid/operators/kernel_primitives/datamover_primitives.h" +#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h" #include "paddle/fluid/operators/kernel_primitives/helper_primitives.h" namespace paddle {