未验证 提交 482e5b6c 编写于 作者: L lilong12 提交者: GitHub

update (#41762)

上级 30a1213b
......@@ -105,11 +105,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
if (local_rank_ == 0) {
std::vector<phi::DenseTensor> cpu_tensors;
cpu_tensors.reserve(in_tensors.size());
phi::DenseTensor cpu_tensor;
for (size_t i = 0; i < in_tensors.size(); i++) {
auto gpu_tensor = in_tensors[i];
auto cpu_tensor = cpu_tensors[i];
cpu_tensor.Resize(gpu_tensor.dims());
framework::TensorCopySync(gpu_tensor, platform::CPUPlace(), &cpu_tensor);
cpu_tensors.push_back(cpu_tensor);
}
// Step3: do inter cluster allreduce
if (with_switch_) {
......@@ -125,37 +126,32 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
framework::DataTypeSize(dense_cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"Send to the switch module error."));
phi::DenseTensorMeta meta = phi::DenseTensorMeta(
dense_cpu_tensor.dtype(), dense_cpu_tensor.dims());
std::shared_ptr<phi::DenseTensor> dense_cpu_tensor2 =
std::make_shared<phi::DenseTensor>(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
meta);
dense_cpu_tensor2->ResizeAndAllocate(dense_cpu_tensor.dims());
phi::DenseTensor cpu_tensor2;
cpu_tensor2.AllocateFrom(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
dense_cpu_tensor.dtype(), dense_cpu_tensor.numel());
ret = client_->Recv(
gid_, {dense_cpu_tensor.name()}, dense_cpu_tensor2->data(),
dense_cpu_tensor2->numel() *
framework::DataTypeSize(dense_cpu_tensor2->dtype()));
gid_, {dense_cpu_tensor.name()}, cpu_tensor2.data(),
cpu_tensor2.numel() * framework::DataTypeSize(cpu_tensor2.dtype()));
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"Recv from the switch module error."));
switch (dense_cpu_tensor.dtype()) {
case DataType::FLOAT32:
_do_add<float>(reinterpret_cast<float*>(dense_cpu_tensor.data()),
reinterpret_cast<float*>(dense_cpu_tensor2->data()),
reinterpret_cast<float*>(cpu_tensor2.data()),
dense_cpu_tensor.numel());
break;
case DataType::FLOAT64:
_do_add<double>(
reinterpret_cast<double*>(dense_cpu_tensor.data()),
reinterpret_cast<double*>(dense_cpu_tensor2->data()),
dense_cpu_tensor.numel());
_do_add<double>(reinterpret_cast<double*>(dense_cpu_tensor.data()),
reinterpret_cast<double*>(cpu_tensor2.data()),
dense_cpu_tensor.numel());
break;
case DataType::INT32:
_do_add<int>(reinterpret_cast<int*>(dense_cpu_tensor.data()),
reinterpret_cast<int*>(dense_cpu_tensor2->data()),
reinterpret_cast<int*>(cpu_tensor2.data()),
dense_cpu_tensor.numel());
break;
default:
......@@ -207,9 +203,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast(
cpu_tensors.reserve(in_tensors.size());
for (size_t i = 0; i < in_tensors.size(); i++) {
auto gpu_tensor = in_tensors[i];
auto cpu_tensor = cpu_tensors[i];
phi::DenseTensor cpu_tensor;
cpu_tensor.Resize(gpu_tensor.dims());
framework::TensorCopySync(gpu_tensor, platform::CPUPlace(), &cpu_tensor);
cpu_tensors.push_back(cpu_tensor);
}
if (with_switch_) {
if (local_rank_ == 0) {
......@@ -234,13 +231,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast(
PADDLE_ENFORCE_EQ(ret, 0,
platform::errors::PreconditionNotMet(
"Receive from the switch module error."));
ret = client_->Recv(
gid_, {dense_cpu_tensor.name()}, dense_cpu_tensor.data(),
dense_cpu_tensor.numel() *
framework::DataTypeSize(dense_cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0,
platform::errors::PreconditionNotMet(
"Receive from the switch module error."));
}
}
} else {
......
......@@ -286,8 +286,7 @@ int HeterClient::Send(int group_id, const std::vector<std::string>& var_names,
request.add_vars_len(var_len);
}
auto& request_buffer = closure->cntl.request_attachment();
request_buffer.append(reinterpret_cast<void*>(data_ptr),
data_size * sizeof(float));
request_buffer.append(reinterpret_cast<void*>(data_ptr), data_size);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
......@@ -387,7 +386,7 @@ int HeterClient::Recv(int group_id, const std::vector<std::string>& var_names,
if (xpu_channels_.size() < 2) {
LOG(ERROR) << "xpu_channels_ is null";
}
recv_switch_channels_.push_back(xpu_channels_[1]);
recv_switch_channels_.push_back(xpu_channels_[0]);
}
brpc::Channel* channel = recv_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册