parallel_executor.cc 22.8 KB
Newer Older
Y
Yang Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/parallel_executor.h"
Y
Yu Yang 已提交
16 17
#include "ThreadPool.h"
#include "executor.h"
Y
Yu Yang 已提交
18
#include "lod_tensor.h"
Y
Yu Yang 已提交
19
#include "lod_tensor_array.h"
Y
Yu Yang 已提交
20
#include "op_registry.h"
Y
Yu Yang 已提交
21
#include "paddle/fluid/operators/math/concat.h"
Y
Yang Yang 已提交
22 23

namespace paddle {
Y
Yu Yang 已提交
24 25
namespace framework {

Y
Yu Yang 已提交
26 27 28 29 30 31
#ifdef PADDLE_WITH_CUDA

// FIXME: CHECK the return value of x;
#define NCCL_INVOKE(x) x
#endif

Y
Yu Yang 已提交
32 33
struct OpHandle;

Y
Yu Yang 已提交
34 35 36 37 38
struct VarHandleBase {
  virtual ~VarHandleBase() {}
  virtual std::string DebugString() const = 0;

  OpHandle *generated_op_;
Y
Yu Yang 已提交
39
  std::unordered_set<OpHandle *> pending_ops_;
Y
Yu Yang 已提交
40 41 42 43 44 45 46 47 48
};

struct VarHandle : public VarHandleBase {
  std::string DebugString() const override {
    std::stringstream ss;
    ss << name_ << ":" << place_;
    return ss.str();
  }

Y
Yu Yang 已提交
49 50
  // version field currently is not used, however, just store the version to
  // debug easily.
Y
Yu Yang 已提交
51 52 53
  size_t version_;
  std::string name_;
  platform::Place place_;
Y
Yu Yang 已提交
54
};
Y
Yu Yang 已提交
55

Y
Yu Yang 已提交
56
struct DependencyVarHandle : public VarHandleBase {
Y
Yu Yang 已提交
57
  std::string DebugString() const override { return "Dependency Variable"; }
Y
Yu Yang 已提交
58 59 60
};

struct OpHandle {
Y
Yu Yang 已提交
61 62 63 64 65
  std::vector<VarHandleBase *> inputs_;
  std::vector<VarHandleBase *> outputs_;
  std::unordered_map<platform::Place, platform::DeviceContext *,
                     platform::PlaceHash>
      dev_ctx_;
Y
Yu Yang 已提交
66 67 68 69 70

  std::string DebugString() {
    std::stringstream ss;
    ss << "(";
    for (auto *var : inputs_) {
Y
Yu Yang 已提交
71
      ss << var->DebugString() << ", ";
Y
Yu Yang 已提交
72 73 74
    }
    ss << ") --> (";
    for (auto *var : outputs_) {
Y
Yu Yang 已提交
75
      ss << var->DebugString() << ", ";
Y
Yu Yang 已提交
76 77 78 79 80 81
    }
    ss << ")\n";
    return ss.str();
  }

  virtual ~OpHandle() {}
Y
Yu Yang 已提交
82

Y
Yu Yang 已提交
83
  virtual void Run() { PADDLE_THROW("Not implemented"); }
Y
Yu Yang 已提交
84
  virtual void Wait(platform::DeviceContext *waited_dev) {}
Y
Yu Yang 已提交
85 86 87 88
};

struct ComputationOpHandle : public OpHandle {
  std::unique_ptr<OperatorBase> op_;
Y
Yu Yang 已提交
89 90
  Scope *scope_;
  platform::Place place_;
Y
Yu Yang 已提交
91

Y
Yu Yang 已提交
92 93
  explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
                               platform::Place place)
Y
Yu Yang 已提交
94
      : op_(framework::OpRegistry::CreateOp(op_desc)),
Y
Yu Yang 已提交
95
        scope_(scope),
Y
Yu Yang 已提交
96 97 98 99
        place_(place) {}

  void Run() override {
    // Wait other op if necessary
Y
Yu Yang 已提交
100 101 102 103
    if (platform::is_gpu_place(place_)) {
      int dev_id = boost::get<platform::CUDAPlace>(place_).device;
      cudaSetDevice(dev_id);
    }
Y
Yu Yang 已提交
104 105 106
    auto *cur_ctx = dev_ctx_[place_];
    for (auto *in : inputs_) {
      if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
Y
Yu Yang 已提交
107
        in->generated_op_->Wait(cur_ctx);
Y
Yu Yang 已提交
108 109 110 111 112
      }
    }

    op_->Run(*scope_, place_);
  }
Y
Yu Yang 已提交
113 114 115 116

  void Wait(platform::DeviceContext *waited_dev) override {
    this->dev_ctx_.at(place_)->Wait();
  }
Y
Yu Yang 已提交
117 118
};

Y
Yu Yang 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131
struct ScaleLossGradOpHandle : public OpHandle {
  float coeff_;
  Scope *scope_;
  platform::Place place_;

  explicit ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
                                 platform::Place place)
      : coeff_(static_cast<float>(1.0 / num_dev)),
        scope_(scope),
        place_(place) {}

  void Run() override {
    std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
Y
Yu Yang 已提交
132

Y
Yu Yang 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146
    float *tmp = scope_->FindVar(var_name)
                     ->GetMutable<framework::LoDTensor>()
                     ->mutable_data<float>(make_ddim({1}), place_);

    if (platform::is_cpu_place(place_)) {
      *tmp = coeff_;
    } else {
      memory::Copy(
          boost::get<platform::CUDAPlace>(place_), tmp, platform::CPUPlace(),
          &coeff_, sizeof(float),
          static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_])
              ->stream());
    }
  }
Y
Yu Yang 已提交
147 148 149 150

  void Wait(platform::DeviceContext *waited_dev) override {
    this->dev_ctx_.at(place_)->Wait();
  }
Y
Yu Yang 已提交
151 152
};

Y
Yu Yang 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
struct FetchedData {
 public:
  std::vector<framework::LoDTensor> tensors_;

  explicit FetchedData(size_t num_fetched) { tensors_.resize(num_fetched); }
};

struct FetchOpHandle : public OpHandle {
  std::shared_ptr<FetchedData> data_;
  size_t offset_;
  std::vector<Scope *> *local_scopes_;
  std::vector<LoDTensor> tensors_;

  ~FetchOpHandle() {
    for (auto *input_var : inputs_) {
      input_var->pending_ops_.erase(this);
    }
    for (auto &pair : dev_ctx_) {
      pair.second->Wait();
    }

    // Lazily merge tensors. Will faster code.
    MergeTensors();
  }

  void Run() override {
Y
Debug  
Yu Yang 已提交
179 180 181 182
    for (auto *input : inputs_) {
      input->generated_op_->Wait(nullptr);
    }

Y
Yu Yang 已提交
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
    tensors_.resize(inputs_.size());
    auto *var = static_cast<VarHandle *>(inputs_[0]);
    auto &var_name = var->name_;
    platform::CPUPlace cpu;
    auto &scopes = *local_scopes_;

    for (size_t i = 0; i < scopes.size(); ++i) {
      auto &scope = scopes[i];
      auto &t = scope->FindVar(var_name)->Get<framework::LoDTensor>();
      if (platform::is_gpu_place(var->place_)) {
        TensorCopy(t, cpu, *dev_ctx_[t.place()], &tensors_[i]);
      } else {
        tensors_[i].ShareDataWith(t);
        tensors_[i].set_lod(t.lod());
      }
    }
  }

  void Wait(platform::DeviceContext *waited_dev) override {
    PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
  }

 private:
  void MergeTensors() const {
    std::vector<const LoDTensor *> tensors_ptr;
    for (auto &t : tensors_) {
      tensors_ptr.emplace_back(&t);
    }
    data_->tensors_[offset_].MergeLoDTensor(tensors_ptr, platform::CPUPlace());
  }
};

Y
Yu Yang 已提交
215 216
class ParallelExecutorPrivate {
 public:
Y
Yu Yang 已提交
217 218 219
  explicit ParallelExecutorPrivate(size_t num_threads = 12)
      : pool_(num_threads) {}

Y
Stash  
Yu Yang 已提交
220 221
  std::vector<platform::Place> places_;

Y
Yu Yang 已提交
222
  std::vector<Scope *> local_scopes_;
Y
Yu Yang 已提交
223
  Scope *global_scope_;
Y
Yu Yang 已提交
224

Y
Yu Yang 已提交
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
#ifdef PADDLE_WITH_CUDA
  struct NCCLContext {
    std::unique_ptr<platform::CUDADeviceContext> ctx_;
    ncclComm_t comm;

    explicit NCCLContext(int dev_id) {
      ctx_.reset(new platform::CUDADeviceContext(platform::CUDAPlace(dev_id)));
    }

    cudaStream_t stream() const { return ctx_->stream(); }

    int device_id() const {
      return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
    }

Y
Update  
Yu Yang 已提交
240 241
    static void InitNCCLContext(std::unordered_map<int, NCCLContext> &contexts,
                                const std::vector<platform::Place> &places) {
Y
Yu Yang 已提交
242 243 244 245 246
      std::vector<ncclComm_t> comms;
      std::vector<int> devs;
      comms.resize(contexts.size());
      devs.reserve(contexts.size());

Y
Update  
Yu Yang 已提交
247 248
      for (auto &p : places) {
        devs.push_back(boost::get<platform::CUDAPlace>(p).device);
Y
Yu Yang 已提交
249 250 251 252 253 254
      }

      NCCL_INVOKE(platform::dynload::ncclCommInitAll(
          &comms[0], static_cast<int>(contexts.size()), &devs[0]));

      int i = 0;
Y
Update  
Yu Yang 已提交
255 256
      for (auto &dev_id : devs) {
        contexts.at(dev_id).comm = comms[i++];
Y
Yu Yang 已提交
257 258 259 260
      }
    }
  };

Y
Update  
Yu Yang 已提交
261
  std::unordered_map<int, NCCLContext> communication_streams_;
Y
Yu Yang 已提交
262 263 264 265 266 267 268 269

  NCCLContext &GetNCCLCtx(platform::Place p) {
    int dev_id = boost::get<platform::CUDAPlace>(p).device;
    return communication_streams_.at(dev_id);
  }

#endif

Y
Yu Yang 已提交
270 271 272 273 274 275 276 277 278 279 280 281 282
  platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) {
    if (platform::is_cpu_place(place) || local_scopes_.size() == 1) {
      return const_cast<platform::DeviceContext *>(
          platform::DeviceContextPool::Instance().Get(place));
    } else {
#ifdef PADDLE_WITH_CUDA
      return GetNCCLCtx(place).ctx_.get();
#else
      PADDLE_THROW("Not compiled with CUDA")
#endif
    }
  }

Y
Yu Yang 已提交
283 284 285 286 287 288
  platform::Place main_place_;

  std::unordered_map<platform::Place,
                     std::unordered_map<std::string, std::map<int, VarHandle>>,
                     platform::PlaceHash>
      vars_;
Y
Yu Yang 已提交
289 290
  std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;

Y
Yu Yang 已提交
291
  std::vector<std::unique_ptr<OpHandle>> ops_;
Y
Yu Yang 已提交
292

Y
Yu Yang 已提交
293
  // Use a simpler thread pool, might be faster.
Y
Yu Yang 已提交
294
  ThreadPool pool_;
Y
Yu Yang 已提交
295 296

  std::unique_ptr<platform::EnforceNotMet> exception_;
Y
Yu Yang 已提交
297 298 299 300
};

// TODO(yy): Move this function somewhere
ncclDataType_t ToNCCLDataType(std::type_index type) {
Y
Stash  
Yu Yang 已提交
301 302 303 304 305 306 307 308 309
  if (type == typeid(float)) {  // NOLINT
    return ncclFloat;
  } else if (type == typeid(double)) {  // NOLINT
    return ncclDouble;
  } else if (type == typeid(int)) {  // NOLINT
    return ncclInt;
  } else {
    PADDLE_THROW("Not supported");
  }
Y
Yu Yang 已提交
310 311
}

Y
Yu Yang 已提交
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
struct NCCLAllReduceOpHandle : public OpHandle {
  ParallelExecutorPrivate *member_;

  explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
      : member_(member) {}

  void Run() override {
    if (this->inputs_.size() == 1) {
      return;  // No need to all reduce when GPU count = 1;
    } else {
      auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;

      int dtype = -1;
      size_t numel = 0;

Y
Update  
Yu Yang 已提交
327 328
      platform::dynload::ncclGroupStart();

Y
Yu Yang 已提交
329 330 331
      for (size_t i = 0; i < member_->local_scopes_.size(); ++i) {
        auto &p = member_->places_[i];
        auto *s = member_->local_scopes_[i];
Y
Yu Yang 已提交
332 333 334 335 336 337 338 339 340 341 342 343 344 345
        int dev_id = boost::get<platform::CUDAPlace>(p).device;

        auto &lod_tensor = s->FindVar(var_name)->Get<framework::LoDTensor>();
        void *buffer = const_cast<void *>(lod_tensor.data<void>());
        if (dtype == -1) {
          dtype = ToNCCLDataType(lod_tensor.type());
        }

        if (numel == 0) {
          numel = static_cast<size_t>(lod_tensor.numel());
        }

        auto &nccl_ctx = member_->communication_streams_.at(dev_id);

Y
Update  
Yu Yang 已提交
346 347 348
        platform::dynload::ncclAllReduce(
            buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
            nccl_ctx.comm, nccl_ctx.stream());
Y
Yu Yang 已提交
349 350
      }

Y
Update  
Yu Yang 已提交
351
      platform::dynload::ncclGroupEnd();
Y
Yu Yang 已提交
352 353
    }
  }
Y
Yu Yang 已提交
354 355

  void Wait(platform::DeviceContext *waited_dev) override {
Y
Debug  
Yu Yang 已提交
356 357 358
    for (auto &pair : member_->communication_streams_) {
      pair.second.ctx_->Wait();
    }
Y
Yu Yang 已提交
359
  }
Y
Yu Yang 已提交
360 361
};

Y
Yu Yang 已提交
362 363 364 365 366 367
ParallelExecutor::ParallelExecutor(
    const std::vector<platform::Place> &places,
    const std::unordered_set<std::string> &params,
    const ProgramDesc &startup_program, const ProgramDesc &main_program,
    const std::string &loss_var_name, Scope *scope)
    : member_(new ParallelExecutorPrivate()) {
Y
Stash  
Yu Yang 已提交
368
  member_->places_ = places;
Y
Yu Yang 已提交
369
  member_->global_scope_ = scope;
Y
Yu Yang 已提交
370 371 372 373
  // Step 1. RunStartupProgram and Bcast the params to devs.
  Executor exe(places[0]);
  exe.Run(startup_program, scope, 0);
  // Create local scopes
Y
Yu Yang 已提交
374 375
  for (size_t i = 0; i < member_->places_.size(); ++i) {
    member_->local_scopes_.push_back(&scope->NewScope());
Y
Yu Yang 已提交
376 377 378 379
  }
  member_->main_place_ = places[0];

  // Bcast Parameters to all GPUs
Y
Yu Yang 已提交
380
  BuildNCCLCommunicator();
Y
Yu Yang 已提交
381 382 383
  if (platform::is_gpu_place(member_->main_place_) &&
      member_->local_scopes_.size() != 1) {  // Is CUDA
    BCastParamsToGPUs(startup_program);
Y
Yu Yang 已提交
384 385 386 387 388 389
  }
  // Startup Program has been run. All local scopes has correct parameters.

  // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
  // ncclOp
  ConstructDependencyGraph(params, main_program, loss_var_name);
Y
Yu Yang 已提交
390 391

  // Step 3. Create vars in each scope;
Y
Yu Yang 已提交
392
  for (auto *scope : member_->local_scopes_) {
Y
Yu Yang 已提交
393 394 395 396 397 398 399 400
    for (auto *var : main_program.Block(0).AllVars()) {
      if (scope->FindVar(var->Name()) != nullptr) {
        continue;
      }

      InitializeVariable(scope->Var(var->Name()), var->GetType());
    }
  }
Y
Yu Yang 已提交
401 402 403 404 405
}

void ParallelExecutor::ConstructDependencyGraph(
    const std::unordered_set<std::string> &params,
    const ProgramDesc &main_program, const std::string &loss_var_name) const {
Y
Yu Yang 已提交
406
  std::unordered_set<std::string> grads;
Y
Yu Yang 已提交
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
  for (auto &each_param : params) {
    grads.insert(each_param + "@GRAD");
  }

  bool is_forwarding = true;
  for (auto *op : main_program.Block(0).AllOps()) {
    bool change_forward = false;

    if (!is_forwarding) {
      // FIXME(yy): Do not hard code like this
      if (op->OutputArgumentNames().size() == 1 &&
          op->OutputArgumentNames()[0] == loss_var_name + "@GRAD") {
        continue;  // Drop fill 1. for backward coeff;
      }
    }

Y
Yu Yang 已提交
423 424 425 426 427
    for (size_t i = 0; i < member_->places_.size(); ++i) {
      auto &p = member_->places_[i];
      auto *s = member_->local_scopes_[i];

      member_->ops_.emplace_back(new ComputationOpHandle(*op, s, p));
Y
Yu Yang 已提交
428
      auto *op_handle = member_->ops_.back().get();
Y
Yu Yang 已提交
429 430
      op_handle->dev_ctx_[p] = const_cast<platform::DeviceContext *>(
          platform::DeviceContextPool::Instance().Get(p));
Y
Yu Yang 已提交
431 432 433 434

      auto var_names = op->InputArgumentNames();

      for (auto &each_var_name : var_names) {
Y
Yu Yang 已提交
435
        VarHandle *var = GetVarHandle(each_var_name, p);
Y
Yu Yang 已提交
436
        op_handle->inputs_.emplace_back(var);
Y
Yu Yang 已提交
437
        var->pending_ops_.emplace(op_handle);
Y
Yu Yang 已提交
438 439 440 441
      }
      var_names = op->OutputArgumentNames();

      for (auto &each_var_name : var_names) {
Y
Yu Yang 已提交
442
        GenerateVar(op_handle, each_var_name, p);
Y
Yu Yang 已提交
443 444 445 446 447
      }

      if (is_forwarding) {
        if (var_names.size() == 1 && var_names[0] == loss_var_name) {
          // Insert ScaleCost OpHandle
Y
Yu Yang 已提交
448
          member_->ops_.emplace_back(new ScaleLossGradOpHandle(
Y
Yu Yang 已提交
449
              this->member_->local_scopes_.size(), s, p));
Y
Yu Yang 已提交
450
          op_handle = member_->ops_.back().get();
Y
Yu Yang 已提交
451

Y
Yu Yang 已提交
452
          op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p);
Y
Yu Yang 已提交
453

Y
Yu Yang 已提交
454 455 456 457 458 459
          // FIXME: Currently ScaleLossGradOp only use device_count as scale
          // factor. So it does not depend on any other operators.
          // VarHandle *loss = GetVarHandle(loss_var_name, place);
          // loss->pending_ops_.emplace_back(op_handle);
          // op_handle->inputs_.emplace_back(loss);

Y
Yu Yang 已提交
460
          GenerateVar(op_handle, loss_var_name + "@GRAD", p);
Y
Yu Yang 已提交
461 462 463 464 465 466 467 468 469 470 471 472 473 474
          change_forward = true;
        }
      }
    }

    if (change_forward) {
      is_forwarding = false;
    }

    if (!is_forwarding) {
      auto var_names = op->OutputArgumentNames();
      for (auto &og : var_names) {
        if (grads.count(og) != 0) {  // is param grad
          // Insert NCCL AllReduce Op
Y
Yu Yang 已提交
475
          member_->ops_.emplace_back(new NCCLAllReduceOpHandle(member_));
Y
Yu Yang 已提交
476 477
          auto *op_handle = member_->ops_.back().get();

Y
Yu Yang 已提交
478 479 480
          for (size_t i = 0; i < member_->places_.size(); ++i) {
            auto &p = member_->places_[i];
            auto &vars = member_->vars_[p][og];
Y
Yu Yang 已提交
481 482 483 484 485 486

            if (vars.empty()) {  // This device has no data. continue.
              continue;
            }
            auto *prev_grad = &vars[vars.size() - 1];
            op_handle->inputs_.emplace_back(prev_grad);
Y
Yu Yang 已提交
487
            prev_grad->pending_ops_.emplace(op_handle);
Y
Yu Yang 已提交
488
            auto &var = vars[vars.size()];
Y
Yu Yang 已提交
489
            var.place_ = p;
Y
Yu Yang 已提交
490 491 492 493
            var.generated_op_ = op_handle;
            var.name_ = og;
            var.version_ = vars.size() - 1;
            op_handle->outputs_.emplace_back(&var);
Y
Yu Yang 已提交
494

Y
Yu Yang 已提交
495
            op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p);
Y
Yu Yang 已提交
496 497 498 499 500
          }
        }
      }
    }
  }
Y
Yu Yang 已提交
501

Y
Yu Yang 已提交
502 503 504
  /*
    Dependency graph has been constructed. However, there are still data
    harzaeds need to be handled.
Y
Yu Yang 已提交
505
   */
Y
Yu Yang 已提交
506 507
  PolishGraphToSupportDataHarzaeds();
}
Y
Yu Yang 已提交
508

Y
Yu Yang 已提交
509 510 511 512 513 514 515 516
/**
 * We only handle write after read(WAR), since it should not have a write
 * after write in program. If there are write after write operators, we need
 * prune them.
 *
 * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
 */
void ParallelExecutor::PolishGraphToSupportDataHarzaeds() const {
Y
Yu Yang 已提交
517 518 519 520 521 522 523 524 525 526 527
  for (auto &place_pair : member_->vars_) {
    for (auto &name_pair : place_pair.second) {
      if (name_pair.second.size() <= 1) {
        return;
      }
      auto it_new = name_pair.second.rbegin();
      auto it_old = name_pair.second.rbegin();
      ++it_old;
      for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
        auto *write_op = it_new->second.generated_op_;
        auto &read_ops = it_old->second.pending_ops_;
Y
Yu Yang 已提交
528 529 530 531 532 533
        auto *ex_write_op = it_old->second.generated_op_;

        if (ex_write_op == nullptr) {  // Nobody write this var.
          continue;
        }

Y
Yu Yang 已提交
534 535
        for (auto *read_op : read_ops) {
          // Manually add a dependency var from read_op to write_op;
Y
Yu Yang 已提交
536 537 538 539
          if (read_op == write_op) {
            // Read Write is the same op.
            continue;
          }
Y
Yu Yang 已提交
540 541

          auto *dep_var = new DependencyVarHandle();
Y
Yu Yang 已提交
542

Y
Yu Yang 已提交
543 544 545
          dep_var->generated_op_ = read_op;
          read_op->outputs_.emplace_back(dep_var);

Y
Yu Yang 已提交
546
          dep_var->pending_ops_.emplace(write_op);
Y
Yu Yang 已提交
547 548 549 550 551 552
          write_op->inputs_.emplace_back(dep_var);
          member_->dep_vars_.emplace(dep_var);
        }
      }
    }
  }
Y
Yu Yang 已提交
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587
}

void ParallelExecutor::GenerateVar(OpHandle *op_handle,
                                   const std::string &each_var_name,
                                   const platform::Place &place) const {
  auto &vars = member_->vars_[place][each_var_name];
  size_t version = vars.size();
  auto &var = vars[version];
  var.version_ = version;
  var.generated_op_ = op_handle;
  var.name_ = each_var_name;
  var.place_ = place;
  op_handle->outputs_.emplace_back(&var);
}

VarHandle *ParallelExecutor::GetVarHandle(const std::string &each_var_name,
                                          const platform::Place &place) const {
  auto &var_holders = member_->vars_[place];
  auto &var_holder = var_holders[each_var_name];
  VarHandle *var = nullptr;
  if (var_holder.empty()) {
    auto &init_var = var_holder[0];
    init_var.place_ = place;
    init_var.name_ = each_var_name;
    init_var.generated_op_ = nullptr;
    init_var.version_ = 0;
    var = &init_var;
  } else {
    var = &var_holder.rbegin()->second;
  }
  return var;
}

void ParallelExecutor::BCastParamsToGPUs(
    const ProgramDesc &startup_program) const {
Y
Yu Yang 已提交
588
#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
589
  auto *main_scope = member_->local_scopes_[0];
Y
Yu Yang 已提交
590

Y
Yu Yang 已提交
591 592 593 594 595 596 597 598
  for (auto *var_desc : startup_program.Block(0).AllVars()) {
    if (var_desc->GetType() == proto::VarType::LOD_TENSOR) {
      auto &main_tensor =
          main_scope->FindVar(var_desc->Name())->Get<LoDTensor>();
      ncclDataType_t data_type = ToNCCLDataType(main_tensor.type());
      auto &dims = main_tensor.dims();
      size_t numel = main_tensor.numel();

Y
Stash  
Yu Yang 已提交
599
      platform::dynload::ncclGroupStart();
Y
Yu Yang 已提交
600

Y
Update  
Yu Yang 已提交
601 602 603 604 605 606
      for (size_t i = 0; i < member_->places_.size(); ++i) {
        auto place = member_->places_[i];
        void *buffer;
        if (i == 0) {
          buffer = const_cast<void *>(main_tensor.data<void>());
        } else {
Y
Yu Yang 已提交
607
          auto local_scope = member_->local_scopes_[i];
Y
Update  
Yu Yang 已提交
608 609 610 611 612
          auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>();
          t->Resize(dims);
          buffer = t->mutable_data(place, main_tensor.type());
        }

Y
Stash  
Yu Yang 已提交
613
        auto &nccl_ctx = member_->GetNCCLCtx(place);
Y
Update  
Yu Yang 已提交
614
        platform::dynload::ncclBcast(buffer, numel, data_type, 0, nccl_ctx.comm,
Y
Stash  
Yu Yang 已提交
615
                                     nccl_ctx.stream());
Y
Yu Yang 已提交
616
      }
Y
Stash  
Yu Yang 已提交
617 618 619
      platform::dynload::ncclGroupEnd();
    }
  }
Y
Yu Yang 已提交
620 621 622 623
#else
  PADDLE_THROW("Not compiled with CUDA");
#endif
}
Y
Yu Yang 已提交
624

Y
Yu Yang 已提交
625 626
void ParallelExecutor::BuildNCCLCommunicator() const {
#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
627
  for (auto &place : member_->places_) {
Y
Yu Yang 已提交
628
    int dev_id = boost::get<platform::CUDAPlace>(place).device;
Y
Yu Yang 已提交
629

Y
Yu Yang 已提交
630 631
    member_->communication_streams_.emplace(
        dev_id, ParallelExecutorPrivate::NCCLContext(dev_id));
Y
Yu Yang 已提交
632
  }
Y
Yu Yang 已提交
633 634

  ParallelExecutorPrivate::NCCLContext::InitNCCLContext(
Y
Update  
Yu Yang 已提交
635
      member_->communication_streams_, member_->places_);
Y
Yu Yang 已提交
636
#endif
Y
Yu Yang 已提交
637 638
}

Y
Yu Yang 已提交
639 640 641
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
                           const std::string &fetched_var_name) {
  auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size());
Y
Yu Yang 已提交
642
  // Version --> VarHandle
Y
Yu Yang 已提交
643
  member_->exception_.reset();
Y
Use mtx  
Yu Yang 已提交
644
  std::unordered_map<VarHandleBase *, GuardedBool> pending_vars;
Y
Yu Yang 已提交
645 646 647 648 649
  std::unordered_map<OpHandle *, size_t> pending_ops;

  for (auto &place_pair : member_->vars_) {
    for (auto &name_pair : place_pair.second) {
      for (auto &version_pair : name_pair.second) {
Y
Yu Yang 已提交
650 651
        pending_vars[&version_pair.second] =
            version_pair.second.generated_op_ == nullptr;
Y
Yu Yang 已提交
652 653 654 655
      }
    }
  }

Y
Yu Yang 已提交
656
  for (auto &var : member_->dep_vars_) {
Y
Yu Yang 已提交
657
    pending_vars[var.get()] = var->generated_op_ == nullptr;
Y
Yu Yang 已提交
658 659
  }

Y
Yu Yang 已提交
660 661
  std::vector<OpHandle *> to_run;

Y
Yu Yang 已提交
662
  for (auto &op : member_->ops_) {
Y
Yu Yang 已提交
663 664 665 666 667 668 669
    if (op->inputs_.empty()) {  // Special case, Op has no input.
      to_run.emplace_back(op.get());
    } else {
      pending_ops.insert({op.get(), op->inputs_.size()});
    }
  }

Y
Yu Yang 已提交
670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701
  std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;

  for (auto &fetch_var_name : fetch_tensors) {
    for (auto &pair : member_->vars_) {
      auto it = pair.second.find(fetch_var_name);
      if (it != pair.second.end()) {
        fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
      }
    }
  }

  std::vector<FetchOpHandle> fetch_ops;

  for (size_t i = 0; i < fetch_tensors.size(); ++i) {
    auto &var_name = fetch_tensors[i];
    auto &vars = fetched_vars[var_name];
    fetch_ops.emplace_back();
    FetchOpHandle *op = &fetch_ops.back();
    op->data_ = fetched_data;
    op->offset_ = i;
    op->local_scopes_ = &member_->local_scopes_;
    for (auto &p : member_->places_) {
      op->dev_ctx_[p] = this->member_->GetNCCLCtx(p).ctx_.get();
    }

    for (auto *var : vars) {
      var->pending_ops_.emplace(op);
      op->inputs_.emplace_back(var);
    }
    pending_ops.insert({op, op->inputs_.size()});
  }

Y
Yu Yang 已提交
702 703 704
  std::vector<std::future<void>> op_threads;
  op_threads.reserve(pending_ops.size() + to_run.size());

Y
Yu Yang 已提交
705
  for (auto *op : to_run) {
Y
Yu Yang 已提交
706
    op_threads.emplace_back(RunOp(pending_vars, op));
Y
Yu Yang 已提交
707 708
  }

Y
Yu Yang 已提交
709
  while (!pending_ops.empty()) {
Y
Yu Yang 已提交
710
    VarHandleBase *ready_var = nullptr;
Y
Yu Yang 已提交
711
    for (auto &pair : pending_vars) {
Y
Yu Yang 已提交
712
      if (pair.second) {
Y
Yu Yang 已提交
713
        ready_var = pair.first;
Y
Yu Yang 已提交
714 715
      }
    }
Y
Yu Yang 已提交
716
    if (ready_var == nullptr) {
Y
Yu Yang 已提交
717 718 719 720 721 722
      // FIXME use conditional var instead of busy wait.

      if (member_->exception_) {
        throw * member_->exception_;
      }

Y
Yu Yang 已提交
723
      VLOG(3) << pending_vars.size();
Y
Yu Yang 已提交
724
      continue;
Y
Yu Yang 已提交
725
    }
Y
Yu Yang 已提交
726
    pending_vars.erase(ready_var);
Y
Yu Yang 已提交
727
    to_run.clear();
Y
Yu Yang 已提交
728 729 730 731 732
    for (auto *op : ready_var->pending_ops_) {
      auto &deps = pending_ops[op];
      --deps;
      if (deps == 0) {
        to_run.emplace_back(op);
Y
Yu Yang 已提交
733 734 735 736
      }
    }
    for (auto *op : to_run) {
      pending_ops.erase(op);
Y
Yu Yang 已提交
737
      op_threads.emplace_back(RunOp(pending_vars, op));
Y
Yu Yang 已提交
738 739
    }
  }
Y
Yu Yang 已提交
740 741 742 743 744

  for (auto &t : op_threads) {
    t.get();  // Join all workers
  }

Y
Yu Yang 已提交
745 746 747
  fetch_ops.clear();
  *member_->global_scope_->Var(fetched_var_name)->GetMutable<LoDTensorArray>() =
      fetched_data->tensors_;
Y
Yu Yang 已提交
748
}
Y
Yu Yang 已提交
749

Y
Yu Yang 已提交
750
std::future<void> ParallelExecutor::RunOp(
Y
Use mtx  
Yu Yang 已提交
751
    std::unordered_map<VarHandleBase *, GuardedBool> &pending_vars,
Y
Yu Yang 已提交
752
    OpHandle *op) const {
Y
Use mtx  
Yu Yang 已提交
753
  std::vector<GuardedBool *> *ready_buffer = new std::vector<GuardedBool *>();
Y
Yu Yang 已提交
754
  for (auto *var : op->outputs_) {
Y
Debug  
Yu Yang 已提交
755
    ready_buffer->emplace_back(&pending_vars[var]);
Y
Yu Yang 已提交
756 757 758 759 760
  }

  auto op_run = [ready_buffer, op, this] {
    try {
      op->Run();
Y
Debug  
Yu Yang 已提交
761
      for (auto *ready : *ready_buffer) {
Y
Yu Yang 已提交
762
        *ready = true;
Y
Yu Yang 已提交
763
      }
Y
Debug  
Yu Yang 已提交
764
      delete ready_buffer;
Y
Yu Yang 已提交
765 766 767 768 769 770
    } catch (platform::EnforceNotMet ex) {
      member_->exception_.reset(new platform::EnforceNotMet(ex));
    } catch (...) {
      LOG(FATAL) << "Unknown exception catched";
    }
  };
Y
Yu Yang 已提交
771
  return member_->pool_.enqueue(op_run);
Y
Yu Yang 已提交
772
}
Y
Yu Yang 已提交
773
}  // namespace framework
Y
Yang Yang 已提交
774
}  // namespace paddle