diff --git a/paddle/phi/kernels/funcs/concat_and_split_functor.cu b/paddle/phi/kernels/funcs/concat_and_split_functor.cu index fa663528eb0158fa75f8d3e86f83115621491816..dc9150e4f2c565b8d6b03e4120bd9ba19d874f3c 100644 --- a/paddle/phi/kernels/funcs/concat_and_split_functor.cu +++ b/paddle/phi/kernels/funcs/concat_and_split_functor.cu @@ -14,8 +14,8 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/fluid/memory/malloc.h" -#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/kernels/funcs/segmented_array.h" namespace phi { namespace funcs { @@ -45,6 +45,12 @@ static inline void GetBlockDims(const phi::GPUContext& context, *grid_dims = dim3(grid_cols, grid_rows, 1); } +#if !defined(_WIN32) +#define PADDLE_ALIGN(x) __attribute__((aligned(x))) +#else +#define PADDLE_ALIGN(x) +#endif + template struct PointerWrapper { public: @@ -55,12 +61,29 @@ struct PointerWrapper { PointerWrapper(const phi::GPUContext& ctx, const std::vector& ins, const T** pre_alloced_host_ptr) { + SetInputAddr(ins); + } + + protected: + void SetInputAddr(const std::vector& ins) { for (auto i = 0; i < ins.size(); ++i) { ins_addr[i] = ins[i].data(); } } }; +template +struct PADDLE_ALIGN(256) AlignedPointerWrapper + : public PointerWrapper { + public: + AlignedPointerWrapper() {} + AlignedPointerWrapper(const phi::GPUContext& ctx, + const std::vector& ins, + const T** pre_alloced_host_ptr) { + this->SetInputAddr(ins); + } +}; + template struct PointerToPointer { public: @@ -93,7 +116,7 @@ struct PointerToPointer { }; template -struct PointerAndColWrapper { +struct PADDLE_ALIGN(256) PointerAndColWrapper { public: IndexT col_length[Size]; PointerAndColWrapper(const phi::GPUContext& ctx, @@ -151,6 +174,8 @@ struct PointerToPointerAndCol { PointerToPointer ins_ptr_wrapper; }; +#undef PADDLE_ALIGN + template struct alignas(MovSize) Packed { __device__ Packed() { @@ -358,10 +383,10 @@ void DispatchConcatWithSameShapeKernelLimitNum( dim3 grid_dims; GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims); -#define IMPL_CONCAT_CUDA_KERNEL_CASE(size_, ...) \ - case size_: { \ - PointerWrapper ptr_array(ctx, ins, inputs_data); \ - __VA_ARGS__; \ +#define IMPL_CONCAT_CUDA_KERNEL_CASE(size_, ...) \ + case size_: { \ + AlignedPointerWrapper ptr_array(ctx, ins, inputs_data); \ + __VA_ARGS__; \ } break; switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) { @@ -519,108 +544,6 @@ void DispatchConcatKernel(const phi::GPUContext& ctx, } } -template -__global__ void SplitKernel_(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t* out_cols, - int out_cols_size, - T** outputs_data) { - int64_t curr_segment = 0; - int64_t curr_offset = out_cols[0]; - CUDA_KERNEL_LOOP_TYPE(tid_x, in_col, int64_t) { - int64_t curr_col_offset = out_cols[curr_segment + 1]; - while (curr_col_offset <= tid_x) { - curr_offset = curr_col_offset; - ++curr_segment; - curr_col_offset = out_cols[curr_segment + 1]; - } - - int64_t local_col = tid_x - curr_offset; - int64_t segment_width = curr_col_offset - curr_offset; - T* output_ptr = outputs_data[curr_segment]; - if (output_ptr != nullptr) { - int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) - output_ptr[tid_y * segment_width + local_col] = - input_data[tid_y * in_col + tid_x]; - } - } -} - -template -__device__ void SplitKernelDetail(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T** outputs_data) { - CUDA_KERNEL_LOOP_TYPE(tid_x, in_col, int64_t) { - int64_t split = tid_x / fixed_out_col; - int64_t in_offset = tid_x - split * fixed_out_col; - T* output_ptr = outputs_data[split]; - if (output_ptr != nullptr) { - int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) - output_ptr[tid_y * fixed_out_col + in_offset] = - input_data[tid_y * in_col + tid_x]; - } - } -} - -template -__global__ void SplitKernel_(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T** outputs_data) { - SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); -} - -template -__global__ void SplitKernel_(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T* outputs_addr0, - T* outputs_addr1) { - T* outputs_data[2]; - outputs_data[0] = outputs_addr0; - outputs_data[1] = outputs_addr1; - SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); -} - -template -__global__ void SplitKernel_(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T* outputs_addr0, - T* outputs_addr1, - T* outputs_addr2) { - T* outputs_data[3]; - outputs_data[0] = outputs_addr0; - outputs_data[1] = outputs_addr1; - outputs_data[2] = outputs_addr2; - SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); -} - -template -__global__ void SplitKernel_(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T* outputs_addr0, - T* outputs_addr1, - T* outputs_addr2, - T* outputs_addr3) { - T* outputs_data[4]; - outputs_data[0] = outputs_addr0; - outputs_data[1] = outputs_addr1; - outputs_data[2] = outputs_addr2; - outputs_data[3] = outputs_addr3; - SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); -} - /* * All tensors' dimension should be the same and the values of * each dimension must be the same, except the axis dimension. @@ -708,37 +631,152 @@ struct ConcatFunctor { } }; -template -class SplitFunctor { +template +struct PointerAndColArray + : public funcs::PointerArraySetter { public: - void operator()(const phi::GPUContext& context, - const phi::DenseTensor& input, - const std::vector& ref_inputs, - int axis, - std::vector* outputs) { - // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3 - // tensors of shape [0,1,4] - if (input.numel() == 0) { - return; + funcs::ValueArray val_array; + + PointerAndColArray() {} + PointerAndColArray(const phi::GPUContext& ctx, + const int out_col_num, + IndexT* out_cols, + std::vector* t, + T** pre_alloc_host_buf = nullptr) + : funcs::PointerArraySetter( + ctx, + t, + /*need_alloc=*/false, + /*use_cuda_graph=*/true, + pre_alloc_host_buf) { + IndexT* dev_ptr = nullptr; + if (Size == SegmentedArraySize::kVariableLength) { + size_t num_bytes = out_col_num * sizeof(IndexT); + dev_ptr = reinterpret_cast(this->AllocAndCopy( + ctx, reinterpret_cast(out_cols), num_bytes, true)); + val_array.Set(dev_ptr, out_col_num); + } else { + val_array.Set(out_cols, out_col_num); + } + } +}; + +template +__global__ void SplitTensorWithSameShape(const T* input_data, + const IndexT out_row, + const IndexT cumulative_col, + const IndexT fixed_out_col, + DataArrayT data_array) { + CUDA_KERNEL_LOOP_TYPE(tid_x, cumulative_col, IndexT) { + IndexT split = tid_x / fixed_out_col; + IndexT in_offset = tid_x - split * fixed_out_col; + T* output_ptr = data_array.data[split]; + if (output_ptr != nullptr) { + IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < out_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * fixed_out_col + in_offset] = + input_data[tid_y * cumulative_col + tid_x]; + } + } +} + +template +__global__ void SplitTensorWithDifferentShape(const T* input_data, + const IndexT out_row, + const IndexT cumulative_col, + DataArrayT data_array, + ValArrayT col_array) { + IndexT curr_segment = 0; + IndexT curr_offset = col_array.data[0]; + CUDA_KERNEL_LOOP_TYPE(tid_x, cumulative_col, IndexT) { + IndexT curr_col_offset = col_array.data[curr_segment + 1]; + while (curr_col_offset <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + curr_col_offset = col_array.data[curr_segment + 1]; } - // TODO(zcd): Add input data validity checking - int o_num = outputs->size(); - int64_t out_row = 1; - auto dim_0 = ref_inputs[0]->dims(); - for (int i = 0; i < axis; ++i) { - out_row *= dim_0[i]; + IndexT local_col = tid_x - curr_offset; + IndexT segment_width = curr_col_offset - curr_offset; + T* output_ptr = data_array.data[curr_segment]; + if (output_ptr != nullptr) { + IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < out_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * segment_width + local_col] = + input_data[tid_y * cumulative_col + tid_x]; } + } +} + +template +void SplitFunctionDispatchWithSameShape(const phi::GPUContext& ctx, + const IndexT out_col, + const IndexT out_row, + const IndexT cumulative_col, + const T* input_data, + std::vector* outs, + T** pre_alloc_host_buf) { + dim3 grid_dims; + dim3 block_dims; + GetBlockDims(ctx, out_row, cumulative_col, &block_dims, &grid_dims); + + funcs::PointerArraySetter setter( + ctx, + outs, + /*need_alloc=*/false, + /*use_cuda_graph=*/true, + pre_alloc_host_buf); + SplitTensorWithSameShape + <<>>( + input_data, out_row, cumulative_col, out_col, setter.array); +} + +template +void SplitFunctionDispatchWithDifferentShape( + const phi::GPUContext& ctx, + const int out_col_num, + const IndexT out_row, + const IndexT cumulative_col, + const T* input_data, + std::vector* outs, + IndexT* output_cols, + T** pre_alloc_host_buf) { + dim3 grid_dims; + dim3 block_dims; + GetBlockDims(ctx, out_row, cumulative_col, &block_dims, &grid_dims); + PointerAndColArray setter( + ctx, out_col_num, output_cols, outs, pre_alloc_host_buf); + + SplitTensorWithDifferentShape + <<>>( + input_data, out_row, cumulative_col, setter.array, setter.val_array); +} - int64_t out0_col = ref_inputs[0]->numel() / out_row; - int64_t in_col = 0, in_row = out_row; - bool has_same_shape = true; +template +void SplitFunctorDispatchWithIndexType( + const phi::GPUContext& ctx, + int axis, + const phi::DenseTensor& input, + const std::vector& ref_ins, + std::vector* outs) { + // TODO(zcd): Add input data validity checking + int out_num = outs->size(); + IndexT out_row = 1; + auto ref_dim = ref_ins[0]->dims(); + for (int i = 0; i < axis; ++i) { + out_row *= ref_dim[i]; + } + IndexT out_col = ref_ins[0]->numel() / out_row; + IndexT cumulative_col = 0; + bool has_same_shape = true; - int outputs_cols_num = o_num + 1; - std::vector outputs_data_vec(o_num); - std::vector outputs_cols_vec(outputs_cols_num); - T** outputs_data = outputs_data_vec.data(); - int64_t* outputs_cols = outputs_cols_vec.data(); + int out_cols_num = out_num + 1; + std::vector outputs_cols_vec(out_cols_num, 0); + IndexT* outs_cols = outputs_cols_vec.data(); + T** outs_data = nullptr; // There are some differences between hip runtime and NV runtime. // In NV, when the pageable memory data less than 64K is transferred from @@ -748,128 +786,90 @@ class SplitFunctor { // 3.2.6.1. Concurrent Execution between Host and Device // Memory copies from host to device of a memory block of 64 KB or less #ifdef PADDLE_WITH_HIP - paddle::memory::AllocationPtr data_alloc, cols_alloc; - // TODO(chentianyu03): try to find a method to remove the Alloc function - data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), - o_num * sizeof(T*)); - outputs_data = reinterpret_cast(data_alloc->ptr()); - // TODO(chentianyu03): try to find a method to remove the Alloc function - cols_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), - (outputs_cols_num) * sizeof(int64_t)); - outputs_cols = reinterpret_cast(cols_alloc->ptr()); + paddle::memory::AllocationPtr data_alloc, cols_alloc; + // TODO(chentianyu03): try to find a method to remove the Alloc function + data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), + out_num * sizeof(T*)); + outs_data = reinterpret_cast(data_alloc->ptr()); + // TODO(chentianyu03): try to find a method to remove the Alloc function + cols_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), + (out_cols_num) * sizeof(IndexT)); + outs_cols = reinterpret_cast(cols_alloc->ptr()); #endif - outputs_cols[0] = 0; - for (int i = 0; i < o_num; ++i) { - int64_t t_col = ref_inputs.at(i)->numel() / out_row; - if (has_same_shape) { - if (t_col != out0_col) has_same_shape = false; - } - in_col += t_col; - outputs_cols[i + 1] = in_col; - if (outputs->at(i) != nullptr) { - outputs_data[i] = outputs->at(i)->data(); - } else { - outputs_data[i] = nullptr; - } + outs_cols[0] = 0; + for (int i = 0; i < out_num; ++i) { + IndexT t_col = ref_ins.at(i)->numel() / out_row; + if (has_same_shape) { + has_same_shape &= (t_col == cumulative_col); } - - dim3 block_dims; - dim3 grid_dims; - GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims); - - paddle::memory::allocation::AllocationPtr tmp_dev_outs_data; - T** dev_out_gpu_data = nullptr; - if (!has_same_shape || o_num < 2 || o_num > 4) { - // TODO(chentianyu03): try to find a method to remove the Alloc function - tmp_dev_outs_data = paddle::memory::Alloc( - context.GetPlace(), - o_num * sizeof(T*), - phi::Stream(reinterpret_cast(context.stream()))); - auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph( - outputs_data, o_num); - paddle::memory::Copy(context.GetPlace(), - tmp_dev_outs_data->ptr(), - phi::CPUPlace(), - restored, - o_num * sizeof(T*), - context.stream()); - dev_out_gpu_data = reinterpret_cast(tmp_dev_outs_data->ptr()); + cumulative_col += t_col; + outs_cols[i + 1] = cumulative_col; + } + int limit_num = has_same_shape ? out_num : out_cols_num; + if (has_same_shape) { + switch (funcs::CalcArraySize(limit_num)) { + SEGMENTED_ARRAY_KERNEL_HELPER( + SplitFunctionDispatchWithSameShape( + ctx, + out_col, + out_row, + cumulative_col, + input.data(), + outs, + outs_data)); } - - if (has_same_shape) { - if (o_num == 2) { - SplitKernel_<<>>( - input.data(), - in_row, - in_col, - out0_col, - outputs_data[0], - outputs_data[1]); - } else if (o_num == 3) { - SplitKernel_<<>>( - input.data(), - in_row, - in_col, - out0_col, - outputs_data[0], - outputs_data[1], - outputs_data[2]); - } else if (o_num == 4) { - SplitKernel_<<>>( - input.data(), - in_row, - in_col, - out0_col, - outputs_data[0], - outputs_data[1], - outputs_data[2], - outputs_data[3]); - } else { - SplitKernel_<<>>( - input.data(), in_row, in_col, out0_col, dev_out_gpu_data); - } - } else { - auto tmp_dev_ins_col_data = - // TODO(chentianyu03): try to find a method to remove the Alloc - // function - paddle::memory::Alloc( - context.GetPlace(), - outputs_cols_num * sizeof(int64_t), - phi::Stream(reinterpret_cast(context.stream()))); - auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph( - outputs_cols, outputs_cols_num); - paddle::memory::Copy(context.GetPlace(), - tmp_dev_ins_col_data->ptr(), - phi::CPUPlace(), - restored, - outputs_cols_num * sizeof(int64_t), - context.stream()); - int64_t* dev_outs_col_data = - reinterpret_cast(tmp_dev_ins_col_data->ptr()); - - SplitKernel_<<>>( - input.data(), - in_row, - in_col, - dev_outs_col_data, - static_cast(outputs_cols_num), - dev_out_gpu_data); + } else { + switch (funcs::CalcArraySize(limit_num)) { + SEGMENTED_ARRAY_KERNEL_HELPER( + SplitFunctionDispatchWithDifferentShape( + ctx, + out_cols_num, + out_row, + cumulative_col, + input.data(), + outs, + outs_cols, + outs_data)); } + } #ifdef PADDLE_WITH_HIP - // Prevent the pinned memory value from being covered and release the memory - // after the launch kernel of the stream is executed (reapply pinned memory - // next time) - auto* data_alloc_released = data_alloc.release(); - auto* cols_alloc_released = cols_alloc.release(); - context.AddStreamCallback([data_alloc_released, cols_alloc_released] { - paddle::memory::allocation::Allocator::AllocationDeleter( - data_alloc_released); - paddle::memory::allocation::Allocator::AllocationDeleter( - cols_alloc_released); - }); + // Prevent pinned memory from being covered and release the memory after + // kernel launch of the stream is executed (reapply pinned memory next time) + auto* data_alloc_released = data_alloc.release(); + auto* cols_alloc_released = cols_alloc.release(); + ctx.AddStreamCallback([data_alloc_released, cols_alloc_released] { + paddle::memory::allocation::Allocator::AllocationDeleter( + data_alloc_released); + paddle::memory::allocation::Allocator::AllocationDeleter( + cols_alloc_released); + }); #endif +} + +template +class SplitFunctor { + public: + void operator()(const phi::GPUContext& context, + const phi::DenseTensor& input, + const std::vector& ref_inputs, + int axis, + std::vector* outputs) { + int64_t numel = input.numel(); + // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in + // 3 tensors of shape [0,1,4] + if (input.numel() == 0) { + return; + } + + if (numel < std::numeric_limits::max()) { + SplitFunctorDispatchWithIndexType( + context, axis, input, ref_inputs, outputs); + } else { + SplitFunctorDispatchWithIndexType( + context, axis, input, ref_inputs, outputs); + } } }; diff --git a/paddle/phi/kernels/funcs/segmented_array.h b/paddle/phi/kernels/funcs/segmented_array.h index aa03eb4e9fcd21e1c5ac6901bb7a30d7fd1f895a..cacaa8f81fe862aefdcf126cac4cceacdb0ec384 100644 --- a/paddle/phi/kernels/funcs/segmented_array.h +++ b/paddle/phi/kernels/funcs/segmented_array.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -34,6 +35,26 @@ enum class SegmentedArraySize { kFixed64 = 64, }; +template (Size)> +struct PADDLE_ALIGN(256) ValueArray { + public: + T data[Num]; + + void Set(T* ptr, const int num) { + for (auto i = 0; i < num; ++i) { + data[i] = ptr[i]; + } + } +}; + +template +struct PADDLE_ALIGN(256) ValueArray { + public: + T* data{nullptr}; + + void Set(T* ptr, const int num) { data = ptr; } +}; + template struct PADDLE_ALIGN(256) ConstPointerArray { public: @@ -62,8 +83,8 @@ struct PADDLE_ALIGN(256) PointerArray { public: T* data[static_cast(Size)]; - void Set(const std::vector& ptrs, T** dev_ptr = nullptr) { - for (auto i = 0; i < ptrs.size(); ++i) { + void Set(T** ptrs, const int num, T** dev_ptr = nullptr) { + for (auto i = 0; i < num; ++i) { data[i] = ptrs[i]; } } @@ -74,9 +95,7 @@ struct PADDLE_ALIGN(256) PointerArray { public: T** data{nullptr}; - void Set(const std::vector& ptrs, T** dev_ptr = nullptr) { - data = dev_ptr; - } + void Set(T** ptrs, const int num, T** dev_ptr = nullptr) { data = dev_ptr; } }; #undef PADDLE_ALIGN @@ -84,13 +103,24 @@ struct PADDLE_ALIGN(256) PointerArray { template struct ArraySetterBase { protected: - void* AllocAndCopy(const Context& ctx, void* src, size_t num_bytes) { + void* AllocAndCopy(const Context& ctx, + void* src, + size_t num_bytes, + bool use_cuda_graph = false) { allocation = paddle::memory::Alloc( ctx.GetPlace(), num_bytes, phi::Stream(reinterpret_cast(ctx.stream()))); + + int8_t* restored = reinterpret_cast(src); +#ifdef PADDLE_WITH_CUDA + if (use_cuda_graph) { + restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph( + restored, num_bytes); + } +#endif phi::backends::gpu::GpuMemcpyAsync(allocation->ptr(), - src, + restored, num_bytes, phi::gpuMemcpyHostToDevice, ctx.stream()); @@ -131,13 +161,28 @@ struct PointerArraySetter : public ArraySetterBase { public: PointerArray array; - PointerArraySetter(const Context& ctx, std::vector* t) { + // need_alloc : tensor data needs extra buffer or not. + // use_cuda_graph: tensor data shall be captured by cuda_graph or not. + // pre_alloc_host_buf: tensor data is temporaily stored by pinned memory or + // not. + PointerArraySetter(const Context& ctx, + std::vector* t, + bool need_alloc = false, + bool use_cuda_graph = false, + T** pre_alloc_host_buf = nullptr) { ptrs.resize(t->size()); + T** data_ptr = ptrs.data(); +#ifdef PADDLE_WITH_HIP + if (pre_alloc_host_buf) { + data_ptr = pre_alloc_host_buf; + } +#endif for (int i = 0; i < t->size(); ++i) { if (t->at(i) && (t->at(i)->numel() > 0)) { - ptrs[i] = ctx.template Alloc(t->at(i)); + data_ptr[i] = + need_alloc ? ctx.template Alloc(t->at(i)) : t->at(i)->data(); } else { - ptrs[i] = nullptr; + data_ptr[i] = nullptr; } } @@ -145,10 +190,9 @@ struct PointerArraySetter : public ArraySetterBase { if (Size == SegmentedArraySize::kVariableLength) { size_t num_bytes = t->size() * sizeof(T*); dev_ptr = reinterpret_cast(this->AllocAndCopy( - ctx, reinterpret_cast(ptrs.data()), num_bytes)); + ctx, reinterpret_cast(data_ptr), num_bytes, use_cuda_graph)); } - - array.Set(ptrs, dev_ptr); + array.Set(data_ptr, t->size(), dev_ptr); } private: diff --git a/paddle/phi/kernels/funcs/stack_and_unstack.h b/paddle/phi/kernels/funcs/stack_and_unstack.h index c516d4892bf629226bed9bb8f93cd3436792d639..0b2b5443383a94d506f7cc2a2749f0594b82bebc 100644 --- a/paddle/phi/kernels/funcs/stack_and_unstack.h +++ b/paddle/phi/kernels/funcs/stack_and_unstack.h @@ -192,7 +192,7 @@ void LaunchUnStackKernel(const Context& ctx, << ", out_col=" << out_col << ", num_splits=" << num_splits; auto x_ptr = x.data(); - PointerArraySetter setter(ctx, outs); + PointerArraySetter setter(ctx, outs, /*need_alloc=*/true); if (out_col == 1) { // For the case axis == (x.dims().size() - 1)