提交 093d227a 编写于 作者: Y Yu Yang

Use mutex to stablize ncclCtxMap

上级 494c262a
...@@ -39,20 +39,19 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) { ...@@ -39,20 +39,19 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) {
class NCCLGroupGuard { class NCCLGroupGuard {
public: public:
static std::mutex &NCCLMutex() {
static std::mutex mtx;
return mtx;
}
inline NCCLGroupGuard() { inline NCCLGroupGuard() {
mutex().lock(); NCCLMutex().lock();
PADDLE_ENFORCE(dynload::ncclGroupStart()); PADDLE_ENFORCE(dynload::ncclGroupStart());
} }
inline ~NCCLGroupGuard() { inline ~NCCLGroupGuard() {
PADDLE_ENFORCE(dynload::ncclGroupEnd()); PADDLE_ENFORCE(dynload::ncclGroupEnd());
mutex().unlock(); NCCLMutex().unlock();
}
private:
static std::mutex &mutex() {
static std::mutex mtx;
return mtx;
} }
}; };
...@@ -68,26 +67,6 @@ struct NCCLContext { ...@@ -68,26 +67,6 @@ struct NCCLContext {
int device_id() const { int device_id() const {
return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device; return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
} }
static void InitNCCLContext(std::unordered_map<int, NCCLContext> *contexts,
const std::vector<platform::Place> &places) {
std::vector<ncclComm_t> comms;
std::vector<int> devs;
comms.resize(contexts->size());
devs.reserve(contexts->size());
for (auto &p : places) {
devs.push_back(boost::get<platform::CUDAPlace>(p).device);
}
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
&comms[0], static_cast<int>(contexts->size()), &devs[0]));
int i = 0;
for (auto &dev_id : devs) {
contexts->at(dev_id).comm_ = comms[i++];
}
}
}; };
struct NCCLContextMap { struct NCCLContextMap {
...@@ -107,12 +86,12 @@ struct NCCLContextMap { ...@@ -107,12 +86,12 @@ struct NCCLContextMap {
"NCCL Context Map does not support contain two or more same device"); "NCCL Context Map does not support contain two or more same device");
if (places.size() > 1) { if (places.size() > 1) {
std::vector<ncclComm_t> comms; std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
comms.resize(order_.size()); {
std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
&comms[0], static_cast<int>(order_.size()), &order_[0])); comms.get(), static_cast<int>(order_.size()), order_.data()));
}
int i = 0; int i = 0;
for (auto &dev_id : order_) { for (auto &dev_id : order_) {
contexts_.at(dev_id).comm_ = comms[i++]; contexts_.at(dev_id).comm_ = comms[i++];
...@@ -120,6 +99,9 @@ struct NCCLContextMap { ...@@ -120,6 +99,9 @@ struct NCCLContextMap {
} }
} }
NCCLContextMap(const NCCLContextMap &other) = delete;
NCCLContextMap &operator=(const NCCLContextMap &other) = delete;
CUDADeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); } CUDADeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); }
CUDADeviceContext *DevCtx(platform::Place p) const { CUDADeviceContext *DevCtx(platform::Place p) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册