未验证 提交 482e5b6c 编写于 作者: L lilong12 提交者: GitHub

update (#41762)

上级 30a1213b
...@@ -105,11 +105,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce( ...@@ -105,11 +105,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
if (local_rank_ == 0) { if (local_rank_ == 0) {
std::vector<phi::DenseTensor> cpu_tensors; std::vector<phi::DenseTensor> cpu_tensors;
cpu_tensors.reserve(in_tensors.size()); cpu_tensors.reserve(in_tensors.size());
phi::DenseTensor cpu_tensor;
for (size_t i = 0; i < in_tensors.size(); i++) { for (size_t i = 0; i < in_tensors.size(); i++) {
auto gpu_tensor = in_tensors[i]; auto gpu_tensor = in_tensors[i];
auto cpu_tensor = cpu_tensors[i];
cpu_tensor.Resize(gpu_tensor.dims()); cpu_tensor.Resize(gpu_tensor.dims());
framework::TensorCopySync(gpu_tensor, platform::CPUPlace(), &cpu_tensor); framework::TensorCopySync(gpu_tensor, platform::CPUPlace(), &cpu_tensor);
cpu_tensors.push_back(cpu_tensor);
} }
// Step3: do inter cluster allreduce // Step3: do inter cluster allreduce
if (with_switch_) { if (with_switch_) {
...@@ -125,37 +126,32 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce( ...@@ -125,37 +126,32 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
framework::DataTypeSize(dense_cpu_tensor.dtype())); framework::DataTypeSize(dense_cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"Send to the switch module error.")); "Send to the switch module error."));
phi::DenseTensorMeta meta = phi::DenseTensorMeta( phi::DenseTensor cpu_tensor2;
dense_cpu_tensor.dtype(), dense_cpu_tensor.dims()); cpu_tensor2.AllocateFrom(
std::shared_ptr<phi::DenseTensor> dense_cpu_tensor2 = std::make_unique<paddle::experimental::DefaultAllocator>(
std::make_shared<phi::DenseTensor>( paddle::platform::CPUPlace())
std::make_unique<paddle::experimental::DefaultAllocator>( .get(),
paddle::platform::CPUPlace()) dense_cpu_tensor.dtype(), dense_cpu_tensor.numel());
.get(),
meta);
dense_cpu_tensor2->ResizeAndAllocate(dense_cpu_tensor.dims());
ret = client_->Recv( ret = client_->Recv(
gid_, {dense_cpu_tensor.name()}, dense_cpu_tensor2->data(), gid_, {dense_cpu_tensor.name()}, cpu_tensor2.data(),
dense_cpu_tensor2->numel() * cpu_tensor2.numel() * framework::DataTypeSize(cpu_tensor2.dtype()));
framework::DataTypeSize(dense_cpu_tensor2->dtype()));
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"Recv from the switch module error.")); "Recv from the switch module error."));
switch (dense_cpu_tensor.dtype()) { switch (dense_cpu_tensor.dtype()) {
case DataType::FLOAT32: case DataType::FLOAT32:
_do_add<float>(reinterpret_cast<float*>(dense_cpu_tensor.data()), _do_add<float>(reinterpret_cast<float*>(dense_cpu_tensor.data()),
reinterpret_cast<float*>(dense_cpu_tensor2->data()), reinterpret_cast<float*>(cpu_tensor2.data()),
dense_cpu_tensor.numel()); dense_cpu_tensor.numel());
break; break;
case DataType::FLOAT64: case DataType::FLOAT64:
_do_add<double>( _do_add<double>(reinterpret_cast<double*>(dense_cpu_tensor.data()),
reinterpret_cast<double*>(dense_cpu_tensor.data()), reinterpret_cast<double*>(cpu_tensor2.data()),
reinterpret_cast<double*>(dense_cpu_tensor2->data()), dense_cpu_tensor.numel());
dense_cpu_tensor.numel());
break; break;
case DataType::INT32: case DataType::INT32:
_do_add<int>(reinterpret_cast<int*>(dense_cpu_tensor.data()), _do_add<int>(reinterpret_cast<int*>(dense_cpu_tensor.data()),
reinterpret_cast<int*>(dense_cpu_tensor2->data()), reinterpret_cast<int*>(cpu_tensor2.data()),
dense_cpu_tensor.numel()); dense_cpu_tensor.numel());
break; break;
default: default:
...@@ -207,9 +203,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast( ...@@ -207,9 +203,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast(
cpu_tensors.reserve(in_tensors.size()); cpu_tensors.reserve(in_tensors.size());
for (size_t i = 0; i < in_tensors.size(); i++) { for (size_t i = 0; i < in_tensors.size(); i++) {
auto gpu_tensor = in_tensors[i]; auto gpu_tensor = in_tensors[i];
auto cpu_tensor = cpu_tensors[i]; phi::DenseTensor cpu_tensor;
cpu_tensor.Resize(gpu_tensor.dims()); cpu_tensor.Resize(gpu_tensor.dims());
framework::TensorCopySync(gpu_tensor, platform::CPUPlace(), &cpu_tensor); framework::TensorCopySync(gpu_tensor, platform::CPUPlace(), &cpu_tensor);
cpu_tensors.push_back(cpu_tensor);
} }
if (with_switch_) { if (with_switch_) {
if (local_rank_ == 0) { if (local_rank_ == 0) {
...@@ -234,13 +231,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast( ...@@ -234,13 +231,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast(
PADDLE_ENFORCE_EQ(ret, 0, PADDLE_ENFORCE_EQ(ret, 0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Receive from the switch module error.")); "Receive from the switch module error."));
ret = client_->Recv(
gid_, {dense_cpu_tensor.name()}, dense_cpu_tensor.data(),
dense_cpu_tensor.numel() *
framework::DataTypeSize(dense_cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0,
platform::errors::PreconditionNotMet(
"Receive from the switch module error."));
} }
} }
} else { } else {
......
...@@ -286,8 +286,7 @@ int HeterClient::Send(int group_id, const std::vector<std::string>& var_names, ...@@ -286,8 +286,7 @@ int HeterClient::Send(int group_id, const std::vector<std::string>& var_names,
request.add_vars_len(var_len); request.add_vars_len(var_len);
} }
auto& request_buffer = closure->cntl.request_attachment(); auto& request_buffer = closure->cntl.request_attachment();
request_buffer.append(reinterpret_cast<void*>(data_ptr), request_buffer.append(reinterpret_cast<void*>(data_ptr), data_size);
data_size * sizeof(float));
auto promise = std::make_shared<std::promise<int32_t>>(); auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise); closure->add_promise(promise);
std::future<int> fut = promise->get_future(); std::future<int> fut = promise->get_future();
...@@ -387,7 +386,7 @@ int HeterClient::Recv(int group_id, const std::vector<std::string>& var_names, ...@@ -387,7 +386,7 @@ int HeterClient::Recv(int group_id, const std::vector<std::string>& var_names,
if (xpu_channels_.size() < 2) { if (xpu_channels_.size() < 2) {
LOG(ERROR) << "xpu_channels_ is null"; LOG(ERROR) << "xpu_channels_ is null";
} }
recv_switch_channels_.push_back(xpu_channels_[1]); recv_switch_channels_.push_back(xpu_channels_[0]);
} }
brpc::Channel* channel = recv_switch_channels_[0].get(); brpc::Channel* channel = recv_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel); ::paddle::distributed::PsService_Stub stub(channel);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册