diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 97a3ebc2135a0649fff88e1a1c14d02dfb7850b1..be6acab8ee4ecf9110ad3c2bd342ce09d04f2e2a 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -61,6 +61,8 @@ if(NOT WITH_GPU OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_allreduce) LIST(REMOVE_ITEM TEST_OPS test_broadcast) LIST(REMOVE_ITEM TEST_OPS test_collective_reduce) + LIST(REMOVE_ITEM TEST_OPS test_collective_sendrecv) + LIST(REMOVE_ITEM TEST_OPS test_collective_alltoall) LIST(REMOVE_ITEM TEST_OPS test_collective_scatter) LIST(REMOVE_ITEM TEST_OPS test_collective_reduce_api) LIST(REMOVE_ITEM TEST_OPS test_collective_scatter_api) diff --git a/python/paddle/fluid/tests/unittests/collective_alltoall_op.py b/python/paddle/fluid/tests/unittests/collective_alltoall_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c4eb7096b649dce006354cfa77e3b913fc883223 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_alltoall_op.py @@ -0,0 +1,64 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + + +class TestCollectiveAllToAll(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank=None): + ring_id = 0 + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + toutdata = layers.data( + name="toutdata", shape=[10, 1000], dtype='float32') + main_prog.global_block().append_op( + type="c_alltoall", + inputs={'X': tindata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id}) + main_prog.global_block().append_op( + type="c_sync_comm_stream", + inputs={'X': toutdata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id}) + return toutdata + + +if __name__ == "__main__": + runtime_main(TestCollectiveAllToAll, "alltoall", 0) diff --git a/python/paddle/fluid/tests/unittests/collective_sendrecv_op.py b/python/paddle/fluid/tests/unittests/collective_sendrecv_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e5d86fb505a631c0895c7e5a5aba3ea658c130e3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_sendrecv_op.py @@ -0,0 +1,73 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + + +class TestCollectiveSendRecv(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank=None): + ring_id = 0 + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + if rank == 0: + main_prog.global_block().append_op( + type="c_recv", + outputs={'Out': tindata}, + attrs={ + 'ring_id': ring_id, + 'dtype': tindata.dtype, + 'out_shape': tindata.shape, + 'peer': 1 + }) + else: + main_prog.global_block().append_op( + type="c_send", + inputs={'X': tindata}, + attrs={'ring_id': ring_id, + 'peer': 0}) + main_prog.global_block().append_op( + type="c_sync_comm_stream", + inputs={'X': tindata}, + outputs={'Out': tindata}, + attrs={'ring_id': ring_id}) + return tindata + + +if __name__ == "__main__": + runtime_main(TestCollectiveSendRecv, "sendrecv", 0) diff --git a/python/paddle/fluid/tests/unittests/test_collective_alltoall.py b/python/paddle/fluid/tests/unittests/test_collective_alltoall.py new file mode 100644 index 0000000000000000000000000000000000000000..c35e5e89e6e84d1db79737d78dd7a8a7e205f6d1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collective_alltoall.py @@ -0,0 +1,31 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np + +from test_collective_base import TestDistBase + + +class TestCAllToAllOp(TestDistBase): + def _setup_config(self): + pass + + def test_alltoall(self): + self.check_with_place("collective_alltoall_op.py", "alltoall") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py b/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py new file mode 100644 index 0000000000000000000000000000000000000000..5abfb2184308bb28da98ff7d8a3c276de91f70f0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py @@ -0,0 +1,31 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np + +from test_collective_base import TestDistBase + + +class TestCSendRecvOp(TestDistBase): + def _setup_config(self): + pass + + def test_sendrecv(self): + self.check_with_place("collective_sendrecv_op.py", "sendrecv") + + +if __name__ == '__main__': + unittest.main()