未验证 提交 df14dbf0 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Update with collective (#37462)

上级 38f1ef50
...@@ -10,9 +10,10 @@ else() ...@@ -10,9 +10,10 @@ else()
set(BRPC_DEPS "") set(BRPC_DEPS "")
endif() endif()
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc interceptor_message_service.cc message_bus.cc interceptor.cc compute_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto ${BRPC_DEPS}) DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper
${BRPC_DEPS})
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
......
...@@ -12,11 +12,14 @@ ...@@ -12,11 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <chrono>
#include <memory> #include <memory>
#include <thread>
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -32,6 +35,21 @@ void MessageBus::Init( ...@@ -32,6 +35,21 @@ void MessageBus::Init(
rank_to_addr_ = rank_to_addr; rank_to_addr_ = rank_to_addr;
addr_ = addr; addr_ = addr;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
// NOTE: To make the brpc is compatible with collective,
// need release the handler holding the ip address.
if (addr_ != "") {
VLOG(3) << "Message bus is releasing the fd held by gen_comm_id.";
paddle::platform::SocketServer& socket_server =
paddle::platform::SocketServer::GetInstance(addr_);
int server_fd = socket_server.socket();
if (server_fd != -1) {
socket_server.Release();
}
}
#endif
ListenPort(); ListenPort();
std::call_once(once_flag_, []() { std::call_once(once_flag_, []() {
...@@ -87,7 +105,7 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) { ...@@ -87,7 +105,7 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
void MessageBus::ListenPort() { void MessageBus::ListenPort() {
if (addr_ == "") { if (addr_ == "") {
VLOG(3) << "No need listen to port since training on single card."; LOG(INFO) << "No need listen to port since training on single card.";
return; return;
} }
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
...@@ -103,14 +121,22 @@ void MessageBus::ListenPort() { ...@@ -103,14 +121,22 @@ void MessageBus::ListenPort() {
const char* ip_for_brpc = addr_.c_str(); const char* ip_for_brpc = addr_.c_str();
brpc::ServerOptions options; brpc::ServerOptions options;
options.idle_timeout_sec = -1; options.idle_timeout_sec = -1;
PADDLE_ENFORCE_EQ( int retry_times = 0;
server_.Start(ip_for_brpc, &options), 0, int interval = 1000;
platform::errors::Unavailable("Message bus: start brpc service error.")); while (server_.Start(ip_for_brpc, &options) != 0) {
VLOG(3) << "Message bus's listen port thread starts successful."; ++retry_times;
LOG(INFO) << "Message bus is retring for starting brpc for " << retry_times
<< " times. And will retry after " << interval / 1000
<< " seconds.";
std::this_thread::sleep_for(std::chrono::milliseconds(interval));
interval += 2000;
}
LOG(INFO) << "Message bus's listen port thread starts successful.";
#else #else
VLOG(3) << "Fleet executor's ListenPort() is a fake function when Paddle is " LOG(WARNING)
"compiled with npu or Paddle isn't compiled " << "Fleet executor's ListenPort() is a fake function when Paddle is "
"with distributed for now."; "compiled with npu or Paddle isn't compiled "
"with distributed for now.";
#endif #endif
} }
......
...@@ -153,6 +153,16 @@ int CreateListenSocket(const std::string& ep) { ...@@ -153,6 +153,16 @@ int CreateListenSocket(const std::string& ep) {
// not enter the TIME-WAIT state. But this is obviously not as convenient // not enter the TIME-WAIT state. But this is obviously not as convenient
// as the reuse method. // as the reuse method.
int opt = 1; int opt = 1;
// NOTE. The linger is used for skipping TIME-WAIT status forcefully.
linger ling;
ling.l_onoff = 1;
ling.l_linger = 0;
CHECK_SYS_CALL(
setsockopt(server_fd, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)),
"setsockopt set linger");
#if defined(SO_REUSEPORT) #if defined(SO_REUSEPORT)
// since Linux kernel 3.9 // since Linux kernel 3.9
CHECK_SYS_CALL(setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, CHECK_SYS_CALL(setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT,
......
...@@ -22,6 +22,8 @@ limitations under the License. */ ...@@ -22,6 +22,8 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "glog/logging.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -46,10 +48,20 @@ class SocketServer { ...@@ -46,10 +48,20 @@ class SocketServer {
public: public:
SocketServer() = default; SocketServer() = default;
~SocketServer() { CloseSocket(server_fd_); } ~SocketServer() {
if (server_fd_ != -1) {
CloseSocket(server_fd_);
}
}
int socket() const { return server_fd_; } int socket() const { return server_fd_; }
void Release() {
VLOG(3) << "Server will be closed by external call.";
CloseSocket(server_fd_);
server_fd_ = -1;
}
static SocketServer& GetInstance(const std::string& end_point); static SocketServer& GetInstance(const std::string& end_point);
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册