From c16f85f95d0c42989e22c5ebae709f60506111a0 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Thu, 3 Mar 2022 01:24:26 +0800 Subject: [PATCH] Add the implementation of Gloo for ProcessGroup (#39892) * add pg_gloo --- .../distributed/collective/CMakeLists.txt | 3 + .../collective/ProcessGroupGloo.cc | 308 ++++++++++++++++++ .../distributed/collective/ProcessGroupGloo.h | 138 ++++++++ paddle/fluid/distributed/store/store.h | 2 + paddle/fluid/distributed/store/tcp_store.cc | 85 +++-- paddle/fluid/distributed/store/tcp_store.h | 12 +- paddle/fluid/distributed/store/tcp_utils.cc | 3 +- paddle/fluid/pybind/CMakeLists.txt | 3 + paddle/fluid/pybind/communication.cc | 12 +- paddle/fluid/pybind/distributed_py.cc | 54 ++- .../tests/unittests/process_group_gloo.py | 119 +++++++ .../test_collective_process_group.py | 3 + 12 files changed, 701 insertions(+), 41 deletions(-) create mode 100644 paddle/fluid/distributed/collective/ProcessGroupGloo.cc create mode 100644 paddle/fluid/distributed/collective/ProcessGroupGloo.h create mode 100644 python/paddle/fluid/tests/unittests/process_group_gloo.py diff --git a/paddle/fluid/distributed/collective/CMakeLists.txt b/paddle/fluid/distributed/collective/CMakeLists.txt index a5b40f8aa07..96bc4a710f8 100644 --- a/paddle/fluid/distributed/collective/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/CMakeLists.txt @@ -1,4 +1,7 @@ cc_library(processgroup SRCS ProcessGroup.cc DEPS phi phi_api eager_api) +if (WITH_DISTRIBUTE) + cc_library(processgroup_gloo SRCS ProcessGroupGloo.cc DEPS phi phi_api eager_api gloo_wrapper) +endif() cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup) if(WITH_NCCL) diff --git a/paddle/fluid/distributed/collective/ProcessGroupGloo.cc b/paddle/fluid/distributed/collective/ProcessGroupGloo.cc new file mode 100644 index 00000000000..03ad48f560a --- /dev/null +++ b/paddle/fluid/distributed/collective/ProcessGroupGloo.cc @@ -0,0 +1,308 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#ifdef _WIN32 +#include +#include +#include +#else +#include +#include +#include +#endif + +#include +#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h" +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace distributed { + +#ifdef _WIN32 +#define GENERATE_FUNC(type, func, ...) \ + switch (type) { \ + case experimental::DataType::FLOAT32: \ + func(__VA_ARGS__); \ + break; \ + case experimental::DataType::FLOAT64: \ + func(__VA_ARGS__); \ + break; \ + case experimental::DataType::FLOAT16: \ + func(__VA_ARGS__); \ + break; \ + case experimental::DataType::INT32: \ + func(__VA_ARGS__); \ + break; \ + case experimental::DataType::INT64: \ + func(__VA_ARGS__); \ + break; \ + default: \ + VLOG(0) << "Error: Unknown DataType."; \ + exit(-1); \ + } + +#define HOST_NAME_MAX 256 + +#else +#define GENERATE_FUNC(type, func, args...) \ + switch (type) { \ + case experimental::DataType::FLOAT32: \ + func(args); \ + break; \ + case experimental::DataType::FLOAT64: \ + func(args); \ + break; \ + case experimental::DataType::FLOAT16: \ + func(args); \ + break; \ + case experimental::DataType::INT32: \ + func(args); \ + break; \ + case experimental::DataType::INT64: \ + func(args); \ + break; \ + default: \ + VLOG(0) << "Error: Unknown DataType."; \ + exit(-1); \ + } +#endif + +typedef void (*reduce_func)(void*, const void*, const void*, size_t); + +template +reduce_func get_function(const ReduceOp& r) { + switch (r) { + case ReduceOp::SUM: + return reduce_func(&::gloo::sum); + case ReduceOp::PRODUCT: + return reduce_func(&::gloo::product); + case ReduceOp::MIN: + return reduce_func(&::gloo::min); + case ReduceOp::MAX: + return reduce_func(&::gloo::max); + case ReduceOp::AVG: + VLOG(0) << "Error: Unsupported ReduceOp::AVG."; + exit(-1); + } + + VLOG(0) << "Error: Unknown ReduceOp."; + exit(-1); +} + +bool CheckTensorsInCPUPlace(const std::vector& tensors) { + return std::all_of(tensors.cbegin(), tensors.cend(), [&](const Tensor& t) { + return t.place() == PlaceType::kCPU; + }); +} + +template +T* get_data(const Tensor& tensor) { + auto raw_tensor = std::dynamic_pointer_cast(tensor.impl()); + return static_cast(raw_tensor->data()); +} + +template +std::vector get_multi_data(const std::vector& tensors) { + std::vector ret(tensors.size()); + for (size_t i = 0; i < tensors.size(); i++) { + ret[i] = get_data(tensors[i]); + } + return ret; +} + +template +void set_output(P& opts, const Tensor& tensor) { // NOLINT + opts.setOutput(get_data(tensor), tensor.numel()); +} + +template +void set_input(P& opts, const Tensor& tensor) { // NOLINT + opts.setInput(get_data(tensor), tensor.numel()); +} + +template +void set_outputs(P& opts, const std::vector& tensors) { // NOLINT + opts.setOutputs(get_multi_data(tensors), tensors[0].numel()); +} + +template +void set_inputs(P& opts, const std::vector& tensors) { // NOLINT + opts.setInputs(get_multi_data(tensors), tensors[0].numel()); +} + +ProcessGroupGloo::GlooTask::GlooTask(int rank, + const std::vector& inputs, + CommType comm_type) + : ProcessGroup::Task(rank, inputs, comm_type) { + PADDLE_ENFORCE_EQ(CheckTensorsInCPUPlace(inputs), true, + platform::errors::Fatal( + "Only CPU place is supported for ProcessGroupGloo.")); +} + +ProcessGroupGloo::ProcessGroupGloo(const std::shared_ptr& store, + int rank, int world_size, + const std::shared_ptr options) + : ProcessGroup(rank, world_size), _tag(0), _store(store) { + _context = std::make_shared(rank, world_size); + auto prefix_store = + ::gloo::rendezvous::PrefixStore(std::to_string(0), *_store); + _context->connectFullMesh(prefix_store, options->device); +} + +class BroadcastGlooTask : public ProcessGroupGloo::GlooTask { + public: + BroadcastGlooTask(const std::shared_ptr& context, + const std::vector& inputs, int rank, int root, + uint32_t tag) + : ProcessGroupGloo::GlooTask(rank, inputs, CommType::BROADCAST), + _context(context), + _root(root), + _inputs(inputs), + _tag(tag) {} + + void Run() override { _do_broadcast(_inputs[0]); } + + private: + std::shared_ptr _context; + const int _root; + std::vector _inputs{}; + const uint32_t _tag; + + void _do_broadcast(const Tensor& tensor) { + gloo::BroadcastOptions opts(_context); + const auto& dtype = tensor.type(); + GENERATE_FUNC(dtype, set_output, opts, tensor); + opts.setRoot(_root); + opts.setTag(_tag); + gloo::broadcast(opts); + } +}; + +std::shared_ptr ProcessGroupGloo::Broadcast( + std::vector& inputs, const BroadcastOptions& opts) { + auto root = opts.source_rank; + std::unique_ptr task; + auto tag = next_tag(); + auto context = get_context(); + task = std::make_unique(context, inputs, rank_, root, tag); + task->Run(); + return task; +} + +class AllreduceGlooTask : public ProcessGroupGloo::GlooTask { + public: + AllreduceGlooTask(int rank, const std::shared_ptr& context, + std::vector& inputs, ReduceOp reduce_op, // NOLINT + uint32_t tag) + : ProcessGroupGloo::GlooTask(rank, inputs, CommType::ALLREDUCE), + _context(context), + _inputs(inputs), + _reduce_op(reduce_op), + _tag(tag) {} + + void Run() override { _do_allreduce(_inputs); } + + private: + std::shared_ptr _context; + std::vector _inputs; + const ReduceOp _reduce_op; + uint32_t _tag; + + gloo::AllreduceOptions::Func _get_function(const experimental::DataType type, + const ReduceOp op) { + gloo::AllreduceOptions::Func fn; + GENERATE_FUNC(type, _get_function_impl, fn, op); + return fn; + } + + template + void _get_function_impl(gloo::AllreduceOptions::Func& fn, // NOLINT + const ReduceOp op) { + fn = get_function(op); + } + + void _do_allreduce(std::vector& tensors) { // NOLINT + const auto& dtype = tensors[0].type(); + gloo::AllreduceOptions opts(_context); + GENERATE_FUNC(dtype, set_inputs, opts, tensors); + GENERATE_FUNC(dtype, set_outputs, opts, tensors); + opts.setReduceFunction(_get_function(dtype, _reduce_op)); + opts.setTag(_tag); + gloo::allreduce(opts); + } +}; + +std::shared_ptr ProcessGroupGloo::AllReduce( + std::vector& inputs, const AllreduceOptions& opts) { + auto tag = next_tag(); + std::shared_ptr task; + auto context = get_context(); + task = std::make_shared(rank_, context, inputs, + opts.reduce_op, tag); + task->Run(); + return task; +} + +std::shared_ptr<::gloo::transport::Device> +ProcessGroupGloo::createDeviceForInterface(const std::string& ifname) { + ::gloo::transport::tcp::attr attr; + attr.iface = ifname; + return ::gloo::transport::tcp::CreateDevice(attr); +} + +std::shared_ptr<::gloo::transport::Device> +ProcessGroupGloo::createDeviceForHostname(const std::string& hostname) { + ::gloo::transport::tcp::attr attr; + attr.hostname = hostname; + return ::gloo::transport::tcp::CreateDevice(attr); +} + +std::shared_ptr<::gloo::transport::Device> +ProcessGroupGloo::createDefaultDevice() { + std::array hostname{}; + auto ret = ::gethostname(hostname.data(), HOST_NAME_MAX); + PADDLE_ENFORCE_EQ(ret, 0, platform::errors::Fatal( + "Get hostname error for createDefaultDevice.")); + ::addrinfo* result; + result = tcputils::get_addr_info(hostname.data(), "", 0, AF_UNSPEC); + ::addrinfo* cur; + for (cur = result; cur != nullptr; cur = cur->ai_next) { + SocketType socket = + ::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol); + if (socket == -1) { + continue; + } + ret = ::bind(socket, cur->ai_addr, cur->ai_addrlen); +#ifdef _WIN32 + closesocket(socket); +#else + close(socket); +#endif + if (ret == -1) { + continue; + } + break; + } + freeaddrinfo(result); + if (cur != nullptr) { + return createDeviceForHostname(hostname.data()); + } + return createDeviceForHostname("127.0.0.1"); +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupGloo.h b/paddle/fluid/distributed/collective/ProcessGroupGloo.h new file mode 100644 index 00000000000..d989939fcb8 --- /dev/null +++ b/paddle/fluid/distributed/collective/ProcessGroupGloo.h @@ -0,0 +1,138 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/distributed/collective/ProcessGroup.h" + +#ifdef PADDLE_WITH_GLOO +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#endif + +#include "paddle/fluid/distributed/store/store.h" +#include "paddle/fluid/distributed/store/tcp_store.h" + +constexpr const char* GLOO_BACKEND_NAME = "GLOO"; + +namespace paddle { +namespace distributed { + +class ProcessGroupGloo : public ProcessGroup { + public: + class GlooTask : public ProcessGroup::Task, + public std::enable_shared_from_this { + public: + explicit GlooTask(int rank, const std::vector& input_tensors, + CommType comm_type); + + ~GlooTask() = default; + + virtual void Run() = 0; + bool Wait(std::chrono::milliseconds timeout) override { return true; } + bool IsCompleted() override { return true; } + void Synchronize() override {} + + protected: + friend class ProcessGroupGloo; + }; + + class GlooStore : public ::gloo::rendezvous::Store { + public: + explicit GlooStore( + const std::shared_ptr& store) + : _store(store) {} + + ~GlooStore() = default; + + std::vector get(const std::string& key) override { + VLOG(3) << "GlooStore::get"; + auto value = _store->get(key); + return std::vector(value.begin(), value.end()); + } + + void wait(const std::vector& keys) override { + VLOG(3) << "GlooStore::wait"; + for (auto& key : keys) { + _store->wait(key); + } + } + + void set(const std::string& key, const std::vector& value) override { + VLOG(3) << "GlooStore::set"; + std::vector tmp(value.begin(), value.end()); + _store->set(key, tmp); + } + + void wait(const std::vector& keys, + const std::chrono::milliseconds& timeout) override { + VLOG(3) << "GlooStore::wait"; + for (auto& key : keys) { + _store->wait(key); + } + // wait(keys); + } + + protected: + std::shared_ptr _store; + }; + + class GlooOptions { + public: + GlooOptions() = default; + ~GlooOptions() = default; + static std::shared_ptr create() { + return std::make_shared(); + } + std::shared_ptr<::gloo::transport::Device> device; + }; + + explicit ProcessGroupGloo(const std::shared_ptr& store, int rank, + int world_size, + std::shared_ptr options); + + ~ProcessGroupGloo() = default; + + std::shared_ptr Broadcast( + std::vector& inputs, + const BroadcastOptions& = BroadcastOptions()) override; + + std::shared_ptr AllReduce( + std::vector& inputs, + const AllreduceOptions& opts = AllreduceOptions()) override; + + std::shared_ptr<::gloo::Context> get_context() { return _context; } + uint64_t next_tag() { return _tag++; } + + const std::string GetBackendName() const override { + return GLOO_BACKEND_NAME; + } + + // Helper functions for Gloo. + static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname( + const std::string& hostname); + static std::shared_ptr<::gloo::transport::Device> createDeviceForInterface( + const std::string& ifname); + static std::shared_ptr<::gloo::transport::Device> createDefaultDevice(); + + protected: + uint32_t _tag; + std::shared_ptr _context; + std::shared_ptr _store; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/store/store.h b/paddle/fluid/distributed/store/store.h index 2673314d222..2581a74d7e8 100644 --- a/paddle/fluid/distributed/store/store.h +++ b/paddle/fluid/distributed/store/store.h @@ -32,6 +32,8 @@ class Store { virtual int64_t add(const std::string& key, int64_t value) = 0; virtual std::vector get(const std::string& key) = 0; virtual void wait(const std::string& key) = 0; + virtual void set(const std::string& key, + const std::vector& value) = 0; virtual const std::chrono::seconds& timeout() const { return _timeout; } diff --git a/paddle/fluid/distributed/store/tcp_store.cc b/paddle/fluid/distributed/store/tcp_store.cc index de85ac0d910..8675981955d 100644 --- a/paddle/fluid/distributed/store/tcp_store.cc +++ b/paddle/fluid/distributed/store/tcp_store.cc @@ -27,11 +27,13 @@ namespace detail { constexpr int INFTIME = -1; -std::unique_ptr MasterDaemon::start(SocketType socket) { - return std::make_unique(socket); +std::unique_ptr MasterDaemon::start(SocketType socket, + int nranks) { + return std::make_unique(socket, nranks); } -MasterDaemon::MasterDaemon(SocketType socket) : _listen_socket(socket) { +MasterDaemon::MasterDaemon(SocketType socket, int nranks) + : _listen_socket(socket), _nranks(nranks) { _background_thread = std::thread{&MasterDaemon::run, this}; } @@ -64,6 +66,13 @@ void MasterDaemon::_do_add(SocketType socket) { tcputils::send_value(socket, new_value); } +void MasterDaemon::_do_set(SocketType socket) { + VLOG(3) << "MasterDaemon::_do_set"; + std::string key = tcputils::receive_string(socket); + auto value = tcputils::receive_vector(socket); + _store[key] = value; +} + void MasterDaemon::_do_get(SocketType socket) { std::string key = tcputils::receive_string(socket); auto iter = _store.find(key); @@ -71,16 +80,15 @@ void MasterDaemon::_do_get(SocketType socket) { iter, _store.end(), platform::errors::InvalidArgument("Key %s not found in TCPStore.", key)); std::vector value = iter->second; - VLOG(3) << "TCPStore: value (" - << std::stoll(std::string(reinterpret_cast(value.data()), - value.size())) - << ") for key (" << key << ")."; tcputils::send_vector(socket, value); } void MasterDaemon::_do_stop(SocketType socket) { + VLOG(3) << "MasterDaemon::_do_stop"; ReplyType value = ReplyType::STOP_WAIT; - _stop = true; + if (--_nranks == 0) { + _stop = true; + } tcputils::send_value(socket, value); } @@ -140,21 +148,27 @@ void MasterDaemon::run() { case Command::GET: _do_get(fds[i].fd); break; + case Command::SET: + _do_set(fds[i].fd); + break; case Command::WAIT: _do_wait(fds[i].fd); break; case Command::STOP: _do_stop(fds[i].fd); break; + default: + VLOG(0) << "Unknow command: " << static_cast(command); + exit(-1); } } } } -std::unique_ptr TCPServer::create(uint16_t port) { +std::unique_ptr TCPServer::create(uint16_t port, int nranks) { int socket = tcputils::tcp_listen("", std::to_string(port), AF_INET); auto server = std::make_unique(); - server->_master_daemon = MasterDaemon::start(socket); + server->_master_daemon = MasterDaemon::start(socket, nranks); return server; } @@ -200,7 +214,7 @@ TCPStore::TCPStore(std::string host, uint16_t port, bool is_master, size_t num_workers, std::chrono::seconds timeout) : Store(timeout), _is_master(is_master), _num_workers(num_workers) { if (_is_master) { - _server = detail::TCPServer::create(port); + _server = detail::TCPServer::create(port, num_workers); } _client = detail::TCPClient::connect(host, port); @@ -213,36 +227,41 @@ void TCPStore::waitWorkers() { } add(_init_key, 1); - if (_server) { - auto begin = std::chrono::steady_clock::now(); - do { - auto value = get(_init_key); - int completed = std::stoi(std::string(value.begin(), value.end())); - VLOG(3) << completed << " worker ready, total " << _num_workers; - if (completed >= _num_workers) { - break; - } - const auto elapsed = std::chrono::duration_cast( - std::chrono::steady_clock::now() - begin); - - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - if (_timeout != tcputils::kNoTimeout && elapsed > _timeout) { - PADDLE_ENFORCE_EQ( - completed, _num_workers, - platform::errors::InvalidArgument( - "TCPStore timeouted and not all workers got ready.")); - } - } while (true); - } + auto begin = std::chrono::steady_clock::now(); + do { + auto value = get(_init_key); + int completed = std::stoi(std::string(value.begin(), value.end())); + VLOG(3) << completed << " worker ready, total " << _num_workers; + if (completed >= _num_workers) { + break; + } + const auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - begin); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + if (_timeout != tcputils::kNoTimeout && elapsed > _timeout) { + PADDLE_ENFORCE_EQ( + completed, _num_workers, + platform::errors::InvalidArgument( + "TCPStore timeouted and not all workers got ready.")); + } + } while (true); VLOG(3) << "TCPStore initialized."; } int64_t TCPStore::add(const std::string& key, int64_t value) { + VLOG(3) << "TCPStore add."; _client->send_command_for_key(Command::ADD, _key_prefix + key); _client->send_value(value); return _client->receive_value(); } +void TCPStore::set(const std::string& key, const std::vector& value) { + VLOG(3) << "TCPStore set."; + _client->send_command_for_key(Command::SET, _key_prefix + key); + _client->send_vector(value); +} + std::vector TCPStore::get(const std::string& key) { wait(key); _client->send_command_for_key(Command::GET, _key_prefix + key); @@ -252,6 +271,7 @@ std::vector TCPStore::get(const std::string& key) { void TCPStore::wait(const std::string& key) { ReplyType reply; + VLOG(3) << "TCPStore wait."; do { _client->send_command_for_key(Command::WAIT, _key_prefix + key); @@ -262,6 +282,7 @@ void TCPStore::wait(const std::string& key) { TCPStore::~TCPStore() { _client->send_command_for_key(Command::STOP, ""); + VLOG(3) << "~TCPStore"; ReplyType ret = _client->receive_value(); PADDLE_ENFORCE_EQ(ret, ReplyType::STOP_WAIT, platform::errors::InvalidArgument( diff --git a/paddle/fluid/distributed/store/tcp_store.h b/paddle/fluid/distributed/store/tcp_store.h index cd706dd6640..17c1d8ea30a 100644 --- a/paddle/fluid/distributed/store/tcp_store.h +++ b/paddle/fluid/distributed/store/tcp_store.h @@ -27,15 +27,16 @@ namespace paddle { namespace distributed { enum class ReplyType { WAITING, STOP_WAIT }; -enum class Command { ADD, GET, WAIT, STOP }; +enum class Command { ADD, GET, SET, WAIT, STOP }; namespace detail { class MasterDaemon { public: - static std::unique_ptr start(SocketType listen_socket); + static std::unique_ptr start(SocketType listen_socket, + int nranks); MasterDaemon() = delete; - explicit MasterDaemon(SocketType listen_socket); + explicit MasterDaemon(SocketType listen_socket, int nranks); ~MasterDaemon(); private: @@ -43,18 +44,20 @@ class MasterDaemon { void _do_add(SocketType socket); void _do_wait(SocketType socket); void _do_get(SocketType socket); + void _do_set(SocketType socket); void _do_stop(SocketType socket); SocketType _listen_socket; std::vector _sockets; std::unordered_map> _store; std::thread _background_thread{}; + int _nranks; bool _stop = false; }; class TCPServer { public: TCPServer() = default; - static std::unique_ptr create(std::uint16_t port); + static std::unique_ptr create(std::uint16_t port, int nranks); private: std::unique_ptr _master_daemon; @@ -97,6 +100,7 @@ class TCPStore : public Store { int64_t add(const std::string& key, int64_t value) override; std::vector get(const std::string& key) override; void wait(const std::string& key) override; + void set(const std::string& key, const std::vector& value) override; private: void waitWorkers(); diff --git a/paddle/fluid/distributed/store/tcp_utils.cc b/paddle/fluid/distributed/store/tcp_utils.cc index d0561d0b9a9..a28cba28833 100644 --- a/paddle/fluid/distributed/store/tcp_utils.cc +++ b/paddle/fluid/distributed/store/tcp_utils.cc @@ -46,9 +46,10 @@ void close_socket(SocketType socket) { hints.ai_socktype = SOCK_STREAM; const char* node = host.empty() ? nullptr : host.c_str(); + const char* port_cstr = port.empty() ? nullptr : port.c_str(); int n; - n = ::getaddrinfo(node, port.c_str(), &hints, &res); + n = ::getaddrinfo(node, port_cstr, &hints, &res); const char* gai_err = ::gai_strerror(n); const char* proto = (family == AF_INET ? "IPv4" : family == AF_INET6 ? "IPv6" : ""); diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 48d42f803a8..5e61133510d 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -85,6 +85,9 @@ if(NOT ON_INFER) if (WITH_NCCL) set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_nccl) endif() + if (WITH_GLOO) + set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_gloo) + endif() set(PYBIND_SRCS ${PYBIND_SRCS} distributed_py.cc) endif() diff --git a/paddle/fluid/pybind/communication.cc b/paddle/fluid/pybind/communication.cc index a0d2777f825..c01accaf598 100644 --- a/paddle/fluid/pybind/communication.cc +++ b/paddle/fluid/pybind/communication.cc @@ -31,9 +31,15 @@ namespace pybind { using TCPStore = paddle::distributed::TCPStore; void BindTCPStore(py::module* m) { - py::class_(*m, "TCPStore") - .def( - py::init()) + py::class_>(*m, "TCPStore") + .def(py::init([](std::string hostname, uint16_t port, bool is_master, + size_t world_size, std::chrono::seconds timeout) { + return std::make_shared(hostname, port, is_master, + world_size, timeout); + }), + py::arg("hostname"), py::arg("port"), py::arg("is_master"), + py::arg("world_size"), py::arg("timeout"), + py::call_guard()) .def("add", &TCPStore::add) .def("get", &TCPStore::get); } diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index a4a1d07db2c..3b5644764a5 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -35,6 +35,11 @@ limitations under the License. */ #include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h" #endif +#if defined(PADDLE_WITH_GLOO) +#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h" +#include "paddle/fluid/distributed/store/tcp_store.h" +#endif + namespace py = pybind11; namespace paddle { @@ -42,6 +47,14 @@ namespace pybind { using Tensor = paddle::experimental::Tensor; +#if defined(PADDLE_WITH_GLOO) +using ProcessGroupGloo = paddle::distributed::ProcessGroupGloo; +using GlooStore = paddle::distributed::ProcessGroupGloo::GlooStore; +using GlooOptions = paddle::distributed::ProcessGroupGloo::GlooOptions; +#endif + +static std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; // NOLINT + void BindDistributed(py::module *m) { py::enum_(*m, "ReduceOp") .value("SUM", distributed::ReduceOp::SUM) @@ -129,6 +142,7 @@ void BindDistributed(py::module *m) { *m, "ProcessGroupNCCL", ProcessGroup) .def(py::init(), py::call_guard()); +#endif py::class_>(*m, "task") @@ -138,7 +152,6 @@ void BindDistributed(py::module *m) { py::call_guard()) .def("synchronize", &distributed::ProcessGroup::Task::Synchronize, py::call_guard()); -#endif // define parallel strategy, it will be removed py::class_ pg_strategy( @@ -178,6 +191,45 @@ void BindDistributed(py::module *m) { self.nrings_ = nrings; }); +#if defined(PADDLE_WITH_GLOO) + py::class_(*m, "GlooOptions") + .def(py::init<>()) + .def_readwrite("_device", &GlooOptions::device) + .def_static("create", &GlooOptions::create); + + py::class_>(*m, "GlooStore") + .def(py::init( + [](const std::shared_ptr &store) { + return std::make_shared(store); + }), + py::call_guard()); + + py::class_>( + *m, "ProcessGroupGloo", ProcessGroup) + .def(py::init &, int, int, + std::shared_ptr &>(), + py::call_guard()) + .def(py::init([](const std::shared_ptr &store, int rank, + int world_size) { + auto opts = GlooOptions::create(); + char *ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str()); + if (ifname && strlen(ifname) > 1) { + opts->device = ProcessGroupGloo::createDeviceForInterface( + std::string(ifname)); + } else { + opts->device = ProcessGroupGloo::createDefaultDevice(); + } + return std::make_shared(store, rank, world_size, + opts); + }), + py::arg("store"), py::arg("rank"), + py::arg("world_size"), // py::arg("timeout") = + // kProcessGroupDefaultTimeout, + py::call_guard()) + .def_static("create_default_device", + &ProcessGroupGloo::createDefaultDevice); +#endif + m->def("eager_assign_group_by_size", [](py::handle py_tensors, std::vector is_sparse_gradient, std::vector group_size_limits, diff --git a/python/paddle/fluid/tests/unittests/process_group_gloo.py b/python/paddle/fluid/tests/unittests/process_group_gloo.py new file mode 100644 index 00000000000..5420e1d36b3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/process_group_gloo.py @@ -0,0 +1,119 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import random +import numpy as np +import os +import shutil + +import paddle +from paddle.fluid import core +import datetime +from datetime import timedelta +import paddle.fluid.core as core +from paddle.fluid.framework import _test_eager_guard +from paddle.fluid.dygraph.parallel import ParallelEnv + + +class TestProcessGroupFp32(unittest.TestCase): + def setUp(self): + paddle.seed(2022) + random.seed(2022) + np.random.seed(2022) + self.config() + + def config(self): + self.dtype = "float32" + self.shape = (2, 10, 5) + + def test_create_process_group_gloo(self): + with _test_eager_guard(): + nranks = ParallelEnv().nranks + rank = ParallelEnv().local_rank + is_master = True if rank == 0 else False + store = paddle.fluid.core.TCPStore("127.0.0.1", 6172, is_master, + nranks, datetime.timedelta(0)) + gloo_store = paddle.fluid.core.GlooStore(store) + opt = paddle.fluid.core.GlooOptions() + pg = paddle.fluid.core.ProcessGroupGloo(gloo_store, rank, nranks) + + # test allreduce sum + # rank 0 + paddle.device.set_device('cpu') + x = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + # rank 1 + y = np.random.random(self.shape).astype(self.dtype) + tensor_y = paddle.to_tensor(y) + + sum_result = x + y + if rank == 0: + task = pg.allreduce(tensor_x) + task.wait() + assert np.array_equal(tensor_x, sum_result) + else: + task = pg.allreduce(tensor_y) + task.wait() + assert np.array_equal(tensor_y, sum_result) + + print("test allreduce sum api ok") + + # test allreduce max + # rank 0 + x = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + # rank 1 + y = np.random.random(self.shape).astype(self.dtype) + tensor_y = paddle.to_tensor(y) + + max_result = paddle.maximum(tensor_x, tensor_y) + + if rank == 0: + task = pg.allreduce(tensor_x, core.ReduceOp.MAX) + task.wait() + assert np.array_equal(tensor_x, max_result) + else: + task = pg.allreduce(tensor_y, core.ReduceOp.MAX) + task.wait() + assert np.array_equal(tensor_y, max_result) + + print("test allreduce max api ok") + + # test broadcast + # rank 0 + x = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + # rank 1 + y = np.random.random(self.shape).astype(self.dtype) + tensor_y = paddle.to_tensor(y) + + broadcast_result = paddle.assign(tensor_x) + if rank == 0: + task = pg.broadcast(tensor_x, 0) + task.synchronize() + assert task.is_completed() + assert np.array_equal(broadcast_result, tensor_x) + else: + task = pg.broadcast(tensor_y, 0) + task.synchronize() + assert task.is_completed() + assert np.array_equal(broadcast_result, tensor_y) + print("test broadcast api ok") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_collective_process_group.py b/python/paddle/fluid/tests/unittests/test_collective_process_group.py index 6ae5424a882..58baa0a2fa9 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_process_group.py +++ b/python/paddle/fluid/tests/unittests/test_collective_process_group.py @@ -22,6 +22,9 @@ class TestProcessGroup(TestMultipleGpus): def test_process_group_nccl(self): self.run_mnist_2gpu('process_group_nccl.py') + def test_process_group_gloo(self): + self.run_mnist_2gpu('process_group_gloo.py') + if __name__ == "__main__": unittest.main() -- GitLab