diff --git a/python/paddle/fluid/tests/unittests/test_where_op.py b/python/paddle/fluid/tests/unittests/test_where_op.py index d601117b96f12d35756b521b85902bf91ef01bae..7fb4d39cd7338fb3cd57c786bc811b901351eaf9 100644 --- a/python/paddle/fluid/tests/unittests/test_where_op.py +++ b/python/paddle/fluid/tests/unittests/test_where_op.py @@ -139,6 +139,28 @@ class TestWhereAPI(unittest.TestCase): fetch_list=[result]) assert np.array_equal(out[0], np.where((x_i > 1), x_i, y_i)) + def test_scalar(self): + paddle.enable_static() + main_program = Program() + with fluid.program_guard(main_program): + cond_shape = [2, 4] + cond = fluid.layers.data( + name='cond', shape=cond_shape, dtype='bool') + x_data = 1.0 + y_data = 2.0 + cond_data = np.array([False, False, True, True]).astype('bool') + result = paddle.where(condition=cond, x=x_data, y=y_data) + for use_cuda in [False, True]: + if (use_cuda and (not fluid.core.is_compiled_with_cuda())): + return + place = (fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()) + exe = fluid.Executor(place) + out = exe.run(fluid.default_main_program(), + feed={'cond': cond_data}, + fetch_list=[result]) + expect = np.where(cond_data, x_data, y_data) + assert np.array_equal(out[0], expect) + def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape): paddle.enable_static() main_program = Program() @@ -227,6 +249,15 @@ class TestWhereDygraphAPI(unittest.TestCase): out = paddle.where(cond, x, y) assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i)) + def test_scalar(self): + with fluid.dygraph.guard(): + cond_i = np.array([False, False, True, True]).astype('bool') + x = 1.0 + y = 2.0 + cond = fluid.dygraph.to_variable(cond_i) + out = paddle.where(cond, x, y) + assert np.array_equal(out.numpy(), np.where(cond_i, x, y)) + def __test_where_with_broadcast_dygraph(self, cond_shape, a_shape, b_shape): with fluid.dygraph.guard(): cond_tmp = paddle.rand(cond_shape) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 5c5517e54f71ad8dde7999561953ca4c03680b90..ecf70ffe4a1dd3179d02a2a6ca1e260e8193d1d1 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -543,8 +543,8 @@ def where(condition, x=None, y=None, name=None): Args: condition(Tensor): The condition to choose x or y. - x(Tensor, optional): x is a Tensor with data type float32, float64, int32, int64. Either both or neither of x and y should be given. - y(Tensor, optional): y is a Tensor with data type float32, float64, int32, int64. Either both or neither of x and y should be given. + x(Tensor or Scalar, optional): x is a Tensor or Scalar with data type float32, float64, int32, int64. Either both or neither of x and y should be given. + y(Tensor or Scalar, optional): y is a Tensor or Scalar with data type float32, float64, int32, int64. Either both or neither of x and y should be given. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please @@ -571,6 +571,12 @@ def where(condition, x=None, y=None, name=None): # [[2], # [3]]),) """ + if np.isscalar(x): + x = layers.fill_constant([1], np.array([x]).dtype.name, x) + + if np.isscalar(y): + y = layers.fill_constant([1], np.array([y]).dtype.name, y) + if x is None and y is None: return nonzero(condition, as_tuple=True)