From d5afc1bac8e19cfe0c53fcace4be6d06f60a10ef Mon Sep 17 00:00:00 2001 From: shixingbo <90814748+bmb0537@users.noreply.github.com> Date: Tue, 7 Jun 2022 21:27:01 +0800 Subject: [PATCH] Optimized the performance of activation op in XPU2 (#43187) --- .../operators/optimizers/cast_with_ptr.h | 2 +- paddle/phi/kernels/funcs/elementwise_base.h | 77 +++++++++++++------ .../kernels/primitive/datamover_primitives.h | 5 +- .../primitive/datamover_primitives_xpu2.h | 25 +++--- 4 files changed, 73 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/optimizers/cast_with_ptr.h b/paddle/fluid/operators/optimizers/cast_with_ptr.h index eb031ae0c9..ec7db8537b 100644 --- a/paddle/fluid/operators/optimizers/cast_with_ptr.h +++ b/paddle/fluid/operators/optimizers/cast_with_ptr.h @@ -44,7 +44,7 @@ static void VecCastKernel(const platform::CUDADeviceContext &ctx, const InT *x, phi::Array<_ptr_ OutT *, 1> out_arr; out_arr[0] = y; phi::funcs::VectorizedElementwiseKernel - <<>>(in_arr, out_arr, n, main_offset, + <<>>(in_arr, out_arr, n, main_offset, VecSize, FunctorT()); } diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h index 71dfbc206a..8b5a3cf8aa 100644 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -513,19 +513,23 @@ struct Loader { ArgsT *args, int num, int data_offset, + int read_lens, bool is_boundary) { using Type = std::tuple_element_t; - kps::Init(args, static_cast(1.0f)); + kps::Init( + args, static_cast(1.0f), read_lens); if (is_boundary) { kps::ReadData( args, reinterpret_cast(in[Index]) + data_offset, - num); + num, + read_lens); } else { kps::ReadData( args, reinterpret_cast(in[Index]) + data_offset, - num); + num, + read_lens); } } }; @@ -660,11 +664,20 @@ template struct SameDimsElementwisePrimitiveCaller { - __device__ inline void operator()(Functor func, ArgsT *args, OutT *result) { + __device__ inline void operator()(Functor func, + ArgsT *args, + OutT *result, + int read_lens) { +#ifdef PADDLE_WITH_XPU_KP + for (int idx = 0; idx < read_lens; ++idx) { + result[idx] = static_cast(Apply(func, args[idx])); + } +#else #pragma unroll for (int idx = 0; idx < VecSize; ++idx) { result[idx] = static_cast(Apply(func, args[idx])); } +#endif } }; @@ -750,6 +763,7 @@ __device__ void VectorizedElementwiseKernelImpl( phi::Array<_ptr_ OutT *, NumOuts> outs, int num, int data_offset, + int read_lens, Functor func) { using Traits = paddle::platform::FunctionTraits; using ArgsT = typename Traits::ArgsTuple; @@ -757,16 +771,16 @@ __device__ void VectorizedElementwiseKernelImpl( ConditionalT result[VecSize]; Unroller::step( - in, args, num, data_offset, IsBoundary); + in, args, num, data_offset, read_lens, IsBoundary); SameDimsElementwisePrimitiveCaller, VecSize, Functor, ArgsT, - Arity>()(func, args, result); + Arity>()(func, args, result, read_lens); - ElementwiseWriteDataCaller()( - outs, result, data_offset, num); + ElementwiseWriteDataCallerBc()( + outs, result, data_offset, num, read_lens); } template @@ -775,9 +789,10 @@ __global__ void VectorizedElementwiseKernel( phi::Array<_ptr_ OutT *, NumOuts> outs, int size, int main_offset, + int read_lens, Functor func) { - int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; - int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; + int data_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens; + int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens; for (; data_offset < main_offset; data_offset += stride) { VectorizedElementwiseKernelImpl( - ins, outs, VecSize * BLOCK_NUM_X, data_offset, func); + ins, outs, read_lens * BLOCK_NUM_X, data_offset, read_lens, func); } int num = size - data_offset; @@ -795,7 +810,8 @@ __global__ void VectorizedElementwiseKernel( Arity, NumOuts, VecSize, - true>(ins, outs, num, data_offset, func); + true>( + ins, outs, num, data_offset, read_lens, func); } } @@ -803,6 +819,7 @@ template void ElementwiseCudaKernel(const KPDevice &ctx, const std::vector &ins, std::vector *outs, + int read_lens, Functor func) { auto numel = (*outs)[0]->numel(); // To avoid running errors when ins.size()== 0 @@ -817,10 +834,10 @@ void ElementwiseCudaKernel(const KPDevice &ctx, int block_size = 64; int grid_size = 8; auto stream = ctx.x_context()->xpu_stream; - int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; + int main_offset = (numel / (read_lens * block_size)) * read_lens * block_size; VectorizedElementwiseKernel <<>>( - ins_data, outs_data, numel, main_offset, func); + ins_data, outs_data, numel, main_offset, read_lens, func); #else auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize); @@ -829,7 +846,7 @@ void ElementwiseCudaKernel(const KPDevice &ctx, auto stream = ctx.stream(); VectorizedElementwiseKernel <<>>( - ins_data, outs_data, numel, main_offset, func); + ins_data, outs_data, numel, main_offset, VecSize, func); #endif } @@ -868,20 +885,32 @@ void ElementwiseKernel(const KPDevice &ctx, } } +#ifdef PADDLE_WITH_XPU_KP + const int buf_size = 256; + int numel = (*outs)[0]->numel(); + int block_size = 64; + int grid_size = 8; + int nthreads = block_size * grid_size; + int read_lens = + std::min(buf_size, kps::details::RoundUpDiv(numel, 32 * nthreads) * 32); + int vec_size = buf_size; +#else // calculate the max vec_size for all ins and outs int vec_size = GetVectorizedSizeForTensors(ins, *outs); + int read_lens = vec_size; +#endif switch (vec_size) { - case 4: - ElementwiseCudaKernel( - ctx, ins, outs, func); + case VecSizeL: + ElementwiseCudaKernel( + ctx, ins, outs, read_lens, func); break; - case 2: - ElementwiseCudaKernel( - ctx, ins, outs, func); + case VecSizeM: + ElementwiseCudaKernel( + ctx, ins, outs, read_lens, func); break; - case 1: - ElementwiseCudaKernel( - ctx, ins, outs, func); + case VecSizeS: + ElementwiseCudaKernel( + ctx, ins, outs, read_lens, func); break; default: { PADDLE_THROW(phi::errors::Unimplemented( diff --git a/paddle/phi/kernels/primitive/datamover_primitives.h b/paddle/phi/kernels/primitive/datamover_primitives.h index 8b0c42c9d1..bf60d1610e 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives.h +++ b/paddle/phi/kernels/primitive/datamover_primitives.h @@ -259,7 +259,7 @@ __device__ __forceinline__ void Init(T* dst, T init_data, int read_lens) { * it supports different data types of inputs. */ template -__device__ __forceinline__ void Init(ArgsT* dst, T init_data) { +__device__ __forceinline__ void Init(ArgsT* dst, T init_data, int read_lens) { #pragma unroll for (int i = 0; i < NX; i++) { std::get(dst[i]) = init_data; @@ -382,7 +382,8 @@ template __device__ __forceinline__ void ReadData(ArgsT* dst, const T* __restrict__ src, - int num) { + int num, + int read_lens) { if (IsBoundary) { // blockDim.x * NX > num int thread_offset = threadIdx.x * NX; #pragma unroll diff --git a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h index 1e5dfe2a54..f2d187f89b 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h @@ -21,6 +21,8 @@ namespace phi { namespace kps { namespace details { +int RoundUpDiv(int n, int k) { return (n + k - 1) / k; } + enum class OptType { // Optimize type of calc after input shape compressed CanNotOptimize = -1, // can not optimize, broadcast first N_1, // just like {1} op {100} or {100} op {1} @@ -425,9 +427,10 @@ __device__ __inline__ void Init(T* dst, T init_data, int read_lens) { * it supports different data types of inputs. */ template -__device__ __forceinline__ void Init(ArgsT* dst, T init_data) { +__device__ __forceinline__ void Init(ArgsT* dst, T init_data, int read_lens) { + mfence(); #pragma unroll - for (int i = 0; i < NX; i++) { + for (int i = 0; i < read_lens; i++) { std::get(dst[i]) = init_data; } } @@ -523,22 +526,24 @@ template __device__ __forceinline__ void ReadData(ArgsT* dst, const T _global_ptr_* src, - int num) { - int thread_offset = core_id() * NX; + int num, + int read_lens) { + int thread_offset = core_id() * read_lens; __local__ T in_temp[1]; __local__ T in_vec[NX]; - if (IsBoundary) { // core_num() * NX > num + if (IsBoundary) { // core_num() * read_lens > num #pragma unroll - for (int idx = 0; idx < NX; ++idx) { + for (int idx = 0; idx < read_lens; ++idx) { if (idx + thread_offset < num) { GM2LM(src + thread_offset + idx, in_temp, sizeof(T)); std::get(dst[idx]) = in_temp[0]; + mfence(); } } - } else { // core_num() * NX < num - GM2LM(src + thread_offset, in_vec, NX * sizeof(T)); + } else { // core_num() * read_lens < num + GM2LM(src + thread_offset, in_vec, read_lens * sizeof(T)); #pragma unroll - for (int idx = 0; idx < NX; ++idx) { + for (int idx = 0; idx < read_lens; ++idx) { std::get(dst[idx]) = in_vec[idx]; } } @@ -727,10 +732,12 @@ __device__ void WriteData(T _global_ptr_* dst, for (int idx = 0; idx < read_lens; ++idx) { if (idx + thread_offset < num) { in_temp[0] = src[idx]; + mfence(); LM2GM(in_temp, dst + idx + thread_offset, sizeof(T)); } } } else { // core_num() * read_lens < num + mfence(); LM2GM(src, dst + thread_offset, read_lens * sizeof(T)); } } -- GitLab