From 4acc87beb2110e9966327dbd427e0cdc17e05841 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Thu, 1 Apr 2021 11:26:07 +0800 Subject: [PATCH] Optimize the perf of SameDimsAdd CUDA Kernel (#31872) --- .../elementwise/elementwise_add_op.cu | 88 +++++++++++++------ .../elementwise/elementwise_div_op.cu | 2 +- .../elementwise/elementwise_mul_op.cu | 2 +- .../elementwise/elementwise_op_function.cu.h | 86 ++++++++++++------ .../elementwise/elementwise_sub_op.cu | 2 +- 5 files changed, 125 insertions(+), 55 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index 8de6416065..68fd81f826 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -24,7 +24,10 @@ namespace paddle { namespace operators { template -struct SameDimsElemwiseAdd { +struct SameDimsElemwiseAdd< + platform::CUDADeviceContext, T, + typename std::enable_if::value && + !std::is_same::value>::type> { void operator()(const framework::ExecutionContext& ctx, const framework::Tensor* x, const framework::Tensor* y, framework::Tensor* z) { @@ -36,38 +39,68 @@ struct SameDimsElemwiseAdd { } }; -template <> -struct SameDimsElemwiseAdd { +template +struct SameDimsElemwiseAdd< + platform::CUDADeviceContext, T, + typename std::enable_if::value || + std::is_same::value>::type> { void operator()(const framework::ExecutionContext& ctx, const framework::Tensor* x, const framework::Tensor* y, framework::Tensor* z) { auto size = x->numel(); - dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / - PADDLE_CUDA_THREAD_SIZE, - 1); + int vec_size = sizeof(float4) / sizeof(T); + dim3 grid_size = + dim3(((size + vec_size - 1) / vec_size + PADDLE_CUDA_THREAD_SIZE - 1) / + PADDLE_CUDA_THREAD_SIZE, + 1); dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); - const half* x2 = - reinterpret_cast(x->data()); - const half* y2 = - reinterpret_cast(y->data()); - half* z2 = reinterpret_cast(z->data()); - SameDimsElemwiseAddCUDAKernel<<< - grid_size, block_size, 0, - ctx.template device_context().stream()>>>( - x2, y2, z2, size); + if (std::is_same::value) { + SameDimsElemwiseAddCUDAKernel<<< + grid_size, block_size, 0, + ctx.template device_context() + .stream()>>>(x->data(), y->data(), z->data(), + size); + } else { + const half* x2 = + reinterpret_cast(x->data()); + const half* y2 = + reinterpret_cast(y->data()); + half* z2 = reinterpret_cast(z->data()); + SameDimsElemwiseAddCUDAKernel<<< + grid_size, block_size, 0, + ctx.template device_context() + .stream()>>>(x2, y2, z2, size); + } } }; template -static __global__ void SimpleElemwiseAddGradCUDAKernel(const T* dout, - int64_t size, T* dx, - T* dy) { - int col = blockIdx.x * blockDim.x + threadIdx.x; +static __global__ void SimpleElemwiseAddGradCUDAKernel( + const T* __restrict__ dout, int size, int vec_size, T* dx, T* dy) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + int loop = size / vec_size; + int remainder = size % vec_size; + const float4* dout_vec = reinterpret_cast(dout); + float4* dx_vec = reinterpret_cast(dx); + float4* dy_vec = reinterpret_cast(dy); + float4 tmp_loop; + + for (int i = tid; i < loop; i += stride) { + tmp_loop = dout_vec[i]; + dx_vec[i] = tmp_loop; + dy_vec[i] = tmp_loop; + } - while (col < size) { - dx[col] = dout[col]; - dy[col] = dout[col]; - col += blockDim.x * gridDim.x; + if (tid == loop && remainder != 0) { + T tmp_rem; + while (remainder) { + int idx = size - remainder; + remainder--; + tmp_rem = dout[idx]; + dx[idx] = tmp_rem; + dy[idx] = tmp_rem; + } } } @@ -79,14 +112,17 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, const framework::Tensor* out, const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dy) { - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); auto size = x->numel(); + int vec_size = max(static_cast(sizeof(float4) / sizeof(T)), 1); + dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); dim3 grid_size = - dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + dim3(((size + vec_size - 1) / vec_size + PADDLE_CUDA_THREAD_SIZE - 1) / + PADDLE_CUDA_THREAD_SIZE, + 1); SimpleElemwiseAddGradCUDAKernel< T><<().stream()>>>( - dout->data(), size, dx->mutable_data(ctx.GetPlace()), + dout->data(), size, vec_size, dx->mutable_data(ctx.GetPlace()), dy->mutable_data(ctx.GetPlace())); } diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index 96583d0657..0cf9294c9d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -43,7 +43,7 @@ struct SameDimsElemwiseDiv { const framework::Tensor* x, const framework::Tensor* y, framework::Tensor* z) { auto size = x->numel(); - dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / + dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 5b598ab2d7..e01b5eb5fb 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -43,7 +43,7 @@ struct SameDimsElemwiseMul { const framework::Tensor* x, const framework::Tensor* y, framework::Tensor* z) { auto size = x->numel(); - dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / + dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h index 6d5dcc4dd6..8344b3d983 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h @@ -18,7 +18,11 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/hostdevice.h" +#ifdef __HIPCC__ +#define PADDLE_CUDA_THREAD_SIZE 256 +#else #define PADDLE_CUDA_THREAD_SIZE 512 +#endif #ifdef PADDLE_WITH_CUDA #include @@ -158,32 +162,62 @@ inline DEVICE half2 half2_div(const half2& a, const half2& b) { #endif } -#define DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Func, expr, FP16Function) \ - template \ - __global__ void SameDimsElemwise##Func##CUDAKernel(const T* x, const T* y, \ - T* z, int64_t size) { \ - int col = blockIdx.x * blockDim.x + threadIdx.x; \ - while (col < size) { \ - z[col] = x[col] expr y[col]; \ - col += blockDim.x * gridDim.x; \ - } \ - } \ - template <> \ - inline __global__ void SameDimsElemwise##Func##CUDAKernel( \ - const half* x, const half* y, half* z, int64_t size) { \ - int start = threadIdx.x + blockDim.x * blockIdx.x; \ - int stride = blockDim.x * gridDim.x; \ - int n2 = size / 2; \ - const half2* x2 = reinterpret_cast(x); \ - const half2* y2 = reinterpret_cast(y); \ - half2* z2 = reinterpret_cast(z); \ - for (int i = start; i < n2; i += stride) { \ - z2[i] = FP16Function(x2[i], y2[i]); \ - } \ - if (start == 0 && (size % 2)) { \ - z[size - 1] = __float2half(__half2float(x[size - 1]) \ - expr __half2float(y[size - 1])); \ - } \ +#define DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Func, expr, FP16Function) \ + inline __global__ void SameDimsElemwise##Func##CUDAKernel( \ + const float* __restrict__ x, const float* __restrict__ y, float* z, \ + int64_t size) { \ + int tid = blockIdx.x * blockDim.x + threadIdx.x; \ + int stride = gridDim.x * blockDim.x; \ + int loop = size / 4; \ + int remainder = size % 4; \ + const float4* x_vec = reinterpret_cast(x); \ + const float4* y_vec = reinterpret_cast(y); \ + float4* z_vec = reinterpret_cast(z); \ + float4 x_f4, y_f4; \ + for (int i = tid; i < loop; i += stride) { \ + x_f4 = x_vec[i]; \ + y_f4 = y_vec[i]; \ + z_vec[i] = make_float4(x_f4.x expr y_f4.x, x_f4.y expr y_f4.y, \ + x_f4.z expr y_f4.z, x_f4.w expr y_f4.w); \ + } \ + if (tid == loop && remainder != 0) { \ + while (remainder) { \ + int idx = size - remainder; \ + remainder--; \ + z[idx] = x[idx] expr y[idx]; \ + } \ + } \ + } \ + inline __global__ void SameDimsElemwise##Func##CUDAKernel( \ + const half* __restrict__ x, const half* __restrict__ y, half* z, \ + int64_t size) { \ + int tid = blockIdx.x * blockDim.x + threadIdx.x; \ + int stride = gridDim.x * blockDim.x; \ + int loop = size / 8; \ + int remainder = size % 8; \ + const float4* x_vec = reinterpret_cast(x); \ + const float4* y_vec = reinterpret_cast(y); \ + float4* z_vec = reinterpret_cast(z); \ + float4 x_h8, y_h8, z_h8; \ + for (int i = tid; i < loop; i += stride) { \ + x_h8 = x_vec[i]; \ + y_h8 = y_vec[i]; \ + half2* x_h2 = reinterpret_cast(&x_h8); \ + half2* y_h2 = reinterpret_cast(&y_h8); \ + half2* z_h2 = reinterpret_cast(&z_h8); \ + z_h2[0] = FP16Function(x_h2[0], y_h2[0]); \ + z_h2[1] = FP16Function(x_h2[1], y_h2[1]); \ + z_h2[2] = FP16Function(x_h2[2], y_h2[2]); \ + z_h2[3] = FP16Function(x_h2[3], y_h2[3]); \ + z_vec[i] = z_h8; \ + } \ + if (tid == loop && remainder != 0) { \ + while (remainder) { \ + int idx = size - remainder; \ + remainder--; \ + z[idx] = __float2half(__half2float(x[idx]) expr __half2float(y[idx])); \ + } \ + } \ } DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Add, +, half2_add) DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Sub, -, half2_sub) diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu index 1996cc471a..192999fd2a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -43,7 +43,7 @@ struct SameDimsElemwiseSub { const framework::Tensor* x, const framework::Tensor* y, framework::Tensor* z) { auto size = x->numel(); - dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / + dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); -- GitLab