未验证 提交 bcef8275 编写于 作者: 傅剑寒 提交者: GitHub

Flip Kernel Optimization (#46119)

上级 60f9c60c
......@@ -16,6 +16,7 @@
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/array.h"
......@@ -23,7 +24,7 @@
namespace phi {
template <typename T, size_t Rank>
__global__ void flip_cuda_kernel(const int N,
__global__ void flip_cuda_kernel(const int64_t N,
const T* in_data,
T* out_data,
phi::Array<int64_t, Rank> shape,
......@@ -53,41 +54,44 @@ __global__ void flip_cuda_kernel(const int N,
}
template <typename T, typename Context, size_t N>
void launch_flip_cuda_kernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out) {
std::vector<int> flip_dims_v = axis;
void LaunchFlipCudaKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out) {
auto* in_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out);
auto x_dims = x.dims();
const int total_dims = x_dims.size();
const int numel = x.numel();
int block_size = 512;
dim3 dim_block(block_size);
dim3 dim_grid((numel + block_size - 1) / block_size);
for (size_t i = 0; i < flip_dims_v.size(); ++i) {
if (flip_dims_v[i] < 0) {
flip_dims_v[i] += total_dims;
}
}
const int64_t numel = x.numel();
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel);
auto x_stride = phi::stride(x_dims);
phi::Array<int64_t, N> stride_a;
phi::Array<int64_t, N> shape_a;
phi::Array<int, N> flip_dims_a;
size_t flip_dims_size = flip_dims_v.size();
size_t flip_dims_size = axis.size();
for (size_t idx = 0; idx < N; ++idx) {
stride_a[idx] = x_stride[idx];
shape_a[idx] = x_dims[idx];
flip_dims_a[idx] = idx < flip_dims_size ? flip_dims_v[idx] : 0;
flip_dims_a[idx] = idx < flip_dims_size ? axis[idx] : 0;
}
for (size_t i = 0; i < flip_dims_a.size(); ++i) {
if (flip_dims_a[i] < 0) {
flip_dims_a[i] += total_dims;
}
}
flip_cuda_kernel<T, N><<<dim_grid, dim_block, 0, dev_ctx.stream()>>>(
numel, in_data, out_data, shape_a, stride_a, flip_dims_a, flip_dims_size);
flip_cuda_kernel<T, N>
<<<config.block_per_grid, config.thread_per_block, 0, dev_ctx.stream()>>>(
numel,
in_data,
out_data,
shape_a,
stride_a,
flip_dims_a,
flip_dims_size);
}
template <typename T, typename Context>
......@@ -98,31 +102,31 @@ void FlipKernel(const Context& dev_ctx,
const size_t total_dims = x.dims().size();
switch (total_dims) {
case 1:
launch_flip_cuda_kernel<T, Context, 1>(dev_ctx, x, axis, out);
LaunchFlipCudaKernel<T, Context, 1>(dev_ctx, x, axis, out);
break;
case 2:
launch_flip_cuda_kernel<T, Context, 2>(dev_ctx, x, axis, out);
LaunchFlipCudaKernel<T, Context, 2>(dev_ctx, x, axis, out);
break;
case 3:
launch_flip_cuda_kernel<T, Context, 3>(dev_ctx, x, axis, out);
LaunchFlipCudaKernel<T, Context, 3>(dev_ctx, x, axis, out);
break;
case 4:
launch_flip_cuda_kernel<T, Context, 4>(dev_ctx, x, axis, out);
LaunchFlipCudaKernel<T, Context, 4>(dev_ctx, x, axis, out);
break;
case 5:
launch_flip_cuda_kernel<T, Context, 5>(dev_ctx, x, axis, out);
LaunchFlipCudaKernel<T, Context, 5>(dev_ctx, x, axis, out);
break;
case 6:
launch_flip_cuda_kernel<T, Context, 6>(dev_ctx, x, axis, out);
LaunchFlipCudaKernel<T, Context, 6>(dev_ctx, x, axis, out);
break;
case 7:
launch_flip_cuda_kernel<T, Context, 7>(dev_ctx, x, axis, out);
LaunchFlipCudaKernel<T, Context, 7>(dev_ctx, x, axis, out);
break;
case 8:
launch_flip_cuda_kernel<T, Context, 8>(dev_ctx, x, axis, out);
LaunchFlipCudaKernel<T, Context, 8>(dev_ctx, x, axis, out);
break;
case 9:
launch_flip_cuda_kernel<T, Context, 9>(dev_ctx, x, axis, out);
LaunchFlipCudaKernel<T, Context, 9>(dev_ctx, x, axis, out);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册