From 79ee6d63694f7f8252408fb5533bd56b1f7f3e6e Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Mon, 22 Nov 2021 09:01:18 +0800 Subject: [PATCH] modified the elementwise_op_broadcast and elementwise_op_impl for xpu2 (#37226) * modified the elementwise_op_broadcast and elementwise_op_impl for xpu2 --- .../elementwise/elementwise_broadcast.cu.h | 125 +++++++++++++----- .../elementwise/elementwise_no_broadcast.cu.h | 49 ++++--- 2 files changed, 128 insertions(+), 46 deletions(-) diff --git a/paddle/pten/kernels/functions/cuda/elementwise/elementwise_broadcast.cu.h b/paddle/pten/kernels/functions/cuda/elementwise/elementwise_broadcast.cu.h index 2e5ea5fa481..40d3cf60f09 100644 --- a/paddle/pten/kernels/functions/cuda/elementwise/elementwise_broadcast.cu.h +++ b/paddle/pten/kernels/functions/cuda/elementwise/elementwise_broadcast.cu.h @@ -196,7 +196,7 @@ template -__device__ void DealSegment( +__device__ void ElementwiseBroadcastKernelImpl( const paddle::framework::Array &ins, OutT *out, const paddle::framework::Array &use_broadcast, @@ -204,12 +204,11 @@ __device__ void DealSegment( const paddle::framework::Array, Arity> &configs, int num, + int block_offset, Functor func) { InT args[Arity][VecSize]; OutT result[VecSize]; - int block_offset = blockIdx.x * blockDim.x * VecSize; - #pragma unroll for (int i = 0; i < Arity; i++) { kps::Init(args[i], static_cast(1.0f)); @@ -240,27 +239,73 @@ template -__global__ void BroadcastKernel( +__global__ void ElementwiseBroadcastKernel( paddle::framework::Array ins, OutT *out, paddle::framework::Array use_broadcast, uint32_t numel, paddle::framework::Array, Arity> configs, - int main_tid, + int main_offset, int tail_tid, Functor func) { - int block_offset = blockIdx.x * blockDim.x * VecSize; - // data offset of this block - if (blockIdx.x < main_tid) { - int num = blockDim.x * VecSize; // blockIdx.x < main_tid - pten::DealSegment( - ins, out, use_broadcast, numel, configs, num, func); - } else { // reminder - int num = tail_tid; - pten::DealSegment( - ins, out, use_broadcast, numel, configs, num, func); + int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; + int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; +#ifdef PADDLE_WITH_XPU2 + for (; block_offset < main_offset; block_offset += stride) { + ElementwiseBroadcastKernelImpl(ins, + out, + use_broadcast, + numel, + configs, + BLOCK_NUM_X * VecSize, + block_offset, + func); + } + if (block_offset < numel) { + ElementwiseBroadcastKernelImpl( + ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func); } + +#else + if (block_offset < main_offset) { + ElementwiseBroadcastKernelImpl(ins, + out, + use_broadcast, + numel, + configs, + BLOCK_NUM_X * VecSize, + block_offset, + func); + } else { + ElementwiseBroadcastKernelImpl( + ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func); + } +#endif } template mutable_data(); @@ -298,20 +343,40 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); } } - - BroadcastKernel<<>>(ins_data, - out_data, - use_broadcast, - numel, - configs, - main_tid, - tail_tid, - func); +#ifdef PADDLE_WITH_XPU2 + threads = 128; + blocks = 8; + main_offset = (numel / (VecSize * threads)) * VecSize * threads; + tail_tid = numel % (VecSize * threads); + ElementwiseBroadcastKernel<<>>(ins_data, + out_data, + use_broadcast, + numel, + configs, + main_offset, + tail_tid, + func); +#else + ElementwiseBroadcastKernel<<>>( + ins_data, + out_data, + use_broadcast, + numel, + configs, + main_offset, + tail_tid, + func); +#endif } template diff --git a/paddle/pten/kernels/functions/cuda/elementwise/elementwise_no_broadcast.cu.h b/paddle/pten/kernels/functions/cuda/elementwise/elementwise_no_broadcast.cu.h index 10142ba0a37..4eaf8867fd0 100644 --- a/paddle/pten/kernels/functions/cuda/elementwise/elementwise_no_broadcast.cu.h +++ b/paddle/pten/kernels/functions/cuda/elementwise/elementwise_no_broadcast.cu.h @@ -57,16 +57,15 @@ template -__device__ void DealSegment( +__device__ void VectorizedElementwiseKernelImpl( const paddle::framework::Array &in, OutT *out, int num, + int data_offset, Functor func) { InT args[Arity][VecSize]; OutT result[VecSize]; - int data_offset = VecSize * blockIdx.x * blockDim.x; - #pragma unroll for (int i = 0; i < Arity; i++) { kps::Init(args[i], static_cast(1.0f)); @@ -87,18 +86,23 @@ __device__ void DealSegment( } template -__global__ void ElementVectorizeKernel( +__global__ void VectorizedElementwiseKernel( paddle::framework::Array ins, OutT *out, int size, + int main_offset, Functor func) { - int data_offset = VecSize * blockIdx.x * blockDim.x; + int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; + int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; + for (; data_offset < main_offset; data_offset += stride) { + VectorizedElementwiseKernelImpl( + ins, out, VecSize * BLOCK_NUM_X, data_offset, func); + } + int num = size - data_offset; - // the num this time have to deal with - if (VecSize * blockDim.x > num) { // reminder segment - DealSegment(ins, out, num, func); - } else { // complete segment - DealSegment(ins, out, num, func); + if (num > 0) { + VectorizedElementwiseKernelImpl( + ins, out, num, data_offset, func); } } @@ -132,12 +136,25 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, for (int i = 0; i < Arity; i++) { ins_data[i] = ins[i]->data(); } - ElementVectorizeKernel<<>>( - ins_data, out_data, numel, func); +#ifdef PADDLE_WITH_XPU2 + block_size = 128; + grid_size = 8; + int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; + VectorizedElementwiseKernel<<>>( + ins_data, out_data, numel, main_offset, func); +#else + int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; + VectorizedElementwiseKernel<<>>( + ins_data, out_data, numel, main_offset, func); +#endif } template -- GitLab