未验证 提交 1a0b3661 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Add concat optimization (#49540)

* add concat optimization

* refine

* remove annotation

* use alignas instead of aligned_storage
上级 591be3bd
...@@ -20,18 +20,43 @@ limitations under the License. */ ...@@ -20,18 +20,43 @@ limitations under the License. */
namespace phi { namespace phi {
namespace funcs { namespace funcs {
static inline void GetBlockDims(const phi::GPUContext& context,
int64_t num_rows,
int64_t num_cols,
dim3* block_dims,
dim3* grid_dims) {
// Set the thread block and grid according to CurrentDeviceId
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
if (num_cols < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((num_cols + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
*block_dims = dim3(block_cols, block_rows, 1);
constexpr int waves = 1;
int max_threads = context.GetMaxPhysicalThreadCount() * waves;
int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
std::min((num_cols + block_cols - 1) / block_cols, max_blocks);
int grid_rows = std::min(max_blocks / grid_cols,
std::max(num_rows / block_rows, (int64_t)1));
*grid_dims = dim3(grid_cols, grid_rows, 1);
}
template <typename T, int Size> template <typename T, int Size>
struct PointerWrapper { struct PointerWrapper {
public: public:
const T* ins_addr[Size]; const void* ins_addr[Size];
__device__ inline const T* operator[](int i) const { return ins_addr[i]; } __device__ inline const void* operator[](int i) const { return ins_addr[i]; }
PointerWrapper() {} PointerWrapper() {}
PointerWrapper(const phi::GPUContext& ctx, PointerWrapper(const phi::GPUContext& ctx,
const std::vector<phi::DenseTensor>& ins, const std::vector<phi::DenseTensor>& ins,
const T** pre_alloced_host_ptr) { const T** pre_alloced_host_ptr) {
for (auto i = 0; i < ins.size(); ++i) { for (auto i = 0; i < ins.size(); ++i) {
ins_addr[i] = ins[i].data<T>(); ins_addr[i] = ins[i].data();
} }
} }
}; };
...@@ -39,8 +64,8 @@ struct PointerWrapper { ...@@ -39,8 +64,8 @@ struct PointerWrapper {
template <typename T> template <typename T>
struct PointerToPointer { struct PointerToPointer {
public: public:
T** ins_addr{nullptr}; void** ins_addr{nullptr};
__device__ inline const T* operator[](int i) const { return ins_addr[i]; } __device__ inline const void* operator[](int i) const { return ins_addr[i]; }
PointerToPointer() {} PointerToPointer() {}
PointerToPointer(const phi::GPUContext& ctx, PointerToPointer(const phi::GPUContext& ctx,
...@@ -63,7 +88,7 @@ struct PointerToPointer { ...@@ -63,7 +88,7 @@ struct PointerToPointer {
restored, restored,
in_num * sizeof(T*), in_num * sizeof(T*),
ctx.stream()); ctx.stream());
ins_addr = reinterpret_cast<T**>((*dev_ins_ptr)->ptr()); ins_addr = reinterpret_cast<void**>((*dev_ins_ptr)->ptr());
} }
}; };
...@@ -82,7 +107,7 @@ struct PointerAndColWrapper { ...@@ -82,7 +107,7 @@ struct PointerAndColWrapper {
ins_ptr_wrapper = PointerWrapper<T, Size>(ctx, ins, pre_alloced_host_ptr); ins_ptr_wrapper = PointerWrapper<T, Size>(ctx, ins, pre_alloced_host_ptr);
} }
__device__ inline const T* operator[](int i) const { __device__ inline const void* operator[](int i) const {
return ins_ptr_wrapper[i]; return ins_ptr_wrapper[i];
} }
...@@ -118,7 +143,7 @@ struct PointerToPointerAndCol { ...@@ -118,7 +143,7 @@ struct PointerToPointerAndCol {
PointerToPointer<T>(ctx, ins, pre_alloced_host_ptr, dev_ins_ptr); PointerToPointer<T>(ctx, ins, pre_alloced_host_ptr, dev_ins_ptr);
} }
__device__ inline const T* operator[](int i) const { __device__ inline const void* operator[](int i) const {
return ins_ptr_wrapper[i]; return ins_ptr_wrapper[i];
} }
...@@ -126,16 +151,31 @@ struct PointerToPointerAndCol { ...@@ -126,16 +151,31 @@ struct PointerToPointerAndCol {
PointerToPointer<T> ins_ptr_wrapper; PointerToPointer<T> ins_ptr_wrapper;
}; };
template <typename T, typename IndexT, typename PointerAndColWrapperT> template <int MovSize>
__global__ void ConcatTensorWithDifferentShape(PointerAndColWrapperT ins_datas, struct alignas(MovSize) Packed {
__device__ Packed() {
// do nothing
}
union {
char buf[MovSize];
};
};
template <typename IndexT, int MovSize, typename PointerAndColWrapperT>
__global__ void ConcatTensorWithDifferentShape(
const PointerAndColWrapperT ins_datas,
int col_size, int col_size,
const IndexT output_rows, const IndexT output_rows,
const IndexT output_cols, const IndexT output_cols,
T* output) { void* output) {
Packed<MovSize>* dst = reinterpret_cast<Packed<MovSize>*>(output);
IndexT curr_segment = 0; IndexT curr_segment = 0;
IndexT curr_offset = ins_datas.col_length[0]; IndexT curr_offset = ins_datas.col_length[0];
CUDA_KERNEL_LOOP_TYPE(tid_x, output_cols, IndexT) { CUDA_KERNEL_LOOP_TYPE(tid_x, output_cols, IndexT) {
IndexT curr_col_offset = ins_datas.col_length[curr_segment + 1]; IndexT curr_col_offset = ins_datas.col_length[curr_segment + 1];
while (curr_col_offset <= tid_x) { while (curr_col_offset <= tid_x) {
curr_offset = curr_col_offset; curr_offset = curr_col_offset;
++curr_segment; ++curr_segment;
...@@ -145,32 +185,335 @@ __global__ void ConcatTensorWithDifferentShape(PointerAndColWrapperT ins_datas, ...@@ -145,32 +185,335 @@ __global__ void ConcatTensorWithDifferentShape(PointerAndColWrapperT ins_datas,
IndexT local_col = tid_x - curr_offset; IndexT local_col = tid_x - curr_offset;
IndexT segment_width = curr_col_offset - curr_offset; IndexT segment_width = curr_col_offset - curr_offset;
const T* input_ptr = ins_datas[curr_segment]; const Packed<MovSize>* input_ptr =
reinterpret_cast<const Packed<MovSize>*>(ins_datas[curr_segment]);
IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y; IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
output[tid_y * output_cols + tid_x] = for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) {
dst[tid_y * output_cols + tid_x] =
input_ptr[tid_y * segment_width + local_col]; input_ptr[tid_y * segment_width + local_col];
} }
}
} }
template <typename T, typename IndexT, typename PointerWrapperT> template <typename IndexT, int MovSize, typename PointerWrapperT>
__global__ void ConcatTensorWithSameShape(PointerWrapperT ins_data, __global__ void ConcatTensorWithSameShape(const PointerWrapperT ins_data,
const IndexT fixed_in_col, const IndexT fixed_in_col,
const IndexT out_rows, const IndexT out_rows,
const IndexT out_cols, const IndexT out_cols,
T* output_data) { void* output_data) {
Packed<MovSize>* dst = reinterpret_cast<Packed<MovSize>*>(output_data);
CUDA_KERNEL_LOOP_TYPE(tid_x, out_cols, IndexT) { CUDA_KERNEL_LOOP_TYPE(tid_x, out_cols, IndexT) {
IndexT split = tid_x / fixed_in_col; IndexT split = tid_x / fixed_in_col;
IndexT in_offset = tid_x - split * fixed_in_col; IndexT in_offset = tid_x - split * fixed_in_col;
const T* input_ptr = ins_data[split]; const Packed<MovSize>* input_ptr =
reinterpret_cast<const Packed<MovSize>*>(ins_data[split]);
IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y; IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) { for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
output_data[tid_y * out_cols + tid_x] = dst[tid_y * out_cols + tid_x] =
input_ptr[tid_y * fixed_in_col + in_offset]; input_ptr[tid_y * fixed_in_col + in_offset];
} }
} }
} }
#define IMPL_CONCATE_CUDA_KERNEL_HELPER(func_impl, ...) \
func_impl(4, ##__VA_ARGS__); \
func_impl(8, ##__VA_ARGS__); \
func_impl(16, ##__VA_ARGS__); \
func_impl(32, ##__VA_ARGS__); \
func_impl(64, ##__VA_ARGS__); \
func_impl(128, ##__VA_ARGS__);
template <typename T, typename IndexT, int MovSize>
void DispatchConcatWithDifferentShapeKernelLimitNum(
const phi::GPUContext& ctx,
const std::vector<phi::DenseTensor>& ins,
const IndexT inputs_col_num,
const T** inputs_data,
IndexT* inputs_col,
const IndexT out_row,
const IndexT out_col,
phi::DenseTensor* output,
const IndexT in_num,
const IndexT limit_num) {
dim3 block_dims;
dim3 grid_dims;
GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims);
#define IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE(size_, ...) \
case size_: { \
PointerAndColWrapper<T, IndexT, size_> ptr_col_array( \
ctx, ins, inputs_col_num, inputs_data, inputs_col); \
__VA_ARGS__; \
} break;
switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) {
IMPL_CONCATE_CUDA_KERNEL_HELPER(
IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE,
ConcatTensorWithDifferentShape<IndexT, MovSize, decltype(ptr_col_array)>
<<<grid_dims, block_dims, 0, ctx.stream()>>>(
ptr_col_array, inputs_col_num, out_row, out_col, output->data()));
default: {
paddle::memory::AllocationPtr dev_ins_ptr{nullptr};
paddle::memory::AllocationPtr dev_col_ptr{nullptr};
PointerToPointerAndCol<T, IndexT> ptr_col_array(ctx,
ins,
inputs_col_num,
inputs_data,
inputs_col,
&dev_ins_ptr,
&dev_col_ptr);
ConcatTensorWithDifferentShape<IndexT, MovSize, decltype(ptr_col_array)>
<<<grid_dims, block_dims, 0, ctx.stream()>>>(
ptr_col_array, inputs_col_num, out_row, out_col, output->data());
}
}
#undef IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE
}
template <typename T, typename IndexT>
void DispatchConcatWithDifferentShapeMovsize(
const phi::GPUContext& ctx,
const std::vector<phi::DenseTensor>& ins,
const IndexT inputs_col_num,
const T** inputs_data,
IndexT* inputs_col,
const IndexT out_row,
const IndexT out_col,
phi::DenseTensor* output,
const IndexT mov_size,
const IndexT in_num,
const IndexT limit_num) {
if (mov_size == 16) {
DispatchConcatWithDifferentShapeKernelLimitNum<T, IndexT, 16>(
ctx,
ins,
inputs_col_num,
inputs_data,
inputs_col,
out_row,
out_col,
output,
in_num,
limit_num);
} else if (mov_size == 8) {
DispatchConcatWithDifferentShapeKernelLimitNum<T, IndexT, 8>(ctx,
ins,
inputs_col_num,
inputs_data,
inputs_col,
out_row,
out_col,
output,
in_num,
limit_num);
} else if (mov_size == 4) {
DispatchConcatWithDifferentShapeKernelLimitNum<T, IndexT, 4>(ctx,
ins,
inputs_col_num,
inputs_data,
inputs_col,
out_row,
out_col,
output,
in_num,
limit_num);
} else if (mov_size == 2) {
DispatchConcatWithDifferentShapeKernelLimitNum<T, IndexT, 2>(ctx,
ins,
inputs_col_num,
inputs_data,
inputs_col,
out_row,
out_col,
output,
in_num,
limit_num);
} else {
DispatchConcatWithDifferentShapeKernelLimitNum<T, IndexT, 1>(ctx,
ins,
inputs_col_num,
inputs_data,
inputs_col,
out_row,
out_col,
output,
in_num,
limit_num);
}
}
template <typename T, typename IndexT, int MovSize>
void DispatchConcatWithSameShapeKernelLimitNum(
const phi::GPUContext& ctx,
const std::vector<phi::DenseTensor>& ins,
const T** inputs_data,
IndexT in_col,
const IndexT out_row,
const IndexT out_col,
phi::DenseTensor* output,
const IndexT in_num,
const IndexT limit_num) {
dim3 block_dims;
dim3 grid_dims;
GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims);
#define IMPL_CONCAT_CUDA_KERNEL_CASE(size_, ...) \
case size_: { \
PointerWrapper<T, size_> ptr_array(ctx, ins, inputs_data); \
__VA_ARGS__; \
} break;
switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) {
IMPL_CONCATE_CUDA_KERNEL_HELPER(
IMPL_CONCAT_CUDA_KERNEL_CASE,
ConcatTensorWithSameShape<IndexT, MovSize, decltype(ptr_array)>
<<<grid_dims, block_dims, 0, ctx.stream()>>>(
ptr_array, in_col, out_row, out_col, output->data()));
default: {
paddle::memory::AllocationPtr dev_ins_ptr{nullptr};
PointerToPointer<T> ptr_array(ctx, ins, inputs_data, &dev_ins_ptr);
ConcatTensorWithSameShape<IndexT, MovSize, decltype(ptr_array)>
<<<grid_dims, block_dims, 0, ctx.stream()>>>(
ptr_array, in_col, out_row, out_col, output->data());
}
}
#undef IMPL_CONCAT_CUDA_KERNEL_CASE
}
#undef IMPL_CONCATE_CUDA_KERNEL_HELPER
template <typename T, typename IndexT>
void DispatchConcatWithSameShapeMovsize(
const phi::GPUContext& ctx,
const std::vector<phi::DenseTensor>& ins,
const T** inputs_data,
IndexT in_col,
const IndexT out_row,
const IndexT out_col,
phi::DenseTensor* output,
const IndexT mov_size,
const IndexT in_num,
const IndexT limit_num) {
if (mov_size == 16) {
DispatchConcatWithSameShapeKernelLimitNum<T, IndexT, 16>(ctx,
ins,
inputs_data,
in_col,
out_row,
out_col,
output,
in_num,
limit_num);
} else if (mov_size == 8) {
DispatchConcatWithSameShapeKernelLimitNum<T, IndexT, 8>(ctx,
ins,
inputs_data,
in_col,
out_row,
out_col,
output,
in_num,
limit_num);
} else if (mov_size == 4) {
DispatchConcatWithSameShapeKernelLimitNum<T, IndexT, 4>(ctx,
ins,
inputs_data,
in_col,
out_row,
out_col,
output,
in_num,
limit_num);
} else if (mov_size == 2) {
DispatchConcatWithSameShapeKernelLimitNum<T, IndexT, 2>(ctx,
ins,
inputs_data,
in_col,
out_row,
out_col,
output,
in_num,
limit_num);
} else {
DispatchConcatWithSameShapeKernelLimitNum<T, IndexT, 1>(ctx,
ins,
inputs_data,
in_col,
out_row,
out_col,
output,
in_num,
limit_num);
}
}
template <typename T, typename IndexT>
void DispatchConcatKernel(const phi::GPUContext& ctx,
const std::vector<phi::DenseTensor>& ins,
const IndexT inputs_col_num,
const T** inputs_data,
IndexT* inputs_col,
const IndexT out_row,
const IndexT out_col,
phi::DenseTensor* output,
const IndexT in_num,
const IndexT limit_num,
bool has_same_shape) {
constexpr IndexT MaxVecSize = 16 / sizeof(T);
bool find_vecsize_flag = false;
IndexT dispatch_vec_size = 1;
for (IndexT vec_size = MaxVecSize; vec_size > 0; vec_size /= 2) {
for (IndexT idx = 0; idx < in_num + 1; idx++) {
// Since input_cols[0] is 0, we need to jump.
const IndexT input_col = inputs_col[idx + 1] - inputs_col[idx];
if (input_col % vec_size == 0) {
if (idx == in_num - 1) {
find_vecsize_flag = true;
}
} else {
break;
}
}
if (find_vecsize_flag) {
dispatch_vec_size = vec_size;
break;
}
}
const int64_t vectorized_out_col = out_col / dispatch_vec_size;
for (IndexT idx = 0; idx < in_num + 1; idx++) {
inputs_col[idx] /= dispatch_vec_size;
}
const IndexT mov_size = sizeof(T) * dispatch_vec_size;
if (has_same_shape) {
// In same shape situation, each input's col are equal, so here we select to
// use inputs_col[1].
DispatchConcatWithSameShapeMovsize<T, IndexT>(ctx,
ins,
inputs_data,
inputs_col[1],
out_row,
vectorized_out_col,
output,
mov_size,
in_num,
limit_num);
} else {
DispatchConcatWithDifferentShapeMovsize<T, IndexT>(ctx,
ins,
inputs_col_num,
inputs_data,
inputs_col,
out_row,
vectorized_out_col,
output,
mov_size,
in_num,
limit_num);
}
}
template <typename T> template <typename T>
__global__ void SplitKernel_(const T* input_data, __global__ void SplitKernel_(const T* input_data,
const int64_t in_row, const int64_t in_row,
...@@ -273,30 +616,6 @@ __global__ void SplitKernel_(const T* input_data, ...@@ -273,30 +616,6 @@ __global__ void SplitKernel_(const T* input_data,
SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data); SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
} }
static inline void GetBlockDims(const phi::GPUContext& context,
int64_t num_rows,
int64_t num_cols,
dim3* block_dims,
dim3* grid_dims) {
// Set the thread block and grid according to CurrentDeviceId
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
if (num_cols < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((num_cols + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
*block_dims = dim3(block_cols, block_rows, 1);
int max_threads = context.GetMaxPhysicalThreadCount();
int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
std::min((num_cols + block_cols - 1) / block_cols, max_blocks);
int grid_rows = std::min(max_blocks / grid_cols,
std::max(num_rows / block_rows, (int64_t)1));
*grid_dims = dim3(grid_cols, grid_rows, 1);
}
/* /*
* All tensors' dimension should be the same and the values of * All tensors' dimension should be the same and the values of
* each dimension must be the same, except the axis dimension. * each dimension must be the same, except the axis dimension.
...@@ -340,79 +659,19 @@ void ConcatFunctorWithIndexType(const phi::GPUContext& ctx, ...@@ -340,79 +659,19 @@ void ConcatFunctorWithIndexType(const phi::GPUContext& ctx,
out_col += t_cols; out_col += t_cols;
inputs_col[i + 1] = out_col; inputs_col[i + 1] = out_col;
} }
dim3 block_dims;
dim3 grid_dims;
GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims);
IndexT limit_num = has_same_shape ? in_num : inputs_col_num; IndexT limit_num = has_same_shape ? in_num : inputs_col_num;
#define IMPL_CONCATE_CUDA_KERNEL_HELPER(func_impl, ...) \ DispatchConcatKernel<T, IndexT>(ctx,
func_impl(4, ##__VA_ARGS__); \
func_impl(8, ##__VA_ARGS__); \
func_impl(16, ##__VA_ARGS__); \
func_impl(32, ##__VA_ARGS__); \
func_impl(64, ##__VA_ARGS__); \
func_impl(128, ##__VA_ARGS__);
if (has_same_shape) {
#define IMPL_CONCAT_CUDA_KERNEL_CASE(size_, ...) \
case size_: { \
PointerWrapper<T, size_> ptr_array(ctx, ins, inputs_data); \
__VA_ARGS__; \
} break;
switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) {
IMPL_CONCATE_CUDA_KERNEL_HELPER(
IMPL_CONCAT_CUDA_KERNEL_CASE,
ConcatTensorWithSameShape<T, IndexT, decltype(ptr_array)>
<<<grid_dims, block_dims, 0, ctx.stream()>>>(
ptr_array, in_col, out_row, out_col, output->data<T>()));
default: {
paddle::memory::AllocationPtr dev_ins_ptr{nullptr};
PointerToPointer<T> ptr_array(ctx, ins, inputs_data, &dev_ins_ptr);
ConcatTensorWithSameShape<T, IndexT, decltype(ptr_array)>
<<<grid_dims, block_dims, 0, ctx.stream()>>>(
ptr_array, in_col, out_row, out_col, output->data<T>());
}
}
#undef IMPL_CONCAT_CUDA_KERNEL_CASE
} else {
#define IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE(size_, ...) \
case size_: { \
PointerAndColWrapper<T, IndexT, size_> ptr_col_array( \
ctx, ins, inputs_col_num, inputs_data, inputs_col); \
__VA_ARGS__; \
} break;
switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) {
IMPL_CONCATE_CUDA_KERNEL_HELPER(
IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE,
ConcatTensorWithDifferentShape<T, IndexT, decltype(ptr_col_array)>
<<<grid_dims, block_dims, 0, ctx.stream()>>>(ptr_col_array,
inputs_col_num,
out_row,
out_col,
output->data<T>()));
default: {
paddle::memory::AllocationPtr dev_ins_ptr{nullptr};
paddle::memory::AllocationPtr dev_col_ptr{nullptr};
PointerToPointerAndCol<T, IndexT> ptr_col_array(ctx,
ins, ins,
inputs_col_num, inputs_col_num,
inputs_data, inputs_data,
inputs_col, inputs_col,
&dev_ins_ptr,
&dev_col_ptr);
ConcatTensorWithDifferentShape<T, IndexT, decltype(ptr_col_array)>
<<<grid_dims, block_dims, 0, ctx.stream()>>>(ptr_col_array,
inputs_col_num,
out_row, out_row,
out_col, out_col,
output->data<T>()); output,
} in_num,
} limit_num,
#undef IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE has_same_shape);
}
#undef IMPL_CONCATE_CUDA_KERNEL_HELPER
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
// Prevent pinned memory from being covered and release the memory after // Prevent pinned memory from being covered and release the memory after
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册