提交 c64190ec 编写于 作者: Y Yu Yang

Polish NCCLHelper

上级 7483555a
......@@ -61,7 +61,7 @@ struct NCCLContext {
ncclComm_t comm_;
explicit NCCLContext(int dev_id)
: ctx_(new CUDADeviceContext(CUDAPlace(dev_id))) {}
: ctx_(new CUDADeviceContext(CUDAPlace(dev_id))), comm_{nullptr} {}
cudaStream_t stream() const { return ctx_->stream(); }
......@@ -95,6 +95,7 @@ struct NCCLContextMap {
std::vector<int> order_;
explicit NCCLContextMap(const std::vector<platform::Place> &places) {
PADDLE_ENFORCE(!places.empty());
order_.reserve(places.size());
for (auto &p : places) {
int dev_id = boost::get<CUDAPlace>(p).device;
......@@ -105,15 +106,17 @@ struct NCCLContextMap {
order_.size(), contexts_.size(),
"NCCL Context Map does not support contain two or more same device");
std::vector<ncclComm_t> comms;
comms.resize(order_.size());
if (places.size() > 1) {
std::vector<ncclComm_t> comms;
comms.resize(order_.size());
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
&comms[0], static_cast<int>(order_.size()), &order_[0]));
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
&comms[0], static_cast<int>(order_.size()), &order_[0]));
int i = 0;
for (auto &dev_id : order_) {
contexts_.at(dev_id).comm_ = comms[i++];
int i = 0;
for (auto &dev_id : order_) {
contexts_.at(dev_id).comm_ = comms[i++];
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册