未验证 提交 c16f85f9 编写于 作者: L lilong12 提交者: GitHub

Add the implementation of Gloo for ProcessGroup (#39892)

* add pg_gloo
上级 272b32fd
cc_library(processgroup SRCS ProcessGroup.cc DEPS phi phi_api eager_api)
if (WITH_DISTRIBUTE)
cc_library(processgroup_gloo SRCS ProcessGroupGloo.cc DEPS phi phi_api eager_api gloo_wrapper)
endif()
cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup)
if(WITH_NCCL)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#ifdef _WIN32
#include <gloo/common/win.h>
#include <winsock2.h>
#include <ws2tcpip.h>
#else
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#endif
#include <gloo/broadcast.h>
#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
#ifdef _WIN32
#define GENERATE_FUNC(type, func, ...) \
switch (type) { \
case experimental::DataType::FLOAT32: \
func<float>(__VA_ARGS__); \
break; \
case experimental::DataType::FLOAT64: \
func<double>(__VA_ARGS__); \
break; \
case experimental::DataType::FLOAT16: \
func<gloo::float16>(__VA_ARGS__); \
break; \
case experimental::DataType::INT32: \
func<int32_t>(__VA_ARGS__); \
break; \
case experimental::DataType::INT64: \
func<int64_t>(__VA_ARGS__); \
break; \
default: \
VLOG(0) << "Error: Unknown DataType."; \
exit(-1); \
}
#define HOST_NAME_MAX 256
#else
#define GENERATE_FUNC(type, func, args...) \
switch (type) { \
case experimental::DataType::FLOAT32: \
func<float>(args); \
break; \
case experimental::DataType::FLOAT64: \
func<double>(args); \
break; \
case experimental::DataType::FLOAT16: \
func<gloo::float16>(args); \
break; \
case experimental::DataType::INT32: \
func<int32_t>(args); \
break; \
case experimental::DataType::INT64: \
func<int64_t>(args); \
break; \
default: \
VLOG(0) << "Error: Unknown DataType."; \
exit(-1); \
}
#endif
typedef void (*reduce_func)(void*, const void*, const void*, size_t);
template <typename T>
reduce_func get_function(const ReduceOp& r) {
switch (r) {
case ReduceOp::SUM:
return reduce_func(&::gloo::sum<T>);
case ReduceOp::PRODUCT:
return reduce_func(&::gloo::product<T>);
case ReduceOp::MIN:
return reduce_func(&::gloo::min<T>);
case ReduceOp::MAX:
return reduce_func(&::gloo::max<T>);
case ReduceOp::AVG:
VLOG(0) << "Error: Unsupported ReduceOp::AVG.";
exit(-1);
}
VLOG(0) << "Error: Unknown ReduceOp.";
exit(-1);
}
bool CheckTensorsInCPUPlace(const std::vector<Tensor>& tensors) {
return std::all_of(tensors.cbegin(), tensors.cend(), [&](const Tensor& t) {
return t.place() == PlaceType::kCPU;
});
}
template <typename T>
T* get_data(const Tensor& tensor) {
auto raw_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
return static_cast<T*>(raw_tensor->data());
}
template <typename T>
std::vector<T*> get_multi_data(const std::vector<Tensor>& tensors) {
std::vector<T*> ret(tensors.size());
for (size_t i = 0; i < tensors.size(); i++) {
ret[i] = get_data<T>(tensors[i]);
}
return ret;
}
template <typename T, typename P>
void set_output(P& opts, const Tensor& tensor) { // NOLINT
opts.setOutput(get_data<T>(tensor), tensor.numel());
}
template <typename T, typename P>
void set_input(P& opts, const Tensor& tensor) { // NOLINT
opts.setInput(get_data<T>(tensor), tensor.numel());
}
template <typename T, typename P>
void set_outputs(P& opts, const std::vector<Tensor>& tensors) { // NOLINT
opts.setOutputs(get_multi_data<T>(tensors), tensors[0].numel());
}
template <typename T, typename P>
void set_inputs(P& opts, const std::vector<Tensor>& tensors) { // NOLINT
opts.setInputs(get_multi_data<T>(tensors), tensors[0].numel());
}
ProcessGroupGloo::GlooTask::GlooTask(int rank,
const std::vector<Tensor>& inputs,
CommType comm_type)
: ProcessGroup::Task(rank, inputs, comm_type) {
PADDLE_ENFORCE_EQ(CheckTensorsInCPUPlace(inputs), true,
platform::errors::Fatal(
"Only CPU place is supported for ProcessGroupGloo."));
}
ProcessGroupGloo::ProcessGroupGloo(const std::shared_ptr<GlooStore>& store,
int rank, int world_size,
const std::shared_ptr<GlooOptions> options)
: ProcessGroup(rank, world_size), _tag(0), _store(store) {
_context = std::make_shared<gloo::rendezvous::Context>(rank, world_size);
auto prefix_store =
::gloo::rendezvous::PrefixStore(std::to_string(0), *_store);
_context->connectFullMesh(prefix_store, options->device);
}
class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
public:
BroadcastGlooTask(const std::shared_ptr<gloo::Context>& context,
const std::vector<Tensor>& inputs, int rank, int root,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::BROADCAST),
_context(context),
_root(root),
_inputs(inputs),
_tag(tag) {}
void Run() override { _do_broadcast(_inputs[0]); }
private:
std::shared_ptr<gloo::Context> _context;
const int _root;
std::vector<Tensor> _inputs{};
const uint32_t _tag;
void _do_broadcast(const Tensor& tensor) {
gloo::BroadcastOptions opts(_context);
const auto& dtype = tensor.type();
GENERATE_FUNC(dtype, set_output, opts, tensor);
opts.setRoot(_root);
opts.setTag(_tag);
gloo::broadcast(opts);
}
};
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
std::vector<Tensor>& inputs, const BroadcastOptions& opts) {
auto root = opts.source_rank;
std::unique_ptr<BroadcastGlooTask> task;
auto tag = next_tag();
auto context = get_context();
task = std::make_unique<BroadcastGlooTask>(context, inputs, rank_, root, tag);
task->Run();
return task;
}
class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
public:
AllreduceGlooTask(int rank, const std::shared_ptr<gloo::Context>& context,
std::vector<Tensor>& inputs, ReduceOp reduce_op, // NOLINT
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::ALLREDUCE),
_context(context),
_inputs(inputs),
_reduce_op(reduce_op),
_tag(tag) {}
void Run() override { _do_allreduce(_inputs); }
private:
std::shared_ptr<gloo::Context> _context;
std::vector<Tensor> _inputs;
const ReduceOp _reduce_op;
uint32_t _tag;
gloo::AllreduceOptions::Func _get_function(const experimental::DataType type,
const ReduceOp op) {
gloo::AllreduceOptions::Func fn;
GENERATE_FUNC(type, _get_function_impl, fn, op);
return fn;
}
template <typename T>
void _get_function_impl(gloo::AllreduceOptions::Func& fn, // NOLINT
const ReduceOp op) {
fn = get_function<T>(op);
}
void _do_allreduce(std::vector<Tensor>& tensors) { // NOLINT
const auto& dtype = tensors[0].type();
gloo::AllreduceOptions opts(_context);
GENERATE_FUNC(dtype, set_inputs, opts, tensors);
GENERATE_FUNC(dtype, set_outputs, opts, tensors);
opts.setReduceFunction(_get_function(dtype, _reduce_op));
opts.setTag(_tag);
gloo::allreduce(opts);
}
};
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
std::vector<Tensor>& inputs, const AllreduceOptions& opts) {
auto tag = next_tag();
std::shared_ptr<GlooTask> task;
auto context = get_context();
task = std::make_shared<AllreduceGlooTask>(rank_, context, inputs,
opts.reduce_op, tag);
task->Run();
return task;
}
std::shared_ptr<::gloo::transport::Device>
ProcessGroupGloo::createDeviceForInterface(const std::string& ifname) {
::gloo::transport::tcp::attr attr;
attr.iface = ifname;
return ::gloo::transport::tcp::CreateDevice(attr);
}
std::shared_ptr<::gloo::transport::Device>
ProcessGroupGloo::createDeviceForHostname(const std::string& hostname) {
::gloo::transport::tcp::attr attr;
attr.hostname = hostname;
return ::gloo::transport::tcp::CreateDevice(attr);
}
std::shared_ptr<::gloo::transport::Device>
ProcessGroupGloo::createDefaultDevice() {
std::array<char, HOST_NAME_MAX> hostname{};
auto ret = ::gethostname(hostname.data(), HOST_NAME_MAX);
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::Fatal(
"Get hostname error for createDefaultDevice."));
::addrinfo* result;
result = tcputils::get_addr_info(hostname.data(), "", 0, AF_UNSPEC);
::addrinfo* cur;
for (cur = result; cur != nullptr; cur = cur->ai_next) {
SocketType socket =
::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol);
if (socket == -1) {
continue;
}
ret = ::bind(socket, cur->ai_addr, cur->ai_addrlen);
#ifdef _WIN32
closesocket(socket);
#else
close(socket);
#endif
if (ret == -1) {
continue;
}
break;
}
freeaddrinfo(result);
if (cur != nullptr) {
return createDeviceForHostname(hostname.data());
}
return createDeviceForHostname("127.0.0.1");
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <mutex>
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#ifdef PADDLE_WITH_GLOO
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
constexpr const char* GLOO_BACKEND_NAME = "GLOO";
namespace paddle {
namespace distributed {
class ProcessGroupGloo : public ProcessGroup {
public:
class GlooTask : public ProcessGroup::Task,
public std::enable_shared_from_this<GlooTask> {
public:
explicit GlooTask(int rank, const std::vector<Tensor>& input_tensors,
CommType comm_type);
~GlooTask() = default;
virtual void Run() = 0;
bool Wait(std::chrono::milliseconds timeout) override { return true; }
bool IsCompleted() override { return true; }
void Synchronize() override {}
protected:
friend class ProcessGroupGloo;
};
class GlooStore : public ::gloo::rendezvous::Store {
public:
explicit GlooStore(
const std::shared_ptr<paddle::distributed::TCPStore>& store)
: _store(store) {}
~GlooStore() = default;
std::vector<char> get(const std::string& key) override {
VLOG(3) << "GlooStore::get";
auto value = _store->get(key);
return std::vector<char>(value.begin(), value.end());
}
void wait(const std::vector<std::string>& keys) override {
VLOG(3) << "GlooStore::wait";
for (auto& key : keys) {
_store->wait(key);
}
}
void set(const std::string& key, const std::vector<char>& value) override {
VLOG(3) << "GlooStore::set";
std::vector<uint8_t> tmp(value.begin(), value.end());
_store->set(key, tmp);
}
void wait(const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) override {
VLOG(3) << "GlooStore::wait";
for (auto& key : keys) {
_store->wait(key);
}
// wait(keys);
}
protected:
std::shared_ptr<paddle::distributed::TCPStore> _store;
};
class GlooOptions {
public:
GlooOptions() = default;
~GlooOptions() = default;
static std::shared_ptr<GlooOptions> create() {
return std::make_shared<GlooOptions>();
}
std::shared_ptr<::gloo::transport::Device> device;
};
explicit ProcessGroupGloo(const std::shared_ptr<GlooStore>& store, int rank,
int world_size,
std::shared_ptr<GlooOptions> options);
~ProcessGroupGloo() = default;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<Tensor>& inputs,
const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<Tensor>& inputs,
const AllreduceOptions& opts = AllreduceOptions()) override;
std::shared_ptr<::gloo::Context> get_context() { return _context; }
uint64_t next_tag() { return _tag++; }
const std::string GetBackendName() const override {
return GLOO_BACKEND_NAME;
}
// Helper functions for Gloo.
static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname(
const std::string& hostname);
static std::shared_ptr<::gloo::transport::Device> createDeviceForInterface(
const std::string& ifname);
static std::shared_ptr<::gloo::transport::Device> createDefaultDevice();
protected:
uint32_t _tag;
std::shared_ptr<gloo::rendezvous::Context> _context;
std::shared_ptr<GlooStore> _store;
};
} // namespace distributed
} // namespace paddle
......@@ -32,6 +32,8 @@ class Store {
virtual int64_t add(const std::string& key, int64_t value) = 0;
virtual std::vector<uint8_t> get(const std::string& key) = 0;
virtual void wait(const std::string& key) = 0;
virtual void set(const std::string& key,
const std::vector<uint8_t>& value) = 0;
virtual const std::chrono::seconds& timeout() const { return _timeout; }
......
......@@ -27,11 +27,13 @@ namespace detail {
constexpr int INFTIME = -1;
std::unique_ptr<MasterDaemon> MasterDaemon::start(SocketType socket) {
return std::make_unique<MasterDaemon>(socket);
std::unique_ptr<MasterDaemon> MasterDaemon::start(SocketType socket,
int nranks) {
return std::make_unique<MasterDaemon>(socket, nranks);
}
MasterDaemon::MasterDaemon(SocketType socket) : _listen_socket(socket) {
MasterDaemon::MasterDaemon(SocketType socket, int nranks)
: _listen_socket(socket), _nranks(nranks) {
_background_thread = std::thread{&MasterDaemon::run, this};
}
......@@ -64,6 +66,13 @@ void MasterDaemon::_do_add(SocketType socket) {
tcputils::send_value<int64_t>(socket, new_value);
}
void MasterDaemon::_do_set(SocketType socket) {
VLOG(3) << "MasterDaemon::_do_set";
std::string key = tcputils::receive_string(socket);
auto value = tcputils::receive_vector<uint8_t>(socket);
_store[key] = value;
}
void MasterDaemon::_do_get(SocketType socket) {
std::string key = tcputils::receive_string(socket);
auto iter = _store.find(key);
......@@ -71,16 +80,15 @@ void MasterDaemon::_do_get(SocketType socket) {
iter, _store.end(),
platform::errors::InvalidArgument("Key %s not found in TCPStore.", key));
std::vector<uint8_t> value = iter->second;
VLOG(3) << "TCPStore: value ("
<< std::stoll(std::string(reinterpret_cast<char*>(value.data()),
value.size()))
<< ") for key (" << key << ").";
tcputils::send_vector<uint8_t>(socket, value);
}
void MasterDaemon::_do_stop(SocketType socket) {
VLOG(3) << "MasterDaemon::_do_stop";
ReplyType value = ReplyType::STOP_WAIT;
_stop = true;
if (--_nranks == 0) {
_stop = true;
}
tcputils::send_value<ReplyType>(socket, value);
}
......@@ -140,21 +148,27 @@ void MasterDaemon::run() {
case Command::GET:
_do_get(fds[i].fd);
break;
case Command::SET:
_do_set(fds[i].fd);
break;
case Command::WAIT:
_do_wait(fds[i].fd);
break;
case Command::STOP:
_do_stop(fds[i].fd);
break;
default:
VLOG(0) << "Unknow command: " << static_cast<int>(command);
exit(-1);
}
}
}
}
std::unique_ptr<TCPServer> TCPServer::create(uint16_t port) {
std::unique_ptr<TCPServer> TCPServer::create(uint16_t port, int nranks) {
int socket = tcputils::tcp_listen("", std::to_string(port), AF_INET);
auto server = std::make_unique<TCPServer>();
server->_master_daemon = MasterDaemon::start(socket);
server->_master_daemon = MasterDaemon::start(socket, nranks);
return server;
}
......@@ -200,7 +214,7 @@ TCPStore::TCPStore(std::string host, uint16_t port, bool is_master,
size_t num_workers, std::chrono::seconds timeout)
: Store(timeout), _is_master(is_master), _num_workers(num_workers) {
if (_is_master) {
_server = detail::TCPServer::create(port);
_server = detail::TCPServer::create(port, num_workers);
}
_client = detail::TCPClient::connect(host, port);
......@@ -213,36 +227,41 @@ void TCPStore::waitWorkers() {
}
add(_init_key, 1);
if (_server) {
auto begin = std::chrono::steady_clock::now();
do {
auto value = get(_init_key);
int completed = std::stoi(std::string(value.begin(), value.end()));
VLOG(3) << completed << " worker ready, total " << _num_workers;
if (completed >= _num_workers) {
break;
}
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - begin);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (_timeout != tcputils::kNoTimeout && elapsed > _timeout) {
PADDLE_ENFORCE_EQ(
completed, _num_workers,
platform::errors::InvalidArgument(
"TCPStore timeouted and not all workers got ready."));
}
} while (true);
}
auto begin = std::chrono::steady_clock::now();
do {
auto value = get(_init_key);
int completed = std::stoi(std::string(value.begin(), value.end()));
VLOG(3) << completed << " worker ready, total " << _num_workers;
if (completed >= _num_workers) {
break;
}
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - begin);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (_timeout != tcputils::kNoTimeout && elapsed > _timeout) {
PADDLE_ENFORCE_EQ(
completed, _num_workers,
platform::errors::InvalidArgument(
"TCPStore timeouted and not all workers got ready."));
}
} while (true);
VLOG(3) << "TCPStore initialized.";
}
int64_t TCPStore::add(const std::string& key, int64_t value) {
VLOG(3) << "TCPStore add.";
_client->send_command_for_key(Command::ADD, _key_prefix + key);
_client->send_value<std::int64_t>(value);
return _client->receive_value<std::int64_t>();
}
void TCPStore::set(const std::string& key, const std::vector<uint8_t>& value) {
VLOG(3) << "TCPStore set.";
_client->send_command_for_key(Command::SET, _key_prefix + key);
_client->send_vector<std::uint8_t>(value);
}
std::vector<uint8_t> TCPStore::get(const std::string& key) {
wait(key);
_client->send_command_for_key(Command::GET, _key_prefix + key);
......@@ -252,6 +271,7 @@ std::vector<uint8_t> TCPStore::get(const std::string& key) {
void TCPStore::wait(const std::string& key) {
ReplyType reply;
VLOG(3) << "TCPStore wait.";
do {
_client->send_command_for_key(Command::WAIT, _key_prefix + key);
......@@ -262,6 +282,7 @@ void TCPStore::wait(const std::string& key) {
TCPStore::~TCPStore() {
_client->send_command_for_key(Command::STOP, "");
VLOG(3) << "~TCPStore";
ReplyType ret = _client->receive_value<ReplyType>();
PADDLE_ENFORCE_EQ(ret, ReplyType::STOP_WAIT,
platform::errors::InvalidArgument(
......
......@@ -27,15 +27,16 @@ namespace paddle {
namespace distributed {
enum class ReplyType { WAITING, STOP_WAIT };
enum class Command { ADD, GET, WAIT, STOP };
enum class Command { ADD, GET, SET, WAIT, STOP };
namespace detail {
class MasterDaemon {
public:
static std::unique_ptr<MasterDaemon> start(SocketType listen_socket);
static std::unique_ptr<MasterDaemon> start(SocketType listen_socket,
int nranks);
MasterDaemon() = delete;
explicit MasterDaemon(SocketType listen_socket);
explicit MasterDaemon(SocketType listen_socket, int nranks);
~MasterDaemon();
private:
......@@ -43,18 +44,20 @@ class MasterDaemon {
void _do_add(SocketType socket);
void _do_wait(SocketType socket);
void _do_get(SocketType socket);
void _do_set(SocketType socket);
void _do_stop(SocketType socket);
SocketType _listen_socket;
std::vector<SocketType> _sockets;
std::unordered_map<std::string, std::vector<uint8_t>> _store;
std::thread _background_thread{};
int _nranks;
bool _stop = false;
};
class TCPServer {
public:
TCPServer() = default;
static std::unique_ptr<TCPServer> create(std::uint16_t port);
static std::unique_ptr<TCPServer> create(std::uint16_t port, int nranks);
private:
std::unique_ptr<MasterDaemon> _master_daemon;
......@@ -97,6 +100,7 @@ class TCPStore : public Store {
int64_t add(const std::string& key, int64_t value) override;
std::vector<uint8_t> get(const std::string& key) override;
void wait(const std::string& key) override;
void set(const std::string& key, const std::vector<uint8_t>& value) override;
private:
void waitWorkers();
......
......@@ -46,9 +46,10 @@ void close_socket(SocketType socket) {
hints.ai_socktype = SOCK_STREAM;
const char* node = host.empty() ? nullptr : host.c_str();
const char* port_cstr = port.empty() ? nullptr : port.c_str();
int n;
n = ::getaddrinfo(node, port.c_str(), &hints, &res);
n = ::getaddrinfo(node, port_cstr, &hints, &res);
const char* gai_err = ::gai_strerror(n);
const char* proto =
(family == AF_INET ? "IPv4" : family == AF_INET6 ? "IPv6" : "");
......
......@@ -85,6 +85,9 @@ if(NOT ON_INFER)
if (WITH_NCCL)
set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_nccl)
endif()
if (WITH_GLOO)
set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_gloo)
endif()
set(PYBIND_SRCS ${PYBIND_SRCS} distributed_py.cc)
endif()
......
......@@ -31,9 +31,15 @@ namespace pybind {
using TCPStore = paddle::distributed::TCPStore;
void BindTCPStore(py::module* m) {
py::class_<TCPStore>(*m, "TCPStore")
.def(
py::init<std::string, uint16_t, bool, size_t, std::chrono::seconds>())
py::class_<TCPStore, std::shared_ptr<TCPStore>>(*m, "TCPStore")
.def(py::init([](std::string hostname, uint16_t port, bool is_master,
size_t world_size, std::chrono::seconds timeout) {
return std::make_shared<TCPStore>(hostname, port, is_master,
world_size, timeout);
}),
py::arg("hostname"), py::arg("port"), py::arg("is_master"),
py::arg("world_size"), py::arg("timeout"),
py::call_guard<py::gil_scoped_release>())
.def("add", &TCPStore::add)
.def("get", &TCPStore::get);
}
......
......@@ -35,6 +35,11 @@ limitations under the License. */
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#endif
#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
#endif
namespace py = pybind11;
namespace paddle {
......@@ -42,6 +47,14 @@ namespace pybind {
using Tensor = paddle::experimental::Tensor;
#if defined(PADDLE_WITH_GLOO)
using ProcessGroupGloo = paddle::distributed::ProcessGroupGloo;
using GlooStore = paddle::distributed::ProcessGroupGloo::GlooStore;
using GlooOptions = paddle::distributed::ProcessGroupGloo::GlooOptions;
#endif
static std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; // NOLINT
void BindDistributed(py::module *m) {
py::enum_<distributed::ReduceOp>(*m, "ReduceOp")
.value("SUM", distributed::ReduceOp::SUM)
......@@ -129,6 +142,7 @@ void BindDistributed(py::module *m) {
*m, "ProcessGroupNCCL", ProcessGroup)
.def(py::init<const distributed::ProcessGroupStrategy &, int, int>(),
py::call_guard<py::gil_scoped_release>());
#endif
py::class_<distributed::ProcessGroup::Task,
std::shared_ptr<distributed::ProcessGroup::Task>>(*m, "task")
......@@ -138,7 +152,6 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def("synchronize", &distributed::ProcessGroup::Task::Synchronize,
py::call_guard<py::gil_scoped_release>());
#endif
// define parallel strategy, it will be removed
py::class_<distributed::ProcessGroupStrategy> pg_strategy(
......@@ -178,6 +191,45 @@ void BindDistributed(py::module *m) {
self.nrings_ = nrings;
});
#if defined(PADDLE_WITH_GLOO)
py::class_<GlooOptions>(*m, "GlooOptions")
.def(py::init<>())
.def_readwrite("_device", &GlooOptions::device)
.def_static("create", &GlooOptions::create);
py::class_<GlooStore, std::shared_ptr<GlooStore>>(*m, "GlooStore")
.def(py::init(
[](const std::shared_ptr<paddle::distributed::TCPStore> &store) {
return std::make_shared<GlooStore>(store);
}),
py::call_guard<py::gil_scoped_release>());
py::class_<ProcessGroupGloo, std::shared_ptr<ProcessGroupGloo>>(
*m, "ProcessGroupGloo", ProcessGroup)
.def(py::init<const std::shared_ptr<GlooStore> &, int, int,
std::shared_ptr<GlooOptions> &>(),
py::call_guard<py::gil_scoped_release>())
.def(py::init([](const std::shared_ptr<GlooStore> &store, int rank,
int world_size) {
auto opts = GlooOptions::create();
char *ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
if (ifname && strlen(ifname) > 1) {
opts->device = ProcessGroupGloo::createDeviceForInterface(
std::string(ifname));
} else {
opts->device = ProcessGroupGloo::createDefaultDevice();
}
return std::make_shared<ProcessGroupGloo>(store, rank, world_size,
opts);
}),
py::arg("store"), py::arg("rank"),
py::arg("world_size"), // py::arg("timeout") =
// kProcessGroupDefaultTimeout,
py::call_guard<py::gil_scoped_release>())
.def_static("create_default_device",
&ProcessGroupGloo::createDefaultDevice);
#endif
m->def("eager_assign_group_by_size",
[](py::handle py_tensors, std::vector<bool> is_sparse_gradient,
std::vector<size_t> group_size_limits,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import random
import numpy as np
import os
import shutil
import paddle
from paddle.fluid import core
import datetime
from datetime import timedelta
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.dygraph.parallel import ParallelEnv
class TestProcessGroupFp32(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
random.seed(2022)
np.random.seed(2022)
self.config()
def config(self):
self.dtype = "float32"
self.shape = (2, 10, 5)
def test_create_process_group_gloo(self):
with _test_eager_guard():
nranks = ParallelEnv().nranks
rank = ParallelEnv().local_rank
is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore("127.0.0.1", 6172, is_master,
nranks, datetime.timedelta(0))
gloo_store = paddle.fluid.core.GlooStore(store)
opt = paddle.fluid.core.GlooOptions()
pg = paddle.fluid.core.ProcessGroupGloo(gloo_store, rank, nranks)
# test allreduce sum
# rank 0
paddle.device.set_device('cpu')
x = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
# rank 1
y = np.random.random(self.shape).astype(self.dtype)
tensor_y = paddle.to_tensor(y)
sum_result = x + y
if rank == 0:
task = pg.allreduce(tensor_x)
task.wait()
assert np.array_equal(tensor_x, sum_result)
else:
task = pg.allreduce(tensor_y)
task.wait()
assert np.array_equal(tensor_y, sum_result)
print("test allreduce sum api ok")
# test allreduce max
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
# rank 1
y = np.random.random(self.shape).astype(self.dtype)
tensor_y = paddle.to_tensor(y)
max_result = paddle.maximum(tensor_x, tensor_y)
if rank == 0:
task = pg.allreduce(tensor_x, core.ReduceOp.MAX)
task.wait()
assert np.array_equal(tensor_x, max_result)
else:
task = pg.allreduce(tensor_y, core.ReduceOp.MAX)
task.wait()
assert np.array_equal(tensor_y, max_result)
print("test allreduce max api ok")
# test broadcast
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
# rank 1
y = np.random.random(self.shape).astype(self.dtype)
tensor_y = paddle.to_tensor(y)
broadcast_result = paddle.assign(tensor_x)
if rank == 0:
task = pg.broadcast(tensor_x, 0)
task.synchronize()
assert task.is_completed()
assert np.array_equal(broadcast_result, tensor_x)
else:
task = pg.broadcast(tensor_y, 0)
task.synchronize()
assert task.is_completed()
assert np.array_equal(broadcast_result, tensor_y)
print("test broadcast api ok")
if __name__ == "__main__":
unittest.main()
......@@ -22,6 +22,9 @@ class TestProcessGroup(TestMultipleGpus):
def test_process_group_nccl(self):
self.run_mnist_2gpu('process_group_nccl.py')
def test_process_group_gloo(self):
self.run_mnist_2gpu('process_group_gloo.py')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册