未验证 提交 bcef8275 编写于 作者: 傅剑寒 提交者: GitHub

Flip Kernel Optimization (#46119)

上级 60f9c60c
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/array.h" #include "paddle/phi/core/utils/array.h"
...@@ -23,7 +24,7 @@ ...@@ -23,7 +24,7 @@
namespace phi { namespace phi {
template <typename T, size_t Rank> template <typename T, size_t Rank>
__global__ void flip_cuda_kernel(const int N, __global__ void flip_cuda_kernel(const int64_t N,
const T* in_data, const T* in_data,
T* out_data, T* out_data,
phi::Array<int64_t, Rank> shape, phi::Array<int64_t, Rank> shape,
...@@ -53,41 +54,44 @@ __global__ void flip_cuda_kernel(const int N, ...@@ -53,41 +54,44 @@ __global__ void flip_cuda_kernel(const int N,
} }
template <typename T, typename Context, size_t N> template <typename T, typename Context, size_t N>
void launch_flip_cuda_kernel(const Context& dev_ctx, void LaunchFlipCudaKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& axis, const std::vector<int>& axis,
DenseTensor* out) { DenseTensor* out) {
std::vector<int> flip_dims_v = axis;
auto* in_data = x.data<T>(); auto* in_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out); auto* out_data = dev_ctx.template Alloc<T>(out);
auto x_dims = x.dims(); auto x_dims = x.dims();
const int total_dims = x_dims.size(); const int total_dims = x_dims.size();
const int numel = x.numel(); const int64_t numel = x.numel();
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel);
int block_size = 512;
dim3 dim_block(block_size);
dim3 dim_grid((numel + block_size - 1) / block_size);
for (size_t i = 0; i < flip_dims_v.size(); ++i) {
if (flip_dims_v[i] < 0) {
flip_dims_v[i] += total_dims;
}
}
auto x_stride = phi::stride(x_dims); auto x_stride = phi::stride(x_dims);
phi::Array<int64_t, N> stride_a; phi::Array<int64_t, N> stride_a;
phi::Array<int64_t, N> shape_a; phi::Array<int64_t, N> shape_a;
phi::Array<int, N> flip_dims_a; phi::Array<int, N> flip_dims_a;
size_t flip_dims_size = flip_dims_v.size(); size_t flip_dims_size = axis.size();
for (size_t idx = 0; idx < N; ++idx) { for (size_t idx = 0; idx < N; ++idx) {
stride_a[idx] = x_stride[idx]; stride_a[idx] = x_stride[idx];
shape_a[idx] = x_dims[idx]; shape_a[idx] = x_dims[idx];
flip_dims_a[idx] = idx < flip_dims_size ? flip_dims_v[idx] : 0; flip_dims_a[idx] = idx < flip_dims_size ? axis[idx] : 0;
}
for (size_t i = 0; i < flip_dims_a.size(); ++i) {
if (flip_dims_a[i] < 0) {
flip_dims_a[i] += total_dims;
}
} }
flip_cuda_kernel<T, N><<<dim_grid, dim_block, 0, dev_ctx.stream()>>>( flip_cuda_kernel<T, N>
numel, in_data, out_data, shape_a, stride_a, flip_dims_a, flip_dims_size); <<<config.block_per_grid, config.thread_per_block, 0, dev_ctx.stream()>>>(
numel,
in_data,
out_data,
shape_a,
stride_a,
flip_dims_a,
flip_dims_size);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -98,31 +102,31 @@ void FlipKernel(const Context& dev_ctx, ...@@ -98,31 +102,31 @@ void FlipKernel(const Context& dev_ctx,
const size_t total_dims = x.dims().size(); const size_t total_dims = x.dims().size();
switch (total_dims) { switch (total_dims) {
case 1: case 1:
launch_flip_cuda_kernel<T, Context, 1>(dev_ctx, x, axis, out); LaunchFlipCudaKernel<T, Context, 1>(dev_ctx, x, axis, out);
break; break;
case 2: case 2:
launch_flip_cuda_kernel<T, Context, 2>(dev_ctx, x, axis, out); LaunchFlipCudaKernel<T, Context, 2>(dev_ctx, x, axis, out);
break; break;
case 3: case 3:
launch_flip_cuda_kernel<T, Context, 3>(dev_ctx, x, axis, out); LaunchFlipCudaKernel<T, Context, 3>(dev_ctx, x, axis, out);
break; break;
case 4: case 4:
launch_flip_cuda_kernel<T, Context, 4>(dev_ctx, x, axis, out); LaunchFlipCudaKernel<T, Context, 4>(dev_ctx, x, axis, out);
break; break;
case 5: case 5:
launch_flip_cuda_kernel<T, Context, 5>(dev_ctx, x, axis, out); LaunchFlipCudaKernel<T, Context, 5>(dev_ctx, x, axis, out);
break; break;
case 6: case 6:
launch_flip_cuda_kernel<T, Context, 6>(dev_ctx, x, axis, out); LaunchFlipCudaKernel<T, Context, 6>(dev_ctx, x, axis, out);
break; break;
case 7: case 7:
launch_flip_cuda_kernel<T, Context, 7>(dev_ctx, x, axis, out); LaunchFlipCudaKernel<T, Context, 7>(dev_ctx, x, axis, out);
break; break;
case 8: case 8:
launch_flip_cuda_kernel<T, Context, 8>(dev_ctx, x, axis, out); LaunchFlipCudaKernel<T, Context, 8>(dev_ctx, x, axis, out);
break; break;
case 9: case 9:
launch_flip_cuda_kernel<T, Context, 9>(dev_ctx, x, axis, out); LaunchFlipCudaKernel<T, Context, 9>(dev_ctx, x, axis, out);
break; break;
default: default:
PADDLE_THROW(phi::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册