未验证 提交 b3283f4c 编写于 作者: 傅剑寒 提交者: GitHub

Optimize flip kernel by eliminating H2D data transfer, test=develop (#46046)

上级 65bdd80b
......@@ -13,126 +13,123 @@
// limitations under the License.
#include "paddle/phi/kernels/flip_kernel.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/array.h"
namespace phi {
template <typename T>
template <typename T, size_t Rank>
__global__ void flip_cuda_kernel(const int N,
const T* in_data,
T* out_data,
int64_t* x_shape,
int64_t* x_stride,
int* flip_dims,
int flip_dims_size,
int total_dims) {
phi::Array<int64_t, Rank> shape,
phi::Array<int64_t, Rank> stride,
phi::Array<int, Rank> flip_dims,
int flip_dims_size) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int cur_indices = idx, rem = 0, dst_offset = 0;
for (int i = 0; i < total_dims; ++i) {
for (int i = 0; i < Rank; ++i) {
int64_t temp = cur_indices;
cur_indices = cur_indices / x_stride[i];
rem = temp - cur_indices * x_stride[i];
cur_indices = cur_indices / stride[i];
rem = temp - cur_indices * stride[i];
// flip the indices if it is in flip_dims
for (int j = 0; j < flip_dims_size; ++j) {
if (i == flip_dims[j]) {
cur_indices = x_shape[i] - 1 - cur_indices;
cur_indices = shape[i] - 1 - cur_indices;
}
}
dst_offset += cur_indices * x_stride[i];
dst_offset += cur_indices * stride[i];
cur_indices = rem;
}
out_data[idx] = in_data[dst_offset];
}
template <typename T, typename Context>
void FlipKernel(const Context& dev_ctx,
template <typename T, typename Context, size_t N>
void launch_flip_cuda_kernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out) {
const auto gplace = dev_ctx.GetPlace();
auto cplace = phi::CPUPlace();
std::vector<int> flip_dims = axis;
std::vector<int> flip_dims_v = axis;
auto* in_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out);
const int flip_dims_size = static_cast<int>(flip_dims.size());
auto x_dims = x.dims();
const int total_dims = x_dims.size();
const int N = x.numel();
const int numel = x.numel();
int block_size = 512;
dim3 dim_block(block_size);
dim3 dim_grid((N + block_size - 1) / block_size);
dim3 dim_grid((numel + block_size - 1) / block_size);
for (size_t i = 0; i < flip_dims.size(); ++i) {
if (flip_dims[i] < 0) {
flip_dims[i] += total_dims;
for (size_t i = 0; i < flip_dims_v.size(); ++i) {
if (flip_dims_v[i] < 0) {
flip_dims_v[i] += total_dims;
}
}
auto x_stride = phi::stride(x_dims);
std::vector<int64_t> x_dims_v = phi::vectorize(x_dims);
std::vector<int64_t> x_stride_v = phi::vectorize(x_stride);
int bytes = total_dims * sizeof(int64_t);
auto x_strides_array_tmp = paddle::memory::Alloc(
dev_ctx.GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
int64_t* x_strides_array_gpu =
reinterpret_cast<int64_t*>(x_strides_array_tmp->ptr());
paddle::memory::Copy(gplace,
x_strides_array_gpu,
cplace,
x_stride_v.data(),
bytes,
dev_ctx.stream());
auto x_shape_array_tmp = paddle::memory::Alloc(
dev_ctx.GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
int64_t* x_shape_array_gpu =
reinterpret_cast<int64_t*>(x_shape_array_tmp->ptr());
paddle::memory::Copy(gplace,
x_shape_array_gpu,
cplace,
x_dims_v.data(),
bytes,
dev_ctx.stream());
bytes = flip_dims_size * sizeof(int);
auto flip_dims_array_tmp = paddle::memory::Alloc(
dev_ctx.GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
int* flip_dims_array_gpu = reinterpret_cast<int*>(flip_dims_array_tmp->ptr());
paddle::memory::Copy(gplace,
flip_dims_array_gpu,
cplace,
flip_dims.data(),
bytes,
dev_ctx.stream());
phi::Array<int64_t, N> stride_a;
phi::Array<int64_t, N> shape_a;
phi::Array<int, N> flip_dims_a;
size_t flip_dims_size = flip_dims_v.size();
for (size_t idx = 0; idx < N; ++idx) {
stride_a[idx] = x_stride[idx];
shape_a[idx] = x_dims[idx];
flip_dims_a[idx] = idx < flip_dims_size ? flip_dims_v[idx] : 0;
}
flip_cuda_kernel<T, N><<<dim_grid, dim_block, 0, dev_ctx.stream()>>>(
numel, in_data, out_data, shape_a, stride_a, flip_dims_a, flip_dims_size);
}
flip_cuda_kernel<T>
<<<dim_grid, dim_block, 0, dev_ctx.stream()>>>(N,
in_data,
out_data,
x_shape_array_gpu,
x_strides_array_gpu,
flip_dims_array_gpu,
flip_dims_size,
total_dims);
template <typename T, typename Context>
void FlipKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out) {
const size_t total_dims = x.dims().size();
switch (total_dims) {
case 1:
launch_flip_cuda_kernel<T, Context, 1>(dev_ctx, x, axis, out);
break;
case 2:
launch_flip_cuda_kernel<T, Context, 2>(dev_ctx, x, axis, out);
break;
case 3:
launch_flip_cuda_kernel<T, Context, 3>(dev_ctx, x, axis, out);
break;
case 4:
launch_flip_cuda_kernel<T, Context, 4>(dev_ctx, x, axis, out);
break;
case 5:
launch_flip_cuda_kernel<T, Context, 5>(dev_ctx, x, axis, out);
break;
case 6:
launch_flip_cuda_kernel<T, Context, 6>(dev_ctx, x, axis, out);
break;
case 7:
launch_flip_cuda_kernel<T, Context, 7>(dev_ctx, x, axis, out);
break;
case 8:
launch_flip_cuda_kernel<T, Context, 8>(dev_ctx, x, axis, out);
break;
case 9:
launch_flip_cuda_kernel<T, Context, 9>(dev_ctx, x, axis, out);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"dims of input tensor should be less than 10, But received"
"%d",
x.dims().size()));
}
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册