未验证 提交 75227c9e 编写于 作者: L lilong12 提交者: GitHub

use group id to differentiate keys for tcp store (#41496)

上级 dfb47986
...@@ -110,7 +110,8 @@ void ProcessGroupNCCL::BroadcastUniqueNCCLID( ...@@ -110,7 +110,8 @@ void ProcessGroupNCCL::BroadcastUniqueNCCLID(
std::vector<ncclUniqueId>& nccl_ids) { // NOLINT std::vector<ncclUniqueId>& nccl_ids) { // NOLINT
if (rank_ == 0) { if (rank_ == 0) {
for (size_t i = 0; i < nccl_ids.size(); i++) { for (size_t i = 0; i < nccl_ids.size(); i++) {
auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(i); auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/" +
std::to_string(i);
auto nccl_id = std::vector<uint8_t>( auto nccl_id = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(&nccl_ids[i]), reinterpret_cast<uint8_t*>(&nccl_ids[i]),
reinterpret_cast<uint8_t*>(&nccl_ids[i]) + NCCL_UNIQUE_ID_BYTES); reinterpret_cast<uint8_t*>(&nccl_ids[i]) + NCCL_UNIQUE_ID_BYTES);
...@@ -118,7 +119,8 @@ void ProcessGroupNCCL::BroadcastUniqueNCCLID( ...@@ -118,7 +119,8 @@ void ProcessGroupNCCL::BroadcastUniqueNCCLID(
} }
} else { } else {
for (size_t i = 0; i < nccl_ids.size(); i++) { for (size_t i = 0; i < nccl_ids.size(); i++) {
auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(i); auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/" +
std::to_string(i);
auto ret = store_->get(key); auto ret = store_->get(key);
std::memcpy(&nccl_ids[i], ret.data(), ret.size()); std::memcpy(&nccl_ids[i], ret.data(), ret.size());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册