message_bus.cc 7.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include <memory>

17
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
18 19
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
20 21 22 23

namespace paddle {
namespace distributed {

24
void MessageBus::Init(
25 26
    const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
    const std::unordered_map<int64_t, std::string>& rank_to_addr,
27 28 29 30 31 32 33 34
    const std::string& addr) {
  PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
                                         "MessageBus is already init."));
  is_init_ = true;
  interceptor_id_to_rank_ = interceptor_id_to_rank;
  rank_to_addr_ = rank_to_addr;
  addr_ = addr;

35
  ListenPort();
36 37 38 39

  std::call_once(once_flag_, []() {
    std::atexit([]() { MessageBus::Instance().Release(); });
  });
40 41
}

42 43
bool MessageBus::IsInit() const { return is_init_; }

44
void MessageBus::Release() {
45
  VLOG(3) << "Message bus releases resource.";
46 47 48 49 50
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
    !defined(PADDLE_WITH_ASCEND_CL)
  server_.Stop(1000);
  server_.Join();
#endif
51 52 53 54
}

bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
  // called by Interceptor, send InterceptorMessage to dst
55 56 57
  int64_t src_id = interceptor_message.src_id();
  int64_t dst_id = interceptor_message.dst_id();
  if (IsSameRank(src_id, dst_id)) {
58 59
    VLOG(3) << "Send a message from rank " << src_id << " to rank " << dst_id
            << ", which are same ranks.";
60 61
    return SendIntraRank(interceptor_message);
  } else {
62 63
    VLOG(3) << "Send a message from rank " << src_id << " to rank " << dst_id
            << ", which are different ranks.";
64 65
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
    !defined(PADDLE_WITH_ASCEND_CL)
66 67 68 69 70 71 72 73 74 75 76
    int retry_time = 0;  // message bus will retry sending for 10 times
    while (retry_time < 10) {
      ++retry_time;
      if (SendInterRank(interceptor_message)) {
        VLOG(3) << "Message bus sends inter rank successfully with "
                << retry_time << " times retries.";
        return true;
      }
    }
    VLOG(3) << "Message bus sends inter rank fail after 10 times retries.";
    return false;
77 78 79 80 81 82 83
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "Fleet executor does not support sending message between different "
        "ranks when Paddle is compiled with npu or "
        "isn't compiled with distributed for now."));
#endif
  }
84 85 86 87
  return true;
}

void MessageBus::ListenPort() {
88 89 90 91
  if (addr_ == "") {
    VLOG(3) << "No need listen to port since training on single card.";
    return;
  }
92 93
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
    !defined(PADDLE_WITH_ASCEND_CL)
94
  // function keep listen the port and handle the message
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
  InterceptorMessageServiceImpl interceptor_message_service;
  PADDLE_ENFORCE_EQ(server_.AddService(&interceptor_message_service,
                                       brpc::SERVER_DOESNT_OWN_SERVICE),
                    0, platform::errors::Unavailable(
                           "Message bus: init brpc service error."));

  // start the server
  const char* ip_for_brpc = addr_.c_str();
  brpc::ServerOptions options;
  options.idle_timeout_sec = -1;
  PADDLE_ENFORCE_EQ(
      server_.Start(ip_for_brpc, &options), 0,
      platform::errors::Unavailable("Message bus: start brpc service error."));
  VLOG(3) << "Message bus's listen port thread starts successful.";
#else
  VLOG(3) << "Fleet executor's ListenPort() is a fake function when Paddle is "
             "compiled with npu or Paddle isn't compiled "
             "with distributed for now.";
#endif
114 115 116 117
}

bool MessageBus::IsSameRank(int64_t src_id, int64_t dst_id) {
  // check whether the dst is the same rank or different rank with src
118 119 120 121 122 123 124 125 126 127
  const auto& src_rank = interceptor_id_to_rank_.find(src_id);
  const auto& dst_rank = interceptor_id_to_rank_.find(dst_id);
  PADDLE_ENFORCE_NE(
      src_rank, interceptor_id_to_rank_.end(),
      platform::errors::NotFound(
          "Cannot find rank for src interceptor id %lld. Init error.", src_id));
  PADDLE_ENFORCE_NE(
      dst_rank, interceptor_id_to_rank_.end(),
      platform::errors::NotFound(
          "Cannot find rank for dst interceptor id %lld. Init error.", dst_id));
128 129 130 131
  if (addr_ == "") {
    // single card training, must be same rank
    return true;
  }
132 133 134 135 136 137 138 139 140 141 142 143
  const auto& src_ip = rank_to_addr_.find(src_rank->second);
  PADDLE_ENFORCE_NE(src_ip, rank_to_addr_.end(),
                    platform::errors::NotFound(
                        "Cannot find addr for src rank id %lld. Init error.",
                        src_rank->second));
  PADDLE_ENFORCE_EQ(
      src_ip->second, addr_,
      platform::errors::Fatal("The src interceptor's addr is %s, while the "
                              "message bus's addr is %s, which are different. "
                              "Init error.",
                              src_ip->second, addr_));
  return src_rank->second == dst_rank->second;
144 145
}

146 147
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
    !defined(PADDLE_WITH_ASCEND_CL)
148 149
bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
  // send the message inter rank (dst is different rank with src)
150 151 152 153 154 155 156 157 158 159 160 161
  int64_t dst_id = interceptor_message.dst_id();
  int64_t dst_rank = interceptor_id_to_rank_[dst_id];
  auto dst_ip = rank_to_addr_.find(dst_rank);
  PADDLE_ENFORCE_NE(dst_ip, rank_to_addr_.end(),
                    platform::errors::InvalidArgument(
                        "Cannot find rank for dst interceptor id %lld. "
                        "Init error.",
                        dst_id));
  const char* dst_ip_for_brpc = dst_ip->second.c_str();
  brpc::Channel channel;
  brpc::ChannelOptions options;
  options.protocol = "baidu_std";
162
  options.connect_timeout_ms = 1000;
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
  options.timeout_ms = 1000;
  options.max_retry = 5;
  PADDLE_ENFORCE_EQ(
      channel.Init(dst_ip_for_brpc, &options), 0,
      platform::errors::Unavailable("Message bus: init brpc channel error."));
  TheInterceptorMessageService_Stub stub(&channel);
  InterceptorResponse response;
  brpc::Controller ctrl;
  ctrl.set_log_id(0);
  stub.InterceptorMessageService(&ctrl, &interceptor_message, &response, NULL);
  if (!ctrl.Failed()) {
    if (response.rst()) {
      VLOG(3) << "Message bus: brpc sends success.";
      return true;
    } else {
178
      VLOG(4) << "Message bus: InterceptorMessageService error.";
179 180 181
      return false;
    }
  } else {
182
    VLOG(4) << "Message bus: brpc sends failed with error text: "
183 184 185
            << ctrl.ErrorText();
    return false;
  }
186 187 188 189 190
}
#endif

bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) {
  // send the message intra rank (dst is the same rank with src)
191
  return Carrier::Instance().EnqueueInterceptorMessage(interceptor_message);
192 193 194 195
}

}  // namespace distributed
}  // namespace paddle