diff --git a/paddle/fluid/operators/math/concat_and_split.cu b/paddle/fluid/operators/math/concat_and_split.cu index fa04dc8d3ca7b301d892eb61dab813722ce2a106..153e6117227bf9fd273f83f8e64f859a54380053 100644 --- a/paddle/fluid/operators/math/concat_and_split.cu +++ b/paddle/fluid/operators/math/concat_and_split.cu @@ -77,6 +77,33 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, output_data); } +template +__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, + const T* input_addr2, const int fixed_in_col, + const int out_rows, const int out_cols, + T* output_data) { + const T* inputs_data[3]; + inputs_data[0] = input_addr0; + inputs_data[1] = input_addr1; + inputs_data[2] = input_addr2; + ConcatKernelDetail(inputs_data, fixed_in_col, out_rows, out_cols, + output_data); +} + +template +__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, + const T* input_addr2, const T* input_addr3, + const int fixed_in_col, const int out_rows, + const int out_cols, T* output_data) { + const T* inputs_data[4]; + inputs_data[0] = input_addr0; + inputs_data[1] = input_addr1; + inputs_data[2] = input_addr2; + inputs_data[3] = input_addr3; + ConcatKernelDetail(inputs_data, fixed_in_col, out_rows, out_cols, + output_data); +} + template __global__ void ConcatKernel(const T** inputs_data, const int in_num, const int fixed_in_col, const int out_rows, @@ -147,6 +174,31 @@ __global__ void SplitKernel(const T* input_data, const int in_row, SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); } +template +__global__ void SplitKernel(const T* input_data, const int in_row, + const int in_col, const int fixed_out_col, + T* outputs_addr0, T* outputs_addr1, + T* outputs_addr2) { + T* outputs_data[3]; + outputs_data[0] = outputs_addr0; + outputs_data[1] = outputs_addr1; + outputs_data[2] = outputs_addr2; + SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); +} + +template +__global__ void SplitKernel(const T* input_data, const int in_row, + const int in_col, const int fixed_out_col, + T* outputs_addr0, T* outputs_addr1, + T* outputs_addr2, T* outputs_addr3) { + T* outputs_data[4]; + outputs_data[0] = outputs_addr0; + outputs_data[1] = outputs_addr1; + outputs_data[2] = outputs_addr2; + outputs_data[3] = outputs_addr3; + SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); +} + static inline void GetBlockDims(const platform::CUDADeviceContext& context, int num_rows, int num_cols, dim3* block_dims, dim3* grid_dims) { @@ -210,7 +262,7 @@ class ConcatFunctor { memory::allocation::AllocationPtr tmp_dev_ins_data; const T** dev_ins_data = nullptr; - if (!has_same_shape || (in_num != 2)) { + if (!has_same_shape || in_num < 2 || in_num > 4) { tmp_dev_ins_data = platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( inputs_data.size() * sizeof(T*)); @@ -226,6 +278,14 @@ class ConcatFunctor { ConcatKernel<<>>( inputs_data[0], inputs_data[1], in_col, out_row, out_col, output->data()); + } else if (in_num == 3) { + ConcatKernel<<>>( + inputs_data[0], inputs_data[1], inputs_data[2], in_col, out_row, + out_col, output->data()); + } else if (in_num == 4) { + ConcatKernel<<>>( + inputs_data[0], inputs_data[1], inputs_data[2], inputs_data[3], + in_col, out_row, out_col, output->data()); } else { ConcatKernel<<>>( dev_ins_data, in_num, in_col, out_row, out_col, output->data()); @@ -294,7 +354,7 @@ class SplitFunctor { memory::allocation::AllocationPtr tmp_dev_outs_data; T** dev_out_gpu_data = nullptr; - if (!has_same_shape || (o_num != 2)) { + if (!has_same_shape || o_num < 2 || o_num > 4) { tmp_dev_outs_data = platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( outputs_data.size() * sizeof(T*)); @@ -310,6 +370,14 @@ class SplitFunctor { SplitKernel<<>>( input.data(), in_row, in_col, out0_col, outputs_data[0], outputs_data[1]); + } else if (o_num == 3) { + SplitKernel<<>>( + input.data(), in_row, in_col, out0_col, outputs_data[0], + outputs_data[1], outputs_data[2]); + } else if (o_num == 4) { + SplitKernel<<>>( + input.data(), in_row, in_col, out0_col, outputs_data[0], + outputs_data[1], outputs_data[2], outputs_data[3]); } else { SplitKernel<<>>( input.data(), in_row, in_col, out0_col, dev_out_gpu_data);