未验证 提交 096afbe1 编写于 作者: R ronnywang 提交者: GitHub

fix paddle.where torch diff (#38870)

* fix paddle.where torch diff

* update
上级 724d49da
...@@ -305,6 +305,36 @@ class TestWhereDygraphAPI(unittest.TestCase): ...@@ -305,6 +305,36 @@ class TestWhereDygraphAPI(unittest.TestCase):
b_shape = [2, 2, 1] b_shape = [2, 2, 1]
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape) self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
def test_where_condition(self):
data = np.array([[True, False], [False, True]])
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 2])
y = paddle.where(x)
self.assertEqual(type(y), tuple)
self.assertEqual(len(y), 2)
z = fluid.layers.concat(list(y), axis=1)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': data},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array([[0, 0], [1, 1]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
data = np.array([True, True, False])
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1])
y = paddle.where(x)
self.assertEqual(type(y), tuple)
self.assertEqual(len(y), 1)
z = fluid.layers.concat(list(y), axis=1)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': data},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array([[0], [1]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
class TestWhereOpError(unittest.TestCase): class TestWhereOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
...@@ -326,6 +356,14 @@ class TestWhereOpError(unittest.TestCase): ...@@ -326,6 +356,14 @@ class TestWhereOpError(unittest.TestCase):
self.assertRaises(TypeError, test_type) self.assertRaises(TypeError, test_type)
def test_value_error(self):
with fluid.dygraph.guard():
cond_shape = [2, 2, 4]
cond_tmp = paddle.rand(cond_shape)
cond = cond_tmp < 0.3
a = paddle.rand(cond_shape)
self.assertRaises(ValueError, paddle.where, cond, a)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -523,23 +523,26 @@ def mode(x, axis=-1, keepdim=False, name=None): ...@@ -523,23 +523,26 @@ def mode(x, axis=-1, keepdim=False, name=None):
return values, indices return values, indices
def where(condition, x, y, name=None): def where(condition, x=None, y=None, name=None):
r""" r"""
Return a tensor of elements selected from either $x$ or $y$, depending on $condition$. Return a tensor of elements selected from either $x$ or $y$, depending on $condition$.
**Note**:
``paddle.where(condition)`` is identical to ``paddle.nonzero(condition, as_tuple=True)``.
.. math:: .. math::
out_i = out_i =
\\begin{cases} \begin{cases}
x_i, \quad \\text{if} \\ condition_i \\ is \\ True \\\\ x_i, \quad \text{if} \ condition_i \ is \ True \\
y_i, \quad \\text{if} \\ condition_i \\ is \\ False \\\\ y_i, \quad \text{if} \ condition_i \ is \ False \\
\\end{cases} \end{cases}
Args: Args:
condition(Tensor): The condition to choose x or y. condition(Tensor): The condition to choose x or y.
x(Tensor): x is a Tensor with data type float32, float64, int32, int64. x(Tensor, optional): x is a Tensor with data type float32, float64, int32, int64. Either both or neither of x and y should be given.
y(Tensor): y is a Tensor with data type float32, float64, int32, int64. y(Tensor, optional): y is a Tensor with data type float32, float64, int32, int64. Either both or neither of x and y should be given.
name(str, optional): The default value is None. Normally there is no name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please need for user to set this property. For more information, please
...@@ -559,7 +562,19 @@ def where(condition, x, y, name=None): ...@@ -559,7 +562,19 @@ def where(condition, x, y, name=None):
print(out) print(out)
#out: [1.0, 1.0, 3.2, 1.2] #out: [1.0, 1.0, 3.2, 1.2]
out = paddle.where(x>1)
print(out)
#out: (Tensor(shape=[2, 1], dtype=int64, place=CPUPlace, stop_gradient=True,
# [[2],
# [3]]),)
""" """
if x is None and y is None:
return nonzero(condition, as_tuple=True)
if x is None or y is None:
raise ValueError("either both or neither of x and y should be given")
if not in_dygraph_mode(): if not in_dygraph_mode():
check_variable_and_dtype(condition, 'condition', ['bool'], 'where') check_variable_and_dtype(condition, 'condition', ['bool'], 'where')
check_variable_and_dtype( check_variable_and_dtype(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册