From 6707142a796c5be74fe96b9b972fe27a3e7491ad Mon Sep 17 00:00:00 2001 From: Qi Shao <17864154871@163.com> Date: Wed, 10 May 2023 10:13:18 +0800 Subject: [PATCH] revert argsort to fix OOM bug (#53647) Revert argsort to the version without full sort algorithm implemented --- paddle/phi/kernels/gpu/argsort_kernel.cu | 425 ++++++----------------- 1 file changed, 110 insertions(+), 315 deletions(-) diff --git a/paddle/phi/kernels/gpu/argsort_kernel.cu b/paddle/phi/kernels/gpu/argsort_kernel.cu index 0ad699cdc64..64c0589db10 100644 --- a/paddle/phi/kernels/gpu/argsort_kernel.cu +++ b/paddle/phi/kernels/gpu/argsort_kernel.cu @@ -65,10 +65,8 @@ struct SegmentOffsetIter { int num_cols_; }; -#define PADDLE_CUDA_NUM_THREADS 1024 - template -static __global__ void FillIndex(T *indices, T num_rows, T num_cols) { +static __global__ void FillIndex(T* indices, T num_rows, T num_cols) { int col_id = threadIdx.x; int row_id = blockIdx.x; @@ -81,246 +79,23 @@ static __global__ void FillIndex(T *indices, T num_rows, T num_cols) { // Sort by flag descending, True: descending. False: Ascending. // Default is false. -static __global__ void FillIndexAndSegmentKernel(int2 *data, - int numel, - int nsort) { - CUDA_KERNEL_LOOP(idx, numel) { - auto segment = static_cast(idx / nsort); - auto sort = static_cast(idx % nsort); - data[idx] = int2{segment, sort}; - } -} - -#define CUB_WRAPPER(func, ctx, ...) \ - do { \ - size_t temp_storage_bytes = 0; \ - gpuError_t err; \ - err = func(nullptr, temp_storage_bytes, __VA_ARGS__); \ - PADDLE_ENFORCE_GPU_SUCCESS(err); \ - DenseTensor temp_storage; \ - int64_t temp_size = temp_storage_bytes; \ - temp_storage.Resize({temp_size}); \ - ctx.template Alloc(&temp_storage); \ - err = func(temp_storage.data(), temp_storage_bytes, __VA_ARGS__); \ - PADDLE_ENFORCE_GPU_SUCCESS(err); \ - } while (false) - -template -static void RadixSortPairs(const phi::GPUContext &ctx, - const KT *keys_in, - const VT *values_in, - KT *keys_out, - VT *values_out, - int64_t n, - bool descending = false, - int64_t begin_bit = 0, - int64_t end_bit = sizeof(KT) * 8) { - if (keys_out == nullptr) { - DenseTensor key_out_owner; - key_out_owner.Resize({n}); - ctx.template Alloc(&key_out_owner); - keys_out = key_out_owner.data(); - } - - if (descending) { - CUB_WRAPPER(cub::DeviceRadixSort::SortPairsDescending, - ctx, - keys_in, - keys_out, - values_in, - values_out, - n, - begin_bit, - end_bit, - ctx.stream()); - } else { - CUB_WRAPPER(cub::DeviceRadixSort::SortPairs, - ctx, - keys_in, - keys_out, - values_in, - values_out, - n, - begin_bit, - end_bit, - ctx.stream()); - } -} - -template -static void RadixSortKeys(const phi::GPUContext &ctx, - const KT *keys_in, - KT *keys_out, - int64_t n, - bool descending, - int64_t begin_bit, - int64_t end_bit) { - if (descending) { - CUB_WRAPPER(cub::DeviceRadixSort::SortKeysDescending, - ctx, - keys_in, - keys_out, - n, - begin_bit, - end_bit, - ctx.stream()); - } else { - CUB_WRAPPER(cub::DeviceRadixSort::SortKeys, - ctx, - keys_in, - keys_out, - n, - begin_bit, - end_bit, - ctx.stream()); - } -} - -template -static __global__ void SortPostprocessKernel(const T *in, - const int2 *i_s_ptr, - T *out, - int64_t *index, - int nsegments, - int nsort) { - CUDA_KERNEL_LOOP(i, nsegments * nsort) { - int segment = i / nsort; // segment_id - int j = i % nsort; - - int offset = segment * nsort; - const T *in_ = in + offset; - T *out_ = out + offset; - int64_t *index_ = index + offset; - const int2 *i_s_ptr_ = i_s_ptr + offset; - - int idx = i_s_ptr_[j].y; - index_[j] = idx; - out_[j] = in_[idx]; - } -} - -template -inline void SegmentedSortPairsByFullSort(const phi::GPUContext &ctx, - const T *const self_ptr, - T *const values_ptr, - int64_t *const indices_ptr, - const int64_t nsegments, - const int64_t nsort, - const int64_t n, - const bool descending) { - int64_t segment_bits = std::max( - 1L, static_cast(std::ceil(std::log2(nsegments)))); - - const auto numel = nsort * nsegments; - - DenseTensor indices_and_segment; - int64_t indices_and_segment_size = numel; - indices_and_segment.Resize({indices_and_segment_size * 2}); - ctx.template Alloc(&indices_and_segment); - auto i_s_ptr_base = indices_and_segment.data(); - auto i_s_ptr = reinterpret_cast(i_s_ptr_base); - - dim3 block = PADDLE_CUDA_NUM_THREADS; - auto block_num = (numel - 1) / PADDLE_CUDA_NUM_THREADS + 1; - dim3 grid = static_cast(block_num); - - auto cu_stream = ctx.stream(); - - FillIndexAndSegmentKernel<<>>( - i_s_ptr, numel, nsort); - - DenseTensor indices_and_segment2; - int64_t indices_and_segment2_size = numel; - indices_and_segment2.Resize({indices_and_segment2_size * 2}); - ctx.template Alloc(&indices_and_segment2); - auto i_s_ptr2_base = indices_and_segment2.data(); - auto i_s_ptr2 = reinterpret_cast(i_s_ptr2_base); - - RadixSortPairs( - ctx, self_ptr, i_s_ptr, nullptr, i_s_ptr2, n, descending); - - RadixSortKeys(ctx, - reinterpret_cast(i_s_ptr2), - reinterpret_cast(i_s_ptr), - n, - false, - 0, - segment_bits); - - SortPostprocessKernel<<>>( - self_ptr, i_s_ptr, values_ptr, indices_ptr, nsegments, nsort); -} - -// The method is called when # of the rows of the input is less than or equal to -// 4 -template -void ArgFullSortForTinyRows(const phi::GPUContext &ctx, - const DenseTensor *input, - DenseTensor *output, - DenseTensor *indices, - const IndexType num_rows, - const IndexType num_cols, - const bool descending) { - auto gpu_stream = ctx.stream(); - size_t temp_storage_bytes = -1; - - IndexType numel = num_rows * num_cols; - if (numel == 0) { - return; - } - - IndexType numel_or_intmax = - std::min(numel, static_cast(std::numeric_limits::max())); - IndexType nsort = num_cols; - IndexType nbatch = (numel_or_intmax / nsort) * nsort; - - T *sorted_out_ptr; - IndexType *sorted_indices_ptr; - const T *input_data = input->data(); - T *out = ctx.template Alloc(output); - IndexType *ind = ctx.template Alloc(indices); - sorted_out_ptr = out; - sorted_indices_ptr = ind; - - int64_t remaining = numel; - - while (remaining > 0) { - int64_t n = std::min(remaining, nbatch); - IndexType nsegments = n / nsort; - - SegmentedSortPairsByFullSort(ctx, - input_data, - sorted_out_ptr, - sorted_indices_ptr, - nsegments, - nsort, - n, - descending); - - remaining -= n; - input_data += n; - sorted_out_ptr += n; - sorted_indices_ptr += n; - } -} - -template -void ArgFullSort(const phi::GPUContext &ctx, - const DenseTensor *input, - DenseTensor *output, - DenseTensor *indices, - const IndexType num_rows, - const IndexType num_cols, +template +void ArgFullSort(const phi::GPUContext& ctx, + const DenseTensor* input, + DenseTensor* output, + DenseTensor* indices, + const IndType num_rows, + const IndType num_cols, const bool descending) { auto cu_stream = ctx.stream(); DenseTensor input_indices; - const std::vector dims = {num_rows, num_cols}; + const std::vector dims = {num_rows, num_cols}; auto dim = phi::make_ddim(dims); input_indices.Resize(dim); - ctx.template Alloc(&input_indices); + ctx.template Alloc(&input_indices); size_t temp_storage_bytes = -1; - auto ComputeBlockSize = [](IndexType col) { + auto ComputeBlockSize = [](IndType col) { if (col > 512) return 1024; else if (col > 256 && col <= 512) @@ -339,78 +114,118 @@ void ArgFullSort(const phi::GPUContext &ctx, int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX; // Init a index array FillIndex<<>>( - input_indices.data(), num_rows, num_cols); + input_indices.data(), num_rows, num_cols); - T *sorted_out_ptr; - IndexType *sorted_indices_ptr; - const T *inp = input->data(); - T *out = ctx.template Alloc(output); - IndexType *ind = ctx.template Alloc(indices); + T* sorted_out_ptr; + IndType* sorted_indices_ptr; + const T* inp = input->data(); + T* out = ctx.template Alloc(output); + IndType* ind = ctx.template Alloc(indices); sorted_out_ptr = out; sorted_indices_ptr = ind; // create iter for counting input - cub::CountingInputIterator counting_iter(0); + cub::CountingInputIterator counting_iter(0); // segment_offset is used for move to next row - cub::TransformInputIterator> + cub::CountingInputIterator> segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols)); gpuError_t err; if (descending) { - CUB_WRAPPER(cub::DeviceSegmentedRadixSort::SortPairsDescending, - ctx, - inp, - sorted_out_ptr, - input_indices.data(), - sorted_indices_ptr, - num_cols * num_rows, - num_rows, - segment_offsets_t, - segment_offsets_t + 1, - 0, - sizeof(T) * 8, - ctx.stream()); + err = cub::DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, + temp_storage_bytes, + inp, + sorted_out_ptr, + input_indices.data(), + sorted_indices_ptr, + num_cols * num_rows, + num_rows, + segment_offsets_t, + segment_offsets_t + 1, + 0, + sizeof(T) * 8, + cu_stream); } else { - CUB_WRAPPER(cub::DeviceSegmentedRadixSort::SortPairs, - ctx, - inp, - sorted_out_ptr, - input_indices.data(), - sorted_indices_ptr, - num_cols * num_rows, - num_rows, - segment_offsets_t, - segment_offsets_t + 1, - 0, - sizeof(T) * 8, - ctx.stream()); + err = + cub::DeviceSegmentedRadixSort::SortPairs(nullptr, + temp_storage_bytes, + inp, + sorted_out_ptr, + input_indices.data(), + sorted_indices_ptr, + num_cols * num_rows, + num_rows, + segment_offsets_t, + segment_offsets_t + 1, + 0, + sizeof(T) * 8, + cu_stream); } + PADDLE_ENFORCE_GPU_SUCCESS(err); + + DenseTensor temp_storage; + int64_t temp_size = temp_storage_bytes; + temp_storage.Resize({temp_size}); + ctx.template Alloc(&temp_storage); + + if (descending) { + err = cub::DeviceSegmentedRadixSort::SortPairsDescending( + temp_storage.data(), + temp_storage_bytes, + inp, + sorted_out_ptr, + input_indices.data(), + sorted_indices_ptr, + num_cols * num_rows, + num_rows, + segment_offsets_t, + segment_offsets_t + 1, + 0, + sizeof(T) * 8, + cu_stream); + } else { + err = + cub::DeviceSegmentedRadixSort::SortPairs(temp_storage.data(), + temp_storage_bytes, + inp, + sorted_out_ptr, + input_indices.data(), + sorted_indices_ptr, + num_cols * num_rows, + num_rows, + segment_offsets_t, + segment_offsets_t + 1, + 0, + sizeof(T) * 8, + cu_stream); + } + + PADDLE_ENFORCE_GPU_SUCCESS(err); } template -void ArgsortKernel(const Context &dev_ctx, - const DenseTensor &input, +void ArgsortKernel(const Context& dev_ctx, + const DenseTensor& input, int axis, bool descending, - DenseTensor *output, - DenseTensor *indices) { + DenseTensor* output, + DenseTensor* indices) { auto in_dims = input.dims(); auto rank = in_dims.size(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; - - const T *in_data = input.data(); + const T* in_data = input.data(); auto size = input.numel(); - T *out_data = dev_ctx.template Alloc(output); - int64_t *ids_data = dev_ctx.template Alloc(indices); + T* out_data = dev_ctx.template Alloc(output); + int64_t* ids_data = dev_ctx.template Alloc(indices); if (rank == 0) { phi::Copy(dev_ctx, input, dev_ctx.GetPlace(), false, output); phi::funcs::set_constant(dev_ctx, indices, 0); return; } - // Use thrust for parallel acceleration when the input size is equal to the // length of the ‘axis’ dimension. // Compared to the following 'Special case for full sort', ascending sort is @@ -431,23 +246,13 @@ void ArgsortKernel(const Context &dev_ctx, const int64_t input_height = phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); const int64_t input_width = in_dims[in_dims.size() - 1]; - if (input_height <= 4) { - ArgFullSortForTinyRows(dev_ctx, - &input, - output, - indices, - input_height, - input_width, - descending); - } else { - ArgFullSort(dev_ctx, - &input, - output, - indices, - input_height, - input_width, - descending); - } + ArgFullSort(dev_ctx, + &input, + output, + indices, + input_height, + input_width, + descending); } else { // if not full sort, do transpose first std::vector trans; @@ -466,7 +271,7 @@ void ArgsortKernel(const Context &dev_ctx, DenseTensor trans_inp; trans_inp.Resize(trans_dims); - T *trans_inp_data = dev_ctx.template Alloc(&trans_inp); + T* trans_inp_data = dev_ctx.template Alloc(&trans_inp); // Do transpose TransposeKernel(dev_ctx, input, trans, &trans_inp); @@ -484,23 +289,13 @@ void ArgsortKernel(const Context &dev_ctx, dev_ctx.template Alloc(&tmp_indices); dev_ctx.template Alloc(indices); - if (input_height <= 4) { - ArgFullSortForTinyRows(dev_ctx, - &trans_inp, - &tmp_out, - &tmp_indices, - input_height, - input_width, - descending); - } else { - ArgFullSort(dev_ctx, - &trans_inp, - &tmp_out, - &tmp_indices, - input_height, - input_width, - descending); - } + ArgFullSort(dev_ctx, + &trans_inp, + &tmp_out, + &tmp_indices, + input_height, + input_width, + descending); TransposeKernel(dev_ctx, tmp_indices, trans, indices); // transpose back -- GitLab