未验证 提交 b1627455 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Add dscending for argsort (#21400)

* Add ascending for argsort

* Refine api doc description.

* Refine descending description

* Add int32 logic to speedup when data is small size.

* Remove int32 opt as not support in python
上级 6b09b73e
......@@ -73,6 +73,13 @@ Output(Indices) gives the sorted order along the given axis Attr(axis).
"When axis < 0, the actual axis will be the |axis|'th "
"counting backwards. Default -1, the last dimension.")
.SetDefault(-1);
AddAttr<bool>(
"descending",
"(bool, default false) The descending attribute is a flag to tell"
"algorithm how to sort the input data."
"If descending is true, will sort by descending order,"
"else if false, sort by ascending order. Default value is false.")
.SetDefault(false);
}
};
......
......@@ -58,11 +58,12 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
}
}
// Default use ascending sort
// Sort by flag descending, True: descending. False: Ascending.
// Default is false.
template <typename T, typename IndType>
void ArgFullSortAscending(const platform::CUDADeviceContext& ctx,
const Tensor* input, Tensor* output, Tensor* indices,
const IndType num_rows, const IndType num_cols) {
void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
Tensor* output, Tensor* indices, const IndType num_rows,
const IndType num_cols, const bool descending) {
auto cu_stream = ctx.stream();
Tensor input_indices;
......@@ -113,12 +114,20 @@ void ArgFullSortAscending(const platform::CUDADeviceContext& ctx,
cub::CountingInputIterator<IndType>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
auto err = cub::DeviceSegmentedRadixSort::SortPairs(
nullptr, temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
cudaError_t err;
if (descending) {
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
nullptr, temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
} else {
err = cub::DeviceSegmentedRadixSort::SortPairs(
nullptr, temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
}
PADDLE_ENFORCE_CUDA_SUCCESS(
err,
"ArgSortOP failed as could not launch "
......@@ -129,11 +138,19 @@ void ArgFullSortAscending(const platform::CUDADeviceContext& ctx,
Tensor temp_storage;
temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);
err = cub::DeviceSegmentedRadixSort::SortPairs(
temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
if (descending) {
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
} else {
err = cub::DeviceSegmentedRadixSort::SortPairs(
temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
}
PADDLE_ENFORCE_CUDA_SUCCESS(
err,
......@@ -151,6 +168,7 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices");
int axis = ctx.Attr<int>("axis");
bool descending = ctx.Attr<bool>("descending");
auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
......@@ -164,14 +182,8 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
if (input_width < INT_MAX && input_height < INT_MAX) {
ArgFullSortAscending<T, int>(dev_ctx, input, output, indices,
static_cast<int>(input_height),
static_cast<int>(input_width));
} else {
ArgFullSortAscending<T, int64_t>(dev_ctx, input, output, indices,
input_height, input_width);
}
ArgFullSort<T, int64_t>(dev_ctx, input, output, indices, input_height,
input_width, descending);
} else {
// if not full sort, do transpose first
std::vector<int> trans;
......@@ -205,29 +217,15 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
T* out_data = output->mutable_data<T>(ctx.GetPlace());
Tensor tmp_indices;
if (input_height < INT_MAX && input_width < INT_MAX) {
// temp indices for sorting
tmp_indices.mutable_data<int>(trans_dims, ctx.GetPlace());
indices->mutable_data<int>(ctx.GetPlace());
ArgFullSortAscending<T, int>(
dev_ctx, &trans_inp, &tmp_out, &tmp_indices,
static_cast<int>(input_height), static_cast<int>(input_width));
TransCompute<platform::CUDADeviceContext, int>(
ndims, dev_ctx, tmp_indices, indices, trans);
} else {
// temp indices for sorting
tmp_indices.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
indices->mutable_data<int64_t>(ctx.GetPlace());
ArgFullSortAscending<T, int64_t>(dev_ctx, &trans_inp, &tmp_out,
&tmp_indices, input_height,
input_width);
TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, tmp_indices, indices, trans);
}
// temp indices for sorting
tmp_indices.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
indices->mutable_data<int64_t>(ctx.GetPlace());
ArgFullSort<T, int64_t>(dev_ctx, &trans_inp, &tmp_out, &tmp_indices,
input_height, input_width, descending);
TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, tmp_indices, indices, trans);
// transpose back
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, tmp_out,
output, trans);
......
......@@ -16,11 +16,58 @@ limitations under the License. */
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/transpose_op.h"
namespace paddle {
namespace operators {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
using Tensor = framework::Tensor;
template <typename T, typename Type>
static void FullSort(Type input_height, Type input_width, int input_dim,
const framework::Tensor* input, T* t_out, Type* t_indices,
bool descending) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (Type i = 0; i < input_height; ++i) {
std::vector<std::pair<T, Type>> col_vec;
col_vec.reserve(input_width);
if (input_dim == 1) {
auto e_input = EigenVector<T>::Flatten(*input);
for (Type j = 0; j < input_width; ++j) {
col_vec.push_back(std::pair<T, Type>(e_input(j), j));
}
} else {
auto e_input = EigenMatrix<T>::Reshape(*input, input_dim - 1);
for (Type j = 0; j < input_width; ++j) {
col_vec.push_back(std::pair<T, Type>(e_input(i, j), j));
}
}
std::sort(col_vec.begin(), col_vec.end(),
[&](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
if (descending)
return l.first > r.first;
else
return l.first < r.first;
});
for (Type j = 0; j < input_width; ++j) {
t_out[i * input_width + j] = col_vec[j].first;
t_indices[i * input_width + j] = col_vec[j].second;
}
}
}
template <typename DeviceContext, typename T>
class ArgsortKernel : public framework::OpKernel<T> {
public:
......@@ -29,50 +76,68 @@ class ArgsortKernel : public framework::OpKernel<T> {
auto* output = ctx.Output<framework::Tensor>("Out");
auto* indices = ctx.Output<framework::Tensor>("Indices");
int axis = ctx.Attr<int>("axis");
bool descending = ctx.Attr<bool>("descending");
auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* ids_data = indices->mutable_data<int64_t>(ctx.GetPlace());
int64_t groups = input->numel() / in_dims[axis];
int64_t stride = (axis == in_dims.size() - 1)
? 1
: framework::product(framework::slice_ddim(
in_dims, axis + 1, in_dims.size()));
for (int64_t i = 0; i < groups; ++i) {
int64_t idx = i;
std::vector<int64_t> shape_vec(in_dims.size(), 0);
for (int64_t dim = in_dims.size() - 1; dim >= 0; --dim) {
if (dim != axis) {
shape_vec[dim] = idx % in_dims[dim];
idx /= in_dims[dim];
}
}
int64_t start_index = shape_vec[0];
for (int64_t dim = 0; dim < in_dims.size() - 1; ++dim) {
start_index = start_index * in_dims[dim + 1] + shape_vec[dim + 1];
}
// Do full sort
if (axis == -1 || axis + 1 == in_dims.size()) {
const int64_t input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
std::vector<int64_t> org_index_vec(in_dims[axis], start_index);
for (int64_t j = 1; j < in_dims[axis]; ++j) {
org_index_vec[j] += j * stride;
int64_t* ids_data = indices->mutable_data<int64_t>(ctx.GetPlace());
FullSort<T, int64_t>(input_height, input_width, in_dims.size(), input,
out_data, ids_data, descending);
} else {
// If not full sort do transpose
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.push_back(i);
}
trans.push_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.push_back(i);
}
trans.push_back(axis);
framework::DDim trans_dims(in_dims);
for (int i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
}
std::sort(org_index_vec.begin(), org_index_vec.end(),
[in_data](const int64_t v1, const int64_t v2) {
return in_data[v1] < in_data[v2];
});
Tensor trans_inp;
trans_inp.mutable_data<T>(trans_dims, ctx.GetPlace());
int ndims = trans.size();
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
// Do transpose
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_ctx, *input,
&trans_inp, trans);
for (size_t j = 0; j < org_index_vec.size(); ++j) {
int64_t index = start_index + j * stride;
out_data[index] = in_data[org_index_vec[j]];
ids_data[index] = (org_index_vec[j] - start_index) / stride;
}
const int64_t input_height = framework::product(
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];
Tensor tmp_out;
T* t_out = tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());
output->mutable_data<T>(ctx.GetPlace());
Tensor tmp_indices;
auto* t_ind =
tmp_indices.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
FullSort<T, int64_t>(input_height, input_width, in_dims.size(),
&trans_inp, t_out, t_ind, descending);
indices->mutable_data<int64_t>(ctx.GetPlace());
TransCompute<platform::CPUDeviceContext, int64_t>(
ndims, dev_ctx, tmp_indices, indices, trans);
// transpose back
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_ctx, tmp_out,
output, trans);
}
}
};
......
......@@ -53,7 +53,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) {
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (threadIdx.x < block_span) ? shared[lane] : (T)(0.0f);
val = (threadIdx.x < block_span) ? shared[lane] : static_cast<T>(0.0f);
val = warpReduceSum<T>(val, mask);
return val;
......
......@@ -802,7 +802,7 @@ def argmax(x, axis=0):
return out
def argsort(input, axis=-1, name=None):
def argsort(input, axis=-1, descending=False, name=None):
"""
This OP sorts the input along the given axis, and returns sorted output
data Varibale and its corresponding index Variable with the same shape as
......@@ -814,6 +814,9 @@ def argsort(input, axis=-1, name=None):
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is Rank(x). when axis<0, it works the same way
as axis+R. Default is 0.
descending(bool, optional) : Descending is a flag, if set to true,
algorithm will sort by descending order, else sort by
ascending order. Default is false.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
......@@ -879,7 +882,8 @@ def argsort(input, axis=-1, name=None):
inputs={'X': input},
outputs={'Out': out,
'Indices': ids},
attrs={'axis': axis})
attrs={'axis': axis,
'descending': descending})
return out, ids
......
......@@ -24,14 +24,24 @@ class TestArgsortOp(OpTest):
def setUp(self):
self.init_axis()
self.init_datatype()
self.init_direction()
x = np.random.random((2, 3, 4, 5, 10)).astype(self.dtype)
self.attrs = {'axis': self.axis, 'descending': self.descending}
if self.axis < 0:
self.axis = self.axis + len(x.shape)
self.indices = np.argsort(x, kind='quicksort', axis=self.axis)
self.out = np.sort(x, kind='quicksort', axis=self.axis)
if self.descending:
self.indices = np.flip(
np.argsort(
x, kind='quicksort', axis=self.axis), self.axis)
self.out = np.flip(
np.sort(
x, kind='quicksort', axis=self.axis), self.axis)
else:
self.indices = np.argsort(x, kind='quicksort', axis=self.axis)
self.out = np.sort(x, kind='quicksort', axis=self.axis)
self.op_type = "argsort"
self.inputs = {'X': x}
self.attrs = {'axis': self.axis}
self.outputs = {'Indices': self.indices, 'Out': self.out}
def init_axis(self):
......@@ -40,6 +50,9 @@ class TestArgsortOp(OpTest):
def init_datatype(self):
self.dtype = "float32"
def init_direction(self):
self.descending = False
def test_check_output(self):
self.check_output()
......@@ -103,5 +116,35 @@ class TestArgsortOpFP16Axis4Neg4(TestArgsortOpFP16):
self.axis = -4
class TestArgsortOpDescendingAxis(TestArgsortOp):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis0(TestArgsortOpAxis0):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis1(TestArgsortOpAxis1):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis2(TestArgsortOpAxis2):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxisNeg1(TestArgsortOpAxisNeg1):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxisNeg2(TestArgsortOpAxisNeg2):
def init_direction(self):
self.descending = True
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册