diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index 162cbd8a7b5204761e521485b26470a708bc71ba..641110802f1fd3b5e8ec90056541ee671d9ba94d 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -10,9 +10,10 @@ else() set(BRPC_DEPS "") endif() -cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc +cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc interceptor.cc compute_interceptor.cc interceptor_message_service.cc message_bus.cc - DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto ${BRPC_DEPS}) + DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper + ${BRPC_DEPS}) if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index 309982bc04bebd5df9b75ce43d7605440abe6fec..23e1b2a31d88b14e391ee15f9121c671304295c9 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -12,11 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include +#include #include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" +#include "paddle/fluid/platform/gen_comm_id_helper.h" namespace paddle { namespace distributed { @@ -32,6 +35,21 @@ void MessageBus::Init( rank_to_addr_ = rank_to_addr; addr_ = addr; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL) + // NOTE: To make the brpc is compatible with collective, + // need release the handler holding the ip address. + if (addr_ != "") { + VLOG(3) << "Message bus is releasing the fd held by gen_comm_id."; + paddle::platform::SocketServer& socket_server = + paddle::platform::SocketServer::GetInstance(addr_); + int server_fd = socket_server.socket(); + if (server_fd != -1) { + socket_server.Release(); + } + } +#endif + ListenPort(); std::call_once(once_flag_, []() { @@ -87,7 +105,7 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) { void MessageBus::ListenPort() { if (addr_ == "") { - VLOG(3) << "No need listen to port since training on single card."; + LOG(INFO) << "No need listen to port since training on single card."; return; } #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ @@ -103,14 +121,22 @@ void MessageBus::ListenPort() { const char* ip_for_brpc = addr_.c_str(); brpc::ServerOptions options; options.idle_timeout_sec = -1; - PADDLE_ENFORCE_EQ( - server_.Start(ip_for_brpc, &options), 0, - platform::errors::Unavailable("Message bus: start brpc service error.")); - VLOG(3) << "Message bus's listen port thread starts successful."; + int retry_times = 0; + int interval = 1000; + while (server_.Start(ip_for_brpc, &options) != 0) { + ++retry_times; + LOG(INFO) << "Message bus is retring for starting brpc for " << retry_times + << " times. And will retry after " << interval / 1000 + << " seconds."; + std::this_thread::sleep_for(std::chrono::milliseconds(interval)); + interval += 2000; + } + LOG(INFO) << "Message bus's listen port thread starts successful."; #else - VLOG(3) << "Fleet executor's ListenPort() is a fake function when Paddle is " - "compiled with npu or Paddle isn't compiled " - "with distributed for now."; + LOG(WARNING) + << "Fleet executor's ListenPort() is a fake function when Paddle is " + "compiled with npu or Paddle isn't compiled " + "with distributed for now."; #endif } diff --git a/paddle/fluid/platform/gen_comm_id_helper.cc b/paddle/fluid/platform/gen_comm_id_helper.cc index e9fe2a38c6c43cea391516187b6bcbaccd471c38..1b77eb42837d4193cf8f7decbfd41634f16ac882 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.cc +++ b/paddle/fluid/platform/gen_comm_id_helper.cc @@ -153,6 +153,16 @@ int CreateListenSocket(const std::string& ep) { // not enter the TIME-WAIT state. But this is obviously not as convenient // as the reuse method. int opt = 1; + + // NOTE. The linger is used for skipping TIME-WAIT status forcefully. + linger ling; + ling.l_onoff = 1; + ling.l_linger = 0; + + CHECK_SYS_CALL( + setsockopt(server_fd, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)), + "setsockopt set linger"); + #if defined(SO_REUSEPORT) // since Linux kernel 3.9 CHECK_SYS_CALL(setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, diff --git a/paddle/fluid/platform/gen_comm_id_helper.h b/paddle/fluid/platform/gen_comm_id_helper.h index 6198519eb06df8dfac1540f5bca8387a03afc71f..9bbbb1f424a74fe080b5dfcb7bfc0df9c7272356 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.h +++ b/paddle/fluid/platform/gen_comm_id_helper.h @@ -22,6 +22,8 @@ limitations under the License. */ #include #include +#include "glog/logging.h" + namespace paddle { namespace platform { @@ -46,10 +48,20 @@ class SocketServer { public: SocketServer() = default; - ~SocketServer() { CloseSocket(server_fd_); } + ~SocketServer() { + if (server_fd_ != -1) { + CloseSocket(server_fd_); + } + } int socket() const { return server_fd_; } + void Release() { + VLOG(3) << "Server will be closed by external call."; + CloseSocket(server_fd_); + server_fd_ = -1; + } + static SocketServer& GetInstance(const std::string& end_point); private: