parallel_executor.cc 18.5 KB
Newer Older
Y
Yang Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/parallel_executor.h"
Y
Yu Yang 已提交
16 17
#include "ThreadPool.h"
#include "executor.h"
Y
Yu Yang 已提交
18 19
#include "lod_tensor.h"
#include "op_registry.h"
Y
Yang Yang 已提交
20 21

namespace paddle {
Y
Yu Yang 已提交
22 23
namespace framework {

Y
Yu Yang 已提交
24 25 26 27 28 29
#ifdef PADDLE_WITH_CUDA

// FIXME: CHECK the return value of x;
#define NCCL_INVOKE(x) x
#endif

Y
Yu Yang 已提交
30 31
struct OpHandle;

Y
Yu Yang 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
struct VarHandleBase {
  virtual ~VarHandleBase() {}
  virtual std::string DebugString() const = 0;

  OpHandle *generated_op_;
  std::vector<OpHandle *> pending_ops_;
};

struct VarHandle : public VarHandleBase {
  std::string DebugString() const override {
    std::stringstream ss;
    ss << name_ << ":" << place_;
    return ss.str();
  }

Y
Yu Yang 已提交
47 48 49
  size_t version_;
  std::string name_;
  platform::Place place_;
Y
Yu Yang 已提交
50
};
Y
Yu Yang 已提交
51

Y
Yu Yang 已提交
52
struct DependencyVarHandle : public VarHandleBase {
Y
Yu Yang 已提交
53
  std::string DebugString() const override { return "Dependency Variable"; }
Y
Yu Yang 已提交
54 55 56
};

struct OpHandle {
Y
Yu Yang 已提交
57 58 59 60 61
  std::vector<VarHandleBase *> inputs_;
  std::vector<VarHandleBase *> outputs_;
  std::unordered_map<platform::Place, platform::DeviceContext *,
                     platform::PlaceHash>
      dev_ctx_;
Y
Yu Yang 已提交
62 63 64 65 66

  std::string DebugString() {
    std::stringstream ss;
    ss << "(";
    for (auto *var : inputs_) {
Y
Yu Yang 已提交
67
      ss << var->DebugString() << ", ";
Y
Yu Yang 已提交
68 69 70
    }
    ss << ") --> (";
    for (auto *var : outputs_) {
Y
Yu Yang 已提交
71
      ss << var->DebugString() << ", ";
Y
Yu Yang 已提交
72 73 74 75 76 77
    }
    ss << ")\n";
    return ss.str();
  }

  virtual ~OpHandle() {}
Y
Yu Yang 已提交
78

Y
Yu Yang 已提交
79
  virtual void Run() { PADDLE_THROW("Not implemented"); }
Y
Yu Yang 已提交
80
  virtual void Wait() {}
Y
Yu Yang 已提交
81 82 83 84
};

struct ComputationOpHandle : public OpHandle {
  std::unique_ptr<OperatorBase> op_;
Y
Yu Yang 已提交
85 86
  Scope *scope_;
  platform::Place place_;
Y
Yu Yang 已提交
87

Y
Yu Yang 已提交
88 89
  explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
                               platform::Place place)
Y
Yu Yang 已提交
90
      : op_(framework::OpRegistry::CreateOp(op_desc)),
Y
Yu Yang 已提交
91
        scope_(scope),
Y
Yu Yang 已提交
92 93 94 95
        place_(place) {}

  void Run() override {
    // Wait other op if necessary
Y
Yu Yang 已提交
96
    LOG(INFO) << "Run " << this << " " << DebugString();
Y
Yu Yang 已提交
97 98 99 100 101 102 103 104
    auto *cur_ctx = dev_ctx_[place_];
    for (auto *in : inputs_) {
      if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
        in->generated_op_->Wait();
      }
    }

    op_->Run(*scope_, place_);
Y
Yu Yang 已提交
105
    LOG(INFO) << "Done " << this;
Y
Yu Yang 已提交
106
  }
Y
Yu Yang 已提交
107 108
};

Y
Yu Yang 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
struct ScaleLossGradOpHandle : public OpHandle {
  float coeff_;
  Scope *scope_;
  platform::Place place_;

  explicit ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
                                 platform::Place place)
      : coeff_(static_cast<float>(1.0 / num_dev)),
        scope_(scope),
        place_(place) {}

  void Run() override {
    LOG(INFO) << "Run Scale Loss Grad";

    std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
Y
Yu Yang 已提交
124

Y
Yu Yang 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    float *tmp = scope_->FindVar(var_name)
                     ->GetMutable<framework::LoDTensor>()
                     ->mutable_data<float>(make_ddim({1}), place_);

    if (platform::is_cpu_place(place_)) {
      *tmp = coeff_;
    } else {
      memory::Copy(
          boost::get<platform::CUDAPlace>(place_), tmp, platform::CPUPlace(),
          &coeff_, sizeof(float),
          static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_])
              ->stream());
    }
  }
};

struct NCCLAllReduceOpHandle : public OpHandle {
  void Run() override {
    if (this->inputs_.size() == 1) {
      return;  // No need to all reduce when GPU count = 1;
    }
  }
};
Y
Yu Yang 已提交
148 149 150

class ParallelExecutorPrivate {
 public:
Y
Yu Yang 已提交
151 152 153
  explicit ParallelExecutorPrivate(size_t num_threads = 12)
      : pool_(num_threads) {}

Y
Yu Yang 已提交
154 155
  std::unordered_map<platform::Place, Scope *, platform::PlaceHash>
      local_scopes_;
Y
Yu Yang 已提交
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200

#ifdef PADDLE_WITH_CUDA
  struct NCCLContext {
    std::unique_ptr<platform::CUDADeviceContext> ctx_;
    ncclComm_t comm;

    explicit NCCLContext(int dev_id) {
      ctx_.reset(new platform::CUDADeviceContext(platform::CUDAPlace(dev_id)));
    }

    cudaStream_t stream() const { return ctx_->stream(); }

    int device_id() const {
      return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
    }

    static void InitNCCLContext(std::map<int, NCCLContext> &contexts) {
      std::vector<ncclComm_t> comms;
      std::vector<int> devs;
      comms.resize(contexts.size());
      devs.reserve(contexts.size());

      for (auto &ctx : contexts) {
        devs.push_back(ctx.first);
      }

      NCCL_INVOKE(platform::dynload::ncclCommInitAll(
          &comms[0], static_cast<int>(contexts.size()), &devs[0]));

      int i = 0;
      for (auto &ctx : contexts) {
        ctx.second.comm = comms[i++];
      }
    }
  };

  std::map<int, NCCLContext> communication_streams_;

  NCCLContext &GetNCCLCtx(platform::Place p) {
    int dev_id = boost::get<platform::CUDAPlace>(p).device;
    return communication_streams_.at(dev_id);
  }

#endif

Y
Yu Yang 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213
  platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) {
    if (platform::is_cpu_place(place) || local_scopes_.size() == 1) {
      return const_cast<platform::DeviceContext *>(
          platform::DeviceContextPool::Instance().Get(place));
    } else {
#ifdef PADDLE_WITH_CUDA
      return GetNCCLCtx(place).ctx_.get();
#else
      PADDLE_THROW("Not compiled with CUDA")
#endif
    }
  }

Y
Yu Yang 已提交
214 215 216 217 218 219
  platform::Place main_place_;

  std::unordered_map<platform::Place,
                     std::unordered_map<std::string, std::map<int, VarHandle>>,
                     platform::PlaceHash>
      vars_;
Y
Yu Yang 已提交
220 221
  std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;

Y
Yu Yang 已提交
222
  std::vector<std::unique_ptr<OpHandle>> ops_;
Y
Yu Yang 已提交
223

Y
Yu Yang 已提交
224
  // Use a simpler thread pool, might be faster.
Y
Yu Yang 已提交
225
  ThreadPool pool_;
Y
Yu Yang 已提交
226 227

  std::unique_ptr<platform::EnforceNotMet> exception_;
Y
Yu Yang 已提交
228 229 230 231
};

// TODO(yy): Move this function somewhere
ncclDataType_t ToNCCLDataType(std::type_index type) {
Y
Stash  
Yu Yang 已提交
232 233 234 235 236 237 238 239 240
  if (type == typeid(float)) {  // NOLINT
    return ncclFloat;
  } else if (type == typeid(double)) {  // NOLINT
    return ncclDouble;
  } else if (type == typeid(int)) {  // NOLINT
    return ncclInt;
  } else {
    PADDLE_THROW("Not supported");
  }
Y
Yu Yang 已提交
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
}

ParallelExecutor::ParallelExecutor(
    const std::vector<platform::Place> &places,
    const std::unordered_set<std::string> &params,
    const ProgramDesc &startup_program, const ProgramDesc &main_program,
    const std::string &loss_var_name, Scope *scope)
    : member_(new ParallelExecutorPrivate()) {
  // Step 1. RunStartupProgram and Bcast the params to devs.
  Executor exe(places[0]);
  exe.Run(startup_program, scope, 0);
  // Create local scopes
  for (auto &place : places) {
    member_->local_scopes_[place] = &scope->NewScope();
  }
  member_->main_place_ = places[0];

  // Bcast Parameters to all GPUs
Y
Yu Yang 已提交
259 260 261 262
  if (platform::is_gpu_place(member_->main_place_) &&
      member_->local_scopes_.size() != 1) {  // Is CUDA
    BuildNCCLCommunicator();
    BCastParamsToGPUs(startup_program);
Y
Yu Yang 已提交
263 264 265 266 267 268
  }
  // Startup Program has been run. All local scopes has correct parameters.

  // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
  // ncclOp
  ConstructDependencyGraph(params, main_program, loss_var_name);
Y
Yu Yang 已提交
269 270 271 272 273 274 275 276 277 278 279 280 281

  // Step 3. Create vars in each scope;
  for (auto &pair : member_->local_scopes_) {
    auto *scope = pair.second;

    for (auto *var : main_program.Block(0).AllVars()) {
      if (scope->FindVar(var->Name()) != nullptr) {
        continue;
      }

      InitializeVariable(scope->Var(var->Name()), var->GetType());
    }
  }
Y
Yu Yang 已提交
282 283 284 285 286
}

void ParallelExecutor::ConstructDependencyGraph(
    const std::unordered_set<std::string> &params,
    const ProgramDesc &main_program, const std::string &loss_var_name) const {
Y
Yu Yang 已提交
287
  std::unordered_set<std::string> grads;
Y
Yu Yang 已提交
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
  for (auto &each_param : params) {
    grads.insert(each_param + "@GRAD");
  }

  bool is_forwarding = true;
  for (auto *op : main_program.Block(0).AllOps()) {
    bool change_forward = false;

    if (!is_forwarding) {
      // FIXME(yy): Do not hard code like this
      if (op->OutputArgumentNames().size() == 1 &&
          op->OutputArgumentNames()[0] == loss_var_name + "@GRAD") {
        continue;  // Drop fill 1. for backward coeff;
      }
    }

    for (auto &pair : member_->local_scopes_) {
Y
Yu Yang 已提交
305 306
      member_->ops_.emplace_back(
          new ComputationOpHandle(*op, pair.second, pair.first));
Y
Yu Yang 已提交
307
      auto *op_handle = member_->ops_.back().get();
Y
Yu Yang 已提交
308 309
      op_handle->dev_ctx_[pair.first] = const_cast<platform::DeviceContext *>(
          platform::DeviceContextPool::Instance().Get(pair.first));
Y
Yu Yang 已提交
310 311 312 313 314 315 316

      auto var_names = op->InputArgumentNames();

      for (auto &each_var_name : var_names) {
        auto &place = pair.first;
        VarHandle *var = GetVarHandle(each_var_name, place);
        op_handle->inputs_.emplace_back(var);
Y
Yu Yang 已提交
317
        var->pending_ops_.emplace_back(op_handle);
Y
Yu Yang 已提交
318 319 320 321 322 323 324 325 326 327 328
      }
      var_names = op->OutputArgumentNames();

      for (auto &each_var_name : var_names) {
        auto &place = pair.first;
        GenerateVar(op_handle, each_var_name, place);
      }

      if (is_forwarding) {
        if (var_names.size() == 1 && var_names[0] == loss_var_name) {
          // Insert ScaleCost OpHandle
Y
Yu Yang 已提交
329 330
          member_->ops_.emplace_back(new ScaleLossGradOpHandle(
              this->member_->local_scopes_.size(), pair.second, pair.first));
Y
Yu Yang 已提交
331
          op_handle = member_->ops_.back().get();
Y
Yu Yang 已提交
332 333 334 335

          op_handle->dev_ctx_[pair.first] =
              member_->CommunicationDevCtx(pair.first);

Y
Yu Yang 已提交
336
          auto &place = pair.first;
Y
Yu Yang 已提交
337 338 339 340 341 342
          // FIXME: Currently ScaleLossGradOp only use device_count as scale
          // factor. So it does not depend on any other operators.
          // VarHandle *loss = GetVarHandle(loss_var_name, place);
          // loss->pending_ops_.emplace_back(op_handle);
          // op_handle->inputs_.emplace_back(loss);

Y
Yu Yang 已提交
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
          GenerateVar(op_handle, loss_var_name + "@GRAD", place);
          change_forward = true;
          LOG(INFO) << "Scale Loss " << op_handle->DebugString();
        }
      }
    }

    if (change_forward) {
      is_forwarding = false;
    }

    if (!is_forwarding) {
      auto var_names = op->OutputArgumentNames();
      for (auto &og : var_names) {
        if (grads.count(og) != 0) {  // is param grad
          // Insert NCCL AllReduce Op
          member_->ops_.emplace_back(new NCCLAllReduceOpHandle());
          auto *op_handle = member_->ops_.back().get();

          for (auto &pair : member_->local_scopes_) {
            auto &place = pair.first;
            auto &vars = member_->vars_[place][og];

            if (vars.empty()) {  // This device has no data. continue.
              continue;
            }
            auto *prev_grad = &vars[vars.size() - 1];
            op_handle->inputs_.emplace_back(prev_grad);
Y
Yu Yang 已提交
371
            prev_grad->pending_ops_.emplace_back(op_handle);
Y
Yu Yang 已提交
372 373 374 375 376 377
            auto &var = vars[vars.size()];
            var.place_ = place;
            var.generated_op_ = op_handle;
            var.name_ = og;
            var.version_ = vars.size() - 1;
            op_handle->outputs_.emplace_back(&var);
Y
Yu Yang 已提交
378 379 380 381 382

            for (auto &pair : member_->local_scopes_) {
              op_handle->dev_ctx_[pair.first] =
                  member_->CommunicationDevCtx(pair.first);
            }
Y
Yu Yang 已提交
383 384 385 386 387
          }
        }
      }
    }
  }
Y
Yu Yang 已提交
388

Y
Yu Yang 已提交
389 390 391
  /*
    Dependency graph has been constructed. However, there are still data
    harzaeds need to be handled.
Y
Yu Yang 已提交
392
   */
Y
Yu Yang 已提交
393 394
  PolishGraphToSupportDataHarzaeds();
}
Y
Yu Yang 已提交
395

Y
Yu Yang 已提交
396 397 398 399 400 401 402 403
/**
 * We only handle write after read(WAR), since it should not have a write
 * after write in program. If there are write after write operators, we need
 * prune them.
 *
 * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
 */
void ParallelExecutor::PolishGraphToSupportDataHarzaeds() const {
Y
Yu Yang 已提交
404 405 406 407 408 409 410 411 412 413 414
  for (auto &place_pair : member_->vars_) {
    for (auto &name_pair : place_pair.second) {
      if (name_pair.second.size() <= 1) {
        return;
      }
      auto it_new = name_pair.second.rbegin();
      auto it_old = name_pair.second.rbegin();
      ++it_old;
      for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
        auto *write_op = it_new->second.generated_op_;
        auto &read_ops = it_old->second.pending_ops_;
Y
Yu Yang 已提交
415 416 417 418 419 420 421 422 423
        auto *ex_write_op = it_old->second.generated_op_;

        if (ex_write_op == nullptr) {  // Nobody write this var.
          continue;
        }

        LOG(INFO) << "Link " << it_new->second.DebugString() << " From "
                  << it_old->second.version_ << " To "
                  << it_new->second.version_;
Y
Yu Yang 已提交
424 425 426

        for (auto *read_op : read_ops) {
          // Manually add a dependency var from read_op to write_op;
Y
Yu Yang 已提交
427 428 429 430
          if (read_op == write_op) {
            // Read Write is the same op.
            continue;
          }
Y
Yu Yang 已提交
431 432

          auto *dep_var = new DependencyVarHandle();
Y
Yu Yang 已提交
433

Y
Yu Yang 已提交
434 435 436 437 438 439 440 441 442 443
          dep_var->generated_op_ = read_op;
          read_op->outputs_.emplace_back(dep_var);

          dep_var->pending_ops_.emplace_back(write_op);
          write_op->inputs_.emplace_back(dep_var);
          member_->dep_vars_.emplace(dep_var);
        }
      }
    }
  }
Y
Yu Yang 已提交
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478
}

void ParallelExecutor::GenerateVar(OpHandle *op_handle,
                                   const std::string &each_var_name,
                                   const platform::Place &place) const {
  auto &vars = member_->vars_[place][each_var_name];
  size_t version = vars.size();
  auto &var = vars[version];
  var.version_ = version;
  var.generated_op_ = op_handle;
  var.name_ = each_var_name;
  var.place_ = place;
  op_handle->outputs_.emplace_back(&var);
}

VarHandle *ParallelExecutor::GetVarHandle(const std::string &each_var_name,
                                          const platform::Place &place) const {
  auto &var_holders = member_->vars_[place];
  auto &var_holder = var_holders[each_var_name];
  VarHandle *var = nullptr;
  if (var_holder.empty()) {
    auto &init_var = var_holder[0];
    init_var.place_ = place;
    init_var.name_ = each_var_name;
    init_var.generated_op_ = nullptr;
    init_var.version_ = 0;
    var = &init_var;
  } else {
    var = &var_holder.rbegin()->second;
  }
  return var;
}

void ParallelExecutor::BCastParamsToGPUs(
    const ProgramDesc &startup_program) const {
Y
Yu Yang 已提交
479
#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
480
  auto *main_scope = member_->local_scopes_[member_->main_place_];
Y
Yu Yang 已提交
481

Y
Yu Yang 已提交
482 483 484 485 486 487 488 489
  for (auto *var_desc : startup_program.Block(0).AllVars()) {
    if (var_desc->GetType() == proto::VarType::LOD_TENSOR) {
      auto &main_tensor =
          main_scope->FindVar(var_desc->Name())->Get<LoDTensor>();
      ncclDataType_t data_type = ToNCCLDataType(main_tensor.type());
      auto &dims = main_tensor.dims();
      size_t numel = main_tensor.numel();

Y
Stash  
Yu Yang 已提交
490
      platform::dynload::ncclGroupStart();
Y
Yu Yang 已提交
491

Y
Stash  
Yu Yang 已提交
492
      for (auto &pair : member_->local_scopes_) {
Y
Yu Yang 已提交
493 494 495
        auto local_scope = pair.second;
        auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>();
        t->Resize(dims);
Y
Stash  
Yu Yang 已提交
496 497 498 499
        auto &nccl_ctx = member_->GetNCCLCtx(pair.first);
        platform::dynload::ncclBcast(
            t->mutable_data(pair.first, main_tensor.type()), numel, data_type,
            0, nccl_ctx.comm, nccl_ctx.stream());
Y
Yu Yang 已提交
500
      }
Y
Stash  
Yu Yang 已提交
501 502 503
      platform::dynload::ncclGroupEnd();
    }
  }
Y
Yu Yang 已提交
504

Y
Stash  
Yu Yang 已提交
505 506
  for (auto &pair : member_->local_scopes_) {
    member_->GetNCCLCtx(pair.first).ctx_->Wait();
Y
Yu Yang 已提交
507

Y
Stash  
Yu Yang 已提交
508 509 510 511 512
    auto &b = pair.second->FindVar("fc_1.b_0")->Get<framework::LoDTensor>();
    framework::LoDTensor cpu;
    framework::TensorCopy(b, platform::CPUPlace(), &cpu);
    platform::DeviceContextPool::Instance().Get(b.place())->Wait();
    LOG(INFO) << *cpu.data<float>();
Y
Yu Yang 已提交
513
  }
Y
Stash  
Yu Yang 已提交
514

Y
Yu Yang 已提交
515 516 517 518
#else
  PADDLE_THROW("Not compiled with CUDA");
#endif
}
Y
Yu Yang 已提交
519

Y
Yu Yang 已提交
520 521 522 523 524
void ParallelExecutor::BuildNCCLCommunicator() const {
#ifdef PADDLE_WITH_CUDA
  for (auto &place_pair : member_->local_scopes_) {
    auto place = place_pair.first;
    int dev_id = boost::get<platform::CUDAPlace>(place).device;
Y
Yu Yang 已提交
525

Y
Yu Yang 已提交
526 527
    member_->communication_streams_.emplace(
        dev_id, ParallelExecutorPrivate::NCCLContext(dev_id));
Y
Yu Yang 已提交
528
  }
Y
Yu Yang 已提交
529 530 531 532

  ParallelExecutorPrivate::NCCLContext::InitNCCLContext(
      member_->communication_streams_);
#endif
Y
Yu Yang 已提交
533 534 535 536 537
}

std::vector<LoDTensor> ParallelExecutor::Run(
    const std::vector<std::string> &fetch_tensors) {
  // Version --> VarHandle
Y
Yu Yang 已提交
538
  member_->exception_.reset();
Y
Yu Yang 已提交
539
  std::unordered_map<VarHandleBase *, bool> pending_vars;
Y
Yu Yang 已提交
540 541 542 543 544
  std::unordered_map<OpHandle *, size_t> pending_ops;

  for (auto &place_pair : member_->vars_) {
    for (auto &name_pair : place_pair.second) {
      for (auto &version_pair : name_pair.second) {
Y
Yu Yang 已提交
545 546
        pending_vars[&version_pair.second] =
            version_pair.second.generated_op_ == nullptr;
Y
Yu Yang 已提交
547 548 549 550
      }
    }
  }

Y
Yu Yang 已提交
551 552 553 554
  for (auto &var : member_->dep_vars_) {
    pending_vars[var.get()] = var->generated_op_ == nullptr;
  }

Y
Yu Yang 已提交
555 556
  std::vector<OpHandle *> to_run;

Y
Yu Yang 已提交
557
  for (auto &op : member_->ops_) {
Y
Yu Yang 已提交
558 559 560 561 562 563 564 565 566
    if (op->inputs_.empty()) {  // Special case, Op has no input.
      to_run.emplace_back(op.get());
    } else {
      pending_ops.insert({op.get(), op->inputs_.size()});
    }
  }

  for (auto *op : to_run) {
    RunOp(pending_vars, op);
Y
Yu Yang 已提交
567 568
  }

Y
Yu Yang 已提交
569
  while (!pending_ops.empty()) {
Y
Yu Yang 已提交
570
    VarHandleBase *ready_var = nullptr;
Y
Yu Yang 已提交
571 572 573
    for (auto &pair : pending_vars) {
      if (pair.second) {
        ready_var = pair.first;
Y
Yu Yang 已提交
574 575
      }
    }
Y
Yu Yang 已提交
576 577

    if (ready_var == nullptr) {
Y
Yu Yang 已提交
578 579 580 581 582 583 584
      // FIXME use conditional var instead of busy wait.

      if (member_->exception_) {
        throw * member_->exception_;
      }

      std::this_thread::yield();
Y
Yu Yang 已提交
585
      continue;
Y
Yu Yang 已提交
586 587
    }

Y
Yu Yang 已提交
588 589
    pending_vars.erase(ready_var);

Y
Yu Yang 已提交
590
    to_run.clear();
Y
Yu Yang 已提交
591 592 593 594 595 596

    for (auto *op : ready_var->pending_ops_) {
      auto &deps = pending_ops[op];
      --deps;
      if (deps == 0) {
        to_run.emplace_back(op);
Y
Yu Yang 已提交
597 598 599 600 601
      }
    }

    for (auto *op : to_run) {
      pending_ops.erase(op);
Y
Yu Yang 已提交
602
      RunOp(pending_vars, op);
Y
Yu Yang 已提交
603 604 605 606
    }
  }
  return std::vector<LoDTensor>();
}
Y
Yu Yang 已提交
607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631

void ParallelExecutor::RunOp(
    std::unordered_map<VarHandleBase *, bool> &pending_vars,
    OpHandle *op) const {
  std::vector<bool *> ready_buffer;
  for (auto *var : op->outputs_) {
    ready_buffer.emplace_back(&pending_vars[var]);
  }

  auto op_run = [ready_buffer, op, this] {
    try {
      // TODO(yy) Check Previous Op has same dev ctx.
      op->Run();
      for (auto *ready : ready_buffer) {
        *ready = true;
      }
    } catch (platform::EnforceNotMet ex) {
      member_->exception_.reset(new platform::EnforceNotMet(ex));
    } catch (...) {
      LOG(FATAL) << "Unknown exception catched";
    }
  };

  member_->pool_.enqueue(op_run);
}
Y
Yu Yang 已提交
632
}  // namespace framework
Y
Yang Yang 已提交
633
}  // namespace paddle