/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #define MATRIX_SOFTMAX_ALIGN_BYTES 16 #define MATRIX_SOFTMAX_THREAHOLD 100000 #define FIXED_BLOCK_DIM_BASE(dim, ...) \ case (dim): { \ constexpr auto kBlockDim = (dim); \ __VA_ARGS__; \ } break #define FIXED_VEC_SIZE_BASE(vec_size, ...) \ case (vec_size): { \ constexpr auto VecSize = (vec_size); \ __VA_ARGS__; \ } break #define FIXED_BLOCK_DIM(...) \ FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) #define FIXED_VEC_SIZE(...) \ FIXED_VEC_SIZE_BASE(8, ##__VA_ARGS__); \ FIXED_VEC_SIZE_BASE(4, ##__VA_ARGS__) namespace phi { using ScopedTensorDescriptor = paddle::platform::ScopedTensorDescriptor; using GPUDNNDataLayout = paddle::platform::DataLayout; // Vectorization trait 4 * sizeof(T) template class VecT4 {}; template <> class VecT4 { public: using Type = long4; }; template <> class VecT4 { public: using Type = int4; }; template <> class VecT4 { public: using Type = int2; }; template <> class VecT4 { public: using Type = int2; }; // Vectorization trait 2 * sizeof(T) template class VecT2 {}; template <> class VecT2 { public: using Type = int4; }; template <> class VecT2 { public: using Type = int2; }; template <> class VecT2 { public: using Type = int; }; template <> class VecT2 { public: using Type = int; }; static inline int Log2Ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; return log2_value; } inline int getBlockSize(int vec_size, uint64_t dim_size) { uint64_t block_size = 1; uint64_t max_block_size = std::min(dim_size / vec_size, static_cast(1024)); if (vec_size > 1) { max_block_size /= 2; } while (block_size < (max_block_size)) block_size *= 2; block_size = std::max(block_size, static_cast(32)); return block_size; } template __device__ __forceinline__ void WarpReduceSum(T* sum) { #pragma unroll for (int offset = WarpSize / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < BatchSize; ++i) { T sum_val = paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); sum[i] = sum[i] + sum_val; } } } template __device__ __forceinline__ void WarpReduceMax(T* sum) { #pragma unroll for (int offset = WarpSize / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < BatchSize; ++i) { T max_val = paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); sum[i] = max(sum[i], max_val); } } } template __inline__ __device__ void BlockReduceMax(T* val) { static __shared__ T shared[32]; int lane = threadIdx.x & 0x1f; int wid = threadIdx.x >> 5; WarpReduceMax(val); if (lane == 0) shared[wid] = *val; __syncthreads(); int block_span = (blockDim.x + warpSize - 1) >> 5; *val = (lane < block_span) ? shared[lane] : -1e10f; WarpReduceMax(val); } template __inline__ __device__ void BlockReduceSum(T* val) { static __shared__ T shared[32]; int lane = threadIdx.x & 0x1f; int wid = threadIdx.x >> 5; WarpReduceSum(val); __syncthreads(); if (lane == 0) shared[wid] = *val; __syncthreads(); int block_span = (blockDim.x + warpSize - 1) >> 5; *val = (lane < block_span) ? shared[lane] : static_cast(0.0f); WarpReduceSum(val); } template struct ReduceMaxFunctor { inline Ty initial() { return -std::numeric_limits::infinity(); } __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const { return max(a, b); } }; template struct MaxFunctor { __device__ __forceinline__ AccT operator()(const AccT& max_v, const T& v) const { return max(max_v, static_cast(v)); } }; template struct ExpFunctor { HOSTDEVICE inline Ty operator()(const Tx& x) const { return static_cast(std::exp(x)); } }; template struct ExpMulFunctor { HOSTDEVICE inline ExpMulFunctor() { y = static_cast(1.0f); } HOSTDEVICE explicit inline ExpMulFunctor(Tx y) : y((Tx)(y)) {} HOSTDEVICE inline Ty operator()(const Tx& x) const { return static_cast(std::exp(x) * y); } private: Tx y; }; template struct UnarySubFunctor { HOSTDEVICE inline UnarySubFunctor() { y = static_cast(0.0f); } HOSTDEVICE explicit inline UnarySubFunctor(Tx y) : y((Tx)(y)) {} HOSTDEVICE inline Ty operator()(const Tx& x) const { return static_cast(x - y); } private: Tx y; }; template struct UnaryLogFunctor { HOSTDEVICE inline UnaryLogFunctor() {} HOSTDEVICE explicit inline UnaryLogFunctor(int n) {} HOSTDEVICE inline Ty operator()(const Tx& x) const { return static_cast(std::log(x)); } }; template struct DataTransFunctor { HOSTDEVICE inline DataTransFunctor() {} HOSTDEVICE explicit inline DataTransFunctor(int n) {} HOSTDEVICE inline Ty operator()(const Tx& x) const { return x == -std::numeric_limits::infinity() ? -std::numeric_limits::infinity() : static_cast(x); } }; template struct UnaryDivFunctor { HOSTDEVICE inline UnaryDivFunctor() { n_inv = static_cast(1.0f); } HOSTDEVICE explicit inline UnaryDivFunctor(Tx n) : n_inv((Tx)(1.0 / n)) {} HOSTDEVICE inline Ty operator()(const Tx& x) const { return static_cast(x * n_inv); } private: Tx n_inv; }; template struct SoftmaxForwardFunctor { HOSTDEVICE inline SoftmaxForwardFunctor(Tx max, Tx sum) : max(max), sum(sum) {} HOSTDEVICE inline Ty operator()(const Tx& x) const { return static_cast(std::exp(x - max) / sum); } private: Tx max; Tx sum; }; template struct SoftmaxBackwardFunctor { HOSTDEVICE inline SoftmaxBackwardFunctor(Tx sum) : sum(sum) {} HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const { return static_cast(out * (grad_out - sum)); } private: Tx sum; }; template struct LogSoftmaxForwardFunctor { HOSTDEVICE inline LogSoftmaxForwardFunctor(Tx max, Tx sum) : max(max), log_sum(std::log(sum)) {} HOSTDEVICE inline Ty operator()(const Tx& x) const { return static_cast(x - max - log_sum); } private: Tx max; Tx log_sum; }; template struct LogSoftmaxBackwardFunctor { HOSTDEVICE inline LogSoftmaxBackwardFunctor(Tx sum) : sum(sum) {} HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const { return static_cast(grad_out - std::exp(out) * sum); } private: Tx sum; }; template struct SumExpFunctor { HOSTDEVICE inline SumExpFunctor(AccT v) : max_v(v) {} HOSTDEVICE inline AccT operator()(AccT sum, T v) const { return sum + std::exp(static_cast(v) - max_v); } private: AccT max_v; }; template