/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/funcs/segmented_array.h" namespace phi { namespace funcs { static inline void GetBlockDims(const phi::GPUContext& context, int64_t num_rows, int64_t num_cols, dim3* block_dims, dim3* grid_dims) { // Set the thread block and grid according to CurrentDeviceId const int kThreadsPerBlock = 1024; int block_cols = kThreadsPerBlock; if (num_cols < kThreadsPerBlock) { // block_cols is aligned by 32. block_cols = ((num_cols + 31) >> 5) << 5; } int block_rows = kThreadsPerBlock / block_cols; *block_dims = dim3(block_cols, block_rows, 1); constexpr int waves = 1; int max_threads = context.GetMaxPhysicalThreadCount() * waves; int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1); int grid_cols = std::min((num_cols + block_cols - 1) / block_cols, max_blocks); int grid_rows = std::min(max_blocks / grid_cols, std::max(num_rows / block_rows, (int64_t)1)); *grid_dims = dim3(grid_cols, grid_rows, 1); } #if !defined(_WIN32) #define PADDLE_ALIGN(x) __attribute__((aligned(x))) #else #define PADDLE_ALIGN(x) #endif template struct PointerWrapper { public: const void* ins_addr[Size]; __device__ inline const void* operator[](int i) const { return ins_addr[i]; } PointerWrapper() {} PointerWrapper(const phi::GPUContext& ctx, const std::vector& ins, const T** pre_alloced_host_ptr) { SetInputAddr(ins); } protected: void SetInputAddr(const std::vector& ins) { for (auto i = 0; i < ins.size(); ++i) { ins_addr[i] = ins[i].data(); } } }; template struct PADDLE_ALIGN(256) AlignedPointerWrapper : public PointerWrapper { public: AlignedPointerWrapper() {} AlignedPointerWrapper(const phi::GPUContext& ctx, const std::vector& ins, const T** pre_alloced_host_ptr) { this->SetInputAddr(ins); } }; template struct PointerToPointer { public: void** ins_addr{nullptr}; __device__ inline const void* operator[](int i) const { return ins_addr[i]; } PointerToPointer() {} PointerToPointer(const phi::GPUContext& ctx, const std::vector& ins, const T** pre_alloced_host_ptr, paddle::memory::AllocationPtr* dev_ins_ptr) { auto in_num = ins.size(); for (auto i = 0; i < in_num; ++i) { pre_alloced_host_ptr[i] = ins[i].data(); } *dev_ins_ptr = paddle::memory::Alloc( ctx.GetPlace(), in_num * sizeof(T*), phi::Stream(reinterpret_cast(ctx.stream()))); auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph( pre_alloced_host_ptr, in_num); paddle::memory::Copy(ctx.GetPlace(), (*dev_ins_ptr)->ptr(), phi::CPUPlace(), restored, in_num * sizeof(T*), ctx.stream()); ins_addr = reinterpret_cast((*dev_ins_ptr)->ptr()); } }; template struct PADDLE_ALIGN(256) PointerAndColWrapper { public: IndexT col_length[Size]; PointerAndColWrapper(const phi::GPUContext& ctx, const std::vector& ins, const IndexT& inputs_col_num, const T** pre_alloced_host_ptr, IndexT* inputs_col) { for (auto i = 0; i < inputs_col_num; ++i) { col_length[i] = inputs_col[i]; } ins_ptr_wrapper = PointerWrapper(ctx, ins, pre_alloced_host_ptr); } __device__ inline const void* operator[](int i) const { return ins_ptr_wrapper[i]; } private: PointerWrapper ins_ptr_wrapper; }; template struct PointerToPointerAndCol { public: IndexT* col_length{nullptr}; PointerToPointerAndCol(const phi::GPUContext& ctx, const std::vector& ins, const IndexT inputs_col_num, const T** pre_alloced_host_ptr, IndexT* inputs_col, paddle::memory::AllocationPtr* dev_ins_ptr, paddle::memory::AllocationPtr* dev_col_ptr) { *dev_col_ptr = paddle::memory::Alloc( ctx.GetPlace(), inputs_col_num * sizeof(IndexT), phi::Stream(reinterpret_cast(ctx.stream()))); auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph( inputs_col, inputs_col_num); paddle::memory::Copy(ctx.GetPlace(), (*dev_col_ptr)->ptr(), phi::CPUPlace(), restored, inputs_col_num * sizeof(IndexT), ctx.stream()); col_length = static_cast((*dev_col_ptr)->ptr()); ins_ptr_wrapper = PointerToPointer(ctx, ins, pre_alloced_host_ptr, dev_ins_ptr); } __device__ inline const void* operator[](int i) const { return ins_ptr_wrapper[i]; } private: PointerToPointer ins_ptr_wrapper; }; #undef PADDLE_ALIGN template struct alignas(MovSize) Packed { __device__ Packed() { // do nothing } union { char buf[MovSize]; }; }; template __global__ void ConcatTensorWithDifferentShape( const PointerAndColWrapperT ins_datas, int col_size, const IndexT output_rows, const IndexT output_cols, void* output) { Packed* dst = reinterpret_cast*>(output); IndexT curr_segment = 0; IndexT curr_offset = ins_datas.col_length[0]; CUDA_KERNEL_LOOP_TYPE(tid_x, output_cols, IndexT) { IndexT curr_col_offset = ins_datas.col_length[curr_segment + 1]; while (curr_col_offset <= tid_x) { curr_offset = curr_col_offset; ++curr_segment; curr_col_offset = ins_datas.col_length[curr_segment + 1]; } IndexT local_col = tid_x - curr_offset; IndexT segment_width = curr_col_offset - curr_offset; const Packed* input_ptr = reinterpret_cast*>(ins_datas[curr_segment]); IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) { dst[tid_y * output_cols + tid_x] = input_ptr[tid_y * segment_width + local_col]; } } } template __global__ void ConcatTensorWithSameShape(const PointerWrapperT ins_data, const IndexT fixed_in_col, const IndexT out_rows, const IndexT out_cols, void* output_data) { Packed* dst = reinterpret_cast*>(output_data); CUDA_KERNEL_LOOP_TYPE(tid_x, out_cols, IndexT) { IndexT split = tid_x / fixed_in_col; IndexT in_offset = tid_x - split * fixed_in_col; const Packed* input_ptr = reinterpret_cast*>(ins_data[split]); IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) { dst[tid_y * out_cols + tid_x] = input_ptr[tid_y * fixed_in_col + in_offset]; } } } #define IMPL_CONCATE_CUDA_KERNEL_HELPER(func_impl, ...) \ func_impl(4, ##__VA_ARGS__); \ func_impl(8, ##__VA_ARGS__); \ func_impl(16, ##__VA_ARGS__); \ func_impl(32, ##__VA_ARGS__); \ func_impl(64, ##__VA_ARGS__); \ func_impl(128, ##__VA_ARGS__); template void DispatchConcatWithDifferentShapeKernelLimitNum( const phi::GPUContext& ctx, const std::vector& ins, const IndexT inputs_col_num, const T** inputs_data, IndexT* inputs_col, const IndexT out_row, const IndexT out_col, phi::DenseTensor* output, const IndexT in_num, const IndexT limit_num) { dim3 block_dims; dim3 grid_dims; GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims); #define IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE(size_, ...) \ case size_: { \ PointerAndColWrapper ptr_col_array( \ ctx, ins, inputs_col_num, inputs_data, inputs_col); \ __VA_ARGS__; \ } break; switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) { IMPL_CONCATE_CUDA_KERNEL_HELPER( IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE, ConcatTensorWithDifferentShape <<>>( ptr_col_array, inputs_col_num, out_row, out_col, output->data())); default: { paddle::memory::AllocationPtr dev_ins_ptr{nullptr}; paddle::memory::AllocationPtr dev_col_ptr{nullptr}; PointerToPointerAndCol ptr_col_array(ctx, ins, inputs_col_num, inputs_data, inputs_col, &dev_ins_ptr, &dev_col_ptr); ConcatTensorWithDifferentShape <<>>( ptr_col_array, inputs_col_num, out_row, out_col, output->data()); } } #undef IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE } template void DispatchConcatWithDifferentShapeMovsize( const phi::GPUContext& ctx, const std::vector& ins, const IndexT inputs_col_num, const T** inputs_data, IndexT* inputs_col, const IndexT out_row, const IndexT out_col, phi::DenseTensor* output, const IndexT mov_size, const IndexT in_num, const IndexT limit_num) { if (mov_size == 16) { DispatchConcatWithDifferentShapeKernelLimitNum( ctx, ins, inputs_col_num, inputs_data, inputs_col, out_row, out_col, output, in_num, limit_num); } else if (mov_size == 8) { DispatchConcatWithDifferentShapeKernelLimitNum(ctx, ins, inputs_col_num, inputs_data, inputs_col, out_row, out_col, output, in_num, limit_num); } else if (mov_size == 4) { DispatchConcatWithDifferentShapeKernelLimitNum(ctx, ins, inputs_col_num, inputs_data, inputs_col, out_row, out_col, output, in_num, limit_num); } else if (mov_size == 2) { DispatchConcatWithDifferentShapeKernelLimitNum(ctx, ins, inputs_col_num, inputs_data, inputs_col, out_row, out_col, output, in_num, limit_num); } else { DispatchConcatWithDifferentShapeKernelLimitNum(ctx, ins, inputs_col_num, inputs_data, inputs_col, out_row, out_col, output, in_num, limit_num); } } template void DispatchConcatWithSameShapeKernelLimitNum( const phi::GPUContext& ctx, const std::vector& ins, const T** inputs_data, IndexT in_col, const IndexT out_row, const IndexT out_col, phi::DenseTensor* output, const IndexT in_num, const IndexT limit_num) { dim3 block_dims; dim3 grid_dims; GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims); #define IMPL_CONCAT_CUDA_KERNEL_CASE(size_, ...) \ case size_: { \ AlignedPointerWrapper ptr_array(ctx, ins, inputs_data); \ __VA_ARGS__; \ } break; switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) { IMPL_CONCATE_CUDA_KERNEL_HELPER( IMPL_CONCAT_CUDA_KERNEL_CASE, ConcatTensorWithSameShape <<>>( ptr_array, in_col, out_row, out_col, output->data())); default: { paddle::memory::AllocationPtr dev_ins_ptr{nullptr}; PointerToPointer ptr_array(ctx, ins, inputs_data, &dev_ins_ptr); ConcatTensorWithSameShape <<>>( ptr_array, in_col, out_row, out_col, output->data()); } } #undef IMPL_CONCAT_CUDA_KERNEL_CASE } #undef IMPL_CONCATE_CUDA_KERNEL_HELPER template void DispatchConcatWithSameShapeMovsize( const phi::GPUContext& ctx, const std::vector& ins, const T** inputs_data, IndexT in_col, const IndexT out_row, const IndexT out_col, phi::DenseTensor* output, const IndexT mov_size, const IndexT in_num, const IndexT limit_num) { if (mov_size == 16) { DispatchConcatWithSameShapeKernelLimitNum(ctx, ins, inputs_data, in_col, out_row, out_col, output, in_num, limit_num); } else if (mov_size == 8) { DispatchConcatWithSameShapeKernelLimitNum(ctx, ins, inputs_data, in_col, out_row, out_col, output, in_num, limit_num); } else if (mov_size == 4) { DispatchConcatWithSameShapeKernelLimitNum(ctx, ins, inputs_data, in_col, out_row, out_col, output, in_num, limit_num); } else if (mov_size == 2) { DispatchConcatWithSameShapeKernelLimitNum(ctx, ins, inputs_data, in_col, out_row, out_col, output, in_num, limit_num); } else { DispatchConcatWithSameShapeKernelLimitNum(ctx, ins, inputs_data, in_col, out_row, out_col, output, in_num, limit_num); } } template void DispatchConcatKernel(const phi::GPUContext& ctx, const std::vector& ins, const IndexT inputs_col_num, const T** inputs_data, IndexT* inputs_col, const IndexT out_row, const IndexT out_col, phi::DenseTensor* output, const IndexT in_num, const IndexT limit_num, bool has_same_shape) { constexpr IndexT MaxVecSize = 16 / sizeof(T); bool find_vecsize_flag = false; IndexT dispatch_vec_size = 1; auto output_data = reinterpret_cast(output->data()); for (IndexT vec_size = MaxVecSize; vec_size > 0; vec_size /= 2) { const IndexT mov_size = vec_size * sizeof(T); for (IndexT idx = 1; idx < in_num + 1; idx++) { auto input_data = reinterpret_cast(inputs_data[idx - 1]); // Since input_cols[0] is 0, we need to jump. const IndexT input_col = inputs_col[idx] - inputs_col[idx - 1]; if (input_col % vec_size == 0 && output_data % mov_size == 0 && input_data % mov_size == 0) { if (idx == in_num) { find_vecsize_flag = true; } } else { break; } } if (find_vecsize_flag) { dispatch_vec_size = vec_size; break; } } const int64_t vectorized_out_col = out_col / dispatch_vec_size; for (IndexT idx = 0; idx < in_num + 1; idx++) { inputs_col[idx] /= dispatch_vec_size; } const IndexT mov_size = sizeof(T) * dispatch_vec_size; if (has_same_shape) { // In same shape situation, each input's col are equal, so here we select to // use inputs_col[1]. DispatchConcatWithSameShapeMovsize(ctx, ins, inputs_data, inputs_col[1], out_row, vectorized_out_col, output, mov_size, in_num, limit_num); } else { DispatchConcatWithDifferentShapeMovsize(ctx, ins, inputs_col_num, inputs_data, inputs_col, out_row, vectorized_out_col, output, mov_size, in_num, limit_num); } } /* * All tensors' dimension should be the same and the values of * each dimension must be the same, except the axis dimension. */ template void ConcatFunctorWithIndexType(const phi::GPUContext& ctx, const std::vector& ins, int axis, phi::DenseTensor* output) { // TODO(zcd): Add input data validity checking IndexT in_num = ins.size(); IndexT in_row = 1; auto dim_0 = ins[0].dims(); for (int i = 0; i < axis; ++i) { in_row *= dim_0[i]; } IndexT in_col = ins[0].numel() / in_row; IndexT out_row = in_row, out_col = 0; IndexT inputs_col_num = in_num + 1; std::vector inputs_data_vec(in_num, nullptr); std::vector inputs_col_vec(inputs_col_num, 0); const T** inputs_data = inputs_data_vec.data(); IndexT* inputs_col = inputs_col_vec.data(); #ifdef PADDLE_WITH_HIP // TODO(chentianyu03): try to find a method to remove the Alloc function paddle::memory::AllocationPtr data_alloc = paddle::memory::Alloc( paddle::platform::CUDAPinnedPlace(), in_num * sizeof(T*)); inputs_data = reinterpret_cast(data_alloc->ptr()); paddle::memory::AllocationPtr col_alloc = paddle::memory::Alloc( paddle::platform::CUDAPinnedPlace(), inputs_col_num * sizeof(IndexT)); inputs_col = reinterpret_cast(col_alloc->ptr()); #endif bool has_same_shape = true; for (int i = 0; i < in_num; ++i) { IndexT t_cols = ins[i].numel() / in_row; if (has_same_shape) { has_same_shape &= (t_cols == in_col); } out_col += t_cols; inputs_col[i + 1] = out_col; } IndexT limit_num = has_same_shape ? in_num : inputs_col_num; DispatchConcatKernel(ctx, ins, inputs_col_num, inputs_data, inputs_col, out_row, out_col, output, in_num, limit_num, has_same_shape); #ifdef PADDLE_WITH_HIP // Prevent pinned memory from being covered and release the memory after // kernel launch of the stream is executed (reapply pinned memory next time) auto* data_alloc_released = data_alloc.release(); auto* col_alloc_released = col_alloc.release(); ctx.AddStreamCallback([data_alloc_released, col_alloc_released] { VLOG(4) << "Delete cuda pinned at " << data_alloc_released; VLOG(4) << "Delete cuda pinned at " << col_alloc_released; paddle::memory::allocation::Allocator::AllocationDeleter( data_alloc_released); paddle::memory::allocation::Allocator::AllocationDeleter( col_alloc_released); }); #endif } template struct ConcatFunctor { void operator()(const phi::GPUContext& context, const std::vector& input, int axis, phi::DenseTensor* output) { if (output->numel() < std::numeric_limits::max()) { ConcatFunctorWithIndexType(context, input, axis, output); } else { ConcatFunctorWithIndexType(context, input, axis, output); } } }; template struct PointerAndColArray : public funcs::PointerArraySetter { public: funcs::ValueArray val_array; PointerAndColArray() {} PointerAndColArray(const phi::GPUContext& ctx, const int out_col_num, IndexT* out_cols, std::vector* t, T** pre_alloc_host_buf = nullptr) : funcs::PointerArraySetter( ctx, t, /*need_alloc=*/false, /*use_cuda_graph=*/true, pre_alloc_host_buf) { IndexT* dev_ptr = nullptr; if (Size == SegmentedArraySize::kVariableLength) { size_t num_bytes = out_col_num * sizeof(IndexT); dev_ptr = reinterpret_cast(this->AllocAndCopy( ctx, reinterpret_cast(out_cols), num_bytes, true)); val_array.Set(dev_ptr, out_col_num); } else { val_array.Set(out_cols, out_col_num); } } }; template __global__ void SplitTensorWithSameShape(const T* input_data, const IndexT out_row, const IndexT cumulative_col, const IndexT fixed_out_col, DataArrayT data_array) { CUDA_KERNEL_LOOP_TYPE(tid_x, cumulative_col, IndexT) { IndexT split = tid_x / fixed_out_col; IndexT in_offset = tid_x - split * fixed_out_col; T* output_ptr = data_array.data[split]; if (output_ptr != nullptr) { IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < out_row; tid_y += blockDim.y * gridDim.y) output_ptr[tid_y * fixed_out_col + in_offset] = input_data[tid_y * cumulative_col + tid_x]; } } } template __global__ void SplitTensorWithDifferentShape(const T* input_data, const IndexT out_row, const IndexT cumulative_col, DataArrayT data_array, ValArrayT col_array) { IndexT curr_segment = 0; IndexT curr_offset = col_array.data[0]; CUDA_KERNEL_LOOP_TYPE(tid_x, cumulative_col, IndexT) { IndexT curr_col_offset = col_array.data[curr_segment + 1]; while (curr_col_offset <= tid_x) { curr_offset = curr_col_offset; ++curr_segment; curr_col_offset = col_array.data[curr_segment + 1]; } IndexT local_col = tid_x - curr_offset; IndexT segment_width = curr_col_offset - curr_offset; T* output_ptr = data_array.data[curr_segment]; if (output_ptr != nullptr) { IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < out_row; tid_y += blockDim.y * gridDim.y) output_ptr[tid_y * segment_width + local_col] = input_data[tid_y * cumulative_col + tid_x]; } } } template void SplitFunctionDispatchWithSameShape(const phi::GPUContext& ctx, const IndexT out_col, const IndexT out_row, const IndexT cumulative_col, const T* input_data, std::vector* outs, T** pre_alloc_host_buf) { dim3 grid_dims; dim3 block_dims; GetBlockDims(ctx, out_row, cumulative_col, &block_dims, &grid_dims); funcs::PointerArraySetter setter( ctx, outs, /*need_alloc=*/false, /*use_cuda_graph=*/true, pre_alloc_host_buf); SplitTensorWithSameShape <<>>( input_data, out_row, cumulative_col, out_col, setter.array); } template void SplitFunctionDispatchWithDifferentShape( const phi::GPUContext& ctx, const int out_col_num, const IndexT out_row, const IndexT cumulative_col, const T* input_data, std::vector* outs, IndexT* output_cols, T** pre_alloc_host_buf) { dim3 grid_dims; dim3 block_dims; GetBlockDims(ctx, out_row, cumulative_col, &block_dims, &grid_dims); PointerAndColArray setter( ctx, out_col_num, output_cols, outs, pre_alloc_host_buf); SplitTensorWithDifferentShape <<>>( input_data, out_row, cumulative_col, setter.array, setter.val_array); } template void SplitFunctorDispatchWithIndexType( const phi::GPUContext& ctx, int axis, const phi::DenseTensor& input, const std::vector& ref_ins, std::vector* outs) { // TODO(zcd): Add input data validity checking int out_num = outs->size(); IndexT out_row = 1; auto ref_dim = ref_ins[0]->dims(); for (int i = 0; i < axis; ++i) { out_row *= ref_dim[i]; } IndexT out_col = ref_ins[0]->numel() / out_row; IndexT cumulative_col = 0; bool has_same_shape = true; int out_cols_num = out_num + 1; std::vector outputs_cols_vec(out_cols_num, 0); IndexT* outs_cols = outputs_cols_vec.data(); T** outs_data = nullptr; // There are some differences between hip runtime and NV runtime. // In NV, when the pageable memory data less than 64K is transferred from // hosttodevice, it will be automatically asynchronous. // However, only pinned memory in hip can copy asynchronously // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device // 3.2.6.1. Concurrent Execution between Host and Device // Memory copies from host to device of a memory block of 64 KB or less #ifdef PADDLE_WITH_HIP paddle::memory::AllocationPtr data_alloc, cols_alloc; // TODO(chentianyu03): try to find a method to remove the Alloc function data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), out_num * sizeof(T*)); outs_data = reinterpret_cast(data_alloc->ptr()); // TODO(chentianyu03): try to find a method to remove the Alloc function cols_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), (out_cols_num) * sizeof(IndexT)); outs_cols = reinterpret_cast(cols_alloc->ptr()); #endif outs_cols[0] = 0; for (int i = 0; i < out_num; ++i) { IndexT t_col = ref_ins.at(i)->numel() / out_row; if (has_same_shape) { has_same_shape &= (t_col == cumulative_col); } cumulative_col += t_col; outs_cols[i + 1] = cumulative_col; } int limit_num = has_same_shape ? out_num : out_cols_num; if (has_same_shape) { switch (funcs::CalcArraySize(limit_num)) { SEGMENTED_ARRAY_KERNEL_HELPER( SplitFunctionDispatchWithSameShape( ctx, out_col, out_row, cumulative_col, input.data(), outs, outs_data)); } } else { switch (funcs::CalcArraySize(limit_num)) { SEGMENTED_ARRAY_KERNEL_HELPER( SplitFunctionDispatchWithDifferentShape( ctx, out_cols_num, out_row, cumulative_col, input.data(), outs, outs_cols, outs_data)); } } #ifdef PADDLE_WITH_HIP // Prevent pinned memory from being covered and release the memory after // kernel launch of the stream is executed (reapply pinned memory next time) auto* data_alloc_released = data_alloc.release(); auto* cols_alloc_released = cols_alloc.release(); ctx.AddStreamCallback([data_alloc_released, cols_alloc_released] { paddle::memory::allocation::Allocator::AllocationDeleter( data_alloc_released); paddle::memory::allocation::Allocator::AllocationDeleter( cols_alloc_released); }); #endif } template class SplitFunctor { public: void operator()(const phi::GPUContext& context, const phi::DenseTensor& input, const std::vector& ref_inputs, int axis, std::vector* outputs) { int64_t numel = input.numel(); // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in // 3 tensors of shape [0,1,4] if (input.numel() == 0) { return; } if (numel < std::numeric_limits::max()) { SplitFunctorDispatchWithIndexType( context, axis, input, ref_inputs, outputs); } else { SplitFunctorDispatchWithIndexType( context, axis, input, ref_inputs, outputs); } } }; #define DEFINE_FUNCTOR(type) \ template class ConcatFunctor; \ template class SplitFunctor FOR_ALL_TYPES(DEFINE_FUNCTOR); } // namespace funcs } // namespace phi