未验证 提交 08c19c58 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Improve argsort performance. (#21267)

* Improve argsort performance.

- Give 200000 data to compute argsort on v100,
can speed up ~190x
before opt cost: 0.53s
after opt cost:0.0027s

- Add fp16 support

* Refine error message
* Refine code

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>
上级 7fcaa39b
......@@ -14,82 +14,133 @@ limitations under the License. */
#include <thrust/execution_policy.h>
#include <thrust/sort.h>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/argsort_op.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
// set cub base traits in order to handle float16
namespace cub {
template <>
struct NumericTraits<paddle::platform::float16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t,
paddle::platform::float16> {};
} // namespace cub
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using platform::PADDLE_CUDA_NUM_THREADS;
const int kMaxRank = 9; // The max rank of a tensor allowed in Fluid
__global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size,
int axis, int64_t n, int64_t* trg_idx,
int64_t* med_ids) {
int64_t index = threadIdx.x + blockDim.x * blockIdx.x;
if (index < n) {
int64_t shape_out_axis[kMaxRank - 1] = {0};
int64_t dims_out_axis[kMaxRank - 1] = {0};
int64_t tmp = index;
int64_t pos_in_axis = 0;
int64_t i = dims_size - 2;
int64_t dim_axis = 0;
for (int64_t j = dims_size - 1; j >= 0; --j) {
int64_t dim = in_dims[j];
if (j != axis) {
shape_out_axis[i] = tmp % dim;
dims_out_axis[i] = dim;
i--;
} else {
dim_axis = dim;
pos_in_axis = tmp % dim_axis;
}
tmp /= dim;
}
int64_t group = (dims_size > 1) ? shape_out_axis[0] : 0;
for (int64_t j = 0; j < dims_size - 2; ++j) {
group = group * dims_out_axis[j + 1] + shape_out_axis[j + 1];
}
int64_t traget_idx = group * dim_axis + pos_in_axis;
trg_idx[index] = traget_idx;
med_ids[traget_idx] = pos_in_axis;
}
}
// Iter for move to next row
struct SegmentOffsetIter {
EIGEN_DEVICE_FUNC
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
template <typename T>
__global__ void PermuteInData(const T* in, const int64_t* trg_idx, int64_t n,
T* med_out) {
int index = threadIdx.x + blockDim.x * blockIdx.x;
if (index < n) {
med_out[trg_idx[index]] = in[index];
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
return idx * num_cols_;
}
}
int num_cols_;
};
template <typename T>
__global__ void Sort(int64_t axis_dim, int64_t groups, T* med_out,
int64_t* med_ids) {
int index = threadIdx.x + blockDim.x * blockIdx.x;
if (index < groups) {
thrust::sort_by_key(thrust::device, med_out + index * axis_dim,
med_out + axis_dim * (1 + index),
med_ids + index * axis_dim);
static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
for (T j = row_id; j < num_rows; j += gridDim.x) {
for (T i = col_id; i < num_cols; i += blockDim.x) {
indices[j * num_cols + i] = i;
}
}
}
template <typename T>
__global__ void PermuteMediateData(const T* med_out, const int64_t* med_ids,
const int64_t* trg_idx, int64_t n, T* out,
int64_t* indices) {
int index = threadIdx.x + blockDim.x * blockIdx.x;
if (index < n) {
out[index] = med_out[trg_idx[index]];
indices[index] = med_ids[trg_idx[index]];
}
// Default use ascending sort
template <typename T, typename IndType>
void ArgFullSortAscending(const platform::CUDADeviceContext& ctx,
const Tensor* input, Tensor* output, Tensor* indices,
const IndType num_rows, const IndType num_cols) {
auto cu_stream = ctx.stream();
Tensor input_indices;
const std::vector<IndType> dims = {num_rows, num_cols};
auto dim = framework::make_ddim(dims);
input_indices.Resize(dim);
input_indices.mutable_data<IndType>(ctx.GetPlace());
size_t temp_storage_bytes = -1;
auto ComputeBlockSize = [](IndType col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else if (col > 64 && col <= 128)
return 128;
else
return 64;
};
int block_size = ComputeBlockSize(num_cols);
int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
// actually, int num_rows < max_grid_size
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
// Init a index array
FillIndex<<<grid_size, block_size, 0, cu_stream>>>(
input_indices.data<IndType>(), num_rows, num_cols);
T* sorted_out_ptr;
IndType* sorted_indices_ptr;
const T* inp = input->data<T>();
T* out = output->mutable_data<T>(ctx.GetPlace());
IndType* ind = indices->mutable_data<IndType>(ctx.GetPlace());
sorted_out_ptr = out;
sorted_indices_ptr = ind;
// create iter for counting input
cub::CountingInputIterator<IndType> counting_iter(0);
// segment_offset is used for move to next row
cub::TransformInputIterator<IndType, SegmentOffsetIter,
cub::CountingInputIterator<IndType>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
auto err = cub::DeviceSegmentedRadixSort::SortPairs(
nullptr, temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
PADDLE_ENFORCE_CUDA_SUCCESS(
err,
"ArgSortOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate"
"temp_storage_bytes, status:%s.",
temp_storage_bytes, cudaGetErrorString(err));
Tensor temp_storage;
temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);
err = cub::DeviceSegmentedRadixSort::SortPairs(
temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
PADDLE_ENFORCE_CUDA_SUCCESS(
err,
"ArgSortOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
"temp_storage_bytes:%d status:%s.",
temp_storage_bytes, cudaGetErrorString(err));
}
template <typename T>
......@@ -104,47 +155,91 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* ids_data = indices->mutable_data<int64_t>(ctx.GetPlace());
int64_t numel = input->numel();
int64_t groups = numel / in_dims[axis];
std::vector<int64_t> in_dims_vec = vectorize(in_dims);
thrust::device_vector<int64_t> in_dims_dev(in_dims_vec.begin(),
in_dims_vec.end());
int64_t* in_dims_data = thrust::raw_pointer_cast(in_dims_dev.data());
// Mediate tensor for sorting data and indices
Tensor mediate_output, mediate_indices;
T* med_out_data =
mediate_output.mutable_data<T>(input->dims(), ctx.GetPlace());
int64_t* med_ids_data =
mediate_indices.mutable_data<int64_t>(in_dims, ctx.GetPlace());
// Target index of each element along the given axis in the mediate tensors
Tensor trg_idx_t;
int64_t* trg_idx = trg_idx_t.mutable_data<int64_t>(in_dims, ctx.GetPlace());
auto stream = ctx.cuda_device_context().stream();
const int num_threads = PADDLE_CUDA_NUM_THREADS;
ComputeTargetIdx<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>(
in_dims_data, in_dims.size(), axis, numel, trg_idx, med_ids_data);
PermuteInData<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>(
in_data, trg_idx, numel, med_out_data);
Sort<<<(groups - 1) / num_threads + 1, num_threads, 0, stream>>>(
in_dims[axis], groups, med_out_data, med_ids_data);
PermuteMediateData<<<(numel - 1) / num_threads + 1, num_threads, 0,
stream>>>(med_out_data, med_ids_data, trg_idx, numel,
out_data, ids_data);
// Special case for full sort, speedup ~190x.
if (axis == -1 || axis + 1 == in_dims.size()) {
const int64_t input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
if (input_width < INT_MAX && input_height < INT_MAX) {
ArgFullSortAscending<T, int>(dev_ctx, input, output, indices,
static_cast<int>(input_height),
static_cast<int>(input_width));
} else {
ArgFullSortAscending<T, int64_t>(dev_ctx, input, output, indices,
input_height, input_width);
}
} else {
// if not full sort, do transpose first
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.push_back(i);
}
trans.push_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.push_back(i);
}
trans.push_back(axis);
framework::DDim trans_dims(in_dims);
for (int i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
}
Tensor trans_inp;
T* trans_inp_data = trans_inp.mutable_data<T>(trans_dims, ctx.GetPlace());
int ndims = trans.size();
const auto& dev_ctx = ctx.cuda_device_context();
// Do transpose
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *input,
&trans_inp, trans);
const int64_t input_height = framework::product(
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];
Tensor tmp_out;
tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());
T* out_data = output->mutable_data<T>(ctx.GetPlace());
Tensor tmp_indices;
if (input_height < INT_MAX && input_width < INT_MAX) {
// temp indices for sorting
tmp_indices.mutable_data<int>(trans_dims, ctx.GetPlace());
indices->mutable_data<int>(ctx.GetPlace());
ArgFullSortAscending<T, int>(
dev_ctx, &trans_inp, &tmp_out, &tmp_indices,
static_cast<int>(input_height), static_cast<int>(input_width));
TransCompute<platform::CUDADeviceContext, int>(
ndims, dev_ctx, tmp_indices, indices, trans);
} else {
// temp indices for sorting
tmp_indices.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
indices->mutable_data<int64_t>(ctx.GetPlace());
ArgFullSortAscending<T, int64_t>(dev_ctx, &trans_inp, &tmp_out,
&tmp_indices, input_height,
input_width);
TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, tmp_indices, indices, trans);
}
// transpose back
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, tmp_out,
output, trans);
return;
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(argsort, paddle::operators::ArgsortOpCUDAKernel<float>,
paddle::operators::ArgsortOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
argsort, paddle::operators::ArgsortOpCUDAKernel<float>,
paddle::operators::ArgsortOpCUDAKernel<double>,
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::float16>);
......@@ -17,12 +17,14 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
class TestArgsortOp(OpTest):
def setUp(self):
self.init_axis()
x = np.random.random((2, 3, 4, 5, 10)).astype("float32")
self.init_datatype()
x = np.random.random((2, 3, 4, 5, 10)).astype(self.dtype)
if self.axis < 0:
self.axis = self.axis + len(x.shape)
self.indices = np.argsort(x, kind='quicksort', axis=self.axis)
......@@ -35,6 +37,9 @@ class TestArgsortOp(OpTest):
def init_axis(self):
self.axis = -1
def init_datatype(self):
self.dtype = "float32"
def test_check_output(self):
self.check_output()
......@@ -49,10 +54,54 @@ class TestArgsortOpAxis1(TestArgsortOp):
self.axis = 1
class TestArgsortOpAxis2(TestArgsortOp):
def init_axis(self):
self.axis = 2
class TestArgsortOpAxisNeg1(TestArgsortOp):
def init_axis(self):
self.axis = -1
class TestArgsortOpAxisNeg2(TestArgsortOp):
def init_axis(self):
self.axis = -2
class TestArgsortOpFP16(TestArgsortOp):
def init_datatype(self):
if core.is_compiled_with_cuda():
self.dtype = 'float16'
def test_check_output(self):
pass
def test_check_output_with_place(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
class TestArgsortOpFP16Axis0(TestArgsortOpFP16):
def init_axis(self):
self.axis = 0
class TestArgsortOpFP16Axis2(TestArgsortOpFP16):
def init_axis(self):
self.axis = 2
class TestArgsortOpFP16AxisNeg2(TestArgsortOpFP16):
def init_axis(self):
self.axis = -2
class TestArgsortOpFP16Axis4Neg4(TestArgsortOpFP16):
def init_axis(self):
self.axis = -4
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册