parallel_executor.cc 20.4 KB
Newer Older
Y
Yang Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/parallel_executor.h"
Y
Yu Yang 已提交
16 17
#include "ThreadPool.h"
#include "executor.h"
Y
Yu Yang 已提交
18 19
#include "lod_tensor.h"
#include "op_registry.h"
Y
Yang Yang 已提交
20 21

namespace paddle {
Y
Yu Yang 已提交
22 23
namespace framework {

Y
Yu Yang 已提交
24 25 26 27 28 29
#ifdef PADDLE_WITH_CUDA

// FIXME: CHECK the return value of x;
#define NCCL_INVOKE(x) x
#endif

Y
Yu Yang 已提交
30 31
struct OpHandle;

Y
Yu Yang 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
struct VarHandleBase {
  virtual ~VarHandleBase() {}
  virtual std::string DebugString() const = 0;

  OpHandle *generated_op_;
  std::vector<OpHandle *> pending_ops_;
};

struct VarHandle : public VarHandleBase {
  std::string DebugString() const override {
    std::stringstream ss;
    ss << name_ << ":" << place_;
    return ss.str();
  }

Y
Yu Yang 已提交
47 48 49
  size_t version_;
  std::string name_;
  platform::Place place_;
Y
Yu Yang 已提交
50
};
Y
Yu Yang 已提交
51

Y
Yu Yang 已提交
52
struct DependencyVarHandle : public VarHandleBase {
Y
Yu Yang 已提交
53
  std::string DebugString() const override { return "Dependency Variable"; }
Y
Yu Yang 已提交
54 55 56
};

struct OpHandle {
Y
Yu Yang 已提交
57 58 59 60 61
  std::vector<VarHandleBase *> inputs_;
  std::vector<VarHandleBase *> outputs_;
  std::unordered_map<platform::Place, platform::DeviceContext *,
                     platform::PlaceHash>
      dev_ctx_;
Y
Yu Yang 已提交
62 63 64 65 66

  std::string DebugString() {
    std::stringstream ss;
    ss << "(";
    for (auto *var : inputs_) {
Y
Yu Yang 已提交
67
      ss << var->DebugString() << ", ";
Y
Yu Yang 已提交
68 69 70
    }
    ss << ") --> (";
    for (auto *var : outputs_) {
Y
Yu Yang 已提交
71
      ss << var->DebugString() << ", ";
Y
Yu Yang 已提交
72 73 74 75 76 77
    }
    ss << ")\n";
    return ss.str();
  }

  virtual ~OpHandle() {}
Y
Yu Yang 已提交
78

Y
Yu Yang 已提交
79
  virtual void Run() { PADDLE_THROW("Not implemented"); }
Y
Yu Yang 已提交
80
  virtual void Wait(platform::DeviceContext *waited_dev) {}
Y
Yu Yang 已提交
81 82 83 84
};

struct ComputationOpHandle : public OpHandle {
  std::unique_ptr<OperatorBase> op_;
Y
Yu Yang 已提交
85 86
  Scope *scope_;
  platform::Place place_;
Y
Yu Yang 已提交
87

Y
Yu Yang 已提交
88 89
  explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
                               platform::Place place)
Y
Yu Yang 已提交
90
      : op_(framework::OpRegistry::CreateOp(op_desc)),
Y
Yu Yang 已提交
91
        scope_(scope),
Y
Yu Yang 已提交
92 93 94 95
        place_(place) {}

  void Run() override {
    // Wait other op if necessary
Y
Yu Yang 已提交
96
    LOG(INFO) << "Run " << this << " " << DebugString();
Y
Yu Yang 已提交
97 98 99
    auto *cur_ctx = dev_ctx_[place_];
    for (auto *in : inputs_) {
      if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
Y
Yu Yang 已提交
100
        in->generated_op_->Wait(cur_ctx);
Y
Yu Yang 已提交
101 102 103 104
      }
    }

    op_->Run(*scope_, place_);
Y
Yu Yang 已提交
105
    LOG(INFO) << "Done " << this;
Y
Yu Yang 已提交
106
  }
Y
Yu Yang 已提交
107 108 109 110

  void Wait(platform::DeviceContext *waited_dev) override {
    this->dev_ctx_.at(place_)->Wait();
  }
Y
Yu Yang 已提交
111 112
};

Y
Yu Yang 已提交
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
struct ScaleLossGradOpHandle : public OpHandle {
  float coeff_;
  Scope *scope_;
  platform::Place place_;

  explicit ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
                                 platform::Place place)
      : coeff_(static_cast<float>(1.0 / num_dev)),
        scope_(scope),
        place_(place) {}

  void Run() override {
    LOG(INFO) << "Run Scale Loss Grad";

    std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
Y
Yu Yang 已提交
128

Y
Yu Yang 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142
    float *tmp = scope_->FindVar(var_name)
                     ->GetMutable<framework::LoDTensor>()
                     ->mutable_data<float>(make_ddim({1}), place_);

    if (platform::is_cpu_place(place_)) {
      *tmp = coeff_;
    } else {
      memory::Copy(
          boost::get<platform::CUDAPlace>(place_), tmp, platform::CPUPlace(),
          &coeff_, sizeof(float),
          static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_])
              ->stream());
    }
  }
Y
Yu Yang 已提交
143 144 145 146

  void Wait(platform::DeviceContext *waited_dev) override {
    this->dev_ctx_.at(place_)->Wait();
  }
Y
Yu Yang 已提交
147 148
};

Y
Yu Yang 已提交
149 150
class ParallelExecutorPrivate {
 public:
Y
Yu Yang 已提交
151 152 153
  explicit ParallelExecutorPrivate(size_t num_threads = 12)
      : pool_(num_threads) {}

Y
Yu Yang 已提交
154 155
  std::unordered_map<platform::Place, Scope *, platform::PlaceHash>
      local_scopes_;
Y
Yu Yang 已提交
156

Y
Stash  
Yu Yang 已提交
157 158
  std::vector<platform::Place> places_;

Y
Yu Yang 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
#ifdef PADDLE_WITH_CUDA
  struct NCCLContext {
    std::unique_ptr<platform::CUDADeviceContext> ctx_;
    ncclComm_t comm;

    explicit NCCLContext(int dev_id) {
      ctx_.reset(new platform::CUDADeviceContext(platform::CUDAPlace(dev_id)));
    }

    cudaStream_t stream() const { return ctx_->stream(); }

    int device_id() const {
      return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
    }

Y
Update  
Yu Yang 已提交
174 175
    static void InitNCCLContext(std::unordered_map<int, NCCLContext> &contexts,
                                const std::vector<platform::Place> &places) {
Y
Yu Yang 已提交
176 177 178 179 180
      std::vector<ncclComm_t> comms;
      std::vector<int> devs;
      comms.resize(contexts.size());
      devs.reserve(contexts.size());

Y
Update  
Yu Yang 已提交
181 182
      for (auto &p : places) {
        devs.push_back(boost::get<platform::CUDAPlace>(p).device);
Y
Yu Yang 已提交
183 184 185 186 187 188
      }

      NCCL_INVOKE(platform::dynload::ncclCommInitAll(
          &comms[0], static_cast<int>(contexts.size()), &devs[0]));

      int i = 0;
Y
Update  
Yu Yang 已提交
189 190
      for (auto &dev_id : devs) {
        contexts.at(dev_id).comm = comms[i++];
Y
Yu Yang 已提交
191 192 193 194
      }
    }
  };

Y
Update  
Yu Yang 已提交
195
  std::unordered_map<int, NCCLContext> communication_streams_;
Y
Yu Yang 已提交
196 197 198 199 200 201 202 203

  NCCLContext &GetNCCLCtx(platform::Place p) {
    int dev_id = boost::get<platform::CUDAPlace>(p).device;
    return communication_streams_.at(dev_id);
  }

#endif

Y
Yu Yang 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216
  platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) {
    if (platform::is_cpu_place(place) || local_scopes_.size() == 1) {
      return const_cast<platform::DeviceContext *>(
          platform::DeviceContextPool::Instance().Get(place));
    } else {
#ifdef PADDLE_WITH_CUDA
      return GetNCCLCtx(place).ctx_.get();
#else
      PADDLE_THROW("Not compiled with CUDA")
#endif
    }
  }

Y
Yu Yang 已提交
217 218 219 220 221 222
  platform::Place main_place_;

  std::unordered_map<platform::Place,
                     std::unordered_map<std::string, std::map<int, VarHandle>>,
                     platform::PlaceHash>
      vars_;
Y
Yu Yang 已提交
223 224
  std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;

Y
Yu Yang 已提交
225
  std::vector<std::unique_ptr<OpHandle>> ops_;
Y
Yu Yang 已提交
226

Y
Yu Yang 已提交
227
  // Use a simpler thread pool, might be faster.
Y
Yu Yang 已提交
228
  ThreadPool pool_;
Y
Yu Yang 已提交
229 230

  std::unique_ptr<platform::EnforceNotMet> exception_;
Y
Yu Yang 已提交
231 232 233 234
};

// TODO(yy): Move this function somewhere
ncclDataType_t ToNCCLDataType(std::type_index type) {
Y
Stash  
Yu Yang 已提交
235 236 237 238 239 240 241 242 243
  if (type == typeid(float)) {  // NOLINT
    return ncclFloat;
  } else if (type == typeid(double)) {  // NOLINT
    return ncclDouble;
  } else if (type == typeid(int)) {  // NOLINT
    return ncclInt;
  } else {
    PADDLE_THROW("Not supported");
  }
Y
Yu Yang 已提交
244 245
}

Y
Yu Yang 已提交
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
struct NCCLAllReduceOpHandle : public OpHandle {
  ParallelExecutorPrivate *member_;

  explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
      : member_(member) {}

  void Run() override {
    if (this->inputs_.size() == 1) {
      return;  // No need to all reduce when GPU count = 1;
    } else {
      auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;

      int dtype = -1;
      size_t numel = 0;

Y
Update  
Yu Yang 已提交
261 262
      platform::dynload::ncclGroupStart();

Y
Yu Yang 已提交
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
      for (auto &p : member_->places_) {
        int dev_id = boost::get<platform::CUDAPlace>(p).device;

        Scope *s = member_->local_scopes_[p];
        auto &lod_tensor = s->FindVar(var_name)->Get<framework::LoDTensor>();
        void *buffer = const_cast<void *>(lod_tensor.data<void>());
        if (dtype == -1) {
          dtype = ToNCCLDataType(lod_tensor.type());
        }

        if (numel == 0) {
          numel = static_cast<size_t>(lod_tensor.numel());
        }

        auto &nccl_ctx = member_->communication_streams_.at(dev_id);

Y
Update  
Yu Yang 已提交
279 280 281
        platform::dynload::ncclAllReduce(
            buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
            nccl_ctx.comm, nccl_ctx.stream());
Y
Yu Yang 已提交
282 283
      }

Y
Update  
Yu Yang 已提交
284
      platform::dynload::ncclGroupEnd();
Y
Yu Yang 已提交
285 286
    }
  }
Y
Yu Yang 已提交
287 288 289 290

  void Wait(platform::DeviceContext *waited_dev) override {
    this->dev_ctx_.at(waited_dev->GetPlace())->Wait();
  }
Y
Yu Yang 已提交
291 292
};

Y
Yu Yang 已提交
293 294 295 296 297 298
ParallelExecutor::ParallelExecutor(
    const std::vector<platform::Place> &places,
    const std::unordered_set<std::string> &params,
    const ProgramDesc &startup_program, const ProgramDesc &main_program,
    const std::string &loss_var_name, Scope *scope)
    : member_(new ParallelExecutorPrivate()) {
Y
Stash  
Yu Yang 已提交
299 300
  member_->places_ = places;

Y
Yu Yang 已提交
301 302 303 304 305 306 307 308 309 310
  // Step 1. RunStartupProgram and Bcast the params to devs.
  Executor exe(places[0]);
  exe.Run(startup_program, scope, 0);
  // Create local scopes
  for (auto &place : places) {
    member_->local_scopes_[place] = &scope->NewScope();
  }
  member_->main_place_ = places[0];

  // Bcast Parameters to all GPUs
Y
Yu Yang 已提交
311 312 313 314
  if (platform::is_gpu_place(member_->main_place_) &&
      member_->local_scopes_.size() != 1) {  // Is CUDA
    BuildNCCLCommunicator();
    BCastParamsToGPUs(startup_program);
Y
Yu Yang 已提交
315 316 317 318 319 320
  }
  // Startup Program has been run. All local scopes has correct parameters.

  // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
  // ncclOp
  ConstructDependencyGraph(params, main_program, loss_var_name);
Y
Yu Yang 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333

  // Step 3. Create vars in each scope;
  for (auto &pair : member_->local_scopes_) {
    auto *scope = pair.second;

    for (auto *var : main_program.Block(0).AllVars()) {
      if (scope->FindVar(var->Name()) != nullptr) {
        continue;
      }

      InitializeVariable(scope->Var(var->Name()), var->GetType());
    }
  }
Y
Yu Yang 已提交
334 335 336 337 338
}

void ParallelExecutor::ConstructDependencyGraph(
    const std::unordered_set<std::string> &params,
    const ProgramDesc &main_program, const std::string &loss_var_name) const {
Y
Yu Yang 已提交
339
  std::unordered_set<std::string> grads;
Y
Yu Yang 已提交
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
  for (auto &each_param : params) {
    grads.insert(each_param + "@GRAD");
  }

  bool is_forwarding = true;
  for (auto *op : main_program.Block(0).AllOps()) {
    bool change_forward = false;

    if (!is_forwarding) {
      // FIXME(yy): Do not hard code like this
      if (op->OutputArgumentNames().size() == 1 &&
          op->OutputArgumentNames()[0] == loss_var_name + "@GRAD") {
        continue;  // Drop fill 1. for backward coeff;
      }
    }

    for (auto &pair : member_->local_scopes_) {
Y
Yu Yang 已提交
357 358
      member_->ops_.emplace_back(
          new ComputationOpHandle(*op, pair.second, pair.first));
Y
Yu Yang 已提交
359
      auto *op_handle = member_->ops_.back().get();
Y
Yu Yang 已提交
360 361
      op_handle->dev_ctx_[pair.first] = const_cast<platform::DeviceContext *>(
          platform::DeviceContextPool::Instance().Get(pair.first));
Y
Yu Yang 已提交
362 363 364 365 366 367 368

      auto var_names = op->InputArgumentNames();

      for (auto &each_var_name : var_names) {
        auto &place = pair.first;
        VarHandle *var = GetVarHandle(each_var_name, place);
        op_handle->inputs_.emplace_back(var);
Y
Yu Yang 已提交
369
        var->pending_ops_.emplace_back(op_handle);
Y
Yu Yang 已提交
370 371 372 373 374 375 376 377 378 379 380
      }
      var_names = op->OutputArgumentNames();

      for (auto &each_var_name : var_names) {
        auto &place = pair.first;
        GenerateVar(op_handle, each_var_name, place);
      }

      if (is_forwarding) {
        if (var_names.size() == 1 && var_names[0] == loss_var_name) {
          // Insert ScaleCost OpHandle
Y
Yu Yang 已提交
381 382
          member_->ops_.emplace_back(new ScaleLossGradOpHandle(
              this->member_->local_scopes_.size(), pair.second, pair.first));
Y
Yu Yang 已提交
383
          op_handle = member_->ops_.back().get();
Y
Yu Yang 已提交
384 385 386 387

          op_handle->dev_ctx_[pair.first] =
              member_->CommunicationDevCtx(pair.first);

Y
Yu Yang 已提交
388
          auto &place = pair.first;
Y
Yu Yang 已提交
389 390 391 392 393 394
          // FIXME: Currently ScaleLossGradOp only use device_count as scale
          // factor. So it does not depend on any other operators.
          // VarHandle *loss = GetVarHandle(loss_var_name, place);
          // loss->pending_ops_.emplace_back(op_handle);
          // op_handle->inputs_.emplace_back(loss);

Y
Yu Yang 已提交
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
          GenerateVar(op_handle, loss_var_name + "@GRAD", place);
          change_forward = true;
          LOG(INFO) << "Scale Loss " << op_handle->DebugString();
        }
      }
    }

    if (change_forward) {
      is_forwarding = false;
    }

    if (!is_forwarding) {
      auto var_names = op->OutputArgumentNames();
      for (auto &og : var_names) {
        if (grads.count(og) != 0) {  // is param grad
          // Insert NCCL AllReduce Op
Y
Yu Yang 已提交
411
          member_->ops_.emplace_back(new NCCLAllReduceOpHandle(member_));
Y
Yu Yang 已提交
412 413 414 415 416 417 418 419 420 421 422
          auto *op_handle = member_->ops_.back().get();

          for (auto &pair : member_->local_scopes_) {
            auto &place = pair.first;
            auto &vars = member_->vars_[place][og];

            if (vars.empty()) {  // This device has no data. continue.
              continue;
            }
            auto *prev_grad = &vars[vars.size() - 1];
            op_handle->inputs_.emplace_back(prev_grad);
Y
Yu Yang 已提交
423
            prev_grad->pending_ops_.emplace_back(op_handle);
Y
Yu Yang 已提交
424 425 426 427 428 429
            auto &var = vars[vars.size()];
            var.place_ = place;
            var.generated_op_ = op_handle;
            var.name_ = og;
            var.version_ = vars.size() - 1;
            op_handle->outputs_.emplace_back(&var);
Y
Yu Yang 已提交
430 431 432 433 434

            for (auto &pair : member_->local_scopes_) {
              op_handle->dev_ctx_[pair.first] =
                  member_->CommunicationDevCtx(pair.first);
            }
Y
Yu Yang 已提交
435 436 437 438 439
          }
        }
      }
    }
  }
Y
Yu Yang 已提交
440

Y
Yu Yang 已提交
441 442 443
  /*
    Dependency graph has been constructed. However, there are still data
    harzaeds need to be handled.
Y
Yu Yang 已提交
444
   */
Y
Yu Yang 已提交
445 446
  PolishGraphToSupportDataHarzaeds();
}
Y
Yu Yang 已提交
447

Y
Yu Yang 已提交
448 449 450 451 452 453 454 455
/**
 * We only handle write after read(WAR), since it should not have a write
 * after write in program. If there are write after write operators, we need
 * prune them.
 *
 * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
 */
void ParallelExecutor::PolishGraphToSupportDataHarzaeds() const {
Y
Yu Yang 已提交
456 457 458 459 460 461 462 463 464 465 466
  for (auto &place_pair : member_->vars_) {
    for (auto &name_pair : place_pair.second) {
      if (name_pair.second.size() <= 1) {
        return;
      }
      auto it_new = name_pair.second.rbegin();
      auto it_old = name_pair.second.rbegin();
      ++it_old;
      for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
        auto *write_op = it_new->second.generated_op_;
        auto &read_ops = it_old->second.pending_ops_;
Y
Yu Yang 已提交
467 468 469 470 471 472 473 474 475
        auto *ex_write_op = it_old->second.generated_op_;

        if (ex_write_op == nullptr) {  // Nobody write this var.
          continue;
        }

        LOG(INFO) << "Link " << it_new->second.DebugString() << " From "
                  << it_old->second.version_ << " To "
                  << it_new->second.version_;
Y
Yu Yang 已提交
476 477 478

        for (auto *read_op : read_ops) {
          // Manually add a dependency var from read_op to write_op;
Y
Yu Yang 已提交
479 480 481 482
          if (read_op == write_op) {
            // Read Write is the same op.
            continue;
          }
Y
Yu Yang 已提交
483 484

          auto *dep_var = new DependencyVarHandle();
Y
Yu Yang 已提交
485

Y
Yu Yang 已提交
486 487 488 489 490 491 492 493 494 495
          dep_var->generated_op_ = read_op;
          read_op->outputs_.emplace_back(dep_var);

          dep_var->pending_ops_.emplace_back(write_op);
          write_op->inputs_.emplace_back(dep_var);
          member_->dep_vars_.emplace(dep_var);
        }
      }
    }
  }
Y
Yu Yang 已提交
496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
}

void ParallelExecutor::GenerateVar(OpHandle *op_handle,
                                   const std::string &each_var_name,
                                   const platform::Place &place) const {
  auto &vars = member_->vars_[place][each_var_name];
  size_t version = vars.size();
  auto &var = vars[version];
  var.version_ = version;
  var.generated_op_ = op_handle;
  var.name_ = each_var_name;
  var.place_ = place;
  op_handle->outputs_.emplace_back(&var);
}

VarHandle *ParallelExecutor::GetVarHandle(const std::string &each_var_name,
                                          const platform::Place &place) const {
  auto &var_holders = member_->vars_[place];
  auto &var_holder = var_holders[each_var_name];
  VarHandle *var = nullptr;
  if (var_holder.empty()) {
    auto &init_var = var_holder[0];
    init_var.place_ = place;
    init_var.name_ = each_var_name;
    init_var.generated_op_ = nullptr;
    init_var.version_ = 0;
    var = &init_var;
  } else {
    var = &var_holder.rbegin()->second;
  }
  return var;
}

void ParallelExecutor::BCastParamsToGPUs(
    const ProgramDesc &startup_program) const {
Y
Yu Yang 已提交
531
#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
532
  auto *main_scope = member_->local_scopes_[member_->main_place_];
Y
Yu Yang 已提交
533

Y
Yu Yang 已提交
534 535 536 537 538 539 540 541
  for (auto *var_desc : startup_program.Block(0).AllVars()) {
    if (var_desc->GetType() == proto::VarType::LOD_TENSOR) {
      auto &main_tensor =
          main_scope->FindVar(var_desc->Name())->Get<LoDTensor>();
      ncclDataType_t data_type = ToNCCLDataType(main_tensor.type());
      auto &dims = main_tensor.dims();
      size_t numel = main_tensor.numel();

Y
Stash  
Yu Yang 已提交
542
      platform::dynload::ncclGroupStart();
Y
Yu Yang 已提交
543

Y
Update  
Yu Yang 已提交
544 545 546 547 548 549 550 551 552 553 554 555
      for (size_t i = 0; i < member_->places_.size(); ++i) {
        auto place = member_->places_[i];
        void *buffer;
        if (i == 0) {
          buffer = const_cast<void *>(main_tensor.data<void>());
        } else {
          auto local_scope = member_->local_scopes_[place];
          auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>();
          t->Resize(dims);
          buffer = t->mutable_data(place, main_tensor.type());
        }

Y
Stash  
Yu Yang 已提交
556
        auto &nccl_ctx = member_->GetNCCLCtx(place);
Y
Update  
Yu Yang 已提交
557
        platform::dynload::ncclBcast(buffer, numel, data_type, 0, nccl_ctx.comm,
Y
Stash  
Yu Yang 已提交
558
                                     nccl_ctx.stream());
Y
Yu Yang 已提交
559
      }
Y
Stash  
Yu Yang 已提交
560 561 562
      platform::dynload::ncclGroupEnd();
    }
  }
Y
Yu Yang 已提交
563

Y
Yu Yang 已提交
564
  // Debug code, bias should be 1.0f.
Y
Stash  
Yu Yang 已提交
565 566
  for (auto &pair : member_->local_scopes_) {
    member_->GetNCCLCtx(pair.first).ctx_->Wait();
Y
Yu Yang 已提交
567

Y
Stash  
Yu Yang 已提交
568
    auto &b = pair.second->FindVar("fc_0.b_0")->Get<framework::LoDTensor>();
Y
Stash  
Yu Yang 已提交
569 570 571 572
    framework::LoDTensor cpu;
    framework::TensorCopy(b, platform::CPUPlace(), &cpu);
    platform::DeviceContextPool::Instance().Get(b.place())->Wait();
    LOG(INFO) << *cpu.data<float>();
Y
Yu Yang 已提交
573
  }
Y
Stash  
Yu Yang 已提交
574

Y
Yu Yang 已提交
575 576 577 578
#else
  PADDLE_THROW("Not compiled with CUDA");
#endif
}
Y
Yu Yang 已提交
579

Y
Yu Yang 已提交
580 581 582 583 584
void ParallelExecutor::BuildNCCLCommunicator() const {
#ifdef PADDLE_WITH_CUDA
  for (auto &place_pair : member_->local_scopes_) {
    auto place = place_pair.first;
    int dev_id = boost::get<platform::CUDAPlace>(place).device;
Y
Yu Yang 已提交
585

Y
Yu Yang 已提交
586 587
    member_->communication_streams_.emplace(
        dev_id, ParallelExecutorPrivate::NCCLContext(dev_id));
Y
Yu Yang 已提交
588
  }
Y
Yu Yang 已提交
589 590

  ParallelExecutorPrivate::NCCLContext::InitNCCLContext(
Y
Update  
Yu Yang 已提交
591
      member_->communication_streams_, member_->places_);
Y
Yu Yang 已提交
592
#endif
Y
Yu Yang 已提交
593 594 595 596 597
}

std::vector<LoDTensor> ParallelExecutor::Run(
    const std::vector<std::string> &fetch_tensors) {
  // Version --> VarHandle
Y
Yu Yang 已提交
598
  member_->exception_.reset();
Y
Yu Yang 已提交
599
  std::unordered_map<VarHandleBase *, bool> pending_vars;
Y
Yu Yang 已提交
600 601 602 603 604
  std::unordered_map<OpHandle *, size_t> pending_ops;

  for (auto &place_pair : member_->vars_) {
    for (auto &name_pair : place_pair.second) {
      for (auto &version_pair : name_pair.second) {
Y
Yu Yang 已提交
605 606
        pending_vars[&version_pair.second] =
            version_pair.second.generated_op_ == nullptr;
Y
Yu Yang 已提交
607 608 609 610
      }
    }
  }

Y
Yu Yang 已提交
611 612 613 614
  for (auto &var : member_->dep_vars_) {
    pending_vars[var.get()] = var->generated_op_ == nullptr;
  }

Y
Yu Yang 已提交
615 616
  std::vector<OpHandle *> to_run;

Y
Yu Yang 已提交
617
  for (auto &op : member_->ops_) {
Y
Yu Yang 已提交
618 619 620 621 622 623 624 625 626
    if (op->inputs_.empty()) {  // Special case, Op has no input.
      to_run.emplace_back(op.get());
    } else {
      pending_ops.insert({op.get(), op->inputs_.size()});
    }
  }

  for (auto *op : to_run) {
    RunOp(pending_vars, op);
Y
Yu Yang 已提交
627 628
  }

Y
Yu Yang 已提交
629
  while (!pending_ops.empty()) {
Y
Yu Yang 已提交
630
    VarHandleBase *ready_var = nullptr;
Y
Yu Yang 已提交
631 632 633
    for (auto &pair : pending_vars) {
      if (pair.second) {
        ready_var = pair.first;
Y
Yu Yang 已提交
634 635
      }
    }
Y
Yu Yang 已提交
636 637

    if (ready_var == nullptr) {
Y
Yu Yang 已提交
638 639 640 641 642 643 644
      // FIXME use conditional var instead of busy wait.

      if (member_->exception_) {
        throw * member_->exception_;
      }

      std::this_thread::yield();
Y
Yu Yang 已提交
645
      continue;
Y
Yu Yang 已提交
646 647
    }

Y
Yu Yang 已提交
648 649
    pending_vars.erase(ready_var);

Y
Yu Yang 已提交
650
    to_run.clear();
Y
Yu Yang 已提交
651 652 653 654 655 656

    for (auto *op : ready_var->pending_ops_) {
      auto &deps = pending_ops[op];
      --deps;
      if (deps == 0) {
        to_run.emplace_back(op);
Y
Yu Yang 已提交
657 658 659 660 661
      }
    }

    for (auto *op : to_run) {
      pending_ops.erase(op);
Y
Yu Yang 已提交
662
      RunOp(pending_vars, op);
Y
Yu Yang 已提交
663 664 665 666
    }
  }
  return std::vector<LoDTensor>();
}
Y
Yu Yang 已提交
667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691

void ParallelExecutor::RunOp(
    std::unordered_map<VarHandleBase *, bool> &pending_vars,
    OpHandle *op) const {
  std::vector<bool *> ready_buffer;
  for (auto *var : op->outputs_) {
    ready_buffer.emplace_back(&pending_vars[var]);
  }

  auto op_run = [ready_buffer, op, this] {
    try {
      // TODO(yy) Check Previous Op has same dev ctx.
      op->Run();
      for (auto *ready : ready_buffer) {
        *ready = true;
      }
    } catch (platform::EnforceNotMet ex) {
      member_->exception_.reset(new platform::EnforceNotMet(ex));
    } catch (...) {
      LOG(FATAL) << "Unknown exception catched";
    }
  };

  member_->pool_.enqueue(op_run);
}
Y
Yu Yang 已提交
692
}  // namespace framework
Y
Yang Yang 已提交
693
}  // namespace paddle