From ed9dd7c9f0036c55745b441ea9cae79c3cf602b2 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Fri, 13 Nov 2020 15:49:55 +0800 Subject: [PATCH] add send and recv ops (#28590) * update, test=develop --- .../fluid/operators/collective/recv_v2_op.cc | 91 +++++++++++++++ .../operators/collective/recv_v2_op.cu.cc | 104 ++++++++++++++++++ .../fluid/operators/collective/recv_v2_op.h | 37 +++++++ .../fluid/operators/collective/send_v2_op.cc | 77 +++++++++++++ .../operators/collective/send_v2_op.cu.cc | 90 +++++++++++++++ .../fluid/operators/collective/send_v2_op.h | 38 +++++++ .../fluid/tests/unittests/CMakeLists.txt | 1 + .../tests/unittests/collective_sendrecv_op.py | 74 +++++++++++++ .../tests/unittests/test_collective_base.py | 6 + .../unittests/test_collective_sendrecv.py | 34 ++++++ 10 files changed, 552 insertions(+) create mode 100644 paddle/fluid/operators/collective/recv_v2_op.cc create mode 100644 paddle/fluid/operators/collective/recv_v2_op.cu.cc create mode 100644 paddle/fluid/operators/collective/recv_v2_op.h create mode 100644 paddle/fluid/operators/collective/send_v2_op.cc create mode 100644 paddle/fluid/operators/collective/send_v2_op.cu.cc create mode 100644 paddle/fluid/operators/collective/send_v2_op.h create mode 100644 python/paddle/fluid/tests/unittests/collective_sendrecv_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_collective_sendrecv.py diff --git a/paddle/fluid/operators/collective/recv_v2_op.cc b/paddle/fluid/operators/collective/recv_v2_op.cc new file mode 100644 index 00000000000..10408820387 --- /dev/null +++ b/paddle/fluid/operators/collective/recv_v2_op.cc @@ -0,0 +1,91 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/recv_v2_op.h" +#include + +namespace paddle { +namespace operators { + +class RecvOpV2 : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Recv_V2"); + int peer = ctx->Attrs().Get("peer"); + int ring_id = ctx->Attrs().Get("ring_id"); + PADDLE_ENFORCE_GE( + peer, 0, + platform::errors::InvalidArgument( + "The peer (%d) for recv_v2 op must be non-negative.", peer)); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for recv_v2 op must be non-negative.", ring_id)); + auto out_shape = ctx->Attrs().Get>("out_shape"); + PADDLE_ENFORCE_GE(out_shape.size(), 1, + platform::errors::InvalidArgument( + "The size of the output shape must be greater than 0 " + "but the value given is %d.", + out_shape.size())); + ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + int dtype = ctx.Attr("dtype"); + framework::proto::VarType::Type type = + framework::proto::VarType::Type(dtype); + return framework::OpKernelType(type, ctx.GetPlace()); + } +}; + +class RecvOpV2Maker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddOutput("Out", "(Tensor) tensor to receive."); + AddAttr("ring_id", "(int default 0) nccl communication ring id.") + .SetDefault(0); + AddAttr("peer", "(int default 0) rank id for sender.").SetDefault(0); + AddAttr("dtype", "(int default 5('float32')) data type of tensor.") + .SetDefault(5); + AddAttr>("out_shape", "shape of the output tensor.") + .SetDefault(std::vector()); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(R"DOC( +Recv Operator + +Reference: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#sendrecv +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(recv_v2, ops::RecvOpV2, ops::RecvOpV2Maker); + +REGISTER_OP_CPU_KERNEL(recv_v2, ops::RecvOpV2CPUKernel, + ops::RecvOpV2CPUKernel, + ops::RecvOpV2CPUKernel, + ops::RecvOpV2CPUKernel, + ops::RecvOpV2CPUKernel); diff --git a/paddle/fluid/operators/collective/recv_v2_op.cu.cc b/paddle/fluid/operators/collective/recv_v2_op.cu.cc new file mode 100644 index 00000000000..f0dd8aee235 --- /dev/null +++ b/paddle/fluid/operators/collective/recv_v2_op.cu.cc @@ -0,0 +1,104 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/recv_v2_op.h" + +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class RecvOpV2CUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + int rid = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + rid, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for recv_v2 op must be non-negative.", rid)); + + int peer = ctx.Attr("peer"); + PADDLE_ENFORCE_GE( + peer, 0, + platform::errors::InvalidArgument( + "The peer (%d) for recv_v2 op must be non-negative.", peer)); + + auto out = ctx.Output("Out"); + auto out_dims = out->dims(); + int data_type = ctx.Attr("dtype"); + framework::proto::VarType::Type type = + framework::proto::VarType::Type(data_type); + +#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703 + cudaStream_t stream = nullptr; + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(rid, place); + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + PADDLE_ENFORCE_LT( + peer, comm->nranks(), + platform::errors::InvalidArgument("The value of peer (%d) you set must " + "be less than comm->nranks (%d).", + peer, comm->nranks())); + ncclDataType_t dtype = platform::ToNCCLDataType(type); + + // Recv the number of elements to receive first + int numel = 0; + int *numel_ptr = nullptr; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&numel_ptr, sizeof(int))); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclRecv(static_cast(numel_ptr), 1, ncclInt, + peer, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemcpy(&numel, numel_ptr, sizeof(int), cudaMemcpyDeviceToHost)); + + int rest_numel = 1; + for (int i = 1; i < out_dims.size(); ++i) { + rest_numel = rest_numel * out_dims[i]; + } + out_dims[0] = numel / rest_numel; + out->mutable_data(out_dims, place); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( + out->data(), numel, dtype, peer, comm->comm(), stream)); + VLOG(3) << "rank " << comm->rank() << " recv " + << framework::product(out->dims()) << " from " << peer; +#else + PADDLE_THROW(platform::errors::Unavailable( + "PaddlePaddle should be compiled with NCCL and " + "NCCL version >= 2.7.3 is needed.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(recv_v2, ops::RecvOpV2CUDAKernel, + ops::RecvOpV2CUDAKernel, + ops::RecvOpV2CUDAKernel, + ops::RecvOpV2CUDAKernel, + ops::RecvOpV2CUDAKernel); diff --git a/paddle/fluid/operators/collective/recv_v2_op.h b/paddle/fluid/operators/collective/recv_v2_op.h new file mode 100644 index 00000000000..f9e21003f8f --- /dev/null +++ b/paddle/fluid/operators/collective/recv_v2_op.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class RecvOpV2CPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support recv for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/send_v2_op.cc b/paddle/fluid/operators/collective/send_v2_op.cc new file mode 100644 index 00000000000..c5a86b4f088 --- /dev/null +++ b/paddle/fluid/operators/collective/send_v2_op.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/send_v2_op.h" + +namespace paddle { +namespace operators { + +class SendOpV2 : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SendV2"); + int peer = ctx->Attrs().Get("peer"); + int ring_id = ctx->Attrs().Get("ring_id"); + PADDLE_ENFORCE_GE( + peer, 0, + platform::errors::InvalidArgument( + "The peer (%d) for send_v2 op must be non-negative.", peer)); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for send_v2 op must be non-negative.", ring_id)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class SendOpV2Maker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) tensor to be sent."); + AddAttr("ring_id", "(int default 0) nccl communication ring id.") + .SetDefault(0); + AddAttr("peer", "(int default 0) rank id for receiver.").SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(R"DOC( +Send Operator + +Reference: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#sendrecv +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(send_v2, ops::SendOpV2, ops::SendOpV2Maker); + +REGISTER_OP_CPU_KERNEL(send_v2, ops::SendOpV2CPUKernel, + ops::SendOpV2CPUKernel, + ops::SendOpV2CPUKernel, + ops::SendOpV2CPUKernel, + ops::SendOpV2CPUKernel); diff --git a/paddle/fluid/operators/collective/send_v2_op.cu.cc b/paddle/fluid/operators/collective/send_v2_op.cu.cc new file mode 100644 index 00000000000..9f925b2eede --- /dev/null +++ b/paddle/fluid/operators/collective/send_v2_op.cu.cc @@ -0,0 +1,90 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/send_v2_op.h" + +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class SendOpV2CUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto x = ctx.Input("X"); + int numel = x->numel(); + + int rid = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + rid, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for send_v2 op must be non-negative.", rid)); + + int peer = ctx.Attr("peer"); + PADDLE_ENFORCE_GE( + peer, 0, + platform::errors::InvalidArgument( + "The peer (%d) for send_v2 op must be non-negative.", peer)); + cudaStream_t stream = nullptr; + auto place = ctx.GetPlace(); +#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703 + auto comm = platform::NCCLCommContext::Instance().Get(rid, place); + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + PADDLE_ENFORCE_LT( + peer, comm->nranks(), + platform::errors::InvalidArgument("The value of peer (%d) you set must " + "be less than comm->nranks (%d).", + peer, comm->nranks())); + ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); + // Send number of elements to the receiver, as the receiver may have + // no information of the Tensor size. + int* numel_ptr = nullptr; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&numel_ptr, sizeof(int))); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemcpy(numel_ptr, &numel, sizeof(int), cudaMemcpyHostToDevice)); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( + numel_ptr, 1, ncclInt, peer, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( + x->data(), numel, dtype, peer, comm->comm(), stream)); + VLOG(3) << "rank " << comm->rank() << " send " + << framework::product(x->dims()) << " to " << peer; +#else + PADDLE_THROW(platform::errors::Unavailable( + "PaddlePaddle should be compiled with NCCL " + "and NCCL version >= 2.7.3 is needed.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(send_v2, ops::SendOpV2CUDAKernel, + ops::SendOpV2CUDAKernel, + ops::SendOpV2CUDAKernel, + ops::SendOpV2CUDAKernel, + ops::SendOpV2CUDAKernel); diff --git a/paddle/fluid/operators/collective/send_v2_op.h b/paddle/fluid/operators/collective/send_v2_op.h new file mode 100644 index 00000000000..6215fb1f3b6 --- /dev/null +++ b/paddle/fluid/operators/collective/send_v2_op.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class SendOpV2CPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support send for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 01c5cfa0aae..6fcc8b96917 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -64,6 +64,7 @@ if(NOT WITH_GPU OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_broadcast) LIST(REMOVE_ITEM TEST_OPS test_collective_reduce) LIST(REMOVE_ITEM TEST_OPS test_collective_scatter) + LIST(REMOVE_ITEM TEST_OPS test_collective_sendrecv) LIST(REMOVE_ITEM TEST_OPS test_reducescatter) LIST(REMOVE_ITEM TEST_OPS test_reducescatter_api) LIST(REMOVE_ITEM TEST_OPS test_collective_reduce_api) diff --git a/python/paddle/fluid/tests/unittests/collective_sendrecv_op.py b/python/paddle/fluid/tests/unittests/collective_sendrecv_op.py new file mode 100644 index 00000000000..0a1967aa658 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_sendrecv_op.py @@ -0,0 +1,74 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + +paddle.enable_static() + + +class TestCollectiveSendRecv(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program): + ring_id = self.global_ring_id + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float64') + if self.rank == 0: + main_prog.global_block().append_op( + type="send_v2", + inputs={'X': tindata}, + attrs={ + 'ring_id': ring_id, + 'peer': 1, + 'use_calc_stream': True + }) + else: + main_prog.global_block().append_op( + type="recv_v2", + outputs={'Out': tindata}, + attrs={ + 'peer': 0, + 'ring_id': ring_id, + 'dtype': tindata.dtype, + 'out_shape': tindata.shape, + 'use_calc_stream': True, + }) + return tindata + + +if __name__ == "__main__": + runtime_main(TestCollectiveSendRecv, "sendrecv", 0) diff --git a/python/paddle/fluid/tests/unittests/test_collective_base.py b/python/paddle/fluid/tests/unittests/test_collective_base.py index 512b2967e02..fc267ed914e 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_base.py @@ -103,6 +103,7 @@ class TestCollectiveRunnerBase(object): nranks = 2 self.initCommunicator(startup_prog, rank, nranks, True, current_endpoint, endpoints) + self.rank = rank result = self.get_model(train_prog, startup_prog) device_id = int(os.getenv("FLAGS_selected_gpus", "0")) place = fluid.CUDAPlace( @@ -268,6 +269,11 @@ class TestDistBase(unittest.TestCase): self.assertTrue( np.allclose( tr1_out, need_result2, rtol=1e-05, atol=1e-05)) + elif col_type == "sendrecv": + need_result = input1 + self.assertTrue( + np.allclose( + tr1_out, need_result, rtol=1e-05, atol=1e-05)) elif col_type == "reduce_slicegather": slicesize = input1.shape[0] // 2 tmp10 = input1[0:slicesize] diff --git a/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py b/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py new file mode 100644 index 00000000000..67c84a71bb3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py @@ -0,0 +1,34 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle + +from test_collective_base import TestDistBase + +paddle.enable_static() + + +class TestSendRecvOp(TestDistBase): + def _setup_config(self): + pass + + def test_sendrecv(self): + self.check_with_place("collective_sendrecv_op.py", "sendrecv") + + +if __name__ == '__main__': + unittest.main() -- GitLab