message_bus.cc 8.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <chrono>
16
#include <memory>
Y
Yuang Liu 已提交
17
#include <set>
18
#include <thread>
19

20 21
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
22
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
23
#include "paddle/fluid/platform/gen_comm_id_helper.h"
24 25 26 27

namespace paddle {
namespace distributed {

28
void MessageBus::Init(
29
    int64_t rank, const std::unordered_map<int64_t, std::string>& rank_to_addr,
30 31 32
    const std::string& addr) {
  PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
                                         "MessageBus is already init."));
33
  rank_ = rank;
34 35 36 37
  is_init_ = true;
  rank_to_addr_ = rank_to_addr;
  addr_ = addr;

38 39 40 41 42 43 44 45 46 47
  if (addr_ != "") {
    const auto& addr = GetAddr(rank_);
    PADDLE_ENFORCE_EQ(addr, addr_,
                      platform::errors::Fatal(
                          "The current rank's addr is %s, while the "
                          "message bus's addr is %s, which are different. "
                          "Init error.",
                          addr, addr_));
  }

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
    defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
  // NOTE: To make the brpc is compatible with collective,
  // need release the handler holding the ip address.
  if (addr_ != "") {
    VLOG(3) << "Message bus is releasing the fd held by gen_comm_id.";
    paddle::platform::SocketServer& socket_server =
        paddle::platform::SocketServer::GetInstance(addr_);
    int server_fd = socket_server.socket();
    if (server_fd != -1) {
      socket_server.Release();
    }
  }
#endif

63
  ListenPort();
64 65
}

66 67
bool MessageBus::IsInit() const { return is_init_; }

68
MessageBus::~MessageBus() {
Y
Yuang Liu 已提交
69
  VLOG(3) << "Message bus releases resource.";
70
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
71 72 73
  server_.Stop(1000);
  server_.Join();
#endif
74 75
}

76 77 78 79 80 81 82 83 84
const std::string& MessageBus::GetAddr(int64_t rank) const {
  PADDLE_ENFORCE_NE(
      rank_to_addr_.find(rank), rank_to_addr_.end(),
      platform::errors::NotFound("Cannot find addr rank id %lld.", rank));
  return rank_to_addr_.at(rank);
}

bool MessageBus::Send(int64_t dst_rank,
                      const InterceptorMessage& interceptor_message) {
85 86 87 88
  PADDLE_ENFORCE_EQ(
      IsInit(), true,
      platform::errors::PreconditionNotMet(
          "Using message bus since it has not been initialized."));
89
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
90 91 92 93 94 95 96
  int retry_time = 0;  // message bus will retry sending for 10 times
  while (retry_time < 10) {
    ++retry_time;
    if (SendInterRank(dst_rank, interceptor_message)) {
      VLOG(3) << "Message bus sends inter rank successfully with " << retry_time
              << " times retries.";
      return true;
97
    }
98 99 100 101 102
    VLOG(3) << "Message bus sends failed, retry after 1 seconds.";
    std::this_thread::sleep_for(std::chrono::milliseconds(1000));
  }
  VLOG(3) << "Message bus sends inter rank fail after 10 times retries.";
  return false;
103
#else
104 105 106 107
  PADDLE_THROW(platform::errors::Unavailable(
      "Fleet executor does not support sending message between different "
      "ranks when Paddle is compiled with npu or "
      "isn't compiled with distributed for now."));
108
#endif
109 110 111
  return true;
}

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
void MessageBus::IncreaseBarrierCount() {
  VLOG(3) << "IncreaseBarrierCount";
  {
    std::unique_lock<std::mutex> lock(mutex_);
    ++count_;
    cv_.notify_one();
  }
  VLOG(3) << "End IncreaseBarrierCount";
}

void MessageBus::Barrier() {
  // gather to root
  if (rank_ != 0) {
    InterceptorMessage ctrl_msg;
    ctrl_msg.set_ctrl_message(true);
    ctrl_msg.set_src_id(rank_);
    ctrl_msg.set_dst_id(0);
    VLOG(3) << "Barrier Gather ctrl message from " << rank_ << " to 0";
    while (!Send(0, ctrl_msg)) {
      std::this_thread::sleep_for(std::chrono::milliseconds(1000));
    }
  } else {
    VLOG(3) << "Barrier 0 wait others rank ready";
    std::unique_lock<std::mutex> lock(mutex_);
    cv_.wait(lock, [this] {
      return count_ == static_cast<int>(rank_to_addr_.size() - 1);
    });
    count_ = 0;
  }

  // scatter from root
  if (rank_ == 0) {
    for (int i = 1; i < static_cast<int>(rank_to_addr_.size()); ++i) {
      InterceptorMessage ctrl_msg;
      ctrl_msg.set_ctrl_message(true);
      ctrl_msg.set_src_id(0);
      ctrl_msg.set_dst_id(i);
      VLOG(3) << "Barrier Scatter ctrl message from 0 to " << i;
      while (!Send(i, ctrl_msg)) {
151 152 153
        std::this_thread::sleep_for(std::chrono::milliseconds(1000));
      }
    }
154 155 156 157 158
  } else {
    VLOG(3) << "Barrier " << rank_ << " wait others rank ready";
    std::unique_lock<std::mutex> lock(mutex_);
    cv_.wait(lock, [this] { return count_ == 1; });
    count_ = 0;
159 160 161
  }
}

162 163
bool MessageBus::DispatchMsgToCarrier(
    const InterceptorMessage& interceptor_message) {
164 165 166
  const std::string& carrier_id = *GlobalVal<std::string>::Get();
  return GlobalMap<std::string, Carrier>::Get(carrier_id)
      ->EnqueueInterceptorMessage(interceptor_message);
167 168
}

169
void MessageBus::ListenPort() {
170
  if (addr_ == "") {
171
    LOG(INFO) << "No need listen to port since training on single card.";
172 173
    return;
  }
174
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
175
  // function keep listen the port and handle the message
176 177 178
  PADDLE_ENFORCE_EQ(
      server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE), 0,
      platform::errors::Unavailable("Message bus: init brpc service error."));
179 180 181 182 183

  // start the server
  const char* ip_for_brpc = addr_.c_str();
  brpc::ServerOptions options;
  options.idle_timeout_sec = -1;
184
  int retry_times = 0;
Y
Yuang Liu 已提交
185
  int interval = 100;
186 187 188 189 190 191
  while (server_.Start(ip_for_brpc, &options) != 0) {
    ++retry_times;
    LOG(INFO) << "Message bus is retring for starting brpc for " << retry_times
              << " times. And will retry after " << interval / 1000
              << " seconds.";
    std::this_thread::sleep_for(std::chrono::milliseconds(interval));
Y
Yuang Liu 已提交
192
    interval += 500;
193 194
  }
  LOG(INFO) << "Message bus's listen port thread starts successful.";
195
#else
196 197 198 199
  LOG(WARNING)
      << "Fleet executor's ListenPort() is a fake function when Paddle is "
         "compiled with npu or Paddle isn't compiled "
         "with distributed for now.";
200
#endif
201 202
}

203
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
204 205 206 207 208
bool MessageBus::SendInterRank(int64_t dst_rank,
                               const InterceptorMessage& interceptor_message) {
  const auto& dst_addr = GetAddr(dst_rank);
  VLOG(3) << "Message bus sending to addr: " << dst_addr;
  const char* dst_addr_for_brpc = dst_addr.c_str();
209 210 211
  brpc::Channel channel;
  brpc::ChannelOptions options;
  options.protocol = "baidu_std";
212
  options.connect_timeout_ms = 1000;
213 214 215
  options.timeout_ms = 1000;
  options.max_retry = 5;
  PADDLE_ENFORCE_EQ(
216
      channel.Init(dst_addr_for_brpc, &options), 0,
217
      platform::errors::Unavailable("Message bus: init brpc channel error."));
218
  MessageService_Stub stub(&channel);
219 220 221
  InterceptorResponse response;
  brpc::Controller ctrl;
  ctrl.set_log_id(0);
222 223 224 225 226 227
  if (interceptor_message.ctrl_message()) {
    stub.IncreaseBarrierCount(&ctrl, &interceptor_message, &response, NULL);
  } else {
    stub.ReceiveInterceptorMessage(&ctrl, &interceptor_message, &response,
                                   NULL);
  }
228 229 230 231 232
  if (!ctrl.Failed()) {
    if (response.rst()) {
      VLOG(3) << "Message bus: brpc sends success.";
      return true;
    } else {
233
      VLOG(4) << "Message bus: InterceptorMessageService error.";
234 235 236
      return false;
    }
  } else {
237
    VLOG(4) << "Message bus: brpc sends failed with error text: "
238 239 240
            << ctrl.ErrorText();
    return false;
  }
241
}
242

243 244 245 246
#endif

}  // namespace distributed
}  // namespace paddle