未验证 提交 096afbe1 编写于 作者: R ronnywang 提交者: GitHub

fix paddle.where torch diff (#38870)

* fix paddle.where torch diff

* update
上级 724d49da
......@@ -305,6 +305,36 @@ class TestWhereDygraphAPI(unittest.TestCase):
b_shape = [2, 2, 1]
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
def test_where_condition(self):
data = np.array([[True, False], [False, True]])
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 2])
y = paddle.where(x)
self.assertEqual(type(y), tuple)
self.assertEqual(len(y), 2)
z = fluid.layers.concat(list(y), axis=1)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': data},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array([[0, 0], [1, 1]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
data = np.array([True, True, False])
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1])
y = paddle.where(x)
self.assertEqual(type(y), tuple)
self.assertEqual(len(y), 1)
z = fluid.layers.concat(list(y), axis=1)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': data},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array([[0], [1]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
class TestWhereOpError(unittest.TestCase):
def test_errors(self):
......@@ -326,6 +356,14 @@ class TestWhereOpError(unittest.TestCase):
self.assertRaises(TypeError, test_type)
def test_value_error(self):
with fluid.dygraph.guard():
cond_shape = [2, 2, 4]
cond_tmp = paddle.rand(cond_shape)
cond = cond_tmp < 0.3
a = paddle.rand(cond_shape)
self.assertRaises(ValueError, paddle.where, cond, a)
if __name__ == '__main__':
unittest.main()
......@@ -523,23 +523,26 @@ def mode(x, axis=-1, keepdim=False, name=None):
return values, indices
def where(condition, x, y, name=None):
def where(condition, x=None, y=None, name=None):
r"""
Return a tensor of elements selected from either $x$ or $y$, depending on $condition$.
**Note**:
``paddle.where(condition)`` is identical to ``paddle.nonzero(condition, as_tuple=True)``.
.. math::
out_i =
\\begin{cases}
x_i, \quad \\text{if} \\ condition_i \\ is \\ True \\\\
y_i, \quad \\text{if} \\ condition_i \\ is \\ False \\\\
\\end{cases}
\begin{cases}
x_i, \quad \text{if} \ condition_i \ is \ True \\
y_i, \quad \text{if} \ condition_i \ is \ False \\
\end{cases}
Args:
condition(Tensor): The condition to choose x or y.
x(Tensor): x is a Tensor with data type float32, float64, int32, int64.
y(Tensor): y is a Tensor with data type float32, float64, int32, int64.
x(Tensor, optional): x is a Tensor with data type float32, float64, int32, int64. Either both or neither of x and y should be given.
y(Tensor, optional): y is a Tensor with data type float32, float64, int32, int64. Either both or neither of x and y should be given.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
......@@ -559,7 +562,19 @@ def where(condition, x, y, name=None):
print(out)
#out: [1.0, 1.0, 3.2, 1.2]
out = paddle.where(x>1)
print(out)
#out: (Tensor(shape=[2, 1], dtype=int64, place=CPUPlace, stop_gradient=True,
# [[2],
# [3]]),)
"""
if x is None and y is None:
return nonzero(condition, as_tuple=True)
if x is None or y is None:
raise ValueError("either both or neither of x and y should be given")
if not in_dygraph_mode():
check_variable_and_dtype(condition, 'condition', ['bool'], 'where')
check_variable_and_dtype(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册