提交 c64190ec 编写于 作者: Y Yu Yang

Polish NCCLHelper

上级 7483555a
...@@ -61,7 +61,7 @@ struct NCCLContext { ...@@ -61,7 +61,7 @@ struct NCCLContext {
ncclComm_t comm_; ncclComm_t comm_;
explicit NCCLContext(int dev_id) explicit NCCLContext(int dev_id)
: ctx_(new CUDADeviceContext(CUDAPlace(dev_id))) {} : ctx_(new CUDADeviceContext(CUDAPlace(dev_id))), comm_{nullptr} {}
cudaStream_t stream() const { return ctx_->stream(); } cudaStream_t stream() const { return ctx_->stream(); }
...@@ -95,6 +95,7 @@ struct NCCLContextMap { ...@@ -95,6 +95,7 @@ struct NCCLContextMap {
std::vector<int> order_; std::vector<int> order_;
explicit NCCLContextMap(const std::vector<platform::Place> &places) { explicit NCCLContextMap(const std::vector<platform::Place> &places) {
PADDLE_ENFORCE(!places.empty());
order_.reserve(places.size()); order_.reserve(places.size());
for (auto &p : places) { for (auto &p : places) {
int dev_id = boost::get<CUDAPlace>(p).device; int dev_id = boost::get<CUDAPlace>(p).device;
...@@ -105,15 +106,17 @@ struct NCCLContextMap { ...@@ -105,15 +106,17 @@ struct NCCLContextMap {
order_.size(), contexts_.size(), order_.size(), contexts_.size(),
"NCCL Context Map does not support contain two or more same device"); "NCCL Context Map does not support contain two or more same device");
std::vector<ncclComm_t> comms; if (places.size() > 1) {
comms.resize(order_.size()); std::vector<ncclComm_t> comms;
comms.resize(order_.size());
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
&comms[0], static_cast<int>(order_.size()), &order_[0])); &comms[0], static_cast<int>(order_.size()), &order_[0]));
int i = 0; int i = 0;
for (auto &dev_id : order_) { for (auto &dev_id : order_) {
contexts_.at(dev_id).comm_ = comms[i++]; contexts_.at(dev_id).comm_ = comms[i++];
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册