未验证 提交 4e00d2bb 编写于 作者: B Baibaifan 提交者: GitHub

add_new_comm_primitive (#40040)

上级 aa47297a
......@@ -96,7 +96,25 @@ class ProcessGroup {
std::vector<Tensor>& /* tensors */,
const BroadcastOptions& = BroadcastOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support allreduce", GetBackendName()));
"ProcessGroup%s does not support broadcast", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support barrier", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send(
std::vector<Tensor>& tensors /* tensors */, int dst_rank) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<Tensor>& tensors /* tensors */, int src_rank) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support receive", GetBackendName()));
}
protected:
......
......@@ -14,6 +14,9 @@
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/place.h"
DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator);
......@@ -139,6 +142,14 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
}
}
if (!barrierTensors_.empty()) {
// If we use the work to do barrier, we should block cpu
for (auto& place : places_) {
platform::CUDADeviceGuard gpuGuard(place);
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
}
}
return true;
}
......@@ -193,6 +204,10 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
nccl_ids.resize(1);
auto& nccl_id = nccl_ids.front();
for (auto& place : places) {
used_place_ids_.insert(place.GetDeviceId());
}
if (rank_ == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id));
}
......@@ -274,6 +289,54 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
return task;
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
std::vector<Tensor>& tensors, Fn fn, int dst_rank, CommType op_type) {
const auto places = GetPlaceList(tensors);
const auto key = GetKeyFromPlaces(places);
{
std::lock_guard<std::mutex> lock(mutex_);
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) {
CreateNCCLManagerCache(key, places);
}
}
auto& nccl_comms = places_to_ncclcomm_[key];
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
auto task = CreateTask(places, rank_, op_type, tensors);
// construct uninitialize guard for device
platform::CUDADeviceGuard cuda_guard;
if (FLAGS_use_stream_safe_cuda_allocator) {
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl());
memory::RecordStream(dense_tensor->Holder(),
places_to_ctx_[key][i]->stream());
}
}
{
platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
const auto& nccl_stream = places_to_ctx_[key][i]->stream();
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank);
}
}
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]);
}
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
std::vector<Tensor>& tensors, const AllreduceOptions& opts) {
PADDLE_ENFORCE_EQ(
......@@ -317,5 +380,98 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
CommType::BROADCAST);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
const BarrierOptions& opts) {
std::vector<phi::GPUPlace> places;
if (!opts.place_ids.empty()) {
for (auto place_id : opts.place_ids) {
places.emplace_back(place_id);
}
} else if (!used_place_ids_.empty()) {
for (auto place_id : used_place_ids_) {
places.emplace_back(place_id);
}
} else {
auto numGPUs = GetSize();
int place_id = static_cast<int>(rank_ % numGPUs);
places.emplace_back(place_id);
}
std::vector<Tensor> barrierTensors;
barrierTensors.reserve(places.size());
platform::CUDADeviceGuard gpuGuard;
for (auto& place : places) {
gpuGuard.SetDeviceIndex(place.GetDeviceId());
auto dt = full({1}, 0, phi::DataType::FLOAT32, phi::Backend::GPU);
barrierTensors.push_back(dt);
}
auto task = ProcessGroupNCCL::AllReduce(barrierTensors);
auto nccl_task = dynamic_cast<ProcessGroupNCCL::NCCLTask*>(task.get());
nccl_task->barrierTensors_ = std::move(barrierTensors);
return task;
}
void CheckTensorsInDifferentDevices(const std::vector<Tensor>& tensors,
const size_t num_devices) {
PADDLE_ENFORCE_EQ(
tensors.size() == 0, false,
platform::errors::InvalidArgument("Tensor list must be nonempty."));
PADDLE_ENFORCE_LE(
tensors.size(), num_devices,
platform::errors::InvalidArgument(
"Tensor list mustn't be larger than the number of available GPUs."));
std::set<Place> used_devices;
for (const auto& t : tensors) {
PADDLE_ENFORCE_EQ(t.is_cuda() && t.is_dense_tensor(), true,
platform::errors::InvalidArgument(
"Tensors must be CUDA and dense tensor."));
const auto inserted = used_devices.insert(t.inner_place()).second;
PADDLE_ENFORCE_EQ(inserted, true,
platform::errors::InvalidArgument(
"Tensors must be on distinct GPU devices."));
}
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
std::vector<Tensor>& tensors, int dst_rank) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
auto task = PointToPoint(
tensors,
[&](Tensor& input, ncclComm_t comm, const gpuStream_t& stream,
int dst_rank) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
return platform::dynload::ncclSend(
input_tensor->data(), input_tensor->numel(),
platform::ToNCCLDataType(input.type()), dst_rank, comm, stream);
},
dst_rank, CommType::SEND);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
std::vector<Tensor>& tensors, int src_rank) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
auto task = PointToPoint(
tensors,
[&](Tensor& output, ncclComm_t comm, const gpuStream_t& stream,
int src_rank) {
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
return platform::dynload::ncclRecv(
output_tensor->data(), output_tensor->numel(),
platform::ToNCCLDataType(output.type()), src_rank, comm, stream);
},
src_rank, CommType::RECV);
return task;
}
} // namespace distributed
} // namespace paddle
......@@ -65,6 +65,7 @@ class ProcessGroupNCCL : public ProcessGroup {
virtual ~NCCLTask();
std::vector<EventManager> control_events_;
std::vector<Tensor> barrierTensors_;
protected:
std::vector<Place> places_;
......@@ -88,6 +89,15 @@ class ProcessGroupNCCL : public ProcessGroup {
std::vector<Tensor>& tensors,
const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Task> Send(std::vector<Tensor>& tensors,
int dst_rank) override;
std::shared_ptr<ProcessGroup::Task> Recv(std::vector<Tensor>& tensors,
int src_rank) override;
protected:
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places, int rank, CommType opType,
......@@ -106,6 +116,8 @@ class ProcessGroupNCCL : public ProcessGroup {
std::vector<std::unique_ptr<CUDADeviceContext>>>
places_to_ctx_;
std::set<int> used_place_ids_;
private:
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, int root, // NOLINT
int server_fd);
......@@ -118,6 +130,11 @@ class ProcessGroupNCCL : public ProcessGroup {
std::vector<Tensor>& outputs, // NOLINT
Fn fn, CommType op_type);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> PointToPoint(
std::vector<Tensor>& tensors, // NOLINT
Fn fn, int dst_rank, CommType op_type);
void CreateNCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places);
};
......
......@@ -32,5 +32,9 @@ struct BroadcastOptions {
int source_root = 0;
};
struct BarrierOptions {
std::vector<int> place_ids;
};
} // namespace distributed
} // namespace paddle
......@@ -60,6 +60,10 @@ void BindDistributed(py::module *m) {
.def_readwrite("source_root",
&distributed::BroadcastOptions::source_root);
py::class_<distributed::BarrierOptions>(*m, "BarrierOptions")
.def(py::init<>())
.def_readwrite("place_ids", &distributed::BarrierOptions::place_ids);
auto ProcessGroup =
py::class_<distributed::ProcessGroup,
std::shared_ptr<distributed::ProcessGroup>>(*m, "ProcessGroup")
......@@ -88,6 +92,35 @@ void BindDistributed(py::module *m) {
return self.Broadcast(tensors, opts);
},
py::arg("tensor"), py::arg("source_rank"),
py::call_guard<py::gil_scoped_release>())
.def("barrier",
[](distributed::ProcessGroup &self, std::vector<int> place_ids) {
distributed::BarrierOptions opts;
opts.place_ids = place_ids;
return self.Barrier(opts);
},
py::arg("place_ids") = std::vector<int>{},
py::call_guard<py::gil_scoped_release>())
.def("send",
[](distributed::ProcessGroup &self, py::handle py_tensor,
int dst) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
std::vector<Tensor> tensors = {tensor};
return self.Send(tensors, dst);
},
py::arg("tensor"), py::arg("dst"),
py::call_guard<py::gil_scoped_release>())
.def("recv",
[](distributed::ProcessGroup &self, py::handle py_tensor,
int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
std::vector<Tensor> tensors = {tensor};
return self.Recv(tensors, src);
},
py::arg("tensor"), py::arg("src"),
py::call_guard<py::gil_scoped_release>());
#if defined(PADDLE_WITH_NCCL)
......
......@@ -132,6 +132,36 @@ class TestProcessGroupFp32(unittest.TestCase):
print("test broadcast api ok")
# test barrier
# rank 0
if pg.rank() == 0:
task = pg.barrier()
task.wait()
# rank 1
else:
task = pg.barrier()
task.wait()
print("test barrier api ok\n")
# test send/recv
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
if pg.rank() == 0:
task = pg.send(tensor_x, dst=1)
task.wait()
paddle.device.cuda.synchronize()
# rank 1
else:
y = np.random.random(self.shape).astype(self.dtype)
tensor_y = paddle.to_tensor(y)
task = pg.recv(tensor_y, src=0)
task.wait()
paddle.device.cuda.synchronize()
assert np.array_equal(tensor_x, tensor_y)
print("test send/recv api ok\n")
class TestProcessGroupFp16(TestProcessGroupFp32):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册