未验证 提交 0de94cd9 编写于 作者: L limingshu 提交者: GitHub

H2D data transfer optimization for concat kernel (#49040)

上级 f484a61e
......@@ -53,20 +53,25 @@ inline T DivUp(T a, T b) {
// https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
// for round integer value into next highest power of 2.
inline int64_t RoundToPowerOfTwo(int64_t n) {
inline int64_t RoundToNextHighPowOfTwo(int64_t n, int64_t min_val = 1) {
n--;
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
int64_t min_val = 32;
return std::max(min_val, (n + 1));
}
inline int64_t RoundToPowerOfTwo(int64_t n) {
constexpr int64_t min_val = 32;
int64_t num = RoundToNextHighPowOfTwo(n, min_val);
#ifdef __HIPCC__
int64_t max_val = 256;
#else
int64_t max_val = 1024;
#endif
return std::min(max_val, std::max(min_val, (n + 1)));
return std::min(max_val, num);
}
#ifdef WITH_NV_JETSON
......
......@@ -15,49 +15,155 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
namespace phi {
namespace funcs {
template <typename T, int Size>
struct PointerWrapper {
public:
const T* ins_addr[Size];
__device__ inline const T* operator[](int i) const { return ins_addr[i]; }
PointerWrapper() {}
PointerWrapper(const phi::GPUContext& ctx,
const std::vector<phi::DenseTensor>& ins,
const T** pre_alloced_host_ptr) {
for (auto i = 0; i < ins.size(); ++i) {
ins_addr[i] = ins[i].data<T>();
}
}
};
template <typename T>
__global__ void ConcatKernel_(const T** inputs,
const int64_t* input_cols,
int col_size,
const int64_t output_rows,
const int64_t output_cols,
T* output) {
int64_t curr_segment = 0;
int64_t curr_offset = input_cols[0];
CUDA_KERNEL_LOOP_TYPE(tid_x, output_cols, int64_t) {
int64_t curr_col_offset = input_cols[curr_segment + 1];
struct PointerToPointer {
public:
T** ins_addr{nullptr};
__device__ inline const T* operator[](int i) const { return ins_addr[i]; }
PointerToPointer() {}
PointerToPointer(const phi::GPUContext& ctx,
const std::vector<phi::DenseTensor>& ins,
const T** pre_alloced_host_ptr,
paddle::memory::AllocationPtr* dev_ins_ptr) {
auto in_num = ins.size();
for (auto i = 0; i < in_num; ++i) {
pre_alloced_host_ptr[i] = ins[i].data<T>();
}
*dev_ins_ptr = paddle::memory::Alloc(
ctx.GetPlace(),
in_num * sizeof(T*),
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph(
pre_alloced_host_ptr, in_num);
paddle::memory::Copy(ctx.GetPlace(),
(*dev_ins_ptr)->ptr(),
phi::CPUPlace(),
restored,
in_num * sizeof(T*),
ctx.stream());
ins_addr = reinterpret_cast<T**>((*dev_ins_ptr)->ptr());
}
};
template <typename T, typename IndexT, int Size>
struct PointerAndColWrapper {
public:
IndexT col_length[Size];
PointerAndColWrapper(const phi::GPUContext& ctx,
const std::vector<phi::DenseTensor>& ins,
const IndexT& inputs_col_num,
const T** pre_alloced_host_ptr,
IndexT* inputs_col) {
for (auto i = 0; i < inputs_col_num; ++i) {
col_length[i] = inputs_col[i];
}
ins_ptr_wrapper = PointerWrapper<T, Size>(ctx, ins, pre_alloced_host_ptr);
}
__device__ inline const T* operator[](int i) const {
return ins_ptr_wrapper[i];
}
private:
PointerWrapper<T, Size> ins_ptr_wrapper;
};
template <typename T, typename IndexT>
struct PointerToPointerAndCol {
public:
IndexT* col_length{nullptr};
PointerToPointerAndCol(const phi::GPUContext& ctx,
const std::vector<phi::DenseTensor>& ins,
const IndexT inputs_col_num,
const T** pre_alloced_host_ptr,
IndexT* inputs_col,
paddle::memory::AllocationPtr* dev_ins_ptr,
paddle::memory::AllocationPtr* dev_col_ptr) {
*dev_col_ptr = paddle::memory::Alloc(
ctx.GetPlace(),
inputs_col_num * sizeof(IndexT),
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph(
inputs_col, inputs_col_num);
paddle::memory::Copy(ctx.GetPlace(),
(*dev_col_ptr)->ptr(),
phi::CPUPlace(),
restored,
inputs_col_num * sizeof(IndexT),
ctx.stream());
col_length = static_cast<IndexT*>((*dev_col_ptr)->ptr());
ins_ptr_wrapper =
PointerToPointer<T>(ctx, ins, pre_alloced_host_ptr, dev_ins_ptr);
}
__device__ inline const T* operator[](int i) const {
return ins_ptr_wrapper[i];
}
private:
PointerToPointer<T> ins_ptr_wrapper;
};
template <typename T, typename IndexT, typename PointerAndColWrapperT>
__global__ void ConcatTensorWithDifferentShape(PointerAndColWrapperT ins_datas,
int col_size,
const IndexT output_rows,
const IndexT output_cols,
T* output) {
IndexT curr_segment = 0;
IndexT curr_offset = ins_datas.col_length[0];
CUDA_KERNEL_LOOP_TYPE(tid_x, output_cols, IndexT) {
IndexT curr_col_offset = ins_datas.col_length[curr_segment + 1];
while (curr_col_offset <= tid_x) {
curr_offset = curr_col_offset;
++curr_segment;
curr_col_offset = input_cols[curr_segment + 1];
curr_col_offset = ins_datas.col_length[curr_segment + 1];
}
int64_t local_col = tid_x - curr_offset;
int64_t segment_width = curr_col_offset - curr_offset;
IndexT local_col = tid_x - curr_offset;
IndexT segment_width = curr_col_offset - curr_offset;
const T* input_ptr = inputs[curr_segment];
int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y;
const T* input_ptr = ins_datas[curr_segment];
IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
output[tid_y * output_cols + tid_x] =
input_ptr[tid_y * segment_width + local_col];
}
}
template <typename T>
__device__ void ConcatKernelDetail(const T** inputs_data,
const int64_t fixed_in_col,
const int64_t out_rows,
const int64_t out_cols,
T* output_data) {
CUDA_KERNEL_LOOP_TYPE(tid_x, out_cols, int64_t) {
int64_t split = tid_x * 1.0 / fixed_in_col;
int64_t in_offset = tid_x - split * fixed_in_col;
const T* input_ptr = inputs_data[split];
int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y;
template <typename T, typename IndexT, typename PointerWrapperT>
__global__ void ConcatTensorWithSameShape(PointerWrapperT ins_data,
const IndexT fixed_in_col,
const IndexT out_rows,
const IndexT out_cols,
T* output_data) {
CUDA_KERNEL_LOOP_TYPE(tid_x, out_cols, IndexT) {
IndexT split = tid_x / fixed_in_col;
IndexT in_offset = tid_x - split * fixed_in_col;
const T* input_ptr = ins_data[split];
IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
output_data[tid_y * out_cols + tid_x] =
input_ptr[tid_y * fixed_in_col + in_offset];
......@@ -65,65 +171,6 @@ __device__ void ConcatKernelDetail(const T** inputs_data,
}
}
template <typename T>
__global__ void ConcatKernel_(const T* input_addr0,
const T* input_addr1,
const int64_t fixed_in_col,
const int64_t out_rows,
const int64_t out_cols,
T* output_data) {
const T* inputs_data[2];
inputs_data[0] = input_addr0;
inputs_data[1] = input_addr1;
ConcatKernelDetail<T>(
inputs_data, fixed_in_col, out_rows, out_cols, output_data);
}
template <typename T>
__global__ void ConcatKernel_(const T* input_addr0,
const T* input_addr1,
const T* input_addr2,
const int64_t fixed_in_col,
const int64_t out_rows,
const int64_t out_cols,
T* output_data) {
const T* inputs_data[3];
inputs_data[0] = input_addr0;
inputs_data[1] = input_addr1;
inputs_data[2] = input_addr2;
ConcatKernelDetail<T>(
inputs_data, fixed_in_col, out_rows, out_cols, output_data);
}
template <typename T>
__global__ void ConcatKernel_(const T* input_addr0,
const T* input_addr1,
const T* input_addr2,
const T* input_addr3,
const int64_t fixed_in_col,
const int64_t out_rows,
const int64_t out_cols,
T* output_data) {
const T* inputs_data[4];
inputs_data[0] = input_addr0;
inputs_data[1] = input_addr1;
inputs_data[2] = input_addr2;
inputs_data[3] = input_addr3;
ConcatKernelDetail<T>(
inputs_data, fixed_in_col, out_rows, out_cols, output_data);
}
template <typename T>
__global__ void ConcatKernel_(const T** inputs_data,
const int in_num,
const int64_t fixed_in_col,
const int64_t out_rows,
const int64_t out_cols,
T* output_data) {
ConcatKernelDetail<T>(
inputs_data, fixed_in_col, out_rows, out_cols, output_data);
}
template <typename T>
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
......@@ -254,155 +301,146 @@ static inline void GetBlockDims(const phi::GPUContext& context,
* All tensors' dimension should be the same and the values of
* each dimension must be the same, except the axis dimension.
*/
template <typename T>
struct ConcatFunctor<phi::GPUContext, T> {
void operator()(const phi::GPUContext& context,
const std::vector<phi::DenseTensor>& input,
int axis,
phi::DenseTensor* output) {
// TODO(zcd): Add input data validity checking
int64_t in_num = input.size();
int64_t in_row = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
in_row *= dim_0[i];
}
int64_t in_col = input[0].numel() / in_row;
int64_t out_row = in_row, out_col = 0;
int64_t inputs_col_num = in_num + 1;
std::vector<const T*> inputs_data_vec(in_num);
std::vector<int64_t> inputs_col_vec(inputs_col_num);
const T** inputs_data = inputs_data_vec.data();
int64_t* inputs_col = inputs_col_vec.data();
// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
template <typename T, typename IndexT>
void ConcatFunctorWithIndexType(const phi::GPUContext& ctx,
const std::vector<phi::DenseTensor>& ins,
int axis,
phi::DenseTensor* output) {
// TODO(zcd): Add input data validity checking
IndexT in_num = ins.size();
IndexT in_row = 1;
auto dim_0 = ins[0].dims();
for (int i = 0; i < axis; ++i) {
in_row *= dim_0[i];
}
IndexT in_col = ins[0].numel() / in_row;
IndexT out_row = in_row, out_col = 0;
IndexT inputs_col_num = in_num + 1;
std::vector<const T*> inputs_data_vec(in_num, nullptr);
std::vector<IndexT> inputs_col_vec(inputs_col_num, 0);
const T** inputs_data = inputs_data_vec.data();
IndexT* inputs_col = inputs_col_vec.data();
#ifdef PADDLE_WITH_HIP
paddle::memory::AllocationPtr data_alloc, col_alloc;
// TODO(chentianyu03): try to find a method to remove the Alloc function
data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(),
in_num * sizeof(T*));
inputs_data = reinterpret_cast<const T**>(data_alloc->ptr());
// TODO(chentianyu03): try to find a method to remove the Alloc function
col_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(),
inputs_col_num * sizeof(int));
inputs_col = reinterpret_cast<int64_t*>(col_alloc->ptr());
// TODO(chentianyu03): try to find a method to remove the Alloc function
paddle::memory::AllocationPtr data_alloc = paddle::memory::Alloc(
paddle::platform::CUDAPinnedPlace(), in_num * sizeof(T*));
inputs_data = reinterpret_cast<const T**>(data_alloc->ptr());
paddle::memory::AllocationPtr col_alloc = paddle::memory::Alloc(
paddle::platform::CUDAPinnedPlace(), inputs_col_num * sizeof(IndexT));
inputs_col = reinterpret_cast<IndexT*>(col_alloc->ptr());
#endif
inputs_col[0] = 0;
bool has_same_shape = true;
for (int i = 0; i < in_num; ++i) {
int64_t t_cols = input[i].numel() / in_row;
if (has_same_shape) {
if (t_cols != in_col) has_same_shape = false;
}
out_col += t_cols;
inputs_col[i + 1] = out_col;
inputs_data[i] = input[i].data<T>();
bool has_same_shape = true;
for (int i = 0; i < in_num; ++i) {
IndexT t_cols = ins[i].numel() / in_row;
if (has_same_shape) {
has_same_shape &= (t_cols == in_col);
}
dim3 block_dims;
dim3 grid_dims;
GetBlockDims(context, out_row, out_col, &block_dims, &grid_dims);
paddle::memory::allocation::AllocationPtr tmp_dev_ins_data;
const T** dev_ins_data = nullptr;
if (!has_same_shape || in_num < 2 || in_num > 4) {
tmp_dev_ins_data = paddle::memory::Alloc(
context.GetPlace(),
in_num * sizeof(T*),
phi::Stream(reinterpret_cast<phi::StreamId>(context.stream())));
auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph(
inputs_data, in_num);
paddle::memory::Copy(context.GetPlace(),
tmp_dev_ins_data->ptr(),
paddle::platform::CPUPlace(),
restored,
in_num * sizeof(T*),
context.stream());
dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr());
out_col += t_cols;
inputs_col[i + 1] = out_col;
}
dim3 block_dims;
dim3 grid_dims;
GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims);
IndexT limit_num = has_same_shape ? in_num : inputs_col_num;
#define IMPL_CONCATE_CUDA_KERNEL_HELPER(func_impl, ...) \
func_impl(4, ##__VA_ARGS__); \
func_impl(8, ##__VA_ARGS__); \
func_impl(16, ##__VA_ARGS__); \
func_impl(32, ##__VA_ARGS__); \
func_impl(64, ##__VA_ARGS__); \
func_impl(128, ##__VA_ARGS__);
if (has_same_shape) {
#define IMPL_CONCAT_CUDA_KERNEL_CASE(size_, ...) \
case size_: { \
PointerWrapper<T, size_> ptr_array(ctx, ins, inputs_data); \
__VA_ARGS__; \
} break;
switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) {
IMPL_CONCATE_CUDA_KERNEL_HELPER(
IMPL_CONCAT_CUDA_KERNEL_CASE,
ConcatTensorWithSameShape<T, IndexT, decltype(ptr_array)>
<<<grid_dims, block_dims, 0, ctx.stream()>>>(
ptr_array, in_col, out_row, out_col, output->data<T>()));
default: {
paddle::memory::AllocationPtr dev_ins_ptr{nullptr};
PointerToPointer<T> ptr_array(ctx, ins, inputs_data, &dev_ins_ptr);
ConcatTensorWithSameShape<T, IndexT, decltype(ptr_array)>
<<<grid_dims, block_dims, 0, ctx.stream()>>>(
ptr_array, in_col, out_row, out_col, output->data<T>());
}
}
if (has_same_shape) {
if (in_num == 2) {
ConcatKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
inputs_data[0],
inputs_data[1],
in_col,
out_row,
out_col,
output->data<T>());
} else if (in_num == 3) {
ConcatKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
inputs_data[0],
inputs_data[1],
inputs_data[2],
in_col,
out_row,
out_col,
output->data<T>());
} else if (in_num == 4) {
ConcatKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
inputs_data[0],
inputs_data[1],
inputs_data[2],
inputs_data[3],
in_col,
out_row,
out_col,
output->data<T>());
} else {
ConcatKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
dev_ins_data, in_num, in_col, out_row, out_col, output->data<T>());
#undef IMPL_CONCAT_CUDA_KERNEL_CASE
} else {
#define IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE(size_, ...) \
case size_: { \
PointerAndColWrapper<T, IndexT, size_> ptr_col_array( \
ctx, ins, inputs_col_num, inputs_data, inputs_col); \
__VA_ARGS__; \
} break;
switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) {
IMPL_CONCATE_CUDA_KERNEL_HELPER(
IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE,
ConcatTensorWithDifferentShape<T, IndexT, decltype(ptr_col_array)>
<<<grid_dims, block_dims, 0, ctx.stream()>>>(ptr_col_array,
inputs_col_num,
out_row,
out_col,
output->data<T>()));
default: {
paddle::memory::AllocationPtr dev_ins_ptr{nullptr};
paddle::memory::AllocationPtr dev_col_ptr{nullptr};
PointerToPointerAndCol<T, IndexT> ptr_col_array(ctx,
ins,
inputs_col_num,
inputs_data,
inputs_col,
&dev_ins_ptr,
&dev_col_ptr);
ConcatTensorWithDifferentShape<T, IndexT, decltype(ptr_col_array)>
<<<grid_dims, block_dims, 0, ctx.stream()>>>(ptr_col_array,
inputs_col_num,
out_row,
out_col,
output->data<T>());
}
} else {
auto tmp_dev_ins_col_data = paddle::memory::Alloc(
context.GetPlace(),
inputs_col_num * sizeof(int64_t),
phi::Stream(reinterpret_cast<phi::StreamId>(context.stream())));
auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph(
inputs_col, inputs_col_num);
paddle::memory::Copy(context.GetPlace(),
tmp_dev_ins_col_data->ptr(),
paddle::platform::CPUPlace(),
restored,
inputs_col_num * sizeof(int64_t),
context.stream());
int64_t* dev_ins_col_data =
static_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
ConcatKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
dev_ins_data,
dev_ins_col_data,
static_cast<int>(inputs_col_num),
out_row,
out_col,
output->data<T>());
}
#undef IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE
}
#undef IMPL_CONCATE_CUDA_KERNEL_HELPER
#ifdef PADDLE_WITH_HIP
// Prevent the pinned memory value from being covered and release the memory
// after the launch kernel of the stream is executed (reapply pinned memory
// next time)
auto* data_alloc_released = data_alloc.release();
auto* col_alloc_released = col_alloc.release();
context.AddStreamCallback([data_alloc_released, col_alloc_released] {
VLOG(4) << "Delete cuda pinned at " << data_alloc_released;
VLOG(4) << "Delete cuda pinned at " << col_alloc_released;
paddle::memory::allocation::Allocator::AllocationDeleter(
data_alloc_released);
paddle::memory::allocation::Allocator::AllocationDeleter(
col_alloc_released);
});
// Prevent pinned memory from being covered and release the memory after
// kernel launch of the stream is executed (reapply pinned memory next time)
auto* data_alloc_released = data_alloc.release();
auto* col_alloc_released = col_alloc.release();
ctx.AddStreamCallback([data_alloc_released, col_alloc_released] {
VLOG(4) << "Delete cuda pinned at " << data_alloc_released;
VLOG(4) << "Delete cuda pinned at " << col_alloc_released;
paddle::memory::allocation::Allocator::AllocationDeleter(
data_alloc_released);
paddle::memory::allocation::Allocator::AllocationDeleter(
col_alloc_released);
});
#endif
}
template <typename T>
struct ConcatFunctor<phi::GPUContext, T> {
void operator()(const phi::GPUContext& context,
const std::vector<phi::DenseTensor>& input,
int axis,
phi::DenseTensor* output) {
if (output->numel() < std::numeric_limits<int32_t>::max()) {
ConcatFunctorWithIndexType<T, int32_t>(context, input, axis, output);
} else {
ConcatFunctorWithIndexType<T, int64_t>(context, input, axis, output);
}
}
};
......@@ -488,7 +526,7 @@ class SplitFunctor<phi::GPUContext, T> {
outputs_data, o_num);
paddle::memory::Copy(context.GetPlace(),
tmp_dev_outs_data->ptr(),
paddle::platform::CPUPlace(),
phi::CPUPlace(),
restored,
o_num * sizeof(T*),
context.stream());
......@@ -539,7 +577,7 @@ class SplitFunctor<phi::GPUContext, T> {
outputs_cols, outputs_cols_num);
paddle::memory::Copy(context.GetPlace(),
tmp_dev_ins_col_data->ptr(),
paddle::platform::CPUPlace(),
phi::CPUPlace(),
restored,
outputs_cols_num * sizeof(int64_t),
context.stream());
......
......@@ -25,7 +25,9 @@ namespace phi {
template <typename IndexT>
struct DivmodWarpper {
public:
void SetDivden(IndexT dividen) { divmoder = phi::funcs::FastDivMod(dividen); }
void SetDivisor(IndexT divisor) {
divmoder = phi::funcs::FastDivMod(divisor);
}
__device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) {
return divmoder.Divmod(val);
}
......@@ -39,7 +41,7 @@ struct DivmodWarpper<int64_t> {
public:
using DivModT = phi::AlignedVector<int64_t, 2>;
void SetDivden(int64_t dividen) { dividen_ = dividen; }
void SetDivisor(int64_t divisor) { dividen_ = divisor; }
__device__ inline DivModT div_mod(int64_t val) {
DivModT data;
data[0] = val / dividen_;
......@@ -51,15 +53,14 @@ struct DivmodWarpper<int64_t> {
int64_t dividen_;
};
constexpr int kWarpperSize = 64;
template <typename T, typename IndexT>
template <typename T, typename IndexT, int Size>
struct PointerArray : public DivmodWarpper<IndexT> {
public:
const T* data[kWarpperSize];
const T* data[Size];
PointerArray(const std::vector<const DenseTensor*>& x,
int num,
int64_t dividen) {
this->SetDivden(dividen);
IndexT divisor) {
this->SetDivisor(divisor);
for (auto i = 0; i < num; ++i) {
data[i] = x[i]->data<T>();
}
......@@ -69,33 +70,33 @@ struct PointerArray : public DivmodWarpper<IndexT> {
template <typename Context, typename T, typename IndexT>
struct PointerToPointer : public DivmodWarpper<IndexT> {
public:
T** data;
T** data{nullptr};
PointerToPointer(const Context& ctx,
const std::vector<const DenseTensor*>& x,
int num,
int64_t dividen) {
this->SetDivden(dividen);
auto byte_len = num * sizeof(T*);
IndexT num,
IndexT divisor,
paddle::memory::AllocationPtr* dev_ins_ptr) {
this->SetDivisor(divisor);
std::vector<const T*> x_datas(num);
for (int i = 0; i < num; ++i) {
x_datas[i] = x[i]->data<T>();
}
auto tmp_x_data = paddle::memory::Alloc(
*dev_ins_ptr = paddle::memory::Alloc(
ctx.GetPlace(),
byte_len,
num * sizeof(T*),
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
paddle::memory::Copy(ctx.GetPlace(),
tmp_x_data->ptr(),
(*dev_ins_ptr)->ptr(),
phi::CPUPlace(),
reinterpret_cast<void*>(x_datas.data()),
x_datas.size() * sizeof(T*),
num * sizeof(T*),
ctx.stream());
data = reinterpret_cast<T**>(tmp_x_data->ptr());
data = reinterpret_cast<T**>((*dev_ins_ptr)->ptr());
}
};
template <typename T, typename IndexT, typename WarpT>
__global__ void StackCUDAKernel(WarpT input_warpper,
template <typename T, typename IndexT, typename WrapT>
__global__ void StackCUDAKernel(WrapT input_warpper,
IndexT split_size,
IndexT rows,
IndexT cols,
......@@ -117,14 +118,56 @@ __global__ void StackCUDAKernel(WarpT input_warpper,
}
}
template <typename T, typename IndexT, typename Context>
void LaunchStackCUDAKernelWithIndexType(
const Context& ctx,
const IndexT x_col,
const IndexT x_row,
const IndexT out_col,
const phi::backends::gpu::GpuLaunchConfig& cfg,
const std::vector<const DenseTensor*>& x,
T* dst_data) {
int num = static_cast<int>(x.size());
#define IMPL_STACK_CUDA_KERNEL_CASE(size_, ...) \
case size_: { \
PointerArray<T, IndexT, size_> ptr_array(x, num, x_col); \
__VA_ARGS__; \
} break;
#define IMPL_STACK_CUDA_KERNEL_HELPER(...) \
IMPL_STACK_CUDA_KERNEL_CASE(2, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(8, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(16, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(32, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(64, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(128, ##__VA_ARGS__);
switch (phi::backends::gpu::RoundToNextHighPowOfTwo(num, 4)) {
IMPL_STACK_CUDA_KERNEL_HELPER(
StackCUDAKernel<T, IndexT, decltype(ptr_array)>
<<<cfg.block_per_grid, cfg.thread_per_block, 0, ctx.stream()>>>(
ptr_array, x_col, x_row, out_col, dst_data));
default: {
paddle::memory::AllocationPtr dev_ins_ptr{nullptr};
PointerToPointer<Context, T, IndexT> ptr_array(
ctx, x, num, x_col, &dev_ins_ptr);
StackCUDAKernel<T, IndexT, decltype(ptr_array)>
<<<cfg.block_per_grid, cfg.thread_per_block, 0, ctx.stream()>>>(
ptr_array, x_col, x_row, out_col, dst_data);
}
}
#undef IMPL_STACK_CUDA_KERNEL_HELPER
#undef IMPL_STACK_CUDA_KERNEL_CASE
}
template <typename T, typename Context>
void StackKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
int axis,
DenseTensor* out) {
if (axis < 0) axis += (x[0]->dims().size() + 1);
int n = static_cast<int>(x.size());
T* y_data = dev_ctx.template Alloc<T>(out);
int num = static_cast<int>(x.size());
T* dst_data = dev_ctx.template Alloc<T>(out);
// Split x dim from axis to matrix
int64_t x_row = 1, x_col = 1;
......@@ -132,40 +175,17 @@ void StackKernel(const Context& dev_ctx,
x_row *= x[0]->dims()[i];
}
x_col = x[0]->numel() / x_row;
int64_t out_col = x_col * n;
int64_t out_col = x_col * num;
auto config =
phi::backends::gpu::GetGpuLaunchConfig2D(dev_ctx, out_col, x_row);
#define IMPL_STACK_CUDA_KERNEL(index_t, input_warpper) \
StackCUDAKernel<T, index_t, decltype(input_warpper)> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
dev_ctx.stream()>>>(input_warpper, \
static_cast<index_t>(x_col), \
static_cast<index_t>(x_row), \
static_cast<index_t>(out_col), \
y_data);
bool use_int32 = out->numel() < std::numeric_limits<int32_t>::max();
if (n <= kWarpperSize) {
if (use_int32) {
PointerArray<T, int32_t> ptr_array(x, n, x_col);
IMPL_STACK_CUDA_KERNEL(int32_t, ptr_array);
} else {
PointerArray<T, int64_t> ptr_array(x, n, x_col);
IMPL_STACK_CUDA_KERNEL(int64_t, ptr_array);
}
if (out->numel() < std::numeric_limits<int32_t>::max()) {
LaunchStackCUDAKernelWithIndexType<T, int32_t, Context>(
dev_ctx, x_col, x_row, out_col, config, x, dst_data);
} else {
if (use_int32) {
PointerToPointer<Context, T, int32_t> ptr_array(dev_ctx, x, n, x_col);
IMPL_STACK_CUDA_KERNEL(int32_t, ptr_array);
} else {
PointerToPointer<Context, T, int64_t> ptr_array(dev_ctx, x, n, x_col);
IMPL_STACK_CUDA_KERNEL(int64_t, ptr_array);
}
LaunchStackCUDAKernelWithIndexType<T, int64_t, Context>(
dev_ctx, x_col, x_row, out_col, config, x, dst_data);
}
#undef IMPL_STACK_CUDA_KERNEL
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册