From bcef8275a4650365d4edff69a7670ac2cc9c2b72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Tue, 20 Sep 2022 02:17:06 -0700 Subject: [PATCH] Flip Kernel Optimization (#46119) --- paddle/phi/kernels/gpu/flip_kernel.cu | 66 ++++++++++++++------------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/paddle/phi/kernels/gpu/flip_kernel.cu b/paddle/phi/kernels/gpu/flip_kernel.cu index 6bcc3d6ff4e..6e9dbf37a91 100644 --- a/paddle/phi/kernels/gpu/flip_kernel.cu +++ b/paddle/phi/kernels/gpu/flip_kernel.cu @@ -16,6 +16,7 @@ #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/utils/array.h" @@ -23,7 +24,7 @@ namespace phi { template -__global__ void flip_cuda_kernel(const int N, +__global__ void flip_cuda_kernel(const int64_t N, const T* in_data, T* out_data, phi::Array shape, @@ -53,41 +54,44 @@ __global__ void flip_cuda_kernel(const int N, } template -void launch_flip_cuda_kernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& axis, - DenseTensor* out) { - std::vector flip_dims_v = axis; +void LaunchFlipCudaKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out) { auto* in_data = x.data(); auto* out_data = dev_ctx.template Alloc(out); auto x_dims = x.dims(); const int total_dims = x_dims.size(); - const int numel = x.numel(); - - int block_size = 512; - dim3 dim_block(block_size); - dim3 dim_grid((numel + block_size - 1) / block_size); - - for (size_t i = 0; i < flip_dims_v.size(); ++i) { - if (flip_dims_v[i] < 0) { - flip_dims_v[i] += total_dims; - } - } - + const int64_t numel = x.numel(); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); auto x_stride = phi::stride(x_dims); phi::Array stride_a; phi::Array shape_a; phi::Array flip_dims_a; - size_t flip_dims_size = flip_dims_v.size(); + size_t flip_dims_size = axis.size(); + for (size_t idx = 0; idx < N; ++idx) { stride_a[idx] = x_stride[idx]; shape_a[idx] = x_dims[idx]; - flip_dims_a[idx] = idx < flip_dims_size ? flip_dims_v[idx] : 0; + flip_dims_a[idx] = idx < flip_dims_size ? axis[idx] : 0; + } + + for (size_t i = 0; i < flip_dims_a.size(); ++i) { + if (flip_dims_a[i] < 0) { + flip_dims_a[i] += total_dims; + } } - flip_cuda_kernel<<>>( - numel, in_data, out_data, shape_a, stride_a, flip_dims_a, flip_dims_size); + flip_cuda_kernel + <<>>( + numel, + in_data, + out_data, + shape_a, + stride_a, + flip_dims_a, + flip_dims_size); } template @@ -98,31 +102,31 @@ void FlipKernel(const Context& dev_ctx, const size_t total_dims = x.dims().size(); switch (total_dims) { case 1: - launch_flip_cuda_kernel(dev_ctx, x, axis, out); + LaunchFlipCudaKernel(dev_ctx, x, axis, out); break; case 2: - launch_flip_cuda_kernel(dev_ctx, x, axis, out); + LaunchFlipCudaKernel(dev_ctx, x, axis, out); break; case 3: - launch_flip_cuda_kernel(dev_ctx, x, axis, out); + LaunchFlipCudaKernel(dev_ctx, x, axis, out); break; case 4: - launch_flip_cuda_kernel(dev_ctx, x, axis, out); + LaunchFlipCudaKernel(dev_ctx, x, axis, out); break; case 5: - launch_flip_cuda_kernel(dev_ctx, x, axis, out); + LaunchFlipCudaKernel(dev_ctx, x, axis, out); break; case 6: - launch_flip_cuda_kernel(dev_ctx, x, axis, out); + LaunchFlipCudaKernel(dev_ctx, x, axis, out); break; case 7: - launch_flip_cuda_kernel(dev_ctx, x, axis, out); + LaunchFlipCudaKernel(dev_ctx, x, axis, out); break; case 8: - launch_flip_cuda_kernel(dev_ctx, x, axis, out); + LaunchFlipCudaKernel(dev_ctx, x, axis, out); break; case 9: - launch_flip_cuda_kernel(dev_ctx, x, axis, out); + LaunchFlipCudaKernel(dev_ctx, x, axis, out); break; default: PADDLE_THROW(phi::errors::InvalidArgument( -- GitLab