interceptor.cc 4.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
16
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
17
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
18
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
19 20 21 22

namespace paddle {
namespace distributed {

23
Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node)
24 25 26 27 28 29 30 31
    : interceptor_id_(interceptor_id), node_(node) {}

Interceptor::~Interceptor() {
  // FIXME(wangxi): throw in stop function
  // std::lock_guard<std::mutex> lock(mutex_);
  // PADDLE_ENFORCE_EQ(messages_.empty(), true,
  //                  platform::errors::PreconditionNotMet(
  //                      "Interceptor must destruct with messages empty"));
32
}
33

34
void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; }
35 36

void Interceptor::Handle(const InterceptorMessage& msg) {
37 38 39
  PADDLE_ENFORCE_NOT_NULL(handle_, platform::errors::PreconditionNotMet(
                                       "Message handle is not registered."));
  handle_(msg);
40 41
}

42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
void Interceptor::LoopOnce() {
  std::deque<InterceptorMessage> tmp_messages;
  {
    std::lock_guard<std::mutex> lock(mutex_);
    messages_.swap(tmp_messages);
  }
  PADDLE_ENFORCE_EQ(tmp_messages.empty(), false,
                    platform::errors::PreconditionNotMet(
                        "tmp_messages must not empty in task loop"));

  for (auto& msg : tmp_messages) {
    const MessageType message_type = msg.message_type();
    VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message"
            << " from interceptor " << msg.src_id()
            << " with message: " << message_type << ".";

    Handle(msg);
  }
}

62
void Interceptor::StopCarrier() {
63 64
  PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet(
                                        "Carrier is not registered."));
65
  carrier_->WakeUp();
66 67
}

68
void Interceptor::EnqueueRemoteInterceptorMessage(
69
    const InterceptorMessage& message) {
70
  // Called by Carrier, enqueue an InterceptorMessage to remote mailbox
71 72 73 74 75 76 77 78 79 80 81 82
  VLOG(3) << "Enqueue message: " << message.message_type() << " into "
          << interceptor_id_ << "'s remote mailbox.";

  bool empty = false;
  {
    std::lock_guard<std::mutex> lock(mutex_);
    empty = messages_.empty();
    messages_.emplace_back(message);
  }
  if (empty) {
    loop_->QueueInLoop([this]() { LoopOnce(); });
  }
83 84
}

85
bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
86 87
  PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet(
                                        "Carrier is not registered."));
88 89
  msg.set_src_id(interceptor_id_);
  msg.set_dst_id(dst_id);
90
  return carrier_->Send(msg);
91 92
}

93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() {
  static InterceptorFactory::CreateInterceptorMap interceptorMap;
  return interceptorMap;
}

std::unique_ptr<Interceptor> InterceptorFactory::Create(const std::string& type,
                                                        int64_t id,
                                                        TaskNode* node) {
  auto& interceptor_map = GetInterceptorMap();
  auto iter = interceptor_map.find(type);
  PADDLE_ENFORCE_NE(
      iter, interceptor_map.end(),
      platform::errors::NotFound("interceptor %s is not register", type));
  return iter->second(id, node);
}

void InterceptorFactory::Register(
    const std::string& type, InterceptorFactory::CreateInterceptorFunc func) {
  auto& interceptor_map = GetInterceptorMap();
  interceptor_map.emplace(type, func);
}

115 116
}  // namespace distributed
}  // namespace paddle