未验证 提交 a78f0a16 编写于 作者: L limingshu 提交者: GitHub

H2D data transfer optimization with usage of structure type for stack kernel (#48899)

* first commit.

* refine performance with fast_divmod

* refine performance with fast_divmod
上级 96e58f87
......@@ -18,30 +18,101 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/fast_divmod.h"
namespace phi {
template <typename T, typename IntType>
__global__ void StackCUDAKernel(T** input_ptrs,
IntType split_size,
IntType rows,
IntType cols,
template <typename IndexT>
struct DivmodWarpper {
public:
void SetDivden(IndexT dividen) { divmoder = phi::funcs::FastDivMod(dividen); }
__device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) {
return divmoder.Divmod(val);
}
private:
phi::funcs::FastDivMod divmoder;
};
template <>
struct DivmodWarpper<int64_t> {
public:
using DivModT = phi::AlignedVector<int64_t, 2>;
void SetDivden(int64_t dividen) { dividen_ = dividen; }
__device__ inline DivModT div_mod(int64_t val) {
DivModT data;
data[0] = val / dividen_;
data[1] = val - data[0] * dividen_;
return data;
}
private:
int64_t dividen_;
};
constexpr int kWarpperSize = 64;
template <typename T, typename IndexT>
struct PointerArray : public DivmodWarpper<IndexT> {
public:
const T* data[kWarpperSize];
PointerArray(const std::vector<const DenseTensor*>& x,
int num,
int64_t dividen) {
this->SetDivden(dividen);
for (auto i = 0; i < num; ++i) {
data[i] = x[i]->data<T>();
}
}
};
template <typename Context, typename T, typename IndexT>
struct PointerToPointer : public DivmodWarpper<IndexT> {
public:
T** data;
PointerToPointer(const Context& ctx,
const std::vector<const DenseTensor*>& x,
int num,
int64_t dividen) {
this->SetDivden(dividen);
auto byte_len = num * sizeof(T*);
std::vector<const T*> x_datas(num);
for (int i = 0; i < num; ++i) {
x_datas[i] = x[i]->data<T>();
}
auto tmp_x_data = paddle::memory::Alloc(
ctx.GetPlace(),
byte_len,
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
paddle::memory::Copy(ctx.GetPlace(),
tmp_x_data->ptr(),
phi::CPUPlace(),
reinterpret_cast<void*>(x_datas.data()),
x_datas.size() * sizeof(T*),
ctx.stream());
data = reinterpret_cast<T**>(tmp_x_data->ptr());
}
};
template <typename T, typename IndexT, typename WarpT>
__global__ void StackCUDAKernel(WarpT input_warpper,
IndexT split_size,
IndexT rows,
IndexT cols,
T* __restrict__ output) {
IntType grid_x = static_cast<IntType>(blockIdx.x) * blockDim.x + threadIdx.x;
IntType grid_x_stride = static_cast<IntType>(blockDim.x) * gridDim.x;
IntType grid_y_stride = static_cast<IntType>(blockDim.y) * gridDim.y;
IndexT grid_x = static_cast<IndexT>(blockIdx.x) * blockDim.x + threadIdx.x;
IndexT grid_x_stride = static_cast<IndexT>(blockDim.x) * gridDim.x;
IndexT grid_y_stride = static_cast<IndexT>(blockDim.y) * gridDim.y;
for (; grid_x < cols; grid_x += grid_x_stride) {
IntType grid_y =
static_cast<IntType>(blockIdx.y) * blockDim.y + threadIdx.y;
IndexT grid_y = static_cast<IndexT>(blockIdx.y) * blockDim.y + threadIdx.y;
IntType split = grid_x / split_size;
const T* input_ptr = input_ptrs[split];
IntType col_offset = grid_x % split_size;
auto divmod_rslt = input_warpper.div_mod(grid_x);
const T* input_ptr = input_warpper.data[divmod_rslt[0]];
#pragma unroll
for (; grid_y < rows; grid_y += grid_y_stride) {
output[grid_y * cols + grid_x] =
input_ptr[grid_y * split_size + col_offset];
input_ptr[grid_y * split_size + divmod_rslt[1]];
}
}
}
......@@ -52,24 +123,8 @@ void StackKernel(const Context& dev_ctx,
int axis,
DenseTensor* out) {
if (axis < 0) axis += (x[0]->dims().size() + 1);
int n = static_cast<int>(x.size());
T* y_data = dev_ctx.template Alloc<T>(out);
std::vector<const T*> x_datas(n);
for (int i = 0; i < n; i++) {
x_datas[i] = x[i]->data<T>();
}
auto tmp_x_data = paddle::memory::Alloc(
dev_ctx.GetPlace(),
x_datas.size() * sizeof(T*),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
paddle::memory::Copy(dev_ctx.GetPlace(),
tmp_x_data->ptr(),
phi::CPUPlace(),
reinterpret_cast<void*>(x_datas.data()),
x_datas.size() * sizeof(T*),
dev_ctx.stream());
// Split x dim from axis to matrix
int64_t x_row = 1, x_col = 1;
......@@ -78,33 +133,40 @@ void StackKernel(const Context& dev_ctx,
}
x_col = x[0]->numel() / x_row;
int64_t out_col = x_col * n;
auto config =
phi::backends::gpu::GetGpuLaunchConfig2D(dev_ctx, out_col, x_row);
if (out->numel() < std::numeric_limits<int32_t>::max()) {
StackCUDAKernel<T, int32_t>
<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(reinterpret_cast<T**>(tmp_x_data->ptr()),
static_cast<int32_t>(x_col),
static_cast<int32_t>(x_row),
static_cast<int32_t>(out_col),
y_data);
#define IMPL_STACK_CUDA_KERNEL(index_t, input_warpper) \
StackCUDAKernel<T, index_t, decltype(input_warpper)> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
dev_ctx.stream()>>>(input_warpper, \
static_cast<index_t>(x_col), \
static_cast<index_t>(x_row), \
static_cast<index_t>(out_col), \
y_data);
bool use_int32 = out->numel() < std::numeric_limits<int32_t>::max();
if (n <= kWarpperSize) {
if (use_int32) {
PointerArray<T, int32_t> ptr_array(x, n, x_col);
IMPL_STACK_CUDA_KERNEL(int32_t, ptr_array);
} else {
PointerArray<T, int64_t> ptr_array(x, n, x_col);
IMPL_STACK_CUDA_KERNEL(int64_t, ptr_array);
}
} else {
StackCUDAKernel<T, int64_t>
<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(reinterpret_cast<T**>(tmp_x_data->ptr()),
x_col,
x_row,
out_col,
y_data);
if (use_int32) {
PointerToPointer<Context, T, int32_t> ptr_array(dev_ctx, x, n, x_col);
IMPL_STACK_CUDA_KERNEL(int32_t, ptr_array);
} else {
PointerToPointer<Context, T, int64_t> ptr_array(dev_ctx, x, n, x_col);
IMPL_STACK_CUDA_KERNEL(int64_t, ptr_array);
}
}
#undef IMPL_STACK_CUDA_KERNEL
}
} // namespace phi
PD_REGISTER_KERNEL(stack,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册