From 26a92083941770cd6f380eef2de040f85779afbc Mon Sep 17 00:00:00 2001 From: Varun Arora Date: Thu, 15 Mar 2018 15:14:12 -0700 Subject: [PATCH] New PingPong test for testing channels / concurrency (#9132) * New test for testing channels / concurrency * Formatting fix --- python/paddle/fluid/tests/test_concurrency.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/python/paddle/fluid/tests/test_concurrency.py b/python/paddle/fluid/tests/test_concurrency.py index 3aa51610cd..924895a9af 100644 --- a/python/paddle/fluid/tests/test_concurrency.py +++ b/python/paddle/fluid/tests/test_concurrency.py @@ -217,6 +217,57 @@ class TestRoutineOp(unittest.TestCase): exe_result = exe.run(fetch_list=[result]) self.assertEqual(exe_result[0][0], 34) + def test_ping_pong(self): + """ + Mimics Ping Pong example: https://gobyexample.com/channel-directions + """ + with framework.program_guard(framework.Program()): + result = self._create_tensor('return_value', + core.VarDesc.VarType.LOD_TENSOR, + core.VarDesc.VarType.FP64) + + ping_result = self._create_tensor('ping_return_value', + core.VarDesc.VarType.LOD_TENSOR, + core.VarDesc.VarType.FP64) + + pong_result = self._create_tensor('pong_return_value', + core.VarDesc.VarType.LOD_TENSOR, + core.VarDesc.VarType.FP64) + + def ping(ch, message): + message_to_send_tmp = fill_constant( + shape=[1], dtype=core.VarDesc.VarType.FP64, value=0) + + assign(input=message, output=message_to_send_tmp) + fluid.channel_send(ch, message_to_send_tmp) + + def pong(ch1, ch2): + fluid.channel_recv(ch1, ping_result) + assign(input=ping_result, output=pong_result) + fluid.channel_send(ch2, pong_result) + + pings = fluid.make_channel( + dtype=core.VarDesc.VarType.LOD_TENSOR, capacity=1) + pongs = fluid.make_channel( + dtype=core.VarDesc.VarType.LOD_TENSOR, capacity=1) + + msg = fill_constant( + shape=[1], dtype=core.VarDesc.VarType.FP64, value=9) + + ping(pings, msg) + pong(pings, pongs) + + fluid.channel_recv(pongs, result) + + fluid.channel_close(pings) + fluid.channel_close(pongs) + + cpu = core.CPUPlace() + exe = Executor(cpu) + + exe_result = exe.run(fetch_list=[result]) + self.assertEqual(exe_result[0][0], 9) + if __name__ == '__main__': unittest.main() -- GitLab