未验证 提交 0de94cd9 编写于 作者: L limingshu 提交者: GitHub

H2D data transfer optimization for concat kernel (#49040)

上级 f484a61e
...@@ -53,20 +53,25 @@ inline T DivUp(T a, T b) { ...@@ -53,20 +53,25 @@ inline T DivUp(T a, T b) {
// https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 // https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
// for round integer value into next highest power of 2. // for round integer value into next highest power of 2.
inline int64_t RoundToPowerOfTwo(int64_t n) { inline int64_t RoundToNextHighPowOfTwo(int64_t n, int64_t min_val = 1) {
n--; n--;
n |= (n >> 1); n |= (n >> 1);
n |= (n >> 2); n |= (n >> 2);
n |= (n >> 4); n |= (n >> 4);
n |= (n >> 8); n |= (n >> 8);
n |= (n >> 16); n |= (n >> 16);
int64_t min_val = 32; return std::max(min_val, (n + 1));
}
inline int64_t RoundToPowerOfTwo(int64_t n) {
constexpr int64_t min_val = 32;
int64_t num = RoundToNextHighPowOfTwo(n, min_val);
#ifdef __HIPCC__ #ifdef __HIPCC__
int64_t max_val = 256; int64_t max_val = 256;
#else #else
int64_t max_val = 1024; int64_t max_val = 1024;
#endif #endif
return std::min(max_val, std::max(min_val, (n + 1))); return std::min(max_val, num);
} }
#ifdef WITH_NV_JETSON #ifdef WITH_NV_JETSON
......
...@@ -25,7 +25,9 @@ namespace phi { ...@@ -25,7 +25,9 @@ namespace phi {
template <typename IndexT> template <typename IndexT>
struct DivmodWarpper { struct DivmodWarpper {
public: public:
void SetDivden(IndexT dividen) { divmoder = phi::funcs::FastDivMod(dividen); } void SetDivisor(IndexT divisor) {
divmoder = phi::funcs::FastDivMod(divisor);
}
__device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) { __device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) {
return divmoder.Divmod(val); return divmoder.Divmod(val);
} }
...@@ -39,7 +41,7 @@ struct DivmodWarpper<int64_t> { ...@@ -39,7 +41,7 @@ struct DivmodWarpper<int64_t> {
public: public:
using DivModT = phi::AlignedVector<int64_t, 2>; using DivModT = phi::AlignedVector<int64_t, 2>;
void SetDivden(int64_t dividen) { dividen_ = dividen; } void SetDivisor(int64_t divisor) { dividen_ = divisor; }
__device__ inline DivModT div_mod(int64_t val) { __device__ inline DivModT div_mod(int64_t val) {
DivModT data; DivModT data;
data[0] = val / dividen_; data[0] = val / dividen_;
...@@ -51,15 +53,14 @@ struct DivmodWarpper<int64_t> { ...@@ -51,15 +53,14 @@ struct DivmodWarpper<int64_t> {
int64_t dividen_; int64_t dividen_;
}; };
constexpr int kWarpperSize = 64; template <typename T, typename IndexT, int Size>
template <typename T, typename IndexT>
struct PointerArray : public DivmodWarpper<IndexT> { struct PointerArray : public DivmodWarpper<IndexT> {
public: public:
const T* data[kWarpperSize]; const T* data[Size];
PointerArray(const std::vector<const DenseTensor*>& x, PointerArray(const std::vector<const DenseTensor*>& x,
int num, int num,
int64_t dividen) { IndexT divisor) {
this->SetDivden(dividen); this->SetDivisor(divisor);
for (auto i = 0; i < num; ++i) { for (auto i = 0; i < num; ++i) {
data[i] = x[i]->data<T>(); data[i] = x[i]->data<T>();
} }
...@@ -69,33 +70,33 @@ struct PointerArray : public DivmodWarpper<IndexT> { ...@@ -69,33 +70,33 @@ struct PointerArray : public DivmodWarpper<IndexT> {
template <typename Context, typename T, typename IndexT> template <typename Context, typename T, typename IndexT>
struct PointerToPointer : public DivmodWarpper<IndexT> { struct PointerToPointer : public DivmodWarpper<IndexT> {
public: public:
T** data; T** data{nullptr};
PointerToPointer(const Context& ctx, PointerToPointer(const Context& ctx,
const std::vector<const DenseTensor*>& x, const std::vector<const DenseTensor*>& x,
int num, IndexT num,
int64_t dividen) { IndexT divisor,
this->SetDivden(dividen); paddle::memory::AllocationPtr* dev_ins_ptr) {
auto byte_len = num * sizeof(T*); this->SetDivisor(divisor);
std::vector<const T*> x_datas(num); std::vector<const T*> x_datas(num);
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
x_datas[i] = x[i]->data<T>(); x_datas[i] = x[i]->data<T>();
} }
auto tmp_x_data = paddle::memory::Alloc( *dev_ins_ptr = paddle::memory::Alloc(
ctx.GetPlace(), ctx.GetPlace(),
byte_len, num * sizeof(T*),
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream()))); phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
paddle::memory::Copy(ctx.GetPlace(), paddle::memory::Copy(ctx.GetPlace(),
tmp_x_data->ptr(), (*dev_ins_ptr)->ptr(),
phi::CPUPlace(), phi::CPUPlace(),
reinterpret_cast<void*>(x_datas.data()), reinterpret_cast<void*>(x_datas.data()),
x_datas.size() * sizeof(T*), num * sizeof(T*),
ctx.stream()); ctx.stream());
data = reinterpret_cast<T**>(tmp_x_data->ptr()); data = reinterpret_cast<T**>((*dev_ins_ptr)->ptr());
} }
}; };
template <typename T, typename IndexT, typename WarpT> template <typename T, typename IndexT, typename WrapT>
__global__ void StackCUDAKernel(WarpT input_warpper, __global__ void StackCUDAKernel(WrapT input_warpper,
IndexT split_size, IndexT split_size,
IndexT rows, IndexT rows,
IndexT cols, IndexT cols,
...@@ -117,14 +118,56 @@ __global__ void StackCUDAKernel(WarpT input_warpper, ...@@ -117,14 +118,56 @@ __global__ void StackCUDAKernel(WarpT input_warpper,
} }
} }
template <typename T, typename IndexT, typename Context>
void LaunchStackCUDAKernelWithIndexType(
const Context& ctx,
const IndexT x_col,
const IndexT x_row,
const IndexT out_col,
const phi::backends::gpu::GpuLaunchConfig& cfg,
const std::vector<const DenseTensor*>& x,
T* dst_data) {
int num = static_cast<int>(x.size());
#define IMPL_STACK_CUDA_KERNEL_CASE(size_, ...) \
case size_: { \
PointerArray<T, IndexT, size_> ptr_array(x, num, x_col); \
__VA_ARGS__; \
} break;
#define IMPL_STACK_CUDA_KERNEL_HELPER(...) \
IMPL_STACK_CUDA_KERNEL_CASE(2, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(8, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(16, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(32, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(64, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(128, ##__VA_ARGS__);
switch (phi::backends::gpu::RoundToNextHighPowOfTwo(num, 4)) {
IMPL_STACK_CUDA_KERNEL_HELPER(
StackCUDAKernel<T, IndexT, decltype(ptr_array)>
<<<cfg.block_per_grid, cfg.thread_per_block, 0, ctx.stream()>>>(
ptr_array, x_col, x_row, out_col, dst_data));
default: {
paddle::memory::AllocationPtr dev_ins_ptr{nullptr};
PointerToPointer<Context, T, IndexT> ptr_array(
ctx, x, num, x_col, &dev_ins_ptr);
StackCUDAKernel<T, IndexT, decltype(ptr_array)>
<<<cfg.block_per_grid, cfg.thread_per_block, 0, ctx.stream()>>>(
ptr_array, x_col, x_row, out_col, dst_data);
}
}
#undef IMPL_STACK_CUDA_KERNEL_HELPER
#undef IMPL_STACK_CUDA_KERNEL_CASE
}
template <typename T, typename Context> template <typename T, typename Context>
void StackKernel(const Context& dev_ctx, void StackKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x, const std::vector<const DenseTensor*>& x,
int axis, int axis,
DenseTensor* out) { DenseTensor* out) {
if (axis < 0) axis += (x[0]->dims().size() + 1); if (axis < 0) axis += (x[0]->dims().size() + 1);
int n = static_cast<int>(x.size()); int num = static_cast<int>(x.size());
T* y_data = dev_ctx.template Alloc<T>(out); T* dst_data = dev_ctx.template Alloc<T>(out);
// Split x dim from axis to matrix // Split x dim from axis to matrix
int64_t x_row = 1, x_col = 1; int64_t x_row = 1, x_col = 1;
...@@ -132,40 +175,17 @@ void StackKernel(const Context& dev_ctx, ...@@ -132,40 +175,17 @@ void StackKernel(const Context& dev_ctx,
x_row *= x[0]->dims()[i]; x_row *= x[0]->dims()[i];
} }
x_col = x[0]->numel() / x_row; x_col = x[0]->numel() / x_row;
int64_t out_col = x_col * n; int64_t out_col = x_col * num;
auto config = auto config =
phi::backends::gpu::GetGpuLaunchConfig2D(dev_ctx, out_col, x_row); phi::backends::gpu::GetGpuLaunchConfig2D(dev_ctx, out_col, x_row);
#define IMPL_STACK_CUDA_KERNEL(index_t, input_warpper) \ if (out->numel() < std::numeric_limits<int32_t>::max()) {
StackCUDAKernel<T, index_t, decltype(input_warpper)> \ LaunchStackCUDAKernelWithIndexType<T, int32_t, Context>(
<<<config.block_per_grid, \ dev_ctx, x_col, x_row, out_col, config, x, dst_data);
config.thread_per_block, \
0, \
dev_ctx.stream()>>>(input_warpper, \
static_cast<index_t>(x_col), \
static_cast<index_t>(x_row), \
static_cast<index_t>(out_col), \
y_data);
bool use_int32 = out->numel() < std::numeric_limits<int32_t>::max();
if (n <= kWarpperSize) {
if (use_int32) {
PointerArray<T, int32_t> ptr_array(x, n, x_col);
IMPL_STACK_CUDA_KERNEL(int32_t, ptr_array);
} else {
PointerArray<T, int64_t> ptr_array(x, n, x_col);
IMPL_STACK_CUDA_KERNEL(int64_t, ptr_array);
}
} else { } else {
if (use_int32) { LaunchStackCUDAKernelWithIndexType<T, int64_t, Context>(
PointerToPointer<Context, T, int32_t> ptr_array(dev_ctx, x, n, x_col); dev_ctx, x_col, x_row, out_col, config, x, dst_data);
IMPL_STACK_CUDA_KERNEL(int32_t, ptr_array);
} else {
PointerToPointer<Context, T, int64_t> ptr_array(dev_ctx, x, n, x_col);
IMPL_STACK_CUDA_KERNEL(int64_t, ptr_array);
}
} }
#undef IMPL_STACK_CUDA_KERNEL
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册