carrier.cc 9.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
18
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
19
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
20
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
21
#include "paddle/fluid/framework/garbage_collector.h"
22
#include "paddle/fluid/framework/scope.h"
23 24 25 26

namespace paddle {
namespace distributed {

27
USE_INTERCEPTOR(Compute);
28
USE_INTERCEPTOR(Amplifier);
29

30
void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
31 32 33 34
                   framework::Scope* root_scope,
                   framework::Scope* minibatch_scope,
                   const std::vector<framework::Scope*>& microbatch_scopes,
                   const platform::Place& place) {
35 36
  PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
                                         "Carrier is already init."));
37
  rank_ = rank;
38
  runtime_graph_ = runtime_graph;
39
  interceptor_id_to_rank_ = runtime_graph_->interceptor_id_to_rank();
40 41 42
  minibatch_scope_ = minibatch_scope;
  microbatch_scopes_ = microbatch_scopes;
  place_ = place;
43 44
  root_scope_ = root_scope;
  dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
45 46 47 48 49 50

  // TODO(fleet_exe dev): thread pool
  thread_num_ = 1;
  thread_pool_.SetThreadNum(thread_num_);
  thread_pool_.Start();

51
  CreateInterceptors();
52
  is_init_ = true;
53 54
}

55
void Carrier::Release() {}
56

57 58
Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }

59 60
bool Carrier::EnqueueInterceptorMessage(
    const InterceptorMessage& interceptor_message) {
61
  if (interceptor_message.ctrl_message()) {
62 63 64
    VLOG(3) << "Receiving control message from rank "
            << interceptor_message.src_id() << " to rank "
            << interceptor_message.dst_id();
65 66
    // for barrier
    msg_bus_->IncreaseBarrierCount();
67 68 69
  } else {
    int64_t dst_id = interceptor_message.dst_id();
    Interceptor* dst_interceptor = GetInterceptor(dst_id);
70
    dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
71
  }
72
  return true;
73 74
}

75 76
void Carrier::Barrier() { msg_bus_->Barrier(); }

77 78 79 80 81 82 83 84
Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
  auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
  PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(),
                    platform::errors::InvalidArgument(
                        "Cannot find interceptor instance for interceptor "
                        "id %lld. Wrong dst? Call before init?",
                        interceptor_id));
  return iter->second.get();
85 86
}

87 88 89 90 91
void Carrier::Wait() {
  std::unique_lock<std::mutex> lock(running_mutex_);
  cond_var_.wait(lock);
}

92 93 94 95 96
void Carrier::WakeUp() {
  // probably double notify, but ok for ut
  cond_var_.notify_all();
}

97
void Carrier::Start() {
98
  PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
99
                    platform::errors::PreconditionNotMet(
100 101 102
                        "Using message bus since it has not been initialized. "
                        "Please invoke MessageBus::Init() before using it or "
                        "neccessary components are not ready."));
103 104 105 106 107 108 109 110 111

  for (int64_t id : source_interceptor_ids_) {
    VLOG(3) << "Carrier Start is sending start to source interceptor " << id
            << ".";
    InterceptorMessage start_msg;
    // source node data_is_ready is send by carrier, so set src_id=-1
    start_msg.set_src_id(-1);
    start_msg.set_dst_id(id);
    start_msg.set_message_type(DATA_IS_READY);
112
    Send(start_msg);
113
  }
114
  // TODO(wangxi): async step
115
  Wait();
116
  dev_ctx_->Wait();
117 118 119 120
}

bool Carrier::IsInit() const { return is_init_; }

121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
int64_t Carrier::GetRank(int64_t interceptor_id) const {
  PADDLE_ENFORCE_NE(
      interceptor_id_to_rank_.find(interceptor_id),
      interceptor_id_to_rank_.end(),
      platform::errors::NotFound("Cannot find rank for interceptor id %lld.",
                                 interceptor_id));
  return interceptor_id_to_rank_.at(interceptor_id);
}

bool Carrier::Send(const InterceptorMessage& msg) {
  int64_t src_id = (msg.src_id() == -1) ? msg.dst_id() : msg.src_id();
  int64_t dst_id = msg.dst_id();
  int64_t src_rank = GetRank(src_id);
  int64_t dst_rank = GetRank(dst_id);
  PADDLE_ENFORCE_EQ(
      src_rank, rank_,
      platform::errors::Fatal("The source rank id %lld, which is not equal to "
                              "the carrier rank id %lld.",
                              src_rank, rank_));
  if (src_rank == dst_rank) {
    VLOG(3) << "Send a message from interceptor " << src_id
            << " to interceptor " << dst_id << ", which are in the same ranks.";
    return EnqueueInterceptorMessage(msg);
  } else {
    PADDLE_ENFORCE_NOT_NULL(
        msg_bus_.get(),
        platform::errors::Unavailable("Message bus is released accidently"));
    PADDLE_ENFORCE_EQ(
        msg_bus_->IsInit(), true,
        platform::errors::PreconditionNotMet(
            "Using message bus since it has not been initialized. "
            "Please invoke MessageBus::Init() before using it or "
            "neccessary components are not ready."));
    VLOG(3) << "Send a message from interceptor " << src_id
            << " to interceptor " << dst_id
            << ", which are in different ranks.";
    return msg_bus_->Send(dst_rank, msg);
  }
159 160
}

161 162 163 164 165 166 167 168
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
                                     std::unique_ptr<Interceptor> interceptor) {
  auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
  PADDLE_ENFORCE_EQ(iter, interceptor_idx_to_interceptor_.end(),
                    platform::errors::AlreadyExists(
                        "The interceptor id %lld has already been created! "
                        "The interceptor id should be unique.",
                        interceptor_id));
169
  interceptor->RegisterCarrier(this);
170 171 172 173 174 175 176

  // TODO(fleet_exe dev): get loop
  auto* loop = thread_pool_.GetLoop(interceptor_id % thread_num_);
  PADDLE_ENFORCE_NOT_NULL(
      loop, platform::errors::Fatal("thread task loop must not null"));
  interceptor->RegisterTaskLoop(loop);

177 178 179 180 181 182
  auto* ptr = interceptor.get();
  interceptor_idx_to_interceptor_.insert(
      std::make_pair(interceptor_id, std::move(interceptor)));
  return ptr;
}

183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
static std::shared_ptr<framework::GarbageCollector> GetGC(
    const platform::Place& place) {
  int64_t max_memory_size = framework::GetEagerDeletionThreshold();
  std::shared_ptr<framework::GarbageCollector> gc;
  if (max_memory_size >= 0) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    if (platform::is_gpu_place(place)) {
      if (framework::IsFastEagerDeletionModeEnabled()) {
        gc.reset(new framework::UnsafeFastGPUGarbageCollector(
            BOOST_GET_CONST(platform::CUDAPlace, place), max_memory_size));
      }
    }
#endif
  }  // max_memory_size >= 0

  return gc;
}

201
void Carrier::CreateInterceptors() {
202
  if (runtime_graph_->interceptor_id_to_node().empty()) return;
203 204 205

  auto gc = GetGC(place_);

206
  // create each Interceptor
207
  // no auto init since there is no config
208
  for (const auto& item : runtime_graph_->interceptor_id_to_node()) {
209 210
    int64_t interceptor_id = item.first;
    TaskNode* task_node = item.second;
211

212 213 214 215 216 217
    PADDLE_ENFORCE_LT(
        task_node->run_at_offset(), task_node->run_per_steps(),
        platform::errors::InvalidArgument(
            "Interceptor's run_at_offset must < run_per_steps, must now "
            "run_at_offset=%ld run_per_steps=%ld",
            task_node->run_at_offset(), task_node->run_per_steps()));
218

219
    std::unique_ptr<Interceptor> interceptor;
220 221 222 223 224 225
    PADDLE_ENFORCE_NE(task_node->type().empty(), true,
                      platform::errors::NotFound(
                          "Cannot found type for task node with id %lld",
                          task_node->task_id()));
    interceptor = InterceptorFactory::Create(task_node->type(), interceptor_id,
                                             task_node);
226 227 228 229 230 231 232 233 234 235 236 237
    interceptor->SetPlace(place_);
    interceptor->SetMiniBatchScope(minibatch_scope_);
    interceptor->SetMicroBatchScope(microbatch_scopes_);
    interceptor->SetRootScope(root_scope_);
    interceptor->SetGC(gc);

    SetInterceptor(interceptor_id, std::move(interceptor));
    VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
            << " with type: " << task_node->type() << ".";

    if (task_node->upstream().empty()) {
      source_interceptor_ids_.emplace_back(interceptor_id);
238
    }
239
  }
240 241 242 243
}

}  // namespace distributed
}  // namespace paddle