diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index 29990043206509e4192bfff84832f09ef127d9dd..3a2a423486170320d82aea7ad1e97138c2df7e69 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -14,8 +14,9 @@ #pragma once -#include +#include // NOLINT #include +#include #include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/enforce.h" @@ -29,6 +30,8 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) { return ncclDouble; } else if (type == typeid(int)) { // NOLINT return ncclInt; + } else if (type == typeid(int64_t)) { // NOLINT + return ncclInt64; } else { PADDLE_THROW("Not supported"); } @@ -66,23 +69,23 @@ struct NCCLContext { return boost::get(ctx_->GetPlace()).device; } - static void InitNCCLContext(std::unordered_map &contexts, + static void InitNCCLContext(std::unordered_map *contexts, const std::vector &places) { std::vector comms; std::vector devs; - comms.resize(contexts.size()); - devs.reserve(contexts.size()); + comms.resize(contexts->size()); + devs.reserve(contexts->size()); for (auto &p : places) { devs.push_back(boost::get(p).device); } PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( - &comms[0], static_cast(contexts.size()), &devs[0])); + &comms[0], static_cast(contexts->size()), &devs[0])); int i = 0; for (auto &dev_id : devs) { - contexts.at(dev_id).comm_ = comms[i++]; + contexts->at(dev_id).comm_ = comms[i++]; } } }; @@ -91,7 +94,7 @@ struct NCCLContextMap { std::unordered_map contexts_; std::vector order_; - NCCLContextMap(const std::vector &places) { + explicit NCCLContextMap(const std::vector &places) { order_.reserve(places.size()); for (auto &p : places) { int dev_id = boost::get(p).device;