diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 62b0afe8f160779f845babf6fd1a6cca69420afc..a9ec30aa6232199885693a585141aa668ed8617e 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -28,7 +28,7 @@ from ..framework import Variable, OpProtoHolder, in_dygraph_mode from ..dygraph import base from ..param_attr import ParamAttr from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_ -from .tensor import concat, assign, fill_constant, zeros +from .tensor import concat, assign, fill_constant, zeros, cast from . import utils from .. import unique_name from functools import reduce @@ -183,6 +183,7 @@ __all__ = [ 'hard_swish', 'gather_tree', 'uniform_random', + 'masked_select', ] @@ -13581,3 +13582,63 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): outputs={"Out": out}) return helper.append_activation(out) + + +def masked_select(input, mask): + """ + This OP selects elements of the input tensor according to the mask tensor. + The shapes of the mask tensor don't have to match shapes of input tensor, but they must be broadcastable, and the result is a new 1-D tensor. + + NOTE: The meaning of broadcastable is consistent with expand_as. + + Parameters: + + input(Variable): The input tensor, the data type should be int32, float32, float64. + mask(Variable): The boolean mask tensor, the data type should be bool. + + Returns: + Variable: masked select tensor, its data type is same as the input. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + mask_shape = [4,1] + shape = [4,4] + data = np.random.random(mask_shape).astype("float32") + input_data = np.random.randint(5,size=shape).astype("float32") + mask_data = data > 0.5 + + # print(input_data) + # [[0.38972723 0.36218056 0.7892614 0.50122297] + # [0.14408113 0.85540855 0.30984417 0.7577004 ] + # [0.97263193 0.5248062 0.07655851 0.75549215] + # [0.26214206 0.32359877 0.6314582 0.2128865 ]] + + # print(mask_data) + # [[ True] + # [ True] + # [False] + # [ True]] + + input = fluid.data(name="input",shape=[4,4],dtype="float32") + mask = fluid.data(name="mask",shape=[4,1],dtype="bool") + result = fluid.layers.masked_select(input=input, mask=mask) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + start = fluid.default_startup_program() + main = fluid.default_main_program() + exe.run(start) + masked_select_result= exe.run(main, feed={'input':input_data, 'mask':mask_data}, fetch_list=[result]) + # print(masked_select_result) + # [0.38972723 0.36218056 0.7892614 0.50122297 0.14408113 0.85540855 + # 0.30984417 0.7577004 0.26214206 0.32359877 0.6314582 0.2128865 ] + + """ + mask_cast = cast(x=mask, dtype=input.dtype) + mask_expand = expand_as(x=mask_cast, target_tensor=input) + mask_expand_cast_back_bool = cast(x=mask_expand, dtype="bool") + select = where(mask_expand_cast_back_bool) + result = gather_nd(input, select) + return result diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 6db2267fca9a26f8ba9686e17a1dc9f517e9fb4b..718f3f6efa56d521d0de078a40165f9d4cdbbb4e 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1640,6 +1640,9 @@ class TestBook(LayerTest): elif dtype == 'int64': return np.random.randint(self._low_data_bound, self._high_data_bound, shape).astype(dtype) + elif dtype == 'bool': + return np.random.randint(self._low_data_bound, + self._high_data_bound, shape).astype(dtype) def _get_data(self, name, @@ -2557,6 +2560,14 @@ class TestBook(LayerTest): out = layers.square_error_cost(input=x, label=y) return (out) + def make_masked_select(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name="X", shape=[4, 4], dtype="float32") + y = self._get_data(name="Y", shape=[1, 4], dtype="bool") + out = layers.masked_select(input=x, mask=y) + return (out) + def test_dynamic_lstmp(self): # TODO(minqiyang): dygraph do not support lod now with self.static_graph(): diff --git a/python/paddle/fluid/tests/unittests/test_masked_select.py b/python/paddle/fluid/tests/unittests/test_masked_select.py new file mode 100644 index 0000000000000000000000000000000000000000..c2e087fdd9a0c430146ff340ec42344bccfae75b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_masked_select.py @@ -0,0 +1,54 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +import paddle.fluid.core as core +import paddle.fluid as fluid +import paddle.fluid.layers as layers +from paddle.fluid.executor import Executor + + +class TestMaskedSelect(unittest.TestCase): + def test_masked_select(self): + + mask_shape = [4, 1] + shape = [4, 4] + data = np.random.random(mask_shape).astype("float32") + input_data = np.random.random(shape).astype("float32") + mask_data = data > 0.5 + mask_data_b = np.broadcast_to(mask_data, shape) + npresult = input_data[np.where(mask_data_b)] + + input_var = layers.create_tensor(dtype="float32", name="input") + mask_var = layers.create_tensor(dtype="bool", name="mask") + + output = layers.masked_select(input=input_var, mask=mask_var) + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = Executor(place) + result = exe.run(fluid.default_main_program(), + feed={"input": input_data, + "mask": mask_data}, + fetch_list=[output]) + + self.assertTrue(np.isclose(npresult, result).all()) + + +if __name__ == "__main__": + unittest.main()