diff --git a/paddle/fluid/operators/kernel_primitives/compute_primitives.h b/paddle/fluid/operators/kernel_primitives/compute_primitives.h index 1d23cfe007558f843e12ec43803bf4963e43e072..ccd301aa8ca3d44904669e59c970cf61fd1ef032 100644 --- a/paddle/fluid/operators/kernel_primitives/compute_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/compute_primitives.h @@ -14,8 +14,140 @@ #pragma once +#ifdef PADDLE_WITH_CUDA +#include +#endif +#ifdef PADDLE_WITH_HIP +#include +#endif + +#include +#include "paddle/fluid/platform/float16.h" + namespace paddle { namespace operators { -namespace kernel_primitives {} +namespace kernel_primitives { +namespace details { + +template +class MPTypeTrait { + public: + using Type = T; +}; + +template <> +class MPTypeTrait { + public: + using Type = float; +}; + +} // namespace details + +/*************************** Compute Functor****************************/ +template +struct DivFunctor { + inline HOSTDEVICE T operator()(const T* args) const { + return args[0] / args[1]; + } +}; + +template +struct DivFunctor::value>> { + inline HOSTDEVICE T operator()(const T* args) const { + PADDLE_ENFORCE(args[1] != 0, + platform::errors::InvalidArgument( + "Invalid Argument Error: Integer division by zero " + "encountered in divide. Please check the input value.")); + return args[0] / args[1]; + } +}; + +/*************************** Compute Function****************************/ + +/** + * @brief compute functor for elementwise_two, in1 and in2 has the same shape + * @param: + * T : the type of in1 and in2 + * NX: the row of in1 and in2 + * NY: the col of in1 and in2 + * BlockSize: the strid of col + * OpFunc: compute functor eg: ADD, SUB, XOR, OR, MUL + */ +template +__device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1, + const T* in2, + OpFunc compute) { + T args[2]; +#pragma unroll + for (int idx = 0; idx < NX * NY; ++idx) { + args[0] = in1[idx]; + args[1] = in2[idx]; + out[idx] = static_cast(compute(args)); + } +} + +/** + * @brief fma eg: a * b + c, in1 in2, in3 and out has the same shape + * @param: + * T : the type of in1 and in2, in3 + * NX: the row of in1, in2 and in3 + * NY: the col of in1, in2 and in3 + * BlockSize: the strid of col + */ +template +__device__ __forceinline__ void ElementwiseFma(OutT* out, const T* in1, + const T* in2, const T* in3, + OpFunc compute) { +#pragma unroll + for (int idx = 0; idx < NX * NY; ++idx) { + out[idx] = static_cast(compute(in1[idx], in2[idx], in3[idx])); + } } + +/** + * @brief compute functor for elementwise_two, in1 is [1, NY], in2 is [NX, NY] + * @param: + * T : the type of in1 and in2 + * NX: the row of in1 and in2 + * NY: the col of in2 + * BlockSize: the strid of col + * OpFunc: compute functor eg: ADD, SUB, XOR, OR, MUL + */ +template +__device__ __forceinline__ void CycleBinary(OutT* out, const T* in1, + const T* in2, OpFunc compute) { +#pragma unroll + for (int idx = 0; idx < NX; idx++) { +#pragma unroll + for (int idy = 0; idy < NY; idy++) { + out[idx + idy * NX] = + static_cast(compute(in1[idx], in2[idx + idy * NX])); + } + } } + +/** + * @brief compute functor for unary, in1 is [NX, NY] + * @param: + * T : the type of in + * NX: the row of in + * NY: the col of in + * BlockSize: the strid of col + * OpFunc: compute functor eg: relu, sigmoid, exp + */ +template +__device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in, + OpFunc compute) { +#pragma unroll + for (int idx = 0; idx < NX * NY; idx++) { + out[idx] = static_cast(compute(in + idx)); + } +} + +} // namespace kernel_primitives +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/kernel_primitives/datamover_primitives.h b/paddle/fluid/operators/kernel_primitives/datamover_primitives.h index 1d23cfe007558f843e12ec43803bf4963e43e072..d520c33ca9bccf25483ae229aca2a49c8df18e3a 100644 --- a/paddle/fluid/operators/kernel_primitives/datamover_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/datamover_primitives.h @@ -13,9 +13,205 @@ // limitations under the License. #pragma once +#include +#include +#include +#include +#include namespace paddle { namespace operators { -namespace kernel_primitives {} +namespace kernel_primitives { +namespace details { + +#define INT_BITS 32 + +template +struct alignas(sizeof(T) * VecSize) VectorType { + T val[VecSize]; +}; + +struct FastDivMod { + // 1st value represents the result of input number divides by recorded divisor + // 2nd value represents the result of input number modulo by recorded divisor + using DivModT = VectorType; + + FastDivMod() {} + HOSTDEVICE FastDivMod(uint32_t d) : divisor(d) { + static_assert(sizeof(unsigned int) == 4, + "Only Support 32-bit unsigned int."); + + for (shift_val = 0; shift_val < INT_BITS; ++shift_val) { + auto shift_limit = 1 << shift_val; + if (shift_limit >= divisor) break; + } + uint64_t long_one = 1; + uint64_t temp_div = + ((long_one << INT_BITS) * ((long_one << shift_val) - divisor)) / + divisor + + 1; + multiplier = temp_div; + } + + __device__ __forceinline__ uint32_t Div(uint32_t n) const { + uint32_t t = __umulhi(n, multiplier); + return (t + n) >> shift_val; + } + + __device__ __forceinline__ DivModT Divmod(uint32_t n) const { + uint32_t q = Div(n); + DivModT result = {q, n - q * divisor}; + return result; + } + + int32_t divisor; + int32_t shift_val; + uint32_t multiplier; +}; + +template +struct BroadcastConfig { + FastDivMod divmoders[kDims]; + uint32_t strides[framework::DDim::kMaxRank]; + HOSTDEVICE BroadcastConfig() {} + + HOSTDEVICE BroadcastConfig(const std::vector& out_dims, + const std::vector& in_dims, + int dim_size) { + std::vector strides_in; + std::vector divmoders_in; + // for divmoders + divmoders_in.resize(dim_size); + for (int i = 0; i < dim_size; ++i) { + divmoders_in[i] = FastDivMod(out_dims[i]); + } + // for strides + strides_in.resize(dim_size, 1); + for (int i = 0; i < dim_size; ++i) { + strides_in[i] = in_dims[i] == 1 ? 0 : strides_in[i]; + strides_in[i] = + (i != 0 && strides_in[i] != 0) + ? std::accumulate(in_dims.begin(), in_dims.begin() + i, 1, + std::multiplies()) + : strides_in[i]; + } + + memcpy(strides, strides_in.data(), kDims * sizeof(uint32_t)); + memcpy(divmoders, divmoders_in.data(), kDims * sizeof(FastDivMod)); + } +}; + +#undef INT_BITS +} // namespace details + +template +__device__ __forceinline__ void ReadDataBase(T* dst, const T* __restrict__ src, + int size) { + int dx = threadIdx.x * NX; +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if ((idx + dx) >= size) { + break; + } + dst[idx] = src[idx + dx]; + } +} + +template +__device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src, + int size) { + const int VECTOR_SIZE = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; + const int VECTORS_PER_THREAD = NX / VECTOR_SIZE; + + // Vector per thread + if (blockDim.x * NX > size) { + ReadDataBase(dst, src, size); + } else { + // Vector type + using VecType = details::VectorType; + VecType vec_temp[VECTORS_PER_THREAD]; + const VecType* vec_input = reinterpret_cast(src); + ReadDataBase( + vec_temp, vec_input, VECTORS_PER_THREAD * blockDim.x); +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + dst[idx] = *(reinterpret_cast(vec_temp) + idx); + } + } +} + +/** @brief: ReadDataBc + * read data from src ptr when the shape of src and dst are different + * @param: + * src: the source pointer + * dst: the dst pointer + * stride_nx: the stride of src + * stride_ny: the stride of src + * the shape of dst is [NY, NX] + */ +template +__device__ __forceinline__ void ReadDataBc( + T* dst, const T* __restrict__ src, uint32_t fix, + details::BroadcastConfig config, int num, int stride_nx, + int stride_ny) { + uint32_t base_offset = fix + threadIdx.x * NX; + uint32_t offset = 0; + +#pragma unroll + for (int ny = 0; ny < NY; ++ny) { +#pragma unroll + for (uint32_t nx = 0; nx < NX; ++nx) { + uint32_t idx = base_offset + ny * stride_ny + nx * stride_nx; + if (idx < num) { + offset = 0; +#pragma unroll + for (int i = 0; i < ShapeSize; ++i) { + auto fast_divmoder = config.divmoders[i].Divmod(idx); + idx = fast_divmoder.val[0]; + offset += fast_divmoder.val[1] * config.strides[i]; + } + dst[nx + ny * NX] = src[offset]; + } + } + } +} + +template +__device__ __forceinline__ void WriteDataBase(T* dst, const T* __restrict__ src, + int size) { + int dx = threadIdx.x * NX; +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if ((idx + dx) >= size) { + break; + } + dst[idx + dx] = src[idx]; + } } + +template +__device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src, + int size) { + const int VECTOR_SIZE = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; + const int VECTORS_PER_THREAD = NX / VECTOR_SIZE; + + // Vector per thread + if (blockDim.x * NX > size) { + WriteDataBase(dst, src, size); + } else { + // Vector type + using VecType = details::VectorType; + VecType vec_temp[VECTORS_PER_THREAD]; +#pragma unroll + for (int idx = 0; idx < VECTORS_PER_THREAD; ++idx) { + vec_temp[idx] = *(reinterpret_cast(src) + idx); + } + VecType* vec_dst = reinterpret_cast(dst); + WriteDataBase( + vec_dst, vec_temp, VECTORS_PER_THREAD * blockDim.x); + } } + +} // namespace kernel_primitives +} // namespace operators +} // namespace paddle