From 47f51e0784637525b6c15edd740862b65e29b11f Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Fri, 4 Sep 2020 13:13:43 +0000 Subject: [PATCH] add gather op --- .../fluid/operators/collective/c_gather_op.cc | 95 ++++++++++++++++++ .../operators/collective/c_gather_op.cu.cc | 99 +++++++++++++++++++ .../fluid/operators/collective/c_gather_op.h | 78 +++++++++++++++ .../fluid/operators/collective/c_recv_op.cc | 22 ++++- .../operators/collective/c_recv_op.cu.cc | 12 ++- .../operators/collective/c_scatter_op.cc | 7 +- .../operators/collective/c_scatter_op.cu.cc | 39 +++----- .../fluid/operators/collective/c_send_op.cc | 14 ++- .../operators/collective/c_send_op.cu.cc | 6 ++ .../fluid/tests/unittests/CMakeLists.txt | 2 + .../tests/unittests/collective_alltoall_op.py | 64 ++++++++++++ .../tests/unittests/collective_gather_op.py | 66 +++++++++++++ .../tests/unittests/collective_sendrecv_op.py | 10 +- .../unittests/test_collective_alltoall.py | 31 ++++++ .../tests/unittests/test_collective_base.py | 16 +++ .../tests/unittests/test_collective_gather.py | 31 ++++++ .../unittests/test_collective_sendrecv.py | 2 +- 17 files changed, 555 insertions(+), 39 deletions(-) create mode 100644 paddle/fluid/operators/collective/c_gather_op.cc create mode 100644 paddle/fluid/operators/collective/c_gather_op.cu.cc create mode 100644 paddle/fluid/operators/collective/c_gather_op.h create mode 100644 python/paddle/fluid/tests/unittests/collective_alltoall_op.py create mode 100644 python/paddle/fluid/tests/unittests/collective_gather_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_collective_alltoall.py create mode 100644 python/paddle/fluid/tests/unittests/test_collective_gather.py diff --git a/paddle/fluid/operators/collective/c_gather_op.cc b/paddle/fluid/operators/collective/c_gather_op.cc new file mode 100644 index 00000000000..2daf69283ec --- /dev/null +++ b/paddle/fluid/operators/collective/c_gather_op.cc @@ -0,0 +1,95 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_gather_op.h" + +namespace paddle { +namespace operators { + +class CGatherOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CGather"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CGather"); + int root_id = ctx->Attrs().Get("root"); + int ring_id = ctx->Attrs().Get("ring_id"); + int nranks = ctx->Attrs().Get("nranks"); + PADDLE_ENFORCE_GE(nranks, 2, + platform::errors::InvalidArgument( + "The number of ranks (%d) must be greater than 1 " + "to use collective op (c_gather op).", + nranks)); + PADDLE_ENFORCE_GE( + root_id, 0, + platform::errors::InvalidArgument( + "The root_id (%d) for c_gather_op must be non-negative.", root_id)); + PADDLE_ENFORCE_LT( + root_id, nranks, + platform::errors::InvalidArgument( + "The root_id (%d) for c_gather_op must be less than nranks (%d).", + root_id, nranks)); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for c_gather_op must be non-negative.", root_id)); + framework::DDim dim = ctx->GetInputDim("X"); + dim[0] = dim[0] * nranks; + if (dim[0] < 0) dim[0] = -1; + ctx->SetOutputDim("Out", dim); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class CGatherOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) tensor to be gathered."); + AddOutput("Out", "(Tensor) the result of gather."); + AddAttr("ring_id", "(int default 0) nccl communication ring id.") + .SetDefault(0); + AddAttr("root", "(int default 0) root id for broadcasting.") + .SetDefault(0); + AddAttr("nranks", "(int default 1) number of ranks.").SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(R"DOC( +CGather Operator +Gather tensors from all participators. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(c_gather, ops::CGatherOp, ops::CGatherOpMaker); + +REGISTER_OP_CPU_KERNEL(c_gather, ops::CGatherOpCPUKernel, + ops::CGatherOpCPUKernel, + ops::CGatherOpCPUKernel, + ops::CGatherOpCPUKernel, + ops::CGatherOpCPUKernel); diff --git a/paddle/fluid/operators/collective/c_gather_op.cu.cc b/paddle/fluid/operators/collective/c_gather_op.cu.cc new file mode 100644 index 00000000000..84f9e877d41 --- /dev/null +++ b/paddle/fluid/operators/collective/c_gather_op.cu.cc @@ -0,0 +1,99 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_gather_op.h" + +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CGatherOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_NCCL) + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + int send_numel = x->numel(); + ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); + + int nranks = ctx.Attr("nranks"); + int root_id = ctx.Attr("root"); + int ring_id = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + PADDLE_ENFORCE_EQ(nranks, comm->nranks(), + platform::errors::InvalidArgument( + "The number of ranks (%d) you set of must " + "be equal to comm->nranks (%d).", + nranks, comm->nranks())); + PADDLE_ENFORCE_GE( + root_id, 0, + platform::errors::InvalidArgument( + "The root_id (%d) for c_scatter_op must be non-negative.", + root_id)); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for c_scatter_op must be non-negative.", + ring_id)); + + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + framework::DDim x_dims = x->dims(); + framework::DDim out_dims(x_dims); + out_dims[0] *= nranks; + auto send_buf = x->data(); + auto offset = 0; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( + send_buf, send_numel, dtype, root_id, comm->comm(), stream)); + if (root_id == comm->rank()) { + auto recv_buf = out->mutable_data(out_dims, place); + for (auto i = 0; i < nranks; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( + recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + offset += send_numel; + } + } + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); +#else + PADDLE_ENFORCE_EQ( + true, false, + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(c_gather, ops::CGatherOpCUDAKernel, + ops::CGatherOpCUDAKernel, + ops::CGatherOpCUDAKernel, + ops::CGatherOpCUDAKernel, + ops::CGatherOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_gather_op.h b/paddle/fluid/operators/collective/c_gather_op.h new file mode 100644 index 00000000000..ca4ed7510c6 --- /dev/null +++ b/paddle/fluid/operators/collective/c_gather_op.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CGatherOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_GLOO) + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + auto root_id = ctx.Attr("root"); + auto nranks = ctx.Attr("nranks"); + + auto gloo = paddle::framework::GlooWrapper::GetInstance(); + PADDLE_ENFORCE_EQ( + gloo->IsInitialized(), true, + platform::errors::PreconditionNotMet( + "You must initialize the gloo environment first to use it.")); + + PADDLE_ENFORCE_EQ(nranks, gloo->Size(), + platform::errors::InvalidArgument( + "The number of ranks (%d) you set must " + "be equal to gloo->Size() (%d).", + nranks, gloo->Size())); + int64_t send_numel = in->numel(); + int64_t recv_numel = out->numel(); + auto in_dim = x->dims(); + auto out_dim = framework::DDim(in_dim); + out_dim[0] *= nranks; + auto nranks = gloo->Size(); + auto rank = gloo->Rank(); + gloo::GatherOptions opts(gloo->GetContext()); + if (root_id == rank) { + T* recv_buff = out->mutable_data(place, out_dim); + opts.setOutput(recv_buff, recv_numel); + } + opts.setInput(send_buff, send_numel); + opts.setRoot(root_id); + + gloo::gather(opts); +#else + PADDLE_THROW(platform::errors::Unavailable( + "PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON")); +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/c_recv_op.cc b/paddle/fluid/operators/collective/c_recv_op.cc index ed786974432..a3d59648e09 100644 --- a/paddle/fluid/operators/collective/c_recv_op.cc +++ b/paddle/fluid/operators/collective/c_recv_op.cc @@ -21,20 +21,34 @@ class CRecvOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override {} + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CRecv"); + int peer = ctx->Attrs().Get("peer"); + int ring_id = ctx->Attrs().Get("ring_id"); + PADDLE_ENFORCE_GE( + peer, 0, + platform::errors::InvalidArgument( + "The peer (%d) for c_send_op must be non-negative.", peer)); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for c_send_op must be non-negative.", ring_id)); + } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace()); + auto out = ctx.Output("Out"); + auto dtype = out->type(); + return framework::OpKernelType(dtype, ctx.GetPlace()); + // OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace()); } }; class CRecvOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { - AddInput("Out", "(Tensor) tensor to receive."); + AddOutput("Out", "(Tensor) tensor to receive."); AddAttr("ring_id", "(int default 0) nccl communication ring id.") .SetDefault(0); AddAttr("peer", "(int default 0) rank id for sender.").SetDefault(0); diff --git a/paddle/fluid/operators/collective/c_recv_op.cu.cc b/paddle/fluid/operators/collective/c_recv_op.cu.cc index 69f5d5beb9d..4a716ab61b5 100644 --- a/paddle/fluid/operators/collective/c_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/c_recv_op.cu.cc @@ -27,7 +27,7 @@ class CRecvOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { #if defined(PADDLE_WITH_NCCL) - auto out = ctx.Input("Out"); + auto out = ctx.Output("Out"); int numel = out->numel(); ncclDataType_t dtype = platform::ToNCCLDataType(out->type()); @@ -44,9 +44,13 @@ class CRecvOpCUDAKernel : public framework::OpKernel { } int peer = ctx.Attr("peer"); - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::ncclRecv(const_cast(out->data()), numel, - dtype, peer, comm->comm(), stream)); + PADDLE_ENFORCE_LT( + peer, comm->nranks(), + platform::errors::InvalidArgument("The value of peer (%d) you set must " + "be less than comm->nranks (%d).", + peer, comm->nranks())); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( + out->data(), numel, dtype, peer, comm->comm(), stream)); VLOG(3) << "rank " << comm->rank() << " recv " << framework::product(out->dims()) << " from " << peer; #else diff --git a/paddle/fluid/operators/collective/c_scatter_op.cc b/paddle/fluid/operators/collective/c_scatter_op.cc index 908708e6e32..2567febc65d 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cc +++ b/paddle/fluid/operators/collective/c_scatter_op.cc @@ -37,11 +37,16 @@ class CScatterOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "The root_id (%d) for c_scatter_op must be non-negative.", root_id)); + PADDLE_ENFORCE_LT(root_id, nranks, + platform::errors::InvalidArgument( + "The root_id (%d) for c_scatter_op must be less " + "than the number of ranks (%d).", + root_id, nranks)); PADDLE_ENFORCE_GE( ring_id, 0, platform::errors::InvalidArgument( "The ring_id (%d) for c_scatter_op must be non-negative.", - root_id)); + ring_id)); framework::DDim dim = ctx->GetInputDim("X"); dim[0] = dim[0] / nranks; if (dim[0] < 0) dim[0] = -1; diff --git a/paddle/fluid/operators/collective/c_scatter_op.cu.cc b/paddle/fluid/operators/collective/c_scatter_op.cu.cc index 8d9e6b4b7d9..4209ff43a8a 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_scatter_op.cu.cc @@ -39,7 +39,7 @@ class CScatterOpCUDAKernel : public framework::OpKernel { auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); PADDLE_ENFORCE_EQ(nranks, comm->nranks(), platform::errors::InvalidArgument( - "The number of ranks (%d) you set of must " + "The number of ranks (%d) you set must " "be equal to comm->nranks (%d).", nranks, comm->nranks())); PADDLE_ENFORCE_GE( @@ -63,30 +63,23 @@ class CScatterOpCUDAKernel : public framework::OpKernel { framework::DDim x_dims = x->dims(); framework::DDim out_dims(x_dims); - framework::Tensor temp; - auto out_ptr = temp.mutable_data(out_dims, place); + out_dims[0] /= nranks; + auto send_buf = x->data(); + auto send_numel = numel / nranks; + auto recv_buf = out->mutable_data(out_dims, place); + auto offset = 0; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); if (root_id == comm->rank()) { - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( - reinterpret_cast(const_cast(x->data())), numel, dtype, - root_id, comm->comm(), stream)); - - framework::TensorCopy(*static_cast(x), place, - *platform::DeviceContextPool::Instance().Get(place), - static_cast(&temp)); - } else { - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( - out_ptr, numel, dtype, root_id, comm->comm(), stream)); + for (auto i = 0; i < nranks; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclSend(send_buf + offset, send_numel, dtype, + root_id, comm->comm(), stream)); + offset += send_numel; + } } - - out_dims[0] = out_dims[0] / nranks; - auto start_index = out_dims[0] * comm->rank(); - auto end_index = start_index + out_dims[0]; - temp = temp.Slice(start_index, end_index); - temp.Resize(out_dims); - out->mutable_data(out_dims, place); - framework::TensorCopySync(*static_cast(&temp), - place, static_cast(out)); - out->Resize(out_dims); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( + recv_buf, send_numel, dtype, root_id, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); #else PADDLE_ENFORCE_EQ( true, false, diff --git a/paddle/fluid/operators/collective/c_send_op.cc b/paddle/fluid/operators/collective/c_send_op.cc index 54c4b86bbdb..70e93f8d73c 100644 --- a/paddle/fluid/operators/collective/c_send_op.cc +++ b/paddle/fluid/operators/collective/c_send_op.cc @@ -21,7 +21,19 @@ class CSendOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override {} + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CSend"); + int peer = ctx->Attrs().Get("peer"); + int ring_id = ctx->Attrs().Get("ring_id"); + PADDLE_ENFORCE_GE( + peer, 0, + platform::errors::InvalidArgument( + "The peer (%d) for c_send_op must be non-negative.", peer)); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for c_send_op must be non-negative.", ring_id)); + } protected: framework::OpKernelType GetExpectedKernelType( diff --git a/paddle/fluid/operators/collective/c_send_op.cu.cc b/paddle/fluid/operators/collective/c_send_op.cu.cc index 97b62849ca6..44ef0fa019d 100644 --- a/paddle/fluid/operators/collective/c_send_op.cu.cc +++ b/paddle/fluid/operators/collective/c_send_op.cu.cc @@ -44,6 +44,12 @@ class CSendOpCUDAKernel : public framework::OpKernel { } int peer = ctx.Attr("peer"); + PADDLE_ENFORCE_LT( + peer, comm->nranks(), + platform::errors::InvalidArgument("The value of peer (%d) you set must " + "be less than comm->nranks (%d).", + peer, comm->nranks())); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( x->data(), numel, dtype, peer, comm->comm(), stream)); VLOG(3) << "rank " << comm->rank() << " send " diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index cdd94a1fa03..df22fa873cd 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -58,6 +58,8 @@ if(NOT WITH_GPU OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_broadcast) LIST(REMOVE_ITEM TEST_OPS test_collective_reduce) LIST(REMOVE_ITEM TEST_OPS test_collective_sendrecv) + LIST(REMOVE_ITEM TEST_OPS test_collective_gather) + LIST(REMOVE_ITEM TEST_OPS test_collective_alltoall) LIST(REMOVE_ITEM TEST_OPS test_collective_scatter) LIST(REMOVE_ITEM TEST_OPS test_collective_reduce_api) LIST(REMOVE_ITEM TEST_OPS test_collective_scatter_api) diff --git a/python/paddle/fluid/tests/unittests/collective_alltoall_op.py b/python/paddle/fluid/tests/unittests/collective_alltoall_op.py new file mode 100644 index 00000000000..c4eb7096b64 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_alltoall_op.py @@ -0,0 +1,64 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + + +class TestCollectiveAllToAll(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank=None): + ring_id = 0 + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + toutdata = layers.data( + name="toutdata", shape=[10, 1000], dtype='float32') + main_prog.global_block().append_op( + type="c_alltoall", + inputs={'X': tindata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id}) + main_prog.global_block().append_op( + type="c_sync_comm_stream", + inputs={'X': toutdata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id}) + return toutdata + + +if __name__ == "__main__": + runtime_main(TestCollectiveAllToAll, "alltoall", 0) diff --git a/python/paddle/fluid/tests/unittests/collective_gather_op.py b/python/paddle/fluid/tests/unittests/collective_gather_op.py new file mode 100644 index 00000000000..9991694722b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_gather_op.py @@ -0,0 +1,66 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + + +class TestCollectiveGather(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank=None): + ring_id = 0 + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + toutdata = layers.data( + name="toutdata", shape=[20, 1000], dtype='float32') + main_prog.global_block().append_op( + type="c_gather", + inputs={'X': tindata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id, + 'nranks': 2, + 'root': 1}) + main_prog.global_block().append_op( + type="c_sync_comm_stream", + inputs={'X': toutdata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id}) + return toutdata + + +if __name__ == "__main__": + runtime_main(TestCollectiveGather, "gather", 0) diff --git a/python/paddle/fluid/tests/unittests/collective_sendrecv_op.py b/python/paddle/fluid/tests/unittests/collective_sendrecv_op.py index 822bd3c4293..a97f6866ba8 100644 --- a/python/paddle/fluid/tests/unittests/collective_sendrecv_op.py +++ b/python/paddle/fluid/tests/unittests/collective_sendrecv_op.py @@ -36,7 +36,7 @@ from functools import reduce from test_collective_base import TestCollectiveRunnerBase, runtime_main -class TestCollectiveScatter(TestCollectiveRunnerBase): +class TestCollectiveSendRecv(TestCollectiveRunnerBase): def __init__(self): self.global_ring_id = 0 @@ -48,7 +48,7 @@ class TestCollectiveScatter(TestCollectiveRunnerBase): if rank == 0: main_prog.global_block().append_op( type="c_recv", - inputs={'Out': tindata}, + outputs={'Out': tindata}, attrs={'ring_id': ring_id, 'peer': 1}) else: @@ -59,11 +59,11 @@ class TestCollectiveScatter(TestCollectiveRunnerBase): 'peer': 0}) main_prog.global_block().append_op( type="c_sync_comm_stream", - inputs={'X': toutdata}, - outputs={'Out': toutdata}, + inputs={'X': tindata}, + outputs={'Out': tindata}, attrs={'ring_id': ring_id}) return tindata if __name__ == "__main__": - runtime_main(TestCollectiveScatter, "scatter", 0) + runtime_main(TestCollectiveSendRecv, "sendrecv", 0) diff --git a/python/paddle/fluid/tests/unittests/test_collective_alltoall.py b/python/paddle/fluid/tests/unittests/test_collective_alltoall.py new file mode 100644 index 00000000000..c35e5e89e6e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collective_alltoall.py @@ -0,0 +1,31 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np + +from test_collective_base import TestDistBase + + +class TestCAllToAllOp(TestDistBase): + def _setup_config(self): + pass + + def test_alltoall(self): + self.check_with_place("collective_alltoall_op.py", "alltoall") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_collective_base.py b/python/paddle/fluid/tests/unittests/test_collective_base.py index e97be6b7a55..2d13991e82a 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_base.py @@ -261,6 +261,22 @@ class TestDistBase(unittest.TestCase): elif col_type == "sendrecv": need_result = input2 self.assertTrue(np.allclose(tr0_out, need_result)) + elif col_type == "gather": + need_result = np.vstack((input1, input2)) + self.assertTrue(np.allclose(tr1_out, need_result)) + elif col_type == "alltoall": + temp11, temp12 = np.split(input1, 2) + temp21, temp22 = np.split(input2, 2) + need_result1 = np.hstack((temp11, temp21)) + need_result2 = np.hstack((temp12, temp22)) + print("input1:", input1) + print("input2:", input2) + print("need_result1:", need_result1) + print("need_result2:", need_result2) + print("tr0_out:", tr0_out) + print("tr1_out:", tr1_out) + self.assertTrue(np.allclose(tr1_out, need_result1)) + self.assertTrue(np.allclose(tr2_out, need_result2)) elif col_type == "reduce_scatter": tmp = input1 + input2 need_result1 = tmp[0:tmp.shape[0] // 2] diff --git a/python/paddle/fluid/tests/unittests/test_collective_gather.py b/python/paddle/fluid/tests/unittests/test_collective_gather.py new file mode 100644 index 00000000000..1ce1cb8e7b8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collective_gather.py @@ -0,0 +1,31 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np + +from test_collective_base import TestDistBase + + +class TestCGatherOp(TestDistBase): + def _setup_config(self): + pass + + def test_gather(self): + self.check_with_place("collective_gather_op.py", "gather") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py b/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py index ed6596faa49..5abfb218430 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py +++ b/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py @@ -19,7 +19,7 @@ import numpy as np from test_collective_base import TestDistBase -class TestCScatterOp(TestDistBase): +class TestCSendRecvOp(TestDistBase): def _setup_config(self): pass -- GitLab