From efb05ba258c1d5534ec9be56afa28fc84cc01843 Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Tue, 27 Aug 2019 08:56:48 -0500 Subject: [PATCH] supports multiple NCCL communicators preserved in NCCLCommContext (#19407) * supports multiple NCCL communicators preserved in NCCLCommContext test=develop * add ut for c_comm_init_all operator and fix cuda resource release problem test=develop --- .../operators/collective/c_allgather_op.cu.cc | 6 +- .../operators/collective/c_allreduce_op.h | 4 +- .../operators/collective/c_broadcast_op.cu.cc | 12 +-- .../collective/c_comm_init_all_op.cc | 93 +++++++++++++++++++ .../collective/c_reducescatter_op.cu.cc | 6 +- .../collective/c_sync_comm_stream_op.cc | 7 +- paddle/fluid/platform/collective_helper.cc | 82 ++++++++++++---- paddle/fluid/platform/collective_helper.h | 50 +++++++--- .../fluid/tests/unittests/CMakeLists.txt | 1 + .../unittests/test_c_comm_init_all_op.py | 50 ++++++++++ 10 files changed, 259 insertions(+), 52 deletions(-) create mode 100644 paddle/fluid/operators/collective/c_comm_init_all_op.cc create mode 100644 python/paddle/fluid/tests/unittests/test_c_comm_init_all_op.py diff --git a/paddle/fluid/operators/collective/c_allgather_op.cu.cc b/paddle/fluid/operators/collective/c_allgather_op.cu.cc index 330219cd1f8..14e2741e52e 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cu.cc @@ -35,10 +35,10 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel { int nranks = ctx.Attr("nranks"); int rid = ctx.Attr("ring_id"); - auto comm = platform::NCCLCommContext::Instance().Get(rid); + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(rid, place); PADDLE_ENFORCE_EQ(nranks, comm->nranks()); - auto place = ctx.GetPlace(); framework::DDim out_dims = in->dims(); out_dims[0] *= nranks; out->mutable_data(out_dims, place); @@ -55,7 +55,7 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel { stream = comm->stream(); } - PADDLE_ENFORCE(platform::dynload::ncclAllGather( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( send_buff, recv_buff, send_numel, static_cast(dtype), comm->comm(), stream)); #else diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 1db5f15595e..02f6210ca4c 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -70,7 +70,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel { void* recvbuff = out->mutable_data(place); int rid = ctx.Attr("ring_id"); - auto comm = platform::NCCLCommContext::Instance().Get(rid); + auto comm = platform::NCCLCommContext::Instance().Get(rid, place); cudaStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { @@ -102,7 +102,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel { PADDLE_THROW("Invalid reduce type: %d", red_type); } - PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream)); #else PADDLE_THROW("PaddlePaddle should compile with GPU."); diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc index c0f5bbd2c2f..a4433d0b3d1 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc @@ -33,9 +33,9 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel { ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); int rid = ctx.Attr("ring_id"); - auto comm = platform::NCCLCommContext::Instance().Get(rid); - auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(rid, place); + cudaStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); @@ -46,7 +46,7 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel { int root = ctx.Attr("root"); if (root == comm->rank()) { - PADDLE_ENFORCE(platform::dynload::ncclBcast( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( reinterpret_cast(const_cast(x->data())), numel, dtype, root, comm->comm(), stream)); VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent " @@ -59,9 +59,9 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel { static_cast(out)); } } else { - PADDLE_ENFORCE(platform::dynload::ncclBcast(out->mutable_data(place), - numel, dtype, root, - comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclBcast(out->mutable_data(place), numel, + dtype, root, comm->comm(), stream)); VLOG(3) << "rank " << comm->rank() << " invoke Bcast. recieved " << framework::product(out->dims()); } diff --git a/paddle/fluid/operators/collective/c_comm_init_all_op.cc b/paddle/fluid/operators/collective/c_comm_init_all_op.cc new file mode 100644 index 00000000000..758affbd438 --- /dev/null +++ b/paddle/fluid/operators/collective/c_comm_init_all_op.cc @@ -0,0 +1,93 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include +#endif +#include +#include +#include + +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/operators/distributed/distributed.h" +#include "paddle/fluid/operators/distributed/request_handler_impl.h" +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +class CCommInitAllInferShape : public framework::InferShapeBase { + public: + ~CCommInitAllInferShape() {} + void operator()(framework::InferShapeContext* ctx) const override{}; +}; + +class CCommInitAllOp : public framework::OperatorBase { + public: + CCommInitAllOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + PADDLE_ENFORCE_EQ(is_gpu_place(place), true, + "CCommInitAllOp can run on gpu place only."); + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + std::vector devices = Attr>("devices"); + if (devices.empty()) { + devices = platform::GetSelectedDevices(); + } + + int rid = Attr("ring_id"); + + platform::NCCLCommContext::Instance().CreateAllNCCLComms(devices, rid); +#else + PADDLE_THROW("PaddlePaddle should compile with GPU."); +#endif + } +}; + +class CCommInitAllOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddComment(R"DOC( +CCommInitAll operator + +Initialize all collective communicatoin context +)DOC"); + AddAttr>( + "devices", + "(std::vector) which devices does the nccl comm initialized on") + .SetDefault({}); + AddAttr("ring_id", "(int default 0) user specified ring id") + .SetDefault(0); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(c_comm_init_all, ops::CCommInitAllOp, + ops::CCommInitAllInferShape, ops::CCommInitAllOpMaker); diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc index 7244aa949eb..da92b65aa9e 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc @@ -31,10 +31,10 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel { auto out = ctx.Output("Out"); int rid = ctx.Attr("ring_id"); - auto comm = platform::NCCLCommContext::Instance().Get(rid); + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(rid, place); int nranks = comm->nranks(); - auto place = ctx.GetPlace(); auto out_dims = in->dims(); out_dims[0] = out_dims[0] / nranks; out->mutable_data(out_dims, place); @@ -52,7 +52,7 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel { stream = comm->stream(); } - PADDLE_ENFORCE(platform::dynload::ncclReduceScatter( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclReduceScatter( send_buff, recv_buff, recv_numel, static_cast(dtype), ncclSum, comm->comm(), stream)); #else diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc index 5170356165f..320c8507038 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc @@ -38,12 +38,13 @@ class CSyncCommStreamOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& place) const override { - PADDLE_ENFORCE(is_gpu_place(place), - "Sync stream op can run on gpu place only for now."); + PADDLE_ENFORCE_EQ(is_gpu_place(place), true, + "Sync stream op can run on gpu place only for now."); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) int ring_id = Attr("ring_id"); - auto stream = platform::NCCLCommContext::Instance().Get(ring_id)->stream(); + auto stream = + platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); cudaError_t e_sync = cudaStreamSynchronize(stream); if (e_sync != 0) { LOG(FATAL) << "Fail to sync nccl stream: " << cudaGetErrorString(e_sync); diff --git a/paddle/fluid/platform/collective_helper.cc b/paddle/fluid/platform/collective_helper.cc index ddd242cda83..2025e5346f6 100644 --- a/paddle/fluid/platform/collective_helper.cc +++ b/paddle/fluid/platform/collective_helper.cc @@ -53,46 +53,88 @@ class NCCLCommImpl : public NCCLComm { std::unique_ptr dev_ctx_; }; -// NOTE: not thread-safe NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, int rank, int dev_id, int ring_id) { PADDLE_ENFORCE_NOT_NULL(nccl_id); PADDLE_ENFORCE_GT(nranks, 1); - PADDLE_ENFORCE(rank >= 0 && rank < nranks, - "Expected rank id range [0, %d), but get %d", nranks, rank); + PADDLE_ENFORCE_GE(rank, 0); + PADDLE_ENFORCE_LT(rank, nranks); PADDLE_ENFORCE_GE(dev_id, 0); - if (dev_ctx_map_.count(dev_id) == 0) { - dev_ctx_map_.emplace(dev_id, std::unique_ptr( - new CUDADeviceContext(CUDAPlace(dev_id)))); - } - ncclComm_t comm = nullptr; - PADDLE_ENFORCE(cudaSetDevice(dev_id)); - PADDLE_ENFORCE( + PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id)); + PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank)); std::unique_ptr dev_ctx( new CUDADeviceContext(CUDAPlace(dev_id))); dev_ctx->set_nccl_comm(comm); - NCCLCommImpl* communicator = new NCCLCommImpl; - communicator->set_ring_id(ring_id); - communicator->set_nranks(nranks); - communicator->set_rank(rank); - communicator->set_dev_ctx(std::move(dev_ctx)); + NCCLCommImpl* c = new NCCLCommImpl; + c->set_ring_id(ring_id); + c->set_nranks(nranks); + c->set_rank(rank); + c->set_dev_ctx(std::move(dev_ctx)); + + comm_map_mutex_.lock(); + if (comm_map_.count(ring_id) == 0) { + comm_map_.emplace(ring_id, std::map>()); + } + auto& dev2comm = comm_map_[ring_id]; - comm_map_.emplace(ring_id, std::unique_ptr(communicator)); + dev2comm.emplace(dev_id, std::unique_ptr(c)); + comm_map_mutex_.unlock(); - VLOG(0) << "nccl communicator of rank " << rank << " in ring " << ring_id + VLOG(1) << "nccl communicator of rank " << rank << " in ring " << ring_id << " has been created"; - return comm_map_.at(ring_id).get(); + std::call_once(once_flag_, []() { + std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); }); + }); + + return comm_map_[ring_id][dev_id].get(); +} + +void NCCLCommContext::CreateAllNCCLComms(const std::vector& dev_ids, + int ring_id) { + PADDLE_ENFORCE_GT(dev_ids.size(), 0); + + const int kDevices = dev_ids.size(); + ncclComm_t comms[kDevices]; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitAll( + comms, dev_ids.size(), dev_ids.data())); + + PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0); + comm_map_.emplace(ring_id, std::map>()); + + auto& dev2comm = comm_map_[ring_id]; + for (size_t i = 0; i < dev_ids.size(); ++i) { + std::unique_ptr dev_ctx( + new CUDADeviceContext(CUDAPlace(dev_ids[i]))); + dev_ctx->set_nccl_comm(comms[i]); + + NCCLCommImpl* c = new NCCLCommImpl; + c->set_ring_id(ring_id); + c->set_nranks(dev_ids.size()); + c->set_rank(i); + c->set_dev_ctx(std::move(dev_ctx)); + + dev2comm.emplace(dev_ids[i], std::unique_ptr(c)); + } + + std::call_once(once_flag_, []() { + std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); }); + }); } -NCCLCommContext::~NCCLCommContext() { +void NCCLCommContext::ReleaseNCCLComms() { + // CUDADeviceContext maintain the lifetime of nccl_comm_t, so we should not + // destroy nccl_comm_t explicitly. Please refer to + // platform::CUDADeviceContext::~CUDADeviceContext() for (auto& p : comm_map_) { - PADDLE_ENFORCE(platform::dynload::ncclCommDestroy(p.second->comm())); + for (auto& q : p.second) { + q.second.reset(); + } } } diff --git a/paddle/fluid/platform/collective_helper.h b/paddle/fluid/platform/collective_helper.h index 7479ebaf7d2..747e840037e 100644 --- a/paddle/fluid/platform/collective_helper.h +++ b/paddle/fluid/platform/collective_helper.h @@ -15,9 +15,9 @@ #pragma once #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include #include #include -#include #include #include "boost/variant.hpp" @@ -58,37 +58,57 @@ class NCCLComm { virtual ~NCCLComm() = default; }; -// a singleton NCCL communicator context reserves communication ring ids -// Assume multiprocessing mode +// A singleton NCCL communicator context reserves communication ring ids class NCCLCommContext { public: static NCCLCommContext& Instance() { static NCCLCommContext comm_ctx; return comm_ctx; } - ~NCCLCommContext(); NCCLComm* CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, int rank, int dev_id, int ring_id = 0); - // retrieve a communicator by the ring id + void CreateAllNCCLComms(const std::vector& dev_ids, int ring_id = 0); + + // retrieve a communicator by the ring id in multiprocessing mode NCCLComm* Get(int ring_id) const { - PADDLE_ENFORCE(comm_map_.count(ring_id), - "comunicator in ring id %d has not been initialized", - ring_id); - return comm_map_.at(ring_id).get(); + PADDLE_ENFORCE_GT(comm_map_.count(ring_id), 0, + "comunicator in ring id %d has not been initialized", + ring_id); + PADDLE_ENFORCE_EQ(comm_map_.at(ring_id).size(), 1, + "you should specify a device id to retrieve from " + "multiple communicators"); + return comm_map_.at(ring_id).begin()->second.get(); + } + + // retrieve a communicator by the ring id and the device id + NCCLComm* Get(int ring_id, int dev_id) const { + PADDLE_ENFORCE_GT(comm_map_.count(ring_id), 0, + "comunicator of ring id %d has not been initialized", + ring_id); + PADDLE_ENFORCE_GT( + comm_map_.at(ring_id).count(dev_id), 0, + "comunicator at device id %d has not been initialized in ring %d", + dev_id, ring_id); + return comm_map_.at(ring_id).at(dev_id).get(); + } + + // retrieve a communicator by the ring id and place + NCCLComm* Get(int ring_id, Place place) const { + return Get(ring_id, boost::get(place).device); } private: - // ring id to NCCLComm - std::unordered_map> comm_map_; + std::once_flag once_flag_; + std::mutex comm_map_mutex_; + // ring id to dev-NCCLComm + std::map>> comm_map_; - // device id to CUDADeviceContext - std::unordered_map> dev_ctx_map_; + void ReleaseNCCLComms(); NCCLCommContext() = default; - NCCLCommContext(const NCCLCommContext& other) = delete; - NCCLCommContext& operator=(const NCCLCommContext& other) = delete; + DISABLE_COPY_AND_ASSIGN(NCCLCommContext); }; } // namespace platform diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 7150bf83f9e..c35873dc147 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -28,6 +28,7 @@ endif(NOT WITH_DISTRIBUTE) if(NOT WITH_GPU OR WIN32) + LIST(REMOVE_ITEM TEST_OPS test_c_comm_init_all_op) LIST(REMOVE_ITEM TEST_OPS test_allgather) LIST(REMOVE_ITEM TEST_OPS test_allreduce) LIST(REMOVE_ITEM TEST_OPS test_broadcast) diff --git a/python/paddle/fluid/tests/unittests/test_c_comm_init_all_op.py b/python/paddle/fluid/tests/unittests/test_c_comm_init_all_op.py new file mode 100644 index 00000000000..042f03e19ab --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_c_comm_init_all_op.py @@ -0,0 +1,50 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid.core as core +import paddle.fluid as fluid + + +class TestCCommInitAllOp(unittest.TestCase): + def setUp(self): + self.place = fluid.CUDAPlace(0) + self.exe = fluid.Executor(self.place) + + def test_default_attrs(self): + program = fluid.Program() + block = program.global_block() + block.append_op(type='c_comm_init_all', attrs={'ring_id': 0}) + self.exe.run(program) + + def test_init_with_same_ring_id(self): + program = fluid.Program() + block = program.global_block() + block.append_op(type='c_comm_init_all', attrs={'ring_id': 0}) + with self.assertRaises(core.EnforceNotMet): + self.exe.run(program) + + def test_specifying_devices(self): + program = fluid.Program() + block = program.global_block() + block.append_op( + type='c_comm_init_all', attrs={'devices': [0], + 'ring_id': 1}) + self.exe.run(program) + + +if __name__ == "__main__": + unittest.main() -- GitLab