From c7855125819c65efd4d06589e8855211563d4dc4 Mon Sep 17 00:00:00 2001 From: shixingbo <90814748+bmb0537@users.noreply.github.com> Date: Tue, 10 May 2022 10:00:29 +0800 Subject: [PATCH] broadcast_add kp performance optimization (#42097) --- paddle/phi/kernels/funcs/broadcast_function.h | 138 +++- paddle/phi/kernels/funcs/elementwise_base.h | 56 +- .../kernels/primitive/compute_primitives.h | 14 + .../primitive/compute_primitives_xpu2.h | 14 + .../kernels/primitive/datamover_primitives.h | 104 +++ .../primitive/datamover_primitives_xpu2.h | 596 +++++++++++++++++- .../phi/kernels/primitive/kernel_primitives.h | 1 + 7 files changed, 880 insertions(+), 43 deletions(-) diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index aafa40a3d0..38cd41d3b6 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -242,6 +242,27 @@ __device__ __forceinline__ void LoadData( } } +template +__device__ __forceinline__ void LoadData( + T *dst, + const _ptr_ T *src, + uint32_t block_offset, + const kps::details::BroadcastConfig &config, + int numel, + int num, + int need_broadcast, + int read_lens) { + // numel : whole num of output + // num: how many data will be deal with in this time + if (need_broadcast) { + kps::ReadDataBc( + dst, src, block_offset, config, numel, read_lens); + } else { + kps::ReadData( + dst, src + block_offset, num, read_lens); + } +} + template , Arity> &configs, int num, int block_offset, + int read_lens, Functor func) { - InT args[Arity][VecSize]; - ConditionalT result[VecSize]; + __simd__ InT args[Arity][VecSize]; + __simd__ ConditionalT result[VecSize]; #pragma unroll for (int i = 0; i < Arity; i++) { - kps::Init(args[i], static_cast(1.0f)); + kps::Init(args[i], static_cast(1.0f), read_lens); LoadData(args[i], ins[i], block_offset, configs[i], numel, num, - use_broadcast[i]); + use_broadcast[i], + read_lens); } constexpr bool kCallElementwiseAny = paddle::platform::FunctionTraits::has_pointer_args; @@ -281,10 +304,10 @@ __device__ void VectorizedBroadcastKernelImpl( Functor, Arity, kCallElementwiseAny>()( - func, args, result); - - phi::funcs::ElementwiseWriteDataCaller()( - outs, result, block_offset, num); + func, args, result, read_lens); + phi::funcs:: + ElementwiseWriteDataCallerBc()( + outs, result, block_offset, num, read_lens); } template , Arity> configs, int main_offset, int tail_tid, + int read_lens, Functor func) { - int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; - int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; + int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens; + int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens; #ifdef PADDLE_WITH_XPU_KP for (; block_offset < main_offset; block_offset += stride) { @@ -320,8 +344,9 @@ __global__ void VectorizedBroadcastKernel( use_broadcast, numel, configs, - BLOCK_NUM_X * VecSize, + BLOCK_NUM_X * read_lens, block_offset, + read_lens, func); } int num = numel - block_offset; @@ -333,8 +358,15 @@ __global__ void VectorizedBroadcastKernel( NumOuts, VecSize, Rank, - true>( - ins, outs, use_broadcast, numel, configs, num, block_offset, func); + true>(ins, + outs, + use_broadcast, + numel, + configs, + num, + block_offset, + read_lens, + func); } #else if (block_offset < main_offset) { @@ -352,6 +384,7 @@ __global__ void VectorizedBroadcastKernel( configs, BLOCK_NUM_X * VecSize, block_offset, + read_lens, func); } else { VectorizedBroadcastKernelImpl( - ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func); + true>(ins, + outs, + use_broadcast, + numel, + configs, + tail_tid, + block_offset, + read_lens, + func); } #endif } @@ -392,6 +432,19 @@ void LaunchBroadcastKernel(const KPDevice &ctx, for (int i = 0; i < Arity; i++) { use_broadcast[i] = (ins[i]->numel() != numel); ins_data[i] = (const _ptr_ InT *)(ins[i]->data()); +#ifdef PADDLE_WITH_XPU_KP + if (i == 0) { + configs[i] = kps::details::BroadcastConfig(merge_dims.out_dims, + merge_dims.in_dims[0], + merge_dims.in_dims[1], + merge_dims.dim_size); + } else if (i == 1) { + configs[i] = kps::details::BroadcastConfig(merge_dims.out_dims, + merge_dims.in_dims[1], + merge_dims.in_dims[0], + merge_dims.dim_size); + } +#else if (use_broadcast[i]) { // get the broadcast config, // if data shape is[m, n], then you should set data_dim = {n, m} @@ -399,28 +452,50 @@ void LaunchBroadcastKernel(const KPDevice &ctx, configs[i] = kps::details::BroadcastConfig( merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); } +#endif } #ifdef PADDLE_WITH_XPU_KP const int threads = 64; const int blocks = 8; - int main_offset = (numel / (VecSize * threads)) * VecSize * threads; - int tail_tid = numel % (VecSize * threads); + int read_lens = configs[0].buf_len; + int main_offset = (numel / (read_lens * threads)) * read_lens * threads; + int tail_tid = numel % (read_lens * threads); auto stream = ctx.x_context()->xpu_stream; - VectorizedBroadcastKernel<<>>(ins_data, - outs_data, - use_broadcast, - numel, - configs, - main_offset, - tail_tid, - func); + if (configs[0].cmp_type != kps::details::OptType::CanNotOptimize) { + main_offset = numel; + VectorizedBroadcastKernel<<>>(ins_data, + outs_data, + use_broadcast, + numel, + configs, + main_offset, + tail_tid, + read_lens, + func); + } else { + VectorizedBroadcastKernel<<>>(ins_data, + outs_data, + use_broadcast, + numel, + configs, + main_offset, + tail_tid, + read_lens, + func); + } #else const int threads = 256; int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; @@ -440,6 +515,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx, configs, main_offset, tail_tid, + VecSize, func); #endif } diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h index 332ec0b031..4ee46facc7 100644 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -577,14 +577,16 @@ template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], - OutT *result) { + OutT *result, + int read_lens) { kps::ElementwiseAny( result, args, func); } @@ -594,7 +596,8 @@ template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], - OutT *result) { + OutT *result, + int read_lens) { kps::ElementwiseConstant(result, func); } }; @@ -603,7 +606,8 @@ template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], - OutT *result) { + OutT *result, + int read_lens) { kps::ElementwiseUnary( result, args[0], func); } @@ -613,9 +617,10 @@ template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], - OutT *result) { + OutT *result, + int read_lens) { kps::ElementwiseBinary( - result, args[0], args[1], func); + result, args[0], args[1], func, read_lens); } }; @@ -623,7 +628,8 @@ template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], - OutT *result) { + OutT *result, + int read_lens) { kps::ElementwiseTernary( result, args[0], args[1], args[2], func); } @@ -696,6 +702,42 @@ struct ElementwiseWriteDataCaller { } }; +template +struct ElementwiseWriteDataCallerBc { + __device__ __forceinline__ void operator()( + phi::Array<_ptr_ OutT *, NumOuts> outs, + ConditionalT src[VecSize], + int block_offset, + int num, + int read_lens) { + OutT dst[NumOuts][VecSize]; +#pragma unroll + for (int i = 0; i < read_lens; ++i) { +#pragma unroll + for (int j = 0; j < NumOuts; ++j) { + dst[j][i] = (src[i])[j]; + } + } +#pragma unroll + for (int i = 0; i < NumOuts; ++i) { + kps::WriteData( + outs[i] + block_offset, dst[i], num, read_lens); + } + } +}; + +template +struct ElementwiseWriteDataCallerBc { + __device__ __forceinline__ void operator()(phi::Array<_ptr_ OutT *, 1> outs, + OutT src[VecSize], + int block_offset, + int num, + int read_lens) { + kps::WriteData( + outs[0] + block_offset, src, num, read_lens); + } +}; + template +__device__ __forceinline__ void ElementwiseBinary( + OutT* out, const InT* in1, const InT* in2, OpFunc compute, int read_lens) { +#pragma unroll + for (int idx = 0; idx < NX * NY; ++idx) { + out[idx] = static_cast(compute(in1[idx], in2[idx])); + } +} + /** * @brief Ternary calculation according to OpFunc. Shape of input and output * are the same. diff --git a/paddle/phi/kernels/primitive/compute_primitives_xpu2.h b/paddle/phi/kernels/primitive/compute_primitives_xpu2.h index 0e77b11988..eb45def836 100644 --- a/paddle/phi/kernels/primitive/compute_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/compute_primitives_xpu2.h @@ -17,6 +17,7 @@ #include "xpu/kernel/cluster_header.h" #include "xpu/kernel/debug.h" #include "xpu/kernel/math.h" +#include "xpu/kernel/simd_header.h" namespace phi { namespace kps { @@ -158,6 +159,19 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, } } +template +__device__ __forceinline__ void ElementwiseBinary( + OutT* out, const InT* in1, const InT* in2, OpFunc compute, int read_lens) { + for (int idx = 0; idx < read_lens; ++idx) { + out[idx] = static_cast(compute(in1[idx], in2[idx])); + } +} + /** * @brief Ternary calculation according to OpFunc. Shape of input and output * are the same. diff --git a/paddle/phi/kernels/primitive/datamover_primitives.h b/paddle/phi/kernels/primitive/datamover_primitives.h index 993349f2d9..ea1a830f89 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives.h +++ b/paddle/phi/kernels/primitive/datamover_primitives.h @@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) { } } +template +__device__ __forceinline__ void Init(T* dst, T init_data, int read_lens) { +#pragma unroll + for (int i = 0; i < NX; i++) { + dst[i] = init_data; + } +} + /** * The difference from the above function is that * it supports different data types of inputs. @@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst, } } +template +__device__ __forceinline__ void ReadData(T* dst, + const T* __restrict__ src, + int num, + int read_lens) { + if (IsBoundary) { // blockDim.x * NX > num + int thread_offset = threadIdx.x * NX; +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if (idx + thread_offset < num) { + dst[idx] = src[thread_offset + idx]; + } + } + } else { // blockDim,x * NX < num + constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; + constexpr int kVectorsPerThread = NX / kVectorSize; + int thread_offset = threadIdx.x * kVectorsPerThread; + + using VecType = details::VectorType; + const VecType* vec_input = reinterpret_cast(src); + VecType vec_temp[kVectorsPerThread]; + +#pragma unroll + for (int i = 0; i < kVectorsPerThread; ++i) { + vec_temp[i] = vec_input[thread_offset + i]; +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + dst[idx] = *(reinterpret_cast(vec_temp) + idx); + } + } + } +} /** * @brief Read 1D data from global memory to register. The difference * from the above function is that it supports different data types of inputs. @@ -576,6 +616,36 @@ __device__ __forceinline__ void WriteData(T* dst, } } +template +__device__ __forceinline__ void WriteData(T* dst, + T* __restrict__ src, + int num, + int read_lens) { + if (IsBoundary) { + int thread_offset = threadIdx.x * NX; +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if ((thread_offset + idx) < num) { + dst[thread_offset + idx] = src[idx]; + } + } + } else { + // Vector type + constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; + constexpr int kVectorsPerThread = NX / kVectorSize; + + int thread_offset = threadIdx.x * kVectorsPerThread; + using VecType = details::VectorType; + VecType* vec_dst = reinterpret_cast(dst); + VecType vec_temp[kVectorsPerThread]; +#pragma unroll + for (int idx = 0; idx < kVectorsPerThread; ++idx) { + vec_temp[idx] = *(reinterpret_cast(src) + idx); + vec_dst[thread_offset + idx] = vec_temp[idx]; + } + } +} + /** * @brief Write 2D data from register to global memory according to Tx type, and * store it as Ty type. @@ -749,6 +819,40 @@ __device__ __forceinline__ void ReadDataBc( } } +template +__device__ __forceinline__ void ReadDataBc( + T* dst, + const T* __restrict__ src, + uint32_t block_offset, + details::BroadcastConfig config, + int total_num_output, + int read_lens) { + uint32_t thread_offset = block_offset + threadIdx.x * NX; + uint32_t index_src = 0; + +#pragma unroll + for (uint32_t nx = 0; nx < NX; ++nx) { + uint32_t index_output = thread_offset + nx; + index_src = 0; + if (IsBoundary) { + if (index_output >= total_num_output) { + break; + } + } +#pragma unroll + for (int i = 0; i < Rank; ++i) { + auto fast_divmoder = config.divmoders[i].Divmod(index_output); + index_output = fast_divmoder.val[0]; + index_src += fast_divmoder.val[1] * config.strides[i]; + } + dst[nx] = src[index_src]; + } +} /** * @brief Initialize register with data index. * diff --git a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h index a18fc7cbb3..eb25632378 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h @@ -21,6 +21,39 @@ namespace phi { namespace kps { namespace details { +enum class OptType { // Optimize type of calc after input shape compressed + CanNotOptimize = -1, // can not optimize, broadcast first + N_1, // just like {1} op {100} or {100} op {1} + MN_N, // just like {100} op {3, 100} or {3, 100} op {100} + MN_M, // just like {3} op {3, 100} or {3, 100} op {3} + MNK_1N1, // just like {3} op {2, 3, 100} or {2, 3, 100} op {3} + MNK_M1K, // just like {2, 1, 100} op {2, 3, 100} or {2, 3, 100} op {2, 1, + // 100} +}; + +// Rules to determine whether dimensions can be merged +// rule 0 - xshape[idx] == yshape[idx] +// rule 1 - xshape[idx] == 1 && yshape[idx] != 1 +// rule 2 - xshape[idx] != 1 && yshape[idx] == 1 +static int judge_case(int a, int b) { + if (a == b) { + return 0; + } else if (a == 1 && b != 1) { + return 1; + } else if (a != 1 && b == 1) { + return 2; + } + return -1; +} + +static bool case_is_same(int case_front, int case_back) { + if (case_front == case_back) { + return true; + } else { + return false; + } +} + template struct alignas(sizeof(T) * VecSize) VectorType { T val[VecSize]; @@ -37,11 +70,20 @@ struct BroadcastConfig { int strides_in[phi::DDim::kMaxRank]; int strides_out[phi::DDim::kMaxRank]; int in_dim[phi::DDim::kMaxRank]; + int dim_after_cmp[phi::DDim::kMaxRank]; + int dim_size_after_cmp = 0; + int cmp_res = 0; + OptType cmp_type = OptType::CanNotOptimize; + int m = 1; + int n = 1; + int k = 1; + int buf_len = 0; HOSTDEVICE BroadcastConfig() {} HOSTDEVICE BroadcastConfig(const std::vector& out_dims, const std::vector& in_dims, + const std::vector& another_in_dims, int dim_size) { std::vector strides_in_tmp; std::vector strides_out_tmp; @@ -61,18 +103,187 @@ struct BroadcastConfig { memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int)); memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int)); memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int)); + + cmp_res = get_mnk_for_broadcast_ops(in_dims, another_in_dims); + get_opt_type(another_in_dims); + buf_len = get_buf_len(); + } + + int get_buf_len() { + if (cmp_type == OptType::CanNotOptimize) { + return 256; + } + int max_buf_len = 512; + int buf_len = m / 16 * 16; + if (buf_len == 0) { + buf_len = m; + } + return std::min(max_buf_len, buf_len); } __device__ inline int operator()(int index_output) const { int index_src = 0; -#pragma unroll - for (int i = kDims - 1; i >= 0; --i) { - int tmp_index = (index_output / strides_out[i]); - index_output = index_output - tmp_index * strides_out[i]; - index_src += (tmp_index % in_dim[i]) * strides_in[i]; + + switch (cmp_type) { + int div, mod, tmp_index; + case OptType::MNK_M1K: + div = index_output / (m * n); + mod = index_output % (m * n) % m; + index_src = div * m + mod; + break; + case OptType::MNK_1N1: + // index_src = index_output / m % n; + index_src = index_output % (m * n) / m; + break; + case OptType::N_1: + index_src = 0; + break; + case OptType::MN_N: + index_src = index_output / m; + break; + case OptType::MN_M: + index_src = index_output % m; + break; + case OptType::CanNotOptimize: + for (int i = kDims - 1; i >= 0; --i) { + tmp_index = (index_output / strides_out[i]); + index_output = index_output - tmp_index * strides_out[i]; + index_src += (tmp_index % in_dim[i]) * strides_in[i]; + } + break; } return index_src; } + + void get_opt_type(const std::vector& y_dim_after_cmp) { + if (dim_size_after_cmp == 1) { + if (dim_after_cmp[0] == 1 && y_dim_after_cmp[0] != 1) { // {1} op {n} + n = y_dim_after_cmp[0]; + cmp_type = OptType::N_1; + } else if (dim_after_cmp[0] != 1 && + y_dim_after_cmp[0] == 1) { // {n} op {1} + n = dim_after_cmp[0]; + cmp_type = OptType::N_1; + } else { + cmp_type = OptType::CanNotOptimize; // xshape == yshape + } + } + if (dim_size_after_cmp == 2) { + if (dim_after_cmp[0] == 1 && dim_after_cmp[1] != 1 && + y_dim_after_cmp[0] != 1 && + y_dim_after_cmp[1] != 1) { // {n} op {m, n} + m = y_dim_after_cmp[0]; + n = y_dim_after_cmp[1]; + cmp_type = OptType::MN_N; + } else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] == 1 && + y_dim_after_cmp[0] != 1 && + y_dim_after_cmp[1] != 1) { // {m} op {m, n} + m = y_dim_after_cmp[0]; + n = y_dim_after_cmp[1]; + cmp_type = OptType::MN_M; + } else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] != 1 && + y_dim_after_cmp[0] == 1 && + y_dim_after_cmp[1] != 1) { // {m, n} op {n} + m = dim_after_cmp[0]; + n = dim_after_cmp[1]; + cmp_type = OptType::MN_N; + } else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] != 1 && + y_dim_after_cmp[0] != 1 && + y_dim_after_cmp[1] == 1) { // {m, n} op {m} + m = dim_after_cmp[0]; + n = dim_after_cmp[1]; + cmp_type = OptType::MN_M; + } else { + cmp_type = OptType::CanNotOptimize; + } + } + if (dim_size_after_cmp == 3) { + if (dim_after_cmp[0] == 1 && dim_after_cmp[1] != 1 && + dim_after_cmp[2] == 1 && y_dim_after_cmp[0] != 1 && + y_dim_after_cmp[1] != 1 && + y_dim_after_cmp[2] != 1) { // {1, n, 1} op {m, n, k} + m = y_dim_after_cmp[0]; + n = y_dim_after_cmp[1]; + k = y_dim_after_cmp[2]; + cmp_type = OptType::MNK_1N1; + } else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] != 1 && + dim_after_cmp[2] != 1 && y_dim_after_cmp[0] == 1 && + y_dim_after_cmp[1] != 1 && + y_dim_after_cmp[2] == 1) { // {m, n, k} op {1, n, 1} + m = dim_after_cmp[0]; + n = dim_after_cmp[1]; + k = dim_after_cmp[2]; + cmp_type = OptType::MNK_1N1; + } else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] == 1 && + dim_after_cmp[2] != 1 && y_dim_after_cmp[0] != 1 && + y_dim_after_cmp[1] != 1 && + y_dim_after_cmp[2] != 1) { // {m, 1, k} op {m, n, k} + m = y_dim_after_cmp[0]; + n = y_dim_after_cmp[1]; + k = y_dim_after_cmp[2]; + cmp_type = OptType::MNK_M1K; + } else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] != 1 && + dim_after_cmp[2] != 1 && y_dim_after_cmp[0] != 1 && + y_dim_after_cmp[1] == 1 && + y_dim_after_cmp[2] != 1) { // {m, n, k} op {m, 1, k} + m = dim_after_cmp[0]; + n = dim_after_cmp[1]; + k = dim_after_cmp[2]; + cmp_type = OptType::MNK_M1K; + } else { + cmp_type = OptType::CanNotOptimize; + } + } + } + + int get_mnk_for_broadcast_ops(const std::vector& xshape, + const std::vector& yshape) { + int idx = 0; + int cmp_x = 0; + int cmp_y = 0; + bool is_same = false; + std::vector xshape_after_remove_ones = xshape; + std::vector yshape_after_remove_ones = yshape; + // first step: remove excess ones + std::vector::iterator x_iter = xshape_after_remove_ones.begin(); + std::vector::iterator y_iter = yshape_after_remove_ones.begin(); + for (; x_iter != xshape_after_remove_ones.end();) { + if (*x_iter == 1 && *y_iter == 1) { + x_iter = xshape_after_remove_ones.erase(x_iter); + y_iter = yshape_after_remove_ones.erase(y_iter); + } else { + x_iter++; + y_iter++; + } + } + // second step: compress dims + int after_cmp_idx = 0; + for (int i = 0; i < 3; i++) { + cmp_x = xshape_after_remove_ones[idx]; + cmp_y = yshape_after_remove_ones[idx]; + while ((idx + 1) < xshape_after_remove_ones.size()) { + is_same = case_is_same(judge_case(xshape_after_remove_ones[idx], + yshape_after_remove_ones[idx]), + judge_case(xshape_after_remove_ones[idx + 1], + yshape_after_remove_ones[idx + 1])); + if (is_same) { + cmp_x = cmp_x * xshape_after_remove_ones[idx + 1]; + cmp_y = cmp_y * yshape_after_remove_ones[idx + 1]; + idx++; + } else { + break; + } + } + idx = idx + 1; + dim_after_cmp[after_cmp_idx] = cmp_x; + after_cmp_idx++; + if (idx == xshape_after_remove_ones.size()) { + dim_size_after_cmp = after_cmp_idx; + return 0; + } + } + return -1; // can not compress dims + } }; #pragma pack() @@ -199,6 +410,14 @@ __device__ __inline__ void Init(T* dst, T init_data) { } } +template +__device__ __inline__ void Init(T* dst, T init_data, int read_lens) { +#pragma unroll + for (int i = 0; i < read_lens; i++) { + dst[i] = init_data; + } +} + /** * The difference from the above function is that * it supports different data types of inputs. @@ -251,6 +470,26 @@ __device__ __inline__ void ReadData(T* dst, } } +template +__device__ __inline__ void ReadData(T* dst, + const T _global_ptr_* src, + int num, + int read_lens) { + int thread_offset = core_id() * read_lens; + __local__ T in_temp[1]; + if (IsBoundary) { // core_num() * read_lens > num +#pragma unroll + for (int idx = 0; idx < read_lens; ++idx) { + if (idx + thread_offset < num) { + GM2LM(src + thread_offset + idx, in_temp, sizeof(T)); + dst[idx] = in_temp[0]; + } + } + } else { // core_num() * read_lens < num + GM2LM(src + thread_offset, dst, read_lens * sizeof(T)); + } +} + /** * @brief Read 1D data from global memory to register. The difference * from the above function is that it supports different data types of inputs. @@ -479,10 +718,32 @@ __device__ __forceinline__ void ReadDataReduce( * size: The current block needs to load size elements continuously. */ +template +__device__ void WriteData(T _global_ptr_* dst, + const T* src, + int num, + int read_lens) { + int thread_offset = core_id() * read_lens; + __local__ T in_temp[1]; + + if (IsBoundary) { // core_num() * read_lens > num +#pragma unroll + for (int idx = 0; idx < read_lens; ++idx) { + if (idx + thread_offset < num) { + in_temp[0] = src[idx]; + LM2GM(in_temp, dst + idx + thread_offset, sizeof(T)); + } + } + } else { // core_num() * read_lens < num + LM2GM(src, dst + thread_offset, read_lens * sizeof(T)); + } +} + template __device__ void WriteData(T _global_ptr_* dst, const T* src, int num) { int thread_offset = core_id() * NX; __local__ T in_temp[1]; + if (IsBoundary) { // core_num() * NX > num #pragma unroll for (int idx = 0; idx < NX; ++idx) { @@ -675,6 +936,331 @@ __device__ __inline__ void ReadDataBc( } } +/** + * @brief Read data from global memory to local memory with broadcast + * {m, 1, k}-> {m, n, k} form. + * + * @template paraments + * T: Data type of register. + * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. + * + * @param: + * dst: The register pointer of the thread, the size is NX. + * src: The original input data pointer of kernel. + * thread_offset: The data offset of this thread. + * config: Calculation configuration of broadcast. It is used to calculate the + * coordinate mapping relationship between output data and input data. + * read_lens: The number of data continuously loaded by each thread. + */ +template +__device__ __inline__ void ReadDataBcM1kMnk( + T* dst, + const T _global_ptr_* src, + int thread_offset, + const details::BroadcastConfig& config, + int read_lens) { + int index_output = thread_offset; + int index_base = config(index_output); + int m = config.m; + int n = config.n; + + int m_pos = index_base % m; + if ((m - m_pos) < read_lens) { + int last_col = m - m_pos; + GM2LM(src + index_base, dst, last_col * sizeof(T)); + int n_pos = index_output % (m * n) / m; + int next_part_index = 0; + if (n_pos != config.n - 1) { + next_part_index = index_base / m * m; + } else { + next_part_index = (index_base / m + 1) * m; + } + GM2LM(src + next_part_index, + dst + last_col, + (read_lens - last_col) * sizeof(T)); + } else { + GM2LM(src + index_base, dst, read_lens * sizeof(T)); + } +} + +/** + * @brief Read data from global memory to local memory with broadcast + * {m, 1}-> {m, n} form. + * + * @template paraments + * T: Data type of register. + * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. + * + * @param: + * dst: The register pointer of the thread, the size is NX. + * src: The original input data pointer of kernel. + * thread_offset: The data offset of this thread. + * config: Calculation configuration of broadcast. It is used to calculate the + * coordinate mapping relationship between output data and input data. + * read_lens: The number of data continuously loaded by each thread. + */ +template +__device__ __inline__ void ReadDataBcM1Mn( + T* dst, + const T _global_ptr_* src, + int thread_offset, + const details::BroadcastConfig& config, + int read_lens) { + int index_output = thread_offset; + int index_base = config(index_output); + int m = config.m; + int n = config.n; + + int m_pos = index_base % m; + if ((m - m_pos) < read_lens) { + int last_col = m - m_pos; + GM2LM(src + index_base, dst, last_col * sizeof(T)); + GM2LM(src, dst + last_col, (read_lens - last_col) * sizeof(T)); + } else { + GM2LM(src + index_base, dst, read_lens * sizeof(T)); + } +} + +/** + * @brief Read data from global memory to local memory with broadcast + * {1, n}-> {m, n} form. + * + * @template paraments + * T: Data type of register. + * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. + * + * @param: + * dst: The register pointer of the thread, the size is NX. + * src: The original input data pointer of kernel. + * thread_offset: The data offset of this thread. + * config: Calculation configuration of broadcast. It is used to calculate the + * coordinate mapping relationship between output data and input data. + * read_lens: The number of data continuously loaded by each thread. + */ +template +__device__ __inline__ void ReadDataBc1NMn( + T* dst, + const T _global_ptr_* src, + int thread_offset, + const details::BroadcastConfig& config, + int read_lens) { + int index_output = thread_offset; + int index_base = config(index_output); + int m = config.m; + int n = config.n; + T in_temp; + + int m_pos = index_output % m; + if ((m - m_pos) < read_lens) { + int last_col = m - m_pos; + GM2LM(src + index_base, &in_temp, sizeof(T)); + for (int i = 0; i < last_col; i++) { + dst[i] = in_temp; + } + GM2LM(src + index_base + 1, &in_temp, sizeof(T)); + for (int i = 0; i < read_lens - last_col; i++) { + dst[last_col + i] = in_temp; + } + } else { + GM2LM(src + index_base, &in_temp, sizeof(T)); + for (int i = 0; i < read_lens; i++) { + dst[i] = in_temp; + } + } +} + +/** + * @brief Read data from global memory to local memory with broadcast + * {1, n, 1}-> {m, n, k} form. + * + * @template paraments + * T: Data type of register. + * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. + * + * @param: + * dst: The register pointer of the thread, the size is NX. + * src: The original input data pointer of kernel. + * thread_offset: The data offset of this thread. + * config: Calculation configuration of broadcast. It is used to calculate the + * coordinate mapping relationship between output data and input data. + * read_lens: The number of data continuously loaded by each thread. + */ +template +__device__ __inline__ void ReadDataBc1N1Mnk( + T* dst, + const T _global_ptr_* src, + int thread_offset, + const details::BroadcastConfig& config, + int read_lens) { + int index_output = thread_offset; + int index_base = config(index_output); + int m = config.m; + int n = config.n; + T in_temp; + + int m_pos = index_output % m; + if ((m - m_pos) < read_lens) { + int last_col = m - m_pos; + GM2LM(src + index_base, &in_temp, sizeof(T)); + for (int i = 0; i < last_col; i++) { + dst[i] = in_temp; + } + int n_pos = index_output % (m * n) / m; + int next_part_index = 0; + if (n_pos != n - 1) { + next_part_index = n_pos + 1; + } else { + next_part_index = 0; + } + GM2LM(src + next_part_index, &in_temp, sizeof(T)); + for (int i = 0; i < read_lens - last_col; i++) { + dst[last_col + i] = in_temp; + } + } else { + GM2LM(src + index_base, &in_temp, sizeof(T)); + for (int i = 0; i < read_lens; i++) { + dst[i] = in_temp; + } + } +} + +/** + * @brief Read data from global memory to local memory with broadcast + * {1}-> {n} form. + * + * @template paraments + * T: Data type of register. + * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. + * + * @param: + * dst: The register pointer of the thread, the size is NX. + * src: The original input data pointer of kernel. + * thread_offset: The data offset of this thread. + * config: Calculation configuration of broadcast. It is used to calculate the + * coordinate mapping relationship between output data and input data. + * read_lens: The number of data continuously loaded by each thread. + */ +template +__device__ __inline__ void ReadDataBc1N( + T* dst, + const T _global_ptr_* src, + int thread_offset, + const details::BroadcastConfig& config, + int read_lens) { + int index_output = thread_offset; + int index_base = config(index_output); + T in_temp; + + GM2LM(src + index_base, &in_temp, sizeof(T)); + for (int i = 0; i < read_lens; i++) { + dst[i] = in_temp; + } +} + +/** + * @brief Read data from global memory to local memory with broadcast + * form which can not compress. + * + * @template paraments + * T: Data type of register. + * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. + * + * @param: + * dst: The register pointer of the thread, the size is NX. + * src: The original input data pointer of kernel. + * thread_offset: The data offset of this thread. + * config: Calculation configuration of broadcast. It is used to calculate the + * coordinate mapping relationship between output data and input data. + * total_num_output: Total number of original output. + * read_lens: The number of data continuously loaded by each thread. + */ +template +__device__ __inline__ void ReadDataBcCanNotCmp( + T* dst, + const T _global_ptr_* src, + int thread_offset, + const details::BroadcastConfig& config, + int total_num_output, + int read_lens) { + int index_output = thread_offset; + int index_base = config(index_output); + T in_temp; + int cache_size = 256; + __local__ T src_temp[cache_size]; + GM2LM(src + index_base, src_temp, cache_size * sizeof(T)); + + for (int nx = 0; nx < read_lens; ++nx) { + index_output = thread_offset + nx; + if (IsBoundary) { + if (index_output >= total_num_output) { + break; + } + } + int index_src = config(index_output); + if (index_src >= index_base && index_src < index_base + cache_size) { + in_temp = src_temp[index_src - index_base]; + } else { + GM2LM(src + index_src, &in_temp, sizeof(T)); + } + dst[nx] = in_temp; + } +} + +/** + * @brief Read 1D data from global memory to register with broadcast form. + * + * @template paraments + * T: The type of data stored in the global memory. + * NX: The number of data continuously loaded by each thread. + * NY: The number of data rows loaded by each thread, only NY = 1 was supported. + * BlockSize: Identifies the current device thread index method. For xpu, + * core_id() is used as the index. + * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. + * IsBoundary: Indicates whether to perform block access storage out-of-bounds + * judgment. When the number of data processed by the block is less than + * NX x NY x core_num(), boundary judgment is required to avoid memory access + * crossing the boundary. + * + * @param: + * dst: The register pointer of the thread, the size is NX * NY. + * src: The original input data pointer of kernel. + * block_offset: The data offset of this block, core_num() * blockIdx.x * NX; + * config: Calculation configuration of broadcast. It is used to calculate the + * coordinate mapping relationship between output data and input data. + * read_lens: The number of data continuously loaded by each thread. + * total_num_output: Total number of original output. + */ +template +__device__ __inline__ void ReadDataBc( + T* dst, + const T _global_ptr_* src, + uint32_t block_offset, + const details::BroadcastConfig& config, + int total_num_output, + int read_lens) { + int thread_offset = block_offset + core_id() * read_lens; + + if (config.cmp_type == details::OptType::MNK_M1K) { + ReadDataBcM1kMnk(dst, src, thread_offset, config, read_lens); + } else if (config.cmp_type == details::OptType::N_1) { + ReadDataBc1N(dst, src, thread_offset, config, read_lens); + } else if (config.cmp_type == details::OptType::MN_M) { + ReadDataBcM1Mn(dst, src, thread_offset, config, read_lens); + } else if (config.cmp_type == details::OptType::MN_N) { + ReadDataBc1NMn(dst, src, thread_offset, config, read_lens); + } else if (config.cmp_type == details::OptType::MNK_1N1) { + ReadDataBc1N1Mnk(dst, src, thread_offset, config, read_lens); + } else { + ReadDataBcCanNotCmp( + dst, src, thread_offset, config, total_num_output, read_lens); + } +} + /** * @brief Initialize register with data index. * diff --git a/paddle/phi/kernels/primitive/kernel_primitives.h b/paddle/phi/kernels/primitive/kernel_primitives.h index b5a1e88acc..ea5846c3a2 100644 --- a/paddle/phi/kernels/primitive/kernel_primitives.h +++ b/paddle/phi/kernels/primitive/kernel_primitives.h @@ -46,6 +46,7 @@ #define KPStream gpuStream_t #define KPDevice phi::GPUContext #define _ptr_ +#define __simd__ #define THREAD_ID_X threadIdx.x #define THREAD_ID_Y threadIdx.y -- GitLab