diff --git a/paddle/fluid/operators/fused/attn_bias_add.cu.h b/paddle/fluid/operators/fused/attn_bias_add.cu.h index 3a2de0c4a093514a1c40321ab7dad61011709204..b059223eaf6e7a0907f8344c4ee44087002d005d 100644 --- a/paddle/fluid/operators/fused/attn_bias_add.cu.h +++ b/paddle/fluid/operators/fused/attn_bias_add.cu.h @@ -51,8 +51,7 @@ template use_broadcast, uint32_t numel, - phi::Array, MAX_INPUT_NUM> - configlists, + phi::Array configlists, int main_tid, int tail_tid, Functor func) { int fix = blockIdx.x * blockDim.x * VecSize; int num = tail_tid; @@ -65,14 +64,14 @@ __global__ void BroadcastKernelBinary( // load in0 if (use_broadcast[0]) { - kernel_primitives::ReadDataBc( + kernel_primitives::ReadDataBc( arg0, in0, fix, configlists[0], numel); } else { kernel_primitives::ReadData(arg0, in0 + fix, num); } // load in1 if (use_broadcast[1]) { - kernel_primitives::ReadDataBc( + kernel_primitives::ReadDataBc( arg1, in1, fix, configlists[1], numel); } else { kernel_primitives::ReadData(arg1, in1 + fix, num); @@ -104,7 +103,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n, int main_tid = numel / (data_per_thread * vec_size * threads); int tail_tid = numel % (data_per_thread * vec_size * threads); - phi::Array, MAX_INPUT_NUM> configlists; + phi::Array configlists; phi::Array use_broadcast; use_broadcast[0] = false; @@ -115,7 +114,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n, // Here, dims are transposed due to the logic in BroadcastConfig. std::vector input1_dims = {n, 1}; std::vector out_dims = {n, m}; - configlists[1] = kps::details::BroadcastConfig<2>(out_dims, input1_dims, 2); + configlists[1] = kps::details::BroadcastConfig(out_dims, input1_dims, 2); auto func = AddFunctor(); auto stream = ctx.stream(); diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index 17735c05ada52d15677bafe8cd03d3fbc58e2504..b473d68b68ba9eaaacc7c8974d41fe5d5a4678a6 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -223,31 +223,42 @@ struct DimensionsTransform { } }; -template -__device__ __forceinline__ void LoadData( - T *dst, - const _ptr_ T *src, - uint32_t block_offset, - const kps::details::BroadcastConfig &config, - int numel, - int num, - int need_broadcast) { - // numel : whole num of output - // num: how many data will be deal with in this time - if (need_broadcast) { - kps::ReadDataBc( - dst, src, block_offset, config, numel); +template +int GetVecsize(const std::vector &ins, + std::vector *outs) { + int in_vec_size = 4; + int out_vec_size = 4; + if (NumOuts > 1) { + for (int i = 0; i < NumOuts; ++i) { + PADDLE_ENFORCE_EQ( + (*outs)[i]->dims(), + (*outs)[0]->dims(), + phi::errors::InvalidArgument( + "The shape of each output tensor shall be identical yet, but " + "%d-th output tensor`s shape is not.", + i)); + out_vec_size = std::min( + phi::GetVectorizedSize((*outs)[i]->data()), out_vec_size); + } } else { - kps::ReadData(dst, src + block_offset, num); + out_vec_size = phi::GetVectorizedSize((*outs)[0]->data()); + } + + for (auto *in : ins) { + auto temp_size = phi::GetVectorizedSize(in->data()); + in_vec_size = in->dims() == (*outs)[0]->dims() + ? std::min(temp_size, in_vec_size) + : in_vec_size; } + return std::min(out_vec_size, in_vec_size); } -template +template __device__ __forceinline__ void LoadData( T *dst, const _ptr_ T *src, uint32_t block_offset, - const kps::details::BroadcastConfig &config, + const kps::details::BroadcastConfig &config, int numel, int num, int need_broadcast, @@ -255,7 +266,7 @@ __device__ __forceinline__ void LoadData( // numel : whole num of output // num: how many data will be deal with in this time if (need_broadcast) { - kps::ReadDataBc( + kps::ReadDataBc( dst, src, block_offset, config, numel, read_lens); } else { kps::ReadData( @@ -269,14 +280,13 @@ template __device__ void VectorizedBroadcastKernelImpl( const phi::Array &ins, phi::Array<_ptr_ OutT *, NumOuts> outs, const phi::Array &use_broadcast, uint32_t numel, - const phi::Array, Arity> &configs, + const phi::Array &configs, int num, int block_offset, int read_lens, @@ -287,14 +297,14 @@ __device__ void VectorizedBroadcastKernelImpl( #pragma unroll for (int i = 0; i < Arity; i++) { kps::Init(args[i], static_cast(1.0f), read_lens); - LoadData(args[i], - ins[i], - block_offset, - configs[i], - numel, - num, - use_broadcast[i], - read_lens); + LoadData(args[i], + ins[i], + block_offset, + configs[i], + numel, + num, + use_broadcast[i], + read_lens); } constexpr bool kCallElementwiseAny = paddle::platform::FunctionTraits::has_pointer_args; @@ -315,14 +325,13 @@ template + int VecSize> __global__ void VectorizedBroadcastKernel( phi::Array ins, phi::Array<_ptr_ OutT *, NumOuts> outs, phi::Array use_broadcast, uint32_t numel, - phi::Array, Arity> configs, + phi::Array configs, int main_offset, int tail_tid, int read_lens, @@ -338,7 +347,6 @@ __global__ void VectorizedBroadcastKernel( Arity, NumOuts, VecSize, - Rank, false>(ins, outs, use_broadcast, @@ -357,7 +365,6 @@ __global__ void VectorizedBroadcastKernel( Arity, NumOuts, VecSize, - Rank, true>(ins, outs, use_broadcast, @@ -376,7 +383,6 @@ __global__ void VectorizedBroadcastKernel( Arity, NumOuts, VecSize, - Rank, false>(ins, outs, use_broadcast, @@ -393,7 +399,6 @@ __global__ void VectorizedBroadcastKernel( Arity, NumOuts, VecSize, - Rank, true>(ins, outs, use_broadcast, @@ -412,15 +417,14 @@ template -void LaunchBroadcastKernel(const KPDevice &ctx, - const std::vector &ins, - std::vector *outs, - Functor func, - DimensionsTransform merge_dims) { + int VecSize> +void LaunchBroadcastKernel( + const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + Functor func, + const phi::Array &configs) { int numel = (*outs)[0]->numel(); - phi::Array, Arity> configs; phi::Array use_broadcast; phi::Array ins_data; phi::Array<_ptr_ OutT *, NumOuts> outs_data; @@ -432,132 +436,41 @@ void LaunchBroadcastKernel(const KPDevice &ctx, for (int i = 0; i < Arity; i++) { use_broadcast[i] = (ins[i]->numel() != numel); ins_data[i] = (const _ptr_ InT *)(ins[i]->data()); -#ifdef PADDLE_WITH_XPU_KP - if (i == 0) { - configs[i] = kps::details::BroadcastConfig(merge_dims.out_dims, - merge_dims.in_dims[0], - merge_dims.in_dims[1], - merge_dims.dim_size); - } else if (i == 1) { - configs[i] = kps::details::BroadcastConfig(merge_dims.out_dims, - merge_dims.in_dims[1], - merge_dims.in_dims[0], - merge_dims.dim_size); - } -#else - if (use_broadcast[i]) { - // get the broadcast config, - // if data shape is[m, n], then you should set data_dim = {n, m} - // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} - configs[i] = kps::details::BroadcastConfig( - merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); - } -#endif } #ifdef PADDLE_WITH_XPU_KP const int threads = 64; const int blocks = 8; int read_lens = configs[0].buf_len; + auto stream = ctx.x_context()->xpu_stream; int main_offset = (numel / (read_lens * threads)) * read_lens * threads; int tail_tid = numel % (read_lens * threads); - auto stream = ctx.x_context()->xpu_stream; - if (configs[0].cmp_type != kps::details::OptType::CanNotOptimize) { - main_offset = numel; - VectorizedBroadcastKernel<<>>(ins_data, - outs_data, - use_broadcast, - numel, - configs, - main_offset, - tail_tid, - read_lens, - func); - } else { - VectorizedBroadcastKernel<<>>(ins_data, - outs_data, - use_broadcast, - numel, - configs, - main_offset, - tail_tid, - read_lens, - func); - } #else - const int threads = 256; - int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; - int main_offset = (numel / (VecSize * threads)) * VecSize * threads; - int tail_tid = numel % (VecSize * threads); + auto gpu_config = + phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize); + int read_lens = VecSize; auto stream = ctx.stream(); + auto threads = gpu_config.thread_per_block; + auto blocks = gpu_config.block_per_grid; + int main_offset = (numel / (read_lens * gpu_config.GetBlockSize())) * + read_lens * gpu_config.GetBlockSize(); + int tail_tid = numel % (read_lens * gpu_config.GetBlockSize()); +#endif VectorizedBroadcastKernel<<>>(ins_data, - outs_data, - use_broadcast, - numel, - configs, - main_offset, - tail_tid, - VecSize, - func); -#endif -} - -template -void BroadcastKernelForDifferentDimSize( - const KPDevice &ctx, - const std::vector &ins, - std::vector *outs, - int axis, - Functor func) { - const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis); - -#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \ - case rank: { \ - LaunchBroadcastKernel( \ - ctx, ins, outs, func, merge_dims); \ - } break; - - switch (merge_dims.dim_size) { - CALL_BROADCAST_FOR_DIM_SIZE(1); - CALL_BROADCAST_FOR_DIM_SIZE(2); - CALL_BROADCAST_FOR_DIM_SIZE(3); - CALL_BROADCAST_FOR_DIM_SIZE(4); - CALL_BROADCAST_FOR_DIM_SIZE(5); - CALL_BROADCAST_FOR_DIM_SIZE(6); - CALL_BROADCAST_FOR_DIM_SIZE(7); - CALL_BROADCAST_FOR_DIM_SIZE(8); - default: { - PADDLE_THROW(phi::errors::InvalidArgument( - "The maximum dimension of input tensor is expected to be less than " - "%d, but received %d.", - merge_dims.dim_size, - phi::DDim::kMaxRank)); - } - } -#undef CALL_BROADCAST_FOR_DIM_SIZE + VecSize><<>>( + ins_data, + outs_data, + use_broadcast, + numel, + configs, + main_offset, + tail_tid, + read_lens, + func); } template ; const int kArity = Traits::has_pointer_args ? static_cast(ET) : Traits::arity; - PADDLE_ENFORCE_EQ(ins.size(), - kArity, - phi::errors::InvalidArgument( - "The number of inputs is expected to be equal to the " - "arity of functor. But received: the number of inputs " - "is %d, the arity of functor is %d.", - ins.size(), - kArity)); - PADDLE_ENFORCE_LE(kArity, - 3, - phi::errors::InvalidArgument( - "Currently only broadcast of ternary is supported " - "and verified, but received %d.", - kArity)); - PADDLE_ENFORCE_EQ(outs->size(), - NumOuts, - phi::errors::InvalidArgument( - "Number of outputs shall equal to number of functions, " - "but number of outputs is %d, of functions is %d.", - outs->size(), - NumOuts)); - int in_vec_size = 4; - int out_vec_size = 4; - if (NumOuts > 1) { - for (int i = 0; i < NumOuts; ++i) { - PADDLE_ENFORCE_EQ( - (*outs)[i]->dims(), - (*outs)[0]->dims(), - phi::errors::InvalidArgument( - "The shape of each output tensor shall be identical yet, but " - "%d-th output tensor`s shape is not.", - i)); - out_vec_size = std::min( - phi::GetVectorizedSize((*outs)[i]->data()), out_vec_size); - } - } else { - out_vec_size = phi::GetVectorizedSize((*outs)[0]->data()); - } + PADDLE_ENFORCE_EQ( + ins.size(), + kArity, + phi::errors::InvalidArgument("The number of inputs is expected to be " + "equal to the " + "arity of functor. But recieved: the " + "number of inputs " + "is %d, the arity of functor is %d.", + ins.size(), + kArity)); + PADDLE_ENFORCE_LE( + kArity, + 3, + phi::errors::InvalidArgument("Currently only broadcast of ternary is " + "supported " + "and verified, but received %d.", + kArity)); + PADDLE_ENFORCE_EQ( + outs->size(), + NumOuts, + phi::errors::InvalidArgument("Number of outputs shall equal to number " + "of functions, " + "but number of outputs is %d, of " + "functions is %d.", + outs->size(), + NumOuts)); + + // mergedim and get vec_size + const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis); + phi::Array configs; - for (auto *in : ins) { - auto temp_size = phi::GetVectorizedSize(in->data()); - in_vec_size = in->dims() == (*outs)[0]->dims() - ? std::min(temp_size, in_vec_size) - : in_vec_size; +// get vec_size +#ifdef PADDLE_WITH_XPU_KP + PADDLE_ENFORCE_EQ( + ins.size(), + 2, + phi::errors::InvalidArgument( + "XPU only support inputs is 2, but received %d", ins.size())); + configs[0] = kps::details::BroadcastConfig(merge_dims.out_dims, + merge_dims.in_dims[0], + merge_dims.in_dims[1], + merge_dims.dim_size); + configs[1] = kps::details::BroadcastConfig(merge_dims.out_dims, + merge_dims.in_dims[1], + merge_dims.in_dims[0], + merge_dims.dim_size); + auto type = kps::details::OptType::CanNotOptimize; + bool is_optimize = configs[0].cmp_type != type; + int vec_size = is_optimize ? VecSizeL : VecSizeM; +#else + for (int i = 0; i < kArity; i++) { + // get the broadcast config, + // if data shape is[m, n], then you should set data_dim = {n, m} + // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} + if (ins[i]->numel()) { + configs[i] = kps::details::BroadcastConfig( + merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); + } } - int vec_size = std::min(out_vec_size, in_vec_size); + int vec_size = GetVecsize(ins, outs); +#endif switch (vec_size) { - case 4: { - BroadcastKernelForDifferentDimSize(ctx, ins, outs, axis, func); + case VecSizeL: { + LaunchBroadcastKernel( + ctx, ins, outs, func, configs); break; } - case 2: { - BroadcastKernelForDifferentDimSize(ctx, ins, outs, axis, func); + case VecSizeM: { + LaunchBroadcastKernel( + ctx, ins, outs, func, configs); break; } - case 1: { - BroadcastKernelForDifferentDimSize(ctx, ins, outs, axis, func); + case VecSizeS: { + LaunchBroadcastKernel( + ctx, ins, outs, func, configs); break; } default: { diff --git a/paddle/phi/kernels/primitive/datamover_primitives.h b/paddle/phi/kernels/primitive/datamover_primitives.h index ea1a830f89ab59d2d571ead95664224af809dd87..8b0c42c9d19b1294254ddd467dd412e420c8bb58 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives.h +++ b/paddle/phi/kernels/primitive/datamover_primitives.h @@ -82,10 +82,10 @@ struct FastDivMod { * index of the output data. if input or output shape is [dim0, dim1] then dims * must be [dim1, dim0]. */ -template struct BroadcastConfig { - FastDivMod divmoders[kDims]; + FastDivMod divmoders[phi::DDim::kMaxRank]; uint32_t strides[phi::DDim::kMaxRank]; + int kDims; HOSTDEVICE BroadcastConfig() {} HOSTDEVICE BroadcastConfig(const std::vector& out_dims, @@ -109,7 +109,7 @@ struct BroadcastConfig { std::multiplies()) : strides_in[i]; } - + kDims = dim_size; memcpy(strides, strides_in.data(), kDims * sizeof(uint32_t)); memcpy(divmoders, divmoders_in.data(), kDims * sizeof(FastDivMod)); } @@ -436,17 +436,12 @@ __device__ __forceinline__ void ReadData(ArgsT* dst, * stride_nx: Each read one element stride stride_nx elements in the last dim. * stride_ny: Each read one element stride stride_ny elements in the first dim. */ -template +template __device__ __forceinline__ void ReadDataBc( T* dst, const T* __restrict__ src, uint32_t block_offset, - details::BroadcastConfig config, + const details::BroadcastConfig& config, int total_num_output, int stride_nx, int stride_ny) { @@ -465,7 +460,8 @@ __device__ __forceinline__ void ReadDataBc( } } #pragma unroll - for (int i = 0; i < Rank; ++i) { + for (int i = 0; i < phi::DDim::kMaxRank; ++i) { + if (i >= config.kDims) break; auto fast_divmoder = config.divmoders[i].Divmod(index_output); index_output = fast_divmoder.val[0]; index_src += fast_divmoder.val[1] * config.strides[i]; @@ -785,53 +781,14 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) { * coordinate mapping relationship between output data and input data. * total_num_output: Total number of original output. */ -template -__device__ __forceinline__ void ReadDataBc( - T* dst, - const T* __restrict__ src, - uint32_t block_offset, - details::BroadcastConfig config, - int total_num_output) { - uint32_t thread_offset = block_offset + threadIdx.x * NX; - uint32_t index_src = 0; - -#pragma unroll - for (uint32_t nx = 0; nx < NX; ++nx) { - uint32_t index_output = thread_offset + nx; - index_src = 0; - if (IsBoundary) { - if (index_output >= total_num_output) { - break; - } - } -#pragma unroll - for (int i = 0; i < Rank; ++i) { - auto fast_divmoder = config.divmoders[i].Divmod(index_output); - index_output = fast_divmoder.val[0]; - index_src += fast_divmoder.val[1] * config.strides[i]; - } - dst[nx] = src[index_src]; - } -} - -template +template __device__ __forceinline__ void ReadDataBc( T* dst, const T* __restrict__ src, uint32_t block_offset, - details::BroadcastConfig config, + const details::BroadcastConfig& config, int total_num_output, - int read_lens) { + int read_lens = NX) { uint32_t thread_offset = block_offset + threadIdx.x * NX; uint32_t index_src = 0; @@ -845,7 +802,8 @@ __device__ __forceinline__ void ReadDataBc( } } #pragma unroll - for (int i = 0; i < Rank; ++i) { + for (int i = 0; i < phi::DDim::kMaxRank; ++i) { + if (i >= config.kDims) break; auto fast_divmoder = config.divmoders[i].Divmod(index_output); index_output = fast_divmoder.val[0]; index_src += fast_divmoder.val[1] * config.strides[i]; @@ -853,6 +811,7 @@ __device__ __forceinline__ void ReadDataBc( dst[nx] = src[index_src]; } } + /** * @brief Initialize register with data index. * diff --git a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h index eb25632378a58607ed2918ecd9254a4e1f6e8f54..3799b9d4892f85f751f6007f3ea852febfe382c1 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h @@ -65,7 +65,6 @@ struct alignas(sizeof(T) * VecSize) VectorType { * must be [dim1, dim0]. */ #pragma pack(4) -template struct BroadcastConfig { int strides_in[phi::DDim::kMaxRank]; int strides_out[phi::DDim::kMaxRank]; @@ -78,7 +77,7 @@ struct BroadcastConfig { int n = 1; int k = 1; int buf_len = 0; - + int kDims; HOSTDEVICE BroadcastConfig() {} HOSTDEVICE BroadcastConfig(const std::vector& out_dims, @@ -99,7 +98,7 @@ struct BroadcastConfig { for (int i = 0; i < dim_size; i++) { dim_tmp[i] = in_dims[i]; } - + kDims = dim_size; memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int)); memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int)); memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int)); @@ -551,7 +550,6 @@ __device__ __forceinline__ void ReadData(ArgsT* dst, * NY: The number of data rows loaded by each thread. * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. - * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. * IsBoundary: Indicates whether to perform block access storage out-of-bounds * judgment. When the number of data processed by the block is less than * NX x NY x core_num(), boundary judgment is required to avoid memory access @@ -567,16 +565,11 @@ __device__ __forceinline__ void ReadData(ArgsT* dst, * stride_nx: Each read one element stride stride_nx elements in the last dim. * stride_ny: Each read one element stride stride_ny elements in the first dim. */ -template +template __device__ __inline__ void ReadDataBc(T* dst, const T _global_ptr_* src, uint32_t block_offset, - details::BroadcastConfig config, + const details::BroadcastConfig& config, int total_num_output, int stride_nx, int stride_ny) { @@ -882,60 +875,6 @@ __device__ __inline__ void Init(T* dst, T* init_data, int num) { } } -/** - * @brief Read 1D data from global memory to register with broadcast form. - * - * @template paraments - * T: The type of data stored in the global memory. - * NX: The number of data continuously loaded by each thread. - * NY: The number of data rows loaded by each thread, only NY = 1 was supported. - * BlockSize: Identifies the current device thread index method. For xpu, - * core_id() is used as the index. - * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. - * IsBoundary: Indicates whether to perform block access storage out-of-bounds - * judgment. When the number of data processed by the block is less than - * NX x NY x core_num(), boundary judgment is required to avoid memory access - * crossing the boundary. - * - * @param: - * dst: The register pointer of the thread, the size is NX * NY. - * src: The original input data pointer of kernel. - * block_offset: The data offset of this block, core_num() * blockIdx.x * NX; - * config: Calculation configuration of broadcast. It is used to calculate the - * coordinate mapping relationship between output data and input data. - * total_num_output: Total number of original output. - */ -template -__device__ __inline__ void ReadDataBc( - T* dst, - const T _global_ptr_* src, - uint32_t block_offset, - const details::BroadcastConfig& config, - int total_num_output) { - int thread_offset = block_offset + core_id() * NX; - int index_src = 0; - - __local__ T in_temp; -#pragma unroll - for (int nx = 0; nx < NX; ++nx) { - int index_output = thread_offset + nx; - index_src = 0; - if (IsBoundary) { - if (index_output >= total_num_output) { - break; - } - } - index_src = config(index_output); - GM2LM(src + index_src, &in_temp, sizeof(T)); - dst[nx] = in_temp; - } -} - /** * @brief Read data from global memory to local memory with broadcast * {m, 1, k}-> {m, n, k} form. @@ -952,12 +891,12 @@ __device__ __inline__ void ReadDataBc( * coordinate mapping relationship between output data and input data. * read_lens: The number of data continuously loaded by each thread. */ -template +template __device__ __inline__ void ReadDataBcM1kMnk( T* dst, const T _global_ptr_* src, int thread_offset, - const details::BroadcastConfig& config, + const details::BroadcastConfig& config, int read_lens) { int index_output = thread_offset; int index_base = config(index_output); @@ -999,12 +938,12 @@ __device__ __inline__ void ReadDataBcM1kMnk( * coordinate mapping relationship between output data and input data. * read_lens: The number of data continuously loaded by each thread. */ -template +template __device__ __inline__ void ReadDataBcM1Mn( T* dst, const T _global_ptr_* src, int thread_offset, - const details::BroadcastConfig& config, + const details::BroadcastConfig& config, int read_lens) { int index_output = thread_offset; int index_base = config(index_output); @@ -1027,7 +966,6 @@ __device__ __inline__ void ReadDataBcM1Mn( * * @template paraments * T: Data type of register. - * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. * * @param: * dst: The register pointer of the thread, the size is NX. @@ -1037,12 +975,12 @@ __device__ __inline__ void ReadDataBcM1Mn( * coordinate mapping relationship between output data and input data. * read_lens: The number of data continuously loaded by each thread. */ -template +template __device__ __inline__ void ReadDataBc1NMn( T* dst, const T _global_ptr_* src, int thread_offset, - const details::BroadcastConfig& config, + const details::BroadcastConfig& config, int read_lens) { int index_output = thread_offset; int index_base = config(index_output); @@ -1075,7 +1013,6 @@ __device__ __inline__ void ReadDataBc1NMn( * * @template paraments * T: Data type of register. - * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. * * @param: * dst: The register pointer of the thread, the size is NX. @@ -1085,12 +1022,12 @@ __device__ __inline__ void ReadDataBc1NMn( * coordinate mapping relationship between output data and input data. * read_lens: The number of data continuously loaded by each thread. */ -template +template __device__ __inline__ void ReadDataBc1N1Mnk( T* dst, const T _global_ptr_* src, int thread_offset, - const details::BroadcastConfig& config, + const details::BroadcastConfig& config, int read_lens) { int index_output = thread_offset; int index_base = config(index_output); @@ -1130,7 +1067,6 @@ __device__ __inline__ void ReadDataBc1N1Mnk( * * @template paraments * T: Data type of register. - * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. * * @param: * dst: The register pointer of the thread, the size is NX. @@ -1140,13 +1076,12 @@ __device__ __inline__ void ReadDataBc1N1Mnk( * coordinate mapping relationship between output data and input data. * read_lens: The number of data continuously loaded by each thread. */ -template -__device__ __inline__ void ReadDataBc1N( - T* dst, - const T _global_ptr_* src, - int thread_offset, - const details::BroadcastConfig& config, - int read_lens) { +template +__device__ __inline__ void ReadDataBc1N(T* dst, + const T _global_ptr_* src, + int thread_offset, + const details::BroadcastConfig& config, + int read_lens) { int index_output = thread_offset; int index_base = config(index_output); T in_temp; @@ -1174,12 +1109,12 @@ __device__ __inline__ void ReadDataBc1N( * total_num_output: Total number of original output. * read_lens: The number of data continuously loaded by each thread. */ -template +template __device__ __inline__ void ReadDataBcCanNotCmp( T* dst, const T _global_ptr_* src, int thread_offset, - const details::BroadcastConfig& config, + const details::BroadcastConfig& config, int total_num_output, int read_lens) { int index_output = thread_offset; @@ -1215,7 +1150,6 @@ __device__ __inline__ void ReadDataBcCanNotCmp( * NY: The number of data rows loaded by each thread, only NY = 1 was supported. * BlockSize: Identifies the current device thread index method. For xpu, * core_id() is used as the index. - * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. * IsBoundary: Indicates whether to perform block access storage out-of-bounds * judgment. When the number of data processed by the block is less than * NX x NY x core_num(), boundary judgment is required to avoid memory access @@ -1230,33 +1164,27 @@ __device__ __inline__ void ReadDataBcCanNotCmp( * read_lens: The number of data continuously loaded by each thread. * total_num_output: Total number of original output. */ -template -__device__ __inline__ void ReadDataBc( - T* dst, - const T _global_ptr_* src, - uint32_t block_offset, - const details::BroadcastConfig& config, - int total_num_output, - int read_lens) { +template +__device__ __inline__ void ReadDataBc(T* dst, + const T _global_ptr_* src, + uint32_t block_offset, + const details::BroadcastConfig& config, + int total_num_output, + int read_lens) { int thread_offset = block_offset + core_id() * read_lens; if (config.cmp_type == details::OptType::MNK_M1K) { - ReadDataBcM1kMnk(dst, src, thread_offset, config, read_lens); + ReadDataBcM1kMnk(dst, src, thread_offset, config, read_lens); } else if (config.cmp_type == details::OptType::N_1) { - ReadDataBc1N(dst, src, thread_offset, config, read_lens); + ReadDataBc1N(dst, src, thread_offset, config, read_lens); } else if (config.cmp_type == details::OptType::MN_M) { - ReadDataBcM1Mn(dst, src, thread_offset, config, read_lens); + ReadDataBcM1Mn(dst, src, thread_offset, config, read_lens); } else if (config.cmp_type == details::OptType::MN_N) { - ReadDataBc1NMn(dst, src, thread_offset, config, read_lens); + ReadDataBc1NMn(dst, src, thread_offset, config, read_lens); } else if (config.cmp_type == details::OptType::MNK_1N1) { - ReadDataBc1N1Mnk(dst, src, thread_offset, config, read_lens); + ReadDataBc1N1Mnk(dst, src, thread_offset, config, read_lens); } else { - ReadDataBcCanNotCmp( + ReadDataBcCanNotCmp( dst, src, thread_offset, config, total_num_output, read_lens); } } diff --git a/paddle/phi/kernels/primitive/kernel_primitives.h b/paddle/phi/kernels/primitive/kernel_primitives.h index ea5846c3a241809254655d9cbde6e3fbebfedaf9..f68a046ae077a8fa2727bf5c8ad0322aca76f209 100644 --- a/paddle/phi/kernels/primitive/kernel_primitives.h +++ b/paddle/phi/kernels/primitive/kernel_primitives.h @@ -40,7 +40,9 @@ #define GRID_NUM_X cluster_num() #define GRID_NUM_Y 0 #define GRID_NUM_Z 0 - +#define VecSizeL 512 +#define VecSizeM 256 +#define VecSizeS 128 #else #define KPStream gpuStream_t @@ -64,6 +66,9 @@ #define GRID_NUM_Y gridDim.y #define GRID_NUM_Z gridDim.z +#define VecSizeL 4 +#define VecSizeM 2 +#define VecSizeS 1 #endif // include file