未验证 提交 e4ebf383 编写于 作者: W Wen Sun 提交者: GitHub

Update `ProcessGroupCustom` for `sync_op` compatibility (#47976)

* refactor: update pg custom

* fix: use new api in ut

* fix: typo

* revert: recover legacy apis

* fix: add GetDeviceContext
上级 39c85064
......@@ -202,38 +202,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllGather(
device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER);
}
void* XcclGetPointerByOffset(void* raw_pointer,
size_t offset,
experimental::DataType type) {
......@@ -259,13 +227,13 @@ void* XcclGetPointerByOffset(void* raw_pointer,
return nullptr;
}
// NOTE: this is ONLY for compatibility
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op) {
bool sync_op // for compatibility, no use now
) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Collective(
......@@ -287,6 +255,105 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
CommType::ALLGATHER);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op // for compatibility, no use now
) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return AllReduce(in_wrapper, out_wrapper, opts);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op // for compatibility, no use now
) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
const BarrierOptions& opts) {
// Only support single card single process
PADDLE_ENFORCE_GE(opts.device_id,
0,
platform::errors::PreconditionNotMet(
"The barrier device id must greater or equal than 0."));
platform::CustomPlace place(device_type_, opts.device_id);
auto allocator = std::unique_ptr<phi::Allocator>(
new paddle::experimental::DefaultAllocator(place));
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
phi::DenseTensor barrier_tensor{allocator.get(), meta};
auto task = ProcessGroupCustom::AllReduce(&barrier_tensor,
barrier_tensor,
{},
/*sync_op*/ true);
auto xccl_task = dynamic_cast<ProcessGroupCustom::CustomTask*>(task.get());
xccl_task->barrierTensors_ = {barrier_tensor};
return task;
}
const phi::DeviceContext& ProcessGroupCustom::GetDeviceContext(
const Place& place) const {
const std::string key = GetKeyFromPlace(place);
const auto& iter = places_to_ctx_.find(key);
PADDLE_ENFORCE_NE(
iter,
places_to_ctx_.end(),
platform::errors::NotFound(
"Cannot find the device context in this process group."));
return *iter->second[0];
}
phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const {
std::vector<Place> places = {place};
const auto& iter = places_to_customcomm_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_customcomm_.end(),
platform::errors::InvalidArgument(
"Cannot find nccl comm in process group."));
return iter->second[0]->GetCustomCCLComm();
}
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllGather(
device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
......@@ -366,40 +433,5 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
CommType::BROADCAST);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
const BarrierOptions& opts) {
// Only support single card single process
PADDLE_ENFORCE_GE(opts.device_id,
0,
platform::errors::PreconditionNotMet(
"The barrier device id must greater or equal than 0."));
platform::CustomPlace place(device_type_, opts.device_id);
std::vector<phi::CustomPlace> places = {place};
std::vector<phi::DenseTensor> barrierTensors;
barrierTensors.reserve(places.size());
for (auto& place : places) {
phi::DeviceGuard guard(place);
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim({1}));
auto allocator = std::unique_ptr<phi::Allocator>(
new paddle::experimental::DefaultAllocator(place));
barrierTensors.emplace_back(allocator.get(), meta);
}
auto task = ProcessGroupCustom::AllReduce(barrierTensors, barrierTensors);
auto xccl_task = dynamic_cast<ProcessGroupCustom::CustomTask*>(task.get());
xccl_task->barrierTensors_ = std::move(barrierTensors);
return task;
}
phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const {
std::vector<Place> places = {place};
const auto& iter = places_to_customcomm_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_customcomm_.end(),
platform::errors::InvalidArgument(
"Cannot find nccl comm in process group."));
return iter->second[0]->GetCustomCCLComm();
}
} // namespace distributed
} // namespace paddle
......@@ -78,6 +78,26 @@ class ProcessGroupCustom : public ProcessGroup {
int64_t numel,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
const phi::DeviceContext& GetDeviceContext(const Place& place) const override;
phi::ccl::CCLComm CustomCCLComm(const Place& place) const;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
......@@ -92,11 +112,6 @@ class ProcessGroupCustom : public ProcessGroup {
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
phi::ccl::CCLComm CustomCCLComm(const Place& place) const;
protected:
virtual std::shared_ptr<ProcessGroupCustom::CustomTask> CreateTask(
std::vector<Place> places,
......
......@@ -63,11 +63,11 @@ class TestProcessGroupFp32(unittest.TestCase):
sum_result = tensor_x + tensor_y
if pg.rank() == 0:
task = pg.allreduce(tensor_x)
task = pg.all_reduce(tensor_x, core.ReduceOp.SUM, sync_op=True)
task.wait()
# assert np.array_equal(tensor_x, sum_result)
else:
task = pg.allreduce(tensor_y)
task = pg.all_reduce(tensor_y, core.ReduceOp.SUM, sync_op=True)
task.wait()
# assert np.array_equal(tensor_y, sum_result)
......@@ -81,11 +81,11 @@ class TestProcessGroupFp32(unittest.TestCase):
max_result = paddle.maximum(tensor_x, tensor_y)
if pg.rank() == 0:
task = pg.allreduce(tensor_x, core.ReduceOp.MAX)
task = pg.all_reduce(tensor_x, core.ReduceOp.MAX, sync_op=True)
task.wait()
# assert np.array_equal(tensor_x, max_result)
else:
task = pg.allreduce(tensor_y, core.ReduceOp.MAX)
task = pg.all_reduce(tensor_y, core.ReduceOp.MAX, sync_op=True)
task.wait()
# assert np.array_equal(tensor_y, max_result)
......@@ -101,14 +101,14 @@ class TestProcessGroupFp32(unittest.TestCase):
broadcast_result = paddle.assign(tensor_x)
if pg.rank() == 0:
task = pg.broadcast(tensor_x, 0)
task.synchronize()
task = pg.broadcast(tensor_x, 0, sync_op=True)
task.wait()
# paddle.fluid.core._custom_device_synchronize("custom_cpu", -1)
assert task.is_completed()
# assert np.array_equal(broadcast_result, tensor_x)
else:
task = pg.broadcast(tensor_y, 0)
task.synchronize()
task = pg.broadcast(tensor_y, 0, sync_op=True)
task.wait()
# paddle.fluid.core._custom_device_synchronize("custom_cpu", -1)
assert task.is_completed()
# assert np.array_equal(broadcast_result, tensor_y)
......@@ -139,12 +139,12 @@ class TestProcessGroupFp32(unittest.TestCase):
out = np.random.random(out_shape).astype(self.dtype)
tensor_out = paddle.to_tensor(out)
if pg.rank() == 0:
task = pg.all_gather(tensor_x, tensor_out)
task = pg.all_gather(tensor_out, tensor_x, sync_op=True)
task.wait()
# paddle.fluid.core._custom_device_synchronize("custom_cpu", -1)
# rank 1
else:
task = pg.all_gather(tensor_y, tensor_out)
task = pg.all_gather(tensor_out, tensor_y, sync_op=True)
task.wait()
# paddle.fluid.core._custom_device_synchronize("custom_cpu", -1)
out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册