未验证 提交 efb05ba2 编写于 作者: Y Yi Liu 提交者: GitHub

supports multiple NCCL communicators preserved in NCCLCommContext (#19407)

* supports multiple NCCL communicators preserved in NCCLCommContext
test=develop

* add ut for c_comm_init_all operator and fix cuda resource release problem
test=develop
上级 56dd7653
......@@ -35,10 +35,10 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
int nranks = ctx.Attr<int>("nranks");
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(nranks, comm->nranks());
auto place = ctx.GetPlace();
framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
out->mutable_data<T>(out_dims, place);
......@@ -55,7 +55,7 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
stream = comm->stream();
}
PADDLE_ENFORCE(platform::dynload::ncclAllGather(
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream));
#else
......
......@@ -70,7 +70,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
void* recvbuff = out->mutable_data<T>(place);
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
......@@ -102,7 +102,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
PADDLE_THROW("Invalid reduce type: %d", red_type);
}
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream));
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
......
......@@ -33,9 +33,9 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
ncclDataType_t dtype = platform::ToNCCLDataType(x->type());
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
......@@ -46,7 +46,7 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
int root = ctx.Attr<int>("root");
if (root == comm->rank()) {
PADDLE_ENFORCE(platform::dynload::ncclBcast(
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast(
reinterpret_cast<void*>(const_cast<T*>(x->data<T>())), numel, dtype,
root, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent "
......@@ -59,9 +59,9 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
static_cast<framework::Tensor*>(out));
}
} else {
PADDLE_ENFORCE(platform::dynload::ncclBcast(out->mutable_data<T>(place),
numel, dtype, root,
comm->comm(), stream));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::ncclBcast(out->mutable_data<T>(place), numel,
dtype, root, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. recieved "
<< framework::product(out->dims());
}
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <nccl.h>
#endif
#include <stdint.h>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
class CCommInitAllInferShape : public framework::InferShapeBase {
public:
~CCommInitAllInferShape() {}
void operator()(framework::InferShapeContext* ctx) const override{};
};
class CCommInitAllOp : public framework::OperatorBase {
public:
CCommInitAllOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
"CCommInitAllOp can run on gpu place only.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::vector<int> devices = Attr<std::vector<int>>("devices");
if (devices.empty()) {
devices = platform::GetSelectedDevices();
}
int rid = Attr<int>("ring_id");
platform::NCCLCommContext::Instance().CreateAllNCCLComms(devices, rid);
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
class CCommInitAllOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddComment(R"DOC(
CCommInitAll operator
Initialize all collective communicatoin context
)DOC");
AddAttr<std::vector<int>>(
"devices",
"(std::vector<int>) which devices does the nccl comm initialized on")
.SetDefault({});
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_comm_init_all, ops::CCommInitAllOp,
ops::CCommInitAllInferShape, ops::CCommInitAllOpMaker);
......@@ -31,10 +31,10 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> {
auto out = ctx.Output<framework::Tensor>("Out");
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
int nranks = comm->nranks();
auto place = ctx.GetPlace();
auto out_dims = in->dims();
out_dims[0] = out_dims[0] / nranks;
out->mutable_data<T>(out_dims, place);
......@@ -52,7 +52,7 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> {
stream = comm->stream();
}
PADDLE_ENFORCE(platform::dynload::ncclReduceScatter(
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclReduceScatter(
send_buff, recv_buff, recv_numel, static_cast<ncclDataType_t>(dtype),
ncclSum, comm->comm(), stream));
#else
......
......@@ -38,12 +38,13 @@ class CSyncCommStreamOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE(is_gpu_place(place),
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
"Sync stream op can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
int ring_id = Attr<int>("ring_id");
auto stream = platform::NCCLCommContext::Instance().Get(ring_id)->stream();
auto stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
cudaError_t e_sync = cudaStreamSynchronize(stream);
if (e_sync != 0) {
LOG(FATAL) << "Fail to sync nccl stream: " << cudaGetErrorString(e_sync);
......
......@@ -53,46 +53,88 @@ class NCCLCommImpl : public NCCLComm {
std::unique_ptr<CUDADeviceContext> dev_ctx_;
};
// NOTE: not thread-safe
NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks,
int rank, int dev_id, int ring_id) {
PADDLE_ENFORCE_NOT_NULL(nccl_id);
PADDLE_ENFORCE_GT(nranks, 1);
PADDLE_ENFORCE(rank >= 0 && rank < nranks,
"Expected rank id range [0, %d), but get %d", nranks, rank);
PADDLE_ENFORCE_GE(rank, 0);
PADDLE_ENFORCE_LT(rank, nranks);
PADDLE_ENFORCE_GE(dev_id, 0);
if (dev_ctx_map_.count(dev_id) == 0) {
dev_ctx_map_.emplace(dev_id, std::unique_ptr<CUDADeviceContext>(
new CUDADeviceContext(CUDAPlace(dev_id))));
}
ncclComm_t comm = nullptr;
PADDLE_ENFORCE(cudaSetDevice(dev_id));
PADDLE_ENFORCE(
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank));
std::unique_ptr<CUDADeviceContext> dev_ctx(
new CUDADeviceContext(CUDAPlace(dev_id)));
dev_ctx->set_nccl_comm(comm);
NCCLCommImpl* communicator = new NCCLCommImpl;
communicator->set_ring_id(ring_id);
communicator->set_nranks(nranks);
communicator->set_rank(rank);
communicator->set_dev_ctx(std::move(dev_ctx));
NCCLCommImpl* c = new NCCLCommImpl;
c->set_ring_id(ring_id);
c->set_nranks(nranks);
c->set_rank(rank);
c->set_dev_ctx(std::move(dev_ctx));
comm_map_mutex_.lock();
if (comm_map_.count(ring_id) == 0) {
comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<NCCLComm>>());
}
auto& dev2comm = comm_map_[ring_id];
comm_map_.emplace(ring_id, std::unique_ptr<NCCLComm>(communicator));
dev2comm.emplace(dev_id, std::unique_ptr<NCCLComm>(c));
comm_map_mutex_.unlock();
VLOG(0) << "nccl communicator of rank " << rank << " in ring " << ring_id
VLOG(1) << "nccl communicator of rank " << rank << " in ring " << ring_id
<< " has been created";
return comm_map_.at(ring_id).get();
std::call_once(once_flag_, []() {
std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); });
});
return comm_map_[ring_id][dev_id].get();
}
void NCCLCommContext::CreateAllNCCLComms(const std::vector<int>& dev_ids,
int ring_id) {
PADDLE_ENFORCE_GT(dev_ids.size(), 0);
const int kDevices = dev_ids.size();
ncclComm_t comms[kDevices];
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitAll(
comms, dev_ids.size(), dev_ids.data()));
PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0);
comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<NCCLComm>>());
auto& dev2comm = comm_map_[ring_id];
for (size_t i = 0; i < dev_ids.size(); ++i) {
std::unique_ptr<CUDADeviceContext> dev_ctx(
new CUDADeviceContext(CUDAPlace(dev_ids[i])));
dev_ctx->set_nccl_comm(comms[i]);
NCCLCommImpl* c = new NCCLCommImpl;
c->set_ring_id(ring_id);
c->set_nranks(dev_ids.size());
c->set_rank(i);
c->set_dev_ctx(std::move(dev_ctx));
dev2comm.emplace(dev_ids[i], std::unique_ptr<NCCLComm>(c));
}
std::call_once(once_flag_, []() {
std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); });
});
}
NCCLCommContext::~NCCLCommContext() {
void NCCLCommContext::ReleaseNCCLComms() {
// CUDADeviceContext maintain the lifetime of nccl_comm_t, so we should not
// destroy nccl_comm_t explicitly. Please refer to
// platform::CUDADeviceContext::~CUDADeviceContext()
for (auto& p : comm_map_) {
PADDLE_ENFORCE(platform::dynload::ncclCommDestroy(p.second->comm()));
for (auto& q : p.second) {
q.second.reset();
}
}
}
......
......@@ -15,9 +15,9 @@
#pragma once
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "boost/variant.hpp"
......@@ -58,37 +58,57 @@ class NCCLComm {
virtual ~NCCLComm() = default;
};
// a singleton NCCL communicator context reserves communication ring ids
// Assume multiprocessing mode
// A singleton NCCL communicator context reserves communication ring ids
class NCCLCommContext {
public:
static NCCLCommContext& Instance() {
static NCCLCommContext comm_ctx;
return comm_ctx;
}
~NCCLCommContext();
NCCLComm* CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, int rank,
int dev_id, int ring_id = 0);
// retrieve a communicator by the ring id
void CreateAllNCCLComms(const std::vector<int>& dev_ids, int ring_id = 0);
// retrieve a communicator by the ring id in multiprocessing mode
NCCLComm* Get(int ring_id) const {
PADDLE_ENFORCE(comm_map_.count(ring_id),
PADDLE_ENFORCE_GT(comm_map_.count(ring_id), 0,
"comunicator in ring id %d has not been initialized",
ring_id);
return comm_map_.at(ring_id).get();
PADDLE_ENFORCE_EQ(comm_map_.at(ring_id).size(), 1,
"you should specify a device id to retrieve from "
"multiple communicators");
return comm_map_.at(ring_id).begin()->second.get();
}
// retrieve a communicator by the ring id and the device id
NCCLComm* Get(int ring_id, int dev_id) const {
PADDLE_ENFORCE_GT(comm_map_.count(ring_id), 0,
"comunicator of ring id %d has not been initialized",
ring_id);
PADDLE_ENFORCE_GT(
comm_map_.at(ring_id).count(dev_id), 0,
"comunicator at device id %d has not been initialized in ring %d",
dev_id, ring_id);
return comm_map_.at(ring_id).at(dev_id).get();
}
// retrieve a communicator by the ring id and place
NCCLComm* Get(int ring_id, Place place) const {
return Get(ring_id, boost::get<CUDAPlace>(place).device);
}
private:
// ring id to NCCLComm
std::unordered_map<int, std::unique_ptr<NCCLComm>> comm_map_;
std::once_flag once_flag_;
std::mutex comm_map_mutex_;
// ring id to dev-NCCLComm
std::map<int, std::map<int, std::unique_ptr<NCCLComm>>> comm_map_;
// device id to CUDADeviceContext
std::unordered_map<int, std::unique_ptr<CUDADeviceContext>> dev_ctx_map_;
void ReleaseNCCLComms();
NCCLCommContext() = default;
NCCLCommContext(const NCCLCommContext& other) = delete;
NCCLCommContext& operator=(const NCCLCommContext& other) = delete;
DISABLE_COPY_AND_ASSIGN(NCCLCommContext);
};
} // namespace platform
......
......@@ -28,6 +28,7 @@ endif(NOT WITH_DISTRIBUTE)
if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_c_comm_init_all_op)
LIST(REMOVE_ITEM TEST_OPS test_allgather)
LIST(REMOVE_ITEM TEST_OPS test_allreduce)
LIST(REMOVE_ITEM TEST_OPS test_broadcast)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid.core as core
import paddle.fluid as fluid
class TestCCommInitAllOp(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0)
self.exe = fluid.Executor(self.place)
def test_default_attrs(self):
program = fluid.Program()
block = program.global_block()
block.append_op(type='c_comm_init_all', attrs={'ring_id': 0})
self.exe.run(program)
def test_init_with_same_ring_id(self):
program = fluid.Program()
block = program.global_block()
block.append_op(type='c_comm_init_all', attrs={'ring_id': 0})
with self.assertRaises(core.EnforceNotMet):
self.exe.run(program)
def test_specifying_devices(self):
program = fluid.Program()
block = program.global_block()
block.append_op(
type='c_comm_init_all', attrs={'devices': [0],
'ring_id': 1})
self.exe.run(program)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册