From 482e5b6cd10171a54b5470cf288fce314d65480a Mon Sep 17 00:00:00 2001 From: lilong12 Date: Fri, 15 Apr 2022 14:12:30 +0800 Subject: [PATCH] update (#41762) --- .../collective/ProcessGroupHeter.cc | 44 +++++++------------ .../distributed/ps/service/heter_client.cc | 5 +-- 2 files changed, 19 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/distributed/collective/ProcessGroupHeter.cc b/paddle/fluid/distributed/collective/ProcessGroupHeter.cc index a48bda06323..354a8e23ae4 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupHeter.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupHeter.cc @@ -105,11 +105,12 @@ std::shared_ptr ProcessGroupHeter::AllReduce( if (local_rank_ == 0) { std::vector cpu_tensors; cpu_tensors.reserve(in_tensors.size()); + phi::DenseTensor cpu_tensor; for (size_t i = 0; i < in_tensors.size(); i++) { auto gpu_tensor = in_tensors[i]; - auto cpu_tensor = cpu_tensors[i]; cpu_tensor.Resize(gpu_tensor.dims()); framework::TensorCopySync(gpu_tensor, platform::CPUPlace(), &cpu_tensor); + cpu_tensors.push_back(cpu_tensor); } // Step3: do inter cluster allreduce if (with_switch_) { @@ -125,37 +126,32 @@ std::shared_ptr ProcessGroupHeter::AllReduce( framework::DataTypeSize(dense_cpu_tensor.dtype())); PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( "Send to the switch module error.")); - phi::DenseTensorMeta meta = phi::DenseTensorMeta( - dense_cpu_tensor.dtype(), dense_cpu_tensor.dims()); - std::shared_ptr dense_cpu_tensor2 = - std::make_shared( - std::make_unique( - paddle::platform::CPUPlace()) - .get(), - meta); - dense_cpu_tensor2->ResizeAndAllocate(dense_cpu_tensor.dims()); + phi::DenseTensor cpu_tensor2; + cpu_tensor2.AllocateFrom( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + dense_cpu_tensor.dtype(), dense_cpu_tensor.numel()); ret = client_->Recv( - gid_, {dense_cpu_tensor.name()}, dense_cpu_tensor2->data(), - dense_cpu_tensor2->numel() * - framework::DataTypeSize(dense_cpu_tensor2->dtype())); + gid_, {dense_cpu_tensor.name()}, cpu_tensor2.data(), + cpu_tensor2.numel() * framework::DataTypeSize(cpu_tensor2.dtype())); PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( "Recv from the switch module error.")); switch (dense_cpu_tensor.dtype()) { case DataType::FLOAT32: _do_add(reinterpret_cast(dense_cpu_tensor.data()), - reinterpret_cast(dense_cpu_tensor2->data()), + reinterpret_cast(cpu_tensor2.data()), dense_cpu_tensor.numel()); break; case DataType::FLOAT64: - _do_add( - reinterpret_cast(dense_cpu_tensor.data()), - reinterpret_cast(dense_cpu_tensor2->data()), - dense_cpu_tensor.numel()); + _do_add(reinterpret_cast(dense_cpu_tensor.data()), + reinterpret_cast(cpu_tensor2.data()), + dense_cpu_tensor.numel()); break; case DataType::INT32: _do_add(reinterpret_cast(dense_cpu_tensor.data()), - reinterpret_cast(dense_cpu_tensor2->data()), + reinterpret_cast(cpu_tensor2.data()), dense_cpu_tensor.numel()); break; default: @@ -207,9 +203,10 @@ std::shared_ptr ProcessGroupHeter::Broadcast( cpu_tensors.reserve(in_tensors.size()); for (size_t i = 0; i < in_tensors.size(); i++) { auto gpu_tensor = in_tensors[i]; - auto cpu_tensor = cpu_tensors[i]; + phi::DenseTensor cpu_tensor; cpu_tensor.Resize(gpu_tensor.dims()); framework::TensorCopySync(gpu_tensor, platform::CPUPlace(), &cpu_tensor); + cpu_tensors.push_back(cpu_tensor); } if (with_switch_) { if (local_rank_ == 0) { @@ -234,13 +231,6 @@ std::shared_ptr ProcessGroupHeter::Broadcast( PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( "Receive from the switch module error.")); - ret = client_->Recv( - gid_, {dense_cpu_tensor.name()}, dense_cpu_tensor.data(), - dense_cpu_tensor.numel() * - framework::DataTypeSize(dense_cpu_tensor.dtype())); - PADDLE_ENFORCE_EQ(ret, 0, - platform::errors::PreconditionNotMet( - "Receive from the switch module error.")); } } } else { diff --git a/paddle/fluid/distributed/ps/service/heter_client.cc b/paddle/fluid/distributed/ps/service/heter_client.cc index 4ca25dac826..16c1ff764dc 100644 --- a/paddle/fluid/distributed/ps/service/heter_client.cc +++ b/paddle/fluid/distributed/ps/service/heter_client.cc @@ -286,8 +286,7 @@ int HeterClient::Send(int group_id, const std::vector& var_names, request.add_vars_len(var_len); } auto& request_buffer = closure->cntl.request_attachment(); - request_buffer.append(reinterpret_cast(data_ptr), - data_size * sizeof(float)); + request_buffer.append(reinterpret_cast(data_ptr), data_size); auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); @@ -387,7 +386,7 @@ int HeterClient::Recv(int group_id, const std::vector& var_names, if (xpu_channels_.size() < 2) { LOG(ERROR) << "xpu_channels_ is null"; } - recv_switch_channels_.push_back(xpu_channels_[1]); + recv_switch_channels_.push_back(xpu_channels_[0]); } brpc::Channel* channel = recv_switch_channels_[0].get(); ::paddle::distributed::PsService_Stub stub(channel); -- GitLab