where_cn.rst 1.3 KB
Newer Older
H
Hao Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
.. _cn_api_fluid_layers_where:

where
-------------------------------

.. py:function:: paddle.fluid.layers.where(condition)
     
返回一个秩为2的int64型张量,指定condition中真实元素的坐标。
     
输出的第一维是真实元素的数量,第二维是condition的秩(维数)。如果没有真实元素,则将生成空张量。
        
参数:
    - **condition** (Variable) - 秩至少为1的布尔型张量。

返回:存储一个二维张量的张量变量

返回类型:变量(Variable)
     
**代码示例**:

.. code-block:: python

        import paddle.fluid as fluid
        import paddle.fluid.layers as layers
        import numpy as np
Z
zq19 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39
        # tensor 为 [True, False, True]
        condition = layers.assign(np.array([1, 0, 1], dtype='int32'))
        condition = layers.cast(condition, 'bool')
        out = layers.where(condition) # [[0], [2]]

        # tensor 为 [[True, False], [False, True]]
        condition = layers.assign(np.array([[1, 0], [0, 1]], dtype='int32'))
        condition = layers.cast(condition, 'bool')
        out = layers.where(condition) # [[0, 0], [1, 1]]

        # tensor 为 [False, False, False]
        condition = layers.assign(np.array([0, 0, 0], dtype='int32'))
        condition = layers.cast(condition, 'bool')
        out = layers.where(condition) # [[]]
H
Hao Wang 已提交
40 41 42 43