提交 713d0eb1 编写于 作者: Z zhupengyang 提交者: hong19860320

fix argsort doc (#1266)

* fix argsort doc
test=document_preview

* fix
test=document_preview

* test=document_preview

* fix
test=document_preview

* fix
test=document_preview
上级 dfc3ab06
......@@ -5,48 +5,61 @@ argsort
.. py:function:: paddle.fluid.layers.argsort(input,axis=-1,name=None)
对输入变量沿给定轴进行排序,输出排序好的数据和相应的索引,其维度和输入相同
对输入变量沿给定轴进行 **升序** 排列,输出排序好的数据和相应的索引,其维度和输入相同。**暂不支持降序排列**。
.. code-block:: text
例如:
给定 input 并指定 axis=-1
input = [[0.15849551, 0.45865775, 0.8563702 ],
[0.12070083, 0.28766365, 0.18776911]],
执行argsort操作后,得到排序数据:
out = [[0.15849551, 0.45865775, 0.8563702 ],
[0.12070083, 0.18776911, 0.28766365]],
根据指定axis排序后的数据indices变为:
indices = [[0, 1, 2],
[0, 2, 1]]
参数:
- **input** (Variable)-用于排序的输入变量
- **axis** (int)- 沿该参数指定的轴对输入进行排序。当axis<0,实际的轴为axis+rank(input)。默认为-1,即最后一维
- **name** (str|None)-(可选)该层名称。如果设为空,则自动为该层命名
- **input** (Variable) - 输入的多维 ``Tensor`` ,支持的数据类型:float32、float64。
- **axis** (int,可选) - 指定对输入Tensor进行运算的轴, ``axis`` 的有效范围是[-R, R),R是输入 ``x`` 的Rank, ``axis`` 为负时与 ``axis`` +R 等价。默认值为0
- **name** (str,可选) - 该参数供开发人员打印调试信息时使用,具体用法请参见 :ref:`api_guide_Name`,默认值为None
返回:一组已排序的数据变量和索引
返回:一组已排序的输出(与 ``input`` 维度相同、数据类型相同)和索引(数据类型为int64),
返回类型:元组
返回类型:tuple[Variable]
**代码示例**:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name="x", shape=[3, 4], dtype="float32")
out, indices = fluid.layers.argsort(input=x, axis=0)
import paddle.fluid as fluid
import numpy as np
in1 = np.array([[[5,8,9,5],
[0,0,1,7],
[6,9,2,4]],
[[5,2,4,2],
[4,7,7,9],
[1,7,0,6]]]).astype(np.float32)
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(in1)
out1 = fluid.layers.argsort(input=x, axis=-1) # same as axis==2
out2 = fluid.layers.argsort(input=x, axis=0)
out3 = fluid.layers.argsort(input=x, axis=1)
print(out1[0].numpy())
# [[[5. 5. 8. 9.]
# [0. 0. 1. 7.]
# [2. 4. 6. 9.]]
# [[2. 2. 4. 5.]
# [4. 7. 7. 9.]
# [0. 1. 6. 7.]]]
print(out1[1].numpy())
# [[[0 3 1 2]
# [0 1 2 3]
# [2 3 0 1]]
# [[1 3 2 0]
# [0 1 2 3]
# [2 0 3 1]]]
print(out2[0].numpy())
# [[[5. 2. 4. 2.]
# [0. 0. 1. 7.]
# [1. 7. 0. 4.]]
# [[5. 8. 9. 5.]
# [4. 7. 7. 9.]
# [6. 9. 2. 6.]]]
print(out3[0].numpy())
# [[[0. 0. 1. 4.]
# [5. 8. 2. 5.]
# [6. 9. 9. 7.]]
# [[1. 2. 0. 2.]
# [4. 7. 4. 6.]
# [5. 7. 7. 9.]]]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册