提交 d8aebaf5 编写于 作者: D Dong Zhihong

"fix enforce error"

上级 d2be7ec3
...@@ -8,27 +8,27 @@ NCCLManager::NCCLManager() {} ...@@ -8,27 +8,27 @@ NCCLManager::NCCLManager() {}
NCCLManager::~NCCLManager() { NCCLManager::~NCCLManager() {
for (auto& p : comm_table) { for (auto& p : comm_table) {
auto* comm = p.second; auto& comm = p.second;
auto& gpus_ = comm->gpus_; auto& gpus_ = comm->gpus_;
for (int i = 0; i < gpus_.size(); ++i) { for (size_t i = 0; i < gpus_.size(); ++i) {
int gid = gpus_[i]; int gid = gpus_[i];
platform::SetDeviceId(gid); platform::SetDeviceId(gid);
// mapping gid to idx // mapping gid to idx
int idx = gid % gpus_.size(); int idx = gid % gpus_.size();
// wait finish // wait finish
NCCL_CHECK( PADDLE_ENFORCE(
cudaStreamWaitEvent(*comm->streams_[idx], comm->events_[idx], 0)); cudaStreamWaitEvent(*comm->streams_[idx], comm->events_[idx], 0));
NCCL_CHECK(cudaEventDestroy(comm->events_[idx])); PADDLE_ENFORCE(cudaEventDestroy(comm->events_[idx]));
NCCL_CHECK(ncclCommDestroy(comm->comms_[idx])); PADDLE_ENFORCE(ncclCommDestroy(comm->comms_[idx]));
} }
delete comm; comm.reset(nullptr);
} }
} }
Communicator* NCCLManager::GetCommunicator(const std::vector<int>& gpus) const { Communicator* NCCLManager::GetCommunicator(const std::vector<int>& gpus) {
std::string key; std::string key;
for (auto& id : gpus) { for (auto& id : gpus) {
key += std::to_string(id); key += std::to_string(id);
...@@ -37,21 +37,24 @@ Communicator* NCCLManager::GetCommunicator(const std::vector<int>& gpus) const { ...@@ -37,21 +37,24 @@ Communicator* NCCLManager::GetCommunicator(const std::vector<int>& gpus) const {
std::mutex mu; std::mutex mu;
std::lock_guard<std::mutex> lk(mu); std::lock_guard<std::mutex> lk(mu);
auto* comm = comm_table[key];
if (comm == nullptr) { auto it = comm_table.find(key);
comm = new Communicator(gpus.size());
NCCL_CHECK(ncclCommInitAll(comm->comms_.data(), gpus.size(), gpus.data())); if (it->second == nullptr) {
auto* comm = new Communicator(gpus);
PADDLE_ENFORCE(
ncclCommInitAll(comm->comms_.data(), gpus.size(), gpus.data()));
for (size_t i = 0; i < gpus.size(); ++i) { for (size_t i = 0; i < gpus.size(); ++i) {
platform::SetDeviceId(gpus[i]); platform::SetDeviceId(gpus[i]);
// block wait // block wait
NCCL_CHECK(cudaEventCreateWithFlags( PADDLE_ENFORCE(cudaEventCreateWithFlags(
&events_[i], cudaEventBlockingSync | cudaEventDisableTiming)); &comm->events_[i], cudaEventBlockingSync | cudaEventDisableTiming));
} }
comm_table[key] = comm; comm_table[key].reset(comm);
} }
return comm; return comm_table[key].get();
} }
} // namespace operators } // namespace operators
......
#pragma once #pragma once
#include <nccl.h>
#include <algorithm> #include <algorithm>
#include <condition_variable> #include <condition_variable>
...@@ -10,17 +9,11 @@ ...@@ -10,17 +9,11 @@
#include <vector> #include <vector>
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
#define NCCL_CHECK(condition) \
do { \
ncclResult_t ret = (condition); \
PADDLE_ENFORCE(ret == ncclSuccess, "Error invoking NCCL: ", __FILE__, \
__LINE__, ncclGetErrorString(ret)); \
} while (0)
class WaitGroup { class WaitGroup {
public: public:
inline void Add(int n) { inline void Add(int n) {
...@@ -101,7 +94,7 @@ class NCCLManager { ...@@ -101,7 +94,7 @@ class NCCLManager {
~NCCLManager(); ~NCCLManager();
// for each card only have one communicator // for each card only have one communicator
Communicator* GetCommunicator(const std::vector<int>& gpus) const; Communicator* GetCommunicator(const std::vector<int>& gpus);
private: private:
// // the gpu id list available. Note that only support // // the gpu id list available. Note that only support
...@@ -109,7 +102,8 @@ class NCCLManager { ...@@ -109,7 +102,8 @@ class NCCLManager {
// std::vector<int> _gpu_worlds; // std::vector<int> _gpu_worlds;
// communicator list // communicator list
std::unordered_map<std::string /* key*/, Communicator*> comm_table; std::unordered_map<std::string /* key*/, std::unique_ptr<Communicator>>
comm_table;
}; };
} // namespace operators } // namespace operators
......
...@@ -54,14 +54,15 @@ class NCCLAllReduceKernel : public framework::OpKernel { ...@@ -54,14 +54,15 @@ class NCCLAllReduceKernel : public framework::OpKernel {
comm->streams_[idx] = stream; comm->streams_[idx] = stream;
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
NCCL_CHECK(ncclAllReduce(ins[i]->data<T>(), outs[i]->mutable_data<T>(), PADDLE_ENFORCE(
outs[i]->numel() * sizeof(T), ncclAllReduce(ins[i]->data<T>(), outs[i]->mutable_data<T>(),
NCCLTypeWrapper<T>::type, op_type, outs[i]->numel() * sizeof(T), NCCLTypeWrapper<T>::type,
&comm->comms_[idx], comm->streams_[idx])); op_type, &comm->comms_[idx], comm->streams_[idx]));
NCCL_CHECK(cudaEventRecord(comm->events_[idx], *comms_->streams_[idx])); PADDLE_ENFORCE(
cudaEventRecord(comm->events_[idx], *comms_->streams_[idx]));
// wait finish // wait finish
NCCL_CHECK( PADDLE_ENFORCE(
cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0)); cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0));
} }
......
...@@ -30,13 +30,13 @@ extern void* nccl_dso_handle; ...@@ -30,13 +30,13 @@ extern void* nccl_dso_handle;
#define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
ncclResult_t operator()(Args... args) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
typedef ncclResult_t (*ncclFunc)(Args...); \ using nccl_func = decltype(__name(args...)) (*)(Args...); \
std::call_once(nccl_dso_flag, \ std::call_once(nccl_dso_flag, \
paddle::platform::dynload::GetNcclDsoHandle, \ paddle::platform::dynload::GetNcclDsoHandle, \
&nccl_dso_handle); \ &nccl_dso_handle); \
void* p_##__name = dlsym(nccl_dso_handle, #__name); \ void* p_##__name = dlsym(nccl_dso_handle, #__name); \
return reinterpret_cast<ncclFunc>(p_##__name)(args...); \ return reinterpret_cast<nccl_func>(p_##__name)(args...); \
} \ } \
}; \ }; \
extern DynLoad__##__name __name extern DynLoad__##__name __name
...@@ -65,7 +65,7 @@ extern void* nccl_dso_handle; ...@@ -65,7 +65,7 @@ extern void* nccl_dso_handle;
__macro(ncclReduce); \ __macro(ncclReduce); \
__macro(ncclGetErrorString); __macro(ncclGetErrorString);
NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP); NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
......
...@@ -29,6 +29,8 @@ limitations under the License. */ ...@@ -29,6 +29,8 @@ limitations under the License. */
#include <cxxabi.h> // for __cxa_demangle #include <cxxabi.h> // for __cxa_demangle
#endif #endif
#include <glog/logging.h>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cublas.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册