提交 d8aebaf5 编写于 作者: D Dong Zhihong

"fix enforce error"

上级 d2be7ec3
......@@ -8,27 +8,27 @@ NCCLManager::NCCLManager() {}
NCCLManager::~NCCLManager() {
for (auto& p : comm_table) {
auto* comm = p.second;
auto& comm = p.second;
auto& gpus_ = comm->gpus_;
for (int i = 0; i < gpus_.size(); ++i) {
for (size_t i = 0; i < gpus_.size(); ++i) {
int gid = gpus_[i];
platform::SetDeviceId(gid);
// mapping gid to idx
int idx = gid % gpus_.size();
// wait finish
NCCL_CHECK(
PADDLE_ENFORCE(
cudaStreamWaitEvent(*comm->streams_[idx], comm->events_[idx], 0));
NCCL_CHECK(cudaEventDestroy(comm->events_[idx]));
PADDLE_ENFORCE(cudaEventDestroy(comm->events_[idx]));
NCCL_CHECK(ncclCommDestroy(comm->comms_[idx]));
PADDLE_ENFORCE(ncclCommDestroy(comm->comms_[idx]));
}
delete comm;
comm.reset(nullptr);
}
}
Communicator* NCCLManager::GetCommunicator(const std::vector<int>& gpus) const {
Communicator* NCCLManager::GetCommunicator(const std::vector<int>& gpus) {
std::string key;
for (auto& id : gpus) {
key += std::to_string(id);
......@@ -37,21 +37,24 @@ Communicator* NCCLManager::GetCommunicator(const std::vector<int>& gpus) const {
std::mutex mu;
std::lock_guard<std::mutex> lk(mu);
auto* comm = comm_table[key];
if (comm == nullptr) {
comm = new Communicator(gpus.size());
NCCL_CHECK(ncclCommInitAll(comm->comms_.data(), gpus.size(), gpus.data()));
auto it = comm_table.find(key);
if (it->second == nullptr) {
auto* comm = new Communicator(gpus);
PADDLE_ENFORCE(
ncclCommInitAll(comm->comms_.data(), gpus.size(), gpus.data()));
for (size_t i = 0; i < gpus.size(); ++i) {
platform::SetDeviceId(gpus[i]);
// block wait
NCCL_CHECK(cudaEventCreateWithFlags(
&events_[i], cudaEventBlockingSync | cudaEventDisableTiming));
PADDLE_ENFORCE(cudaEventCreateWithFlags(
&comm->events_[i], cudaEventBlockingSync | cudaEventDisableTiming));
}
comm_table[key] = comm;
comm_table[key].reset(comm);
}
return comm;
return comm_table[key].get();
}
} // namespace operators
......
#pragma once
#include <nccl.h>
#include <algorithm>
#include <condition_variable>
......@@ -10,17 +9,11 @@
#include <vector>
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace platform {
#define NCCL_CHECK(condition) \
do { \
ncclResult_t ret = (condition); \
PADDLE_ENFORCE(ret == ncclSuccess, "Error invoking NCCL: ", __FILE__, \
__LINE__, ncclGetErrorString(ret)); \
} while (0)
class WaitGroup {
public:
inline void Add(int n) {
......@@ -101,7 +94,7 @@ class NCCLManager {
~NCCLManager();
// for each card only have one communicator
Communicator* GetCommunicator(const std::vector<int>& gpus) const;
Communicator* GetCommunicator(const std::vector<int>& gpus);
private:
// // the gpu id list available. Note that only support
......@@ -109,7 +102,8 @@ class NCCLManager {
// std::vector<int> _gpu_worlds;
// communicator list
std::unordered_map<std::string /* key*/, Communicator*> comm_table;
std::unordered_map<std::string /* key*/, std::unique_ptr<Communicator>>
comm_table;
};
} // namespace operators
......
......@@ -54,14 +54,15 @@ class NCCLAllReduceKernel : public framework::OpKernel {
comm->streams_[idx] = stream;
for (size_t i = 0; i < ins.size(); ++i) {
NCCL_CHECK(ncclAllReduce(ins[i]->data<T>(), outs[i]->mutable_data<T>(),
outs[i]->numel() * sizeof(T),
NCCLTypeWrapper<T>::type, op_type,
&comm->comms_[idx], comm->streams_[idx]));
NCCL_CHECK(cudaEventRecord(comm->events_[idx], *comms_->streams_[idx]));
PADDLE_ENFORCE(
ncclAllReduce(ins[i]->data<T>(), outs[i]->mutable_data<T>(),
outs[i]->numel() * sizeof(T), NCCLTypeWrapper<T>::type,
op_type, &comm->comms_[idx], comm->streams_[idx]));
PADDLE_ENFORCE(
cudaEventRecord(comm->events_[idx], *comms_->streams_[idx]));
// wait finish
NCCL_CHECK(
PADDLE_ENFORCE(
cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0));
}
......
......@@ -30,13 +30,13 @@ extern void* nccl_dso_handle;
#define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
ncclResult_t operator()(Args... args) { \
typedef ncclResult_t (*ncclFunc)(Args...); \
auto operator()(Args... args) -> decltype(__name(args...)) { \
using nccl_func = decltype(__name(args...)) (*)(Args...); \
std::call_once(nccl_dso_flag, \
paddle::platform::dynload::GetNcclDsoHandle, \
&nccl_dso_handle); \
void* p_##__name = dlsym(nccl_dso_handle, #__name); \
return reinterpret_cast<ncclFunc>(p_##__name)(args...); \
return reinterpret_cast<nccl_func>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
......@@ -65,7 +65,7 @@ extern void* nccl_dso_handle;
__macro(ncclReduce); \
__macro(ncclGetErrorString);
NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP);
NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
} // namespace dynload
} // namespace platform
......
......@@ -29,6 +29,8 @@ limitations under the License. */
#include <cxxabi.h> // for __cxa_demangle
#endif
#include <glog/logging.h>
#ifdef PADDLE_WITH_CUDA
#include "paddle/platform/dynload/cublas.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册