未验证 提交 c5ae43a2 编写于 作者: R ronnywang 提交者: GitHub

fix paddle.where torch diff (#39859)

上级 b695fd95
......@@ -139,6 +139,28 @@ class TestWhereAPI(unittest.TestCase):
fetch_list=[result])
assert np.array_equal(out[0], np.where((x_i > 1), x_i, y_i))
def test_scalar(self):
paddle.enable_static()
main_program = Program()
with fluid.program_guard(main_program):
cond_shape = [2, 4]
cond = fluid.layers.data(
name='cond', shape=cond_shape, dtype='bool')
x_data = 1.0
y_data = 2.0
cond_data = np.array([False, False, True, True]).astype('bool')
result = paddle.where(condition=cond, x=x_data, y=y_data)
for use_cuda in [False, True]:
if (use_cuda and (not fluid.core.is_compiled_with_cuda())):
return
place = (fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace())
exe = fluid.Executor(place)
out = exe.run(fluid.default_main_program(),
feed={'cond': cond_data},
fetch_list=[result])
expect = np.where(cond_data, x_data, y_data)
assert np.array_equal(out[0], expect)
def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape):
paddle.enable_static()
main_program = Program()
......@@ -227,6 +249,15 @@ class TestWhereDygraphAPI(unittest.TestCase):
out = paddle.where(cond, x, y)
assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i))
def test_scalar(self):
with fluid.dygraph.guard():
cond_i = np.array([False, False, True, True]).astype('bool')
x = 1.0
y = 2.0
cond = fluid.dygraph.to_variable(cond_i)
out = paddle.where(cond, x, y)
assert np.array_equal(out.numpy(), np.where(cond_i, x, y))
def __test_where_with_broadcast_dygraph(self, cond_shape, a_shape, b_shape):
with fluid.dygraph.guard():
cond_tmp = paddle.rand(cond_shape)
......
......@@ -543,8 +543,8 @@ def where(condition, x=None, y=None, name=None):
Args:
condition(Tensor): The condition to choose x or y.
x(Tensor, optional): x is a Tensor with data type float32, float64, int32, int64. Either both or neither of x and y should be given.
y(Tensor, optional): y is a Tensor with data type float32, float64, int32, int64. Either both or neither of x and y should be given.
x(Tensor or Scalar, optional): x is a Tensor or Scalar with data type float32, float64, int32, int64. Either both or neither of x and y should be given.
y(Tensor or Scalar, optional): y is a Tensor or Scalar with data type float32, float64, int32, int64. Either both or neither of x and y should be given.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
......@@ -571,6 +571,12 @@ def where(condition, x=None, y=None, name=None):
# [[2],
# [3]]),)
"""
if np.isscalar(x):
x = layers.fill_constant([1], np.array([x]).dtype.name, x)
if np.isscalar(y):
y = layers.fill_constant([1], np.array([y]).dtype.name, y)
if x is None and y is None:
return nonzero(condition, as_tuple=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册