diff --git a/paddle/fluid/operators/fused/attn_bias_add.cu.h b/paddle/fluid/operators/fused/attn_bias_add.cu.h index 3a2de0c4a093514a1c40321ab7dad61011709204..b059223eaf6e7a0907f8344c4ee44087002d005d 100644 --- a/paddle/fluid/operators/fused/attn_bias_add.cu.h +++ b/paddle/fluid/operators/fused/attn_bias_add.cu.h @@ -51,8 +51,7 @@ template use_broadcast, uint32_t numel, - phi::Array, MAX_INPUT_NUM> - configlists, + phi::Array configlists, int main_tid, int tail_tid, Functor func) { int fix = blockIdx.x * blockDim.x * VecSize; int num = tail_tid; @@ -65,14 +64,14 @@ __global__ void BroadcastKernelBinary( // load in0 if (use_broadcast[0]) { - kernel_primitives::ReadDataBc( + kernel_primitives::ReadDataBc( arg0, in0, fix, configlists[0], numel); } else { kernel_primitives::ReadData(arg0, in0 + fix, num); } // load in1 if (use_broadcast[1]) { - kernel_primitives::ReadDataBc( + kernel_primitives::ReadDataBc( arg1, in1, fix, configlists[1], numel); } else { kernel_primitives::ReadData(arg1, in1 + fix, num); @@ -104,7 +103,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n, int main_tid = numel / (data_per_thread * vec_size * threads); int tail_tid = numel % (data_per_thread * vec_size * threads); - phi::Array, MAX_INPUT_NUM> configlists; + phi::Array configlists; phi::Array use_broadcast; use_broadcast[0] = false; @@ -115,7 +114,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n, // Here, dims are transposed due to the logic in BroadcastConfig. std::vector input1_dims = {n, 1}; std::vector out_dims = {n, m}; - configlists[1] = kps::details::BroadcastConfig<2>(out_dims, input1_dims, 2); + configlists[1] = kps::details::BroadcastConfig(out_dims, input1_dims, 2); auto func = AddFunctor(); auto stream = ctx.stream(); diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index d6b9f0935a24cfede25417a8b4fc072125c6426e..2a4c46eb797cca974d90fcc69967819ecfe249ea 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -185,19 +185,19 @@ struct DimensionsTransform { } }; -template +template __device__ __forceinline__ void LoadData( T *dst, const _ptr_ T *src, uint32_t block_offset, - const kps::details::BroadcastConfig &config, + const kps::details::BroadcastConfig &config, int numel, int num, int need_broadcast) { // numel : whole num of output // num: how many data will be deal with in this time if (need_broadcast) { - kps::ReadDataBc( + kps::ReadDataBc( dst, src, block_offset, config, numel); } else { kps::ReadData(dst, src + block_offset, num); @@ -210,14 +210,13 @@ template __device__ void VectorizedBroadcastKernelImpl( const phi::Array &ins, phi::Array<_ptr_ OutT *, NumOuts> outs, const phi::Array &use_broadcast, uint32_t numel, - const phi::Array, Arity> &configs, + const phi::Array &configs, int num, int block_offset, Functor func) { @@ -227,13 +226,13 @@ __device__ void VectorizedBroadcastKernelImpl( #pragma unroll for (int i = 0; i < Arity; i++) { kps::Init(args[i], static_cast(1.0f)); - LoadData(args[i], - ins[i], - block_offset, - configs[i], - numel, - num, - use_broadcast[i]); + LoadData(args[i], + ins[i], + block_offset, + configs[i], + numel, + num, + use_broadcast[i]); } constexpr bool kCallElementwiseAny = paddle::platform::FunctionTraits::has_pointer_args; @@ -254,14 +253,13 @@ template + int VecSize> __global__ void VectorizedBroadcastKernel( phi::Array ins, phi::Array<_ptr_ OutT *, NumOuts> outs, phi::Array use_broadcast, uint32_t numel, - phi::Array, Arity> configs, + phi::Array configs, int main_offset, int tail_tid, Functor func) { @@ -276,7 +274,6 @@ __global__ void VectorizedBroadcastKernel( Arity, NumOuts, VecSize, - Rank, false>(ins, outs, use_broadcast, @@ -294,7 +291,6 @@ __global__ void VectorizedBroadcastKernel( Arity, NumOuts, VecSize, - Rank, true>( ins, outs, use_broadcast, numel, configs, num, block_offset, func); } @@ -306,7 +302,6 @@ __global__ void VectorizedBroadcastKernel( Arity, NumOuts, VecSize, - Rank, false>(ins, outs, use_broadcast, @@ -322,7 +317,6 @@ __global__ void VectorizedBroadcastKernel( Arity, NumOuts, VecSize, - Rank, true>( ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func); } @@ -334,15 +328,14 @@ template + int VecSize> void LaunchBroadcastKernel(const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func, DimensionsTransform merge_dims) { int numel = (*outs)[0]->numel(); - phi::Array, Arity> configs; + phi::Array configs; phi::Array use_broadcast; phi::Array ins_data; phi::Array<_ptr_ OutT *, NumOuts> outs_data; @@ -358,7 +351,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx, // get the broadcast config, // if data shape is[m, n], then you should set data_dim = {n, m} // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} - configs[i] = kps::details::BroadcastConfig( + configs[i] = kps::details::BroadcastConfig( merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); } } @@ -374,15 +367,14 @@ void LaunchBroadcastKernel(const KPDevice &ctx, Functor, Arity, NumOuts, - VecSize, - Rank><<>>(ins_data, - outs_data, - use_broadcast, - numel, - configs, - main_offset, - tail_tid, - func); + VecSize><<>>(ins_data, + outs_data, + use_broadcast, + numel, + configs, + main_offset, + tail_tid, + func); #else const int threads = 256; int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; @@ -394,58 +386,18 @@ void LaunchBroadcastKernel(const KPDevice &ctx, Functor, Arity, NumOuts, - VecSize, - Rank><<>>(ins_data, - outs_data, - use_broadcast, - numel, - configs, - main_offset, - tail_tid, - func); + VecSize><<>>( + ins_data, + outs_data, + use_broadcast, + numel, + configs, + main_offset, + tail_tid, + func); #endif } -template -void BroadcastKernelForDifferentDimSize( - const KPDevice &ctx, - const std::vector &ins, - std::vector *outs, - int axis, - Functor func) { - const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis); - -#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \ - case rank: { \ - LaunchBroadcastKernel( \ - ctx, ins, outs, func, merge_dims); \ - } break; - - switch (merge_dims.dim_size) { - CALL_BROADCAST_FOR_DIM_SIZE(1); - CALL_BROADCAST_FOR_DIM_SIZE(2); - CALL_BROADCAST_FOR_DIM_SIZE(3); - CALL_BROADCAST_FOR_DIM_SIZE(4); - CALL_BROADCAST_FOR_DIM_SIZE(5); - CALL_BROADCAST_FOR_DIM_SIZE(6); - CALL_BROADCAST_FOR_DIM_SIZE(7); - CALL_BROADCAST_FOR_DIM_SIZE(8); - default: { - PADDLE_THROW(phi::errors::InvalidArgument( - "The maximum dimension of input tensor is expected to be less than " - "%d, but recieved %d.", - merge_dims.dim_size, - phi::DDim::kMaxRank)); - } - } -#undef CALL_BROADCAST_FOR_DIM_SIZE -} - template dims(), axis); switch (vec_size) { case 4: { - BroadcastKernelForDifferentDimSize(ctx, ins, outs, axis, func); + LaunchBroadcastKernel( + ctx, ins, outs, func, merge_dims); break; } case 2: { - BroadcastKernelForDifferentDimSize(ctx, ins, outs, axis, func); + LaunchBroadcastKernel( + ctx, ins, outs, func, merge_dims); break; } case 1: { - BroadcastKernelForDifferentDimSize(ctx, ins, outs, axis, func); + LaunchBroadcastKernel( + ctx, ins, outs, func, merge_dims); break; } default: { diff --git a/paddle/phi/kernels/primitive/datamover_primitives.h b/paddle/phi/kernels/primitive/datamover_primitives.h index 993349f2d9e14112e2acbd0bc098f3a00351232b..8b0c42c9d19b1294254ddd467dd412e420c8bb58 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives.h +++ b/paddle/phi/kernels/primitive/datamover_primitives.h @@ -82,10 +82,10 @@ struct FastDivMod { * index of the output data. if input or output shape is [dim0, dim1] then dims * must be [dim1, dim0]. */ -template struct BroadcastConfig { - FastDivMod divmoders[kDims]; + FastDivMod divmoders[phi::DDim::kMaxRank]; uint32_t strides[phi::DDim::kMaxRank]; + int kDims; HOSTDEVICE BroadcastConfig() {} HOSTDEVICE BroadcastConfig(const std::vector& out_dims, @@ -109,7 +109,7 @@ struct BroadcastConfig { std::multiplies()) : strides_in[i]; } - + kDims = dim_size; memcpy(strides, strides_in.data(), kDims * sizeof(uint32_t)); memcpy(divmoders, divmoders_in.data(), kDims * sizeof(FastDivMod)); } @@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) { } } +template +__device__ __forceinline__ void Init(T* dst, T init_data, int read_lens) { +#pragma unroll + for (int i = 0; i < NX; i++) { + dst[i] = init_data; + } +} + /** * The difference from the above function is that * it supports different data types of inputs. @@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst, } } +template +__device__ __forceinline__ void ReadData(T* dst, + const T* __restrict__ src, + int num, + int read_lens) { + if (IsBoundary) { // blockDim.x * NX > num + int thread_offset = threadIdx.x * NX; +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if (idx + thread_offset < num) { + dst[idx] = src[thread_offset + idx]; + } + } + } else { // blockDim,x * NX < num + constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; + constexpr int kVectorsPerThread = NX / kVectorSize; + int thread_offset = threadIdx.x * kVectorsPerThread; + + using VecType = details::VectorType; + const VecType* vec_input = reinterpret_cast(src); + VecType vec_temp[kVectorsPerThread]; + +#pragma unroll + for (int i = 0; i < kVectorsPerThread; ++i) { + vec_temp[i] = vec_input[thread_offset + i]; +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + dst[idx] = *(reinterpret_cast(vec_temp) + idx); + } + } + } +} /** * @brief Read 1D data from global memory to register. The difference * from the above function is that it supports different data types of inputs. @@ -396,17 +436,12 @@ __device__ __forceinline__ void ReadData(ArgsT* dst, * stride_nx: Each read one element stride stride_nx elements in the last dim. * stride_ny: Each read one element stride stride_ny elements in the first dim. */ -template +template __device__ __forceinline__ void ReadDataBc( T* dst, const T* __restrict__ src, uint32_t block_offset, - details::BroadcastConfig config, + const details::BroadcastConfig& config, int total_num_output, int stride_nx, int stride_ny) { @@ -425,7 +460,8 @@ __device__ __forceinline__ void ReadDataBc( } } #pragma unroll - for (int i = 0; i < Rank; ++i) { + for (int i = 0; i < phi::DDim::kMaxRank; ++i) { + if (i >= config.kDims) break; auto fast_divmoder = config.divmoders[i].Divmod(index_output); index_output = fast_divmoder.val[0]; index_src += fast_divmoder.val[1] * config.strides[i]; @@ -576,6 +612,36 @@ __device__ __forceinline__ void WriteData(T* dst, } } +template +__device__ __forceinline__ void WriteData(T* dst, + T* __restrict__ src, + int num, + int read_lens) { + if (IsBoundary) { + int thread_offset = threadIdx.x * NX; +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if ((thread_offset + idx) < num) { + dst[thread_offset + idx] = src[idx]; + } + } + } else { + // Vector type + constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; + constexpr int kVectorsPerThread = NX / kVectorSize; + + int thread_offset = threadIdx.x * kVectorsPerThread; + using VecType = details::VectorType; + VecType* vec_dst = reinterpret_cast(dst); + VecType vec_temp[kVectorsPerThread]; +#pragma unroll + for (int idx = 0; idx < kVectorsPerThread; ++idx) { + vec_temp[idx] = *(reinterpret_cast(src) + idx); + vec_dst[thread_offset + idx] = vec_temp[idx]; + } + } +} + /** * @brief Write 2D data from register to global memory according to Tx type, and * store it as Ty type. @@ -715,18 +781,14 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) { * coordinate mapping relationship between output data and input data. * total_num_output: Total number of original output. */ -template +template __device__ __forceinline__ void ReadDataBc( T* dst, const T* __restrict__ src, uint32_t block_offset, - details::BroadcastConfig config, - int total_num_output) { + const details::BroadcastConfig& config, + int total_num_output, + int read_lens = NX) { uint32_t thread_offset = block_offset + threadIdx.x * NX; uint32_t index_src = 0; @@ -740,7 +802,8 @@ __device__ __forceinline__ void ReadDataBc( } } #pragma unroll - for (int i = 0; i < Rank; ++i) { + for (int i = 0; i < phi::DDim::kMaxRank; ++i) { + if (i >= config.kDims) break; auto fast_divmoder = config.divmoders[i].Divmod(index_output); index_output = fast_divmoder.val[0]; index_src += fast_divmoder.val[1] * config.strides[i]; diff --git a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h index d2cfdbdec3064c8e9cf20d101afc2adf0ed011a8..d756b1fff18e1f0f090b0a2c8b749c65de563dba 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h @@ -32,11 +32,11 @@ struct alignas(sizeof(T) * VecSize) VectorType { * must be [dim1, dim0]. */ #pragma pack(4) -template struct BroadcastConfig { int strides_in[phi::DDim::kMaxRank]; int strides_out[phi::DDim::kMaxRank]; int in_dim[phi::DDim::kMaxRank]; + int kDims; HOSTDEVICE BroadcastConfig() {} @@ -58,6 +58,7 @@ struct BroadcastConfig { dim_tmp[i] = in_dims[i]; } + kDims = dim_size; memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int)); memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int)); memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int)); @@ -328,16 +329,11 @@ __device__ __forceinline__ void ReadData(ArgsT* dst, * stride_nx: Each read one element stride stride_nx elements in the last dim. * stride_ny: Each read one element stride stride_ny elements in the first dim. */ -template +template __device__ __inline__ void ReadDataBc(T* dst, const T _global_ptr_* src, uint32_t block_offset, - details::BroadcastConfig config, + details::BroadcastConfig config, int total_num_output, int stride_nx, int stride_ny) { @@ -643,18 +639,12 @@ __device__ __inline__ void Init(T* dst, T* init_data, int num) { * coordinate mapping relationship between output data and input data. * total_num_output: Total number of original output. */ -template -__device__ __inline__ void ReadDataBc( - T* dst, - const T _global_ptr_* src, - uint32_t block_offset, - const details::BroadcastConfig& config, - int total_num_output) { +template +__device__ __inline__ void ReadDataBc(T* dst, + const T _global_ptr_* src, + uint32_t block_offset, + const details::BroadcastConfig& config, + int total_num_output) { int thread_offset = block_offset + core_id() * NX; int index_src = 0;