carrier.cc 13.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/distributed/fleet_executor/carrier.h"

17
#include <algorithm>
18
#include <vector>
19

20
#include "paddle/fluid/distributed/fleet_executor/global.h"
21
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
22
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
23
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
24
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
25
#include "paddle/fluid/framework/garbage_collector.h"
26
#include "paddle/fluid/framework/program_desc.h"
27
#include "paddle/fluid/framework/scope.h"
28
#include "paddle/fluid/framework/variable.h"
29
#include "paddle/fluid/framework/variable_helper.h"
30 31 32 33

namespace paddle {
namespace distributed {

34
USE_INTERCEPTOR(Source);
35
USE_INTERCEPTOR(Compute);
36
USE_INTERCEPTOR(Amplifier);
37
USE_INTERCEPTOR(Sink);
38 39
USE_INTERCEPTOR(Cond);
USE_INTERCEPTOR(Start);
40

41 42
void Carrier::Init(
    int64_t rank,
43
    const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) {
44
  rank_ = rank;
45 46 47 48 49 50 51 52 53 54 55 56
  interceptor_id_to_rank_ = interceptor_id_to_rank;

  // TODO(fleet_exe dev): thread pool
  thread_num_ = 1;
  thread_pool_.SetThreadNum(thread_num_);
  thread_pool_.Start();
}

void Carrier::Init(
    int64_t rank,
    const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
    const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
57 58 59 60
    const framework::ProgramDesc& program,
    framework::Scope* scope,
    int64_t num_micro_batches,
    const platform::Place& place,
61 62
    const std::vector<std::string>& inference_root_scope_vars,
    const std::vector<framework::Scope*>& micro_scope_list) {
63 64 65
  rank_ = rank;
  interceptor_id_to_rank_ = interceptor_id_to_rank;
  interceptor_id_to_node_ = interceptor_id_to_node;
66
  place_ = place;
67
  root_scope_ = scope;
68
  dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
69
  bool need_create_scope = micro_scope_list.empty();
70

71 72 73
  PADDLE_ENFORCE_NOT_NULL(
      root_scope_,
      platform::errors::InvalidArgument("root_scope can not be nullptr"));
74 75 76 77 78 79 80 81 82 83 84 85 86

  if (need_create_scope) {
    minibatch_scope_ = &root_scope_->NewScope();
    microbatch_scopes_.resize(num_micro_batches);
    for (int i = 0; i < num_micro_batches; ++i) {
      microbatch_scopes_[i] = &minibatch_scope_->NewScope();
      CopyParameters(i, program, inference_root_scope_vars);
    }
  } else {
    microbatch_scopes_ = micro_scope_list;
    for (int i = 0; i < num_micro_batches; ++i) {
      CopyParameters(i, program, inference_root_scope_vars);
    }
87 88
  }

89 90 91 92
  // Add source and sink interceptor id to rank
  interceptor_id_to_rank_.emplace(SOURCE_ID, rank);
  interceptor_id_to_rank_.emplace(SINK_ID, rank);

93 94 95 96 97
  // TODO(fleet_exe dev): thread pool
  thread_num_ = 1;
  thread_pool_.SetThreadNum(thread_num_);
  thread_pool_.Start();

98
  CreateInterceptors();
99
  is_init_ = true;
100 101
}

102 103 104 105 106
void Carrier::Release() {
  if (root_scope_) {
    root_scope_->DropKids();
  }
}
107

108 109
Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }

110
void Carrier::CopyParameters(
111 112
    int microbatch_id,
    const framework::ProgramDesc& program,
113 114 115 116 117
    const std::vector<std::string>& inference_root_scope_vars) {
  std::map<std::string, int> inference_root_scope_var_map;
  for (auto var_name : inference_root_scope_vars) {
    inference_root_scope_var_map.insert({var_name, 1});
  }
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
  for (size_t i = 0; i < program.Size(); ++i) {
    for (auto& var : program.Block(i).AllVars()) {
      std::string var_name = var->Name();
      bool force_root = inference_root_scope_var_map.find(var_name) !=
                        inference_root_scope_var_map.end();
      if (force_root) {
        VLOG(4) << var_name
                << " will be forced to be created in the root scope.";
      }
      if ((var->Persistable() || force_root) && microbatch_id == 0) {
        auto* ptr = root_scope_->Var(var->Name());
        InitializeVariable(ptr, var->GetType());
        VLOG(5) << "Create persistable var: " << var->Name()
                << ", which pointer is " << ptr;
      } else if (!var->Persistable()) {
        auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name());
        VLOG(5) << "Create variable " << var->Name() << " for microbatch "
                << microbatch_id << ", which pointer is " << ptr << ".";
        InitializeVariable(ptr, var->GetType());
      }
138 139 140 141
    }
  }
}

142 143
bool Carrier::EnqueueInterceptorMessage(
    const InterceptorMessage& interceptor_message) {
144
  PADDLE_ENFORCE_EQ(
145 146
      interceptor_message.ctrl_message(),
      false,
147 148 149 150 151
      platform::errors::Fatal(
          "Control message should be only send inter rank using message bus."));
  int64_t dst_id = interceptor_message.dst_id();
  Interceptor* dst_interceptor = GetInterceptor(dst_id);
  dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
152
  return true;
153 154 155 156
}

Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
  auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
157 158
  PADDLE_ENFORCE_NE(iter,
                    interceptor_idx_to_interceptor_.end(),
159 160 161 162 163
                    platform::errors::InvalidArgument(
                        "Cannot find interceptor instance for interceptor "
                        "id %lld. Wrong dst? Call before init?",
                        interceptor_id));
  return iter->second.get();
164 165
}

166 167 168 169 170
void Carrier::Wait() {
  std::unique_lock<std::mutex> lock(running_mutex_);
  cond_var_.wait(lock);
}

171 172 173 174 175
void Carrier::WakeUp() {
  // probably double notify, but ok for ut
  cond_var_.notify_all();
}

176
void Carrier::Start() {
177 178
  PADDLE_ENFORCE_EQ(is_init_,
                    true,
179 180
                    platform::errors::PreconditionNotMet(
                        "Using carrier before initialized."));
181 182 183 184 185
  InterceptorMessage start_msg;
  start_msg.set_src_id(SOURCE_ID);
  start_msg.set_dst_id(SOURCE_ID);
  start_msg.set_message_type(START);
  Send(start_msg);
186
  // TODO(wangxi): async step
187
  Wait();
188
  dev_ctx_->Wait();
189 190 191 192 193 194 195 196 197
  for (auto* micro_scope : microbatch_scopes_) {
    // By default, we should delete all kid scopes after run executor because
    // some operators may create local scope when running, such as while_op.
    // But when while_op also create a local executor to run it's sub block,
    // the sub scopes it created should not be dropped immediately, because
    // while_grad_op will use some variables created during while_op run, so
    // we need to keep the kids and wait for the outer executor to drop them.
    micro_scope->DropKids();
  }
198 199 200 201
}

bool Carrier::IsInit() const { return is_init_; }

202 203 204 205 206 207 208 209 210 211
int64_t Carrier::GetRank(int64_t interceptor_id) const {
  PADDLE_ENFORCE_NE(
      interceptor_id_to_rank_.find(interceptor_id),
      interceptor_id_to_rank_.end(),
      platform::errors::NotFound("Cannot find rank for interceptor id %lld.",
                                 interceptor_id));
  return interceptor_id_to_rank_.at(interceptor_id);
}

bool Carrier::Send(const InterceptorMessage& msg) {
212 213 214 215 216 217 218
  int64_t src_id = msg.src_id();
  // TODO(liyurui): compatible solution, will be removed completely in the
  // future
  if (interceptor_id_to_rank_.find(src_id) == interceptor_id_to_rank_.end() &&
      src_id == SOURCE_ID) {
    src_id = msg.dst_id();
  }
219 220 221 222
  int64_t dst_id = msg.dst_id();
  int64_t src_rank = GetRank(src_id);
  int64_t dst_rank = GetRank(dst_id);
  PADDLE_ENFORCE_EQ(
223 224
      src_rank,
      rank_,
225 226
      platform::errors::Fatal("The source rank id %lld, which is not equal to "
                              "the carrier rank id %lld.",
227 228
                              src_rank,
                              rank_));
229 230 231
  if (src_rank == dst_rank) {
    VLOG(3) << "Send a message from interceptor " << src_id
            << " to interceptor " << dst_id << ", which are in the same ranks.";
232
    return EnqueueInterceptorMessage(msg);
233 234 235 236
  } else {
    VLOG(3) << "Send a message from interceptor " << src_id
            << " to interceptor " << dst_id
            << ", which are in different ranks.";
237
    return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
238
  }
239 240
}

241 242 243
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
                                     std::unique_ptr<Interceptor> interceptor) {
  auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
244 245
  PADDLE_ENFORCE_EQ(iter,
                    interceptor_idx_to_interceptor_.end(),
246 247 248 249
                    platform::errors::AlreadyExists(
                        "The interceptor id %lld has already been created! "
                        "The interceptor id should be unique.",
                        interceptor_id));
250
  interceptor->RegisterCarrier(this);
251 252 253 254 255 256 257

  // TODO(fleet_exe dev): get loop
  auto* loop = thread_pool_.GetLoop(interceptor_id % thread_num_);
  PADDLE_ENFORCE_NOT_NULL(
      loop, platform::errors::Fatal("thread task loop must not null"));
  interceptor->RegisterTaskLoop(loop);

258 259 260 261 262 263
  auto* ptr = interceptor.get();
  interceptor_idx_to_interceptor_.insert(
      std::make_pair(interceptor_id, std::move(interceptor)));
  return ptr;
}

264 265 266 267 268 269 270 271
static std::shared_ptr<framework::GarbageCollector> GetGC(
    const platform::Place& place) {
  int64_t max_memory_size = framework::GetEagerDeletionThreshold();
  std::shared_ptr<framework::GarbageCollector> gc;
  if (max_memory_size >= 0) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    if (platform::is_gpu_place(place)) {
      if (framework::IsFastEagerDeletionModeEnabled()) {
272 273
        gc.reset(new framework::UnsafeFastGPUGarbageCollector(place,
                                                              max_memory_size));
274 275 276 277 278 279 280 281
      }
    }
#endif
  }  // max_memory_size >= 0

  return gc;
}

282
void Carrier::CreateInterceptors() {
283
  if (interceptor_id_to_node_.empty()) return;
284 285 286

  auto gc = GetGC(place_);

287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
  // create source and sink task node
  auto max_run_times = microbatch_scopes_.size();
  TaskNode* source = new TaskNode(
      rank_, SOURCE_ID, max_run_times);  // rank, task_id, max_run_times
  TaskNode* sink = new TaskNode(rank_, SINK_ID, max_run_times);
  // find nodes without upstreams or without downstreams
  std::vector<TaskNode*> origin_sources, origin_sinks;
  for (const auto& item : interceptor_id_to_node_) {
    TaskNode* task_node = item.second;
    if (task_node->upstream().empty()) {
      origin_sources.emplace_back(task_node);
    }
    if (task_node->downstream().empty()) {
      origin_sinks.emplace_back(task_node);
    }
  }
  // link source node with origin source
  for (const auto& node : origin_sources) {
    source->AddDownstreamTask(node->task_id(),
                              std::numeric_limits<int64_t>::max());
    node->AddUpstreamTask(SOURCE_ID, std::numeric_limits<int64_t>::max());
  }
  // link sink node with origin sink
  for (const auto& node : origin_sinks) {
    sink->AddUpstreamTask(node->task_id(), std::numeric_limits<int64_t>::max());
    node->AddDownstreamTask(SINK_ID, std::numeric_limits<int64_t>::max());
  }
  // create source and sink interceptor
  SetInterceptor(SOURCE_ID,
                 InterceptorFactory::Create("Source", SOURCE_ID, source));
  SetInterceptor(SINK_ID, InterceptorFactory::Create("Sink", SINK_ID, sink));

319
  // create each Interceptor
320
  // no auto init since there is no config
321 322 323
  for (const auto& item : interceptor_id_to_node_) {
    int64_t interceptor_id = item.first;
    TaskNode* task_node = item.second;
324

325
    PADDLE_ENFORCE_LT(
326 327
        task_node->run_at_offset(),
        task_node->run_per_steps(),
328 329 330
        platform::errors::InvalidArgument(
            "Interceptor's run_at_offset must < run_per_steps, must now "
            "run_at_offset=%ld run_per_steps=%ld",
331 332
            task_node->run_at_offset(),
            task_node->run_per_steps()));
333

334
    std::unique_ptr<Interceptor> interceptor;
335 336
    PADDLE_ENFORCE_NE(task_node->type().empty(),
                      true,
337 338 339
                      platform::errors::NotFound(
                          "Cannot found type for task node with id %lld",
                          task_node->task_id()));
340 341
    interceptor = InterceptorFactory::Create(
        task_node->type(), interceptor_id, task_node);
342 343 344 345 346 347 348 349 350 351
    interceptor->SetPlace(place_);
    interceptor->SetMiniBatchScope(minibatch_scope_);
    interceptor->SetMicroBatchScope(microbatch_scopes_);
    interceptor->SetRootScope(root_scope_);
    interceptor->SetGC(gc);

    SetInterceptor(interceptor_id, std::move(interceptor));
    VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
            << " with type: " << task_node->type() << ".";

352 353 354 355 356 357 358 359 360
    PADDLE_ENFORCE_EQ(
        task_node->upstream().empty(),
        false,
        platform::errors::PreconditionNotMet(
            "There should not have normal nodes as source nodes"));
    PADDLE_ENFORCE_EQ(task_node->downstream().empty(),
                      false,
                      platform::errors::PreconditionNotMet(
                          "There should not have normal nodes as sink nodes"));
361
  }
362 363 364 365
}

}  // namespace distributed
}  // namespace paddle