diff --git a/doc/fluid/api_cn/layers_cn.rst b/doc/fluid/api_cn/layers_cn.rst index ebf0cf2e384558eadb2c9802d64b541a48391cd3..f4b8e8b680aadb99118e51e8f12b9040cce78971 100644 --- a/doc/fluid/api_cn/layers_cn.rst +++ b/doc/fluid/api_cn/layers_cn.rst @@ -169,6 +169,7 @@ fluid.layers layers_cn/lstm_unit_cn.rst layers_cn/LSTMCell_cn.rst layers_cn/margin_rank_loss_cn.rst + layers_cn/masked_select_cn.rst layers_cn/matmul_cn.rst layers_cn/maxout_cn.rst layers_cn/mean_cn.rst diff --git a/doc/fluid/api_cn/layers_cn/masked_select_cn.rst b/doc/fluid/api_cn/layers_cn/masked_select_cn.rst new file mode 100644 index 0000000000000000000000000000000000000000..52462854377085ebb39bb5f0c88fedb2ccbc6884 --- /dev/null +++ b/doc/fluid/api_cn/layers_cn/masked_select_cn.rst @@ -0,0 +1,57 @@ +.. _cn_api_fluid_layers_masked_select: + +masked_select +------------------------------- + +.. py:function:: paddle.fluid.layers.masked_select(input, mask) + +该OP将根据mask Tensor的真值选取输入Tensor元素,并返回一个一维Tensor + +参数: + - **input** (Variable)- 输入Tensor,数据类型为int32, float32, float64。 + - **mask** (Variable)- mask Tensor, 数据类型为bool。 + + +返回:根据mask选择后的tensor + +返回类型: Variable + + +**示例代码** + +.. code-block:: python + + import paddle.fluid as fluid + import numpy as np + mask_shape = [4,1] + shape = [4,4] + data = np.random.random(mask_shape).astype("float32") + input_data = np.random.randint(5,size=shape).astype("float32") + mask_data = data > 0.5 + + # print(input_data) + # [[0.38972723 0.36218056 0.7892614 0.50122297] + # [0.14408113 0.85540855 0.30984417 0.7577004 ] + # [0.97263193 0.5248062 0.07655851 0.75549215] + # [0.26214206 0.32359877 0.6314582 0.2128865 ]] + + # print(mask_data) + # [[ True] + # [ True] + # [False] + # [ True]] + + input = fluid.data(name="input",shape=[4,4],dtype="float32") + mask = fluid.data(name="mask",shape=[4,1],dtype="bool") + result = fluid.layers.masked_select(input=input, mask=mask) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + start = fluid.default_startup_program() + main = fluid.default_main_program() + exe.run(start) + masked_select_result= exe.run(main, feed={'input':input_data, 'mask':mask_data}, fetch_list=[result]) + # print(masked_select_result) + # [0.38972723 0.36218056 0.7892614 0.50122297 0.14408113 0.85540855 + # 0.30984417 0.7577004 0.26214206 0.32359877 0.6314582 0.2128865 ] + +