diff --git a/paddle/fluid/distributed/CMakeLists.txt b/paddle/fluid/distributed/CMakeLists.txt index 5ae2e26e87c7b33a75325f5b585ca115bd3b6308..1527b752c6906cd6c79f5250b4876a233961e02a 100644 --- a/paddle/fluid/distributed/CMakeLists.txt +++ b/paddle/fluid/distributed/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(store) + if(NOT WITH_PSCORE) add_subdirectory(fleet_executor) return() diff --git a/paddle/fluid/distributed/store/CMakeLists.txt b/paddle/fluid/distributed/store/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1fde447d97dd99783a77a9a2ad89b4457b55ca74 --- /dev/null +++ b/paddle/fluid/distributed/store/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(tcp_store SRCS tcp_store.cc tcp_utils.cc DEPS enforce glog) diff --git a/paddle/fluid/distributed/store/store.h b/paddle/fluid/distributed/store/store.h new file mode 100644 index 0000000000000000000000000000000000000000..2673314d222d2b32e42c42a3a94df71a1887914a --- /dev/null +++ b/paddle/fluid/distributed/store/store.h @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include + +#include "paddle/fluid/distributed/store/tcp_utils.h" + +namespace paddle { +namespace distributed { + +class Store { + public: + Store() = delete; + explicit Store(const std::chrono::seconds& timeout) : _timeout(timeout) {} + virtual ~Store() = default; + + virtual int64_t add(const std::string& key, int64_t value) = 0; + virtual std::vector get(const std::string& key) = 0; + virtual void wait(const std::string& key) = 0; + + virtual const std::chrono::seconds& timeout() const { return _timeout; } + + private: + std::chrono::seconds _timeout; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/store/tcp_store.cc b/paddle/fluid/distributed/store/tcp_store.cc new file mode 100644 index 0000000000000000000000000000000000000000..de85ac0d910e93257a308052ca1fcf193680a183 --- /dev/null +++ b/paddle/fluid/distributed/store/tcp_store.cc @@ -0,0 +1,272 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/distributed/store/tcp_store.h" +#include "paddle/fluid/distributed/store/tcp_utils.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace distributed { + +namespace detail { + +constexpr int INFTIME = -1; + +std::unique_ptr MasterDaemon::start(SocketType socket) { + return std::make_unique(socket); +} + +MasterDaemon::MasterDaemon(SocketType socket) : _listen_socket(socket) { + _background_thread = std::thread{&MasterDaemon::run, this}; +} + +MasterDaemon::~MasterDaemon() { + _background_thread.join(); + tcputils::close_socket(_listen_socket); + for (SocketType socket : _sockets) { + tcputils::close_socket(socket); + } +} + +void MasterDaemon::_do_add(SocketType socket) { + int64_t new_value{}; + std::string key = tcputils::receive_string(socket); + new_value = tcputils::receive_value(socket); + std::vector old_value; + auto it = _store.find(key); + if (it != _store.end()) { + old_value = it->second; + char* buffer = reinterpret_cast(it->second.data()); + size_t len = old_value.size(); + new_value += std::stoll(std::string(buffer, len)); + } + + std::string new_value_str = std::to_string(new_value); + _store[key] = + std::vector(new_value_str.begin(), new_value_str.end()); + VLOG(3) << "TCPStore: new value (" << new_value << ") for key (" << key + << ")."; + tcputils::send_value(socket, new_value); +} + +void MasterDaemon::_do_get(SocketType socket) { + std::string key = tcputils::receive_string(socket); + auto iter = _store.find(key); + PADDLE_ENFORCE_NE( + iter, _store.end(), + platform::errors::InvalidArgument("Key %s not found in TCPStore.", key)); + std::vector value = iter->second; + VLOG(3) << "TCPStore: value (" + << std::stoll(std::string(reinterpret_cast(value.data()), + value.size())) + << ") for key (" << key << ")."; + tcputils::send_vector(socket, value); +} + +void MasterDaemon::_do_stop(SocketType socket) { + ReplyType value = ReplyType::STOP_WAIT; + _stop = true; + tcputils::send_value(socket, value); +} + +void MasterDaemon::_do_wait(SocketType socket) { + std::string key = tcputils::receive_string(socket); + auto iter = _store.find(key); + auto reply = ReplyType::STOP_WAIT; + if (iter == _store.end()) { + reply = ReplyType::WAITING; + } + VLOG(3) << "TCPStore: wait reply (" << static_cast(reply) + << ") for key (" << key << ")."; + tcputils::send_value(socket, reply); +} + +void MasterDaemon::run() { + std::vector fds; +#ifdef _WIN32 + fds.push_back({_listen_socket, POLLIN}); +#else + fds.push_back({.fd = _listen_socket, .events = POLLIN, .revents = 0}); +#endif + + while (!_stop) { + for (size_t i = 0; i < fds.size(); i++) { + fds[i].revents = 0; + } + +#ifdef _WIN32 + ::WSAPoll(fds.data(), fds.size(), INFTIME); +#else + ::poll(fds.data(), fds.size(), INFTIME); +#endif + + if (fds[0].revents != 0) { + auto socket = tcputils::tcp_accept(_listen_socket); + _sockets.emplace_back(socket); +#ifdef _WIN32 + fds.push_back({socket, POLLIN}); +#else + fds.push_back({.fd = socket, .events = POLLIN, .revents = 0}); +#endif + } + + for (size_t i = 1; i < fds.size(); i++) { + if (fds[i].revents == 0) { + continue; + } + + Command command = tcputils::receive_value(fds[i].fd); + VLOG(3) << "TCPStore: recv command: " << static_cast(command) << "."; + + switch (command) { + case Command::ADD: + _do_add(fds[i].fd); + break; + case Command::GET: + _do_get(fds[i].fd); + break; + case Command::WAIT: + _do_wait(fds[i].fd); + break; + case Command::STOP: + _do_stop(fds[i].fd); + break; + } + } + } +} + +std::unique_ptr TCPServer::create(uint16_t port) { + int socket = tcputils::tcp_listen("", std::to_string(port), AF_INET); + auto server = std::make_unique(); + server->_master_daemon = MasterDaemon::start(socket); + return server; +} + +std::unique_ptr TCPClient::connect(const std::string host, + uint16_t port) { + int socket = tcputils::tcp_connect(host, std::to_string(port), AF_INET); + return std::make_unique(socket); +} + +void TCPClient::send_command_for_key(Command type, const std::string& key) { + tcputils::send_value(_socket, type); + if (key.empty()) { + return; + } + tcputils::send_string(_socket, key); +} + +template +void TCPClient::send_value(const T& value) { + tcputils::send_bytes(_socket, &value, 1); +} + +template +T TCPClient::receive_value() { + T res; + tcputils::receive_bytes(_socket, &res, 1); + return res; +} + +template +void TCPClient::send_vector(const std::vector& value) { + tcputils::send_vector(_socket, value); +} + +template +std::vector TCPClient::receive_vector() { + return tcputils::receive_vector(_socket); +} + +} // namespace detail + +TCPStore::TCPStore(std::string host, uint16_t port, bool is_master, + size_t num_workers, std::chrono::seconds timeout) + : Store(timeout), _is_master(is_master), _num_workers(num_workers) { + if (_is_master) { + _server = detail::TCPServer::create(port); + } + + _client = detail::TCPClient::connect(host, port); + waitWorkers(); +} + +void TCPStore::waitWorkers() { + if (_num_workers == 0) { + return; + } + add(_init_key, 1); + + if (_server) { + auto begin = std::chrono::steady_clock::now(); + do { + auto value = get(_init_key); + int completed = std::stoi(std::string(value.begin(), value.end())); + VLOG(3) << completed << " worker ready, total " << _num_workers; + if (completed >= _num_workers) { + break; + } + const auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - begin); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + if (_timeout != tcputils::kNoTimeout && elapsed > _timeout) { + PADDLE_ENFORCE_EQ( + completed, _num_workers, + platform::errors::InvalidArgument( + "TCPStore timeouted and not all workers got ready.")); + } + } while (true); + } + VLOG(3) << "TCPStore initialized."; +} + +int64_t TCPStore::add(const std::string& key, int64_t value) { + _client->send_command_for_key(Command::ADD, _key_prefix + key); + _client->send_value(value); + return _client->receive_value(); +} + +std::vector TCPStore::get(const std::string& key) { + wait(key); + _client->send_command_for_key(Command::GET, _key_prefix + key); + VLOG(3) << "TCPStore get."; + return _client->receive_vector(); +} + +void TCPStore::wait(const std::string& key) { + ReplyType reply; + do { + _client->send_command_for_key(Command::WAIT, _key_prefix + key); + + reply = _client->receive_value(); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } while (reply != ReplyType::STOP_WAIT); +} + +TCPStore::~TCPStore() { + _client->send_command_for_key(Command::STOP, ""); + ReplyType ret = _client->receive_value(); + PADDLE_ENFORCE_EQ(ret, ReplyType::STOP_WAIT, + platform::errors::InvalidArgument( + "The reply for TCPStore destructure must be 0.")); +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/store/tcp_store.h b/paddle/fluid/distributed/store/tcp_store.h new file mode 100644 index 0000000000000000000000000000000000000000..cd706dd6640acf5e0b5b3714175dac7a6cecb25a --- /dev/null +++ b/paddle/fluid/distributed/store/tcp_store.h @@ -0,0 +1,114 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/distributed/store/store.h" +#include "paddle/fluid/distributed/store/tcp_utils.h" + +namespace paddle { +namespace distributed { + +enum class ReplyType { WAITING, STOP_WAIT }; +enum class Command { ADD, GET, WAIT, STOP }; + +namespace detail { + +class MasterDaemon { + public: + static std::unique_ptr start(SocketType listen_socket); + MasterDaemon() = delete; + explicit MasterDaemon(SocketType listen_socket); + ~MasterDaemon(); + + private: + void run(); + void _do_add(SocketType socket); + void _do_wait(SocketType socket); + void _do_get(SocketType socket); + void _do_stop(SocketType socket); + SocketType _listen_socket; + std::vector _sockets; + std::unordered_map> _store; + std::thread _background_thread{}; + bool _stop = false; +}; + +class TCPServer { + public: + TCPServer() = default; + static std::unique_ptr create(std::uint16_t port); + + private: + std::unique_ptr _master_daemon; +}; + +class TCPClient { + public: + explicit TCPClient(SocketType socket) : _socket{socket} {} + static std::unique_ptr connect(const std::string host, + uint16_t port); + ~TCPClient() { tcputils::close_socket(_socket); } + void send_command_for_key(Command type, const std::string& key); + + template + void send_value(const T& value); + + template + void send_vector(const std::vector& value); + template + std::vector receive_vector(); + + template + T receive_value(); + + private: + SocketType _socket; +}; + +} // namespace detail + +class TCPStore : public Store { + public: + static constexpr std::uint16_t kDefaultPort = 6170; + explicit TCPStore(std::string host, uint16_t port = kDefaultPort, + bool is_master = false, size_t num_workers = 1, + std::chrono::seconds timeout = tcputils::kDefaultTimeout); + + ~TCPStore(); + + int64_t add(const std::string& key, int64_t value) override; + std::vector get(const std::string& key) override; + void wait(const std::string& key) override; + + private: + void waitWorkers(); + std::unique_ptr _server; + std::unique_ptr _client; + + const std::string _init_key = "init/"; + const std::string _key_prefix = "/"; + std::chrono::seconds _timeout; + bool _is_master; + int _num_workers; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/store/tcp_utils.cc b/paddle/fluid/distributed/store/tcp_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..d0561d0b9a9c5b01c32620e72d21ed562e42637e --- /dev/null +++ b/paddle/fluid/distributed/store/tcp_utils.cc @@ -0,0 +1,201 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/store/tcp_utils.h" +#include +#include +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace distributed { +namespace tcputils { + +std::error_code socket_error() { +#ifdef _WIN32 + return std::error_code{::WSAGetLastError(), std::generic_category()}; +#else + return std::error_code{errno, std::generic_category()}; +#endif +} + +void close_socket(SocketType socket) { +#ifdef _WIN32 + ::closesocket(socket); +#else + ::close(socket); +#endif +} + +::addrinfo* get_addr_info(const std::string host, const std::string port, + int ai_flags, int family) { + ::addrinfo hints{}, *res; + hints.ai_flags = ai_flags; + hints.ai_family = family; + hints.ai_socktype = SOCK_STREAM; + + const char* node = host.empty() ? nullptr : host.c_str(); + + int n; + n = ::getaddrinfo(node, port.c_str(), &hints, &res); + const char* gai_err = ::gai_strerror(n); + const char* proto = + (family == AF_INET ? "IPv4" : family == AF_INET6 ? "IPv6" : ""); + PADDLE_ENFORCE_EQ( + n, 0, platform::errors::InvalidArgument( + "%s network %s:%s cannot be obtained. Details: %s.", proto, + host, port, gai_err)); + + return res; +} + +void free_addr_info(::addrinfo* hint) { + PADDLE_ENFORCE_NOT_NULL( + hint, platform::errors::InvalidArgument( + "The parameter for free_addr_info cannot be null.")); + ::freeaddrinfo(hint); +} + +SocketType tcp_connect(const std::string host, const std::string port, + int family, std::chrono::seconds timeout) { + int ai_flags = AI_NUMERICSERV | AI_V4MAPPED | AI_ALL; + ::addrinfo* res = get_addr_info(host, port, ai_flags, family); + + SocketType sockfd = -1; + bool retry = true; + auto deadline = std::chrono::steady_clock::now() + timeout; + do { + for (::addrinfo* cur = res; cur != nullptr; cur = cur->ai_next) { + sockfd = ::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol); + PADDLE_ENFORCE_GT(sockfd, 0, platform::errors::InvalidArgument( + "Create socket to connect %s:%s failed. " + "Details: %s. ", + host, port, socket_error().message())); + + if (::connect(sockfd, cur->ai_addr, cur->ai_addrlen) == 0) { + retry = false; + break; + } + VLOG(0) << "Retry to connect to " << host << ":" << port + << " while the server is not yet listening."; + close_socket(sockfd); + sockfd = -1; + std::this_thread::sleep_for(kDelay); + if (timeout != kNoTimeout && + std::chrono::steady_clock::now() >= deadline) { + retry = false; + break; + } + } + + if (timeout != kNoTimeout && std::chrono::steady_clock::now() >= deadline) { + retry = false; + } + } while (retry); + + free_addr_info(res); + + PADDLE_ENFORCE_GT(sockfd, 0, + platform::errors::InvalidArgument( + "Network %s:%s cannot be connected.", host, port)); + VLOG(0) << "Successfully connected to " << host << ":" << port; + + return sockfd; +} + +SocketType tcp_listen(const std::string host, const std::string port, + int family) { + int ai_flags = AI_PASSIVE | AI_NUMERICSERV; + ::addrinfo* res = get_addr_info(host, port, ai_flags, family); + ::addrinfo* cur = res; + SocketType sockfd{}; + + std::string node = host.empty() ? "IP_ANY" : host; + while (cur) { + sockfd = ::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol); + if (sockfd < 0) { + VLOG(0) << "Cannot create socket on " << node << ":" << port + << ". Details: " << socket_error().message(); + cur = cur->ai_next; + continue; + } + + int on = 1; +#ifdef _WIN32 + int ret = ::setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, + reinterpret_cast(&on), sizeof(on)); +#else + int ret = ::setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); +#endif + if (ret < 0) { + VLOG(0) << "Set the address reuse option failed on the server."; + } + if (::bind(sockfd, res->ai_addr, res->ai_addrlen) == 0) { + break; + } + close_socket(sockfd); + sockfd = -1; + cur = cur->ai_next; + } + + PADDLE_ENFORCE_GT(sockfd, 0, + platform::errors::InvalidArgument( + "Bind network on %s:%s failedd.", node, port)); + + ::listen(sockfd, LISTENQ); + + VLOG(0) << "The server starts to listen on " << node << ":" << port; + return sockfd; +} + +SocketType tcp_accept(SocketType socket) { + ::sockaddr_storage addr_s{}; + ::socklen_t addr_len = sizeof(addr_s); + SocketType new_socket = + ::accept(socket, reinterpret_cast<::sockaddr*>(&addr_s), &addr_len); + PADDLE_ENFORCE_GT( + new_socket, 0, + platform::errors::InvalidArgument( + "The server failed to accept a new connection. Details: %s.", + socket_error().message())); +#ifndef _WIN32 + ::fcntl(new_socket, F_SETFD, FD_CLOEXEC); +#endif + auto value = 1; +#ifdef _WIN32 + ::setsockopt(new_socket, IPPROTO_TCP, TCP_NODELAY, + reinterpret_cast(&value), sizeof(value)); +#else + ::setsockopt(new_socket, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)); +#endif + return new_socket; +} + +void send_string(SocketType socket, const std::string& s) { + std::string::size_type size = s.size(); + send_bytes(socket, &size, 1); + send_bytes(socket, s.data(), size); +} + +std::string receive_string(SocketType socket) { + std::string::size_type size; + receive_bytes(socket, &size, 1); + std::vector v(size); + receive_bytes(socket, v.data(), size); + return std::string(v.data(), v.size()); +} + +} // namespace tcputils +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/store/tcp_utils.h b/paddle/fluid/distributed/store/tcp_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..60cb3de124da3593f3d07ffadcf3b12c2deedf29 --- /dev/null +++ b/paddle/fluid/distributed/store/tcp_utils.h @@ -0,0 +1,133 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef _WIN32 +#include +#include +#pragma comment(lib, "Ws2_32.lib") +#else +#include +#include +#include +#include +#include +#include +#endif +#include +#include +#include +#include "paddle/fluid/platform/enforce.h" + +// Utility functions for TCP socket. +namespace paddle { +namespace distributed { + +#ifdef _WIN32 +using SocketType = SOCKET; +#else +using SocketType = int; +#endif + +namespace tcputils { + +constexpr int LISTENQ = 2048; +constexpr std::chrono::seconds kDelay = std::chrono::seconds(3); +constexpr std::chrono::seconds kNoTimeout = std::chrono::seconds::zero(); +constexpr std::chrono::seconds kDefaultTimeout = std::chrono::seconds(360); + +std::error_code socket_error(); +void close_socket(SocketType socket); +::addrinfo* get_addr_info(const std::string host, const std::string port, + int ai_flags, int family); +void free_addr_info(::addrinfo*); +SocketType tcp_connect(const std::string host, const std::string port, + int family, std::chrono::seconds timeout = kNoTimeout); +SocketType tcp_listen(const std::string host, const std::string port, + int family); +SocketType tcp_accept(SocketType socket); + +void send_string(SocketType socket, const std::string& s); +std::string receive_string(SocketType socket); + +template +void send_bytes(SocketType socket, const T* buffer, size_t len) { + size_t to_send = len * sizeof(T); + if (to_send == 0) { + return; + } + + auto ptr = reinterpret_cast(buffer); + + while (to_send > 0) { + auto byte_sent = ::send(socket, ptr, to_send, 0); + PADDLE_ENFORCE_GT(byte_sent, 0, platform::errors::InvalidArgument( + "TCP send error. Details: %s.", + socket_error().message())); + to_send -= byte_sent; + ptr += byte_sent; + } +} + +template +void receive_bytes(SocketType socket, T* buffer, size_t len) { + size_t to_recv = len * sizeof(T); + if (to_recv == 0) { + return; + } + auto ptr = reinterpret_cast(buffer); + + while (to_recv > 0) { + auto byte_received = ::recv(socket, ptr, to_recv, 0); + PADDLE_ENFORCE_GT(byte_received, 0, platform::errors::InvalidArgument( + "TCP receive error. Details: %s.", + socket_error().message())); + + to_recv -= byte_received; + ptr += byte_received; + } +} + +template +void send_vector(SocketType socket, const std::vector& v) { + size_t size = v.size(); + send_bytes(socket, &size, 1); + send_bytes(socket, v.data(), size); +} + +template +std::vector receive_vector(SocketType socket) { + size_t size; + receive_bytes(socket, &size, 1); + std::vector res(size); + receive_bytes(socket, res.data(), size); + return res; +} + +template +void send_value(SocketType socket, const T& v) { + send_bytes(socket, &v, 1); +} + +template +T receive_value(SocketType socket) { + T v; + receive_bytes(socket, &v, 1); + return v; +} + +} // namespace tcputils +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 3453cff30f5ad2d1016dcd786733a7024ed0ae4a..26c35167f404a1dc475cc200ff2495eae9c09ab4 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -2,7 +2,7 @@ set(PYBIND_DEPS init pybind python proto_desc memory executor fleet_wrapper box_ feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator - cost_model cuda_graph_with_memory_pool fleet_executor global_utils pten_utils) + cost_model cuda_graph_with_memory_pool fleet_executor global_utils pten_utils tcp_store) if (WITH_PSCORE) set(PYBIND_DEPS ${PYBIND_DEPS} ps_service) @@ -73,6 +73,7 @@ set(PYBIND_SRCS compatible.cc io.cc generator_py.cc + communication.cc cuda_streams_py.cc) if(WITH_ASCEND) diff --git a/paddle/fluid/pybind/communication.cc b/paddle/fluid/pybind/communication.cc new file mode 100644 index 0000000000000000000000000000000000000000..a0d2777f825dc592e19230bc2ba4412f943d0c2b --- /dev/null +++ b/paddle/fluid/pybind/communication.cc @@ -0,0 +1,42 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/distributed/store/tcp_store.h" +#include "paddle/fluid/pybind/communication.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { + +using TCPStore = paddle::distributed::TCPStore; + +void BindTCPStore(py::module* m) { + py::class_(*m, "TCPStore") + .def( + py::init()) + .def("add", &TCPStore::add) + .def("get", &TCPStore::get); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/communication.h b/paddle/fluid/pybind/communication.h new file mode 100644 index 0000000000000000000000000000000000000000..17045ccfe65cae25471ceff3abf0129b2a21acb0 --- /dev/null +++ b/paddle/fluid/pybind/communication.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "pybind11/chrono.h" +#include "pybind11/complex.h" +#include "pybind11/functional.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace paddle { +namespace pybind { + +void BindTCPStore(pybind11::module* m); + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f653070b2eff7765aa4359a8405e1f27c6addf0b..58205041b80411cab305f79858df11b11ae0d075 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -91,6 +91,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/bind_cost_model.h" #include "paddle/fluid/pybind/bind_fleet_executor.h" #include "paddle/fluid/pybind/box_helper_py.h" +#include "paddle/fluid/pybind/communication.h" #include "paddle/fluid/pybind/compatible.h" #include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/data_set_py.h" @@ -2621,6 +2622,7 @@ All parameter, weight, gradient are variables in Paddle. BindGlobalValueGetterSetter(&m); BindProcessMeshDesc(&m); BindFleetExecutor(&m); + BindTCPStore(&m); py::class_(m, "LodRankTable") .def("items", [](framework::LoDRankTable &table) { diff --git a/python/paddle/fluid/tests/unittests/test_tcp_store.py b/python/paddle/fluid/tests/unittests/test_tcp_store.py new file mode 100644 index 0000000000000000000000000000000000000000..11e1e8cd059c8b5611bc2f25a2323df06fb00df6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tcp_store.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import datetime +import paddle + + +class TestTCPStore(unittest.TestCase): + def test_tcp_store(self): + store = paddle.fluid.core.TCPStore("127.0.0.1", 6170, True, 1, + datetime.timedelta(0)) + store.add("my", 3) + ret1 = store.get('my') + store.add("my", 3) + ret2 = store.get('my') + self.assertEqual(ret1[0] + 3, ret2[0]) + + +if __name__ == "__main__": + unittest.main()