message_bus.cc 8.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"

17
#include <chrono>
18
#include <memory>
Y
Yuang Liu 已提交
19
#include <set>
20
#include <thread>
21

22 23
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
24
#include "paddle/fluid/platform/gen_comm_id_helper.h"
25 26 27 28

namespace paddle {
namespace distributed {

29
void MessageBus::Init(
30 31
    int64_t rank,
    const std::unordered_map<int64_t, std::string>& rank_to_addr,
32
    const std::string& addr) {
33
  PADDLE_ENFORCE_EQ(
34 35
      is_init_,
      false,
36
      platform::errors::AlreadyExists("MessageBus is already init."));
37
  rank_ = rank;
38 39 40 41
  is_init_ = true;
  rank_to_addr_ = rank_to_addr;
  addr_ = addr;

42 43
  if (addr_ != "") {
    const auto& addr = GetAddr(rank_);
44 45
    PADDLE_ENFORCE_EQ(addr,
                      addr_,
46 47 48 49
                      platform::errors::Fatal(
                          "The current rank's addr is %s, while the "
                          "message bus's addr is %s, which are different. "
                          "Init error.",
50 51
                          addr,
                          addr_));
52 53
  }

54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
    defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
  // NOTE: To make the brpc is compatible with collective,
  // need release the handler holding the ip address.
  if (addr_ != "") {
    VLOG(3) << "Message bus is releasing the fd held by gen_comm_id.";
    paddle::platform::SocketServer& socket_server =
        paddle::platform::SocketServer::GetInstance(addr_);
    int server_fd = socket_server.socket();
    if (server_fd != -1) {
      socket_server.Release();
    }
  }
#endif

69
  ListenPort();
70 71
}

72 73
bool MessageBus::IsInit() const { return is_init_; }

74
MessageBus::~MessageBus() {
Y
Yuang Liu 已提交
75
  VLOG(3) << "Message bus releases resource.";
L
LiYuRio 已提交
76
#if defined(PADDLE_WITH_DISTRIBUTE)
77 78 79
  server_.Stop(1000);
  server_.Join();
#endif
80 81
}

82 83
const std::string& MessageBus::GetAddr(int64_t rank) const {
  PADDLE_ENFORCE_NE(
84 85
      rank_to_addr_.find(rank),
      rank_to_addr_.end(),
86 87 88 89 90 91
      platform::errors::NotFound("Cannot find addr rank id %lld.", rank));
  return rank_to_addr_.at(rank);
}

bool MessageBus::Send(int64_t dst_rank,
                      const InterceptorMessage& interceptor_message) {
92
  PADDLE_ENFORCE_EQ(
93 94
      IsInit(),
      true,
95 96
      platform::errors::PreconditionNotMet(
          "Using message bus since it has not been initialized."));
L
LiYuRio 已提交
97
#if defined(PADDLE_WITH_DISTRIBUTE)
98 99 100 101 102 103 104
  int retry_time = 0;  // message bus will retry sending for 10 times
  while (retry_time < 10) {
    ++retry_time;
    if (SendInterRank(dst_rank, interceptor_message)) {
      VLOG(3) << "Message bus sends inter rank successfully with " << retry_time
              << " times retries.";
      return true;
105
    }
106 107 108 109 110
    VLOG(3) << "Message bus sends failed, retry after 1 seconds.";
    std::this_thread::sleep_for(std::chrono::milliseconds(1000));
  }
  VLOG(3) << "Message bus sends inter rank fail after 10 times retries.";
  return false;
111
#else
112 113 114 115
  PADDLE_THROW(platform::errors::Unavailable(
      "Fleet executor does not support sending message between different "
      "ranks when Paddle is compiled with npu or "
      "isn't compiled with distributed for now."));
116
#endif
117 118 119
  return true;
}

120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
void MessageBus::IncreaseBarrierCount() {
  VLOG(3) << "IncreaseBarrierCount";
  {
    std::unique_lock<std::mutex> lock(mutex_);
    ++count_;
    cv_.notify_one();
  }
  VLOG(3) << "End IncreaseBarrierCount";
}

void MessageBus::Barrier() {
  // gather to root
  if (rank_ != 0) {
    InterceptorMessage ctrl_msg;
    ctrl_msg.set_ctrl_message(true);
    ctrl_msg.set_src_id(rank_);
    ctrl_msg.set_dst_id(0);
    VLOG(3) << "Barrier Gather ctrl message from " << rank_ << " to 0";
    while (!Send(0, ctrl_msg)) {
      std::this_thread::sleep_for(std::chrono::milliseconds(1000));
    }
  } else {
    VLOG(3) << "Barrier 0 wait others rank ready";
    std::unique_lock<std::mutex> lock(mutex_);
    cv_.wait(lock, [this] {
      return count_ == static_cast<int>(rank_to_addr_.size() - 1);
    });
    count_ = 0;
  }

  // scatter from root
  if (rank_ == 0) {
    for (int i = 1; i < static_cast<int>(rank_to_addr_.size()); ++i) {
      InterceptorMessage ctrl_msg;
      ctrl_msg.set_ctrl_message(true);
      ctrl_msg.set_src_id(0);
      ctrl_msg.set_dst_id(i);
      VLOG(3) << "Barrier Scatter ctrl message from 0 to " << i;
      while (!Send(i, ctrl_msg)) {
159 160 161
        std::this_thread::sleep_for(std::chrono::milliseconds(1000));
      }
    }
162 163 164 165 166
  } else {
    VLOG(3) << "Barrier " << rank_ << " wait others rank ready";
    std::unique_lock<std::mutex> lock(mutex_);
    cv_.wait(lock, [this] { return count_ == 1; });
    count_ = 0;
167 168 169
  }
}

170 171
bool MessageBus::DispatchMsgToCarrier(
    const InterceptorMessage& interceptor_message) {
172 173 174
  const std::string& carrier_id = *GlobalVal<std::string>::Get();
  return GlobalMap<std::string, Carrier>::Get(carrier_id)
      ->EnqueueInterceptorMessage(interceptor_message);
175 176
}

177
void MessageBus::ListenPort() {
178
  if (addr_ == "") {
179
    LOG(INFO) << "No need listen to port since training on single card.";
180 181
    return;
  }
L
LiYuRio 已提交
182
#if defined(PADDLE_WITH_DISTRIBUTE)
183
  // function keep listen the port and handle the message
184
  PADDLE_ENFORCE_EQ(
185 186
      server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE),
      0,
187
      platform::errors::Unavailable("Message bus: init brpc service error."));
188 189 190 191 192

  // start the server
  const char* ip_for_brpc = addr_.c_str();
  brpc::ServerOptions options;
  options.idle_timeout_sec = -1;
193
  int retry_times = 0;
Y
Yuang Liu 已提交
194
  int interval = 100;
195 196 197 198 199 200
  while (server_.Start(ip_for_brpc, &options) != 0) {
    ++retry_times;
    LOG(INFO) << "Message bus is retring for starting brpc for " << retry_times
              << " times. And will retry after " << interval / 1000
              << " seconds.";
    std::this_thread::sleep_for(std::chrono::milliseconds(interval));
Y
Yuang Liu 已提交
201
    interval += 500;
202 203
  }
  LOG(INFO) << "Message bus's listen port thread starts successful.";
204
#else
205 206 207 208
  LOG(WARNING)
      << "Fleet executor's ListenPort() is a fake function when Paddle is "
         "compiled with npu or Paddle isn't compiled "
         "with distributed for now.";
209
#endif
210 211
}

L
LiYuRio 已提交
212
#if defined(PADDLE_WITH_DISTRIBUTE)
213 214 215 216 217
bool MessageBus::SendInterRank(int64_t dst_rank,
                               const InterceptorMessage& interceptor_message) {
  const auto& dst_addr = GetAddr(dst_rank);
  VLOG(3) << "Message bus sending to addr: " << dst_addr;
  const char* dst_addr_for_brpc = dst_addr.c_str();
218 219 220
  brpc::Channel channel;
  brpc::ChannelOptions options;
  options.protocol = "baidu_std";
221
  options.connect_timeout_ms = 1000;
222 223 224
  options.timeout_ms = 1000;
  options.max_retry = 5;
  PADDLE_ENFORCE_EQ(
225 226
      channel.Init(dst_addr_for_brpc, &options),
      0,
227
      platform::errors::Unavailable("Message bus: init brpc channel error."));
228
  MessageService_Stub stub(&channel);
229 230 231
  InterceptorResponse response;
  brpc::Controller ctrl;
  ctrl.set_log_id(0);
232 233 234
  if (interceptor_message.ctrl_message()) {
    stub.IncreaseBarrierCount(&ctrl, &interceptor_message, &response, NULL);
  } else {
235 236
    stub.ReceiveInterceptorMessage(
        &ctrl, &interceptor_message, &response, NULL);
237
  }
238 239 240 241 242
  if (!ctrl.Failed()) {
    if (response.rst()) {
      VLOG(3) << "Message bus: brpc sends success.";
      return true;
    } else {
243
      VLOG(4) << "Message bus: InterceptorMessageService error.";
244 245 246
      return false;
    }
  } else {
247
    VLOG(4) << "Message bus: brpc sends failed with error text: "
248 249 250
            << ctrl.ErrorText();
    return false;
  }
251
}
252

253 254 255 256
#endif

}  // namespace distributed
}  // namespace paddle