compute_interceptor.cc 14.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"

17
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
18
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
19
#include "paddle/fluid/framework/executor_gc_helper.h"
20
#include "paddle/fluid/framework/operator.h"
21 22 23 24 25 26 27
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/serialization.h"
#include "paddle/phi/core/utils/dim.h"
28 29 30 31 32 33 34 35 36 37 38

namespace paddle {
namespace distributed {

ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node)
    : Interceptor(interceptor_id, node) {
  PrepareDeps();
  RegisterMsgHandle([this](const InterceptorMessage& msg) { Compute(msg); });
}

void ComputeInterceptor::PrepareDeps() {
39 40
  auto& upstream = node_->upstream();
  auto& downstream = node_->downstream();
41

42
  for (auto up : upstream) {
43 44 45 46 47
    std::map<int64_t, int64_t> ready_size_map;
    for (int64_t i = 0; i < node_->max_run_times(); ++i) {
      ready_size_map.emplace(i, 0);
    }
    in_readys_.emplace(up.first, std::make_pair(up.second, ready_size_map));
48
  }
49 50
  for (auto down : downstream) {
    out_buffs_.emplace(down.first, std::make_pair(down.second, 0));
51 52 53
  }
}

54
void ComputeInterceptor::IncreaseReady(int64_t up_id, int64_t scope_id) {
55
  auto it = in_readys_.find(up_id);
56 57
  PADDLE_ENFORCE_NE(it,
                    in_readys_.end(),
58 59 60 61
                    platform::errors::NotFound(
                        "Cannot find upstream=%lld in in_readys.", up_id));

  auto max_ready_size = it->second.first;
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
  const auto& ready_scope_map = it->second.second;
  int64_t ready_size = 0;
  for (auto& scope_iter : ready_scope_map) {
    ready_size += scope_iter.second;
  }
  if (max_ready_size != INFINITE_BUFFER_SIZE) {
    PADDLE_ENFORCE_LE(
        ready_size,
        max_ready_size,
        platform::errors::OutOfRange(
            "upstream=%lld ready_size must <= max_ready_size, but "
            "now ready_size=%lld, max_ready_size=%lld",
            up_id,
            ready_size,
            max_ready_size));
  }
  PADDLE_ENFORCE_NE(
      it->second.second.find(scope_id),
      it->second.second.end(),
      platform::errors::OutOfRange(
          "Interceptor %lld can not find scope %lld in upstream ready map",
          interceptor_id_,
          scope_id));
  it->second.second.at(scope_id) = ready_scope_map.at(scope_id) + 1;
86 87 88 89
}

void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
  auto it = out_buffs_.find(down_id);
90 91
  PADDLE_ENFORCE_NE(it,
                    out_buffs_.end(),
92 93 94 95 96
                    platform::errors::NotFound(
                        "Cannot find downstream=%lld in out_buffs.", down_id));
  auto used_size = it->second.second;
  used_size -= 1;
  PADDLE_ENFORCE_GE(
97 98
      used_size,
      0,
99 100
      platform::errors::OutOfRange(
          "downstream=%lld used buff size must >= 0, but now equal %lld",
101 102
          down_id,
          used_size));
103 104 105 106
  it->second.second = used_size;
}

bool ComputeInterceptor::IsInputReady() {
107 108 109 110 111 112 113 114 115 116 117 118
  std::map<int64_t, bool> scope_id_to_finish_flag;
  if (!gen_step_to_scope_id_to_finish_flag_.empty()) {
    scope_id_to_finish_flag =
        gen_step_to_scope_id_to_finish_flag_.begin()->second;
    VLOG(3) << "Is Input Ready in gen step "
            << gen_step_to_scope_id_to_finish_flag_.begin()->first;
  }
  int64_t num_micro_step =
      (num_micro_step_ == -1 ? node_->max_run_times() : num_micro_step_);
  int64_t start_micro_step = (start_micro_step_ == -1 ? 0 : start_micro_step_);
  for (int64_t i = start_micro_step; i < start_micro_step + num_micro_step;
       ++i) {
119 120 121 122 123 124
    bool flag = true;
    for (auto& ins : in_readys_) {
      auto ready_size_map = ins.second.second;
      flag = flag && (ready_size_map.at(i) != 0);
    }
    if (flag) {
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
      if (scope_id_to_finish_flag.empty()) {
        cur_scope_id_ = i;
        return true;
      } else if (scope_id_to_finish_flag.find(i) !=
                 scope_id_to_finish_flag.end()) {
        for (auto iter : scope_id_to_finish_flag) {
          if (iter.first == i) {
            break;
          } else if (!iter.second) {
            VLOG(3) << "The previous scope is not ready, waiting for the "
                       "previous scope "
                    << iter.first << " in gen_step "
                    << gen_step_to_scope_id_to_finish_flag_.begin()->first;
            return false;
          }
140
        }
141 142 143 144 145 146
        cur_scope_id_ = i;
        return true;
      } else {
        VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
                << " is larger than gen_step "
                << gen_step_to_scope_id_to_finish_flag_.begin()->first;
147 148 149
      }
    } else {
      VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
Y
Yuang Liu 已提交
150 151
              << "'s upstreams aren't all ready.";
    }
152
  }
153
  return false;
154 155 156 157 158 159
}

bool ComputeInterceptor::CanWriteOutput() {
  for (auto& outs : out_buffs_) {
    auto max_buffer_size = outs.second.first;
    auto used_size = outs.second.second;
160 161 162
    if (max_buffer_size == INFINITE_BUFFER_SIZE) {
      continue;
    }
163
    // full, return false
Y
Yuang Liu 已提交
164 165 166 167 168
    if (used_size == max_buffer_size) {
      VLOG(3) << "Interceptor " << GetInterceptorId()
              << "'s out buffer is full.";
      return false;
    }
169 170
  }
  return true;
171 172 173
}

void ComputeInterceptor::SendDataReadyToDownStream() {
174 175 176 177 178 179 180 181 182 183
  bool need_send_vars = !(node_->vars_to_dtype().empty());
  InterceptorMessage ready_msg;
  ready_msg.set_start_micro_step(start_micro_step_);
  ready_msg.set_num_micro_step(num_micro_step_);
  if (need_send_vars) {
    ready_msg = PrepareVarsMsg();
  } else {
    ready_msg.set_message_type(DATA_IS_READY);
    ready_msg.set_scope_idx(cur_scope_id_);
  }
184 185 186 187 188
  for (auto& outs : out_buffs_) {
    auto down_id = outs.first;
    auto max_buff_size = outs.second.first;
    auto used_size = outs.second.second;
    used_size += 1;
189 190 191 192 193 194 195 196 197 198 199
    if (max_buff_size != INFINITE_BUFFER_SIZE) {
      PADDLE_ENFORCE_LE(
          used_size,
          max_buff_size,
          platform::errors::OutOfRange("downstream=%lld used buff size must <= "
                                       "max_buff_size, but now used_size=%lld, "
                                       "max_buff_size=%lld",
                                       down_id,
                                       used_size,
                                       max_buff_size));
    }
200 201
    outs.second.second = used_size;

202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
    if (need_send_vars) {
      VLOG(3) << "ComputeInterceptor " << interceptor_id_
              << " Send data_with_vars msg to " << down_id
              << " in scope: " << cur_scope_id_;
      Send(down_id, ready_msg);
    } else {
      VLOG(3) << "ComputeInterceptor " << interceptor_id_
              << " Send data_is_ready msg to " << down_id
              << " in scope: " << cur_scope_id_;
      Send(down_id, ready_msg);
    }
  }
}

InterceptorMessage ComputeInterceptor::PrepareVarsMsg() {
  PADDLE_ENFORCE_LT(cur_scope_id_,
                    microbatch_scopes_.size(),
                    platform::errors::InvalidArgument(
                        "Step out of range. There are %ld "
                        "microbatch_scopes, but recevice scope index %ld",
                        microbatch_scopes_.size(),
                        cur_scope_id_));
  auto* scope = microbatch_scopes_[cur_scope_id_];

  InterceptorMessage ready_msg;
  ready_msg.set_message_type(DATA_WITH_VARS);
  ready_msg.set_scope_idx(cur_scope_id_);
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  for (auto iter : node_->vars_to_dtype()) {
    VarList* vars = ready_msg.add_vars_list();
    const auto& var_name = iter.first;
    vars->set_name(var_name);
    std::ostringstream ss;
    auto& dev_ctx = *pool.Get(place_);
    auto* var = scope->FindVar(var_name);
    PADDLE_ENFORCE(
        var,
        platform::errors::NotFound(
            "Variable %s not exists in scope %ld", var_name, cur_scope_id_));
    const auto& tensor = var->Get<phi::DenseTensor>();
    SerializeToStream(ss, tensor, dev_ctx);
    vars->set_stensor(ss.str());
    VLOG(3) << "Prepare vars msg " << var_name << " with dimension "
            << tensor.dims() << " dtype " << tensor.dtype();
246
  }
247
  return ready_msg;
248 249 250 251 252
}

void ComputeInterceptor::ReplyCompletedToUpStream() {
  for (auto& ins : in_readys_) {
    auto up_id = ins.first;
253
    auto ready_size = ins.second.second.at(cur_scope_id_);
254 255
    ready_size -= 1;
    PADDLE_ENFORCE_GE(
256 257
        ready_size,
        0,
258
        platform::errors::OutOfRange(
259 260
            "upstream=%lld ready_size must >= 0, but now got %lld",
            up_id,
261
            ready_size));
262
    ins.second.second[cur_scope_id_] = ready_size;
263

Y
Yuang Liu 已提交
264 265
    VLOG(3) << "ComputeInterceptor " << interceptor_id_
            << " Reply data_is_useless msg to " << up_id
266
            << " in scope: " << cur_scope_id_;
Y
Yuang Liu 已提交
267

268
    InterceptorMessage reply_msg;
269
    reply_msg.set_message_type(DATA_IS_USELESS);
270
    reply_msg.set_scope_idx(cur_scope_id_);
271 272 273 274
    Send(up_id, reply_msg);
  }
}

275 276
void ComputeInterceptor::RunOps() {
  for (auto op : node_->ops()) {
277 278 279 280 281 282 283 284
    PADDLE_ENFORCE_LT(cur_scope_id_,
                      microbatch_scopes_.size(),
                      platform::errors::InvalidArgument(
                          "Step out of range. There are %ld "
                          "microbatch_scopes, but recevice scope index %ld",
                          microbatch_scopes_.size(),
                          cur_scope_id_));
    op->Run(*microbatch_scopes_[cur_scope_id_], place_);
285
    if (gc_) {
286 287 288 289
      framework::DeleteUnusedTensors(*microbatch_scopes_[cur_scope_id_],
                                     op,
                                     node_->unused_vars(),
                                     gc_.get());
290
    }
291 292 293
  }
}

294
void ComputeInterceptor::Run() {
295
  while (IsInputReady() && CanWriteOutput()) {
296 297
    VLOG(3) << "id=" << GetInterceptorId()
            << " ComputeInterceptor running in scope " << cur_scope_id_;
298

299
    RunOps();
300

301 302 303 304 305 306
    if (!gen_step_to_scope_id_to_finish_flag_.empty()) {
      auto iter = gen_step_to_scope_id_to_finish_flag_.begin();
      VLOG(3) << "id=" << GetInterceptorId()
              << " ComputeInterceptor running in scope " << cur_scope_id_
              << " with gen_step " << iter->first;
      auto& scope_id_to_finish_flag = iter->second;
307
      PADDLE_ENFORCE_NE(
308 309
          scope_id_to_finish_flag.find(cur_scope_id_),
          scope_id_to_finish_flag.end(),
310 311
          platform::errors::NotFound(
              "Can not find scope %ld in scope_id_to_finish", cur_scope_id_));
312 313 314 315
      scope_id_to_finish_flag.erase(cur_scope_id_);
      if (scope_id_to_finish_flag.empty()) {
        gen_step_to_scope_id_to_finish_flag_.erase(iter);
      }
316
    }
317 318 319 320 321

    // send to downstream and increase buff used
    SendDataReadyToDownStream();
    // reply to upstream and decrease ready data
    ReplyCompletedToUpStream();
322 323 324
  }
}

325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
void ComputeInterceptor::DecodeMsgVars(const InterceptorMessage& msg) {
  int64_t scope_id = msg.scope_idx();
  PADDLE_ENFORCE_LT(scope_id,
                    microbatch_scopes_.size(),
                    platform::errors::InvalidArgument(
                        "Step out of range. There are %ld "
                        "microbatch_scopes, but recevice scope index %ld",
                        microbatch_scopes_.size(),
                        scope_id));
  auto* scope = microbatch_scopes_[scope_id];
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  for (const auto& var_iter : msg.vars_list()) {
    const std::string& name = var_iter.name();
    auto& dev_ctx = *pool.Get(place_);
    std::istringstream ss(var_iter.stensor());
    auto* var = scope->Var(name);
    auto* tensor = var->GetMutable<phi::DenseTensor>();
    DeserializeFromStream(ss, tensor, dev_ctx);

    VLOG(3) << "Set vars " << name << " with value in scope " << scope_id
            << " with dims " << tensor->dims() << " with dtype "
            << tensor->dtype();
347 348 349
  }
}

350 351
void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
  if (msg.message_type() == DATA_IS_READY) {
352 353 354
    VLOG(3) << "Compute interceptor " << interceptor_id_
            << " receive data_is_ready " << msg.src_id() << " "
            << msg.scope_idx() << " ";
355 356
    start_micro_step_ = msg.start_micro_step();
    num_micro_step_ = msg.num_micro_step();
357
    IncreaseReady(msg.src_id(), msg.scope_idx());
358
    Run();
359
  } else if (msg.message_type() == DATA_IS_USELESS) {
360 361 362
    VLOG(3) << "Compute interceptor " << interceptor_id_
            << " receive data_is_useless " << msg.src_id() << " "
            << msg.scope_idx() << " ";
363 364
    DecreaseBuff(msg.src_id());
    Run();
365 366 367 368 369 370 371 372 373
  } else if (msg.message_type() == DATA_WITH_VARS) {
    VLOG(3) << "Compute interceptor " << interceptor_id_
            << " receive data_with_vars " << msg.src_id() << " "
            << msg.scope_idx() << " ";
    DecodeMsgVars(msg);
    IncreaseReady(msg.src_id(), msg.scope_idx());
    Run();
  } else if (msg.message_type() == START_LOOP) {
    VLOG(3) << "Compute interceptor " << interceptor_id_
374 375 376 377
            << " receive start_loop " << msg.src_id() << " in scope "
            << msg.scope_idx() << " with gen_step " << msg.gen_step();
    start_micro_step_ = msg.start_micro_step();
    num_micro_step_ = msg.num_micro_step();
378
    IncreaseReady(msg.src_id(), msg.scope_idx());
379 380 381
    int64_t gen_step = msg.gen_step();
    gen_step_to_scope_id_to_finish_flag_[gen_step].emplace(msg.scope_idx(),
                                                           false);
382
    Run();
383 384 385 386 387 388 389
  }
}

REGISTER_INTERCEPTOR(Compute, ComputeInterceptor);

}  // namespace distributed
}  // namespace paddle