未验证 提交 9e9b705a 编写于 作者: V Vvsmile 提交者: GitHub

Optimize the implementation of the argsort operator. (#47738)

Optimize the implementation of the argsort operator
上级 de443726
......@@ -64,8 +64,10 @@ struct SegmentOffsetIter {
int num_cols_;
};
#define PADDLE_CUDA_NUM_THREADS 1024
template <typename T>
static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
static __global__ void FillIndex(T *indices, T num_rows, T num_cols) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
......@@ -78,23 +80,246 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
// Sort by flag descending, True: descending. False: Ascending.
// Default is false.
template <typename T, typename IndType>
void ArgFullSort(const phi::GPUContext& ctx,
const DenseTensor* input,
DenseTensor* output,
DenseTensor* indices,
const IndType num_rows,
const IndType num_cols,
static __global__ void FillIndexAndSegmentKernel(int2 *data,
int numel,
int nsort) {
CUDA_KERNEL_LOOP(idx, numel) {
auto segment = static_cast<int>(idx / nsort);
auto sort = static_cast<int>(idx % nsort);
data[idx] = int2{segment, sort};
}
}
#define CUB_WRAPPER(func, ctx, ...) \
do { \
size_t temp_storage_bytes = 0; \
gpuError_t err; \
err = func(nullptr, temp_storage_bytes, __VA_ARGS__); \
PADDLE_ENFORCE_GPU_SUCCESS(err); \
DenseTensor temp_storage; \
int64_t temp_size = temp_storage_bytes; \
temp_storage.Resize({temp_size}); \
ctx.template Alloc<uint8_t>(&temp_storage); \
err = func(temp_storage.data<uint8_t>(), temp_storage_bytes, __VA_ARGS__); \
PADDLE_ENFORCE_GPU_SUCCESS(err); \
} while (false)
template <typename KT, typename VT>
static void RadixSortPairs(const phi::GPUContext &ctx,
const KT *keys_in,
const VT *values_in,
KT *keys_out,
VT *values_out,
int64_t n,
bool descending = false,
int64_t begin_bit = 0,
int64_t end_bit = sizeof(KT) * 8) {
if (keys_out == nullptr) {
DenseTensor key_out_owner;
key_out_owner.Resize({n});
ctx.template Alloc<KT>(&key_out_owner);
keys_out = key_out_owner.data<KT>();
}
if (descending) {
CUB_WRAPPER(cub::DeviceRadixSort::SortPairsDescending,
ctx,
keys_in,
keys_out,
values_in,
values_out,
n,
begin_bit,
end_bit,
ctx.stream());
} else {
CUB_WRAPPER(cub::DeviceRadixSort::SortPairs,
ctx,
keys_in,
keys_out,
values_in,
values_out,
n,
begin_bit,
end_bit,
ctx.stream());
}
}
template <typename KT>
static void RadixSortKeys(const phi::GPUContext &ctx,
const KT *keys_in,
KT *keys_out,
int64_t n,
bool descending,
int64_t begin_bit,
int64_t end_bit) {
if (descending) {
CUB_WRAPPER(cub::DeviceRadixSort::SortKeysDescending,
ctx,
keys_in,
keys_out,
n,
begin_bit,
end_bit,
ctx.stream());
} else {
CUB_WRAPPER(cub::DeviceRadixSort::SortKeys,
ctx,
keys_in,
keys_out,
n,
begin_bit,
end_bit,
ctx.stream());
}
}
template <typename T>
static __global__ void SortPostprocessKernel(const T *in,
const int2 *i_s_ptr,
T *out,
int64_t *index,
int nsegments,
int nsort) {
CUDA_KERNEL_LOOP(i, nsegments * nsort) {
int segment = i / nsort; // segment_id
int j = i % nsort;
int offset = segment * nsort;
const T *in_ = in + offset;
T *out_ = out + offset;
int64_t *index_ = index + offset;
const int2 *i_s_ptr_ = i_s_ptr + offset;
int idx = i_s_ptr_[j].y;
index_[j] = idx;
out_[j] = in_[idx];
}
}
template <typename T>
inline void SegmentedSortPairsByFullSort(const phi::GPUContext &ctx,
const T *const self_ptr,
T *const values_ptr,
int64_t *const indices_ptr,
const int64_t nsegments,
const int64_t nsort,
const int64_t n,
const bool descending) {
int64_t segment_bits = std::max<int64_t>(
1L, static_cast<int64_t>(std::ceil(std::log2(nsegments))));
const auto numel = nsort * nsegments;
DenseTensor indices_and_segment;
int64_t indices_and_segment_size = numel;
indices_and_segment.Resize({indices_and_segment_size * 2});
ctx.template Alloc<int64_t>(&indices_and_segment);
auto i_s_ptr_base = indices_and_segment.data<int64_t>();
auto i_s_ptr = reinterpret_cast<int2 *>(i_s_ptr_base);
dim3 block = PADDLE_CUDA_NUM_THREADS;
auto block_num = (numel - 1) / PADDLE_CUDA_NUM_THREADS + 1;
dim3 grid = static_cast<int>(block_num);
auto cu_stream = ctx.stream();
FillIndexAndSegmentKernel<<<grid, block, 0, cu_stream>>>(
i_s_ptr, numel, nsort);
DenseTensor indices_and_segment2;
int64_t indices_and_segment2_size = numel;
indices_and_segment2.Resize({indices_and_segment2_size * 2});
ctx.template Alloc<int64_t>(&indices_and_segment2);
auto i_s_ptr2_base = indices_and_segment2.data<int64_t>();
auto i_s_ptr2 = reinterpret_cast<int2 *>(i_s_ptr2_base);
RadixSortPairs<T, int2>(
ctx, self_ptr, i_s_ptr, nullptr, i_s_ptr2, n, descending);
RadixSortKeys<int64_t>(ctx,
reinterpret_cast<int64_t *>(i_s_ptr2),
reinterpret_cast<int64_t *>(i_s_ptr),
n,
false,
0,
segment_bits);
SortPostprocessKernel<<<grid, block, 0, cu_stream>>>(
self_ptr, i_s_ptr, values_ptr, indices_ptr, nsegments, nsort);
}
// The method is called when # of the rows of the input is less than or equal to
// 4
template <typename T, typename IndexType>
void ArgFullSortForTinyRows(const phi::GPUContext &ctx,
const DenseTensor *input,
DenseTensor *output,
DenseTensor *indices,
const IndexType num_rows,
const IndexType num_cols,
const bool descending) {
auto gpu_stream = ctx.stream();
size_t temp_storage_bytes = -1;
IndexType numel = num_rows * num_cols;
if (numel == 0) {
return;
}
IndexType numel_or_intmax =
std::min(numel, static_cast<int64_t>(std::numeric_limits<int>::max()));
IndexType nsort = num_cols;
IndexType nbatch = (numel_or_intmax / nsort) * nsort;
T *sorted_out_ptr;
IndexType *sorted_indices_ptr;
const T *input_data = input->data<T>();
T *out = ctx.template Alloc<T>(output);
IndexType *ind = ctx.template Alloc<IndexType>(indices);
sorted_out_ptr = out;
sorted_indices_ptr = ind;
int64_t remaining = numel;
while (remaining > 0) {
int64_t n = std::min(remaining, nbatch);
IndexType nsegments = n / nsort;
SegmentedSortPairsByFullSort(ctx,
input_data,
sorted_out_ptr,
sorted_indices_ptr,
nsegments,
nsort,
n,
descending);
remaining -= n;
input_data += n;
sorted_out_ptr += n;
sorted_indices_ptr += n;
}
}
template <typename T, typename IndexType>
void ArgFullSort(const phi::GPUContext &ctx,
const DenseTensor *input,
DenseTensor *output,
DenseTensor *indices,
const IndexType num_rows,
const IndexType num_cols,
const bool descending) {
auto cu_stream = ctx.stream();
DenseTensor input_indices;
const std::vector<IndType> dims = {num_rows, num_cols};
const std::vector<IndexType> dims = {num_rows, num_cols};
auto dim = phi::make_ddim(dims);
input_indices.Resize(dim);
ctx.template Alloc<IndType>(&input_indices);
ctx.template Alloc<IndexType>(&input_indices);
size_t temp_storage_bytes = -1;
auto ComputeBlockSize = [](IndType col) {
auto ComputeBlockSize = [](IndexType col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
......@@ -113,111 +338,70 @@ void ArgFullSort(const phi::GPUContext& ctx,
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
// Init a index array
FillIndex<<<grid_size, block_size, 0, cu_stream>>>(
input_indices.data<IndType>(), num_rows, num_cols);
input_indices.data<IndexType>(), num_rows, num_cols);
T* sorted_out_ptr;
IndType* sorted_indices_ptr;
const T* inp = input->data<T>();
T* out = ctx.template Alloc<T>(output);
IndType* ind = ctx.template Alloc<IndType>(indices);
T *sorted_out_ptr;
IndexType *sorted_indices_ptr;
const T *inp = input->data<T>();
T *out = ctx.template Alloc<T>(output);
IndexType *ind = ctx.template Alloc<IndexType>(indices);
sorted_out_ptr = out;
sorted_indices_ptr = ind;
// create iter for counting input
cub::CountingInputIterator<IndType> counting_iter(0);
cub::CountingInputIterator<IndexType> counting_iter(0);
// segment_offset is used for move to next row
cub::TransformInputIterator<IndType,
cub::TransformInputIterator<IndexType,
SegmentOffsetIter,
cub::CountingInputIterator<IndType>>
cub::CountingInputIterator<IndexType>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
gpuError_t err;
if (descending) {
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
nullptr,
temp_storage_bytes,
inp,
sorted_out_ptr,
input_indices.data<IndType>(),
sorted_indices_ptr,
num_cols * num_rows,
num_rows,
segment_offsets_t,
segment_offsets_t + 1,
0,
sizeof(T) * 8,
cu_stream);
CUB_WRAPPER(cub::DeviceSegmentedRadixSort::SortPairsDescending,
ctx,
inp,
sorted_out_ptr,
input_indices.data<IndexType>(),
sorted_indices_ptr,
num_cols * num_rows,
num_rows,
segment_offsets_t,
segment_offsets_t + 1,
0,
sizeof(T) * 8,
ctx.stream());
} else {
err =
cub::DeviceSegmentedRadixSort::SortPairs(nullptr,
temp_storage_bytes,
inp,
sorted_out_ptr,
input_indices.data<IndType>(),
sorted_indices_ptr,
num_cols * num_rows,
num_rows,
segment_offsets_t,
segment_offsets_t + 1,
0,
sizeof(T) * 8,
cu_stream);
CUB_WRAPPER(cub::DeviceSegmentedRadixSort::SortPairs,
ctx,
inp,
sorted_out_ptr,
input_indices.data<IndexType>(),
sorted_indices_ptr,
num_cols * num_rows,
num_rows,
segment_offsets_t,
segment_offsets_t + 1,
0,
sizeof(T) * 8,
ctx.stream());
}
PADDLE_ENFORCE_GPU_SUCCESS(err);
DenseTensor temp_storage;
int64_t temp_size = temp_storage_bytes;
temp_storage.Resize({temp_size});
ctx.template Alloc<uint8_t>(&temp_storage);
if (descending) {
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
temp_storage.data<uint8_t>(),
temp_storage_bytes,
inp,
sorted_out_ptr,
input_indices.data<IndType>(),
sorted_indices_ptr,
num_cols * num_rows,
num_rows,
segment_offsets_t,
segment_offsets_t + 1,
0,
sizeof(T) * 8,
cu_stream);
} else {
err =
cub::DeviceSegmentedRadixSort::SortPairs(temp_storage.data<uint8_t>(),
temp_storage_bytes,
inp,
sorted_out_ptr,
input_indices.data<IndType>(),
sorted_indices_ptr,
num_cols * num_rows,
num_rows,
segment_offsets_t,
segment_offsets_t + 1,
0,
sizeof(T) * 8,
cu_stream);
}
PADDLE_ENFORCE_GPU_SUCCESS(err);
}
template <typename T, typename Context>
void ArgsortKernel(const Context& dev_ctx,
const DenseTensor& input,
void ArgsortKernel(const Context &dev_ctx,
const DenseTensor &input,
int axis,
bool descending,
DenseTensor* output,
DenseTensor* indices) {
DenseTensor *output,
DenseTensor *indices) {
auto in_dims = input.dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const T* in_data = input.data<T>();
const T *in_data = input.data<T>();
auto size = input.numel();
T* out_data = dev_ctx.template Alloc<T>(output);
int64_t* ids_data = dev_ctx.template Alloc<int64_t>(indices);
T *out_data = dev_ctx.template Alloc<T>(output);
int64_t *ids_data = dev_ctx.template Alloc<int64_t>(indices);
// Use thrust for parallel acceleration when the input size is equal to the
// length of the ‘axis’ dimension.
......@@ -239,13 +423,23 @@ void ArgsortKernel(const Context& dev_ctx,
const int64_t input_height =
phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
ArgFullSort<T, int64_t>(dev_ctx,
&input,
output,
indices,
input_height,
input_width,
descending);
if (input_height <= 4) {
ArgFullSortForTinyRows<T, int64_t>(dev_ctx,
&input,
output,
indices,
input_height,
input_width,
descending);
} else {
ArgFullSort<T, int64_t>(dev_ctx,
&input,
output,
indices,
input_height,
input_width,
descending);
}
} else {
// if not full sort, do transpose first
std::vector<int> trans;
......@@ -264,7 +458,7 @@ void ArgsortKernel(const Context& dev_ctx,
DenseTensor trans_inp;
trans_inp.Resize(trans_dims);
T* trans_inp_data = dev_ctx.template Alloc<T>(&trans_inp);
T *trans_inp_data = dev_ctx.template Alloc<T>(&trans_inp);
// Do transpose
TransposeKernel<T, Context>(dev_ctx, input, trans, &trans_inp);
......@@ -282,13 +476,23 @@ void ArgsortKernel(const Context& dev_ctx,
dev_ctx.template Alloc<int64_t>(&tmp_indices);
dev_ctx.template Alloc<int64_t>(indices);
ArgFullSort<T, int64_t>(dev_ctx,
&trans_inp,
&tmp_out,
&tmp_indices,
input_height,
input_width,
descending);
if (input_height <= 4) {
ArgFullSortForTinyRows<T, int64_t>(dev_ctx,
&trans_inp,
&tmp_out,
&tmp_indices,
input_height,
input_width,
descending);
} else {
ArgFullSort<T, int64_t>(dev_ctx,
&trans_inp,
&tmp_out,
&tmp_indices,
input_height,
input_width,
descending);
}
TransposeKernel<int64_t, Context>(dev_ctx, tmp_indices, trans, indices);
// transpose back
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册