From 096afbe1f5b473ed994a44d8d08715f0f44ad2f5 Mon Sep 17 00:00:00 2001 From: ronnywang <524019753@qq.com> Date: Mon, 17 Jan 2022 10:42:09 +0800 Subject: [PATCH] fix paddle.where torch diff (#38870) * fix paddle.where torch diff * update --- .../fluid/tests/unittests/test_where_op.py | 38 +++++++++++++++++++ python/paddle/tensor/search.py | 29 ++++++++++---- 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_where_op.py b/python/paddle/fluid/tests/unittests/test_where_op.py index 908b2577a8..5b92fcf52d 100644 --- a/python/paddle/fluid/tests/unittests/test_where_op.py +++ b/python/paddle/fluid/tests/unittests/test_where_op.py @@ -305,6 +305,36 @@ class TestWhereDygraphAPI(unittest.TestCase): b_shape = [2, 2, 1] self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape) + def test_where_condition(self): + data = np.array([[True, False], [False, True]]) + with program_guard(Program(), Program()): + x = fluid.layers.data(name='x', shape=[-1, 2]) + y = paddle.where(x) + self.assertEqual(type(y), tuple) + self.assertEqual(len(y), 2) + z = fluid.layers.concat(list(y), axis=1) + exe = fluid.Executor(fluid.CPUPlace()) + + res, = exe.run(feed={'x': data}, + fetch_list=[z.name], + return_numpy=False) + expect_out = np.array([[0, 0], [1, 1]]) + self.assertTrue(np.allclose(expect_out, np.array(res))) + + data = np.array([True, True, False]) + with program_guard(Program(), Program()): + x = fluid.layers.data(name='x', shape=[-1]) + y = paddle.where(x) + self.assertEqual(type(y), tuple) + self.assertEqual(len(y), 1) + z = fluid.layers.concat(list(y), axis=1) + exe = fluid.Executor(fluid.CPUPlace()) + res, = exe.run(feed={'x': data}, + fetch_list=[z.name], + return_numpy=False) + expect_out = np.array([[0], [1]]) + self.assertTrue(np.allclose(expect_out, np.array(res))) + class TestWhereOpError(unittest.TestCase): def test_errors(self): @@ -326,6 +356,14 @@ class TestWhereOpError(unittest.TestCase): self.assertRaises(TypeError, test_type) + def test_value_error(self): + with fluid.dygraph.guard(): + cond_shape = [2, 2, 4] + cond_tmp = paddle.rand(cond_shape) + cond = cond_tmp < 0.3 + a = paddle.rand(cond_shape) + self.assertRaises(ValueError, paddle.where, cond, a) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 0685e27645..e15d2d49d5 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -523,23 +523,26 @@ def mode(x, axis=-1, keepdim=False, name=None): return values, indices -def where(condition, x, y, name=None): +def where(condition, x=None, y=None, name=None): r""" Return a tensor of elements selected from either $x$ or $y$, depending on $condition$. + **Note**: + ``paddle.where(condition)`` is identical to ``paddle.nonzero(condition, as_tuple=True)``. + .. math:: out_i = - \\begin{cases} - x_i, \quad \\text{if} \\ condition_i \\ is \\ True \\\\ - y_i, \quad \\text{if} \\ condition_i \\ is \\ False \\\\ - \\end{cases} + \begin{cases} + x_i, \quad \text{if} \ condition_i \ is \ True \\ + y_i, \quad \text{if} \ condition_i \ is \ False \\ + \end{cases} Args: condition(Tensor): The condition to choose x or y. - x(Tensor): x is a Tensor with data type float32, float64, int32, int64. - y(Tensor): y is a Tensor with data type float32, float64, int32, int64. + x(Tensor, optional): x is a Tensor with data type float32, float64, int32, int64. Either both or neither of x and y should be given. + y(Tensor, optional): y is a Tensor with data type float32, float64, int32, int64. Either both or neither of x and y should be given. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please @@ -559,7 +562,19 @@ def where(condition, x, y, name=None): print(out) #out: [1.0, 1.0, 3.2, 1.2] + + out = paddle.where(x>1) + print(out) + #out: (Tensor(shape=[2, 1], dtype=int64, place=CPUPlace, stop_gradient=True, + # [[2], + # [3]]),) """ + if x is None and y is None: + return nonzero(condition, as_tuple=True) + + if x is None or y is None: + raise ValueError("either both or neither of x and y should be given") + if not in_dygraph_mode(): check_variable_and_dtype(condition, 'condition', ['bool'], 'where') check_variable_and_dtype( -- GitLab