From 27f245cdefd784e7e9e07d4976e3ae96f85b3fde Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Fri, 4 Sep 2020 14:28:08 +0000 Subject: [PATCH] fix alltoall ut, test=develop --- .../operators/collective/c_alltoall_op.cc | 77 +++++++++++++++++ .../operators/collective/c_alltoall_op.cu.cc | 86 +++++++++++++++++++ .../operators/collective/c_alltoall_op.h | 68 +++++++++++++++ .../tests/unittests/test_collective_base.py | 14 +-- 4 files changed, 235 insertions(+), 10 deletions(-) create mode 100644 paddle/fluid/operators/collective/c_alltoall_op.cc create mode 100644 paddle/fluid/operators/collective/c_alltoall_op.cu.cc create mode 100644 paddle/fluid/operators/collective/c_alltoall_op.h diff --git a/paddle/fluid/operators/collective/c_alltoall_op.cc b/paddle/fluid/operators/collective/c_alltoall_op.cc new file mode 100644 index 00000000000..60cdb50b8fd --- /dev/null +++ b/paddle/fluid/operators/collective/c_alltoall_op.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_alltoall_op.h" + +namespace paddle { +namespace operators { + +class CAllToAllOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CAllToAll"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CAllToAll"); + int ring_id = ctx->Attrs().Get("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for c_scatter_op must be non-negative.", + ring_id)); + framework::DDim dim = ctx->GetInputDim("X"); + if (dim[0] < 0) dim[0] = -1; + ctx->SetOutputDim("Out", dim); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class CAllToAllOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) tensor send."); + AddOutput("Out", "(Tensor) the result of alltoall."); + AddAttr("ring_id", "(int default 0) nccl communication ring id.") + .SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(R"DOC( +CAllToAll Operator +Gather tensors from all participators to all participators. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(c_alltoall, ops::CAllToAllOp, + ops::CAllToAllOpMaker); + +REGISTER_OP_CPU_KERNEL(c_alltoall, ops::CAllToAllOpCPUKernel, + ops::CAllToAllOpCPUKernel, + ops::CAllToAllOpCPUKernel, + ops::CAllToAllOpCPUKernel, + ops::CAllToAllOpCPUKernel); diff --git a/paddle/fluid/operators/collective/c_alltoall_op.cu.cc b/paddle/fluid/operators/collective/c_alltoall_op.cu.cc new file mode 100644 index 00000000000..2bbb5b53a1a --- /dev/null +++ b/paddle/fluid/operators/collective/c_alltoall_op.cu.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_alltoall_op.h" + +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CAllToAllOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_NCCL) + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + int send_numel = x->numel(); + ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); + + int ring_id = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + int nranks = comm->nranks(); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for c_scatter_op must be non-negative.", + ring_id)); + + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + framework::DDim x_dims = x->dims(); + framework::DDim out_dims(x_dims); + auto send_buf = x->data(); + auto recv_buf = out->mutable_data(out_dims, place); + size_t offset = 0; + send_numel /= nranks; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto i = 0; i < nranks; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( + send_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( + recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + offset += send_numel; + } + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); +#else + PADDLE_ENFORCE_EQ( + true, false, + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(c_alltoall, ops::CAllToAllOpCUDAKernel, + ops::CAllToAllOpCUDAKernel, + ops::CAllToAllOpCUDAKernel, + ops::CAllToAllOpCUDAKernel, + ops::CAllToAllOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_alltoall_op.h b/paddle/fluid/operators/collective/c_alltoall_op.h new file mode 100644 index 00000000000..f850be7cac3 --- /dev/null +++ b/paddle/fluid/operators/collective/c_alltoall_op.h @@ -0,0 +1,68 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CAllToAllOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_GLOO) + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + auto root_id = ctx.Attr("root"); + + auto gloo = paddle::framework::GlooWrapper::GetInstance(); + PADDLE_ENFORCE_EQ( + gloo->IsInitialized(), true, + platform::errors::PreconditionNotMet( + "You must initialize the gloo environment first to use it.")); + + int64_t send_numel = in->numel(); + int64_t recv_numel = out->numel(); + auto nranks = gloo->Size(); + auto rank = gloo->Rank(); + T* recv_buff = out->data(); + T* send_buff = in->data(); + gloo::GatherOptions opts(gloo->GetContext()); + opts.setOutput(recv_buff, recv_numel); + opts.setInput(send_buff, send_numel); + opts.setRoot(root_id); + + gloo::alltoall(opts); +#else + PADDLE_THROW(platform::errors::Unavailable( + "PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON")); +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_collective_base.py b/python/paddle/fluid/tests/unittests/test_collective_base.py index 2d13991e82a..adfc31a5159 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_base.py @@ -267,16 +267,10 @@ class TestDistBase(unittest.TestCase): elif col_type == "alltoall": temp11, temp12 = np.split(input1, 2) temp21, temp22 = np.split(input2, 2) - need_result1 = np.hstack((temp11, temp21)) - need_result2 = np.hstack((temp12, temp22)) - print("input1:", input1) - print("input2:", input2) - print("need_result1:", need_result1) - print("need_result2:", need_result2) - print("tr0_out:", tr0_out) - print("tr1_out:", tr1_out) - self.assertTrue(np.allclose(tr1_out, need_result1)) - self.assertTrue(np.allclose(tr2_out, need_result2)) + need_result1 = np.vstack((temp11, temp21)) + need_result2 = np.vstack((temp12, temp22)) + self.assertTrue(np.allclose(tr0_out, need_result1)) + self.assertTrue(np.allclose(tr1_out, need_result2)) elif col_type == "reduce_scatter": tmp = input1 + input2 need_result1 = tmp[0:tmp.shape[0] // 2] -- GitLab