From 0a895bc0dfd0c1075e5ef8fd7b9fd63ceda31b92 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Tue, 25 Aug 2020 12:59:22 +0800 Subject: [PATCH] improve unique op (#26537) * add unique_v2 op * remove unique_v2 op * update doc --- paddle/fluid/operators/unique_op.cc | 113 +++++++-- paddle/fluid/operators/unique_op.h | 239 +++++++++++++++++- paddle/fluid/pybind/op_function_generator.cc | 1 + .../fluid/tests/unittests/test_unique.py | 160 ++++++++++++ python/paddle/tensor/manipulation.py | 121 ++++++++- 5 files changed, 613 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/operators/unique_op.cc b/paddle/fluid/operators/unique_op.cc index c141033b2b3..1aea96a15eb 100644 --- a/paddle/fluid/operators/unique_op.cc +++ b/paddle/fluid/operators/unique_op.cc @@ -24,17 +24,63 @@ class UniqueOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "unique"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "unique"); - OP_INOUT_CHECK(ctx->HasOutput("Index"), "Output", "Index", "unique"); - auto in_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ( - in_dims.size(), 1, - platform::errors::InvalidArgument("The Input(X) should be 1-D Tensor, " - "But now the dims of Input(X) is %d.", - in_dims.size())); + if (!ctx->Attrs().Get("is_sorted")) { + OP_INOUT_CHECK(ctx->HasOutput("Index"), "Output", "Index", "unique"); + PADDLE_ENFORCE_EQ(in_dims.size(), 1, + platform::errors::InvalidArgument( + "The Input(X) should be 1-D Tensor, " + "But now the dims of Input(X) is %d.", + in_dims.size())); + + ctx->SetOutputDim("Out", {-1}); + ctx->SetOutputDim("Index", in_dims); + return; + } + + bool return_index = ctx->Attrs().Get("return_index"); + bool return_inverse = ctx->Attrs().Get("return_inverse"); + bool return_counts = ctx->Attrs().Get("return_counts"); + auto axis_vec = ctx->Attrs().Get>("axis"); + + if (return_index) { + OP_INOUT_CHECK(ctx->HasOutput("Indices"), "Output", "Indices", "unique"); + } + if (return_inverse) { + OP_INOUT_CHECK(ctx->HasOutput("Index"), "Output", "Index", "unique"); + } + if (return_counts) { + OP_INOUT_CHECK(ctx->HasOutput("Counts"), "Output", "Counts", "unique"); + } - ctx->SetOutputDim("Out", {-1}); - ctx->SetOutputDim("Index", in_dims); + if (axis_vec.empty()) { + ctx->SetOutputDim("Out", {-1}); + if (return_inverse) { + ctx->SetOutputDim("Index", {framework::product(in_dims)}); + } + } else { + int axis = axis_vec[0]; + if (axis < 0) { + axis += in_dims.size(); + } + PADDLE_ENFORCE_LT( + axis, in_dims.size(), + platform::errors::InvalidArgument("The axis(%d) should be less than " + "the dimension size(%d) of x.", + axis, in_dims.size())); + auto out_dims = in_dims; + out_dims[axis] = -1; + ctx->SetOutputDim("Out", out_dims); + if (return_inverse) { + ctx->SetOutputDim("Index", {in_dims[axis]}); + } + } + if (return_index) { + ctx->SetOutputDim("Indices", {-1}); + } + if (return_counts) { + ctx->SetOutputDim("Counts", {-1}); + } } protected: @@ -49,14 +95,47 @@ class UniqueOp : public framework::OperatorWithKernel { class UniqueOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "Input tensor. It should be a 1-D tensor."); + AddInput("X", + "Input tensor. It should be a 1-D tensor when Attr(is_sorted)" + " is fasle or a N-D tensor when Attr(is_sorted) is true."); AddAttr("dtype", "data type for output index"); AddOutput("Out", "A unique subsequence for input tensor."); AddOutput("Index", - "An index tensor pointing to unique subsequence, which has " - "identical shape with input tensor and int64 dtype."); + "Equivalent to inverse in numpy.unique, " + "the indices for where elements in the original input ended up " + "in the returned unique tensor."); + AddOutput( + "Indices", + "The indices of the input tensor that result in the unique tensor.") + .AsDispensable(); + AddOutput("Counts", "The counts for each unique element.").AsDispensable(); + AddAttr("return_index", + "If True, also return the indices of the input" + " tensor that result in the unique Tensor.") + .SetDefault(false); + AddAttr( + "return_inverse", + "If True, also return the indices for where elements" + " in the original input ended up in the returned unique tensor.") + .SetDefault(false); + AddAttr("return_counts", + "If True, also return the counts for each unique element.") + .SetDefault(false); + AddAttr>( + "axis", + "The axis to apply unique. If None, the input will be flattened.") + .SetDefault({}); + AddAttr("is_sorted", + "If True, the unique elements of X are in ascending order." + "Otherwise, the unique elements are not sorted.") + .SetDefault(false); AddComment(R"DOC( - Return a unique subsequence for 1-D input tensor, and an index tensor pointing to this unique subsequence + 1. Return a unique subsequence for 1-D input tensor, and an index tensor + pointing to this unique subsequence when Attr(is_sorted) is false. This + means paddle.unique is called. + + 2. Returns the unique elements of X in ascending order when Attr(is_sorted) + is true. This means fluid.layers.unique is called. )DOC"); } }; @@ -65,6 +144,8 @@ class UniqueOpMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(unique, ops::UniqueOp, ops::UniqueOpMaker); -REGISTER_OP_CPU_KERNEL(unique, ops::UniqueKernel, - ops::UniqueKernel, ops::UniqueKernel, - ops::UniqueKernel); +REGISTER_OP_CPU_KERNEL( + unique, ops::UniqueKernel, + ops::UniqueKernel, + ops::UniqueKernel, + ops::UniqueKernel); diff --git a/paddle/fluid/operators/unique_op.h b/paddle/fluid/operators/unique_op.h index cdfd797cbfd..dc8b2ac5555 100644 --- a/paddle/fluid/operators/unique_op.h +++ b/paddle/fluid/operators/unique_op.h @@ -13,12 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include +#include +#include #include #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/transpose_op.h" namespace paddle { namespace operators { @@ -104,17 +109,243 @@ struct UniqueOpFunctor { } }; +static std::vector Unbind(const framework::Tensor& in) { + int64_t size = in.dims()[0]; + std::vector tensors(size); + for (int64_t i = 0; i < size; ++i) { + tensors[i] = in.Slice(i, i + 1); + } + return tensors; +} + +template +static bool Equal(const framework::Tensor& a, const framework::Tensor& b) { + if (a.numel() != b.numel()) { + return false; + } + for (int64_t i = 0; i < a.numel(); ++i) { + if (a.data()[i] != b.data()[i]) { + return false; + } + } + return true; +} + template +static void UniqueFlattendTensor(const framework::ExecutionContext& context, + const framework::Tensor& in, + framework::Tensor* out, bool return_index, + bool return_inverse, bool return_counts) { + const T* in_data = in.data(); + std::set unique(in_data, in_data + in.numel()); + out->Resize(framework::make_ddim({static_cast(unique.size())})); + auto out_data = out->mutable_data(context.GetPlace()); + std::copy(unique.begin(), unique.end(), out_data); + + if (return_index) { + auto* indices = context.Output("Indices"); + indices->Resize(framework::make_ddim({out->numel()})); + auto indices_data = indices->mutable_data(context.GetPlace()); + std::unordered_map indices_map; + indices_map.reserve(out->numel()); + for (int64_t i = 0; i < in.numel(); ++i) { + if (indices_map.find(in_data[i]) != indices_map.end()) continue; + indices_map[in_data[i]] = i; + } + for (int64_t i = 0; i < out->numel(); ++i) { + indices_data[i] = indices_map[out_data[i]]; + } + } + + if (return_inverse) { + auto* inverse = context.Output("Index"); + inverse->Resize(framework::make_ddim({in.numel()})); + auto inverse_data = inverse->mutable_data(context.GetPlace()); + std::unordered_map inverse_map; + inverse_map.reserve(out->numel()); + for (int64_t i = 0; i < out->numel(); ++i) { + inverse_map[out_data[i]] = i; + } + for (int64_t i = 0; i < in.numel(); ++i) { + inverse_data[i] = inverse_map[in_data[i]]; + } + } + + if (return_counts) { + auto* count = context.Output("Counts"); + count->Resize(framework::make_ddim({out->numel()})); + auto count_data = count->mutable_data(context.GetPlace()); + std::unordered_map counts_map; + counts_map.reserve(out->numel()); + for (int64_t i = 0; i < out->numel(); ++i) { + counts_map[out_data[i]] = 0; + } + for (int64_t i = 0; i < in.numel(); i++) { + counts_map[in_data[i]] += 1; + } + for (int64_t i = 0; i < out->numel(); i++) { + count_data[i] = counts_map[out_data[i]]; + } + } +} + +template +static ForwardIt UniqueDimImpl(const framework::ExecutionContext& context, + ForwardIt first, ForwardIt last, + const std::vector& sorted_indices_vec, + std::vector* inverse_vec, + std::vector* counts_vec, + std::vector* indices_vec) { + if (first == last) { + return last; + } + + (*inverse_vec)[sorted_indices_vec[0]] = 0; + (*counts_vec)[0] = 1; + (*indices_vec)[0] = sorted_indices_vec[0]; + + ForwardIt begin = first; + ForwardIt result = first; + + while (++first != last) { + int64_t idx_first = std::distance(begin, first); + int64_t idx_result = std::distance(begin, result); + if (!Equal(*result, *first)) { + if (++result != first) { + *result = std::move(*first); + } + idx_result += 1; + (*indices_vec)[idx_result] = sorted_indices_vec[idx_first]; + } + (*inverse_vec)[sorted_indices_vec[idx_first]] = idx_result; + (*counts_vec)[idx_result] += 1; + } + return ++result; +} + +template +static void UniqueDim(const framework::ExecutionContext& context, + const framework::Tensor& in, framework::Tensor* out, + bool return_index, bool return_inverse, + bool return_counts, int axis) { + // transpose tensor: eg. axis=1, [dim0, dim1, dim2] -> [dim1, dim0, dim2] + std::vector permute(in.dims().size()); + std::iota(permute.begin(), permute.end(), 0); + permute[axis] = 0; + permute[0] = axis; + std::vector in_trans_dims_vec(framework::vectorize(in.dims())); + in_trans_dims_vec[axis] = in.dims()[0]; + in_trans_dims_vec[0] = in.dims()[axis]; + framework::Tensor in_trans; + framework::DDim in_trans_dims = framework::make_ddim(in_trans_dims_vec); + in_trans.Resize(in_trans_dims); + in_trans.mutable_data(context.GetPlace()); + auto& dev_ctx = context.template device_context(); + TransCompute(in.dims().size(), dev_ctx, in, &in_trans, + permute); + // reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2] + framework::DDim in_trans_flat_dims = + framework::flatten_to_2d(in_trans_dims, 1); + in_trans.Resize(in_trans_flat_dims); + + // sort indices + std::vector sorted_indices_vec(in_trans.dims()[0]); + std::iota(sorted_indices_vec.begin(), sorted_indices_vec.end(), 0); + int64_t col = in_trans.dims()[1]; + const T* in_trans_data = in_trans.data(); + std::sort(sorted_indices_vec.begin(), sorted_indices_vec.end(), + [&](int64_t a, int64_t b) -> bool { + for (int64_t i = 0; i < col; ++i) { + T lhs = in_trans_data[i + a * col]; + T rhs = in_trans_data[i + b * col]; + if (lhs < rhs) { + return true; + } else if (lhs > rhs) { + return false; + } + } + return false; + }); + + // sort tensor according to indices + framework::Tensor input_sorted; + input_sorted.Resize(in_trans_dims); + input_sorted.mutable_data(context.GetPlace()); + T* input_sorted_data = input_sorted.data(); + for (size_t i = 0; i < sorted_indices_vec.size(); ++i) { + memcpy(input_sorted_data + i * col, + in_trans_data + sorted_indices_vec[i] * col, col * sizeof(T)); + } + + std::vector input_unbind = Unbind(input_sorted); + std::vector inverse_vec(sorted_indices_vec.size(), 0); + std::vector counts_vec(sorted_indices_vec.size(), 0); + std::vector indices_vec(sorted_indices_vec.size(), 0); + auto last = UniqueDimImpl::iterator, T>( + context, input_unbind.begin(), input_unbind.end(), sorted_indices_vec, + &inverse_vec, &counts_vec, &indices_vec); + input_unbind.erase(last, input_unbind.end()); + counts_vec.erase(counts_vec.begin() + input_unbind.size(), counts_vec.end()); + indices_vec.erase(indices_vec.begin() + input_unbind.size(), + indices_vec.end()); + + math::ConcatFunctor concat_functor; + framework::Tensor out_trans; + std::vector out_trans_dims_vec = in_trans_dims_vec; + out_trans_dims_vec[0] = input_unbind.size(); + out_trans.Resize(framework::make_ddim(out_trans_dims_vec)); + out_trans.mutable_data(context.GetPlace()); + std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]); + out->Resize(framework::make_ddim(out_trans_dims_vec)); + out->mutable_data(context.GetPlace()); + concat_functor(dev_ctx, input_unbind, 0, &out_trans); + TransCompute(out_trans.dims().size(), dev_ctx, out_trans, + out, permute); + + if (return_inverse) { + auto* inverse = context.Output("Index"); + framework::TensorFromVector(inverse_vec, context.device_context(), inverse); + } + + if (return_counts) { + auto* count = context.Output("Counts"); + framework::TensorFromVector(counts_vec, context.device_context(), count); + } + + if (return_index) { + auto* indices = context.Output("Indices"); + framework::TensorFromVector(indices_vec, context.device_context(), indices); + } +} + +template class UniqueKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto data_type = static_cast( - context.Attr("dtype")); auto* x = context.Input("X"); auto* out = context.Output("Out"); - auto* index = context.Output("Index"); + if (!context.Attr("is_sorted")) { + auto data_type = static_cast( + context.Attr("dtype")); + auto* index = context.Output("Index"); + + framework::VisitDataType(data_type, UniqueOpFunctor(out, index, x)); + return; + } - framework::VisitDataType(data_type, UniqueOpFunctor(out, index, x)); + std::vector axis_vec = context.Attr>("axis"); + bool return_index = context.Attr("return_index"); + bool return_inverse = context.Attr("return_inverse"); + bool return_counts = context.Attr("return_counts"); + + if (axis_vec.empty()) { + UniqueFlattendTensor(context, *x, out, return_index, return_inverse, + return_counts); + } else { + int axis = axis_vec[0]; + UniqueDim(context, *x, out, return_index, + return_inverse, return_counts, axis); + } } }; diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index ec458ee7957..d7126b95865 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -62,6 +62,7 @@ std::map> op_outs_map = { {"sync_batch_norm", {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, + {"unique", {"Out", "Index", "Indices", "Counts"}}, }; // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are diff --git a/python/paddle/fluid/tests/unittests/test_unique.py b/python/paddle/fluid/tests/unittests/test_unique.py index 65194524adf..ae36f8a9861 100644 --- a/python/paddle/fluid/tests/unittests/test_unique.py +++ b/python/paddle/fluid/tests/unittests/test_unique.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid.op import Operator @@ -125,5 +126,164 @@ class TestRandomGPU(TestUniqueOp): self.check_output_with_place(place, atol=1e-5) +class TestSortedUniqueOp(TestUniqueOp): + def init_config(self): + self.inputs = {'X': np.array([2, 3, 3, 1, 5, 3], dtype='int64')} + unique, indices, inverse, count = np.unique( + self.inputs['X'], + return_index=True, + return_inverse=True, + return_counts=True, + axis=None) + self.attrs = { + 'dtype': int(core.VarDesc.VarType.INT32), + "return_index": True, + "return_inverse": True, + "return_counts": True, + "axis": None, + "is_sorted": True + } + self.outputs = { + 'Out': unique, + 'Indices': indices, + "Index": inverse, + "Counts": count, + } + + +class TestUniqueOpAxisNone(TestUniqueOp): + def init_config(self): + self.inputs = {'X': np.random.random((4, 7, 10)).astype('float64')} + unique, indices, inverse, counts = np.unique( + self.inputs['X'], + return_index=True, + return_inverse=True, + return_counts=True, + axis=None) + self.attrs = { + 'dtype': int(core.VarDesc.VarType.INT32), + "return_index": True, + "return_inverse": True, + "return_counts": True, + "axis": None, + "is_sorted": True + } + self.outputs = { + 'Out': unique, + 'Indices': indices, + "Index": inverse, + "Counts": counts, + } + + +class TestUniqueOpAxis1(TestUniqueOp): + def init_config(self): + self.inputs = {'X': np.random.random((3, 8, 8)).astype('float64')} + unique, indices, inverse, counts = np.unique( + self.inputs['X'], + return_index=True, + return_inverse=True, + return_counts=True, + axis=1) + self.attrs = { + 'dtype': int(core.VarDesc.VarType.INT32), + "return_index": True, + "return_inverse": True, + "return_counts": True, + "axis": [1], + "is_sorted": True + } + self.outputs = { + 'Out': unique, + 'Indices': indices, + "Index": inverse, + "Counts": counts, + } + + +class TestUniqueAPI(unittest.TestCase): + def test_dygraph_api_out(self): + paddle.disable_static() + x_data = x_data = np.random.randint(0, 10, (120)) + x = paddle.to_tensor(x_data) + out = paddle.unique(x) + expected_out = np.unique(x_data) + self.assertTrue((out.numpy() == expected_out).all(), True) + paddle.enable_static() + + def test_dygraph_api_attr(self): + paddle.disable_static() + x_data = np.random.random((3, 5, 5)).astype("float32") + x = paddle.to_tensor(x_data) + out, index, inverse, counts = paddle.unique( + x, + return_index=True, + return_inverse=True, + return_counts=True, + axis=0) + np_out, np_index, np_inverse, np_counts = np.unique( + x_data, + return_index=True, + return_inverse=True, + return_counts=True, + axis=0) + self.assertTrue((out.numpy() == np_out).all(), True) + self.assertTrue((index.numpy() == np_index).all(), True) + self.assertTrue((inverse.numpy() == np_inverse).all(), True) + self.assertTrue((counts.numpy() == np_counts).all(), True) + paddle.enable_static() + + def test_static_graph(self): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + x = paddle.data(name='x', shape=[3, 2], dtype='float64') + unique, inverse, counts = paddle.unique( + x, return_inverse=True, return_counts=True, axis=0) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + x_np = np.array([[1, 2], [3, 4], [1, 2]]).astype('float64') + result = exe.run(feed={"x": x_np}, + fetch_list=[unique, inverse, counts]) + np_unique, np_inverse, np_counts = np.unique( + x_np, return_inverse=True, return_counts=True, axis=0) + self.assertTrue(np.allclose(result[0], np_unique)) + self.assertTrue(np.allclose(result[1], np_inverse)) + self.assertTrue(np.allclose(result[2], np_counts)) + + +class TestUniqueError(unittest.TestCase): + def test_input_dtype(self): + def test_x_dtype(): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + x = paddle.data(name='x', shape=[10, 10], dtype='float16') + result = paddle.unique(x) + + self.assertRaises(TypeError, test_x_dtype) + + def test_attr(self): + x = paddle.data(name='x', shape=[10, 10], dtype='float64') + + def test_return_index(): + result = paddle.unique(x, return_index=0) + + self.assertRaises(TypeError, test_return_index) + + def test_return_inverse(): + result = paddle.unique(x, return_inverse='s') + + self.assertRaises(TypeError, test_return_inverse) + + def test_return_counts(): + result = paddle.unique(x, return_counts=3) + + self.assertRaises(TypeError, test_return_counts) + + def test_axis(): + result = paddle.unique(x, axis='12') + + self.assertRaises(TypeError, test_axis) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 65469759a38..44ec0a5a4df 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -27,7 +27,6 @@ from ..fluid.layers import expand_as #DEFINE_ALIAS from ..fluid.layers import slice #DEFINE_ALIAS from ..fluid.layers import strided_slice #DEFINE_ALIAS from ..fluid.layers import transpose #DEFINE_ALIAS -from ..fluid.layers import unique #DEFINE_ALIAS from ..fluid.layers import unstack #DEFINE_ALIAS from ..fluid.layers import scatter_nd_add #DEFINE_ALIAS @@ -608,6 +607,126 @@ def squeeze(x, axis=None, name=None): return layers.squeeze(x, axis, name) +def unique(x, + return_index=False, + return_inverse=False, + return_counts=False, + axis=None, + name=None): + """ + Returns the unique elements of `x` in ascending order. + + Args: + x(Tensor): The input tensor, it's data type should be float32, float64, int32, int64. + return_index(bool, optional): If True, also return the indices of the input tensor that + result in the unique Tensor. + return_inverse(bool, optional): If True, also return the indices for where elements in + the original input ended up in the returned unique tensor. + return_counts(bool, optional): If True, also return the counts for each unique element. + axis(int, optional): The axis to apply unique. If None, the input will be flattened. + Default: None. + name(str, optional): Name for the operation. For more information, please refer to + :ref:`api_guide_Name`. Default: None. + + Returns: + tuple: (out, indices, inverse, counts). `out` is the unique tensor for `x`. `indices` is \ + provided only if `return_index` is True. `inverse` is provided only if `return_inverse` \ + is True. `counts` is provided only if `return_counts` is True. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + + paddle.disable_static() + x_data = np.array([2, 3, 3, 1, 5, 3]) + x = paddle.to_tensor(x_data) + unique = paddle.unique(x) + np_unique = unique.numpy() # [1 2 3 5] + _, indices, inverse, counts = paddle.unique(x, return_index=True, return_inverse=True, return_counts=True) + np_indices = indices.numpy() # [3 0 1 4] + np_inverse = inverse.numpy() # [1 2 2 0 3 2] + np_counts = counts.numpy() # [1 1 3 1] + + x_data = np.array([[2, 1, 3], [3, 0, 1], [2, 1, 3]]) + unique = paddle.unique(x) + np_unique = unique.numpy() # [0 1 2 3] + + unique = paddle.unique(x, axis=0) + np_unique = unique.numpy() + # [[2 1 3] + # [3 0 1]] + """ + if axis is None: + axis = [] + else: + axis = [axis] + + if in_dygraph_mode(): + out, inverse, indices, counts = core.ops.unique( + x, 'dtype', + convert_np_dtype_to_dtype_('int32'), 'return_index', return_index, + 'return_inverse', return_inverse, 'return_counts', return_counts, + 'axis', axis, "is_sorted", True) + outs = [out] + if return_index: + outs.append(indices) + if return_inverse: + outs.append(inverse) + if return_counts: + outs.append(counts) + + if len(outs) == 1: + return outs[0] + + return tuple(outs) + + check_variable_and_dtype(x, "input", + ['float32', 'float64', 'int32', 'int64'], 'unique') + check_type(return_index, 'return_index', bool, 'unique') + check_type(return_inverse, 'return_inverse', bool, 'unique') + check_type(return_counts, 'return_counts', bool, 'unique') + if len(axis) != 0: + check_type(axis[0], 'axis', int, 'unique') + + helper = LayerHelper('unique', **locals()) + attrs = { + 'dtype': int(core.VarDesc.VarType.INT32), + "return_index": return_index, + "return_inverse": return_inverse, + "return_counts": return_counts, + "axis": axis, + "is_sorted": True + } + out = helper.create_variable_for_type_inference( + dtype=x.dtype, stop_gradient=True) + inverse = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.INT64, stop_gradient=True) + outputs = {"Out": out, "Index": inverse} + outs = [out] + if return_index: + indices = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.INT64, stop_gradient=True) + outputs["Indices"] = indices + outs.append(indices) + if return_inverse: + outs.append(inverse) + if return_counts: + counts = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.INT64, stop_gradient=True) + outputs["Counts"] = counts + outs.append(counts) + + helper.append_op( + type="unique", inputs={"X": x}, attrs=attrs, outputs=outputs) + + if len(outs) == 1: + return outs[0] + + return tuple(outs) + + def unsqueeze(x, axis, name=None): """ :alias_main: paddle.unsqueeze -- GitLab