From 45c7d905646a3a6f06aecdc4d6fd020a30551582 Mon Sep 17 00:00:00 2001 From: JamesLim <61349199+JamesLim-sy@users.noreply.github.com> Date: Wed, 10 Mar 2021 08:44:56 +0800 Subject: [PATCH] Optimization of elementwise CUDA kernel (#30801) --- .../elementwise/elementwise_op_function.h | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 923611143a3..c69baadb3c2 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -99,6 +99,7 @@ inline void get_mid_dims(const framework::DDim &x_dims, (*post) *= x_dims[i]; } } + inline int GetElementwiseIndex(const int *x_dims_array, const int max_dim, const int *index_array) { int index_ = 0; @@ -202,12 +203,16 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x, #if defined(__NVCC__) || defined(__HIPCC__) template -__global__ void ElementwiseKernel(const T *x, const T *y, OutType *out, int pre, - int n, int post, int total, Functor func) { +__global__ void ElementwiseKernel(const T *__restrict__ x_data, + const T *__restrict__ y_data, + OutType *__restrict__ out_data, int n, + int post, const size_t total, Functor func) { int tid = threadIdx.x + blockDim.x * blockIdx.x; - int idx = tid / post % n; - if (tid < total) { - out[tid] = func(x[tid], y[idx]); + int stride = blockDim.x * gridDim.x; + + for (int i = tid; i < total; i += stride) { + int idx = i / post % n; + out_data[i] = func(x_data[i], y_data[idx]); } } @@ -224,14 +229,16 @@ void ComputeElementwiseCUDA(const framework::Tensor *x, int numel = pre * n * post; int threads = 256; int blocks = (numel + threads - 1) / threads; + if (is_xsize_larger) { ElementwiseKernel<<>>( - x_data, y_data, out_data, pre, n, post, numel, func); + x_data, y_data, out_data, n, post, numel, func); + } else { ElementwiseKernel<<>>( - y_data, x_data, out_data, pre, n, post, numel, func); + y_data, x_data, out_data, n, post, numel, func); } } -- GitLab