未验证 提交 b1627455 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Add dscending for argsort (#21400)

* Add ascending for argsort

* Refine api doc description.

* Refine descending description

* Add int32 logic to speedup when data is small size.

* Remove int32 opt as not support in python
上级 6b09b73e
...@@ -73,6 +73,13 @@ Output(Indices) gives the sorted order along the given axis Attr(axis). ...@@ -73,6 +73,13 @@ Output(Indices) gives the sorted order along the given axis Attr(axis).
"When axis < 0, the actual axis will be the |axis|'th " "When axis < 0, the actual axis will be the |axis|'th "
"counting backwards. Default -1, the last dimension.") "counting backwards. Default -1, the last dimension.")
.SetDefault(-1); .SetDefault(-1);
AddAttr<bool>(
"descending",
"(bool, default false) The descending attribute is a flag to tell"
"algorithm how to sort the input data."
"If descending is true, will sort by descending order,"
"else if false, sort by ascending order. Default value is false.")
.SetDefault(false);
} }
}; };
......
...@@ -58,11 +58,12 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) { ...@@ -58,11 +58,12 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
} }
} }
// Default use ascending sort // Sort by flag descending, True: descending. False: Ascending.
// Default is false.
template <typename T, typename IndType> template <typename T, typename IndType>
void ArgFullSortAscending(const platform::CUDADeviceContext& ctx, void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
const Tensor* input, Tensor* output, Tensor* indices, Tensor* output, Tensor* indices, const IndType num_rows,
const IndType num_rows, const IndType num_cols) { const IndType num_cols, const bool descending) {
auto cu_stream = ctx.stream(); auto cu_stream = ctx.stream();
Tensor input_indices; Tensor input_indices;
...@@ -113,12 +114,20 @@ void ArgFullSortAscending(const platform::CUDADeviceContext& ctx, ...@@ -113,12 +114,20 @@ void ArgFullSortAscending(const platform::CUDADeviceContext& ctx,
cub::CountingInputIterator<IndType>> cub::CountingInputIterator<IndType>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols)); segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
auto err = cub::DeviceSegmentedRadixSort::SortPairs( cudaError_t err;
if (descending) {
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
nullptr, temp_storage_bytes, inp, sorted_out_ptr, nullptr, temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows, input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream); cu_stream);
} else {
err = cub::DeviceSegmentedRadixSort::SortPairs(
nullptr, temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
}
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
err, err,
"ArgSortOP failed as could not launch " "ArgSortOP failed as could not launch "
...@@ -129,11 +138,19 @@ void ArgFullSortAscending(const platform::CUDADeviceContext& ctx, ...@@ -129,11 +138,19 @@ void ArgFullSortAscending(const platform::CUDADeviceContext& ctx,
Tensor temp_storage; Tensor temp_storage;
temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes); temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);
if (descending) {
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
} else {
err = cub::DeviceSegmentedRadixSort::SortPairs( err = cub::DeviceSegmentedRadixSort::SortPairs(
temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr, temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows, input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream); cu_stream);
}
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
err, err,
...@@ -151,6 +168,7 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> { ...@@ -151,6 +168,7 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices"); auto* indices = ctx.Output<Tensor>("Indices");
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
bool descending = ctx.Attr<bool>("descending");
auto in_dims = input->dims(); auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis; axis = (axis < 0) ? (in_dims.size() + axis) : axis;
...@@ -164,14 +182,8 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> { ...@@ -164,14 +182,8 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1]; const int64_t input_width = in_dims[in_dims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context(); const auto& dev_ctx = ctx.cuda_device_context();
if (input_width < INT_MAX && input_height < INT_MAX) { ArgFullSort<T, int64_t>(dev_ctx, input, output, indices, input_height,
ArgFullSortAscending<T, int>(dev_ctx, input, output, indices, input_width, descending);
static_cast<int>(input_height),
static_cast<int>(input_width));
} else {
ArgFullSortAscending<T, int64_t>(dev_ctx, input, output, indices,
input_height, input_width);
}
} else { } else {
// if not full sort, do transpose first // if not full sort, do transpose first
std::vector<int> trans; std::vector<int> trans;
...@@ -205,29 +217,15 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> { ...@@ -205,29 +217,15 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
T* out_data = output->mutable_data<T>(ctx.GetPlace()); T* out_data = output->mutable_data<T>(ctx.GetPlace());
Tensor tmp_indices; Tensor tmp_indices;
if (input_height < INT_MAX && input_width < INT_MAX) {
// temp indices for sorting
tmp_indices.mutable_data<int>(trans_dims, ctx.GetPlace());
indices->mutable_data<int>(ctx.GetPlace());
ArgFullSortAscending<T, int>(
dev_ctx, &trans_inp, &tmp_out, &tmp_indices,
static_cast<int>(input_height), static_cast<int>(input_width));
TransCompute<platform::CUDADeviceContext, int>(
ndims, dev_ctx, tmp_indices, indices, trans);
} else {
// temp indices for sorting // temp indices for sorting
tmp_indices.mutable_data<int64_t>(trans_dims, ctx.GetPlace()); tmp_indices.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
indices->mutable_data<int64_t>(ctx.GetPlace()); indices->mutable_data<int64_t>(ctx.GetPlace());
ArgFullSortAscending<T, int64_t>(dev_ctx, &trans_inp, &tmp_out, ArgFullSort<T, int64_t>(dev_ctx, &trans_inp, &tmp_out, &tmp_indices,
&tmp_indices, input_height, input_height, input_width, descending);
input_width);
TransCompute<platform::CUDADeviceContext, int64_t>( TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, tmp_indices, indices, trans); ndims, dev_ctx, tmp_indices, indices, trans);
}
// transpose back // transpose back
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, tmp_out, TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, tmp_out,
output, trans); output, trans);
......
...@@ -16,11 +16,58 @@ limitations under the License. */ ...@@ -16,11 +16,58 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/transpose_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
using Tensor = framework::Tensor;
template <typename T, typename Type>
static void FullSort(Type input_height, Type input_width, int input_dim,
const framework::Tensor* input, T* t_out, Type* t_indices,
bool descending) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (Type i = 0; i < input_height; ++i) {
std::vector<std::pair<T, Type>> col_vec;
col_vec.reserve(input_width);
if (input_dim == 1) {
auto e_input = EigenVector<T>::Flatten(*input);
for (Type j = 0; j < input_width; ++j) {
col_vec.push_back(std::pair<T, Type>(e_input(j), j));
}
} else {
auto e_input = EigenMatrix<T>::Reshape(*input, input_dim - 1);
for (Type j = 0; j < input_width; ++j) {
col_vec.push_back(std::pair<T, Type>(e_input(i, j), j));
}
}
std::sort(col_vec.begin(), col_vec.end(),
[&](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
if (descending)
return l.first > r.first;
else
return l.first < r.first;
});
for (Type j = 0; j < input_width; ++j) {
t_out[i * input_width + j] = col_vec[j].first;
t_indices[i * input_width + j] = col_vec[j].second;
}
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ArgsortKernel : public framework::OpKernel<T> { class ArgsortKernel : public framework::OpKernel<T> {
public: public:
...@@ -29,50 +76,68 @@ class ArgsortKernel : public framework::OpKernel<T> { ...@@ -29,50 +76,68 @@ class ArgsortKernel : public framework::OpKernel<T> {
auto* output = ctx.Output<framework::Tensor>("Out"); auto* output = ctx.Output<framework::Tensor>("Out");
auto* indices = ctx.Output<framework::Tensor>("Indices"); auto* indices = ctx.Output<framework::Tensor>("Indices");
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
bool descending = ctx.Attr<bool>("descending");
auto in_dims = input->dims(); auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis; axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(ctx.GetPlace()); T* out_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* ids_data = indices->mutable_data<int64_t>(ctx.GetPlace());
int64_t groups = input->numel() / in_dims[axis]; // Do full sort
int64_t stride = (axis == in_dims.size() - 1) if (axis == -1 || axis + 1 == in_dims.size()) {
? 1 const int64_t input_height = framework::product(
: framework::product(framework::slice_ddim( framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
in_dims, axis + 1, in_dims.size())); const int64_t input_width = in_dims[in_dims.size() - 1];
for (int64_t i = 0; i < groups; ++i) { int64_t* ids_data = indices->mutable_data<int64_t>(ctx.GetPlace());
int64_t idx = i; FullSort<T, int64_t>(input_height, input_width, in_dims.size(), input,
std::vector<int64_t> shape_vec(in_dims.size(), 0); out_data, ids_data, descending);
for (int64_t dim = in_dims.size() - 1; dim >= 0; --dim) { } else {
if (dim != axis) { // If not full sort do transpose
shape_vec[dim] = idx % in_dims[dim]; std::vector<int> trans;
idx /= in_dims[dim]; for (int i = 0; i < axis; i++) {
trans.push_back(i);
} }
trans.push_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.push_back(i);
} }
trans.push_back(axis);
int64_t start_index = shape_vec[0]; framework::DDim trans_dims(in_dims);
for (int64_t dim = 0; dim < in_dims.size() - 1; ++dim) { for (int i = 0; i < trans.size(); i++) {
start_index = start_index * in_dims[dim + 1] + shape_vec[dim + 1]; trans_dims[i] = in_dims[trans[i]];
} }
std::vector<int64_t> org_index_vec(in_dims[axis], start_index); Tensor trans_inp;
for (int64_t j = 1; j < in_dims[axis]; ++j) { trans_inp.mutable_data<T>(trans_dims, ctx.GetPlace());
org_index_vec[j] += j * stride; int ndims = trans.size();
} auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
// Do transpose
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_ctx, *input,
&trans_inp, trans);
std::sort(org_index_vec.begin(), org_index_vec.end(), const int64_t input_height = framework::product(
[in_data](const int64_t v1, const int64_t v2) { framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
return in_data[v1] < in_data[v2]; const int64_t input_width = trans_dims[trans_dims.size() - 1];
});
for (size_t j = 0; j < org_index_vec.size(); ++j) { Tensor tmp_out;
int64_t index = start_index + j * stride; T* t_out = tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());
out_data[index] = in_data[org_index_vec[j]]; output->mutable_data<T>(ctx.GetPlace());
ids_data[index] = (org_index_vec[j] - start_index) / stride;
} Tensor tmp_indices;
auto* t_ind =
tmp_indices.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
FullSort<T, int64_t>(input_height, input_width, in_dims.size(),
&trans_inp, t_out, t_ind, descending);
indices->mutable_data<int64_t>(ctx.GetPlace());
TransCompute<platform::CPUDeviceContext, int64_t>(
ndims, dev_ctx, tmp_indices, indices, trans);
// transpose back
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_ctx, tmp_out,
output, trans);
} }
} }
}; };
......
...@@ -53,7 +53,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) { ...@@ -53,7 +53,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) {
// align block_span to warpSize // align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5; int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (threadIdx.x < block_span) ? shared[lane] : (T)(0.0f); val = (threadIdx.x < block_span) ? shared[lane] : static_cast<T>(0.0f);
val = warpReduceSum<T>(val, mask); val = warpReduceSum<T>(val, mask);
return val; return val;
......
...@@ -802,7 +802,7 @@ def argmax(x, axis=0): ...@@ -802,7 +802,7 @@ def argmax(x, axis=0):
return out return out
def argsort(input, axis=-1, name=None): def argsort(input, axis=-1, descending=False, name=None):
""" """
This OP sorts the input along the given axis, and returns sorted output This OP sorts the input along the given axis, and returns sorted output
data Varibale and its corresponding index Variable with the same shape as data Varibale and its corresponding index Variable with the same shape as
...@@ -814,6 +814,9 @@ def argsort(input, axis=-1, name=None): ...@@ -814,6 +814,9 @@ def argsort(input, axis=-1, name=None):
axis(int, optional): Axis to compute indices along. The effective range axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is Rank(x). when axis<0, it works the same way is [-R, R), where R is Rank(x). when axis<0, it works the same way
as axis+R. Default is 0. as axis+R. Default is 0.
descending(bool, optional) : Descending is a flag, if set to true,
algorithm will sort by descending order, else sort by
ascending order. Default is false.
name(str, optional): The default value is None. Normally there is no name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`. refer to :ref:`api_guide_Name`.
...@@ -879,7 +882,8 @@ def argsort(input, axis=-1, name=None): ...@@ -879,7 +882,8 @@ def argsort(input, axis=-1, name=None):
inputs={'X': input}, inputs={'X': input},
outputs={'Out': out, outputs={'Out': out,
'Indices': ids}, 'Indices': ids},
attrs={'axis': axis}) attrs={'axis': axis,
'descending': descending})
return out, ids return out, ids
......
...@@ -24,14 +24,24 @@ class TestArgsortOp(OpTest): ...@@ -24,14 +24,24 @@ class TestArgsortOp(OpTest):
def setUp(self): def setUp(self):
self.init_axis() self.init_axis()
self.init_datatype() self.init_datatype()
self.init_direction()
x = np.random.random((2, 3, 4, 5, 10)).astype(self.dtype) x = np.random.random((2, 3, 4, 5, 10)).astype(self.dtype)
self.attrs = {'axis': self.axis, 'descending': self.descending}
if self.axis < 0: if self.axis < 0:
self.axis = self.axis + len(x.shape) self.axis = self.axis + len(x.shape)
if self.descending:
self.indices = np.flip(
np.argsort(
x, kind='quicksort', axis=self.axis), self.axis)
self.out = np.flip(
np.sort(
x, kind='quicksort', axis=self.axis), self.axis)
else:
self.indices = np.argsort(x, kind='quicksort', axis=self.axis) self.indices = np.argsort(x, kind='quicksort', axis=self.axis)
self.out = np.sort(x, kind='quicksort', axis=self.axis) self.out = np.sort(x, kind='quicksort', axis=self.axis)
self.op_type = "argsort" self.op_type = "argsort"
self.inputs = {'X': x} self.inputs = {'X': x}
self.attrs = {'axis': self.axis}
self.outputs = {'Indices': self.indices, 'Out': self.out} self.outputs = {'Indices': self.indices, 'Out': self.out}
def init_axis(self): def init_axis(self):
...@@ -40,6 +50,9 @@ class TestArgsortOp(OpTest): ...@@ -40,6 +50,9 @@ class TestArgsortOp(OpTest):
def init_datatype(self): def init_datatype(self):
self.dtype = "float32" self.dtype = "float32"
def init_direction(self):
self.descending = False
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -103,5 +116,35 @@ class TestArgsortOpFP16Axis4Neg4(TestArgsortOpFP16): ...@@ -103,5 +116,35 @@ class TestArgsortOpFP16Axis4Neg4(TestArgsortOpFP16):
self.axis = -4 self.axis = -4
class TestArgsortOpDescendingAxis(TestArgsortOp):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis0(TestArgsortOpAxis0):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis1(TestArgsortOpAxis1):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis2(TestArgsortOpAxis2):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxisNeg1(TestArgsortOpAxisNeg1):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxisNeg2(TestArgsortOpAxisNeg2):
def init_direction(self):
self.descending = True
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册