message_bus.cc 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <chrono>
16
#include <memory>
17
#include <thread>
18

19
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
20 21
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
22
#include "paddle/fluid/platform/gen_comm_id_helper.h"
23 24 25 26

namespace paddle {
namespace distributed {

27
void MessageBus::Init(
28 29
    const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
    const std::unordered_map<int64_t, std::string>& rank_to_addr,
30 31 32 33 34 35 36 37
    const std::string& addr) {
  PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
                                         "MessageBus is already init."));
  is_init_ = true;
  interceptor_id_to_rank_ = interceptor_id_to_rank;
  rank_to_addr_ = rank_to_addr;
  addr_ = addr;

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
    defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
  // NOTE: To make the brpc is compatible with collective,
  // need release the handler holding the ip address.
  if (addr_ != "") {
    VLOG(3) << "Message bus is releasing the fd held by gen_comm_id.";
    paddle::platform::SocketServer& socket_server =
        paddle::platform::SocketServer::GetInstance(addr_);
    int server_fd = socket_server.socket();
    if (server_fd != -1) {
      socket_server.Release();
    }
  }
#endif

53
  ListenPort();
54 55
}

56 57
bool MessageBus::IsInit() const { return is_init_; }

58
MessageBus::~MessageBus() {
59
  VLOG(3) << "Message bus releases resource.";
60 61 62 63
  // NOTE: fleet_executor inits carrier before message bus,
  // therefore the message bus's destructor will be called first
  Carrier& carrier = Carrier::Instance();
  carrier.Release();
64 65 66 67 68
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
    !defined(PADDLE_WITH_ASCEND_CL)
  server_.Stop(1000);
  server_.Join();
#endif
69 70 71 72
}

bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
  // called by Interceptor, send InterceptorMessage to dst
73 74 75
  int64_t src_id = interceptor_message.src_id();
  int64_t dst_id = interceptor_message.dst_id();
  if (IsSameRank(src_id, dst_id)) {
76 77
    VLOG(3) << "Send a message from interceptor " << src_id
            << " to interceptor " << dst_id << ", which are in the same ranks.";
78 79
    return SendIntraRank(interceptor_message);
  } else {
80 81 82
    VLOG(3) << "Send a message from interceptor " << src_id
            << " to interceptor " << dst_id
            << ", which are in different ranks.";
83 84
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
    !defined(PADDLE_WITH_ASCEND_CL)
85 86 87 88 89 90 91 92 93 94 95
    int retry_time = 0;  // message bus will retry sending for 10 times
    while (retry_time < 10) {
      ++retry_time;
      if (SendInterRank(interceptor_message)) {
        VLOG(3) << "Message bus sends inter rank successfully with "
                << retry_time << " times retries.";
        return true;
      }
    }
    VLOG(3) << "Message bus sends inter rank fail after 10 times retries.";
    return false;
96 97 98 99 100 101 102
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "Fleet executor does not support sending message between different "
        "ranks when Paddle is compiled with npu or "
        "isn't compiled with distributed for now."));
#endif
  }
103 104 105 106
  return true;
}

void MessageBus::ListenPort() {
107
  if (addr_ == "") {
108
    LOG(INFO) << "No need listen to port since training on single card.";
109 110
    return;
  }
111 112
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
    !defined(PADDLE_WITH_ASCEND_CL)
113
  // function keep listen the port and handle the message
114
  PADDLE_ENFORCE_EQ(server_.AddService(&interceptor_message_service_,
115 116 117 118 119 120 121 122
                                       brpc::SERVER_DOESNT_OWN_SERVICE),
                    0, platform::errors::Unavailable(
                           "Message bus: init brpc service error."));

  // start the server
  const char* ip_for_brpc = addr_.c_str();
  brpc::ServerOptions options;
  options.idle_timeout_sec = -1;
123 124 125 126 127 128 129 130 131 132 133
  int retry_times = 0;
  int interval = 1000;
  while (server_.Start(ip_for_brpc, &options) != 0) {
    ++retry_times;
    LOG(INFO) << "Message bus is retring for starting brpc for " << retry_times
              << " times. And will retry after " << interval / 1000
              << " seconds.";
    std::this_thread::sleep_for(std::chrono::milliseconds(interval));
    interval += 2000;
  }
  LOG(INFO) << "Message bus's listen port thread starts successful.";
134
#else
135 136 137 138
  LOG(WARNING)
      << "Fleet executor's ListenPort() is a fake function when Paddle is "
         "compiled with npu or Paddle isn't compiled "
         "with distributed for now.";
139
#endif
140 141 142
}

bool MessageBus::IsSameRank(int64_t src_id, int64_t dst_id) {
143 144 145
  // -1 is sent by carrier to source interceptor
  if (src_id == -1) src_id = dst_id;

146
  // check whether the dst is the same rank or different rank with src
147 148 149 150 151 152 153 154 155 156
  const auto& src_rank = interceptor_id_to_rank_.find(src_id);
  const auto& dst_rank = interceptor_id_to_rank_.find(dst_id);
  PADDLE_ENFORCE_NE(
      src_rank, interceptor_id_to_rank_.end(),
      platform::errors::NotFound(
          "Cannot find rank for src interceptor id %lld. Init error.", src_id));
  PADDLE_ENFORCE_NE(
      dst_rank, interceptor_id_to_rank_.end(),
      platform::errors::NotFound(
          "Cannot find rank for dst interceptor id %lld. Init error.", dst_id));
157 158 159 160
  if (addr_ == "") {
    // single card training, must be same rank
    return true;
  }
161 162 163 164 165 166 167 168 169 170 171 172
  const auto& src_ip = rank_to_addr_.find(src_rank->second);
  PADDLE_ENFORCE_NE(src_ip, rank_to_addr_.end(),
                    platform::errors::NotFound(
                        "Cannot find addr for src rank id %lld. Init error.",
                        src_rank->second));
  PADDLE_ENFORCE_EQ(
      src_ip->second, addr_,
      platform::errors::Fatal("The src interceptor's addr is %s, while the "
                              "message bus's addr is %s, which are different. "
                              "Init error.",
                              src_ip->second, addr_));
  return src_rank->second == dst_rank->second;
173 174
}

175 176
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
    !defined(PADDLE_WITH_ASCEND_CL)
177 178
bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
  // send the message inter rank (dst is different rank with src)
179 180 181 182 183 184 185 186
  int64_t dst_id = interceptor_message.dst_id();
  int64_t dst_rank = interceptor_id_to_rank_[dst_id];
  auto dst_ip = rank_to_addr_.find(dst_rank);
  PADDLE_ENFORCE_NE(dst_ip, rank_to_addr_.end(),
                    platform::errors::InvalidArgument(
                        "Cannot find rank for dst interceptor id %lld. "
                        "Init error.",
                        dst_id));
187
  VLOG(3) << "Message bus sending to addr: " << dst_ip->second;
188 189 190 191
  const char* dst_ip_for_brpc = dst_ip->second.c_str();
  brpc::Channel channel;
  brpc::ChannelOptions options;
  options.protocol = "baidu_std";
192
  options.connect_timeout_ms = 1000;
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
  options.timeout_ms = 1000;
  options.max_retry = 5;
  PADDLE_ENFORCE_EQ(
      channel.Init(dst_ip_for_brpc, &options), 0,
      platform::errors::Unavailable("Message bus: init brpc channel error."));
  TheInterceptorMessageService_Stub stub(&channel);
  InterceptorResponse response;
  brpc::Controller ctrl;
  ctrl.set_log_id(0);
  stub.InterceptorMessageService(&ctrl, &interceptor_message, &response, NULL);
  if (!ctrl.Failed()) {
    if (response.rst()) {
      VLOG(3) << "Message bus: brpc sends success.";
      return true;
    } else {
208
      VLOG(4) << "Message bus: InterceptorMessageService error.";
209 210 211
      return false;
    }
  } else {
212
    VLOG(4) << "Message bus: brpc sends failed with error text: "
213 214 215
            << ctrl.ErrorText();
    return false;
  }
216 217 218 219 220
}
#endif

bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) {
  // send the message intra rank (dst is the same rank with src)
221
  return Carrier::Instance().EnqueueInterceptorMessage(interceptor_message);
222 223 224 225
}

}  // namespace distributed
}  // namespace paddle