From 05df6973b128b350f149e579496c0e29e95ff3ce Mon Sep 17 00:00:00 2001 From: Wen Sun <35923278+HermitSun@users.noreply.github.com> Date: Tue, 10 Jan 2023 21:50:18 +0800 Subject: [PATCH] Use `CommContextManager` to init comm op using gloo backend (#49666) * refactor: gloo comm context migration * fix: headers & avoid mutable_data usage * fix: cmake gloo dep * style: rename funcs * refactor: move to new files * fix: gloo deps * refactor: simplify create device --- .../fluid/operators/collective/CMakeLists.txt | 2 +- .../operators/collective/c_broadcast_op.h | 39 ++++--- paddle/fluid/pybind/communication.cc | 6 + paddle/phi/core/distributed/CMakeLists.txt | 14 +++ .../core/distributed/comm_context_manager.cc | 26 +++++ .../core/distributed/comm_context_manager.h | 7 ++ .../phi/core/distributed/gloo_comm_context.cc | 52 +++++++++ .../phi/core/distributed/gloo_comm_context.h | 47 ++++++++ paddle/phi/core/distributed/gloo_utils.cc | 94 ++++++++++++++++ paddle/phi/core/distributed/gloo_utils.h | 106 ++++++++++++++++++ .../phi/core/distributed/store/CMakeLists.txt | 7 ++ .../phi/core/distributed/store/gloo_store.cc | 47 ++++++++ .../phi/core/distributed/store/gloo_store.h | 48 ++++++++ python/paddle/distributed/collective.py | 6 +- .../test_collective_broadcast_api.py | 8 ++ 15 files changed, 494 insertions(+), 15 deletions(-) create mode 100644 paddle/phi/core/distributed/gloo_comm_context.cc create mode 100644 paddle/phi/core/distributed/gloo_comm_context.h create mode 100644 paddle/phi/core/distributed/gloo_utils.cc create mode 100644 paddle/phi/core/distributed/gloo_utils.h create mode 100644 paddle/phi/core/distributed/store/gloo_store.cc create mode 100644 paddle/phi/core/distributed/store/gloo_store.h diff --git a/paddle/fluid/operators/collective/CMakeLists.txt b/paddle/fluid/operators/collective/CMakeLists.txt index 7221832191..c20200f6be 100644 --- a/paddle/fluid/operators/collective/CMakeLists.txt +++ b/paddle/fluid/operators/collective/CMakeLists.txt @@ -18,7 +18,7 @@ foreach(src ${OPS}) endforeach() if(WITH_GLOO) - set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} gloo_wrapper) + set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} gloo_wrapper comm_context_manager) endif() register_operators( diff --git a/paddle/fluid/operators/collective/c_broadcast_op.h b/paddle/fluid/operators/collective/c_broadcast_op.h index 140a438321..8b714c1fe7 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.h +++ b/paddle/fluid/operators/collective/c_broadcast_op.h @@ -21,11 +21,13 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" #if defined(PADDLE_WITH_GLOO) #include #include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#include "paddle/phi/core/distributed/gloo_comm_context.h" #endif namespace paddle { @@ -40,19 +42,30 @@ class CBroadcastOpCPUKernel : public framework::OpKernel { auto out = ctx.Output("Out"); auto root = ctx.Attr("root"); - auto place = ctx.GetPlace(); - int64_t send_numel = in->numel(); - T* recv_buff = out->mutable_data(in->dims(), place); - auto gloo = paddle::framework::GlooWrapper::GetInstance(); - PADDLE_ENFORCE_EQ( - gloo->IsInitialized(), - true, - platform::errors::PreconditionNotMet( - "You must initialize the gloo environment first to use it.")); - gloo::BroadcastOptions opts(gloo->GetContext()); - opts.setOutput(recv_buff, send_numel); - opts.setRoot(root); - gloo::broadcast(opts); + int rid = ctx.Attr("ring_id"); + ctx.device_context().Alloc(out); + + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + if (comm_context_manager.Has(rid)) { + auto* comm_context = static_cast( + comm_context_manager.Get(rid)); + comm_context->Broadcast(out, *in, root); + } else { + // NOTE: This will be removed after moving this operator to phi. + int64_t send_numel = in->numel(); + T* recv_buff = reinterpret_cast(out->data()); + auto gloo = paddle::framework::GlooWrapper::GetInstance(); + PADDLE_ENFORCE_EQ( + gloo->IsInitialized(), + true, + platform::errors::PreconditionNotMet( + "You must initialize the gloo environment first to use it.")); + gloo::BroadcastOptions opts(gloo->GetContext()); + opts.setOutput(recv_buff, send_numel); + opts.setRoot(root); + gloo::broadcast(opts); + } #else PADDLE_THROW(platform::errors::Unavailable( "PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON")); diff --git a/paddle/fluid/pybind/communication.cc b/paddle/fluid/pybind/communication.cc index 5be2519161..b9b57f4339 100644 --- a/paddle/fluid/pybind/communication.cc +++ b/paddle/fluid/pybind/communication.cc @@ -42,6 +42,12 @@ void BindCommContextManager(py::module *m) { "create_nccl_comm_context", &phi::distributed::CommContextManager::CreateNCCLCommContext, py::call_guard()) +#endif +#if defined(PADDLE_WITH_GLOO) + .def_static( + "create_gloo_comm_context", + &phi::distributed::CommContextManager::CreateGlooCommContext, + py::call_guard()) #endif .def("set_store", &phi::distributed::CommContextManager::SetStore); } diff --git a/paddle/phi/core/distributed/CMakeLists.txt b/paddle/phi/core/distributed/CMakeLists.txt index 92a5f10787..4e0794e042 100644 --- a/paddle/phi/core/distributed/CMakeLists.txt +++ b/paddle/phi/core/distributed/CMakeLists.txt @@ -10,6 +10,20 @@ if(WITH_NCCL OR WITH_RCCL) list(APPEND COMM_CONTEXT_MANAGER_DEPS nccl_comm_context) endif() +if(WITH_GLOO) + cc_library( + gloo_utils + SRCS gloo_utils.cc + DEPS gloo dense_tensor enforce) + + cc_library( + gloo_comm_context + SRCS gloo_comm_context.cc + DEPS gloo_utils) + + list(APPEND COMM_CONTEXT_MANAGER_DEPS gloo_comm_context gloo_store) +endif() + cc_library( comm_context_manager SRCS comm_context_manager.cc diff --git a/paddle/phi/core/distributed/comm_context_manager.cc b/paddle/phi/core/distributed/comm_context_manager.cc index 818952736e..7ad44f29f4 100644 --- a/paddle/phi/core/distributed/comm_context_manager.cc +++ b/paddle/phi/core/distributed/comm_context_manager.cc @@ -12,6 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#if defined(PADDLE_WITH_GLOO) +#include "gloo/rendezvous/prefix_store.h" + +#include "paddle/phi/core/distributed/gloo_comm_context.h" +#include "paddle/phi/core/distributed/gloo_utils.h" +#include "paddle/phi/core/distributed/store/gloo_store.h" +#endif + #include "paddle/phi/core/distributed/comm_context_manager.h" #include @@ -60,6 +68,24 @@ void CommContextManager::CreateNCCLCommContext( } #endif +#if defined(PADDLE_WITH_GLOO) +void CommContextManager::CreateGlooCommContext( + const std::shared_ptr& store, int ring_id, int rank, int size) { + GlooStore store_wrapper(store); + auto gloo_store = std::make_shared( + std::to_string(ring_id), store_wrapper); + + auto gloo_device = CreateGlooDevice(); + + auto gloo_comm_context = + std::make_unique(rank, size, gloo_store, gloo_device); + auto& comm_context_manager = CommContextManager::GetInstance(); + // set actual store to manager + comm_context_manager.SetStore(store); + comm_context_manager.Emplace(ring_id, std::move(gloo_comm_context)); +} +#endif + CommContext* CommContextManager::Emplace( int ring_id, std::unique_ptr comm_context) { PADDLE_ENFORCE_EQ( diff --git a/paddle/phi/core/distributed/comm_context_manager.h b/paddle/phi/core/distributed/comm_context_manager.h index 5d57856eee..ed77c5ac9e 100644 --- a/paddle/phi/core/distributed/comm_context_manager.h +++ b/paddle/phi/core/distributed/comm_context_manager.h @@ -52,6 +52,13 @@ class CommContextManager { int size); #endif +#if defined(PADDLE_WITH_GLOO) + static void CreateGlooCommContext(const std::shared_ptr& store, + int ring_id, + int rank, + int size); +#endif + private: DISABLE_COPY_AND_ASSIGN(CommContextManager); diff --git a/paddle/phi/core/distributed/gloo_comm_context.cc b/paddle/phi/core/distributed/gloo_comm_context.cc new file mode 100644 index 0000000000..d51db3bee8 --- /dev/null +++ b/paddle/phi/core/distributed/gloo_comm_context.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/gloo_comm_context.h" + +#include "gloo/broadcast.h" +#include "gloo/types.h" + +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/gloo_utils.h" +#include "paddle/phi/core/enforce.h" + +namespace phi { +namespace distributed { + +GlooCommContext::GlooCommContext( + int rank, + int size, + std::shared_ptr store, + std::shared_ptr device) + : CommContext(rank, size) { + gloo_context_ = std::make_shared(rank, size); + gloo_context_->connectFullMesh(*store, device); +} + +void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int root) { + gloo::BroadcastOptions opts(gloo_context_); + const auto& dtype = in_tensor.dtype(); + GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor); + if (rank_ == root) { + GENERATE_FUNC(dtype, SetInput, &opts, in_tensor); + } + opts.setRoot(root); + gloo::broadcast(opts); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/gloo_comm_context.h b/paddle/phi/core/distributed/gloo_comm_context.h new file mode 100644 index 0000000000..f3bcd7c1e5 --- /dev/null +++ b/paddle/phi/core/distributed/gloo_comm_context.h @@ -0,0 +1,47 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include + +#include "gloo/rendezvous/context.h" +#include "gloo/rendezvous/store.h" +#include "gloo/transport/tcp/device.h" + +#include "paddle/phi/core/distributed/comm_context.h" +#include "paddle/phi/core/macros.h" + +namespace phi { +class DenseTensor; +namespace distributed { + +class GlooCommContext final : public CommContext { + public: + GlooCommContext(int rank, + int size, + std::shared_ptr store, + std::shared_ptr device); + + void Broadcast(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int root); + + private: + DISABLE_COPY_AND_ASSIGN(GlooCommContext); + + std::shared_ptr gloo_context_; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/gloo_utils.cc b/paddle/phi/core/distributed/gloo_utils.cc new file mode 100644 index 0000000000..76ef17f0f0 --- /dev/null +++ b/paddle/phi/core/distributed/gloo_utils.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef _WIN32 +#include +#include +#include "gloo/common/win.h" +#else +#include +#include +#include +#endif + +#include +#include + +#include "paddle/phi/core/distributed/gloo_utils.h" +#include "paddle/phi/core/distributed/store/tcp_utils.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" + +namespace phi { +namespace distributed { +std::shared_ptr CreateDeviceForInterface( + const std::string& ifname) { + gloo::transport::tcp::attr attr; + attr.iface = ifname; + return gloo::transport::tcp::CreateDevice(attr); +} + +std::shared_ptr CreateDeviceForHostname( + const std::string& hostname) { + gloo::transport::tcp::attr attr; + attr.hostname = hostname; + return gloo::transport::tcp::CreateDevice(attr); +} + +std::shared_ptr CreateDefaultDevice() { + std::array hostname; + auto ret = ::gethostname(hostname.data(), HOST_NAME_MAX); + PADDLE_ENFORCE_EQ( + ret, + 0, + phi::errors::Fatal("Get hostname error for createDefaultDevice.")); + ::addrinfo* result; + result = phi::distributed::tcputils::get_addr_info( + hostname.data(), "", 0, AF_UNSPEC); + ::addrinfo* cur; + for (cur = result; cur != nullptr; cur = cur->ai_next) { + phi::distributed::SocketType socket = + ::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol); + if (socket == -1) { + continue; + } + ret = ::bind(socket, cur->ai_addr, cur->ai_addrlen); +#ifdef _WIN32 + closesocket(socket); +#else + close(socket); +#endif + if (ret == -1) { + continue; + } + break; + } + freeaddrinfo(result); + if (cur != nullptr) { + return CreateDeviceForHostname(hostname.data()); + } + return CreateDeviceForHostname("127.0.0.1"); +} + +std::shared_ptr CreateGlooDevice() { + char* ifname = std::getenv("GLOO_SOCKET_IFNAME"); + if (ifname && std::strlen(ifname) > 1) { + return CreateDeviceForInterface(std::string(ifname)); + } else { + return CreateDefaultDevice(); + } +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/gloo_utils.h b/paddle/phi/core/distributed/gloo_utils.h new file mode 100644 index 0000000000..3101c7949b --- /dev/null +++ b/paddle/phi/core/distributed/gloo_utils.h @@ -0,0 +1,106 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "gloo/transport/tcp/device.h" +#include "gloo/types.h" + +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +namespace distributed { + +// data preparation +#ifdef _WIN32 +#define GENERATE_FUNC(type, func, ...) \ + switch (type) { \ + case phi::DataType::FLOAT32: \ + func(__VA_ARGS__); \ + break; \ + case phi::DataType::FLOAT64: \ + func(__VA_ARGS__); \ + break; \ + case phi::DataType::FLOAT16: \ + func(__VA_ARGS__); \ + break; \ + case phi::DataType::INT32: \ + func(__VA_ARGS__); \ + break; \ + case phi::DataType::INT64: \ + func(__VA_ARGS__); \ + break; \ + default: \ + VLOG(0) << "Error: Unknown DataType."; \ + exit(-1); \ + } +#define HOST_NAME_MAX 256 +#else +#define GENERATE_FUNC(type, func, args...) \ + switch (type) { \ + case phi::DataType::FLOAT32: \ + func(args); \ + break; \ + case phi::DataType::FLOAT64: \ + func(args); \ + break; \ + case phi::DataType::FLOAT16: \ + func(args); \ + break; \ + case phi::DataType::INT32: \ + func(args); \ + break; \ + case phi::DataType::INT64: \ + func(args); \ + break; \ + case phi::DataType::INT8: \ + func(args); \ + break; \ + case phi::DataType::UINT8: \ + func(args); \ + break; \ + case phi::DataType::BOOL: \ + func(args); \ + break; \ + case phi::DataType::BFLOAT16: \ + func(args); \ + break; \ + default: \ + VLOG(0) << "Error: Unknown DataType."; \ + exit(-1); \ + } +#endif + +template +void SetOutput(P* opts, phi::DenseTensor* tensor) { + opts->setOutput(reinterpret_cast(tensor->data()), tensor->numel()); +} + +template +void SetInput(P* opts, const phi::DenseTensor& tensor) { + // gloo only support mutable data input + opts->setInput(reinterpret_cast(const_cast(tensor.data())), + tensor.numel()); +} + +// env preparation +std::shared_ptr CreateGlooDevice(); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/store/CMakeLists.txt b/paddle/phi/core/distributed/store/CMakeLists.txt index ac5c8ae9f5..d6b35eb342 100644 --- a/paddle/phi/core/distributed/store/CMakeLists.txt +++ b/paddle/phi/core/distributed/store/CMakeLists.txt @@ -3,6 +3,13 @@ cc_library( SRCS tcp_store.cc tcp_utils.cc socket.cpp store.cc DEPS enforce glog) +if(WITH_GLOO) + cc_library( + gloo_store + SRCS gloo_store.cc + DEPS gloo) +endif() + if(NOT WIN32) cc_test( test_c_tcp_store diff --git a/paddle/phi/core/distributed/store/gloo_store.cc b/paddle/phi/core/distributed/store/gloo_store.cc new file mode 100644 index 0000000000..4da028e55b --- /dev/null +++ b/paddle/phi/core/distributed/store/gloo_store.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/store/gloo_store.h" + +namespace phi { +namespace distributed { + +GlooStore::GlooStore(const std::shared_ptr& store) + : store_(store) {} + +std::vector GlooStore::get(const std::string& key) { + auto value = store_->get(key); + return std::vector(value.begin(), value.end()); +} + +void GlooStore::wait(const std::vector& keys) { + for (auto& key : keys) { + store_->wait(key); + } +} + +void GlooStore::set(const std::string& key, const std::vector& value) { + std::vector tmp(value.begin(), value.end()); + store_->set(key, tmp); +} + +void GlooStore::wait(const std::vector& keys, + const std::chrono::milliseconds& timeout) { + for (auto& key : keys) { + store_->wait(key); + } +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/store/gloo_store.h b/paddle/phi/core/distributed/store/gloo_store.h new file mode 100644 index 0000000000..d785636e2a --- /dev/null +++ b/paddle/phi/core/distributed/store/gloo_store.h @@ -0,0 +1,48 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "gloo/rendezvous/store.h" + +#include "paddle/phi/core/distributed/store/store.h" + +namespace phi { +namespace distributed { + +class GlooStore : public gloo::rendezvous::Store { + public: + explicit GlooStore(const std::shared_ptr& store); + + ~GlooStore() = default; + + std::vector get(const std::string& key) override; + + void wait(const std::vector& keys) override; + + void set(const std::string& key, const std::vector& value) override; + + void wait(const std::vector& keys, + const std::chrono::milliseconds& timeout) override; + + protected: + std::shared_ptr store_; +}; + +} // namespace distributed +} // namespace phi diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 90a5653015..9ffb03760f 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -344,7 +344,11 @@ def _init_parallel_env(backend): is_master, world_size, ) - if backend == "nccl": + if backend == "gloo": + core.CommContextManager.create_gloo_comm_context( + store, 0, rank, world_size + ) + elif backend == "nccl": core.CommContextManager.create_nccl_comm_context( store, dev_id, 0, rank, world_size ) diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py index 6a3fc9ba1e..e1b85242ef 100644 --- a/python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py @@ -43,6 +43,14 @@ class TestCollectiveBroadcastAPI(TestDistBase): "collective_broadcast_api.py", "broadcast", "gloo", "0" ) + def test_broadcast_gloo_with_comm_context(self): + self.check_with_place( + "collective_broadcast_api.py", + "broadcast", + "gloo", + need_envs={"USE_COMM_CONTEXT": "1"}, + ) + def test_broadcast_nccl_dygraph(self): dtypes_to_test = [ "float16", -- GitLab