From 66c18f4a6572859d3f47e7550aae65bdc12f1806 Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Tue, 3 Dec 2019 18:00:53 +0800 Subject: [PATCH] [cherry-pick] Improve argsort performance. (#21267) (#21442) * Improve argsort performance. - Give 200000 data to compute argsort on v100, can speed up ~190x before opt cost: 0.53s after opt cost:0.0027s - Add fp16 support * Refine error message * Refine code * Add descending sort test=develop Signed-off-by: zhaoyuchen --- paddle/fluid/operators/argsort_op.cc | 7 + paddle/fluid/operators/argsort_op.cu | 279 ++++++++++++------ paddle/fluid/operators/argsort_op.h | 133 ++++++--- paddle/fluid/operators/multihead_matmul_op.cu | 2 +- python/paddle/fluid/layers/tensor.py | 8 +- .../fluid/tests/unittests/test_argsort_op.py | 100 ++++++- 6 files changed, 395 insertions(+), 134 deletions(-) mode change 100644 => 100755 python/paddle/fluid/tests/unittests/test_argsort_op.py diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc index d25160f423..3f6d6a73d2 100644 --- a/paddle/fluid/operators/argsort_op.cc +++ b/paddle/fluid/operators/argsort_op.cc @@ -73,6 +73,13 @@ Output(Indices) gives the sorted order along the given axis Attr(axis). "When axis < 0, the actual axis will be the |axis|'th " "counting backwards. Default -1, the last dimension.") .SetDefault(-1); + AddAttr( + "descending", + "(bool, default false) The descending attribute is a flag to tell" + "algorithm how to sort the input data." + "If descending is true, will sort by descending order," + "else if false, sort by ascending order. Default value is false.") + .SetDefault(false); } }; diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu index 1a0b303817..8eb56e70b5 100644 --- a/paddle/fluid/operators/argsort_op.cu +++ b/paddle/fluid/operators/argsort_op.cu @@ -14,82 +14,150 @@ limitations under the License. */ #include #include +#include "cub/cub.cuh" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/argsort_op.h" +#include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" +// set cub base traits in order to handle float16 +namespace cub { +template <> +struct NumericTraits + : BaseTraits {}; +} // namespace cub + namespace paddle { namespace operators { using Tensor = framework::Tensor; -using platform::PADDLE_CUDA_NUM_THREADS; - -const int kMaxRank = 9; // The max rank of a tensor allowed in Fluid - -__global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size, - int axis, int64_t n, int64_t* trg_idx, - int64_t* med_ids) { - int64_t index = threadIdx.x + blockDim.x * blockIdx.x; - if (index < n) { - int64_t shape_out_axis[kMaxRank - 1] = {0}; - int64_t dims_out_axis[kMaxRank - 1] = {0}; - int64_t tmp = index; - int64_t pos_in_axis = 0; - int64_t i = dims_size - 2; - int64_t dim_axis = 0; - for (int64_t j = dims_size - 1; j >= 0; --j) { - int64_t dim = in_dims[j]; - if (j != axis) { - shape_out_axis[i] = tmp % dim; - dims_out_axis[i] = dim; - i--; - } else { - dim_axis = dim; - pos_in_axis = tmp % dim_axis; - } - tmp /= dim; - } - int64_t group = (dims_size > 1) ? shape_out_axis[0] : 0; - for (int64_t j = 0; j < dims_size - 2; ++j) { - group = group * dims_out_axis[j + 1] + shape_out_axis[j + 1]; - } - int64_t traget_idx = group * dim_axis + pos_in_axis; - trg_idx[index] = traget_idx; - med_ids[traget_idx] = pos_in_axis; - } -} +// Iter for move to next row +struct SegmentOffsetIter { + EIGEN_DEVICE_FUNC + explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} -template -__global__ void PermuteInData(const T* in, const int64_t* trg_idx, int64_t n, - T* med_out) { - int index = threadIdx.x + blockDim.x * blockIdx.x; - if (index < n) { - med_out[trg_idx[index]] = in[index]; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const { + return idx * num_cols_; } -} + + int num_cols_; +}; template -__global__ void Sort(int64_t axis_dim, int64_t groups, T* med_out, - int64_t* med_ids) { - int index = threadIdx.x + blockDim.x * blockIdx.x; - if (index < groups) { - thrust::sort_by_key(thrust::device, med_out + index * axis_dim, - med_out + axis_dim * (1 + index), - med_ids + index * axis_dim); +static __global__ void FillIndex(T* indices, T num_rows, T num_cols) { + int col_id = threadIdx.x; + int row_id = blockIdx.x; + + for (T j = row_id; j < num_rows; j += gridDim.x) { + for (T i = col_id; i < num_cols; i += blockDim.x) { + indices[j * num_cols + i] = i; + } } } -template -__global__ void PermuteMediateData(const T* med_out, const int64_t* med_ids, - const int64_t* trg_idx, int64_t n, T* out, - int64_t* indices) { - int index = threadIdx.x + blockDim.x * blockIdx.x; - if (index < n) { - out[index] = med_out[trg_idx[index]]; - indices[index] = med_ids[trg_idx[index]]; +// Sort by flag descending, True: descending. False: Ascending. +// Default is false. +template +void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input, + Tensor* output, Tensor* indices, const IndType num_rows, + const IndType num_cols, const bool descending) { + auto cu_stream = ctx.stream(); + + Tensor input_indices; + + const std::vector dims = {num_rows, num_cols}; + auto dim = framework::make_ddim(dims); + input_indices.Resize(dim); + input_indices.mutable_data(ctx.GetPlace()); + + size_t temp_storage_bytes = -1; + + auto ComputeBlockSize = [](IndType col) { + if (col > 512) + return 1024; + else if (col > 256 && col <= 512) + return 512; + else if (col > 128 && col <= 256) + return 256; + else if (col > 64 && col <= 128) + return 128; + else + return 64; + }; + + int block_size = ComputeBlockSize(num_cols); + + int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x; + // actually, int num_rows < max_grid_size + int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX; + // Init a index array + FillIndex<<>>( + input_indices.data(), num_rows, num_cols); + + T* sorted_out_ptr; + IndType* sorted_indices_ptr; + + const T* inp = input->data(); + T* out = output->mutable_data(ctx.GetPlace()); + IndType* ind = indices->mutable_data(ctx.GetPlace()); + + sorted_out_ptr = out; + sorted_indices_ptr = ind; + + // create iter for counting input + cub::CountingInputIterator counting_iter(0); + // segment_offset is used for move to next row + cub::TransformInputIterator> + segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols)); + + cudaError_t err; + if (descending) { + err = cub::DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, temp_storage_bytes, inp, sorted_out_ptr, + input_indices.data(), sorted_indices_ptr, num_cols * num_rows, + num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, + cu_stream); + } else { + err = cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, temp_storage_bytes, inp, sorted_out_ptr, + input_indices.data(), sorted_indices_ptr, num_cols * num_rows, + num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, + cu_stream); + } + PADDLE_ENFORCE_CUDA_SUCCESS( + err, + "ArgSortOP failed as could not launch " + "cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate" + "temp_storage_bytes, status:%s.", + temp_storage_bytes, cudaGetErrorString(err)); + + Tensor temp_storage; + temp_storage.mutable_data(ctx.GetPlace(), temp_storage_bytes); + + if (descending) { + err = cub::DeviceSegmentedRadixSort::SortPairsDescending( + temp_storage.data(), temp_storage_bytes, inp, sorted_out_ptr, + input_indices.data(), sorted_indices_ptr, num_cols * num_rows, + num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, + cu_stream); + } else { + err = cub::DeviceSegmentedRadixSort::SortPairs( + temp_storage.data(), temp_storage_bytes, inp, sorted_out_ptr, + input_indices.data(), sorted_indices_ptr, num_cols * num_rows, + num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, + cu_stream); } + + PADDLE_ENFORCE_CUDA_SUCCESS( + err, + "ArgSortOP failed as could not launch " + "cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, " + "temp_storage_bytes:%d status:%s.", + temp_storage_bytes, cudaGetErrorString(err)); } template @@ -100,51 +168,76 @@ class ArgsortOpCUDAKernel : public framework::OpKernel { auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); int axis = ctx.Attr("axis"); + bool descending = ctx.Attr("descending"); auto in_dims = input->dims(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; - const T* in_data = input->data(); - T* out_data = output->mutable_data(ctx.GetPlace()); - int64_t* ids_data = indices->mutable_data(ctx.GetPlace()); - int64_t numel = input->numel(); int64_t groups = numel / in_dims[axis]; - std::vector in_dims_vec = vectorize(in_dims); - thrust::device_vector in_dims_dev(in_dims_vec.begin(), - in_dims_vec.end()); - int64_t* in_dims_data = thrust::raw_pointer_cast(in_dims_dev.data()); - // Mediate tensor for sorting data and indices - Tensor mediate_output, mediate_indices; - T* med_out_data = - mediate_output.mutable_data(input->dims(), ctx.GetPlace()); - int64_t* med_ids_data = - mediate_indices.mutable_data(in_dims, ctx.GetPlace()); - // Target index of each element along the given axis in the mediate tensors - Tensor trg_idx_t; - int64_t* trg_idx = trg_idx_t.mutable_data(in_dims, ctx.GetPlace()); - - auto stream = ctx.cuda_device_context().stream(); - const int num_threads = PADDLE_CUDA_NUM_THREADS; - - ComputeTargetIdx<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>( - in_dims_data, in_dims.size(), axis, numel, trg_idx, med_ids_data); - - PermuteInData<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>( - in_data, trg_idx, numel, med_out_data); - - Sort<<<(groups - 1) / num_threads + 1, num_threads, 0, stream>>>( - in_dims[axis], groups, med_out_data, med_ids_data); - - PermuteMediateData<<<(numel - 1) / num_threads + 1, num_threads, 0, - stream>>>(med_out_data, med_ids_data, trg_idx, numel, - out_data, ids_data); + // Special case for full sort, speedup ~190x. + if (axis == -1 || axis + 1 == in_dims.size()) { + const int64_t input_height = framework::product( + framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t input_width = in_dims[in_dims.size() - 1]; + const auto& dev_ctx = ctx.cuda_device_context(); + ArgFullSort(dev_ctx, input, output, indices, input_height, + input_width, descending); + } else { + // if not full sort, do transpose first + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.push_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.push_back(i); + } + trans.push_back(axis); + framework::DDim trans_dims(in_dims); + for (int i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; + } + + Tensor trans_inp; + T* trans_inp_data = trans_inp.mutable_data(trans_dims, ctx.GetPlace()); + int ndims = trans.size(); + const auto& dev_ctx = ctx.cuda_device_context(); + // Do transpose + TransCompute(ndims, dev_ctx, *input, + &trans_inp, trans); + + const int64_t input_height = framework::product( + framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const int64_t input_width = trans_dims[trans_dims.size() - 1]; + + Tensor tmp_out; + tmp_out.mutable_data(trans_dims, ctx.GetPlace()); + T* out_data = output->mutable_data(ctx.GetPlace()); + + Tensor tmp_indices; + // temp indices for sorting + tmp_indices.mutable_data(trans_dims, ctx.GetPlace()); + indices->mutable_data(ctx.GetPlace()); + + ArgFullSort(dev_ctx, &trans_inp, &tmp_out, &tmp_indices, + input_height, input_width, descending); + + TransCompute( + ndims, dev_ctx, tmp_indices, indices, trans); + // transpose back + TransCompute(ndims, dev_ctx, tmp_out, + output, trans); + return; + } } }; } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(argsort, paddle::operators::ArgsortOpCUDAKernel, - paddle::operators::ArgsortOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + argsort, paddle::operators::ArgsortOpCUDAKernel, + paddle::operators::ArgsortOpCUDAKernel, + paddle::operators::ArgsortOpCUDAKernel); diff --git a/paddle/fluid/operators/argsort_op.h b/paddle/fluid/operators/argsort_op.h index 7e9112cfb7..cd9f9a9ea2 100644 --- a/paddle/fluid/operators/argsort_op.h +++ b/paddle/fluid/operators/argsort_op.h @@ -16,11 +16,58 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/transpose_op.h" namespace paddle { namespace operators { +template +using EigenMatrix = framework::EigenMatrix; + +template +using EigenVector = framework::EigenVector; + +using Tensor = framework::Tensor; + +template +static void FullSort(Type input_height, Type input_width, int input_dim, + const framework::Tensor* input, T* t_out, Type* t_indices, + bool descending) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (Type i = 0; i < input_height; ++i) { + std::vector> col_vec; + col_vec.reserve(input_width); + if (input_dim == 1) { + auto e_input = EigenVector::Flatten(*input); + for (Type j = 0; j < input_width; ++j) { + col_vec.push_back(std::pair(e_input(j), j)); + } + } else { + auto e_input = EigenMatrix::Reshape(*input, input_dim - 1); + for (Type j = 0; j < input_width; ++j) { + col_vec.push_back(std::pair(e_input(i, j), j)); + } + } + std::sort(col_vec.begin(), col_vec.end(), + [&](const std::pair& l, const std::pair& r) { + if (descending) + return l.first > r.first; + else + return l.first < r.first; + }); + + for (Type j = 0; j < input_width; ++j) { + t_out[i * input_width + j] = col_vec[j].first; + t_indices[i * input_width + j] = col_vec[j].second; + } + } +} template class ArgsortKernel : public framework::OpKernel { public: @@ -29,50 +76,68 @@ class ArgsortKernel : public framework::OpKernel { auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); int axis = ctx.Attr("axis"); + bool descending = ctx.Attr("descending"); auto in_dims = input->dims(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; - const T* in_data = input->data(); T* out_data = output->mutable_data(ctx.GetPlace()); - int64_t* ids_data = indices->mutable_data(ctx.GetPlace()); - - int64_t groups = input->numel() / in_dims[axis]; - int64_t stride = (axis == in_dims.size() - 1) - ? 1 - : framework::product(framework::slice_ddim( - in_dims, axis + 1, in_dims.size())); - - for (int64_t i = 0; i < groups; ++i) { - int64_t idx = i; - std::vector shape_vec(in_dims.size(), 0); - for (int64_t dim = in_dims.size() - 1; dim >= 0; --dim) { - if (dim != axis) { - shape_vec[dim] = idx % in_dims[dim]; - idx /= in_dims[dim]; - } - } - int64_t start_index = shape_vec[0]; - for (int64_t dim = 0; dim < in_dims.size() - 1; ++dim) { - start_index = start_index * in_dims[dim + 1] + shape_vec[dim + 1]; - } + // Do full sort + if (axis == -1 || axis + 1 == in_dims.size()) { + const int64_t input_height = framework::product( + framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t input_width = in_dims[in_dims.size() - 1]; - std::vector org_index_vec(in_dims[axis], start_index); - for (int64_t j = 1; j < in_dims[axis]; ++j) { - org_index_vec[j] += j * stride; + int64_t* ids_data = indices->mutable_data(ctx.GetPlace()); + FullSort(input_height, input_width, in_dims.size(), input, + out_data, ids_data, descending); + } else { + // If not full sort do transpose + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.push_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.push_back(i); + } + trans.push_back(axis); + framework::DDim trans_dims(in_dims); + for (int i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; } - std::sort(org_index_vec.begin(), org_index_vec.end(), - [in_data](const int64_t v1, const int64_t v2) { - return in_data[v1] < in_data[v2]; - }); + Tensor trans_inp; + trans_inp.mutable_data(trans_dims, ctx.GetPlace()); + int ndims = trans.size(); + auto& dev_ctx = ctx.template device_context(); + // Do transpose + TransCompute(ndims, dev_ctx, *input, + &trans_inp, trans); - for (size_t j = 0; j < org_index_vec.size(); ++j) { - int64_t index = start_index + j * stride; - out_data[index] = in_data[org_index_vec[j]]; - ids_data[index] = (org_index_vec[j] - start_index) / stride; - } + const int64_t input_height = framework::product( + framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const int64_t input_width = trans_dims[trans_dims.size() - 1]; + + Tensor tmp_out; + T* t_out = tmp_out.mutable_data(trans_dims, ctx.GetPlace()); + output->mutable_data(ctx.GetPlace()); + + Tensor tmp_indices; + + auto* t_ind = + tmp_indices.mutable_data(trans_dims, ctx.GetPlace()); + + FullSort(input_height, input_width, in_dims.size(), + &trans_inp, t_out, t_ind, descending); + + indices->mutable_data(ctx.GetPlace()); + TransCompute( + ndims, dev_ctx, tmp_indices, indices, trans); + // transpose back + TransCompute(ndims, dev_ctx, tmp_out, + output, trans); } } }; diff --git a/paddle/fluid/operators/multihead_matmul_op.cu b/paddle/fluid/operators/multihead_matmul_op.cu index 8728fd9d21..74bc7731a9 100644 --- a/paddle/fluid/operators/multihead_matmul_op.cu +++ b/paddle/fluid/operators/multihead_matmul_op.cu @@ -53,7 +53,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) { // align block_span to warpSize int block_span = (blockDim.x + warpSize - 1) >> 5; - val = (threadIdx.x < block_span) ? shared[lane] : (T)(0.0f); + val = (threadIdx.x < block_span) ? shared[lane] : static_cast(0.0f); val = warpReduceSum(val, mask); return val; diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 772b60a492..16add7e448 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -838,7 +838,7 @@ def argmax(x, axis=0): return out -def argsort(input, axis=-1, name=None): +def argsort(input, axis=-1, descending=False, name=None): """ This OP sorts the input along the given axis, and returns sorted output data Varibale and its corresponding index Variable with the same shape as @@ -850,6 +850,9 @@ def argsort(input, axis=-1, name=None): axis(int, optional): Axis to compute indices along. The effective range is [-R, R), where R is Rank(x). when axis<0, it works the same way as axis+R. Default is 0. + descending(bool, optional) : Descending is a flag, if set to true, + algorithm will sort by descending order, else sort by + ascending order. Default is false. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -915,7 +918,8 @@ def argsort(input, axis=-1, name=None): inputs={'X': input}, outputs={'Out': out, 'Indices': ids}, - attrs={'axis': axis}) + attrs={'axis': axis, + 'descending': descending}) return out, ids diff --git a/python/paddle/fluid/tests/unittests/test_argsort_op.py b/python/paddle/fluid/tests/unittests/test_argsort_op.py old mode 100644 new mode 100755 index 7bc6f2599d..89ff5d7101 --- a/python/paddle/fluid/tests/unittests/test_argsort_op.py +++ b/python/paddle/fluid/tests/unittests/test_argsort_op.py @@ -17,24 +17,42 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle.fluid.core as core class TestArgsortOp(OpTest): def setUp(self): self.init_axis() - x = np.random.random((2, 3, 4, 5, 10)).astype("float32") + self.init_datatype() + self.init_direction() + x = np.random.random((2, 3, 4, 5, 10)).astype(self.dtype) + self.attrs = {'axis': self.axis, 'descending': self.descending} if self.axis < 0: self.axis = self.axis + len(x.shape) - self.indices = np.argsort(x, kind='quicksort', axis=self.axis) - self.out = np.sort(x, kind='quicksort', axis=self.axis) + if self.descending: + self.indices = np.flip( + np.argsort( + x, kind='quicksort', axis=self.axis), self.axis) + self.out = np.flip( + np.sort( + x, kind='quicksort', axis=self.axis), self.axis) + else: + self.indices = np.argsort(x, kind='quicksort', axis=self.axis) + self.out = np.sort(x, kind='quicksort', axis=self.axis) + self.op_type = "argsort" self.inputs = {'X': x} - self.attrs = {'axis': self.axis} self.outputs = {'Indices': self.indices, 'Out': self.out} def init_axis(self): self.axis = -1 + def init_datatype(self): + self.dtype = "float32" + + def init_direction(self): + self.descending = False + def test_check_output(self): self.check_output() @@ -49,10 +67,84 @@ class TestArgsortOpAxis1(TestArgsortOp): self.axis = 1 +class TestArgsortOpAxis2(TestArgsortOp): + def init_axis(self): + self.axis = 2 + + +class TestArgsortOpAxisNeg1(TestArgsortOp): + def init_axis(self): + self.axis = -1 + + class TestArgsortOpAxisNeg2(TestArgsortOp): def init_axis(self): self.axis = -2 +class TestArgsortOpFP16(TestArgsortOp): + def init_datatype(self): + if core.is_compiled_with_cuda(): + self.dtype = 'float16' + + def test_check_output(self): + pass + + def test_check_output_with_place(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-5) + + +class TestArgsortOpFP16Axis0(TestArgsortOpFP16): + def init_axis(self): + self.axis = 0 + + +class TestArgsortOpFP16Axis2(TestArgsortOpFP16): + def init_axis(self): + self.axis = 2 + + +class TestArgsortOpFP16AxisNeg2(TestArgsortOpFP16): + def init_axis(self): + self.axis = -2 + + +class TestArgsortOpFP16Axis4Neg4(TestArgsortOpFP16): + def init_axis(self): + self.axis = -4 + + +class TestArgsortOpDescendingAxis(TestArgsortOp): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis0(TestArgsortOpAxis0): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis1(TestArgsortOpAxis1): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis2(TestArgsortOpAxis2): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg1(TestArgsortOpAxisNeg1): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg2(TestArgsortOpAxisNeg2): + def init_direction(self): + self.descending = True + + if __name__ == "__main__": unittest.main() -- GitLab