diff --git a/python/paddle/fluid/tests/unittests/test_where_op.py b/python/paddle/fluid/tests/unittests/test_where_op.py index 5eaf140461bce8442c0f11b4e6ae007f7907549a..908b2577a826b0f1a06f75827b54d10b4129b3ba 100644 --- a/python/paddle/fluid/tests/unittests/test_where_op.py +++ b/python/paddle/fluid/tests/unittests/test_where_op.py @@ -140,6 +140,92 @@ class TestWhereAPI(unittest.TestCase): fetch_list=[result]) assert np.array_equal(out[0], np.where(x_i > 1, x_i, y_i)) + def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape): + paddle.enable_static() + + main_program = Program() + with fluid.program_guard(main_program): + cond = fluid.layers.data( + name='cond', shape=cond_shape, dtype='bool') + x = fluid.layers.data(name='x', shape=x_shape, dtype='float32') + y = fluid.layers.data(name='y', shape=y_shape, dtype='float32') + + cond_data_tmp = np.random.random(size=cond_shape).astype("float32") + cond_data = cond_data_tmp < 0.3 + x_data = np.random.random(size=x_shape).astype("float32") + y_data = np.random.random(size=y_shape).astype("float32") + + result = paddle.where(condition=cond, x=x, y=y) + + for use_cuda in [False, True]: + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + + exe = fluid.Executor(place) + out = exe.run( + fluid.default_main_program(), + feed={'cond': cond_data, + 'x': x_data, + 'y': y_data}, + fetch_list=[result]) + + expect = np.where(cond_data, x_data, y_data) + + assert np.array_equal(out[0], expect) + + def test_static_api_broadcast_1(self): + cond_shape = [2, 4] + a_shape = [2, 2, 4] + b_shape = [2, 2, 4] + self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + + def test_static_api_broadcast_2(self): + cond_shape = [2, 1] + a_shape = [2, 2, 4] + b_shape = [2, 2, 4] + self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + + def test_static_api_broadcast_3(self): + cond_shape = [2, 2, 1] + a_shape = [2, 2, 4] + b_shape = [2, 2, 4] + self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + + def test_static_api_broadcast_4(self): + cond_shape = [2, 1, 4] + a_shape = [2, 2, 4] + b_shape = [2, 2, 4] + self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + + # @Note Now, maybe not compatibility with old version + def test_static_api_broadcast_5(self): + cond_shape = [3, 2, 2, 4] + a_shape = [2, 2, 4] + b_shape = [2, 2, 4] + self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + + # @Note Now, maybe not compatibility with old version + def test_static_api_broadcast_6(self): + cond_shape = [2, 2, 4] + a_shape = [2, 2, 1] + b_shape = [2, 2, 1] + self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + + # @Note Now, maybe not compatibility with old version + def test_static_api_broadcast_7(self): + cond_shape = [2, 2, 4] + a_shape = [2, 1, 4] + b_shape = [2, 1, 4] + self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + + # @Note Now, maybe not compatibility with old version + def test_static_api_broadcast_8(self): + cond_shape = [3, 2, 2, 4] + a_shape = [2, 2, 1] + b_shape = [2, 2, 1] + self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + class TestWhereDygraphAPI(unittest.TestCase): def test_api(self): @@ -153,6 +239,72 @@ class TestWhereDygraphAPI(unittest.TestCase): out = paddle.where(cond, x, y) assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i)) + def __test_where_with_broadcast_dygraph(self, cond_shape, a_shape, b_shape): + with fluid.dygraph.guard(): + cond_tmp = paddle.rand(cond_shape) + cond = cond_tmp < 0.3 + a = paddle.rand(a_shape) + b = paddle.rand(b_shape) + + result = paddle.where(cond, a, b) + result = result.numpy() + + expect = np.where(cond, a, b) + + self.assertTrue(np.array_equal(expect, result)) + + def test_dygraph_api_broadcast_1(self): + cond_shape = [2, 4] + a_shape = [2, 2, 4] + b_shape = [2, 2, 4] + self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape) + + def test_dygraph_api_broadcast_2(self): + cond_shape = [2, 1] + a_shape = [2, 2, 4] + b_shape = [2, 2, 4] + self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape) + + def test_dygraph_api_broadcast_3(self): + cond_shape = [2, 2, 1] + a_shape = [2, 2, 4] + b_shape = [2, 2, 4] + self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape) + + def test_dygraph_api_broadcast_4(self): + cond_shape = [2, 1, 4] + a_shape = [2, 2, 4] + b_shape = [2, 2, 4] + self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape) + + # @Note Now, maybe not compatibility with old version + def test_dygraph_api_broadcast_5(self): + cond_shape = [3, 2, 2, 4] + a_shape = [2, 2, 4] + b_shape = [2, 2, 4] + self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape) + + # @Note Now, maybe not compatibility with old version + def test_dygraph_api_broadcast_6(self): + cond_shape = [2, 2, 4] + a_shape = [2, 2, 1] + b_shape = [2, 2, 1] + self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape) + + # @Note Now, maybe not compatibility with old version + def test_dygraph_api_broadcast_7(self): + cond_shape = [2, 2, 4] + a_shape = [2, 1, 4] + b_shape = [2, 1, 4] + self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape) + + # @Note Now, maybe not compatibility with old version + def test_dygraph_api_broadcast_8(self): + cond_shape = [3, 2, 2, 4] + a_shape = [2, 2, 1] + b_shape = [2, 2, 1] + self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape) + class TestWhereOpError(unittest.TestCase): def test_errors(self): diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 071a26151905a22404c7c405bf0ad342811c0458..79eeae78a41c69c160e332926f85ebde3b66916a 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -514,9 +514,10 @@ def where(condition, x, y, name=None): check_variable_and_dtype( y, 'y', ['float32', 'float64', 'int32', 'int64'], 'where') + condition_shape = list(condition.shape) x_shape = list(x.shape) y_shape = list(y.shape) - if x_shape == y_shape: + if x_shape == y_shape and condition_shape == x_shape: if in_dygraph_mode(): return _C_ops.where(condition, x, y) else: