未验证 提交 3e8b6bbc 编写于 作者: B Baibaifan 提交者: GitHub

fix_nccl_barrier (#41970)

上级 9b54bf93
......@@ -353,21 +353,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
const BarrierOptions& opts) {
std::vector<phi::GPUPlace> places;
if (!opts.place_ids.empty()) {
for (auto place_id : opts.place_ids) {
places.emplace_back(place_id);
}
} else if (!used_place_ids_.empty()) {
for (auto place_id : used_place_ids_) {
places.emplace_back(place_id);
}
} else {
auto numGPUs = GetSize();
int place_id = static_cast<int>(rank_ % numGPUs);
places.emplace_back(place_id);
}
// Only support single card single process
std::vector<phi::GPUPlace> places = {place_};
std::vector<phi::DenseTensor> barrierTensors;
barrierTensors.reserve(places.size());
......@@ -375,7 +362,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
platform::CUDADeviceGuard gpuGuard;
for (auto& place : places) {
gpuGuard.SetDeviceIndex(place.GetDeviceId());
auto dt = full({1}, 0, phi::DataType::FLOAT32, phi::GPUPlace());
auto dt = full({1}, 0, phi::DataType::FLOAT32, place);
barrierTensors.push_back(
*std::dynamic_pointer_cast<phi::DenseTensor>(dt.impl()));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册