未验证 提交 7a724ddb 编写于 作者: Y yaoxuefeng 提交者: GitHub

fix multi-node (#36329)

上级 414c252a
......@@ -117,6 +117,15 @@ class PSGPUWrapper {
resource_ = std::make_shared<HeterPsResource>(dev_ids);
resource_->enable_p2p();
keys_tensor.resize(resource_->total_gpu());
#ifdef PADDLE_WITH_GLOO
auto gloo = paddle::framework::GlooWrapper::GetInstance();
if (gloo->Size() > 1) {
multi_node_ = 1;
}
#else
PADDLE_THROW(
platform::errors::Unavailable("heter ps need compile with GLOO"));
#endif
if (multi_node_) {
int dev_size = dev_ids.size();
// init inner comm
......@@ -127,7 +136,6 @@ class PSGPUWrapper {
// init inter comm
#ifdef PADDLE_WITH_GLOO
inter_comms_.resize(dev_size);
auto gloo = paddle::framework::GlooWrapper::GetInstance();
if (gloo->Rank() == 0) {
for (int i = 0; i < dev_size; ++i) {
platform::dynload::ncclGetUniqueId(&inter_ncclids_[i]);
......
......@@ -148,7 +148,7 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer(
paddle::platform::errors::InvalidArgument(
"dev ids = [%d], it should greater than 0.", dev_ids.size()));
const int kDevices = dev_ids.size();
VLOG(3) << "Begin CreateNCCLCommMultiTrainer. device number: " << kDevices
VLOG(1) << "Begin CreateNCCLCommMultiTrainer. device number: " << kDevices
<< ", ntrainers: " << ntrainers << ", train_id: " << train_id
<< ", rind_id: " << ring_id;
ncclComm_t comms[kDevices];
......@@ -162,10 +162,10 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer(
#endif
platform::dynload::ncclCommInitRank(comms + i, kDevices * ntrainers,
*nccl_id, train_id * kDevices + i);
VLOG(3) << "ncclCommInitRank: " << i;
VLOG(1) << "ncclCommInitRank: " << i;
}
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclGroupEnd());
VLOG(3) << "nccl group end seccessss";
VLOG(1) << "nccl group end seccessss";
}
PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0,
platform::errors::InvalidArgument(
......@@ -174,7 +174,7 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer(
for (int i = 0; i < kDevices; ++i) {
AssignNCCLComm(comms[i], kDevices * ntrainers, train_id * kDevices + i,
dev_ids[i], ring_id);
VLOG(3) << "nccl communicator of train_id " << train_id * kDevices + i
VLOG(1) << "nccl communicator of train_id " << train_id * kDevices + i
<< " in ring " << ring_id << " has been created on device "
<< dev_ids[i];
}
......
......@@ -396,6 +396,8 @@ class InMemoryDataset(DatasetBase):
Set data_feed_desc
"""
self.proto_desc.name = data_feed_type
if (self.proto_desc.name == "SlotRecordInMemoryDataFeed"):
self.dataset = core.Dataset("SlotRecordDataset")
@deprecated(
since="2.0.0",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册