From 0c617377a9290493f0af83e63ac2681e5b2fba22 Mon Sep 17 00:00:00 2001 From: songyouwei Date: Tue, 3 Mar 2020 14:23:18 +0800 Subject: [PATCH] add case and switch_case unittests for dygraph mode (#22790) test=develop --- .../fluid/tests/unittests/test_layers.py | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index c0d4284dcc..e94c352da4 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1564,6 +1564,110 @@ class TestLayer(LayerTest): self.assertTrue(np.array_equal(static_res, dynamic_res)) + def test_case(self): + def fn_1(): + return layers.fill_constant(shape=[1, 2], dtype='float32', value=1) + + def fn_2(): + return layers.fill_constant(shape=[2, 2], dtype='int32', value=2) + + def fn_3(): + return layers.fill_constant(shape=[3], dtype='int32', value=3) + + with self.static_graph(): + x = layers.fill_constant(shape=[1], dtype='float32', value=0.3) + y = layers.fill_constant(shape=[1], dtype='float32', value=0.1) + z = layers.fill_constant(shape=[1], dtype='float32', value=0.2) + + pred_1 = layers.less_than(z, x) # true: 0.2 < 0.3 + pred_2 = layers.less_than(x, y) # false: 0.3 < 0.1 + pred_3 = layers.equal(x, y) # false: 0.3 == 0.1 + + out_1 = layers.case( + pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3) + out_2 = layers.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)]) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + static_res1, static_res2 = exe.run(fetch_list=[out_1, out_2]) + + with self.dynamic_graph(): + x = layers.fill_constant(shape=[1], dtype='float32', value=0.3) + y = layers.fill_constant(shape=[1], dtype='float32', value=0.1) + z = layers.fill_constant(shape=[1], dtype='float32', value=0.2) + + pred_1 = layers.less_than(z, x) # true: 0.2 < 0.3 + pred_2 = layers.less_than(x, y) # false: 0.3 < 0.1 + pred_3 = layers.equal(x, y) # false: 0.3 == 0.1 + + out_1 = layers.case( + pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3) + out_2 = layers.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)]) + dynamic_res1 = out_1.numpy() + dynamic_res2 = out_2.numpy() + + self.assertTrue(np.array_equal(static_res1, dynamic_res1)) + self.assertTrue(np.array_equal(static_res2, dynamic_res2)) + + def test_switch_case(self): + def fn_1(): + return layers.fill_constant(shape=[1, 2], dtype='float32', value=1) + + def fn_2(): + return layers.fill_constant(shape=[2, 2], dtype='int32', value=2) + + def fn_3(): + return layers.fill_constant(shape=[3], dtype='int32', value=3) + + with self.static_graph(): + index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1) + index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2) + + out_1 = layers.switch_case( + branch_index=index_1, + branch_fns={1: fn_1, + 2: fn_2}, + default=fn_3) + out_2 = layers.switch_case( + branch_index=index_2, + branch_fns=[(1, fn_1), (2, fn_2)], + default=fn_3) + out_3 = layers.switch_case( + branch_index=index_2, + branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)]) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + static_res1, static_res2, static_res3 = exe.run( + fetch_list=[out_1, out_2, out_3]) + + with self.dynamic_graph(): + index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1) + index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2) + + out_1 = layers.switch_case( + branch_index=index_1, + branch_fns={1: fn_1, + 2: fn_2}, + default=fn_3) + out_2 = layers.switch_case( + branch_index=index_2, + branch_fns=[(1, fn_1), (2, fn_2)], + default=fn_3) + out_3 = layers.switch_case( + branch_index=index_2, + branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)]) + + dynamic_res1 = out_1.numpy() + dynamic_res2 = out_2.numpy() + dynamic_res3 = out_3.numpy() + + self.assertTrue(np.array_equal(static_res1, dynamic_res1)) + self.assertTrue(np.array_equal(static_res2, dynamic_res2)) + self.assertTrue(np.array_equal(static_res3, dynamic_res3)) + def test_crop_tensor(self): with self.static_graph(): x = fluid.layers.data(name="x1", shape=[6, 5, 8]) -- GitLab