未验证 提交 c5ae43a2 编写于 作者: R ronnywang 提交者: GitHub

fix paddle.where torch diff (#39859)

上级 b695fd95
...@@ -139,6 +139,28 @@ class TestWhereAPI(unittest.TestCase): ...@@ -139,6 +139,28 @@ class TestWhereAPI(unittest.TestCase):
fetch_list=[result]) fetch_list=[result])
assert np.array_equal(out[0], np.where((x_i > 1), x_i, y_i)) assert np.array_equal(out[0], np.where((x_i > 1), x_i, y_i))
def test_scalar(self):
paddle.enable_static()
main_program = Program()
with fluid.program_guard(main_program):
cond_shape = [2, 4]
cond = fluid.layers.data(
name='cond', shape=cond_shape, dtype='bool')
x_data = 1.0
y_data = 2.0
cond_data = np.array([False, False, True, True]).astype('bool')
result = paddle.where(condition=cond, x=x_data, y=y_data)
for use_cuda in [False, True]:
if (use_cuda and (not fluid.core.is_compiled_with_cuda())):
return
place = (fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace())
exe = fluid.Executor(place)
out = exe.run(fluid.default_main_program(),
feed={'cond': cond_data},
fetch_list=[result])
expect = np.where(cond_data, x_data, y_data)
assert np.array_equal(out[0], expect)
def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape): def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape):
paddle.enable_static() paddle.enable_static()
main_program = Program() main_program = Program()
...@@ -227,6 +249,15 @@ class TestWhereDygraphAPI(unittest.TestCase): ...@@ -227,6 +249,15 @@ class TestWhereDygraphAPI(unittest.TestCase):
out = paddle.where(cond, x, y) out = paddle.where(cond, x, y)
assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i)) assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i))
def test_scalar(self):
with fluid.dygraph.guard():
cond_i = np.array([False, False, True, True]).astype('bool')
x = 1.0
y = 2.0
cond = fluid.dygraph.to_variable(cond_i)
out = paddle.where(cond, x, y)
assert np.array_equal(out.numpy(), np.where(cond_i, x, y))
def __test_where_with_broadcast_dygraph(self, cond_shape, a_shape, b_shape): def __test_where_with_broadcast_dygraph(self, cond_shape, a_shape, b_shape):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
cond_tmp = paddle.rand(cond_shape) cond_tmp = paddle.rand(cond_shape)
......
...@@ -543,8 +543,8 @@ def where(condition, x=None, y=None, name=None): ...@@ -543,8 +543,8 @@ def where(condition, x=None, y=None, name=None):
Args: Args:
condition(Tensor): The condition to choose x or y. condition(Tensor): The condition to choose x or y.
x(Tensor, optional): x is a Tensor with data type float32, float64, int32, int64. Either both or neither of x and y should be given. x(Tensor or Scalar, optional): x is a Tensor or Scalar with data type float32, float64, int32, int64. Either both or neither of x and y should be given.
y(Tensor, optional): y is a Tensor with data type float32, float64, int32, int64. Either both or neither of x and y should be given. y(Tensor or Scalar, optional): y is a Tensor or Scalar with data type float32, float64, int32, int64. Either both or neither of x and y should be given.
name(str, optional): The default value is None. Normally there is no name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please need for user to set this property. For more information, please
...@@ -571,6 +571,12 @@ def where(condition, x=None, y=None, name=None): ...@@ -571,6 +571,12 @@ def where(condition, x=None, y=None, name=None):
# [[2], # [[2],
# [3]]),) # [3]]),)
""" """
if np.isscalar(x):
x = layers.fill_constant([1], np.array([x]).dtype.name, x)
if np.isscalar(y):
y = layers.fill_constant([1], np.array([y]).dtype.name, y)
if x is None and y is None: if x is None and y is None:
return nonzero(condition, as_tuple=True) return nonzero(condition, as_tuple=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册