未验证 提交 b3283f4c 编写于 作者: 傅剑寒 提交者: GitHub

Optimize flip kernel by eliminating H2D data transfer, test=develop (#46046)

上级 65bdd80b
...@@ -13,126 +13,123 @@ ...@@ -13,126 +13,123 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/flip_kernel.h" #include "paddle/phi/kernels/flip_kernel.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/array.h"
namespace phi { namespace phi {
template <typename T> template <typename T, size_t Rank>
__global__ void flip_cuda_kernel(const int N, __global__ void flip_cuda_kernel(const int N,
const T* in_data, const T* in_data,
T* out_data, T* out_data,
int64_t* x_shape, phi::Array<int64_t, Rank> shape,
int64_t* x_stride, phi::Array<int64_t, Rank> stride,
int* flip_dims, phi::Array<int, Rank> flip_dims,
int flip_dims_size, int flip_dims_size) {
int total_dims) {
int idx = blockIdx.x * blockDim.x + threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) { if (idx >= N) {
return; return;
} }
int cur_indices = idx, rem = 0, dst_offset = 0; int cur_indices = idx, rem = 0, dst_offset = 0;
for (int i = 0; i < total_dims; ++i) { for (int i = 0; i < Rank; ++i) {
int64_t temp = cur_indices; int64_t temp = cur_indices;
cur_indices = cur_indices / x_stride[i]; cur_indices = cur_indices / stride[i];
rem = temp - cur_indices * x_stride[i]; rem = temp - cur_indices * stride[i];
// flip the indices if it is in flip_dims // flip the indices if it is in flip_dims
for (int j = 0; j < flip_dims_size; ++j) { for (int j = 0; j < flip_dims_size; ++j) {
if (i == flip_dims[j]) { if (i == flip_dims[j]) {
cur_indices = x_shape[i] - 1 - cur_indices; cur_indices = shape[i] - 1 - cur_indices;
} }
} }
dst_offset += cur_indices * x_stride[i]; dst_offset += cur_indices * stride[i];
cur_indices = rem; cur_indices = rem;
} }
out_data[idx] = in_data[dst_offset]; out_data[idx] = in_data[dst_offset];
} }
template <typename T, typename Context> template <typename T, typename Context, size_t N>
void FlipKernel(const Context& dev_ctx, void launch_flip_cuda_kernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& axis, const std::vector<int>& axis,
DenseTensor* out) { DenseTensor* out) {
const auto gplace = dev_ctx.GetPlace(); std::vector<int> flip_dims_v = axis;
auto cplace = phi::CPUPlace();
std::vector<int> flip_dims = axis;
auto* in_data = x.data<T>(); auto* in_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out); auto* out_data = dev_ctx.template Alloc<T>(out);
const int flip_dims_size = static_cast<int>(flip_dims.size());
auto x_dims = x.dims(); auto x_dims = x.dims();
const int total_dims = x_dims.size(); const int total_dims = x_dims.size();
const int N = x.numel(); const int numel = x.numel();
int block_size = 512; int block_size = 512;
dim3 dim_block(block_size); dim3 dim_block(block_size);
dim3 dim_grid((N + block_size - 1) / block_size); dim3 dim_grid((numel + block_size - 1) / block_size);
for (size_t i = 0; i < flip_dims.size(); ++i) { for (size_t i = 0; i < flip_dims_v.size(); ++i) {
if (flip_dims[i] < 0) { if (flip_dims_v[i] < 0) {
flip_dims[i] += total_dims; flip_dims_v[i] += total_dims;
} }
} }
auto x_stride = phi::stride(x_dims); auto x_stride = phi::stride(x_dims);
std::vector<int64_t> x_dims_v = phi::vectorize(x_dims);
std::vector<int64_t> x_stride_v = phi::vectorize(x_stride);
int bytes = total_dims * sizeof(int64_t);
auto x_strides_array_tmp = paddle::memory::Alloc(
dev_ctx.GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
int64_t* x_strides_array_gpu =
reinterpret_cast<int64_t*>(x_strides_array_tmp->ptr());
paddle::memory::Copy(gplace,
x_strides_array_gpu,
cplace,
x_stride_v.data(),
bytes,
dev_ctx.stream());
auto x_shape_array_tmp = paddle::memory::Alloc(
dev_ctx.GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
int64_t* x_shape_array_gpu =
reinterpret_cast<int64_t*>(x_shape_array_tmp->ptr());
paddle::memory::Copy(gplace,
x_shape_array_gpu,
cplace,
x_dims_v.data(),
bytes,
dev_ctx.stream());
bytes = flip_dims_size * sizeof(int); phi::Array<int64_t, N> stride_a;
auto flip_dims_array_tmp = paddle::memory::Alloc( phi::Array<int64_t, N> shape_a;
dev_ctx.GetPlace(), phi::Array<int, N> flip_dims_a;
bytes, size_t flip_dims_size = flip_dims_v.size();
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); for (size_t idx = 0; idx < N; ++idx) {
int* flip_dims_array_gpu = reinterpret_cast<int*>(flip_dims_array_tmp->ptr()); stride_a[idx] = x_stride[idx];
paddle::memory::Copy(gplace, shape_a[idx] = x_dims[idx];
flip_dims_array_gpu, flip_dims_a[idx] = idx < flip_dims_size ? flip_dims_v[idx] : 0;
cplace, }
flip_dims.data(), flip_cuda_kernel<T, N><<<dim_grid, dim_block, 0, dev_ctx.stream()>>>(
bytes, numel, in_data, out_data, shape_a, stride_a, flip_dims_a, flip_dims_size);
dev_ctx.stream()); }
flip_cuda_kernel<T> template <typename T, typename Context>
<<<dim_grid, dim_block, 0, dev_ctx.stream()>>>(N, void FlipKernel(const Context& dev_ctx,
in_data, const DenseTensor& x,
out_data, const std::vector<int>& axis,
x_shape_array_gpu, DenseTensor* out) {
x_strides_array_gpu, const size_t total_dims = x.dims().size();
flip_dims_array_gpu, switch (total_dims) {
flip_dims_size, case 1:
total_dims); launch_flip_cuda_kernel<T, Context, 1>(dev_ctx, x, axis, out);
break;
case 2:
launch_flip_cuda_kernel<T, Context, 2>(dev_ctx, x, axis, out);
break;
case 3:
launch_flip_cuda_kernel<T, Context, 3>(dev_ctx, x, axis, out);
break;
case 4:
launch_flip_cuda_kernel<T, Context, 4>(dev_ctx, x, axis, out);
break;
case 5:
launch_flip_cuda_kernel<T, Context, 5>(dev_ctx, x, axis, out);
break;
case 6:
launch_flip_cuda_kernel<T, Context, 6>(dev_ctx, x, axis, out);
break;
case 7:
launch_flip_cuda_kernel<T, Context, 7>(dev_ctx, x, axis, out);
break;
case 8:
launch_flip_cuda_kernel<T, Context, 8>(dev_ctx, x, axis, out);
break;
case 9:
launch_flip_cuda_kernel<T, Context, 9>(dev_ctx, x, axis, out);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"dims of input tensor should be less than 10, But received"
"%d",
x.dims().size()));
}
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册