未验证 提交 66c18f4a 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

[cherry-pick] Improve argsort performance. (#21267) (#21442)

* Improve argsort performance.

- Give 200000 data to compute argsort on v100,
can speed up ~190x
before opt cost: 0.53s
after opt cost:0.0027s

- Add fp16 support

* Refine error message
* Refine code
* Add descending sort

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>
上级 735a2db0
...@@ -73,6 +73,13 @@ Output(Indices) gives the sorted order along the given axis Attr(axis). ...@@ -73,6 +73,13 @@ Output(Indices) gives the sorted order along the given axis Attr(axis).
"When axis < 0, the actual axis will be the |axis|'th " "When axis < 0, the actual axis will be the |axis|'th "
"counting backwards. Default -1, the last dimension.") "counting backwards. Default -1, the last dimension.")
.SetDefault(-1); .SetDefault(-1);
AddAttr<bool>(
"descending",
"(bool, default false) The descending attribute is a flag to tell"
"algorithm how to sort the input data."
"If descending is true, will sort by descending order,"
"else if false, sort by ascending order. Default value is false.")
.SetDefault(false);
} }
}; };
......
...@@ -14,82 +14,150 @@ limitations under the License. */ ...@@ -14,82 +14,150 @@ limitations under the License. */
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <thrust/sort.h> #include <thrust/sort.h>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/argsort_op.h" #include "paddle/fluid/operators/argsort_op.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
// set cub base traits in order to handle float16
namespace cub {
template <>
struct NumericTraits<paddle::platform::float16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t,
paddle::platform::float16> {};
} // namespace cub
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using platform::PADDLE_CUDA_NUM_THREADS;
const int kMaxRank = 9; // The max rank of a tensor allowed in Fluid
__global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size,
int axis, int64_t n, int64_t* trg_idx,
int64_t* med_ids) {
int64_t index = threadIdx.x + blockDim.x * blockIdx.x;
if (index < n) {
int64_t shape_out_axis[kMaxRank - 1] = {0};
int64_t dims_out_axis[kMaxRank - 1] = {0};
int64_t tmp = index;
int64_t pos_in_axis = 0;
int64_t i = dims_size - 2;
int64_t dim_axis = 0;
for (int64_t j = dims_size - 1; j >= 0; --j) {
int64_t dim = in_dims[j];
if (j != axis) {
shape_out_axis[i] = tmp % dim;
dims_out_axis[i] = dim;
i--;
} else {
dim_axis = dim;
pos_in_axis = tmp % dim_axis;
}
tmp /= dim;
}
int64_t group = (dims_size > 1) ? shape_out_axis[0] : 0;
for (int64_t j = 0; j < dims_size - 2; ++j) {
group = group * dims_out_axis[j + 1] + shape_out_axis[j + 1];
}
int64_t traget_idx = group * dim_axis + pos_in_axis; // Iter for move to next row
trg_idx[index] = traget_idx; struct SegmentOffsetIter {
med_ids[traget_idx] = pos_in_axis; EIGEN_DEVICE_FUNC
} explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
}
template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
__global__ void PermuteInData(const T* in, const int64_t* trg_idx, int64_t n, return idx * num_cols_;
T* med_out) {
int index = threadIdx.x + blockDim.x * blockIdx.x;
if (index < n) {
med_out[trg_idx[index]] = in[index];
} }
}
int num_cols_;
};
template <typename T> template <typename T>
__global__ void Sort(int64_t axis_dim, int64_t groups, T* med_out, static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
int64_t* med_ids) { int col_id = threadIdx.x;
int index = threadIdx.x + blockDim.x * blockIdx.x; int row_id = blockIdx.x;
if (index < groups) {
thrust::sort_by_key(thrust::device, med_out + index * axis_dim, for (T j = row_id; j < num_rows; j += gridDim.x) {
med_out + axis_dim * (1 + index), for (T i = col_id; i < num_cols; i += blockDim.x) {
med_ids + index * axis_dim); indices[j * num_cols + i] = i;
}
} }
} }
template <typename T> // Sort by flag descending, True: descending. False: Ascending.
__global__ void PermuteMediateData(const T* med_out, const int64_t* med_ids, // Default is false.
const int64_t* trg_idx, int64_t n, T* out, template <typename T, typename IndType>
int64_t* indices) { void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
int index = threadIdx.x + blockDim.x * blockIdx.x; Tensor* output, Tensor* indices, const IndType num_rows,
if (index < n) { const IndType num_cols, const bool descending) {
out[index] = med_out[trg_idx[index]]; auto cu_stream = ctx.stream();
indices[index] = med_ids[trg_idx[index]];
Tensor input_indices;
const std::vector<IndType> dims = {num_rows, num_cols};
auto dim = framework::make_ddim(dims);
input_indices.Resize(dim);
input_indices.mutable_data<IndType>(ctx.GetPlace());
size_t temp_storage_bytes = -1;
auto ComputeBlockSize = [](IndType col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else if (col > 64 && col <= 128)
return 128;
else
return 64;
};
int block_size = ComputeBlockSize(num_cols);
int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
// actually, int num_rows < max_grid_size
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
// Init a index array
FillIndex<<<grid_size, block_size, 0, cu_stream>>>(
input_indices.data<IndType>(), num_rows, num_cols);
T* sorted_out_ptr;
IndType* sorted_indices_ptr;
const T* inp = input->data<T>();
T* out = output->mutable_data<T>(ctx.GetPlace());
IndType* ind = indices->mutable_data<IndType>(ctx.GetPlace());
sorted_out_ptr = out;
sorted_indices_ptr = ind;
// create iter for counting input
cub::CountingInputIterator<IndType> counting_iter(0);
// segment_offset is used for move to next row
cub::TransformInputIterator<IndType, SegmentOffsetIter,
cub::CountingInputIterator<IndType>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
cudaError_t err;
if (descending) {
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
nullptr, temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
} else {
err = cub::DeviceSegmentedRadixSort::SortPairs(
nullptr, temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
}
PADDLE_ENFORCE_CUDA_SUCCESS(
err,
"ArgSortOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate"
"temp_storage_bytes, status:%s.",
temp_storage_bytes, cudaGetErrorString(err));
Tensor temp_storage;
temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);
if (descending) {
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
} else {
err = cub::DeviceSegmentedRadixSort::SortPairs(
temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
} }
PADDLE_ENFORCE_CUDA_SUCCESS(
err,
"ArgSortOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
"temp_storage_bytes:%d status:%s.",
temp_storage_bytes, cudaGetErrorString(err));
} }
template <typename T> template <typename T>
...@@ -100,51 +168,76 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> { ...@@ -100,51 +168,76 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices"); auto* indices = ctx.Output<Tensor>("Indices");
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
bool descending = ctx.Attr<bool>("descending");
auto in_dims = input->dims(); auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis; axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* ids_data = indices->mutable_data<int64_t>(ctx.GetPlace());
int64_t numel = input->numel(); int64_t numel = input->numel();
int64_t groups = numel / in_dims[axis]; int64_t groups = numel / in_dims[axis];
std::vector<int64_t> in_dims_vec = vectorize(in_dims); // Special case for full sort, speedup ~190x.
thrust::device_vector<int64_t> in_dims_dev(in_dims_vec.begin(), if (axis == -1 || axis + 1 == in_dims.size()) {
in_dims_vec.end()); const int64_t input_height = framework::product(
int64_t* in_dims_data = thrust::raw_pointer_cast(in_dims_dev.data()); framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
// Mediate tensor for sorting data and indices const int64_t input_width = in_dims[in_dims.size() - 1];
Tensor mediate_output, mediate_indices; const auto& dev_ctx = ctx.cuda_device_context();
T* med_out_data = ArgFullSort<T, int64_t>(dev_ctx, input, output, indices, input_height,
mediate_output.mutable_data<T>(input->dims(), ctx.GetPlace()); input_width, descending);
int64_t* med_ids_data = } else {
mediate_indices.mutable_data<int64_t>(in_dims, ctx.GetPlace()); // if not full sort, do transpose first
// Target index of each element along the given axis in the mediate tensors std::vector<int> trans;
Tensor trg_idx_t; for (int i = 0; i < axis; i++) {
int64_t* trg_idx = trg_idx_t.mutable_data<int64_t>(in_dims, ctx.GetPlace()); trans.push_back(i);
}
auto stream = ctx.cuda_device_context().stream(); trans.push_back(in_dims.size() - 1);
const int num_threads = PADDLE_CUDA_NUM_THREADS; for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.push_back(i);
ComputeTargetIdx<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>( }
in_dims_data, in_dims.size(), axis, numel, trg_idx, med_ids_data); trans.push_back(axis);
framework::DDim trans_dims(in_dims);
PermuteInData<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>( for (int i = 0; i < trans.size(); i++) {
in_data, trg_idx, numel, med_out_data); trans_dims[i] = in_dims[trans[i]];
}
Sort<<<(groups - 1) / num_threads + 1, num_threads, 0, stream>>>(
in_dims[axis], groups, med_out_data, med_ids_data); Tensor trans_inp;
T* trans_inp_data = trans_inp.mutable_data<T>(trans_dims, ctx.GetPlace());
PermuteMediateData<<<(numel - 1) / num_threads + 1, num_threads, 0, int ndims = trans.size();
stream>>>(med_out_data, med_ids_data, trg_idx, numel, const auto& dev_ctx = ctx.cuda_device_context();
out_data, ids_data); // Do transpose
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *input,
&trans_inp, trans);
const int64_t input_height = framework::product(
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];
Tensor tmp_out;
tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());
T* out_data = output->mutable_data<T>(ctx.GetPlace());
Tensor tmp_indices;
// temp indices for sorting
tmp_indices.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
indices->mutable_data<int64_t>(ctx.GetPlace());
ArgFullSort<T, int64_t>(dev_ctx, &trans_inp, &tmp_out, &tmp_indices,
input_height, input_width, descending);
TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, tmp_indices, indices, trans);
// transpose back
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, tmp_out,
output, trans);
return;
}
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(argsort, paddle::operators::ArgsortOpCUDAKernel<float>, REGISTER_OP_CUDA_KERNEL(
paddle::operators::ArgsortOpCUDAKernel<double>); argsort, paddle::operators::ArgsortOpCUDAKernel<float>,
paddle::operators::ArgsortOpCUDAKernel<double>,
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::float16>);
...@@ -16,11 +16,58 @@ limitations under the License. */ ...@@ -16,11 +16,58 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/transpose_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
using Tensor = framework::Tensor;
template <typename T, typename Type>
static void FullSort(Type input_height, Type input_width, int input_dim,
const framework::Tensor* input, T* t_out, Type* t_indices,
bool descending) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (Type i = 0; i < input_height; ++i) {
std::vector<std::pair<T, Type>> col_vec;
col_vec.reserve(input_width);
if (input_dim == 1) {
auto e_input = EigenVector<T>::Flatten(*input);
for (Type j = 0; j < input_width; ++j) {
col_vec.push_back(std::pair<T, Type>(e_input(j), j));
}
} else {
auto e_input = EigenMatrix<T>::Reshape(*input, input_dim - 1);
for (Type j = 0; j < input_width; ++j) {
col_vec.push_back(std::pair<T, Type>(e_input(i, j), j));
}
}
std::sort(col_vec.begin(), col_vec.end(),
[&](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
if (descending)
return l.first > r.first;
else
return l.first < r.first;
});
for (Type j = 0; j < input_width; ++j) {
t_out[i * input_width + j] = col_vec[j].first;
t_indices[i * input_width + j] = col_vec[j].second;
}
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ArgsortKernel : public framework::OpKernel<T> { class ArgsortKernel : public framework::OpKernel<T> {
public: public:
...@@ -29,50 +76,68 @@ class ArgsortKernel : public framework::OpKernel<T> { ...@@ -29,50 +76,68 @@ class ArgsortKernel : public framework::OpKernel<T> {
auto* output = ctx.Output<framework::Tensor>("Out"); auto* output = ctx.Output<framework::Tensor>("Out");
auto* indices = ctx.Output<framework::Tensor>("Indices"); auto* indices = ctx.Output<framework::Tensor>("Indices");
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
bool descending = ctx.Attr<bool>("descending");
auto in_dims = input->dims(); auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis; axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(ctx.GetPlace()); T* out_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* ids_data = indices->mutable_data<int64_t>(ctx.GetPlace());
int64_t groups = input->numel() / in_dims[axis];
int64_t stride = (axis == in_dims.size() - 1)
? 1
: framework::product(framework::slice_ddim(
in_dims, axis + 1, in_dims.size()));
for (int64_t i = 0; i < groups; ++i) {
int64_t idx = i;
std::vector<int64_t> shape_vec(in_dims.size(), 0);
for (int64_t dim = in_dims.size() - 1; dim >= 0; --dim) {
if (dim != axis) {
shape_vec[dim] = idx % in_dims[dim];
idx /= in_dims[dim];
}
}
int64_t start_index = shape_vec[0]; // Do full sort
for (int64_t dim = 0; dim < in_dims.size() - 1; ++dim) { if (axis == -1 || axis + 1 == in_dims.size()) {
start_index = start_index * in_dims[dim + 1] + shape_vec[dim + 1]; const int64_t input_height = framework::product(
} framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
std::vector<int64_t> org_index_vec(in_dims[axis], start_index); int64_t* ids_data = indices->mutable_data<int64_t>(ctx.GetPlace());
for (int64_t j = 1; j < in_dims[axis]; ++j) { FullSort<T, int64_t>(input_height, input_width, in_dims.size(), input,
org_index_vec[j] += j * stride; out_data, ids_data, descending);
} else {
// If not full sort do transpose
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.push_back(i);
}
trans.push_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.push_back(i);
}
trans.push_back(axis);
framework::DDim trans_dims(in_dims);
for (int i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
} }
std::sort(org_index_vec.begin(), org_index_vec.end(), Tensor trans_inp;
[in_data](const int64_t v1, const int64_t v2) { trans_inp.mutable_data<T>(trans_dims, ctx.GetPlace());
return in_data[v1] < in_data[v2]; int ndims = trans.size();
}); auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
// Do transpose
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_ctx, *input,
&trans_inp, trans);
for (size_t j = 0; j < org_index_vec.size(); ++j) { const int64_t input_height = framework::product(
int64_t index = start_index + j * stride; framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
out_data[index] = in_data[org_index_vec[j]]; const int64_t input_width = trans_dims[trans_dims.size() - 1];
ids_data[index] = (org_index_vec[j] - start_index) / stride;
} Tensor tmp_out;
T* t_out = tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());
output->mutable_data<T>(ctx.GetPlace());
Tensor tmp_indices;
auto* t_ind =
tmp_indices.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
FullSort<T, int64_t>(input_height, input_width, in_dims.size(),
&trans_inp, t_out, t_ind, descending);
indices->mutable_data<int64_t>(ctx.GetPlace());
TransCompute<platform::CPUDeviceContext, int64_t>(
ndims, dev_ctx, tmp_indices, indices, trans);
// transpose back
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_ctx, tmp_out,
output, trans);
} }
} }
}; };
......
...@@ -53,7 +53,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) { ...@@ -53,7 +53,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) {
// align block_span to warpSize // align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5; int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (threadIdx.x < block_span) ? shared[lane] : (T)(0.0f); val = (threadIdx.x < block_span) ? shared[lane] : static_cast<T>(0.0f);
val = warpReduceSum<T>(val, mask); val = warpReduceSum<T>(val, mask);
return val; return val;
......
...@@ -838,7 +838,7 @@ def argmax(x, axis=0): ...@@ -838,7 +838,7 @@ def argmax(x, axis=0):
return out return out
def argsort(input, axis=-1, name=None): def argsort(input, axis=-1, descending=False, name=None):
""" """
This OP sorts the input along the given axis, and returns sorted output This OP sorts the input along the given axis, and returns sorted output
data Varibale and its corresponding index Variable with the same shape as data Varibale and its corresponding index Variable with the same shape as
...@@ -850,6 +850,9 @@ def argsort(input, axis=-1, name=None): ...@@ -850,6 +850,9 @@ def argsort(input, axis=-1, name=None):
axis(int, optional): Axis to compute indices along. The effective range axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is Rank(x). when axis<0, it works the same way is [-R, R), where R is Rank(x). when axis<0, it works the same way
as axis+R. Default is 0. as axis+R. Default is 0.
descending(bool, optional) : Descending is a flag, if set to true,
algorithm will sort by descending order, else sort by
ascending order. Default is false.
name(str, optional): The default value is None. Normally there is no name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`. refer to :ref:`api_guide_Name`.
...@@ -915,7 +918,8 @@ def argsort(input, axis=-1, name=None): ...@@ -915,7 +918,8 @@ def argsort(input, axis=-1, name=None):
inputs={'X': input}, inputs={'X': input},
outputs={'Out': out, outputs={'Out': out,
'Indices': ids}, 'Indices': ids},
attrs={'axis': axis}) attrs={'axis': axis,
'descending': descending})
return out, ids return out, ids
......
...@@ -17,24 +17,42 @@ from __future__ import print_function ...@@ -17,24 +17,42 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
class TestArgsortOp(OpTest): class TestArgsortOp(OpTest):
def setUp(self): def setUp(self):
self.init_axis() self.init_axis()
x = np.random.random((2, 3, 4, 5, 10)).astype("float32") self.init_datatype()
self.init_direction()
x = np.random.random((2, 3, 4, 5, 10)).astype(self.dtype)
self.attrs = {'axis': self.axis, 'descending': self.descending}
if self.axis < 0: if self.axis < 0:
self.axis = self.axis + len(x.shape) self.axis = self.axis + len(x.shape)
self.indices = np.argsort(x, kind='quicksort', axis=self.axis) if self.descending:
self.out = np.sort(x, kind='quicksort', axis=self.axis) self.indices = np.flip(
np.argsort(
x, kind='quicksort', axis=self.axis), self.axis)
self.out = np.flip(
np.sort(
x, kind='quicksort', axis=self.axis), self.axis)
else:
self.indices = np.argsort(x, kind='quicksort', axis=self.axis)
self.out = np.sort(x, kind='quicksort', axis=self.axis)
self.op_type = "argsort" self.op_type = "argsort"
self.inputs = {'X': x} self.inputs = {'X': x}
self.attrs = {'axis': self.axis}
self.outputs = {'Indices': self.indices, 'Out': self.out} self.outputs = {'Indices': self.indices, 'Out': self.out}
def init_axis(self): def init_axis(self):
self.axis = -1 self.axis = -1
def init_datatype(self):
self.dtype = "float32"
def init_direction(self):
self.descending = False
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -49,10 +67,84 @@ class TestArgsortOpAxis1(TestArgsortOp): ...@@ -49,10 +67,84 @@ class TestArgsortOpAxis1(TestArgsortOp):
self.axis = 1 self.axis = 1
class TestArgsortOpAxis2(TestArgsortOp):
def init_axis(self):
self.axis = 2
class TestArgsortOpAxisNeg1(TestArgsortOp):
def init_axis(self):
self.axis = -1
class TestArgsortOpAxisNeg2(TestArgsortOp): class TestArgsortOpAxisNeg2(TestArgsortOp):
def init_axis(self): def init_axis(self):
self.axis = -2 self.axis = -2
class TestArgsortOpFP16(TestArgsortOp):
def init_datatype(self):
if core.is_compiled_with_cuda():
self.dtype = 'float16'
def test_check_output(self):
pass
def test_check_output_with_place(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
class TestArgsortOpFP16Axis0(TestArgsortOpFP16):
def init_axis(self):
self.axis = 0
class TestArgsortOpFP16Axis2(TestArgsortOpFP16):
def init_axis(self):
self.axis = 2
class TestArgsortOpFP16AxisNeg2(TestArgsortOpFP16):
def init_axis(self):
self.axis = -2
class TestArgsortOpFP16Axis4Neg4(TestArgsortOpFP16):
def init_axis(self):
self.axis = -4
class TestArgsortOpDescendingAxis(TestArgsortOp):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis0(TestArgsortOpAxis0):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis1(TestArgsortOpAxis1):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis2(TestArgsortOpAxis2):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxisNeg1(TestArgsortOpAxisNeg1):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxisNeg2(TestArgsortOpAxisNeg2):
def init_direction(self):
self.descending = True
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册