未验证 提交 05df6973 编写于 作者: W Wen Sun 提交者: GitHub

Use `CommContextManager` to init comm op using gloo backend (#49666)

* refactor: gloo comm context migration

* fix: headers & avoid mutable_data usage

* fix: cmake gloo dep

* style: rename funcs

* refactor: move to new files

* fix: gloo deps

* refactor: simplify create device
上级 98693428
......@@ -18,7 +18,7 @@ foreach(src ${OPS})
endforeach()
if(WITH_GLOO)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} gloo_wrapper)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} gloo_wrapper comm_context_manager)
endif()
register_operators(
......
......@@ -21,11 +21,13 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_GLOO)
#include <gloo/broadcast.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif
namespace paddle {
......@@ -40,9 +42,19 @@ class CBroadcastOpCPUKernel : public framework::OpKernel<T> {
auto out = ctx.Output<phi::DenseTensor>("Out");
auto root = ctx.Attr<int>("root");
auto place = ctx.GetPlace();
int rid = ctx.Attr<int>("ring_id");
ctx.device_context().Alloc<T>(out);
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (comm_context_manager.Has(rid)) {
auto* comm_context = static_cast<phi::distributed::GlooCommContext*>(
comm_context_manager.Get(rid));
comm_context->Broadcast(out, *in, root);
} else {
// NOTE: This will be removed after moving this operator to phi.
int64_t send_numel = in->numel();
T* recv_buff = out->mutable_data<T>(in->dims(), place);
T* recv_buff = reinterpret_cast<T*>(out->data());
auto gloo = paddle::framework::GlooWrapper::GetInstance();
PADDLE_ENFORCE_EQ(
gloo->IsInitialized(),
......@@ -53,6 +65,7 @@ class CBroadcastOpCPUKernel : public framework::OpKernel<T> {
opts.setOutput(recv_buff, send_numel);
opts.setRoot(root);
gloo::broadcast(opts);
}
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"));
......
......@@ -42,6 +42,12 @@ void BindCommContextManager(py::module *m) {
"create_nccl_comm_context",
&phi::distributed::CommContextManager::CreateNCCLCommContext,
py::call_guard<py::gil_scoped_release>())
#endif
#if defined(PADDLE_WITH_GLOO)
.def_static(
"create_gloo_comm_context",
&phi::distributed::CommContextManager::CreateGlooCommContext,
py::call_guard<py::gil_scoped_release>())
#endif
.def("set_store", &phi::distributed::CommContextManager::SetStore);
}
......
......@@ -10,6 +10,20 @@ if(WITH_NCCL OR WITH_RCCL)
list(APPEND COMM_CONTEXT_MANAGER_DEPS nccl_comm_context)
endif()
if(WITH_GLOO)
cc_library(
gloo_utils
SRCS gloo_utils.cc
DEPS gloo dense_tensor enforce)
cc_library(
gloo_comm_context
SRCS gloo_comm_context.cc
DEPS gloo_utils)
list(APPEND COMM_CONTEXT_MANAGER_DEPS gloo_comm_context gloo_store)
endif()
cc_library(
comm_context_manager
SRCS comm_context_manager.cc
......
......@@ -12,6 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_GLOO)
#include "gloo/rendezvous/prefix_store.h"
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#include "paddle/phi/core/distributed/gloo_utils.h"
#include "paddle/phi/core/distributed/store/gloo_store.h"
#endif
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include <memory>
......@@ -60,6 +68,24 @@ void CommContextManager::CreateNCCLCommContext(
}
#endif
#if defined(PADDLE_WITH_GLOO)
void CommContextManager::CreateGlooCommContext(
const std::shared_ptr<Store>& store, int ring_id, int rank, int size) {
GlooStore store_wrapper(store);
auto gloo_store = std::make_shared<gloo::rendezvous::PrefixStore>(
std::to_string(ring_id), store_wrapper);
auto gloo_device = CreateGlooDevice();
auto gloo_comm_context =
std::make_unique<GlooCommContext>(rank, size, gloo_store, gloo_device);
auto& comm_context_manager = CommContextManager::GetInstance();
// set actual store to manager
comm_context_manager.SetStore(store);
comm_context_manager.Emplace(ring_id, std::move(gloo_comm_context));
}
#endif
CommContext* CommContextManager::Emplace(
int ring_id, std::unique_ptr<CommContext> comm_context) {
PADDLE_ENFORCE_EQ(
......
......@@ -52,6 +52,13 @@ class CommContextManager {
int size);
#endif
#if defined(PADDLE_WITH_GLOO)
static void CreateGlooCommContext(const std::shared_ptr<Store>& store,
int ring_id,
int rank,
int size);
#endif
private:
DISABLE_COPY_AND_ASSIGN(CommContextManager);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#include "gloo/broadcast.h"
#include "gloo/types.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/gloo_utils.h"
#include "paddle/phi/core/enforce.h"
namespace phi {
namespace distributed {
GlooCommContext::GlooCommContext(
int rank,
int size,
std::shared_ptr<gloo::rendezvous::Store> store,
std::shared_ptr<gloo::transport::Device> device)
: CommContext(rank, size) {
gloo_context_ = std::make_shared<gloo::rendezvous::Context>(rank, size);
gloo_context_->connectFullMesh(*store, device);
}
void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root) {
gloo::BroadcastOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
if (rank_ == root) {
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
}
opts.setRoot(root);
gloo::broadcast(opts);
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "gloo/rendezvous/context.h"
#include "gloo/rendezvous/store.h"
#include "gloo/transport/tcp/device.h"
#include "paddle/phi/core/distributed/comm_context.h"
#include "paddle/phi/core/macros.h"
namespace phi {
class DenseTensor;
namespace distributed {
class GlooCommContext final : public CommContext {
public:
GlooCommContext(int rank,
int size,
std::shared_ptr<gloo::rendezvous::Store> store,
std::shared_ptr<gloo::transport::Device> device);
void Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root);
private:
DISABLE_COPY_AND_ASSIGN(GlooCommContext);
std::shared_ptr<gloo::rendezvous::Context> gloo_context_;
};
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef _WIN32
#include <winsock2.h>
#include <ws2tcpip.h>
#include "gloo/common/win.h"
#else
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#endif
#include <cstdlib>
#include <cstring>
#include "paddle/phi/core/distributed/gloo_utils.h"
#include "paddle/phi/core/distributed/store/tcp_utils.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace phi {
namespace distributed {
std::shared_ptr<gloo::transport::Device> CreateDeviceForInterface(
const std::string& ifname) {
gloo::transport::tcp::attr attr;
attr.iface = ifname;
return gloo::transport::tcp::CreateDevice(attr);
}
std::shared_ptr<gloo::transport::Device> CreateDeviceForHostname(
const std::string& hostname) {
gloo::transport::tcp::attr attr;
attr.hostname = hostname;
return gloo::transport::tcp::CreateDevice(attr);
}
std::shared_ptr<gloo::transport::Device> CreateDefaultDevice() {
std::array<char, HOST_NAME_MAX> hostname;
auto ret = ::gethostname(hostname.data(), HOST_NAME_MAX);
PADDLE_ENFORCE_EQ(
ret,
0,
phi::errors::Fatal("Get hostname error for createDefaultDevice."));
::addrinfo* result;
result = phi::distributed::tcputils::get_addr_info(
hostname.data(), "", 0, AF_UNSPEC);
::addrinfo* cur;
for (cur = result; cur != nullptr; cur = cur->ai_next) {
phi::distributed::SocketType socket =
::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol);
if (socket == -1) {
continue;
}
ret = ::bind(socket, cur->ai_addr, cur->ai_addrlen);
#ifdef _WIN32
closesocket(socket);
#else
close(socket);
#endif
if (ret == -1) {
continue;
}
break;
}
freeaddrinfo(result);
if (cur != nullptr) {
return CreateDeviceForHostname(hostname.data());
}
return CreateDeviceForHostname("127.0.0.1");
}
std::shared_ptr<gloo::transport::Device> CreateGlooDevice() {
char* ifname = std::getenv("GLOO_SOCKET_IFNAME");
if (ifname && std::strlen(ifname) > 1) {
return CreateDeviceForInterface(std::string(ifname));
} else {
return CreateDefaultDevice();
}
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <climits>
#include <memory>
#include <string>
#include "gloo/transport/tcp/device.h"
#include "gloo/types.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
namespace distributed {
// data preparation
#ifdef _WIN32
#define GENERATE_FUNC(type, func, ...) \
switch (type) { \
case phi::DataType::FLOAT32: \
func<float>(__VA_ARGS__); \
break; \
case phi::DataType::FLOAT64: \
func<double>(__VA_ARGS__); \
break; \
case phi::DataType::FLOAT16: \
func<gloo::float16>(__VA_ARGS__); \
break; \
case phi::DataType::INT32: \
func<int32_t>(__VA_ARGS__); \
break; \
case phi::DataType::INT64: \
func<int64_t>(__VA_ARGS__); \
break; \
default: \
VLOG(0) << "Error: Unknown DataType."; \
exit(-1); \
}
#define HOST_NAME_MAX 256
#else
#define GENERATE_FUNC(type, func, args...) \
switch (type) { \
case phi::DataType::FLOAT32: \
func<float>(args); \
break; \
case phi::DataType::FLOAT64: \
func<double>(args); \
break; \
case phi::DataType::FLOAT16: \
func<gloo::float16>(args); \
break; \
case phi::DataType::INT32: \
func<int32_t>(args); \
break; \
case phi::DataType::INT64: \
func<int64_t>(args); \
break; \
case phi::DataType::INT8: \
func<int8_t>(args); \
break; \
case phi::DataType::UINT8: \
func<uint8_t>(args); \
break; \
case phi::DataType::BOOL: \
func<bool>(args); \
break; \
case phi::DataType::BFLOAT16: \
func<phi::dtype::bfloat16>(args); \
break; \
default: \
VLOG(0) << "Error: Unknown DataType."; \
exit(-1); \
}
#endif
template <typename T, typename P>
void SetOutput(P* opts, phi::DenseTensor* tensor) {
opts->setOutput(reinterpret_cast<T*>(tensor->data()), tensor->numel());
}
template <typename T, typename P>
void SetInput(P* opts, const phi::DenseTensor& tensor) {
// gloo only support mutable data input
opts->setInput(reinterpret_cast<T*>(const_cast<void*>(tensor.data())),
tensor.numel());
}
// env preparation
std::shared_ptr<gloo::transport::Device> CreateGlooDevice();
} // namespace distributed
} // namespace phi
......@@ -3,6 +3,13 @@ cc_library(
SRCS tcp_store.cc tcp_utils.cc socket.cpp store.cc
DEPS enforce glog)
if(WITH_GLOO)
cc_library(
gloo_store
SRCS gloo_store.cc
DEPS gloo)
endif()
if(NOT WIN32)
cc_test(
test_c_tcp_store
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/store/gloo_store.h"
namespace phi {
namespace distributed {
GlooStore::GlooStore(const std::shared_ptr<phi::distributed::Store>& store)
: store_(store) {}
std::vector<char> GlooStore::get(const std::string& key) {
auto value = store_->get(key);
return std::vector<char>(value.begin(), value.end());
}
void GlooStore::wait(const std::vector<std::string>& keys) {
for (auto& key : keys) {
store_->wait(key);
}
}
void GlooStore::set(const std::string& key, const std::vector<char>& value) {
std::vector<uint8_t> tmp(value.begin(), value.end());
store_->set(key, tmp);
}
void GlooStore::wait(const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) {
for (auto& key : keys) {
store_->wait(key);
}
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#include <memory>
#include <vector>
#include "gloo/rendezvous/store.h"
#include "paddle/phi/core/distributed/store/store.h"
namespace phi {
namespace distributed {
class GlooStore : public gloo::rendezvous::Store {
public:
explicit GlooStore(const std::shared_ptr<phi::distributed::Store>& store);
~GlooStore() = default;
std::vector<char> get(const std::string& key) override;
void wait(const std::vector<std::string>& keys) override;
void set(const std::string& key, const std::vector<char>& value) override;
void wait(const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) override;
protected:
std::shared_ptr<phi::distributed::Store> store_;
};
} // namespace distributed
} // namespace phi
......@@ -344,7 +344,11 @@ def _init_parallel_env(backend):
is_master,
world_size,
)
if backend == "nccl":
if backend == "gloo":
core.CommContextManager.create_gloo_comm_context(
store, 0, rank, world_size
)
elif backend == "nccl":
core.CommContextManager.create_nccl_comm_context(
store, dev_id, 0, rank, world_size
)
......@@ -43,6 +43,14 @@ class TestCollectiveBroadcastAPI(TestDistBase):
"collective_broadcast_api.py", "broadcast", "gloo", "0"
)
def test_broadcast_gloo_with_comm_context(self):
self.check_with_place(
"collective_broadcast_api.py",
"broadcast",
"gloo",
need_envs={"USE_COMM_CONTEXT": "1"},
)
def test_broadcast_nccl_dygraph(self):
dtypes_to_test = [
"float16",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册