未验证 提交 ffc3d364 编写于 作者: F furnace 提交者: GitHub

[WIP] paddle.where api add broadcast, when x_shape == y_shape, and x_shape != cond_shape (#35092)

* where op add broadcast, when x_shape == y_shape, and x_shape != cond_shape

* add static api tests, and delete debug codes
上级 e8772486
...@@ -140,6 +140,92 @@ class TestWhereAPI(unittest.TestCase): ...@@ -140,6 +140,92 @@ class TestWhereAPI(unittest.TestCase):
fetch_list=[result]) fetch_list=[result])
assert np.array_equal(out[0], np.where(x_i > 1, x_i, y_i)) assert np.array_equal(out[0], np.where(x_i > 1, x_i, y_i))
def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape):
paddle.enable_static()
main_program = Program()
with fluid.program_guard(main_program):
cond = fluid.layers.data(
name='cond', shape=cond_shape, dtype='bool')
x = fluid.layers.data(name='x', shape=x_shape, dtype='float32')
y = fluid.layers.data(name='y', shape=y_shape, dtype='float32')
cond_data_tmp = np.random.random(size=cond_shape).astype("float32")
cond_data = cond_data_tmp < 0.3
x_data = np.random.random(size=x_shape).astype("float32")
y_data = np.random.random(size=y_shape).astype("float32")
result = paddle.where(condition=cond, x=x, y=y)
for use_cuda in [False, True]:
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
out = exe.run(
fluid.default_main_program(),
feed={'cond': cond_data,
'x': x_data,
'y': y_data},
fetch_list=[result])
expect = np.where(cond_data, x_data, y_data)
assert np.array_equal(out[0], expect)
def test_static_api_broadcast_1(self):
cond_shape = [2, 4]
a_shape = [2, 2, 4]
b_shape = [2, 2, 4]
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
def test_static_api_broadcast_2(self):
cond_shape = [2, 1]
a_shape = [2, 2, 4]
b_shape = [2, 2, 4]
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
def test_static_api_broadcast_3(self):
cond_shape = [2, 2, 1]
a_shape = [2, 2, 4]
b_shape = [2, 2, 4]
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
def test_static_api_broadcast_4(self):
cond_shape = [2, 1, 4]
a_shape = [2, 2, 4]
b_shape = [2, 2, 4]
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_static_api_broadcast_5(self):
cond_shape = [3, 2, 2, 4]
a_shape = [2, 2, 4]
b_shape = [2, 2, 4]
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_static_api_broadcast_6(self):
cond_shape = [2, 2, 4]
a_shape = [2, 2, 1]
b_shape = [2, 2, 1]
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_static_api_broadcast_7(self):
cond_shape = [2, 2, 4]
a_shape = [2, 1, 4]
b_shape = [2, 1, 4]
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_static_api_broadcast_8(self):
cond_shape = [3, 2, 2, 4]
a_shape = [2, 2, 1]
b_shape = [2, 2, 1]
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
class TestWhereDygraphAPI(unittest.TestCase): class TestWhereDygraphAPI(unittest.TestCase):
def test_api(self): def test_api(self):
...@@ -153,6 +239,72 @@ class TestWhereDygraphAPI(unittest.TestCase): ...@@ -153,6 +239,72 @@ class TestWhereDygraphAPI(unittest.TestCase):
out = paddle.where(cond, x, y) out = paddle.where(cond, x, y)
assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i)) assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i))
def __test_where_with_broadcast_dygraph(self, cond_shape, a_shape, b_shape):
with fluid.dygraph.guard():
cond_tmp = paddle.rand(cond_shape)
cond = cond_tmp < 0.3
a = paddle.rand(a_shape)
b = paddle.rand(b_shape)
result = paddle.where(cond, a, b)
result = result.numpy()
expect = np.where(cond, a, b)
self.assertTrue(np.array_equal(expect, result))
def test_dygraph_api_broadcast_1(self):
cond_shape = [2, 4]
a_shape = [2, 2, 4]
b_shape = [2, 2, 4]
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
def test_dygraph_api_broadcast_2(self):
cond_shape = [2, 1]
a_shape = [2, 2, 4]
b_shape = [2, 2, 4]
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
def test_dygraph_api_broadcast_3(self):
cond_shape = [2, 2, 1]
a_shape = [2, 2, 4]
b_shape = [2, 2, 4]
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
def test_dygraph_api_broadcast_4(self):
cond_shape = [2, 1, 4]
a_shape = [2, 2, 4]
b_shape = [2, 2, 4]
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_dygraph_api_broadcast_5(self):
cond_shape = [3, 2, 2, 4]
a_shape = [2, 2, 4]
b_shape = [2, 2, 4]
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_dygraph_api_broadcast_6(self):
cond_shape = [2, 2, 4]
a_shape = [2, 2, 1]
b_shape = [2, 2, 1]
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_dygraph_api_broadcast_7(self):
cond_shape = [2, 2, 4]
a_shape = [2, 1, 4]
b_shape = [2, 1, 4]
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_dygraph_api_broadcast_8(self):
cond_shape = [3, 2, 2, 4]
a_shape = [2, 2, 1]
b_shape = [2, 2, 1]
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
class TestWhereOpError(unittest.TestCase): class TestWhereOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
......
...@@ -514,9 +514,10 @@ def where(condition, x, y, name=None): ...@@ -514,9 +514,10 @@ def where(condition, x, y, name=None):
check_variable_and_dtype( check_variable_and_dtype(
y, 'y', ['float32', 'float64', 'int32', 'int64'], 'where') y, 'y', ['float32', 'float64', 'int32', 'int64'], 'where')
condition_shape = list(condition.shape)
x_shape = list(x.shape) x_shape = list(x.shape)
y_shape = list(y.shape) y_shape = list(y.shape)
if x_shape == y_shape: if x_shape == y_shape and condition_shape == x_shape:
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.where(condition, x, y) return _C_ops.where(condition, x, y)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册