提交 d054cfea 编写于 作者: X Xin Pan

Avoid init_nccl for every steps.

上级 158d5674
......@@ -16,5 +16,44 @@ limitations under the License. */
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace platform {} // namespace platform
namespace platform {
namespace {
// TODO(panyx0718): Where to destroy them.
std::unique_ptr<std::vector<ncclComm_t>> global_comms;
std::unique_ptr<std::unordered_map<int, int>> comm_id_map;
bool inited = false;
size_t last_num_gpus = -1;
}
int Communicator::GetCommId(int device_id) const {
return comm_id_map->at(device_id);
}
void Communicator::InitAll(const std::vector<int>& gpus) {
if (inited && last_num_gpus == gpus.size()) {
return;
}
last_num_gpus = gpus.size();
if (global_comms) {
for (size_t i = 0; i < global_comms->size(); ++i) {
// FIXME(dzh) : PADDLE_ENFORCE return void
dynload::ncclCommDestroy((*global_comms)[i]);
}
}
global_comms.reset(new std::vector<ncclComm_t>());
comm_id_map.reset(new std::unordered_map<int, int>());
global_comms->resize(gpus.size());
for (size_t i = 0; i < gpus.size(); ++i) {
(*comm_id_map)[gpus[i]] = i;
}
PADDLE_ENFORCE(
dynload::ncclCommInitAll(global_comms->data(), gpus.size(), gpus.data()));
inited = true;
}
const std::vector<ncclComm_t>& Communicator::comms() const {
return *global_comms;
}
} // namespace platform
} // namespace paddle
......@@ -29,39 +29,16 @@ limitations under the License. */
namespace paddle {
namespace platform {
constexpr int kInvalidGPUId = -1;
struct Communicator {
std::vector<ncclComm_t> comms_;
std::unordered_map<int, int> comm_id_map_;
bool inited_;
Communicator() {}
int GetCommId(int device_id) const { return comm_id_map_.at(device_id); }
void InitAll(const std::vector<int>& gpus) {
comms_.resize(gpus.size());
inited_ = false;
for (size_t i = 0; i < gpus.size(); ++i) {
comm_id_map_[gpus[i]] = i;
}
PADDLE_ENFORCE(
dynload::ncclCommInitAll(comms_.data(), gpus.size(), gpus.data()));
inited_ = true;
}
int GetCommId(int device_id) const;
~Communicator() {
if (inited_) {
for (size_t i = 0; i < comms_.size(); ++i) {
// FIXME(dzh) : PADDLE_ENFORCE return void
dynload::ncclCommDestroy(comms_[i]);
}
}
}
void InitAll(const std::vector<int>& gpus);
DISABLE_COPY_AND_ASSIGN(Communicator);
const std::vector<ncclComm_t>& comms() const;
};
} // namespace platform
......
......@@ -78,7 +78,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_,
comm->comms_[idx], stream));
comm->comms().at(idx), stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : "
......@@ -127,7 +127,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
std::hash<std::string> hasher;
for (size_t i = 0; i < ins.size(); ++i) {
if (root == platform::kInvalidGPUId) {
root = hasher(ins_names[i]) % comm->comms_.size();
root = hasher(ins_names[i]) % comm->comms().size();
}
T* recvbuffer = nullptr;
if (root == gpu_id) {
......@@ -139,7 +139,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::dynload::ncclReduce(
ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms_[idx],
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms().at(idx),
stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
......@@ -176,7 +176,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
VLOG(1) << " before ncclBcast";
PADDLE_ENFORCE(platform::dynload::ncclBcast(
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type,
root, comm->comms_[idx], stream));
root, comm->comms().at(idx), stream));
VLOG(1) << " after ncclBcast";
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
......@@ -190,7 +190,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::dynload::ncclBcast(
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(),
NCCLTypeWrapper<T>::type, root, comm->comms_[idx], stream));
NCCLTypeWrapper<T>::type, root, comm->comms().at(idx), stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册