未验证 提交 77376727 编写于 作者: H Huihuang Zheng 提交者: GitHub

Enrich 0d Tensor Dygraph and Shape Unit Test for `case` and `switch_case` (#49889)

Followed PR https://github.com/PaddlePaddle/Paddle/pull/49842 , added Digraph and Shape unit test for `case` and `switch_case`. This PR only contained test changes because `case` and `switch_case` call `cond`. The PR https://github.com/PaddlePaddle/Paddle/pull/49842 has already solved the 0d tensor support.
上级 ce045890
......@@ -22,6 +22,7 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
import paddle.fluid.optimizer as optimizer
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Program, program_guard
paddle.enable_static()
......@@ -145,10 +146,101 @@ class TestAPICase(unittest.TestCase):
)
np.testing.assert_allclose(res[0], 1, rtol=1e-05)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(res[1], 2, rtol=1e-05)
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(res[2], 3, rtol=1e-05)
self.assertEqual(res[2].shape, ())
np.testing.assert_allclose(res[3], 2, rtol=1e-05)
self.assertEqual(res[3].shape, ())
np.testing.assert_allclose(res[4], 2, rtol=1e-05)
self.assertEqual(res[4].shape, ())
def test_0d_tensor_backward(self):
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
x = paddle.full(shape=[], dtype='float32', fill_value=-2.0)
x.stop_gradient = False
pred = paddle.full(shape=[], dtype='bool', fill_value=0)
# pred is False, so out = -x
out = paddle.static.nn.case(
pred_fn_pairs=[(pred, lambda: x)], default=lambda: -x
)
append_backward(out)
place = (
fluid.CUDAPlace(0)
if core.is_compiled_with_cuda()
else fluid.CPUPlace()
)
exe = fluid.Executor(place)
res = exe.run(main_program, fetch_list=[out.name, x.grad_name])
np.testing.assert_allclose(
np.asarray(res[0]), np.array(2.0), rtol=1e-05
)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(
np.asarray(res[1]), np.array(-1.0), rtol=1e-05
)
self.assertEqual(res[1].shape, ())
def test_0d_tensor_dygraph(self):
paddle.disable_static()
def fn_1():
return paddle.full(shape=[], dtype='int32', fill_value=1)
def fn_2():
return paddle.full(shape=[], dtype='int32', fill_value=2)
def fn_3():
return paddle.full(shape=[], dtype='int32', fill_value=3)
x = paddle.full(shape=[], dtype='float32', fill_value=0.3)
y = paddle.full(shape=[], dtype='float32', fill_value=0.1)
z = paddle.full(shape=[], dtype='float32', fill_value=0.2)
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3
# call fn_1
out_0 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_1, fn_2)], default=fn_3
)
# call fn_2
out_1 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3
)
# call default fn_3
out_2 = paddle.static.nn.control_flow.case(
pred_fn_pairs=((pred_2, fn_1), (pred_2, fn_2)), default=fn_3
)
# no default, call fn_2
out_3 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, fn_2)]
)
# no default, call fn_2. but pred_2 is false
out_4 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_2, fn_2)]
)
np.testing.assert_allclose(out_0, 1, rtol=1e-05)
self.assertEqual(out_0.shape, [])
np.testing.assert_allclose(out_1, 2, rtol=1e-05)
self.assertEqual(out_1.shape, [])
np.testing.assert_allclose(out_2, 3, rtol=1e-05)
self.assertEqual(out_2.shape, [])
np.testing.assert_allclose(out_3, 2, rtol=1e-05)
self.assertEqual(out_3.shape, [])
np.testing.assert_allclose(out_4, 2, rtol=1e-05)
self.assertEqual(out_4.shape, [])
paddle.enable_static()
def test_return_var_tuple(self):
def fn_1():
......@@ -394,8 +486,11 @@ class TestAPICase_Nested(unittest.TestCase):
res = exe.run(main_program, fetch_list=[out_1, out_2, out_3])
np.testing.assert_allclose(res[0], 1, rtol=1e-05)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(res[1], 2, rtol=1e-05)
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(res[2], 3, rtol=1e-05)
self.assertEqual(res[2].shape, ())
class TestAPICase_Error(unittest.TestCase):
......
......@@ -21,6 +21,7 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Program, program_guard
paddle.enable_static()
......@@ -93,25 +94,25 @@ class TestAPISwitchCase(unittest.TestCase):
res[1],
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2),
err_msg='result is {} but answer is {}'.format(res[1], 2),
)
np.testing.assert_allclose(
res[2],
3,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 3),
err_msg='result is {} but answer is {}'.format(res[2], 3),
)
np.testing.assert_allclose(
res[3],
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2),
err_msg='result is {} but answer is {}'.format(res[3], 2),
)
np.testing.assert_allclose(
res[4],
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2),
err_msg='result is {} but answer is {}'.format(res[4], 2),
)
def test_0d_tensor(self):
......@@ -176,30 +177,148 @@ class TestAPISwitchCase(unittest.TestCase):
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 1),
)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(
res[1],
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2),
err_msg='result is {} but answer is {}'.format(res[1], 2),
)
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(
res[2],
3,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 3),
err_msg='result is {} but answer is {}'.format(res[2], 3),
)
self.assertEqual(res[2].shape, ())
np.testing.assert_allclose(
res[3],
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2),
err_msg='result is {} but answer is {}'.format(res[3], 2),
)
self.assertEqual(res[3].shape, ())
np.testing.assert_allclose(
res[4],
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2),
err_msg='result is {} but answer is {}'.format(res[4], 2),
)
self.assertEqual(res[4].shape, ())
def test_0d_tensor_backward(self):
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
x = paddle.full(shape=[], dtype='float32', fill_value=-2.0)
x.stop_gradient = False
pred = paddle.full(shape=[], dtype='int32', fill_value=2)
# pred is 2, so out = 2 * x
out = paddle.static.nn.switch_case(
branch_index=pred,
branch_fns=[(1, lambda: x), (2, lambda: 2 * x)],
default=lambda: -x,
)
append_backward(out)
place = (
fluid.CUDAPlace(0)
if core.is_compiled_with_cuda()
else fluid.CPUPlace()
)
exe = fluid.Executor(place)
res = exe.run(main_program, fetch_list=[out.name, x.grad_name])
np.testing.assert_allclose(
np.asarray(res[0]), np.array(-4.0), rtol=1e-05
)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(
np.asarray(res[1]), np.array(2.0), rtol=1e-05
)
self.assertEqual(res[1].shape, ())
def test_0d_tensor_dygraph(self):
paddle.disable_static()
def fn_1():
return paddle.full(shape=[], dtype='int32', fill_value=1)
def fn_2():
return paddle.full(shape=[], dtype='int32', fill_value=2)
def fn_3():
return paddle.full(shape=[], dtype='int32', fill_value=3)
index_1 = paddle.full(shape=[], dtype='int32', fill_value=1)
index_2 = paddle.full(shape=[], dtype='int32', fill_value=2)
index_5 = paddle.full(shape=[], dtype='int32', fill_value=5)
# call fn_1
out_0 = paddle.static.nn.switch_case(
branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
)
# call fn_2 : branch_fns={0: fn_1, 1:fn_2, 2:fn_3}
out_1 = paddle.static.nn.switch_case(
branch_index=index_1, branch_fns=(fn_1, fn_2, fn_3)
)
# call default fn_3
out_2 = paddle.static.nn.switch_case(
branch_index=index_5,
branch_fns=((1, fn_1), (2, fn_2)),
default=fn_3,
)
# no default, call fn_2
out_3 = paddle.static.nn.switch_case(
branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)]
)
# no default, call fn_2 but branch_index is 5
out_4 = paddle.static.nn.switch_case(
branch_index=index_5,
branch_fns=[(1, fn_1), (3, fn_2), (2, fn_3)],
)
np.testing.assert_allclose(
out_0,
1,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(out_0, 1),
)
self.assertEqual(out_0.shape, [])
np.testing.assert_allclose(
out_1,
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(out_1, 2),
)
self.assertEqual(out_1.shape, [])
np.testing.assert_allclose(
out_2,
3,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(out_2, 3),
)
self.assertEqual(out_2.shape, [])
np.testing.assert_allclose(
out_3,
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(out_3, 2),
)
self.assertEqual(out_3.shape, [])
np.testing.assert_allclose(
out_4,
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(out_4, 2),
)
self.assertEqual(out_4.shape, [])
paddle.enable_static()
def test_return_var_tuple(self):
def fn_1():
......@@ -426,18 +545,21 @@ class TestAPISwitchCase_Nested(unittest.TestCase):
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 1),
)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(
res[1],
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[1], 2),
)
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(
res[2],
3,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[2], 3),
)
self.assertEqual(res[2].shape, ())
# test TypeError and ValueError of api switch_case
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册