未验证 提交 df14dbf0 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Update with collective (#37462)

上级 38f1ef50
......@@ -10,9 +10,10 @@ else()
set(BRPC_DEPS "")
endif()
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto ${BRPC_DEPS})
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper
${BRPC_DEPS})
if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
......
......@@ -12,11 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <chrono>
#include <memory>
#include <thread>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
namespace paddle {
namespace distributed {
......@@ -32,6 +35,21 @@ void MessageBus::Init(
rank_to_addr_ = rank_to_addr;
addr_ = addr;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
// NOTE: To make the brpc is compatible with collective,
// need release the handler holding the ip address.
if (addr_ != "") {
VLOG(3) << "Message bus is releasing the fd held by gen_comm_id.";
paddle::platform::SocketServer& socket_server =
paddle::platform::SocketServer::GetInstance(addr_);
int server_fd = socket_server.socket();
if (server_fd != -1) {
socket_server.Release();
}
}
#endif
ListenPort();
std::call_once(once_flag_, []() {
......@@ -87,7 +105,7 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
void MessageBus::ListenPort() {
if (addr_ == "") {
VLOG(3) << "No need listen to port since training on single card.";
LOG(INFO) << "No need listen to port since training on single card.";
return;
}
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
......@@ -103,14 +121,22 @@ void MessageBus::ListenPort() {
const char* ip_for_brpc = addr_.c_str();
brpc::ServerOptions options;
options.idle_timeout_sec = -1;
PADDLE_ENFORCE_EQ(
server_.Start(ip_for_brpc, &options), 0,
platform::errors::Unavailable("Message bus: start brpc service error."));
VLOG(3) << "Message bus's listen port thread starts successful.";
int retry_times = 0;
int interval = 1000;
while (server_.Start(ip_for_brpc, &options) != 0) {
++retry_times;
LOG(INFO) << "Message bus is retring for starting brpc for " << retry_times
<< " times. And will retry after " << interval / 1000
<< " seconds.";
std::this_thread::sleep_for(std::chrono::milliseconds(interval));
interval += 2000;
}
LOG(INFO) << "Message bus's listen port thread starts successful.";
#else
VLOG(3) << "Fleet executor's ListenPort() is a fake function when Paddle is "
"compiled with npu or Paddle isn't compiled "
"with distributed for now.";
LOG(WARNING)
<< "Fleet executor's ListenPort() is a fake function when Paddle is "
"compiled with npu or Paddle isn't compiled "
"with distributed for now.";
#endif
}
......
......@@ -153,6 +153,16 @@ int CreateListenSocket(const std::string& ep) {
// not enter the TIME-WAIT state. But this is obviously not as convenient
// as the reuse method.
int opt = 1;
// NOTE. The linger is used for skipping TIME-WAIT status forcefully.
linger ling;
ling.l_onoff = 1;
ling.l_linger = 0;
CHECK_SYS_CALL(
setsockopt(server_fd, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)),
"setsockopt set linger");
#if defined(SO_REUSEPORT)
// since Linux kernel 3.9
CHECK_SYS_CALL(setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT,
......
......@@ -22,6 +22,8 @@ limitations under the License. */
#include <string>
#include <vector>
#include "glog/logging.h"
namespace paddle {
namespace platform {
......@@ -46,10 +48,20 @@ class SocketServer {
public:
SocketServer() = default;
~SocketServer() { CloseSocket(server_fd_); }
~SocketServer() {
if (server_fd_ != -1) {
CloseSocket(server_fd_);
}
}
int socket() const { return server_fd_; }
void Release() {
VLOG(3) << "Server will be closed by external call.";
CloseSocket(server_fd_);
server_fd_ = -1;
}
static SocketServer& GetInstance(const std::string& end_point);
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册