From 3eaf8d2cead9fc3d7b82c5c928c331917ea687b6 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Tue, 11 Jan 2022 19:49:01 +0800 Subject: [PATCH] Modified Kernel Primitive API and elementwise for xpu2 #38688 --- .../elementwise/elementwise_op_broadcast.cu.h | 8 +- .../elementwise/elementwise_op_impl.cu.h | 3 +- .../datamover_primitives_xpu2.h | 172 +++++++++--------- .../kernel_primitives/kernel_primitives.h | 15 +- paddle/fluid/platform/hostdevice.h | 9 +- paddle/pten/kernels/gpu/elementwise.h | 104 +++++------ 6 files changed, 164 insertions(+), 147 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 25c983566b3..e3d4607b713 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -25,8 +25,7 @@ namespace kps = paddle::operators::kernel_primitives; template void LaunchBroadcastElementwiseCudaKernel( - const platform::CUDADeviceContext &ctx, - const std::vector &ins, + const KPDevice &ctx, const std::vector &ins, std::vector *outs, int axis, Functor func) { std::vector pt_inputs; std::vector pt_outputs; @@ -58,8 +57,7 @@ void LaunchBroadcastElementwiseCudaKernel( template void LaunchElementwiseCudaKernel( - const platform::CUDADeviceContext &cuda_ctx, - const std::vector &ins, + const KPDevice &ctx, const std::vector &ins, std::vector *outs, int axis, Functor func) { std::vector pt_inputs; std::vector pt_outputs; @@ -85,7 +83,7 @@ void LaunchElementwiseCudaKernel( pt_outputs.push_back(pt_outputs_tmp[i].get()); } pten::LaunchElementwiseCudaKernel( - cuda_ctx, pt_inputs, &pt_outputs, axis, func); + ctx, pt_inputs, &pt_outputs, axis, func); } } // namespace operators diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 1d8acd5eca5..36ff1ae254d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -35,8 +35,7 @@ using ElementwiseType = pten::ElementwiseType; template void LaunchSameDimsElementwiseCudaKernel( - const platform::CUDADeviceContext &ctx, - const std::vector &ins, + const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func) { std::vector pt_inputs; std::vector pt_outputs; diff --git a/paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h b/paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h index b27ba27b3c6..33389953589 100644 --- a/paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h +++ b/paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h @@ -32,42 +32,50 @@ struct alignas(sizeof(T) * VecSize) VectorType { * index of the output data. if input or output shape is [dim0, dim1] then dims * must be [dim1, dim0]. */ +#pragma pack(4) template struct BroadcastConfig { - uint32_t stride_in[framework::DDim::kMaxRank]; - uint32_t stride_out[framework::DDim::kMaxRank]; - uint32_t shape_in[framework::DDim::kMaxRank]; + int strides_in[framework::DDim::kMaxRank]; + int strides_out[framework::DDim::kMaxRank]; + int in_dim[framework::DDim::kMaxRank]; HOSTDEVICE BroadcastConfig() {} HOSTDEVICE BroadcastConfig(const std::vector& out_dims, const std::vector& in_dims, int dim_size) { - std::vector strides_in; - std::vector strides_out; - std::vector shapes_in; - - strides_out.resize(dim_size, 1); - strides_in.resize(dim_size, 1); - shapes_in.resize(dim_size, 1); - - for (int i = 0; i < dim_size; ++i) { - shape_in[i] = in_dims[dim_size - i - 1]; + std::vector strides_in_tmp; + std::vector strides_out_tmp; + std::vector dim_tmp; + strides_in_tmp.resize(dim_size, 1); + strides_out_tmp.resize(dim_size, 1); + dim_tmp.resize(dim_size, 1); + for (int i = 1; i < dim_size; i++) { + strides_in_tmp[i] = strides_in_tmp[i - 1] * in_dims[i - 1]; + strides_out_tmp[i] = strides_out_tmp[i - 1] * out_dims[i - 1]; } - for (int i = 1; i < dim_size - 1; ++i) { - strides_out[dim_size - i - 1] = std::accumulate( - out_dims.begin(), out_dims.begin() + i, 1, std::multiplies()) - strides_in[dim_size - i - 1] = - std::accumulate(in_dims.begin(), in_dims.begin() + i, 1, - std::multiplies()) + for (int i = 0; i < dim_size; i++) { + dim_tmp[i] = in_dims[i]; } - memcpy(stride_in, strides_in.data(), kDims * sizeof(uint32_t)); - memcpy(stride_out, strides_out.data(), kDims * sizeof(uint32_t)); - memcpy(shape_in, shapes_in.data(), kDims * sizeof(uint32_t)); + memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int)); + memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int)); + memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int)); + } + + __device__ inline int operator()(int index_output) const { + int index_src = 0; +#pragma unroll + for (int i = kDims - 1; i >= 0; --i) { + int tmp_index = (index_output / strides_out[i]); + index_output = index_output - tmp_index * strides_out[i]; + index_src += (tmp_index % in_dim[i]) * strides_in[i]; + } + return index_src; } }; +#pragma pack() } // namespace details @@ -99,12 +107,12 @@ struct BroadcastConfig { */ template -__device__ __forceinline__ void ReadData(Ty* dst, const Tx _global_ptr_* src, - int size_nx, int size_ny, - int stride_nx, int stride_ny) { +__device__ __inline__ void ReadData(Ty* dst, const Tx _global_ptr_* src, + int size_nx, int size_ny, int stride_nx, + int stride_ny) { int thread_offset = core_id(); int left_size_nx = size_nx - thread_offset; - __local__ T in_temp[1]; + __local__ Tx in_temp[1]; // Each branch is added for better performance if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1 if (IsBoundary) { @@ -168,7 +176,7 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx _global_ptr_* src, * init_data: Initial value. */ template -__device__ __forceinline__ void Init(T* dst, T init_data) { +__device__ __inline__ void Init(T* dst, T init_data) { #pragma unroll for (int i = 0; i < NX; i++) { dst[i] = init_data; @@ -197,8 +205,8 @@ __device__ __forceinline__ void Init(T* dst, T init_data) { * size: The current block needs to load size data continuously. */ template -__device__ __forceinline__ void ReadData(T* dst, const T _global_ptr_* src, - int num) { +__device__ __inline__ void ReadData(T* dst, const T _global_ptr_* src, + int num) { int thread_offset = core_id() * NX; __local__ T in_temp[1]; if (IsBoundary) { // core_num() * NX > num @@ -241,10 +249,11 @@ __device__ __forceinline__ void ReadData(T* dst, const T _global_ptr_* src, */ template -__device__ __forceinline__ void ReadDataBc( - T* dst, const T _global_ptr_* src, uint32_t block_offset, - details::BroadcastConfig config, int total_num_output, int stride_nx, - int stride_ny) { +__device__ __inline__ void ReadDataBc(T* dst, const T _global_ptr_* src, + uint32_t block_offset, + details::BroadcastConfig config, + int total_num_output, int stride_nx, + int stride_ny) { uint32_t thread_offset = block_offset + core_id(); uint32_t index_src = 0; __local__ T in_temp[1]; @@ -256,16 +265,11 @@ __device__ __forceinline__ void ReadDataBc( uint32_t index_output = thread_offset + ny * stride_ny + nx * stride_nx; index_src = 0; if (IsBoundary) { - if (index_output >= total_num_output) { + if (index_output >= (uint32_t)total_num_output) { break; } } -#pragma unroll - for (int i = 0; i < Rank; ++i) { - uint32_t tmp = index_output / config.stride_out[i]; - index_output = index_output - tmp * config.stride_out[i]; - index_src += (tmp % config.shape_in[i]) * config.stride_in[i]; - } + index_src = config(index_output); GM2LM(src + index_src, in_temp, sizeof(T)); dst[nx + ny * NX] = in_temp[0]; } @@ -305,33 +309,34 @@ __device__ __forceinline__ void ReadDataBc( */ template -__device__ __forceinline__ void ReadDataReduce( - T* dst, const T _global_ptr_* src, int block_offset, - const IndexCal& index_cal, int size_nx, int size_ny, int stride_nx, - int stride_ny, bool reduce_last_dim) { - __local__ T in_temp[1]; +__device__ __inline__ void ReadDataReduce(T* dst, const T _global_ptr_* src, + int block_offset, + const IndexCal& index_cal, + int size_nx, int size_ny, + int stride_nx, int stride_ny, + bool reduce_last_dim) { + __local__ Tx in_temp[1]; int thread_offset = 0; - int left_size_nx = size_nx; - int left_size_ny = size_ny; + int left_idx = 0; if (reduce_last_dim) { - thread_offset = block_offset + core_id(); - left_size_nx -= thread_offset; + thread_offset = core_id(); + left_idx = 0; } else { - thread_offset = block_offset + core_id(); - left_size_ny -= thread_offset; + thread_offset = 0; + left_idx = 0; } if (NX == 1) { #pragma unroll for (int ny = 0; ny < NY; ++ny) { if (IsBoundary) { - if (ny * stride_ny >= left_size_ny) { + if (thread_offset >= size_ny) { break; } } - uint32_t index_src = index_cal(thread_offset); - GM2LM(src + index_src, in_temp, sizeof(T)); - dst[ny] = in_temp[0]; + uint32_t index_src = index_cal(thread_offset + block_offset); + GM2LM(src + index_src, in_temp, sizeof(Tx)); + dst[ny] = static_cast(func(in_temp[0])); thread_offset += stride_ny; } } else { @@ -340,17 +345,16 @@ __device__ __forceinline__ void ReadDataReduce( #pragma unroll for (int ny = 0; ny < NY; ++ny) { if (IsBoundary) { - if ((ny * stride_ny >= left_size_ny) || - (nx * stride_nx >= left_size_nx)) { + if ((thread_offset >= size_ny) || + (left_idx + nx * stride_nx >= size_nx)) { break; } } - uint32_t index_src = index_cal(thread_offset); - GM2LM(src + index_src, in_temp, sizeof(T)); - dst[nx + ny * NX] = in_temp[0]; + uint32_t index_src = index_cal(thread_offset + block_offset); + GM2LM(src + index_src, in_temp, sizeof(Tx)); + dst[nx + ny * NX] = static_cast(func(in_temp[0])); thread_offset += stride_ny; } - thread_offset += stride_nx; } } } @@ -421,9 +425,9 @@ __device__ void WriteData(T _global_ptr_* dst, const T* src, int num) { */ template -__device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, - int size_nx, int size_ny, - int stride_nx, int stride_ny) { +__device__ __inline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, + int size_nx, int size_ny, int stride_nx, + int stride_ny) { int thread_offset = core_id(); int left_size_nx = size_nx - thread_offset; __local__ Ty in_temp[1]; @@ -433,11 +437,11 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, if (IsBoundary) { if (left_size_nx > 0) { in_temp[0] = static_cast(src[0]); - LM2GM(in_temp, dst + thread_offset, sizeof(T)); + LM2GM(in_temp, dst + thread_offset, sizeof(Ty)); } } else { in_temp[0] = static_cast(src[0]); - LM2GM(in_temp, dst + thread_offset, sizeof(T)); + LM2GM(in_temp, dst + thread_offset, sizeof(Ty)); } } else if (NX == 1) { #pragma unroll @@ -449,7 +453,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, } in_temp[0] = static_cast(src[idy]); - LM2GM(in_temp, dst + thread_offset + idy * stride_ny, sizeof(T)); + LM2GM(in_temp, dst + thread_offset + idy * stride_ny, sizeof(Ty)); } } else if (NY == 1) { // for NY == 1 and NX != 1 #pragma unroll @@ -461,7 +465,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, } in_temp[0] = static_cast(src[idx]); - LM2GM(in_temp, dst + thread_offset + idx * stride_nx, sizeof(T)); + LM2GM(in_temp, dst + thread_offset + idx * stride_nx, sizeof(Ty)); } } else { // for NX != 1 and NY != 1 #pragma unroll @@ -480,7 +484,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, } in_temp[0] = static_cast(src[idx + idy * NX]); LM2GM(in_temp, dst + thread_offset + idx * stride_nx + idy * stride_ny, - sizeof(T)); + sizeof(Ty)); } } } @@ -498,7 +502,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, * init_data: The register pointer of init data, the size is NX. */ template -__device__ __forceinline__ void Init(T* dst, T* init_data, int num) { +__device__ __inline__ void Init(T* dst, T* init_data, int num) { #pragma unroll for (int i = 0; i < NX; i++) { if (IsBoundary) { @@ -535,30 +539,26 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) { */ template -__device__ __forceinline__ void ReadDataBc( - T* dst, const T _global_ptr_* src, uint32_t block_offset, - details::BroadcastConfig config, int total_num_output) { - uint32_t thread_offset = block_offset + core_id() * NX; - uint32_t index_src = 0; - __local__ T in_temp[1]; +__device__ __inline__ void ReadDataBc(T* dst, const T _global_ptr_* src, + uint32_t block_offset, + details::BroadcastConfig config, + int total_num_output) { + int thread_offset = block_offset + core_id() * NX; + int index_src = 0; + __local__ T in_temp; #pragma unroll - for (uint32_t nx = 0; nx < NX; ++nx) { - uint32_t index_output = thread_offset + nx; + for (int nx = 0; nx < NX; ++nx) { + int index_output = thread_offset + nx; index_src = 0; if (IsBoundary) { if (index_output >= total_num_output) { break; } } -#pragma unroll - for (int i = 0; i < Rank; ++i) { - uint32_t tmp = index_output / config.stride_out[i]; - index_output = index_output - tmp * config.stride_out[i]; - index_src += (tmp % config.shape_in[i]) * config.stride_in[i]; - } - GM2LM(src + index_src, in_temp, sizeof(T)); - dst[nx + ny * NX] = in_temp[0]; + index_src = config(index_output); + GM2LM(src + index_src, &in_temp, sizeof(T)); + dst[nx] = in_temp; } } diff --git a/paddle/fluid/operators/kernel_primitives/kernel_primitives.h b/paddle/fluid/operators/kernel_primitives/kernel_primitives.h index e20e77ae26a..558f8c81c66 100644 --- a/paddle/fluid/operators/kernel_primitives/kernel_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/kernel_primitives.h @@ -13,11 +13,18 @@ // limitations under the License. #pragma once -#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h" #include "paddle/fluid/operators/kernel_primitives/helper_primitives.h" #ifdef PADDLE_WITH_XPU2 #include "paddle/fluid/operators/kernel_primitives/compute_primitives_xpu2.h" #include "paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h" +#include "paddle/fluid/operators/kernel_primitives/functor_primitives_xpu2.h" + +#define KPStream XPUStream +#define KPDevice paddle::platform::XPUDeviceContext +#define _ptr_ _global_ptr_ +#define __forceinline__ __inline__ +#define __restrict__ + #define THREAD_ID_X core_id() #define THREAD_ID_Y 0 #define THREAD_ID_Z 0 @@ -36,6 +43,12 @@ #else #include "paddle/fluid/operators/kernel_primitives/compute_primitives.h" #include "paddle/fluid/operators/kernel_primitives/datamover_primitives.h" +#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h" + +#define KPStream gpuStream_t +#define KPDevice paddle::platform::CUDADeviceContext +#define _ptr_ + #define THREAD_ID_X threadIdx.x #define THREAD_ID_Y threadIdx.y #define THREAD_ID_Z threadIdx.z diff --git a/paddle/fluid/platform/hostdevice.h b/paddle/fluid/platform/hostdevice.h index 1ffbbc217e2..65005a5adbb 100644 --- a/paddle/fluid/platform/hostdevice.h +++ b/paddle/fluid/platform/hostdevice.h @@ -17,7 +17,14 @@ #include #endif -#if (defined(__CUDACC__) || defined(__HIPCC__)) +#ifdef __xpu_kp__ +#include +#include "xpu/kernel/cluster_header.h" +#include "xpu/kernel/debug.h" +#include "xpu/kernel/math.h" +#endif + +#if (defined(__CUDACC__) || defined(__HIPCC__) || defined(__xpu_kp__)) #define HOSTDEVICE __host__ __device__ #define DEVICE __device__ #define HOST __host__ diff --git a/paddle/pten/kernels/gpu/elementwise.h b/paddle/pten/kernels/gpu/elementwise.h index f78328c01a3..e4cc894e483 100644 --- a/paddle/pten/kernels/gpu/elementwise.h +++ b/paddle/pten/kernels/gpu/elementwise.h @@ -86,7 +86,7 @@ struct ElementwisePrimitiveCaller { template struct ElementwiseWriteDataCaller { __device__ __forceinline__ void operator()( - paddle::framework::Array outs, + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, ConditionalT src[VecSize], int block_offset, int num) { @@ -109,7 +109,7 @@ struct ElementwiseWriteDataCaller { template struct ElementwiseWriteDataCaller { __device__ __forceinline__ void operator()( - paddle::framework::Array outs, + paddle::framework::Array<_ptr_ OutT *, 1> outs, OutT src[VecSize], int block_offset, int num) { @@ -126,8 +126,8 @@ template __device__ void VectorizedElementwiseKernelImpl( - const paddle::framework::Array &in, - paddle::framework::Array outs, + const paddle::framework::Array &in, + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, int num, int data_offset, Functor func) { @@ -161,8 +161,8 @@ template __global__ void VectorizedElementwiseKernel( - paddle::framework::Array ins, - paddle::framework::Array outs, + paddle::framework::Array ins, + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, int size, int main_offset, Functor func) { @@ -212,17 +212,13 @@ template -void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, +void ElementwiseCudaKernel(const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func) { auto numel = ins[0]->numel(); - int block_size = funcs::GetThreadsConfig(ctx, numel, VecSize); - int grid_size = - ((numel + VecSize - 1) / VecSize + block_size - 1) / block_size; - auto stream = ctx.stream(); - paddle::framework::Array ins_data; - paddle::framework::Array outs_data; + paddle::framework::Array ins_data; + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs_data; for (int i = 0; i < Arity; ++i) { ins_data[i] = ins[i]->data(); @@ -231,8 +227,9 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, outs_data[i] = (*outs)[i]->mutable_data(); } #ifdef PADDLE_WITH_XPU2 - block_size = 128; - grid_size = 8; + int block_size = 64; + int grid_size = 8; + auto stream = ctx.x_context()->xpu_stream; int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; VectorizedElementwiseKernel<<>>( ins_data, outs_data, numel, main_offset, func); #else + int block_size = funcs::GetThreadsConfig(ctx, numel, VecSize); + int grid_size = + ((numel + VecSize - 1) / VecSize + block_size - 1) / block_size; int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; + auto stream = ctx.stream(); VectorizedElementwiseKernel void LaunchSameDimsElementwiseCudaKernel( - const paddle::platform::CUDADeviceContext &ctx, + const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func) { @@ -471,12 +472,12 @@ struct DimensionsTransform { template __device__ __forceinline__ void LoadData( T *dst, - const T *__restrict__ src, + const _ptr_ T *src, uint32_t block_offset, const kps::details::BroadcastConfig &config, int numel, int num, - bool need_broadcast) { + int need_broadcast) { // numel : whole num of output // num: how many data will be deal with in this time if (need_broadcast) { @@ -496,9 +497,9 @@ template __device__ void ElementwiseBroadcastKernelImpl( - const paddle::framework::Array &ins, - paddle::framework::Array outs, - const paddle::framework::Array &use_broadcast, + const paddle::framework::Array &ins, + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, + const paddle::framework::Array &use_broadcast, uint32_t numel, const paddle::framework::Array, Arity> &configs, @@ -540,9 +541,9 @@ template __global__ void ElementwiseBroadcastKernel( - paddle::framework::Array ins, - paddle::framework::Array outs, - paddle::framework::Array use_broadcast, + paddle::framework::Array ins, + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, + paddle::framework::Array use_broadcast, uint32_t numel, paddle::framework::Array, Arity> configs, @@ -570,7 +571,8 @@ __global__ void ElementwiseBroadcastKernel( block_offset, func); } - if (block_offset < numel) { + int num = numel - block_offset; + if (num > 0) { ElementwiseBroadcastKernelImpl( - ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func); + ins, outs, use_broadcast, numel, configs, num, block_offset, func); } #else if (block_offset < main_offset) { @@ -619,23 +621,16 @@ template -void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, +void LaunchKernel(const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func, DimensionsTransform merge_dims) { int numel = (*outs)[0]->numel(); - const int threads = 256; - int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; - - int main_offset = (numel / (VecSize * threads)) * VecSize * threads; - int tail_tid = numel % (VecSize * threads); - auto stream = ctx.stream(); - paddle::framework::Array, Arity> configs; - paddle::framework::Array use_broadcast; - paddle::framework::Array ins_data; - paddle::framework::Array outs_data; + paddle::framework::Array use_broadcast; + paddle::framework::Array ins_data; + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs_data; for (int i = 0; i < NumOuts; ++i) { outs_data[i] = (*outs)[i]->mutable_data(); @@ -643,7 +638,7 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, for (int i = 0; i < Arity; i++) { use_broadcast[i] = (ins[i]->numel() != numel); - ins_data[i] = ins[i]->data(); + ins_data[i] = (_ptr_ InT *)(ins[i]->data()); if (use_broadcast[i]) { // get the broadcast config, // if data shape is[m, n], then you should set data_dim = {n, m} @@ -654,10 +649,11 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, } #ifdef PADDLE_WITH_XPU2 - threads = 128; - blocks = 8; - main_offset = (numel / (VecSize * threads)) * VecSize * threads; - tail_tid = numel % (VecSize * threads); + const int threads = 64; + const int blocks = 8; + int main_offset = (numel / (VecSize * threads)) * VecSize * threads; + int tail_tid = numel % (VecSize * threads); + auto stream = ctx.x_context()->xpu_stream; ElementwiseBroadcastKernel void LaunchBroadcastKernelForDifferentVecSize( - const paddle::platform::CUDADeviceContext &ctx, + const KPDevice &ctx, const std::vector &ins, std::vector *outs, int axis, @@ -737,7 +738,7 @@ template void LaunchBroadcastElementwiseCudaKernel( - const paddle::platform::CUDADeviceContext &ctx, + const KPDevice &ctx, const std::vector &ins, std::vector *outs, int axis, @@ -835,12 +836,11 @@ template -void LaunchElementwiseCudaKernel( - const paddle::platform::CUDADeviceContext &cuda_ctx, - const std::vector &ins, - std::vector *outs, - int axis, - Functor func) { +void LaunchElementwiseCudaKernel(const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + int axis, + Functor func) { std::vector dims_size; bool no_broadcast_flag = true; for (auto *in : ins) { @@ -849,14 +849,14 @@ void LaunchElementwiseCudaKernel( } if (no_broadcast_flag) { LaunchSameDimsElementwiseCudaKernel( - cuda_ctx, ins, outs, func); + ctx, ins, outs, func); } else { axis = axis == -1 ? *std::max_element(dims_size.begin(), dims_size.end()) - *std::min_element(dims_size.begin(), dims_size.end()) : axis; LaunchBroadcastElementwiseCudaKernel( - cuda_ctx, ins, outs, axis, func); + ctx, ins, outs, axis, func); } } -- GitLab