carrier.cc 11.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
18
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
19
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
20
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
21
#include "paddle/fluid/framework/garbage_collector.h"
22
#include "paddle/fluid/framework/scope.h"
23 24 25 26

namespace paddle {
namespace distributed {

27
USE_INTERCEPTOR(Compute);
28
USE_INTERCEPTOR(Amplifier);
29

30
void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
31 32 33 34
                   framework::Scope* root_scope,
                   framework::Scope* minibatch_scope,
                   const std::vector<framework::Scope*>& microbatch_scopes,
                   const platform::Place& place) {
35 36
  PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
                                         "Carrier is already init."));
37
  rank_ = rank;
38
  runtime_graph_ = runtime_graph;
39
  interceptor_id_to_rank_ = runtime_graph_->interceptor_id_to_rank();
40 41 42
  minibatch_scope_ = minibatch_scope;
  microbatch_scopes_ = microbatch_scopes;
  place_ = place;
43 44
  root_scope_ = root_scope;
  dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
45
  CreateInterceptors();
46
  is_init_ = true;
47 48
}

49
void Carrier::Release() {
50 51
  // NOTE(wangxi): must join before `Derived Interceptor` destruct,
  // otherwise Derived object will be destructed before thread complete.
52 53 54 55 56 57 58 59 60

  for (int64_t id : source_interceptor_ids_) {
    VLOG(3) << "Carrier Release is sending stop to source interceptor " << id
            << ".";
    InterceptorMessage stop_msg;
    // source node STOP is send by carrier, so set src_id=-1
    stop_msg.set_src_id(-1);
    stop_msg.set_dst_id(id);
    stop_msg.set_message_type(STOP);
61
    Send(stop_msg);
62 63
  }

64 65 66 67 68 69
  // TODO(wangxi): Maybe need a better to use thread.
  for (auto& interceptor : interceptor_idx_to_interceptor_) {
    interceptor.second->Join();
  }
}

70 71
Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }

72 73
bool Carrier::EnqueueInterceptorMessage(
    const InterceptorMessage& interceptor_message) {
74
  if (interceptor_message.ctrl_message()) {
75 76 77
    VLOG(3) << "Receiving control message from rank "
            << interceptor_message.src_id() << " to rank "
            << interceptor_message.dst_id();
78
  } else {
79 80 81 82 83 84 85 86 87 88
    {
      std::unique_lock<std::mutex> lock_creating(creating_flag_mutex_);
      if (creating_interceptors_) {
        std::unique_lock<std::mutex> lock_message(tmp_message_mutex_);
        // Cannot handle the message to interceptor since interceptors
        // are still under creating. Will enqueue into a tmp stack.
        VLOG(3) << "Receiving message while creating interceptors.";
        message_tmp_.emplace_back(interceptor_message);
        return true;
      }
89
    }
90 91
    int64_t dst_id = interceptor_message.dst_id();
    Interceptor* dst_interceptor = GetInterceptor(dst_id);
92
    dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
93
  }
94
  return true;
95 96 97 98 99 100 101 102 103 104
}

Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
  auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
  PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(),
                    platform::errors::InvalidArgument(
                        "Cannot find interceptor instance for interceptor "
                        "id %lld. Wrong dst? Call before init?",
                        interceptor_id));
  return iter->second.get();
105 106
}

107 108 109 110 111
void Carrier::Wait() {
  std::unique_lock<std::mutex> lock(running_mutex_);
  cond_var_.wait(lock);
}

112
void Carrier::Start() {
113
  PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
114
                    platform::errors::PreconditionNotMet(
115 116 117
                        "Using message bus since it has not been initialized. "
                        "Please invoke MessageBus::Init() before using it or "
                        "neccessary components are not ready."));
118 119 120 121 122 123 124 125 126

  for (int64_t id : source_interceptor_ids_) {
    VLOG(3) << "Carrier Start is sending start to source interceptor " << id
            << ".";
    InterceptorMessage start_msg;
    // source node data_is_ready is send by carrier, so set src_id=-1
    start_msg.set_src_id(-1);
    start_msg.set_dst_id(id);
    start_msg.set_message_type(DATA_IS_READY);
127
    Send(start_msg);
128
  }
129
  Wait();
130
  dev_ctx_->Wait();
131 132
}

133 134
std::condition_variable& Carrier::GetCondVar() { return cond_var_; }

135 136
bool Carrier::IsInit() const { return is_init_; }

137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
int64_t Carrier::GetRank(int64_t interceptor_id) const {
  PADDLE_ENFORCE_NE(
      interceptor_id_to_rank_.find(interceptor_id),
      interceptor_id_to_rank_.end(),
      platform::errors::NotFound("Cannot find rank for interceptor id %lld.",
                                 interceptor_id));
  return interceptor_id_to_rank_.at(interceptor_id);
}

bool Carrier::Send(const InterceptorMessage& msg) {
  int64_t src_id = (msg.src_id() == -1) ? msg.dst_id() : msg.src_id();
  int64_t dst_id = msg.dst_id();
  int64_t src_rank = GetRank(src_id);
  int64_t dst_rank = GetRank(dst_id);
  PADDLE_ENFORCE_EQ(
      src_rank, rank_,
      platform::errors::Fatal("The source rank id %lld, which is not equal to "
                              "the carrier rank id %lld.",
                              src_rank, rank_));
  if (src_rank == dst_rank) {
    VLOG(3) << "Send a message from interceptor " << src_id
            << " to interceptor " << dst_id << ", which are in the same ranks.";
    return EnqueueInterceptorMessage(msg);
  } else {
    PADDLE_ENFORCE_NOT_NULL(
        msg_bus_.get(),
        platform::errors::Unavailable("Message bus is released accidently"));
    PADDLE_ENFORCE_EQ(
        msg_bus_->IsInit(), true,
        platform::errors::PreconditionNotMet(
            "Using message bus since it has not been initialized. "
            "Please invoke MessageBus::Init() before using it or "
            "neccessary components are not ready."));
    VLOG(3) << "Send a message from interceptor " << src_id
            << " to interceptor " << dst_id
            << ", which are in different ranks.";
    return msg_bus_->Send(dst_rank, msg);
  }
175 176
}

177 178 179 180 181 182 183 184
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
                                     std::unique_ptr<Interceptor> interceptor) {
  auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
  PADDLE_ENFORCE_EQ(iter, interceptor_idx_to_interceptor_.end(),
                    platform::errors::AlreadyExists(
                        "The interceptor id %lld has already been created! "
                        "The interceptor id should be unique.",
                        interceptor_id));
185
  interceptor->RegisterCarrier(this);
186 187 188 189 190 191
  auto* ptr = interceptor.get();
  interceptor_idx_to_interceptor_.insert(
      std::make_pair(interceptor_id, std::move(interceptor)));
  return ptr;
}

192 193
void Carrier::SetCreatingFlag(bool flag) {
  // set the creating flag
194
  creating_flag_mutex_.lock();
195 196 197
  VLOG(3) << "Carrier is set the creating flag from " << creating_interceptors_
          << " to " << flag << ".";
  creating_interceptors_ = flag;
198
  creating_flag_mutex_.unlock();
199
  if (!flag) {
200 201 202 203 204 205 206 207 208 209 210
    for (auto& pair : interceptor_idx_to_interceptor_) {
      // update the source interceptor id
      if (std::find(source_interceptor_ids_.begin(),
                    source_interceptor_ids_.end(),
                    pair.first) == source_interceptor_ids_.end()) {
        auto task = pair.second->GetTaskNode();
        if (task != nullptr && task->upstream().empty()) {
          source_interceptor_ids_.emplace_back(pair.first);
        }
      }
    }
211 212 213 214 215 216
    // finish create interceptors outside, handle tmp messsages
    HandleTmpMessages();
  }
}

void Carrier::HandleTmpMessages() {
217 218 219 220 221 222
  // NOTE: It's ok lock on the tmp_message_mutex_ here, when enter this
  // `HandleTmpMessages` method, the creating_interceptors_ flag
  // must be false, therefore, there won't have conflict with the
  // lock on the tmp_message_mutex_ inside `EnqueueInterceptorMessage`
  // on the same thread.
  std::unique_lock<std::mutex> lock(tmp_message_mutex_);
223 224 225 226 227 228 229 230
  VLOG(3) << "Carrier has received " << message_tmp_.size()
          << " messages during creating interceptors.";
  for (const auto& msg : message_tmp_) {
    EnqueueInterceptorMessage(msg);
  }
  message_tmp_.clear();
}

231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
static std::shared_ptr<framework::GarbageCollector> GetGC(
    const platform::Place& place) {
  int64_t max_memory_size = framework::GetEagerDeletionThreshold();
  std::shared_ptr<framework::GarbageCollector> gc;
  if (max_memory_size >= 0) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    if (platform::is_gpu_place(place)) {
      if (framework::IsFastEagerDeletionModeEnabled()) {
        gc.reset(new framework::UnsafeFastGPUGarbageCollector(
            BOOST_GET_CONST(platform::CUDAPlace, place), max_memory_size));
      }
    }
#endif
  }  // max_memory_size >= 0

  return gc;
}

249
void Carrier::CreateInterceptors() {
250
  if (runtime_graph_->interceptor_id_to_node().empty()) return;
251 252 253

  auto gc = GetGC(place_);

254
  // create each Interceptor
255
  // no auto init since there is no config
256
  for (const auto& item : runtime_graph_->interceptor_id_to_node()) {
257 258
    int64_t interceptor_id = item.first;
    TaskNode* task_node = item.second;
259

260 261 262 263 264 265
    PADDLE_ENFORCE_LT(
        task_node->run_at_offset(), task_node->run_per_steps(),
        platform::errors::InvalidArgument(
            "Interceptor's run_at_offset must < run_per_steps, must now "
            "run_at_offset=%ld run_per_steps=%ld",
            task_node->run_at_offset(), task_node->run_per_steps()));
266

267
    std::unique_ptr<Interceptor> interceptor;
268 269 270 271 272 273
    PADDLE_ENFORCE_NE(task_node->type().empty(), true,
                      platform::errors::NotFound(
                          "Cannot found type for task node with id %lld",
                          task_node->task_id()));
    interceptor = InterceptorFactory::Create(task_node->type(), interceptor_id,
                                             task_node);
274 275 276 277 278 279 280 281 282 283 284 285
    interceptor->SetPlace(place_);
    interceptor->SetMiniBatchScope(minibatch_scope_);
    interceptor->SetMicroBatchScope(microbatch_scopes_);
    interceptor->SetRootScope(root_scope_);
    interceptor->SetGC(gc);

    SetInterceptor(interceptor_id, std::move(interceptor));
    VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
            << " with type: " << task_node->type() << ".";

    if (task_node->upstream().empty()) {
      source_interceptor_ids_.emplace_back(interceptor_id);
286
    }
287
  }
288 289 290 291 292 293
  // The carrier will be always waiting for outside initializer
  // since there is no interceptor has been created during auto init
  creating_flag_mutex_.lock();
  creating_interceptors_ = false;
  creating_flag_mutex_.unlock();
  HandleTmpMessages();
294 295 296 297
}

}  // namespace distributed
}  // namespace paddle