diff --git a/paddle/fluid/distributed/store/CMakeLists.txt b/paddle/fluid/distributed/store/CMakeLists.txt index cfab4aad5f795ec04447e64e73fe907fc832b155..111a8e95d38bb87d9f5370608c200500c0f53321 100644 --- a/paddle/fluid/distributed/store/CMakeLists.txt +++ b/paddle/fluid/distributed/store/CMakeLists.txt @@ -1,4 +1,11 @@ cc_library( tcp_store - SRCS tcp_store.cc tcp_utils.cc + SRCS tcp_store.cc tcp_utils.cc socket.cpp DEPS enforce glog) + +if(NOT WIN32) + cc_test( + test_c_tcp_store + SRCS test_tcp_store.cc + DEPS tcp_store) +endif() diff --git a/paddle/fluid/distributed/store/socket.cpp b/paddle/fluid/distributed/store/socket.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ca6dc0f02902af6f0b5ccf402f1a64def9e4f91f --- /dev/null +++ b/paddle/fluid/distributed/store/socket.cpp @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/store/socket.h" + +#ifndef _WIN32 +#include +#include +#include +#include +#endif +#include +#include + +namespace paddle { +namespace distributed { + +#ifdef _WIN32 +static int _get_sockname_of_win(int sock, char* out, int out_len) { + snprintf(out, out_len, "not support win now"); + return 0; +} +#else +static int _get_sockname(int sock, char *out, int out_len) { + struct sockaddr_in addr; + socklen_t s_len = sizeof(addr); + + if (::getpeername(sock, reinterpret_cast(&addr), &s_len)) { + ::snprintf( + out, out_len, "can't getsocketname of %d, errno:%d", sock, errno); + return -1; + } + + char ip[128]; + int port = 0; + + // deal with both IPv4 and IPv6: + if (addr.sin_family == AF_INET) { + struct sockaddr_in *s = (struct sockaddr_in *)&addr; + port = ntohs(s->sin_port); + ::inet_ntop(AF_INET, &s->sin_addr, ip, sizeof(ip)); + } else { // AF_INET6 + struct sockaddr_in6 *s = (struct sockaddr_in6 *)&addr; + port = ntohs(s->sin6_port); + ::inet_ntop(AF_INET6, &s->sin6_addr, ip, sizeof(ip)); + } + + ::snprintf(out, out_len, "%s:%d", ip, port); + return 0; +} +#endif + +int GetSockName(int sock, char* out, int out_len) { +#ifdef _WIN32 + return _get_sockname_of_win(sock, out, out_len); +#else + return _get_sockname(sock, out, out_len); +#endif +} + +std::string GetSockName(int fd) { + char out[256]; + GetSockName(fd, out, sizeof(out)); + return std::string(out); +} + +}; // namespace distributed +}; // namespace paddle diff --git a/paddle/fluid/distributed/store/socket.h b/paddle/fluid/distributed/store/socket.h new file mode 100644 index 0000000000000000000000000000000000000000..f423d2643354bd4d7aa83defc7ea98e35dd26bb3 --- /dev/null +++ b/paddle/fluid/distributed/store/socket.h @@ -0,0 +1,26 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace paddle { +namespace distributed { + +int GetSockName(int fd, char* out, int out_len); + +std::string GetSockName(int fd); +}; // namespace distributed +}; // namespace paddle diff --git a/paddle/fluid/distributed/store/store.h b/paddle/fluid/distributed/store/store.h index 7b4ae7e70ff6f033e038f1c5214f46e0876257d2..eb329276d67b1ac446276c08eab6ad57f7041cfb 100644 --- a/paddle/fluid/distributed/store/store.h +++ b/paddle/fluid/distributed/store/store.h @@ -25,8 +25,8 @@ namespace distributed { class Store { public: - Store() : _timeout(tcputils::kNoTimeout) {} - explicit Store(const std::chrono::seconds& timeout) : _timeout(timeout) {} + Store() : _timeout(900) {} + explicit Store(const int timeout) : _timeout(timeout) {} virtual ~Store() = default; virtual int64_t add(const std::string& key, int64_t value) { @@ -46,10 +46,10 @@ class Store { "Implement the add method in the subclass.")); } - virtual const std::chrono::seconds& timeout() const { return _timeout; } + virtual int timeout() { return _timeout; } - private: - std::chrono::seconds _timeout; + protected: + int _timeout; }; } // namespace distributed diff --git a/paddle/fluid/distributed/store/tcp_store.cc b/paddle/fluid/distributed/store/tcp_store.cc index a46b4b32c9f1857b2cc70be3e7cd05befc35edbc..a67ca29a543ab98ec016590f6d2e7421a56e3125 100644 --- a/paddle/fluid/distributed/store/tcp_store.cc +++ b/paddle/fluid/distributed/store/tcp_store.cc @@ -29,25 +29,28 @@ namespace detail { constexpr int INFTIME = 10000; // 10 seconds -std::unique_ptr MasterDaemon::start(SocketType socket, int nranks, - int stop_check_timeout) { - return std::make_unique(socket, nranks, stop_check_timeout); +std::unique_ptr MasterDaemon::start(SocketType socket, + int nranks, + int timeout) { + VLOG(4) << ("begin to run start"); + return std::make_unique(socket, nranks, timeout); } -MasterDaemon::MasterDaemon(SocketType socket, int nranks, - int stop_check_timeout) - : _listen_socket(socket), - _nranks(nranks), - _stop_check_timeout(stop_check_timeout) { +MasterDaemon::MasterDaemon(SocketType socket, int nranks, int timeout) + : _listen_socket(socket), _nranks(nranks), _timeout(timeout) { + InitControlFd(); _background_thread = std::thread{&MasterDaemon::run, this}; } MasterDaemon::~MasterDaemon() { + VLOG(4) << ("begin to destruct MasterDaemon"); + StopByControlFd(); _background_thread.join(); tcputils::close_socket(_listen_socket); for (SocketType socket : _sockets) { tcputils::close_socket(socket); } + CloseControlFd(); } void MasterDaemon::_do_add(SocketType socket) { @@ -66,31 +69,34 @@ void MasterDaemon::_do_add(SocketType socket) { std::string new_value_str = std::to_string(new_value); _store[key] = std::vector(new_value_str.begin(), new_value_str.end()); - VLOG(3) << "TCPStore: new value (" << new_value << ") for key (" << key - << ")."; + VLOG(4) << "TCPStore: new value (" << new_value << ") for key (" << key + << ") " << GetSockName(socket); tcputils::send_value(socket, new_value); } void MasterDaemon::_do_set(SocketType socket) { - VLOG(3) << "MasterDaemon::_do_set"; std::string key = tcputils::receive_string(socket); + VLOG(4) << "MasterDaemon::_do_set key(" << key << ") " << GetSockName(socket); + auto value = tcputils::receive_vector(socket); _store[key] = value; } void MasterDaemon::_do_get(SocketType socket) { - VLOG(3) << "MasterDaemon::_do_get"; std::string key = tcputils::receive_string(socket); + VLOG(4) << "MasterDaemon::_do_get key(" << key << ") " << GetSockName(socket); + auto iter = _store.find(key); PADDLE_ENFORCE_NE( - iter, _store.end(), + iter, + _store.end(), platform::errors::InvalidArgument("Key %s not found in TCPStore.", key)); std::vector value = iter->second; tcputils::send_vector(socket, value); } void MasterDaemon::_do_stop(SocketType socket) { - VLOG(3) << "MasterDaemon::_do_stop"; + VLOG(4) << "MasterDaemon::_do_stop " << GetSockName(socket); if (!_has_stop) { _stop_time = std::chrono::system_clock::now(); } @@ -102,9 +108,40 @@ void MasterDaemon::_do_stop(SocketType socket) { } } +#ifndef _WIN32 +void MasterDaemon::InitControlFd() { + PADDLE_ENFORCE_NE( + pipe(_control_fd.data()), + -1, + platform::errors::Fatal("failed to cread control pipe errno:%d", errno)); +} +void MasterDaemon::CloseControlFd() { + for (int fd : _control_fd) { + if (fd != -1) { + ::close(fd); + } + } +} +void MasterDaemon::StopByControlFd() { + VLOG(4) << ("begin to run StopByControlFd"); + if (_control_fd[1] != -1) { + ::write(_control_fd[1], "\0", 1); + // close the write end of the pipe + ::close(_control_fd[1]); + _control_fd[1] = -1; + } +} +#else +void MasterDaemon::InitControlFd() {} +void MasterDaemon::CloseControlFd() {} +void MasterDaemon::StopByControlFd() {} +#endif + void MasterDaemon::_do_wait(SocketType socket) { - VLOG(3) << "MasterDaemon::_do_wait"; std::string key = tcputils::receive_string(socket); + VLOG(4) << "MasterDaemon::_do_wait key(" << key << ") " + << GetSockName(socket); + auto iter = _store.find(key); auto reply = ReplyType::STOP_WAIT; if (iter == _store.end()) { @@ -115,12 +152,67 @@ void MasterDaemon::_do_wait(SocketType socket) { tcputils::send_value(socket, reply); } +void MasterDaemon::ProcessCommands(std::vector* p_fds) { + std::vector& fds = *p_fds; + // FIXME(gongwb): Don't loop all fds of set just the fds who have event. +#ifdef _WIN32 + // 0: listen socket, so loop from 1. + for (size_t i = 1; i < fds.size(); i++) { +#else + // 0: listen socket, 1:controller pipe, so loop from 2. + for (size_t i = 2; i < fds.size(); i++) { +#endif + try { + if (fds[i].revents == 0) { + continue; + } + + Command command = tcputils::receive_value(fds[i].fd); + VLOG(3) << "TCPStore: recv command: " << static_cast(command) << "."; + + switch (command) { + case Command::ADD: + _do_add(fds[i].fd); + break; + case Command::GET: + _do_get(fds[i].fd); + break; + case Command::SET: + _do_set(fds[i].fd); + break; + case Command::WAIT: + _do_wait(fds[i].fd); + break; + case Command::STOP: + _do_stop(fds[i].fd); + break; + default: + LOG(WARNING) << "Unknown command: " << static_cast(command) + << " from addr info:" << GetSockName(fds[i].fd); + } + } catch (const std::exception& ex) { + fds.erase(fds.begin() + i); + tcputils::close_socket(fds[i].fd); +#ifdef _WIN32 + _sockets.erase(_sockets.begin() + i - 1); +#else + _sockets.erase(_sockets.begin() + i - 2); +#endif + + VLOG(3) << "Meet some exceptions during run:" << ex.what(); + } + } +} + void MasterDaemon::run() { + VLOG(4) << "begin to run run _stop:" << _stop << " _has_stop:" << _has_stop; std::vector fds; #ifdef _WIN32 fds.push_back({_listen_socket, POLLIN}); #else fds.push_back({.fd = _listen_socket, .events = POLLIN, .revents = 0}); + fds.push_back( + {.fd = _control_fd[0], .events = POLLIN | POLLHUP, .revents = 0}); #endif while (!_stop) { @@ -129,7 +221,8 @@ void MasterDaemon::run() { std::chrono::duration diff = end_time - _stop_time; int elapsed_seconds = static_cast(diff.count()); PADDLE_ENFORCE_LT( - elapsed_seconds, _stop_check_timeout, + elapsed_seconds, + _timeout, platform::errors::Fatal( "%d seconds elapsed after the first worker " "stopped, so we think there may be something wrong and will " @@ -138,16 +231,34 @@ void MasterDaemon::run() { " to change the timeout value in seconds. The default one is 900", elapsed_seconds)); } + for (size_t i = 0; i < fds.size(); i++) { fds[i].revents = 0; } + VLOG(9) << "begin to poll fds_size:" + << paddle::string::Sprintf("%d", fds.size()); #ifdef _WIN32 ::WSAPoll(fds.data(), fds.size(), INFTIME); #else ::poll(fds.data(), fds.size(), INFTIME); + + VLOG(9) << "begin to fds[1].revents:" + << paddle::string::Sprintf("%d", fds[1].revents); + // The control pipe receive shutdown event, and begin to close it. + if (fds[1].revents != 0) { + if (fds[1].revents & ~(POLLIN | POLLHUP)) { + PADDLE_THROW(paddle::platform::errors::Fatal("Undefined event type:%d", + fds[1].revents)); + } + VLOG(0) + << "receive shutdown event and so quit from MasterDaemon run loop"; + _stop = true; + break; + } #endif + // accept connect request. if (fds[0].revents != 0) { auto socket = tcputils::tcp_accept(_listen_socket); _sockets.emplace_back(socket); @@ -158,45 +269,12 @@ void MasterDaemon::run() { #endif } - for (size_t i = 1; i < fds.size(); i++) { - try { - if (fds[i].revents == 0) { - continue; - } - - Command command = tcputils::receive_value(fds[i].fd); - VLOG(3) << "TCPStore: recv command: " << static_cast(command) - << "."; - - switch (command) { - case Command::ADD: - _do_add(fds[i].fd); - break; - case Command::GET: - _do_get(fds[i].fd); - break; - case Command::SET: - _do_set(fds[i].fd); - break; - case Command::WAIT: - _do_wait(fds[i].fd); - break; - case Command::STOP: - _do_stop(fds[i].fd); - break; - default: - VLOG(0) << "Unknow command: " << static_cast(command); - exit(-1); - } - } catch (...) { - fds.erase(fds.begin() + i); - _sockets.erase(_sockets.begin() + i - 1); - } - } + ProcessCommands(&fds); } } -std::unique_ptr TCPServer::create(uint16_t port, int nranks, +std::unique_ptr TCPServer::create(uint16_t port, + int nranks, int stop_check_timeout) { int socket = tcputils::tcp_listen("", std::to_string(port), AF_INET); auto server = std::make_unique(); @@ -243,12 +321,21 @@ std::vector TCPClient::receive_vector() { } // namespace detail -TCPStore::TCPStore(std::string host, uint16_t port, bool is_master, - size_t num_workers, std::chrono::seconds timeout, - int stop_check_timeout) +TCPStore::TCPStore(std::string host, + uint16_t port, + bool is_master, + size_t num_workers, + int timeout) : Store(timeout), _is_master(is_master), _num_workers(num_workers) { + _timeout = timeout; + PADDLE_ENFORCE_GT( + timeout, + 0, + platform::errors::InvalidArgument("timeout must >= %d", timeout)); + + VLOG(3) << "input timeout" << timeout << ", member timeout:" << _timeout; if (_is_master) { - _server = detail::TCPServer::create(port, num_workers, stop_check_timeout); + _server = detail::TCPServer::create(port, num_workers, timeout); } _client = detail::TCPClient::connect(host, port); @@ -261,11 +348,13 @@ void TCPStore::waitWorkers() { } add(_init_key, 1); + VLOG(3) << paddle::string::Sprintf("_timeout:%d", _timeout); auto begin = std::chrono::steady_clock::now(); do { auto value = get(_init_key); int completed = std::stoi(std::string(value.begin(), value.end())); - VLOG(3) << completed << " worker ready, total " << _num_workers; + VLOG(3) << completed << " worker ready, total " << _num_workers + << ", _timeout:" << _timeout; if (completed >= _num_workers) { break; } @@ -273,9 +362,16 @@ void TCPStore::waitWorkers() { std::chrono::steady_clock::now() - begin); std::this_thread::sleep_for(std::chrono::milliseconds(100)); - if (_timeout != tcputils::kNoTimeout && elapsed > _timeout) { + if (_timeout != 0 && elapsed.count() > _timeout) { + LOG(FATAL) << paddle::string::Sprintf( + "_timeout:%d elapsed:%d (elapsed > _timeout)=%d", + _timeout, + elapsed.count(), + elapsed.count() > _timeout); + PADDLE_ENFORCE_EQ( - completed, _num_workers, + completed, + _num_workers, platform::errors::InvalidArgument( "TCPStore timeouted and not all workers got ready.")); } @@ -293,7 +389,7 @@ int64_t TCPStore::add(const std::string& key, int64_t value) { void TCPStore::set(const std::string& key, const std::vector& value) { VLOG(3) << "TCPStore set."; _client->send_command_for_key(Command::SET, _key_prefix + key); - _client->send_vector(value); + _client->send_vector(value); } std::vector TCPStore::get(const std::string& key) { @@ -314,14 +410,7 @@ void TCPStore::wait(const std::string& key) { } while (reply != ReplyType::STOP_WAIT); } -TCPStore::~TCPStore() { - VLOG(3) << "~TCPStore"; - _client->send_command_for_key(Command::STOP, ""); - ReplyType ret = _client->receive_value(); - PADDLE_ENFORCE_EQ(ret, ReplyType::STOP_WAIT, - platform::errors::InvalidArgument( - "The reply for TCPStore destructure must be 0.")); -} +TCPStore::~TCPStore() { VLOG(3) << "TCPStore destructure"; } } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/store/tcp_store.h b/paddle/fluid/distributed/store/tcp_store.h index 4ca9a673bf57562609e45d090741eab96f92a6c8..2d25987dc19e2bb542d7478efcbf3e9a5af08079 100644 --- a/paddle/fluid/distributed/store/tcp_store.h +++ b/paddle/fluid/distributed/store/tcp_store.h @@ -14,12 +14,23 @@ #pragma once +#include +#include +#include +#include + +#ifndef _WIN32 +#include +#endif + +#include #include #include #include #include #include +#include "paddle/fluid/distributed/store/socket.h" #include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/distributed/store/tcp_utils.h" @@ -35,14 +46,16 @@ class MasterDaemon { public: static std::unique_ptr start(SocketType listen_socket, int nranks, - int stop_check_timeout); + int timeout); MasterDaemon() = delete; - explicit MasterDaemon(SocketType listen_socket, int nranks, + explicit MasterDaemon(SocketType listen_socket, + int nranks, int stop_check_timeout); ~MasterDaemon(); private: void run(); + void ProcessCommands(std::vector* p_fds); void _do_add(SocketType socket); void _do_wait(SocketType socket); void _do_get(SocketType socket); @@ -52,17 +65,26 @@ class MasterDaemon { std::vector _sockets; std::unordered_map> _store; std::thread _background_thread{}; - int _nranks; - int _stop_check_timeout; + int _nranks = -1; + int _timeout = 0; bool _stop = false; // all workers stopped std::chrono::time_point _stop_time; bool _has_stop = false; // at least one worker stopped + + void InitControlFd(); + void CloseControlFd(); + void StopByControlFd(); +#ifdef _WIN32 +#else + std::array _control_fd{{-1, -1}}; +#endif }; class TCPServer { public: TCPServer() = default; - static std::unique_ptr create(std::uint16_t port, int nranks, + static std::unique_ptr create(std::uint16_t port, + int nranks, int stop_check_timeout); private: @@ -94,13 +116,15 @@ class TCPClient { } // namespace detail +// TODO(gongwb) :Add IP6 support. class TCPStore : public Store { public: static constexpr std::uint16_t kDefaultPort = 6170; - explicit TCPStore(std::string host, uint16_t port = kDefaultPort, - bool is_master = false, size_t num_workers = 1, - std::chrono::seconds timeout = tcputils::kDefaultTimeout, - int stop_check_timeout = 900); + explicit TCPStore(std::string host, + uint16_t port = kDefaultPort, + bool is_master = false, + size_t num_workers = 1, + int timeout = 900); ~TCPStore(); @@ -116,7 +140,7 @@ class TCPStore : public Store { const std::string _init_key = "init/"; const std::string _key_prefix = "/"; - std::chrono::seconds _timeout; + bool _is_master; int _num_workers; }; diff --git a/paddle/fluid/distributed/store/test_tcp_store.cc b/paddle/fluid/distributed/store/test_tcp_store.cc new file mode 100644 index 0000000000000000000000000000000000000000..45bf56953598a9a4541b588c87e761f2a030e75c --- /dev/null +++ b/paddle/fluid/distributed/store/test_tcp_store.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "paddle/fluid/distributed/store/tcp_store.h" +#include "paddle/fluid/distributed/store/tcp_utils.h" + +#ifdef _WIN32 +#include +#endif + +namespace paddle { +namespace distributed { + +TEST(MasterDaemon, init) { + int socket = tcputils::tcp_listen("", std::to_string(0), AF_INET); + auto d = detail::MasterDaemon::start(socket, 1, 100); + printf("started to sleep 2s\n"); +#ifdef _WIN32 + Sleep(2 * 1000); +#else + usleep(2 * 1000 * 1000); +#endif + printf("end to reset\n"); + + d.reset(); +} + +/* now for only c compile test +TEST(TCPStore, init) { + TCPStore store("127.0.0.1", 6170, true, 1); + store.add("my", 3); + auto ret1 = store.get("my"); + store.add("my", 3); + auto ret2 = store.get("my"); + PADDLE_ENFORCE_EQ(ret1[0] + 3, ret2[0], + paddle::errors::Fatal("result of add is not right")); +} +*/ + +}; // namespace distributed +}; // namespace paddle diff --git a/paddle/fluid/pybind/communication.cc b/paddle/fluid/pybind/communication.cc index 418804df02879a5b6d20bd965280bc33b1a2f4c8..3803d132515a5beb7d5bd84e7f45b8107f1236e8 100644 --- a/paddle/fluid/pybind/communication.cc +++ b/paddle/fluid/pybind/communication.cc @@ -39,12 +39,14 @@ void BindTCPStore(py::module *m) { .def(py::init<>()) .def( "set", - [](distributed::Store &self, const std::string &key, + [](distributed::Store &self, + const std::string &key, const std::string &value) { std::vector data(value.begin(), value.end()); self.set(key, data); }, - py::arg("key"), py::arg("value"), + py::arg("key"), + py::arg("value"), py::call_guard()) .def( "get", @@ -54,24 +56,29 @@ void BindTCPStore(py::module *m) { return py::bytes(reinterpret_cast(data.data()), data.size()); }, - py::arg("key"), py::call_guard()) - .def("add", &distributed::Store::add, + py::arg("key"), + py::call_guard()) + .def("add", + &distributed::Store::add, py::call_guard()) - .def("wait", &distributed::Store::wait, + .def("wait", + &distributed::Store::wait, py::call_guard()); py::class_>(*m, "TCPStore", Store) - .def(py::init([](std::string hostname, uint16_t port, bool is_master, - size_t world_size, std::chrono::seconds timeout, - int stop_check_timeout) { - return std::make_shared(hostname, port, is_master, - world_size, timeout, - stop_check_timeout); + .def(py::init([](std::string hostname, + uint16_t port, + bool is_master, + size_t world_size, + int timeout) { + return std::make_shared( + hostname, port, is_master, world_size, timeout); }), - py::arg("hostname"), py::arg("port"), py::arg("is_master"), + py::arg("hostname"), + py::arg("port"), + py::arg("is_master"), py::arg("world_size"), - py::arg("timeout") = distributed::tcputils::kNoTimeout, - py::arg("stop_check_timeout") = 900, + py::arg("timeout") = 900, py::call_guard()); } diff --git a/python/paddle/distributed/fleet/launch.py b/python/paddle/distributed/fleet/launch.py index 583043c186abfac12c53d32f5aaf328294edd280..78c0d9f1a741af72b7bef82ad4c094ac7606002b 100644 --- a/python/paddle/distributed/fleet/launch.py +++ b/python/paddle/distributed/fleet/launch.py @@ -276,6 +276,7 @@ def get_cluster_from_args(args, device_mode, devices_per_proc): free_ports = find_free_ports(len(devices_per_proc)) if free_ports is not None: free_ports = list(free_ports) + logger.info("find free ports:{}".format(free_ports)) else: start_port = 6070 if os.environ.get('FLAGS_START_PORT') is not None: diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index 79b680ef2d187700718f4835fede6b4bcdbeb112..52d19ae52b2bafca056ac4906d5849bb319842fb 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -240,7 +240,7 @@ def init_parallel_env(): master_port, is_master, world_size, - stop_check_timeout=stop_check_timeout) + timeout=stop_check_timeout) _set_default_store(default_store) pg = _new_process_group_impl(backend, default_store, diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 45a6c815728d1728b15090bb9eff120ca399a7eb..cf717bd84fa0c9ad9274ad3ece7fea057dac50ca 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -129,6 +129,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_dpmppp) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_cost_model) +list(APPEND MIXED_DIST_TEST_OPS test_tcp_store) foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) endforeach() @@ -868,6 +869,7 @@ if(WITH_DISTRIBUTE) test_auto_parallel_reshard_dpmppp ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_cost_model MODULES test_auto_parallel_cost_model ENVS ${dist_ENVS}) + if(WITH_GPU OR WITH_XPU OR WITH_ASCEND @@ -895,6 +897,21 @@ if(WITH_DISTRIBUTE) list(REMOVE_ITEM DIST_TEST_OPS "test_dist_mnist_dgc_nccl") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_se_resnext_dgc") endif() + + # port range (20000, 23000) is reserved for dist-ops + set(dist_ut_port 20001) + if(NOT WIN32) + bash_test_modules( + test_tcp_store + START_BASH + dist_test.sh + LABELS + "RUN_TYPE=EXCLUSIVE" + ENVS + "PADDLE_DIST_UT_PORT=${dist_ut_port}") + math(EXPR dist_ut_port "${dist_ut_port}+1") + endif() + if(NOT APPLE) if(WITH_GPU OR WITH_ROCM) bash_test_modules(test_c_comm_init_op START_BASH test_c_comm_init_op.sh @@ -930,7 +947,6 @@ if(WITH_DISTRIBUTE) endif() # port range (20000, 23000) is reserved for dist-ops - set(dist_ut_port 20001) foreach(TEST_OP ${DIST_TEST_OPS}) bash_test_modules( ${TEST_OP} diff --git a/python/paddle/fluid/tests/unittests/process_group_gloo.py b/python/paddle/fluid/tests/unittests/process_group_gloo.py index f18d73842bdb6aadb590fa3eaac3ca3384772778..0dcf740586e78dd14d3ee69008bb223945af7b7d 100644 --- a/python/paddle/fluid/tests/unittests/process_group_gloo.py +++ b/python/paddle/fluid/tests/unittests/process_group_gloo.py @@ -47,7 +47,7 @@ class TestProcessGroupFp32(unittest.TestCase): rank = ParallelEnv().local_rank is_master = True if rank == 0 else False store = paddle.fluid.core.TCPStore("127.0.0.1", 6272, is_master, - nranks, datetime.timedelta(0)) + nranks, 30) place = paddle.fluid.core.CPUPlace() pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks, place) @@ -64,11 +64,11 @@ class TestProcessGroupFp32(unittest.TestCase): if rank == 0: task = pg.allreduce(tensor_x) task.wait() - assert np.array_equal(tensor_x, sum_result) + np.testing.assert_equal(tensor_x, sum_result) else: task = pg.allreduce(tensor_y) task.wait() - assert np.array_equal(tensor_y, sum_result) + np.testing.assert_equal(tensor_y, sum_result) print("test allreduce sum api ok") diff --git a/python/paddle/fluid/tests/unittests/test_tcp_store.py b/python/paddle/fluid/tests/unittests/test_tcp_store.py index a051519d634a55f068d2111252b33a5061f9414f..56f40e25e27b249e0ecc21ac3187183960f26085 100644 --- a/python/paddle/fluid/tests/unittests/test_tcp_store.py +++ b/python/paddle/fluid/tests/unittests/test_tcp_store.py @@ -17,13 +17,15 @@ from __future__ import print_function import unittest import datetime import paddle +import os class TestTCPStore(unittest.TestCase): def test_tcp_store(self): - store = paddle.fluid.core.TCPStore("127.0.0.1", 6170, True, 1, - datetime.timedelta(0)) + dist_port = int(os.getenv("PADDLE_DIST_UT_PORT", 6170)) + print("get dist_port:", dist_port) + store = paddle.fluid.core.TCPStore("127.0.0.1", dist_port, True, 1, 1) store.add("my", 3) ret1 = store.get('my') store.add("my", 3)