未验证 提交 4c9330d6 编写于 作者: G gongweibao 提交者: GitHub

Fix hang bug of TCPStore (#43724)

* tmp fix

* init

* compile ok

* compile ok

* add vlogs

* add test

* fix termination error

* add testfile

* add

* fix window compile

* fix window compile

* fix windows compile

* fix windows compile

* fix windows compile

* fix windows compile

* fix windows compile

* fix windows compile

* fix kunlun compile

* fix compilation

* fix compilation

* fix compilation

* tmp fix

* add windows

* add windows

* add more logs

* change timeout to protected

* SB

* add

* add

* fix timeout

* add

* fix test

* fix test

* fix test

* fix ut

* fix ut

* fix ut
上级 491b87b4
cc_library( cc_library(
tcp_store tcp_store
SRCS tcp_store.cc tcp_utils.cc SRCS tcp_store.cc tcp_utils.cc socket.cpp
DEPS enforce glog) DEPS enforce glog)
if(NOT WIN32)
cc_test(
test_c_tcp_store
SRCS test_tcp_store.cc
DEPS tcp_store)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/store/socket.h"
#ifndef _WIN32
#include <arpa/inet.h>
#include <netinet/ip.h>
#include <sys/socket.h>
#include <unistd.h>
#endif
#include <errno.h>
#include <stdio.h>
namespace paddle {
namespace distributed {
#ifdef _WIN32
static int _get_sockname_of_win(int sock, char* out, int out_len) {
snprintf(out, out_len, "not support win now");
return 0;
}
#else
static int _get_sockname(int sock, char *out, int out_len) {
struct sockaddr_in addr;
socklen_t s_len = sizeof(addr);
if (::getpeername(sock, reinterpret_cast<sockaddr *>(&addr), &s_len)) {
::snprintf(
out, out_len, "can't getsocketname of %d, errno:%d", sock, errno);
return -1;
}
char ip[128];
int port = 0;
// deal with both IPv4 and IPv6:
if (addr.sin_family == AF_INET) {
struct sockaddr_in *s = (struct sockaddr_in *)&addr;
port = ntohs(s->sin_port);
::inet_ntop(AF_INET, &s->sin_addr, ip, sizeof(ip));
} else { // AF_INET6
struct sockaddr_in6 *s = (struct sockaddr_in6 *)&addr;
port = ntohs(s->sin6_port);
::inet_ntop(AF_INET6, &s->sin6_addr, ip, sizeof(ip));
}
::snprintf(out, out_len, "%s:%d", ip, port);
return 0;
}
#endif
int GetSockName(int sock, char* out, int out_len) {
#ifdef _WIN32
return _get_sockname_of_win(sock, out, out_len);
#else
return _get_sockname(sock, out, out_len);
#endif
}
std::string GetSockName(int fd) {
char out[256];
GetSockName(fd, out, sizeof(out));
return std::string(out);
}
}; // namespace distributed
}; // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace paddle {
namespace distributed {
int GetSockName(int fd, char* out, int out_len);
std::string GetSockName(int fd);
}; // namespace distributed
}; // namespace paddle
...@@ -25,8 +25,8 @@ namespace distributed { ...@@ -25,8 +25,8 @@ namespace distributed {
class Store { class Store {
public: public:
Store() : _timeout(tcputils::kNoTimeout) {} Store() : _timeout(900) {}
explicit Store(const std::chrono::seconds& timeout) : _timeout(timeout) {} explicit Store(const int timeout) : _timeout(timeout) {}
virtual ~Store() = default; virtual ~Store() = default;
virtual int64_t add(const std::string& key, int64_t value) { virtual int64_t add(const std::string& key, int64_t value) {
...@@ -46,10 +46,10 @@ class Store { ...@@ -46,10 +46,10 @@ class Store {
"Implement the add method in the subclass.")); "Implement the add method in the subclass."));
} }
virtual const std::chrono::seconds& timeout() const { return _timeout; } virtual int timeout() { return _timeout; }
private: protected:
std::chrono::seconds _timeout; int _timeout;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -29,25 +29,28 @@ namespace detail { ...@@ -29,25 +29,28 @@ namespace detail {
constexpr int INFTIME = 10000; // 10 seconds constexpr int INFTIME = 10000; // 10 seconds
std::unique_ptr<MasterDaemon> MasterDaemon::start(SocketType socket, int nranks, std::unique_ptr<MasterDaemon> MasterDaemon::start(SocketType socket,
int stop_check_timeout) { int nranks,
return std::make_unique<MasterDaemon>(socket, nranks, stop_check_timeout); int timeout) {
VLOG(4) << ("begin to run start");
return std::make_unique<MasterDaemon>(socket, nranks, timeout);
} }
MasterDaemon::MasterDaemon(SocketType socket, int nranks, MasterDaemon::MasterDaemon(SocketType socket, int nranks, int timeout)
int stop_check_timeout) : _listen_socket(socket), _nranks(nranks), _timeout(timeout) {
: _listen_socket(socket), InitControlFd();
_nranks(nranks),
_stop_check_timeout(stop_check_timeout) {
_background_thread = std::thread{&MasterDaemon::run, this}; _background_thread = std::thread{&MasterDaemon::run, this};
} }
MasterDaemon::~MasterDaemon() { MasterDaemon::~MasterDaemon() {
VLOG(4) << ("begin to destruct MasterDaemon");
StopByControlFd();
_background_thread.join(); _background_thread.join();
tcputils::close_socket(_listen_socket); tcputils::close_socket(_listen_socket);
for (SocketType socket : _sockets) { for (SocketType socket : _sockets) {
tcputils::close_socket(socket); tcputils::close_socket(socket);
} }
CloseControlFd();
} }
void MasterDaemon::_do_add(SocketType socket) { void MasterDaemon::_do_add(SocketType socket) {
...@@ -66,31 +69,34 @@ void MasterDaemon::_do_add(SocketType socket) { ...@@ -66,31 +69,34 @@ void MasterDaemon::_do_add(SocketType socket) {
std::string new_value_str = std::to_string(new_value); std::string new_value_str = std::to_string(new_value);
_store[key] = _store[key] =
std::vector<uint8_t>(new_value_str.begin(), new_value_str.end()); std::vector<uint8_t>(new_value_str.begin(), new_value_str.end());
VLOG(3) << "TCPStore: new value (" << new_value << ") for key (" << key VLOG(4) << "TCPStore: new value (" << new_value << ") for key (" << key
<< ")."; << ") " << GetSockName(socket);
tcputils::send_value<int64_t>(socket, new_value); tcputils::send_value<int64_t>(socket, new_value);
} }
void MasterDaemon::_do_set(SocketType socket) { void MasterDaemon::_do_set(SocketType socket) {
VLOG(3) << "MasterDaemon::_do_set";
std::string key = tcputils::receive_string(socket); std::string key = tcputils::receive_string(socket);
VLOG(4) << "MasterDaemon::_do_set key(" << key << ") " << GetSockName(socket);
auto value = tcputils::receive_vector<uint8_t>(socket); auto value = tcputils::receive_vector<uint8_t>(socket);
_store[key] = value; _store[key] = value;
} }
void MasterDaemon::_do_get(SocketType socket) { void MasterDaemon::_do_get(SocketType socket) {
VLOG(3) << "MasterDaemon::_do_get";
std::string key = tcputils::receive_string(socket); std::string key = tcputils::receive_string(socket);
VLOG(4) << "MasterDaemon::_do_get key(" << key << ") " << GetSockName(socket);
auto iter = _store.find(key); auto iter = _store.find(key);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
iter, _store.end(), iter,
_store.end(),
platform::errors::InvalidArgument("Key %s not found in TCPStore.", key)); platform::errors::InvalidArgument("Key %s not found in TCPStore.", key));
std::vector<uint8_t> value = iter->second; std::vector<uint8_t> value = iter->second;
tcputils::send_vector<uint8_t>(socket, value); tcputils::send_vector<uint8_t>(socket, value);
} }
void MasterDaemon::_do_stop(SocketType socket) { void MasterDaemon::_do_stop(SocketType socket) {
VLOG(3) << "MasterDaemon::_do_stop"; VLOG(4) << "MasterDaemon::_do_stop " << GetSockName(socket);
if (!_has_stop) { if (!_has_stop) {
_stop_time = std::chrono::system_clock::now(); _stop_time = std::chrono::system_clock::now();
} }
...@@ -102,9 +108,40 @@ void MasterDaemon::_do_stop(SocketType socket) { ...@@ -102,9 +108,40 @@ void MasterDaemon::_do_stop(SocketType socket) {
} }
} }
#ifndef _WIN32
void MasterDaemon::InitControlFd() {
PADDLE_ENFORCE_NE(
pipe(_control_fd.data()),
-1,
platform::errors::Fatal("failed to cread control pipe errno:%d", errno));
}
void MasterDaemon::CloseControlFd() {
for (int fd : _control_fd) {
if (fd != -1) {
::close(fd);
}
}
}
void MasterDaemon::StopByControlFd() {
VLOG(4) << ("begin to run StopByControlFd");
if (_control_fd[1] != -1) {
::write(_control_fd[1], "\0", 1);
// close the write end of the pipe
::close(_control_fd[1]);
_control_fd[1] = -1;
}
}
#else
void MasterDaemon::InitControlFd() {}
void MasterDaemon::CloseControlFd() {}
void MasterDaemon::StopByControlFd() {}
#endif
void MasterDaemon::_do_wait(SocketType socket) { void MasterDaemon::_do_wait(SocketType socket) {
VLOG(3) << "MasterDaemon::_do_wait";
std::string key = tcputils::receive_string(socket); std::string key = tcputils::receive_string(socket);
VLOG(4) << "MasterDaemon::_do_wait key(" << key << ") "
<< GetSockName(socket);
auto iter = _store.find(key); auto iter = _store.find(key);
auto reply = ReplyType::STOP_WAIT; auto reply = ReplyType::STOP_WAIT;
if (iter == _store.end()) { if (iter == _store.end()) {
...@@ -115,12 +152,67 @@ void MasterDaemon::_do_wait(SocketType socket) { ...@@ -115,12 +152,67 @@ void MasterDaemon::_do_wait(SocketType socket) {
tcputils::send_value<ReplyType>(socket, reply); tcputils::send_value<ReplyType>(socket, reply);
} }
void MasterDaemon::ProcessCommands(std::vector<struct pollfd>* p_fds) {
std::vector<struct pollfd>& fds = *p_fds;
// FIXME(gongwb): Don't loop all fds of set just the fds who have event.
#ifdef _WIN32
// 0: listen socket, so loop from 1.
for (size_t i = 1; i < fds.size(); i++) {
#else
// 0: listen socket, 1:controller pipe, so loop from 2.
for (size_t i = 2; i < fds.size(); i++) {
#endif
try {
if (fds[i].revents == 0) {
continue;
}
Command command = tcputils::receive_value<Command>(fds[i].fd);
VLOG(3) << "TCPStore: recv command: " << static_cast<int>(command) << ".";
switch (command) {
case Command::ADD:
_do_add(fds[i].fd);
break;
case Command::GET:
_do_get(fds[i].fd);
break;
case Command::SET:
_do_set(fds[i].fd);
break;
case Command::WAIT:
_do_wait(fds[i].fd);
break;
case Command::STOP:
_do_stop(fds[i].fd);
break;
default:
LOG(WARNING) << "Unknown command: " << static_cast<int>(command)
<< " from addr info:" << GetSockName(fds[i].fd);
}
} catch (const std::exception& ex) {
fds.erase(fds.begin() + i);
tcputils::close_socket(fds[i].fd);
#ifdef _WIN32
_sockets.erase(_sockets.begin() + i - 1);
#else
_sockets.erase(_sockets.begin() + i - 2);
#endif
VLOG(3) << "Meet some exceptions during run:" << ex.what();
}
}
}
void MasterDaemon::run() { void MasterDaemon::run() {
VLOG(4) << "begin to run run _stop:" << _stop << " _has_stop:" << _has_stop;
std::vector<struct pollfd> fds; std::vector<struct pollfd> fds;
#ifdef _WIN32 #ifdef _WIN32
fds.push_back({_listen_socket, POLLIN}); fds.push_back({_listen_socket, POLLIN});
#else #else
fds.push_back({.fd = _listen_socket, .events = POLLIN, .revents = 0}); fds.push_back({.fd = _listen_socket, .events = POLLIN, .revents = 0});
fds.push_back(
{.fd = _control_fd[0], .events = POLLIN | POLLHUP, .revents = 0});
#endif #endif
while (!_stop) { while (!_stop) {
...@@ -129,7 +221,8 @@ void MasterDaemon::run() { ...@@ -129,7 +221,8 @@ void MasterDaemon::run() {
std::chrono::duration<double> diff = end_time - _stop_time; std::chrono::duration<double> diff = end_time - _stop_time;
int elapsed_seconds = static_cast<int>(diff.count()); int elapsed_seconds = static_cast<int>(diff.count());
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
elapsed_seconds, _stop_check_timeout, elapsed_seconds,
_timeout,
platform::errors::Fatal( platform::errors::Fatal(
"%d seconds elapsed after the first worker " "%d seconds elapsed after the first worker "
"stopped, so we think there may be something wrong and will " "stopped, so we think there may be something wrong and will "
...@@ -138,16 +231,34 @@ void MasterDaemon::run() { ...@@ -138,16 +231,34 @@ void MasterDaemon::run() {
" to change the timeout value in seconds. The default one is 900", " to change the timeout value in seconds. The default one is 900",
elapsed_seconds)); elapsed_seconds));
} }
for (size_t i = 0; i < fds.size(); i++) { for (size_t i = 0; i < fds.size(); i++) {
fds[i].revents = 0; fds[i].revents = 0;
} }
VLOG(9) << "begin to poll fds_size:"
<< paddle::string::Sprintf("%d", fds.size());
#ifdef _WIN32 #ifdef _WIN32
::WSAPoll(fds.data(), fds.size(), INFTIME); ::WSAPoll(fds.data(), fds.size(), INFTIME);
#else #else
::poll(fds.data(), fds.size(), INFTIME); ::poll(fds.data(), fds.size(), INFTIME);
VLOG(9) << "begin to fds[1].revents:"
<< paddle::string::Sprintf("%d", fds[1].revents);
// The control pipe receive shutdown event, and begin to close it.
if (fds[1].revents != 0) {
if (fds[1].revents & ~(POLLIN | POLLHUP)) {
PADDLE_THROW(paddle::platform::errors::Fatal("Undefined event type:%d",
fds[1].revents));
}
VLOG(0)
<< "receive shutdown event and so quit from MasterDaemon run loop";
_stop = true;
break;
}
#endif #endif
// accept connect request.
if (fds[0].revents != 0) { if (fds[0].revents != 0) {
auto socket = tcputils::tcp_accept(_listen_socket); auto socket = tcputils::tcp_accept(_listen_socket);
_sockets.emplace_back(socket); _sockets.emplace_back(socket);
...@@ -158,45 +269,12 @@ void MasterDaemon::run() { ...@@ -158,45 +269,12 @@ void MasterDaemon::run() {
#endif #endif
} }
for (size_t i = 1; i < fds.size(); i++) { ProcessCommands(&fds);
try {
if (fds[i].revents == 0) {
continue;
}
Command command = tcputils::receive_value<Command>(fds[i].fd);
VLOG(3) << "TCPStore: recv command: " << static_cast<int>(command)
<< ".";
switch (command) {
case Command::ADD:
_do_add(fds[i].fd);
break;
case Command::GET:
_do_get(fds[i].fd);
break;
case Command::SET:
_do_set(fds[i].fd);
break;
case Command::WAIT:
_do_wait(fds[i].fd);
break;
case Command::STOP:
_do_stop(fds[i].fd);
break;
default:
VLOG(0) << "Unknow command: " << static_cast<int>(command);
exit(-1);
}
} catch (...) {
fds.erase(fds.begin() + i);
_sockets.erase(_sockets.begin() + i - 1);
}
}
} }
} }
std::unique_ptr<TCPServer> TCPServer::create(uint16_t port, int nranks, std::unique_ptr<TCPServer> TCPServer::create(uint16_t port,
int nranks,
int stop_check_timeout) { int stop_check_timeout) {
int socket = tcputils::tcp_listen("", std::to_string(port), AF_INET); int socket = tcputils::tcp_listen("", std::to_string(port), AF_INET);
auto server = std::make_unique<TCPServer>(); auto server = std::make_unique<TCPServer>();
...@@ -243,12 +321,21 @@ std::vector<T> TCPClient::receive_vector() { ...@@ -243,12 +321,21 @@ std::vector<T> TCPClient::receive_vector() {
} // namespace detail } // namespace detail
TCPStore::TCPStore(std::string host, uint16_t port, bool is_master, TCPStore::TCPStore(std::string host,
size_t num_workers, std::chrono::seconds timeout, uint16_t port,
int stop_check_timeout) bool is_master,
size_t num_workers,
int timeout)
: Store(timeout), _is_master(is_master), _num_workers(num_workers) { : Store(timeout), _is_master(is_master), _num_workers(num_workers) {
_timeout = timeout;
PADDLE_ENFORCE_GT(
timeout,
0,
platform::errors::InvalidArgument("timeout must >= %d", timeout));
VLOG(3) << "input timeout" << timeout << ", member timeout:" << _timeout;
if (_is_master) { if (_is_master) {
_server = detail::TCPServer::create(port, num_workers, stop_check_timeout); _server = detail::TCPServer::create(port, num_workers, timeout);
} }
_client = detail::TCPClient::connect(host, port); _client = detail::TCPClient::connect(host, port);
...@@ -261,11 +348,13 @@ void TCPStore::waitWorkers() { ...@@ -261,11 +348,13 @@ void TCPStore::waitWorkers() {
} }
add(_init_key, 1); add(_init_key, 1);
VLOG(3) << paddle::string::Sprintf("_timeout:%d", _timeout);
auto begin = std::chrono::steady_clock::now(); auto begin = std::chrono::steady_clock::now();
do { do {
auto value = get(_init_key); auto value = get(_init_key);
int completed = std::stoi(std::string(value.begin(), value.end())); int completed = std::stoi(std::string(value.begin(), value.end()));
VLOG(3) << completed << " worker ready, total " << _num_workers; VLOG(3) << completed << " worker ready, total " << _num_workers
<< ", _timeout:" << _timeout;
if (completed >= _num_workers) { if (completed >= _num_workers) {
break; break;
} }
...@@ -273,9 +362,16 @@ void TCPStore::waitWorkers() { ...@@ -273,9 +362,16 @@ void TCPStore::waitWorkers() {
std::chrono::steady_clock::now() - begin); std::chrono::steady_clock::now() - begin);
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (_timeout != tcputils::kNoTimeout && elapsed > _timeout) { if (_timeout != 0 && elapsed.count() > _timeout) {
LOG(FATAL) << paddle::string::Sprintf(
"_timeout:%d elapsed:%d (elapsed > _timeout)=%d",
_timeout,
elapsed.count(),
elapsed.count() > _timeout);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
completed, _num_workers, completed,
_num_workers,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"TCPStore timeouted and not all workers got ready.")); "TCPStore timeouted and not all workers got ready."));
} }
...@@ -293,7 +389,7 @@ int64_t TCPStore::add(const std::string& key, int64_t value) { ...@@ -293,7 +389,7 @@ int64_t TCPStore::add(const std::string& key, int64_t value) {
void TCPStore::set(const std::string& key, const std::vector<uint8_t>& value) { void TCPStore::set(const std::string& key, const std::vector<uint8_t>& value) {
VLOG(3) << "TCPStore set."; VLOG(3) << "TCPStore set.";
_client->send_command_for_key(Command::SET, _key_prefix + key); _client->send_command_for_key(Command::SET, _key_prefix + key);
_client->send_vector<std::uint8_t>(value); _client->send_vector<uint8_t>(value);
} }
std::vector<uint8_t> TCPStore::get(const std::string& key) { std::vector<uint8_t> TCPStore::get(const std::string& key) {
...@@ -314,14 +410,7 @@ void TCPStore::wait(const std::string& key) { ...@@ -314,14 +410,7 @@ void TCPStore::wait(const std::string& key) {
} while (reply != ReplyType::STOP_WAIT); } while (reply != ReplyType::STOP_WAIT);
} }
TCPStore::~TCPStore() { TCPStore::~TCPStore() { VLOG(3) << "TCPStore destructure"; }
VLOG(3) << "~TCPStore";
_client->send_command_for_key(Command::STOP, "");
ReplyType ret = _client->receive_value<ReplyType>();
PADDLE_ENFORCE_EQ(ret, ReplyType::STOP_WAIT,
platform::errors::InvalidArgument(
"The reply for TCPStore destructure must be 0."));
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -14,12 +14,23 @@ ...@@ -14,12 +14,23 @@
#pragma once #pragma once
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#ifndef _WIN32
#include <unistd.h>
#endif
#include <array>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/distributed/store/socket.h"
#include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/distributed/store/tcp_utils.h" #include "paddle/fluid/distributed/store/tcp_utils.h"
...@@ -35,14 +46,16 @@ class MasterDaemon { ...@@ -35,14 +46,16 @@ class MasterDaemon {
public: public:
static std::unique_ptr<MasterDaemon> start(SocketType listen_socket, static std::unique_ptr<MasterDaemon> start(SocketType listen_socket,
int nranks, int nranks,
int stop_check_timeout); int timeout);
MasterDaemon() = delete; MasterDaemon() = delete;
explicit MasterDaemon(SocketType listen_socket, int nranks, explicit MasterDaemon(SocketType listen_socket,
int nranks,
int stop_check_timeout); int stop_check_timeout);
~MasterDaemon(); ~MasterDaemon();
private: private:
void run(); void run();
void ProcessCommands(std::vector<struct pollfd>* p_fds);
void _do_add(SocketType socket); void _do_add(SocketType socket);
void _do_wait(SocketType socket); void _do_wait(SocketType socket);
void _do_get(SocketType socket); void _do_get(SocketType socket);
...@@ -52,17 +65,26 @@ class MasterDaemon { ...@@ -52,17 +65,26 @@ class MasterDaemon {
std::vector<SocketType> _sockets; std::vector<SocketType> _sockets;
std::unordered_map<std::string, std::vector<uint8_t>> _store; std::unordered_map<std::string, std::vector<uint8_t>> _store;
std::thread _background_thread{}; std::thread _background_thread{};
int _nranks; int _nranks = -1;
int _stop_check_timeout; int _timeout = 0;
bool _stop = false; // all workers stopped bool _stop = false; // all workers stopped
std::chrono::time_point<std::chrono::system_clock> _stop_time; std::chrono::time_point<std::chrono::system_clock> _stop_time;
bool _has_stop = false; // at least one worker stopped bool _has_stop = false; // at least one worker stopped
void InitControlFd();
void CloseControlFd();
void StopByControlFd();
#ifdef _WIN32
#else
std::array<int, 2> _control_fd{{-1, -1}};
#endif
}; };
class TCPServer { class TCPServer {
public: public:
TCPServer() = default; TCPServer() = default;
static std::unique_ptr<TCPServer> create(std::uint16_t port, int nranks, static std::unique_ptr<TCPServer> create(std::uint16_t port,
int nranks,
int stop_check_timeout); int stop_check_timeout);
private: private:
...@@ -94,13 +116,15 @@ class TCPClient { ...@@ -94,13 +116,15 @@ class TCPClient {
} // namespace detail } // namespace detail
// TODO(gongwb) :Add IP6 support.
class TCPStore : public Store { class TCPStore : public Store {
public: public:
static constexpr std::uint16_t kDefaultPort = 6170; static constexpr std::uint16_t kDefaultPort = 6170;
explicit TCPStore(std::string host, uint16_t port = kDefaultPort, explicit TCPStore(std::string host,
bool is_master = false, size_t num_workers = 1, uint16_t port = kDefaultPort,
std::chrono::seconds timeout = tcputils::kDefaultTimeout, bool is_master = false,
int stop_check_timeout = 900); size_t num_workers = 1,
int timeout = 900);
~TCPStore(); ~TCPStore();
...@@ -116,7 +140,7 @@ class TCPStore : public Store { ...@@ -116,7 +140,7 @@ class TCPStore : public Store {
const std::string _init_key = "init/"; const std::string _init_key = "init/";
const std::string _key_prefix = "/"; const std::string _key_prefix = "/";
std::chrono::seconds _timeout;
bool _is_master; bool _is_master;
int _num_workers; int _num_workers;
}; };
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
#include "paddle/fluid/distributed/store/tcp_utils.h"
#ifdef _WIN32
#include <windows.h>
#endif
namespace paddle {
namespace distributed {
TEST(MasterDaemon, init) {
int socket = tcputils::tcp_listen("", std::to_string(0), AF_INET);
auto d = detail::MasterDaemon::start(socket, 1, 100);
printf("started to sleep 2s\n");
#ifdef _WIN32
Sleep(2 * 1000);
#else
usleep(2 * 1000 * 1000);
#endif
printf("end to reset\n");
d.reset();
}
/* now for only c compile test
TEST(TCPStore, init) {
TCPStore store("127.0.0.1", 6170, true, 1);
store.add("my", 3);
auto ret1 = store.get("my");
store.add("my", 3);
auto ret2 = store.get("my");
PADDLE_ENFORCE_EQ(ret1[0] + 3, ret2[0],
paddle::errors::Fatal("result of add is not right"));
}
*/
}; // namespace distributed
}; // namespace paddle
...@@ -39,12 +39,14 @@ void BindTCPStore(py::module *m) { ...@@ -39,12 +39,14 @@ void BindTCPStore(py::module *m) {
.def(py::init<>()) .def(py::init<>())
.def( .def(
"set", "set",
[](distributed::Store &self, const std::string &key, [](distributed::Store &self,
const std::string &key,
const std::string &value) { const std::string &value) {
std::vector<uint8_t> data(value.begin(), value.end()); std::vector<uint8_t> data(value.begin(), value.end());
self.set(key, data); self.set(key, data);
}, },
py::arg("key"), py::arg("value"), py::arg("key"),
py::arg("value"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"get", "get",
...@@ -54,24 +56,29 @@ void BindTCPStore(py::module *m) { ...@@ -54,24 +56,29 @@ void BindTCPStore(py::module *m) {
return py::bytes(reinterpret_cast<char *>(data.data()), return py::bytes(reinterpret_cast<char *>(data.data()),
data.size()); data.size());
}, },
py::arg("key"), py::call_guard<py::gil_scoped_release>()) py::arg("key"),
.def("add", &distributed::Store::add, py::call_guard<py::gil_scoped_release>())
.def("add",
&distributed::Store::add,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("wait", &distributed::Store::wait, .def("wait",
&distributed::Store::wait,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
py::class_<TCPStore, std::shared_ptr<TCPStore>>(*m, "TCPStore", Store) py::class_<TCPStore, std::shared_ptr<TCPStore>>(*m, "TCPStore", Store)
.def(py::init([](std::string hostname, uint16_t port, bool is_master, .def(py::init([](std::string hostname,
size_t world_size, std::chrono::seconds timeout, uint16_t port,
int stop_check_timeout) { bool is_master,
return std::make_shared<TCPStore>(hostname, port, is_master, size_t world_size,
world_size, timeout, int timeout) {
stop_check_timeout); return std::make_shared<TCPStore>(
hostname, port, is_master, world_size, timeout);
}), }),
py::arg("hostname"), py::arg("port"), py::arg("is_master"), py::arg("hostname"),
py::arg("port"),
py::arg("is_master"),
py::arg("world_size"), py::arg("world_size"),
py::arg("timeout") = distributed::tcputils::kNoTimeout, py::arg("timeout") = 900,
py::arg("stop_check_timeout") = 900,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
} }
......
...@@ -276,6 +276,7 @@ def get_cluster_from_args(args, device_mode, devices_per_proc): ...@@ -276,6 +276,7 @@ def get_cluster_from_args(args, device_mode, devices_per_proc):
free_ports = find_free_ports(len(devices_per_proc)) free_ports = find_free_ports(len(devices_per_proc))
if free_ports is not None: if free_ports is not None:
free_ports = list(free_ports) free_ports = list(free_ports)
logger.info("find free ports:{}".format(free_ports))
else: else:
start_port = 6070 start_port = 6070
if os.environ.get('FLAGS_START_PORT') is not None: if os.environ.get('FLAGS_START_PORT') is not None:
......
...@@ -240,7 +240,7 @@ def init_parallel_env(): ...@@ -240,7 +240,7 @@ def init_parallel_env():
master_port, master_port,
is_master, is_master,
world_size, world_size,
stop_check_timeout=stop_check_timeout) timeout=stop_check_timeout)
_set_default_store(default_store) _set_default_store(default_store)
pg = _new_process_group_impl(backend, pg = _new_process_group_impl(backend,
default_store, default_store,
......
...@@ -129,6 +129,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial) ...@@ -129,6 +129,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_dpmppp) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_dpmppp)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_cost_model) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_cost_model)
list(APPEND MIXED_DIST_TEST_OPS test_tcp_store)
foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
list(REMOVE_ITEM TEST_OPS ${TEST_OP}) list(REMOVE_ITEM TEST_OPS ${TEST_OP})
endforeach() endforeach()
...@@ -868,6 +869,7 @@ if(WITH_DISTRIBUTE) ...@@ -868,6 +869,7 @@ if(WITH_DISTRIBUTE)
test_auto_parallel_reshard_dpmppp ENVS ${dist_ENVS}) test_auto_parallel_reshard_dpmppp ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_cost_model MODULES py_test_modules(test_auto_parallel_cost_model MODULES
test_auto_parallel_cost_model ENVS ${dist_ENVS}) test_auto_parallel_cost_model ENVS ${dist_ENVS})
if(WITH_GPU if(WITH_GPU
OR WITH_XPU OR WITH_XPU
OR WITH_ASCEND OR WITH_ASCEND
...@@ -895,6 +897,21 @@ if(WITH_DISTRIBUTE) ...@@ -895,6 +897,21 @@ if(WITH_DISTRIBUTE)
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_mnist_dgc_nccl") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_mnist_dgc_nccl")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_se_resnext_dgc") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_se_resnext_dgc")
endif() endif()
# port range (20000, 23000) is reserved for dist-ops
set(dist_ut_port 20001)
if(NOT WIN32)
bash_test_modules(
test_tcp_store
START_BASH
dist_test.sh
LABELS
"RUN_TYPE=EXCLUSIVE"
ENVS
"PADDLE_DIST_UT_PORT=${dist_ut_port}")
math(EXPR dist_ut_port "${dist_ut_port}+1")
endif()
if(NOT APPLE) if(NOT APPLE)
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
bash_test_modules(test_c_comm_init_op START_BASH test_c_comm_init_op.sh bash_test_modules(test_c_comm_init_op START_BASH test_c_comm_init_op.sh
...@@ -930,7 +947,6 @@ if(WITH_DISTRIBUTE) ...@@ -930,7 +947,6 @@ if(WITH_DISTRIBUTE)
endif() endif()
# port range (20000, 23000) is reserved for dist-ops # port range (20000, 23000) is reserved for dist-ops
set(dist_ut_port 20001)
foreach(TEST_OP ${DIST_TEST_OPS}) foreach(TEST_OP ${DIST_TEST_OPS})
bash_test_modules( bash_test_modules(
${TEST_OP} ${TEST_OP}
......
...@@ -47,7 +47,7 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -47,7 +47,7 @@ class TestProcessGroupFp32(unittest.TestCase):
rank = ParallelEnv().local_rank rank = ParallelEnv().local_rank
is_master = True if rank == 0 else False is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore("127.0.0.1", 6272, is_master, store = paddle.fluid.core.TCPStore("127.0.0.1", 6272, is_master,
nranks, datetime.timedelta(0)) nranks, 30)
place = paddle.fluid.core.CPUPlace() place = paddle.fluid.core.CPUPlace()
pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks, place) pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks, place)
...@@ -64,11 +64,11 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -64,11 +64,11 @@ class TestProcessGroupFp32(unittest.TestCase):
if rank == 0: if rank == 0:
task = pg.allreduce(tensor_x) task = pg.allreduce(tensor_x)
task.wait() task.wait()
assert np.array_equal(tensor_x, sum_result) np.testing.assert_equal(tensor_x, sum_result)
else: else:
task = pg.allreduce(tensor_y) task = pg.allreduce(tensor_y)
task.wait() task.wait()
assert np.array_equal(tensor_y, sum_result) np.testing.assert_equal(tensor_y, sum_result)
print("test allreduce sum api ok") print("test allreduce sum api ok")
......
...@@ -17,13 +17,15 @@ from __future__ import print_function ...@@ -17,13 +17,15 @@ from __future__ import print_function
import unittest import unittest
import datetime import datetime
import paddle import paddle
import os
class TestTCPStore(unittest.TestCase): class TestTCPStore(unittest.TestCase):
def test_tcp_store(self): def test_tcp_store(self):
store = paddle.fluid.core.TCPStore("127.0.0.1", 6170, True, 1, dist_port = int(os.getenv("PADDLE_DIST_UT_PORT", 6170))
datetime.timedelta(0)) print("get dist_port:", dist_port)
store = paddle.fluid.core.TCPStore("127.0.0.1", dist_port, True, 1, 1)
store.add("my", 3) store.add("my", 3)
ret1 = store.get('my') ret1 = store.get('my')
store.add("my", 3) store.add("my", 3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册