未验证 提交 c7b7291b 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #8758 from panyx0718/nccl

[Speed]Avoid init_nccl for every steps.
...@@ -16,5 +16,50 @@ limitations under the License. */ ...@@ -16,5 +16,50 @@ limitations under the License. */
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
namespace paddle { namespace paddle {
namespace platform {} // namespace platform namespace platform {
namespace {
// TODO(panyx0718): Where to destroy them.
std::unique_ptr<std::vector<ncclComm_t>> global_comms;
std::unique_ptr<std::unordered_map<int, int>> comm_id_map;
bool inited = false;
size_t last_num_gpus = -1;
// TODO(panyx0718): Need to decide whether Paddle supports parallel
// runs with different number GPUs. If true, current solution is not enough.
std::mutex comm_mu;
}
int Communicator::GetCommId(int device_id) const {
std::lock_guard<std::mutex> guard(comm_mu);
return comm_id_map->at(device_id);
}
void Communicator::InitAll(const std::vector<int>& gpus) {
std::lock_guard<std::mutex> guard(comm_mu);
if (inited && last_num_gpus == gpus.size()) {
return;
}
last_num_gpus = gpus.size();
if (global_comms) {
for (size_t i = 0; i < global_comms->size(); ++i) {
// FIXME(dzh) : PADDLE_ENFORCE return void
dynload::ncclCommDestroy((*global_comms)[i]);
}
}
global_comms.reset(new std::vector<ncclComm_t>());
comm_id_map.reset(new std::unordered_map<int, int>());
global_comms->resize(gpus.size());
for (size_t i = 0; i < gpus.size(); ++i) {
(*comm_id_map)[gpus[i]] = i;
}
PADDLE_ENFORCE(
dynload::ncclCommInitAll(global_comms->data(), gpus.size(), gpus.data()));
inited = true;
}
const std::vector<ncclComm_t>& Communicator::comms() const {
std::lock_guard<std::mutex> guard(comm_mu);
return *global_comms;
}
} // namespace platform
} // namespace paddle } // namespace paddle
...@@ -29,39 +29,16 @@ limitations under the License. */ ...@@ -29,39 +29,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
constexpr int kInvalidGPUId = -1; constexpr int kInvalidGPUId = -1;
struct Communicator { struct Communicator {
std::vector<ncclComm_t> comms_;
std::unordered_map<int, int> comm_id_map_;
bool inited_;
Communicator() {} Communicator() {}
int GetCommId(int device_id) const { return comm_id_map_.at(device_id); } int GetCommId(int device_id) const;
void InitAll(const std::vector<int>& gpus) {
comms_.resize(gpus.size());
inited_ = false;
for (size_t i = 0; i < gpus.size(); ++i) {
comm_id_map_[gpus[i]] = i;
}
PADDLE_ENFORCE(
dynload::ncclCommInitAll(comms_.data(), gpus.size(), gpus.data()));
inited_ = true;
}
~Communicator() { void InitAll(const std::vector<int>& gpus);
if (inited_) {
for (size_t i = 0; i < comms_.size(); ++i) {
// FIXME(dzh) : PADDLE_ENFORCE return void
dynload::ncclCommDestroy(comms_[i]);
}
}
}
DISABLE_COPY_AND_ASSIGN(Communicator); const std::vector<ncclComm_t>& comms() const;
}; };
} // namespace platform } // namespace platform
......
...@@ -78,7 +78,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -78,7 +78,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()), ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_, outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_,
comm->comms_[idx], stream)); comm->comms().at(idx), stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " VLOG(1) << "gpu : "
...@@ -127,7 +127,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> { ...@@ -127,7 +127,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
std::hash<std::string> hasher; std::hash<std::string> hasher;
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
if (root == platform::kInvalidGPUId) { if (root == platform::kInvalidGPUId) {
root = hasher(ins_names[i]) % comm->comms_.size(); root = hasher(ins_names[i]) % comm->comms().size();
} }
T* recvbuffer = nullptr; T* recvbuffer = nullptr;
if (root == gpu_id) { if (root == gpu_id) {
...@@ -139,7 +139,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> { ...@@ -139,7 +139,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::dynload::ncclReduce( PADDLE_ENFORCE(platform::dynload::ncclReduce(
ins[i]->data<T>(), recvbuffer, ins[i]->numel(), ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms_[idx], NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms().at(idx),
stream)); stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
...@@ -176,7 +176,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> { ...@@ -176,7 +176,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
VLOG(1) << " before ncclBcast"; VLOG(1) << " before ncclBcast";
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type, (void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type,
root, comm->comms_[idx], stream)); root, comm->comms().at(idx), stream));
VLOG(1) << " after ncclBcast"; VLOG(1) << " after ncclBcast";
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
...@@ -190,7 +190,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> { ...@@ -190,7 +190,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(), outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(),
NCCLTypeWrapper<T>::type, root, comm->comms_[idx], stream)); NCCLTypeWrapper<T>::type, root, comm->comms().at(idx), stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv " VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册