message_bus.cc 6.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <chrono>
16
#include <memory>
Y
Yuang Liu 已提交
17
#include <set>
18
#include <thread>
19 20

#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
21
#include "paddle/fluid/platform/gen_comm_id_helper.h"
22 23 24 25

namespace paddle {
namespace distributed {

26
void MessageBus::Init(
27
    int64_t rank, const std::unordered_map<int64_t, std::string>& rank_to_addr,
28 29 30
    const std::string& addr) {
  PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
                                         "MessageBus is already init."));
31
  rank_ = rank;
32 33 34 35
  is_init_ = true;
  rank_to_addr_ = rank_to_addr;
  addr_ = addr;

36 37 38 39 40 41 42 43 44 45
  if (addr_ != "") {
    const auto& addr = GetAddr(rank_);
    PADDLE_ENFORCE_EQ(addr, addr_,
                      platform::errors::Fatal(
                          "The current rank's addr is %s, while the "
                          "message bus's addr is %s, which are different. "
                          "Init error.",
                          addr, addr_));
  }

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
    defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
  // NOTE: To make the brpc is compatible with collective,
  // need release the handler holding the ip address.
  if (addr_ != "") {
    VLOG(3) << "Message bus is releasing the fd held by gen_comm_id.";
    paddle::platform::SocketServer& socket_server =
        paddle::platform::SocketServer::GetInstance(addr_);
    int server_fd = socket_server.socket();
    if (server_fd != -1) {
      socket_server.Release();
    }
  }
#endif

61
  ListenPort();
62 63
}

64 65
bool MessageBus::IsInit() const { return is_init_; }

66
MessageBus::~MessageBus() {
Y
Yuang Liu 已提交
67
  VLOG(3) << "Message bus releases resource.";
68 69 70 71 72
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
    !defined(PADDLE_WITH_ASCEND_CL)
  server_.Stop(1000);
  server_.Join();
#endif
73 74
}

75 76 77 78 79 80 81 82 83
const std::string& MessageBus::GetAddr(int64_t rank) const {
  PADDLE_ENFORCE_NE(
      rank_to_addr_.find(rank), rank_to_addr_.end(),
      platform::errors::NotFound("Cannot find addr rank id %lld.", rank));
  return rank_to_addr_.at(rank);
}

bool MessageBus::Send(int64_t dst_rank,
                      const InterceptorMessage& interceptor_message) {
84 85
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
    !defined(PADDLE_WITH_ASCEND_CL)
86 87 88 89 90 91 92
  int retry_time = 0;  // message bus will retry sending for 10 times
  while (retry_time < 10) {
    ++retry_time;
    if (SendInterRank(dst_rank, interceptor_message)) {
      VLOG(3) << "Message bus sends inter rank successfully with " << retry_time
              << " times retries.";
      return true;
93
    }
94 95 96 97 98
    VLOG(3) << "Message bus sends failed, retry after 1 seconds.";
    std::this_thread::sleep_for(std::chrono::milliseconds(1000));
  }
  VLOG(3) << "Message bus sends inter rank fail after 10 times retries.";
  return false;
99
#else
100 101 102 103
  PADDLE_THROW(platform::errors::Unavailable(
      "Fleet executor does not support sending message between different "
      "ranks when Paddle is compiled with npu or "
      "isn't compiled with distributed for now."));
104
#endif
105 106 107
  return true;
}

108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
void MessageBus::TestConnection() {
  InterceptorMessage ctrl_msg;
  ctrl_msg.set_ctrl_message(true);
  ctrl_msg.set_src_id(rank_);
  for (const auto& dst_rank_pair : rank_to_addr_) {
    int64_t dst_rank = dst_rank_pair.first;
    if (dst_rank != rank_) {
      ctrl_msg.set_dst_id(dst_rank);
      VLOG(3) << "Send control message bus from rank " << rank_ << " to rank "
              << dst_rank;
      while (!Send(dst_rank, ctrl_msg)) {
        std::this_thread::sleep_for(std::chrono::milliseconds(1000));
      }
      VLOG(3) << "Message bus has connected to rank: " << dst_rank << ".";
    }
  }
}

126
void MessageBus::ListenPort() {
127
  if (addr_ == "") {
128
    LOG(INFO) << "No need listen to port since training on single card.";
129 130
    return;
  }
131 132
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
    !defined(PADDLE_WITH_ASCEND_CL)
133
  // function keep listen the port and handle the message
134
  PADDLE_ENFORCE_EQ(server_.AddService(&interceptor_message_service_,
135 136 137 138 139 140 141 142
                                       brpc::SERVER_DOESNT_OWN_SERVICE),
                    0, platform::errors::Unavailable(
                           "Message bus: init brpc service error."));

  // start the server
  const char* ip_for_brpc = addr_.c_str();
  brpc::ServerOptions options;
  options.idle_timeout_sec = -1;
143
  int retry_times = 0;
Y
Yuang Liu 已提交
144
  int interval = 100;
145 146 147 148 149 150
  while (server_.Start(ip_for_brpc, &options) != 0) {
    ++retry_times;
    LOG(INFO) << "Message bus is retring for starting brpc for " << retry_times
              << " times. And will retry after " << interval / 1000
              << " seconds.";
    std::this_thread::sleep_for(std::chrono::milliseconds(interval));
Y
Yuang Liu 已提交
151
    interval += 500;
152 153
  }
  LOG(INFO) << "Message bus's listen port thread starts successful.";
154
  TestConnection();
155
#else
156 157 158 159
  LOG(WARNING)
      << "Fleet executor's ListenPort() is a fake function when Paddle is "
         "compiled with npu or Paddle isn't compiled "
         "with distributed for now.";
160
#endif
161 162
}

163 164
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
    !defined(PADDLE_WITH_ASCEND_CL)
165 166 167 168 169
bool MessageBus::SendInterRank(int64_t dst_rank,
                               const InterceptorMessage& interceptor_message) {
  const auto& dst_addr = GetAddr(dst_rank);
  VLOG(3) << "Message bus sending to addr: " << dst_addr;
  const char* dst_addr_for_brpc = dst_addr.c_str();
170 171 172
  brpc::Channel channel;
  brpc::ChannelOptions options;
  options.protocol = "baidu_std";
173
  options.connect_timeout_ms = 1000;
174 175 176
  options.timeout_ms = 1000;
  options.max_retry = 5;
  PADDLE_ENFORCE_EQ(
177
      channel.Init(dst_addr_for_brpc, &options), 0,
178 179 180 181 182 183 184 185 186 187 188
      platform::errors::Unavailable("Message bus: init brpc channel error."));
  TheInterceptorMessageService_Stub stub(&channel);
  InterceptorResponse response;
  brpc::Controller ctrl;
  ctrl.set_log_id(0);
  stub.InterceptorMessageService(&ctrl, &interceptor_message, &response, NULL);
  if (!ctrl.Failed()) {
    if (response.rst()) {
      VLOG(3) << "Message bus: brpc sends success.";
      return true;
    } else {
189
      VLOG(4) << "Message bus: InterceptorMessageService error.";
190 191 192
      return false;
    }
  } else {
193
    VLOG(4) << "Message bus: brpc sends failed with error text: "
194 195 196
            << ctrl.ErrorText();
    return false;
  }
197 198 199 200 201
}
#endif

}  // namespace distributed
}  // namespace paddle