// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/distributed/store/tcp_utils.h" namespace paddle { namespace distributed { enum class ReplyType { WAITING, STOP_WAIT }; enum class Command { ADD, GET, SET, WAIT, STOP }; namespace detail { class MasterDaemon { public: static std::unique_ptr start(SocketType listen_socket, int nranks, int stop_check_timeout); MasterDaemon() = delete; explicit MasterDaemon(SocketType listen_socket, int nranks, int stop_check_timeout); ~MasterDaemon(); private: void run(); void _do_add(SocketType socket); void _do_wait(SocketType socket); void _do_get(SocketType socket); void _do_set(SocketType socket); void _do_stop(SocketType socket); SocketType _listen_socket; std::vector _sockets; std::unordered_map> _store; std::thread _background_thread{}; int _nranks; int _stop_check_timeout; bool _stop = false; // all workers stopped std::chrono::time_point _stop_time; bool _has_stop = false; // at least one worker stopped }; class TCPServer { public: TCPServer() = default; static std::unique_ptr create(std::uint16_t port, int nranks, int stop_check_timeout); private: std::unique_ptr _master_daemon; }; class TCPClient { public: explicit TCPClient(SocketType socket) : _socket{socket} {} static std::unique_ptr connect(const std::string host, uint16_t port); ~TCPClient() { tcputils::close_socket(_socket); } void send_command_for_key(Command type, const std::string& key); template void send_value(const T& value); template void send_vector(const std::vector& value); template std::vector receive_vector(); template T receive_value(); private: SocketType _socket; }; } // namespace detail class TCPStore : public Store { public: static constexpr std::uint16_t kDefaultPort = 6170; explicit TCPStore(std::string host, uint16_t port = kDefaultPort, bool is_master = false, size_t num_workers = 1, std::chrono::seconds timeout = tcputils::kDefaultTimeout, int stop_check_timeout = 900); ~TCPStore(); int64_t add(const std::string& key, int64_t value) override; std::vector get(const std::string& key) override; void wait(const std::string& key) override; void set(const std::string& key, const std::vector& value) override; private: void waitWorkers(); std::unique_ptr _server; std::unique_ptr _client; const std::string _init_key = "init/"; const std::string _key_prefix = "/"; std::chrono::seconds _timeout; bool _is_master; int _num_workers; }; } // namespace distributed } // namespace paddle