提交 a523b6f4 编写于 作者: Y Yibing Liu

Add python api for argsort_op

上级 7ca511e0
......@@ -1105,6 +1105,12 @@ argmax
.. autofunction:: paddle.fluid.layers.argmax
:noindex:
argsort
------
.. autofunction:: paddle.fluid.layers.argsort
:noindex:
ones
----
......
......@@ -34,13 +34,13 @@ class ArgsortOp : public framework::OperatorWithKernel {
auto num_dims = in_dims.size();
PADDLE_ENFORCE(axis < num_dims,
"Attr(axis) %d of ArgsortOp is out of bounds for Input(X) "
"dimension %d.",
"Attr(axis) %d of ArgsortOp is out of bounds for Input(X)'s "
"rank %d.",
axis, num_dims);
PADDLE_ENFORCE(axis >= -num_dims,
"Attr(axis) %d of ArgsortOp must be not less than "
"-rank(Input(X)) (%d).",
axis, num_dims);
PADDLE_ENFORCE(in_dims.size() + axis >= 0,
"Attr(axis) %d of ArgsortOp plus the rank %d of Input(X) "
"must be nonnegative.",
axis, in_dims.size());
ctx->SetOutputDim("Out", in_dims);
ctx->SetOutputDim("Indices", in_dims);
......
......@@ -33,6 +33,7 @@ __all__ = [
'fill_constant',
'argmin',
'argmax',
'argsort',
'ones',
'zeros',
'reverse',
......@@ -438,6 +439,56 @@ def argmax(x, axis=0):
return out
def argsort(input, axis=-1):
"""
Performs sorting on the input Variable along the given axis, and outputs
sorted data Varibale and its corresponding index Variable with the same
shape as :attr:`input`.
.. code-block:: text
For example, the given axis is -1 and the input Variable
input = [[0.15849551, 0.45865775, 0.8563702 ],
[0.12070083, 0.28766365, 0.18776911]],
after argsort, the sorted Vairable becomes
out = [[0.15849551, 0.45865775, 0.8563702 ],
[0.12070083, 0.18776911, 0.28766365]],
and the sorted indices along the given axis turn outs to be
indices = [[0, 1, 2],
[0, 2, 1]]
Args:
input(Variable): The input Variable for sorting.
axis(int): The axis along which to sort the input Variable. When
:attr:`axis` < 0, the actual axis will be :attr:`axis` +
rank(:attr:`input`). Default -1, the last dimension.
Returns:
tuple: A tuple of sorted data Variable and the sorted indices.
Examples:
.. code-block:: python
input = fluid.layers.data(data=[2, 3])
out, indices = fluid.layers.argsort(input, axis=0)
"""
helper = LayerHelper("argsort", **locals())
out = helper.create_tmp_variable(dtype=input.dtype, stop_gradient=True)
ids = helper.create_tmp_variable(VarDesc.VarType.INT64, stop_gradient=True)
helper.append_op(
type='argsort',
inputs={'X': input},
outputs={'Out': out,
'Indics': ids},
attts={'axis': axis})
return out, ids
def ones(shape, dtype, force_cpu=False):
"""
**ones**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册