diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 5834f091d9a4de02afe7488ededc0189ae6f21d0..85c371e9f9d450c55741b901eff6f102fa6c3f6f 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -14,8 +14,8 @@ #pragma once -// CUDA and HIP use same api -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +// CUDA, XPU and HIP use same api +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || defined(__xpu__) #include #include @@ -220,7 +220,7 @@ struct IndexCalculator { phi::Array dims; phi::Array strides; phi::Array reduce_strides; -#ifndef PADDLE_WITH_XPU2 +#ifndef PADDLE_WITH_XPU_KP phi::Array divmoders; #endif }; @@ -231,81 +231,65 @@ struct ReduceIndexMapping { HOSTDEVICE explicit ReduceIndexMapping(const kps::DimConfig& dims) : dim(dims) {} +#ifdef PADDLE_WITH_XPU_KP __device__ __forceinline__ int BlockIdX() { -#ifdef PADDLE_WITH_XPU2 if (ReduceLastDim) { return (cluster_id() / dim.split_num_x % dim.split_num_y); } else { return cluster_id() % dim.split_num_x; } -#else - return blockIdx.x; -#endif } __device__ __forceinline__ int BlockIdY() { -#ifdef PADDLE_WITH_XPU2 if (ReduceLastDim) { return (cluster_id() % dim.split_num_x); } else { return (cluster_id() / dim.split_num_x % dim.split_num_y); } -#else - return blockIdx.y; -#endif } - __device__ __forceinline__ int BlockDimX() { -#ifdef PADDLE_WITH_XPU2 - return dim.deal_size_x; -#else - return blockDim.x; -#endif - } + __device__ __forceinline__ int BlockDimX() { return dim.deal_size_x; } - __device__ __forceinline__ int BlockDimY() { -#ifdef PADDLE_WITH_XPU2 - return 1; -#else - return blockDim.y; -#endif - } + __device__ __forceinline__ int BlockDimY() { return 1; } __device__ __forceinline__ int GridDimX() { -#ifdef PADDLE_WITH_XPU2 if (ReduceLastDim) { return dim.split_num_y; } else { return dim.split_num_x; } -#else - return gridDim.x; -#endif } __device__ __forceinline__ int GridDimY() { -#ifdef PADDLE_WITH_XPU2 if (ReduceLastDim) { return dim.split_num_x; } else { return dim.split_num_y; } -#else - return gridDim.y; -#endif } __device__ __forceinline__ int GetLoopSize() { -#ifdef PADDLE_WITH_XPU2 if (ReduceLastDim) { return dim.deal_size_y; } else { return dim.deal_size_x; } + } #else - return 1; + __device__ __forceinline__ int BlockIdX() { return blockIdx.x; } + + __device__ __forceinline__ int BlockIdY() { return blockIdx.y; } + + __device__ __forceinline__ int BlockDimX() { return blockDim.x; } + + __device__ __forceinline__ int BlockDimY() { return blockDim.y; } + + __device__ __forceinline__ int GridDimX() { return gridDim.x; } + + __device__ __forceinline__ int GridDimY() { return gridDim.y; } + + __device__ int GetLoopSize() { return 1; } #endif - } }; // when reduce_type == kReduceLastDim this struct will be used @@ -341,7 +325,7 @@ struct ReduceConfig { // when should_reduce_again is true, we need malloc temp space for temp data void SetOutputData(Ty* y_data, - const phi::GPUContext& dev_ctx, + const KPDevice& dev_ctx, phi::DenseTensor* tmp) { if (should_reduce_again) { tmp->Resize(phi::make_ddim( @@ -640,9 +624,7 @@ struct ReduceConfig { int blocking_size; bool should_reduce_again; bool reduce_last_dim; - Ty* output_data; - dim3 block; dim3 grid; }; @@ -770,9 +752,10 @@ __global__ void ReduceAnyKernel(const Tx* x, kps::Reduce( &reduce_var, &reduce_var, reducer, reduce_last_dim); - if (need_store) { - y[store_offset + i] = static_cast(reduce_var); - } + + Ty result = static_cast(reduce_var); + kps::details::WriteData( + y + store_offset + i, &result, static_cast(need_store)); } } @@ -882,30 +865,18 @@ static void LaunchReduceKernel(const Tx* x_data, dim.SetRem(config.reduce_num % config.block.x, 0, 0); #ifdef PADDLE_WITH_XPU_KP - ReduceAnyKernel<<<8, 64, 0, stream>>>( - x_data, - config.output_data, - reducer, - transform, - init, - config.reduce_num, - config.left_num, - config.reduce_last_dim, - reduce_index_calculator, - left_index_calculator, - dim); + auto grid_num = 8; + auto block_num = 64; #else + auto grid_num = config.grid; + auto block_num = config.block; +#endif ReduceAnyKernel<<>>( + OneDimIndexCal><<>>( x_data, config.output_data, reducer, @@ -917,7 +888,6 @@ static void LaunchReduceKernel(const Tx* x_data, reduce_index_calculator, left_index_calculator, dim); -#endif } else { int reduce_rank = config.reduce_strides.size(); @@ -938,30 +908,18 @@ static void LaunchReduceKernel(const Tx* x_data, dim.SetRem(config.reduce_num % config.block.x, 0, 0); #ifdef PADDLE_WITH_XPU_KP - ReduceAnyKernel<<<8, 64, 0, stream>>>( - x_data, - config.output_data, - reducer, - transform, - init, - config.reduce_num, - config.left_num, - config.reduce_last_dim, - reduce_index_calculator, - left_index_calculator, - dim); + auto grid_num = 8; + auto block_num = 64; #else + auto grid_num = config.grid; + auto block_num = config.block; +#endif ReduceAnyKernel<<>>( + IndexCalculator><<>>( x_data, config.output_data, reducer, @@ -973,7 +931,6 @@ static void LaunchReduceKernel(const Tx* x_data, reduce_index_calculator, left_index_calculator, dim); -#endif } if (config.should_reduce_again) { @@ -993,22 +950,9 @@ static void LaunchReduceKernel(const Tx* x_data, kps::DimConfig(grid.x, grid.y, grid.z, block.x, config.grid.y, 0); dim.SetRem(config.left_num % block.x, 0, 0); #ifdef PADDLE_WITH_XPU_KP - ReduceHigherDimKernel< - Ty, - Ty, - MPType, - ReduceOp, - kps::IdentityFunctor><<<8, 64, 0, stream>>>( - config.output_data, - y_data, - reducer, - kps::IdentityFunctor(), - init, - config.grid.y, - config.left_num, - config.grid.y, - dim); -#else + grid = 8; + block = 64; +#endif ReduceHigherDimKernel< Ty, Ty, @@ -1024,7 +968,6 @@ static void LaunchReduceKernel(const Tx* x_data, config.left_num, config.grid.y, dim); -#endif } } @@ -1038,7 +981,7 @@ CubTensorReduceImpl(const Tx* x_data, Ty* y_data, const TransformOp& transform, int reduce_num, - const phi::GPUContext& dev_ctx, + const KPDevice& dev_ctx, KPStream stream) { auto reducer = ReduceOp(); cub::TransformInputIterator trans_x(x_data, @@ -1077,7 +1020,7 @@ CubTensorReduceImpl(const Tx* x_data, Ty* y_data, const TransformOp& transform, int reduce_num, - const phi::GPUContext& dev_ctx, + const KPDevice& dev_ctx, KPStream stream) { PADDLE_THROW(phi::errors::InvalidArgument( "Tx should not be float16 when using cub::DeviceReduce::Reduce().")); @@ -1087,12 +1030,16 @@ template class ReduceOp, typename TransformOp> -void ReduceKernel(const phi::GPUContext& dev_ctx, +void ReduceKernel(const KPDevice& dev_ctx, const phi::DenseTensor& x, phi::DenseTensor* y, const TransformOp& transform, const std::vector& origin_reduce_dims) { +#ifdef PADDLE_WITH_XPU_KP + auto stream = dev_ctx.x_context()->xpu_stream; +#else auto stream = dev_ctx.stream(); +#endif dev_ctx.Alloc(y); auto x_dim = phi::vectorize(x.dims()); @@ -1149,11 +1096,17 @@ void ReduceKernel(const phi::GPUContext& dev_ctx, 0); #ifdef PADDLE_WITH_XPU_KP + auto grid_num = 8; + auto block_num = 64; +#else + auto grid_num = config.grid; + auto block_num = config.block; +#endif ReduceHigherDimKernel, - TransformOp><<<8, 64, 0, stream>>>( + TransformOp><<>>( x_data, config.output_data, reducer, @@ -1163,23 +1116,6 @@ void ReduceKernel(const phi::GPUContext& dev_ctx, config.left_num, config.blocking_size, dim); -#else - ReduceHigherDimKernel< - Tx, - Ty, - MPType, - ReduceOp, - TransformOp><<>>( - x_data, - config.output_data, - reducer, - transform, - reducer.initial(), - config.reduce_num, - config.left_num, - config.blocking_size, - dim); -#endif if (config.should_reduce_again) { dim3 block = dim3(config.block.x, 1, 1); @@ -1189,22 +1125,9 @@ void ReduceKernel(const phi::GPUContext& dev_ctx, dim2.SetRem(config.left_num % config.block.x, 0, 0); #ifdef PADDLE_WITH_XPU_KP - ReduceHigherDimKernel< - Ty, - Ty, - MPType, - ReduceOp, - kps::IdentityFunctor><<<8, 64, 0, stream>>>( - config.output_data, - y_data, - reducer, - kps::IdentityFunctor(config.grid.y), - reducer.initial(), - config.grid.y, - config.left_num, - config.grid.y, - dim2); -#else + grid = 8; + block = 64; +#endif ReduceHigherDimKernel< Ty, Ty, @@ -1220,7 +1143,6 @@ void ReduceKernel(const phi::GPUContext& dev_ctx, config.left_num, config.grid.y, dim2); -#endif } return; } diff --git a/paddle/phi/kernels/primitive/datamover_primitives.h b/paddle/phi/kernels/primitive/datamover_primitives.h index 2f1e2f589c5122987d9776700f3aa7bd95daa7a5..1d4181f3b9a89509ada2a8fe27d584a9b5aa039c 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives.h +++ b/paddle/phi/kernels/primitive/datamover_primitives.h @@ -115,6 +115,14 @@ struct BroadcastConfig { } }; +template +__device__ __forceinline__ void WriteData(T* dst, + T* __restrict__ src, + int num) { + for (int i = 0; i < num; i++) { + dst[i] = src[i]; + } +} #undef INT_BITS } // namespace details diff --git a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h index 53a8b7d0c9ef9489056ab293d97e5767b23531fe..d2cfdbdec3064c8e9cf20d101afc2adf0ed011a8 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h @@ -76,6 +76,16 @@ struct BroadcastConfig { }; #pragma pack() +template +__device__ __forceinline__ void WriteData(T* _global_ptr_ dst, + T* src, + int num) { + if (num > 0) { + LM2GM(src, dst, num * sizeof(T)); + } +} +#undef INT_BITS + } // namespace details /**