From a523b6f49f4048d8c32c4d6c53dc22fdcebfe2b0 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 18 Jun 2018 02:30:24 -0700 Subject: [PATCH] Add python api for argsort_op --- doc/fluid/api/layers.rst | 6 ++++ paddle/fluid/operators/argsort_op.cc | 12 +++---- python/paddle/fluid/layers/tensor.py | 51 ++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/doc/fluid/api/layers.rst b/doc/fluid/api/layers.rst index 1f8f6360404..4157faae4c2 100644 --- a/doc/fluid/api/layers.rst +++ b/doc/fluid/api/layers.rst @@ -1105,6 +1105,12 @@ argmax .. autofunction:: paddle.fluid.layers.argmax :noindex: +argsort +------ + +.. autofunction:: paddle.fluid.layers.argsort + :noindex: + ones ---- diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc index ca9a884b983..a2f5a254570 100644 --- a/paddle/fluid/operators/argsort_op.cc +++ b/paddle/fluid/operators/argsort_op.cc @@ -34,13 +34,13 @@ class ArgsortOp : public framework::OperatorWithKernel { auto num_dims = in_dims.size(); PADDLE_ENFORCE(axis < num_dims, - "Attr(axis) %d of ArgsortOp is out of bounds for Input(X) " - "dimension %d.", + "Attr(axis) %d of ArgsortOp is out of bounds for Input(X)'s " + "rank %d.", + axis, num_dims); + PADDLE_ENFORCE(axis >= -num_dims, + "Attr(axis) %d of ArgsortOp must be not less than " + "-rank(Input(X)) (%d).", axis, num_dims); - PADDLE_ENFORCE(in_dims.size() + axis >= 0, - "Attr(axis) %d of ArgsortOp plus the rank %d of Input(X) " - "must be nonnegative.", - axis, in_dims.size()); ctx->SetOutputDim("Out", in_dims); ctx->SetOutputDim("Indices", in_dims); diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 149e77b5241..656bd5bb1d7 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -33,6 +33,7 @@ __all__ = [ 'fill_constant', 'argmin', 'argmax', + 'argsort', 'ones', 'zeros', 'reverse', @@ -438,6 +439,56 @@ def argmax(x, axis=0): return out +def argsort(input, axis=-1): + """ + Performs sorting on the input Variable along the given axis, and outputs + sorted data Varibale and its corresponding index Variable with the same + shape as :attr:`input`. + + .. code-block:: text + + For example, the given axis is -1 and the input Variable + + input = [[0.15849551, 0.45865775, 0.8563702 ], + [0.12070083, 0.28766365, 0.18776911]], + + after argsort, the sorted Vairable becomes + + out = [[0.15849551, 0.45865775, 0.8563702 ], + [0.12070083, 0.18776911, 0.28766365]], + + and the sorted indices along the given axis turn outs to be + + indices = [[0, 1, 2], + [0, 2, 1]] + + Args: + input(Variable): The input Variable for sorting. + axis(int): The axis along which to sort the input Variable. When + :attr:`axis` < 0, the actual axis will be :attr:`axis` + + rank(:attr:`input`). Default -1, the last dimension. + + Returns: + tuple: A tuple of sorted data Variable and the sorted indices. + + Examples: + .. code-block:: python + + input = fluid.layers.data(data=[2, 3]) + out, indices = fluid.layers.argsort(input, axis=0) + """ + helper = LayerHelper("argsort", **locals()) + out = helper.create_tmp_variable(dtype=input.dtype, stop_gradient=True) + ids = helper.create_tmp_variable(VarDesc.VarType.INT64, stop_gradient=True) + helper.append_op( + type='argsort', + inputs={'X': input}, + outputs={'Out': out, + 'Indics': ids}, + attts={'axis': axis}) + return out, ids + + def ones(shape, dtype, force_cpu=False): """ **ones** -- GitLab