未验证 提交 77376727 编写于 作者: H Huihuang Zheng 提交者: GitHub

Enrich 0d Tensor Dygraph and Shape Unit Test for `case` and `switch_case` (#49889)

Followed PR https://github.com/PaddlePaddle/Paddle/pull/49842 , added Digraph and Shape unit test for `case` and `switch_case`. This PR only contained test changes because `case` and `switch_case` call `cond`. The PR https://github.com/PaddlePaddle/Paddle/pull/49842 has already solved the 0d tensor support.
上级 ce045890
...@@ -22,6 +22,7 @@ import paddle.fluid as fluid ...@@ -22,6 +22,7 @@ import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.fluid.optimizer as optimizer import paddle.fluid.optimizer as optimizer
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Program, program_guard from paddle.fluid.framework import Program, program_guard
paddle.enable_static() paddle.enable_static()
...@@ -145,10 +146,101 @@ class TestAPICase(unittest.TestCase): ...@@ -145,10 +146,101 @@ class TestAPICase(unittest.TestCase):
) )
np.testing.assert_allclose(res[0], 1, rtol=1e-05) np.testing.assert_allclose(res[0], 1, rtol=1e-05)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(res[1], 2, rtol=1e-05) np.testing.assert_allclose(res[1], 2, rtol=1e-05)
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(res[2], 3, rtol=1e-05) np.testing.assert_allclose(res[2], 3, rtol=1e-05)
self.assertEqual(res[2].shape, ())
np.testing.assert_allclose(res[3], 2, rtol=1e-05) np.testing.assert_allclose(res[3], 2, rtol=1e-05)
self.assertEqual(res[3].shape, ())
np.testing.assert_allclose(res[4], 2, rtol=1e-05) np.testing.assert_allclose(res[4], 2, rtol=1e-05)
self.assertEqual(res[4].shape, ())
def test_0d_tensor_backward(self):
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
x = paddle.full(shape=[], dtype='float32', fill_value=-2.0)
x.stop_gradient = False
pred = paddle.full(shape=[], dtype='bool', fill_value=0)
# pred is False, so out = -x
out = paddle.static.nn.case(
pred_fn_pairs=[(pred, lambda: x)], default=lambda: -x
)
append_backward(out)
place = (
fluid.CUDAPlace(0)
if core.is_compiled_with_cuda()
else fluid.CPUPlace()
)
exe = fluid.Executor(place)
res = exe.run(main_program, fetch_list=[out.name, x.grad_name])
np.testing.assert_allclose(
np.asarray(res[0]), np.array(2.0), rtol=1e-05
)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(
np.asarray(res[1]), np.array(-1.0), rtol=1e-05
)
self.assertEqual(res[1].shape, ())
def test_0d_tensor_dygraph(self):
paddle.disable_static()
def fn_1():
return paddle.full(shape=[], dtype='int32', fill_value=1)
def fn_2():
return paddle.full(shape=[], dtype='int32', fill_value=2)
def fn_3():
return paddle.full(shape=[], dtype='int32', fill_value=3)
x = paddle.full(shape=[], dtype='float32', fill_value=0.3)
y = paddle.full(shape=[], dtype='float32', fill_value=0.1)
z = paddle.full(shape=[], dtype='float32', fill_value=0.2)
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3
# call fn_1
out_0 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_1, fn_2)], default=fn_3
)
# call fn_2
out_1 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3
)
# call default fn_3
out_2 = paddle.static.nn.control_flow.case(
pred_fn_pairs=((pred_2, fn_1), (pred_2, fn_2)), default=fn_3
)
# no default, call fn_2
out_3 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, fn_2)]
)
# no default, call fn_2. but pred_2 is false
out_4 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_2, fn_2)]
)
np.testing.assert_allclose(out_0, 1, rtol=1e-05)
self.assertEqual(out_0.shape, [])
np.testing.assert_allclose(out_1, 2, rtol=1e-05)
self.assertEqual(out_1.shape, [])
np.testing.assert_allclose(out_2, 3, rtol=1e-05)
self.assertEqual(out_2.shape, [])
np.testing.assert_allclose(out_3, 2, rtol=1e-05)
self.assertEqual(out_3.shape, [])
np.testing.assert_allclose(out_4, 2, rtol=1e-05)
self.assertEqual(out_4.shape, [])
paddle.enable_static()
def test_return_var_tuple(self): def test_return_var_tuple(self):
def fn_1(): def fn_1():
...@@ -394,8 +486,11 @@ class TestAPICase_Nested(unittest.TestCase): ...@@ -394,8 +486,11 @@ class TestAPICase_Nested(unittest.TestCase):
res = exe.run(main_program, fetch_list=[out_1, out_2, out_3]) res = exe.run(main_program, fetch_list=[out_1, out_2, out_3])
np.testing.assert_allclose(res[0], 1, rtol=1e-05) np.testing.assert_allclose(res[0], 1, rtol=1e-05)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(res[1], 2, rtol=1e-05) np.testing.assert_allclose(res[1], 2, rtol=1e-05)
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(res[2], 3, rtol=1e-05) np.testing.assert_allclose(res[2], 3, rtol=1e-05)
self.assertEqual(res[2].shape, ())
class TestAPICase_Error(unittest.TestCase): class TestAPICase_Error(unittest.TestCase):
......
...@@ -21,6 +21,7 @@ import paddle ...@@ -21,6 +21,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Program, program_guard from paddle.fluid.framework import Program, program_guard
paddle.enable_static() paddle.enable_static()
...@@ -93,25 +94,25 @@ class TestAPISwitchCase(unittest.TestCase): ...@@ -93,25 +94,25 @@ class TestAPISwitchCase(unittest.TestCase):
res[1], res[1],
2, 2,
rtol=1e-05, rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2), err_msg='result is {} but answer is {}'.format(res[1], 2),
) )
np.testing.assert_allclose( np.testing.assert_allclose(
res[2], res[2],
3, 3,
rtol=1e-05, rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 3), err_msg='result is {} but answer is {}'.format(res[2], 3),
) )
np.testing.assert_allclose( np.testing.assert_allclose(
res[3], res[3],
2, 2,
rtol=1e-05, rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2), err_msg='result is {} but answer is {}'.format(res[3], 2),
) )
np.testing.assert_allclose( np.testing.assert_allclose(
res[4], res[4],
2, 2,
rtol=1e-05, rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2), err_msg='result is {} but answer is {}'.format(res[4], 2),
) )
def test_0d_tensor(self): def test_0d_tensor(self):
...@@ -176,30 +177,148 @@ class TestAPISwitchCase(unittest.TestCase): ...@@ -176,30 +177,148 @@ class TestAPISwitchCase(unittest.TestCase):
rtol=1e-05, rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 1), err_msg='result is {} but answer is {}'.format(res[0], 1),
) )
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose( np.testing.assert_allclose(
res[1], res[1],
2, 2,
rtol=1e-05, rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2), err_msg='result is {} but answer is {}'.format(res[1], 2),
) )
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose( np.testing.assert_allclose(
res[2], res[2],
3, 3,
rtol=1e-05, rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 3), err_msg='result is {} but answer is {}'.format(res[2], 3),
) )
self.assertEqual(res[2].shape, ())
np.testing.assert_allclose( np.testing.assert_allclose(
res[3], res[3],
2, 2,
rtol=1e-05, rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2), err_msg='result is {} but answer is {}'.format(res[3], 2),
) )
self.assertEqual(res[3].shape, ())
np.testing.assert_allclose( np.testing.assert_allclose(
res[4], res[4],
2, 2,
rtol=1e-05, rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2), err_msg='result is {} but answer is {}'.format(res[4], 2),
)
self.assertEqual(res[4].shape, ())
def test_0d_tensor_backward(self):
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
x = paddle.full(shape=[], dtype='float32', fill_value=-2.0)
x.stop_gradient = False
pred = paddle.full(shape=[], dtype='int32', fill_value=2)
# pred is 2, so out = 2 * x
out = paddle.static.nn.switch_case(
branch_index=pred,
branch_fns=[(1, lambda: x), (2, lambda: 2 * x)],
default=lambda: -x,
)
append_backward(out)
place = (
fluid.CUDAPlace(0)
if core.is_compiled_with_cuda()
else fluid.CPUPlace()
)
exe = fluid.Executor(place)
res = exe.run(main_program, fetch_list=[out.name, x.grad_name])
np.testing.assert_allclose(
np.asarray(res[0]), np.array(-4.0), rtol=1e-05
)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(
np.asarray(res[1]), np.array(2.0), rtol=1e-05
) )
self.assertEqual(res[1].shape, ())
def test_0d_tensor_dygraph(self):
paddle.disable_static()
def fn_1():
return paddle.full(shape=[], dtype='int32', fill_value=1)
def fn_2():
return paddle.full(shape=[], dtype='int32', fill_value=2)
def fn_3():
return paddle.full(shape=[], dtype='int32', fill_value=3)
index_1 = paddle.full(shape=[], dtype='int32', fill_value=1)
index_2 = paddle.full(shape=[], dtype='int32', fill_value=2)
index_5 = paddle.full(shape=[], dtype='int32', fill_value=5)
# call fn_1
out_0 = paddle.static.nn.switch_case(
branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
)
# call fn_2 : branch_fns={0: fn_1, 1:fn_2, 2:fn_3}
out_1 = paddle.static.nn.switch_case(
branch_index=index_1, branch_fns=(fn_1, fn_2, fn_3)
)
# call default fn_3
out_2 = paddle.static.nn.switch_case(
branch_index=index_5,
branch_fns=((1, fn_1), (2, fn_2)),
default=fn_3,
)
# no default, call fn_2
out_3 = paddle.static.nn.switch_case(
branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)]
)
# no default, call fn_2 but branch_index is 5
out_4 = paddle.static.nn.switch_case(
branch_index=index_5,
branch_fns=[(1, fn_1), (3, fn_2), (2, fn_3)],
)
np.testing.assert_allclose(
out_0,
1,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(out_0, 1),
)
self.assertEqual(out_0.shape, [])
np.testing.assert_allclose(
out_1,
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(out_1, 2),
)
self.assertEqual(out_1.shape, [])
np.testing.assert_allclose(
out_2,
3,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(out_2, 3),
)
self.assertEqual(out_2.shape, [])
np.testing.assert_allclose(
out_3,
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(out_3, 2),
)
self.assertEqual(out_3.shape, [])
np.testing.assert_allclose(
out_4,
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(out_4, 2),
)
self.assertEqual(out_4.shape, [])
paddle.enable_static()
def test_return_var_tuple(self): def test_return_var_tuple(self):
def fn_1(): def fn_1():
...@@ -426,18 +545,21 @@ class TestAPISwitchCase_Nested(unittest.TestCase): ...@@ -426,18 +545,21 @@ class TestAPISwitchCase_Nested(unittest.TestCase):
rtol=1e-05, rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 1), err_msg='result is {} but answer is {}'.format(res[0], 1),
) )
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose( np.testing.assert_allclose(
res[1], res[1],
2, 2,
rtol=1e-05, rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[1], 2), err_msg='result is {} but answer is {}'.format(res[1], 2),
) )
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose( np.testing.assert_allclose(
res[2], res[2],
3, 3,
rtol=1e-05, rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[2], 3), err_msg='result is {} but answer is {}'.format(res[2], 3),
) )
self.assertEqual(res[2].shape, ())
# test TypeError and ValueError of api switch_case # test TypeError and ValueError of api switch_case
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册