From d02e3e7f79309b0e3fd20ce4522c8dcef2d526bc Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 2 Apr 2019 17:26:06 +0800 Subject: [PATCH] add a simple test test=develop --- .../tests/unittests/test_imperative_basic.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index 13f2d66217..8b3094cb2a 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -348,6 +348,55 @@ class TestImperative(unittest.TestCase): self.assertEqual(mlp._fc2, sublayers[1]) self.assertEqual(len(sublayers), 2) + def test_dygraph_vs_static(self): + inp1 = np.random.rand(4, 3, 3) + inp2 = np.random.rand(4, 3, 3) + + # dynamic graph + with fluid.dygraph.guard(): + if np.sum(inp1) < np.sum(inp2): + x = fluid.layers.elementwise_add(inp1, inp2) + else: + x = fluid.layers.elementwise_sub(inp1, inp2) + dygraph_result = x._numpy() + + # static graph + with new_program_scope(): + inp_data1 = fluid.layers.data( + name='inp1', shape=[3, 3], dtype=np.float32) + inp_data2 = fluid.layers.data( + name='inp2', shape=[3, 3], dtype=np.float32) + + a = fluid.layers.expand( + fluid.layers.reshape( + fluid.layers.reduce_sum(inp_data1), [1, 1]), [4, 1]) + b = fluid.layers.expand( + fluid.layers.reshape( + fluid.layers.reduce_sum(inp_data2), [1, 1]), [4, 1]) + cond = fluid.layers.less_than(x=a, y=b) + + ie = fluid.layers.IfElse(cond) + with ie.true_block(): + d1 = ie.input(inp_data1) + d2 = ie.input(inp_data2) + d3 = fluid.layers.elementwise_add(d1, d2) + ie.output(d3) + + with ie.false_block(): + d1 = ie.input(inp_data1) + d2 = ie.input(inp_data2) + d3 = fluid.layers.elementwise_sub(d1, d2) + ie.output(d3) + out = ie() + + exe = fluid.Executor(fluid.CPUPlace( + ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) + static_result = exe.run(fluid.default_main_program(), + feed={'inp1': inp1, + 'inp2': inp2}, + fetch_list=out)[0] + self.assertTrue(np.allclose(dygraph_result, static_result)) + def test_rnn(self): np_inp = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]) -- GitLab