parallel_executor.cc 24.6 KB
Newer Older
Y
Yang Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/parallel_executor.h"
Y
Yu Yang 已提交
16 17
#include "ThreadPool.h"
#include "executor.h"
Y
Yu Yang 已提交
18
#include "lod_tensor.h"
Y
Yu Yang 已提交
19
#include "lod_tensor_array.h"
Y
Yu Yang 已提交
20
#include "op_registry.h"
Y
Yu Yang 已提交
21
#include "paddle/fluid/operators/math/concat.h"
Y
Yang Yang 已提交
22 23

namespace paddle {
Y
Yu Yang 已提交
24 25
namespace framework {

Y
Yu Yang 已提交
26 27 28 29 30 31
#ifdef PADDLE_WITH_CUDA

// FIXME: CHECK the return value of x;
#define NCCL_INVOKE(x) x
#endif

Y
Yu Yang 已提交
32 33
struct OpHandle;

Y
Yu Yang 已提交
34 35 36 37 38
struct VarHandleBase {
  virtual ~VarHandleBase() {}
  virtual std::string DebugString() const = 0;

  OpHandle *generated_op_;
Y
Yu Yang 已提交
39
  std::unordered_set<OpHandle *> pending_ops_;
Y
Yu Yang 已提交
40 41 42 43 44 45 46 47 48
};

struct VarHandle : public VarHandleBase {
  std::string DebugString() const override {
    std::stringstream ss;
    ss << name_ << ":" << place_;
    return ss.str();
  }

Y
Yu Yang 已提交
49 50
  // version field currently is not used, however, just store the version to
  // debug easily.
Y
Yu Yang 已提交
51 52 53
  size_t version_;
  std::string name_;
  platform::Place place_;
Y
Yu Yang 已提交
54
};
Y
Yu Yang 已提交
55

Y
Yu Yang 已提交
56 57 58 59
struct DummyVarHandle : public VarHandleBase {
  std::string DebugString() const override { return "dummy"; }
};

Y
Yu Yang 已提交
60
struct DependencyVarHandle : public VarHandleBase {
Y
Yu Yang 已提交
61
  std::string DebugString() const override { return "Dependency Variable"; }
Y
Yu Yang 已提交
62 63 64
};

struct OpHandle {
Y
Yu Yang 已提交
65 66 67 68 69
  std::vector<VarHandleBase *> inputs_;
  std::vector<VarHandleBase *> outputs_;
  std::unordered_map<platform::Place, platform::DeviceContext *,
                     platform::PlaceHash>
      dev_ctx_;
Y
Yu Yang 已提交
70

Y
Yu Yang 已提交
71 72
  std::unordered_map<int, cudaEvent_t> events_;

Y
Yu Yang 已提交
73 74 75 76
  std::string DebugString() {
    std::stringstream ss;
    ss << "(";
    for (auto *var : inputs_) {
Y
Yu Yang 已提交
77
      ss << var->DebugString() << ", ";
Y
Yu Yang 已提交
78 79 80
    }
    ss << ") --> (";
    for (auto *var : outputs_) {
Y
Yu Yang 已提交
81
      ss << var->DebugString() << ", ";
Y
Yu Yang 已提交
82 83 84 85 86 87
    }
    ss << ")\n";
    return ss.str();
  }

  virtual ~OpHandle() {}
Y
Yu Yang 已提交
88

Y
Yu Yang 已提交
89 90
  void Run(bool use_event) {
    if (events_.empty() && use_event) {
Y
Yu Yang 已提交
91 92 93 94 95 96 97 98 99
      for (auto &p : dev_ctx_) {
        int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
        cudaSetDevice(dev_id);
        cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming);
      }
    }

    RunImpl();

Y
Yu Yang 已提交
100 101 102 103 104 105 106
    if (use_event) {
      for (auto &p : dev_ctx_) {
        int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
        auto stream =
            static_cast<platform::CUDADeviceContext *>(p.second)->stream();
        cudaEventRecord(events_.at(dev_id), stream);
      }
Y
Yu Yang 已提交
107 108 109 110
    }
  }

  virtual void Wait(platform::DeviceContext *waited_dev) {
Y
Fix bug  
Yu Yang 已提交
111
    if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
Y
Yu Yang 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125
      for (auto &dev_ctx : dev_ctx_) {
        dev_ctx.second->Wait();
      }
    } else {
      auto stream =
          static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
      for (auto &ev : events_) {
        PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
      }
    }
  }

 protected:
  virtual void RunImpl() = 0;
Y
Yu Yang 已提交
126 127 128 129
};

struct ComputationOpHandle : public OpHandle {
  std::unique_ptr<OperatorBase> op_;
Y
Yu Yang 已提交
130 131
  Scope *scope_;
  platform::Place place_;
Y
Yu Yang 已提交
132

Y
Yu Yang 已提交
133 134
  explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
                               platform::Place place)
Y
Yu Yang 已提交
135
      : op_(framework::OpRegistry::CreateOp(op_desc)),
Y
Yu Yang 已提交
136
        scope_(scope),
Y
Yu Yang 已提交
137
        place_(place) {}
Y
Yu Yang 已提交
138

Y
Yu Yang 已提交
139 140
 protected:
  void RunImpl() override {
Y
Yu Yang 已提交
141
    // Wait other op if necessary
Y
Yu Yang 已提交
142 143 144 145
    if (platform::is_gpu_place(place_)) {
      int dev_id = boost::get<platform::CUDAPlace>(place_).device;
      cudaSetDevice(dev_id);
    }
Y
Yu Yang 已提交
146 147 148
    auto *cur_ctx = dev_ctx_[place_];
    for (auto *in : inputs_) {
      if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
Y
Yu Yang 已提交
149
        in->generated_op_->Wait(cur_ctx);
Y
Yu Yang 已提交
150 151 152 153
      }
    }

    op_->Run(*scope_, place_);
Y
Yu Yang 已提交
154
  }
Y
Yu Yang 已提交
155 156
};

Y
Yu Yang 已提交
157 158 159 160 161 162 163 164 165
struct ScaleLossGradOpHandle : public OpHandle {
  float coeff_;
  Scope *scope_;
  platform::Place place_;

  explicit ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
                                 platform::Place place)
      : coeff_(static_cast<float>(1.0 / num_dev)),
        scope_(scope),
Y
Yu Yang 已提交
166
        place_(place) {
Y
Set dev  
Yu Yang 已提交
167
    cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
Y
Yu Yang 已提交
168 169
  }

Y
Log  
Yu Yang 已提交
170
  ~ScaleLossGradOpHandle() {
Y
SetDev  
Yu Yang 已提交
171
    cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
Y
Log  
Yu Yang 已提交
172
  }
Y
Yu Yang 已提交
173

Y
Yu Yang 已提交
174 175
 protected:
  void RunImpl() override {
Y
Yu Yang 已提交
176
    std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
Y
Yu Yang 已提交
177

Y
Yu Yang 已提交
178 179 180 181 182 183 184
    float *tmp = scope_->FindVar(var_name)
                     ->GetMutable<framework::LoDTensor>()
                     ->mutable_data<float>(make_ddim({1}), place_);

    if (platform::is_cpu_place(place_)) {
      *tmp = coeff_;
    } else {
Y
Yu Yang 已提交
185
      auto stream =
Y
Yu Yang 已提交
186
          static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_])
Y
Yu Yang 已提交
187 188 189 190
              ->stream();
      memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
                   platform::CPUPlace(), &coeff_, sizeof(float), stream);
    }
Y
Yu Yang 已提交
191
  }
Y
Yu Yang 已提交
192 193
};

Y
Yu Yang 已提交
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
struct FetchedData {
 public:
  std::vector<framework::LoDTensor> tensors_;

  explicit FetchedData(size_t num_fetched) { tensors_.resize(num_fetched); }
};

struct FetchOpHandle : public OpHandle {
  std::shared_ptr<FetchedData> data_;
  size_t offset_;
  std::vector<Scope *> *local_scopes_;
  std::vector<LoDTensor> tensors_;

  ~FetchOpHandle() {
    for (auto *input_var : inputs_) {
      input_var->pending_ops_.erase(this);
    }

    // Lazily merge tensors. Will faster code.
    MergeTensors();
  }

Y
Yu Yang 已提交
216 217 218 219 220 221
  void Wait(platform::DeviceContext *waited_dev) override {
    PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
  }

 protected:
  void RunImpl() override {
Y
Debug  
Yu Yang 已提交
222
    for (auto *input : inputs_) {
Y
Yu Yang 已提交
223 224
      auto *var = static_cast<VarHandle *>(input);
      var->generated_op_->Wait(this->dev_ctx_[var->place_]);
Y
Debug  
Yu Yang 已提交
225 226
    }

Y
Yu Yang 已提交
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
    tensors_.resize(inputs_.size());
    auto *var = static_cast<VarHandle *>(inputs_[0]);
    auto &var_name = var->name_;
    platform::CPUPlace cpu;
    auto &scopes = *local_scopes_;

    for (size_t i = 0; i < scopes.size(); ++i) {
      auto &scope = scopes[i];
      auto &t = scope->FindVar(var_name)->Get<framework::LoDTensor>();
      if (platform::is_gpu_place(var->place_)) {
        TensorCopy(t, cpu, *dev_ctx_[t.place()], &tensors_[i]);
      } else {
        tensors_[i].ShareDataWith(t);
        tensors_[i].set_lod(t.lod());
      }
    }
  }

 private:
  void MergeTensors() const {
    std::vector<const LoDTensor *> tensors_ptr;
    for (auto &t : tensors_) {
      tensors_ptr.emplace_back(&t);
    }
    data_->tensors_[offset_].MergeLoDTensor(tensors_ptr, platform::CPUPlace());
  }
};

Y
Yu Yang 已提交
255 256
class ParallelExecutorPrivate {
 public:
Y
Yu Yang 已提交
257
  explicit ParallelExecutorPrivate(size_t num_threads = 12)
Y
Yu Yang 已提交
258
      : pool_(num_threads == 0 ? nullptr : new ThreadPool(num_threads)) {}
Y
Yu Yang 已提交
259

Y
Stash  
Yu Yang 已提交
260 261
  std::vector<platform::Place> places_;

Y
Yu Yang 已提交
262
  std::vector<Scope *> local_scopes_;
Y
Yu Yang 已提交
263
  Scope *global_scope_;
Y
Yu Yang 已提交
264

Y
Yu Yang 已提交
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
#ifdef PADDLE_WITH_CUDA
  struct NCCLContext {
    std::unique_ptr<platform::CUDADeviceContext> ctx_;
    ncclComm_t comm;

    explicit NCCLContext(int dev_id) {
      ctx_.reset(new platform::CUDADeviceContext(platform::CUDAPlace(dev_id)));
    }

    cudaStream_t stream() const { return ctx_->stream(); }

    int device_id() const {
      return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
    }

Y
Update  
Yu Yang 已提交
280 281
    static void InitNCCLContext(std::unordered_map<int, NCCLContext> &contexts,
                                const std::vector<platform::Place> &places) {
Y
Yu Yang 已提交
282 283 284 285 286
      std::vector<ncclComm_t> comms;
      std::vector<int> devs;
      comms.resize(contexts.size());
      devs.reserve(contexts.size());

Y
Update  
Yu Yang 已提交
287 288
      for (auto &p : places) {
        devs.push_back(boost::get<platform::CUDAPlace>(p).device);
Y
Yu Yang 已提交
289 290 291 292 293 294
      }

      NCCL_INVOKE(platform::dynload::ncclCommInitAll(
          &comms[0], static_cast<int>(contexts.size()), &devs[0]));

      int i = 0;
Y
Update  
Yu Yang 已提交
295 296
      for (auto &dev_id : devs) {
        contexts.at(dev_id).comm = comms[i++];
Y
Yu Yang 已提交
297 298 299 300
      }
    }
  };

Y
Update  
Yu Yang 已提交
301
  std::unordered_map<int, NCCLContext> communication_streams_;
Y
Yu Yang 已提交
302 303 304 305 306 307 308 309

  NCCLContext &GetNCCLCtx(platform::Place p) {
    int dev_id = boost::get<platform::CUDAPlace>(p).device;
    return communication_streams_.at(dev_id);
  }

#endif

Y
Yu Yang 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322
  platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) {
    if (platform::is_cpu_place(place) || local_scopes_.size() == 1) {
      return const_cast<platform::DeviceContext *>(
          platform::DeviceContextPool::Instance().Get(place));
    } else {
#ifdef PADDLE_WITH_CUDA
      return GetNCCLCtx(place).ctx_.get();
#else
      PADDLE_THROW("Not compiled with CUDA")
#endif
    }
  }

Y
Yu Yang 已提交
323 324 325 326 327 328
  platform::Place main_place_;

  std::unordered_map<platform::Place,
                     std::unordered_map<std::string, std::map<int, VarHandle>>,
                     platform::PlaceHash>
      vars_;
Y
Yu Yang 已提交
329 330
  std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;

Y
Yu Yang 已提交
331
  std::vector<std::unique_ptr<OpHandle>> ops_;
Y
Yu Yang 已提交
332

Y
Yu Yang 已提交
333
  // Use a simpler thread pool, might be faster.
Y
Yu Yang 已提交
334
  std::unique_ptr<ThreadPool> pool_;
Y
Yu Yang 已提交
335 336

  std::unique_ptr<platform::EnforceNotMet> exception_;
Y
Yu Yang 已提交
337 338 339 340
};

// TODO(yy): Move this function somewhere
ncclDataType_t ToNCCLDataType(std::type_index type) {
Y
Stash  
Yu Yang 已提交
341 342 343 344 345 346 347 348 349
  if (type == typeid(float)) {  // NOLINT
    return ncclFloat;
  } else if (type == typeid(double)) {  // NOLINT
    return ncclDouble;
  } else if (type == typeid(int)) {  // NOLINT
    return ncclInt;
  } else {
    PADDLE_THROW("Not supported");
  }
Y
Yu Yang 已提交
350 351
}

Y
Yu Yang 已提交
352 353 354 355
struct NCCLAllReduceOpHandle : public OpHandle {
  ParallelExecutorPrivate *member_;

  explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
Y
Yu Yang 已提交
356
      : member_(member) {}
Y
Yu Yang 已提交
357

Y
Yu Yang 已提交
358 359
 protected:
  void RunImpl() override {
Y
Yu Yang 已提交
360 361 362
    if (this->inputs_.size() == 1) {
      return;  // No need to all reduce when GPU count = 1;
    } else {
Y
Yu Yang 已提交
363 364 365 366 367
      // Wait input done
      for (auto *in : inputs_) {
        auto &p = static_cast<VarHandle *>(in)->place_;
        in->generated_op_->Wait(dev_ctx_[p]);
      }
Y
Yu Yang 已提交
368
      VLOG(3) << "Before NCCL";
Y
Yu Yang 已提交
369
      PADDLE_ENFORCE(cudaDeviceSynchronize());
Y
Yu Yang 已提交
370

Y
Yu Yang 已提交
371 372 373 374
      auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
      int dtype = -1;
      size_t numel = 0;

Y
Update  
Yu Yang 已提交
375 376
      platform::dynload::ncclGroupStart();

Y
Yu Yang 已提交
377 378 379
      for (size_t i = 0; i < member_->local_scopes_.size(); ++i) {
        auto &p = member_->places_[i];
        auto *s = member_->local_scopes_[i];
Y
Yu Yang 已提交
380 381 382 383 384 385 386 387 388 389 390 391 392
        int dev_id = boost::get<platform::CUDAPlace>(p).device;

        auto &lod_tensor = s->FindVar(var_name)->Get<framework::LoDTensor>();
        void *buffer = const_cast<void *>(lod_tensor.data<void>());
        if (dtype == -1) {
          dtype = ToNCCLDataType(lod_tensor.type());
        }

        if (numel == 0) {
          numel = static_cast<size_t>(lod_tensor.numel());
        }

        auto &nccl_ctx = member_->communication_streams_.at(dev_id);
Y
Update  
Yu Yang 已提交
393 394 395
        platform::dynload::ncclAllReduce(
            buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
            nccl_ctx.comm, nccl_ctx.stream());
Y
Yu Yang 已提交
396
      }
Y
Update  
Yu Yang 已提交
397
      platform::dynload::ncclGroupEnd();
Y
Yu Yang 已提交
398
      PADDLE_ENFORCE(cudaDeviceSynchronize());
Y
Yu Yang 已提交
399 400

      VLOG(3) << "After NCCL";
Y
Debug  
Yu Yang 已提交
401
    }
Y
Yu Yang 已提交
402
  }
Y
Yu Yang 已提交
403 404
};

Y
Yu Yang 已提交
405 406 407 408 409 410
ParallelExecutor::ParallelExecutor(
    const std::vector<platform::Place> &places,
    const std::unordered_set<std::string> &params,
    const ProgramDesc &startup_program, const ProgramDesc &main_program,
    const std::string &loss_var_name, Scope *scope)
    : member_(new ParallelExecutorPrivate()) {
Y
Stash  
Yu Yang 已提交
411
  member_->places_ = places;
Y
Yu Yang 已提交
412
  member_->global_scope_ = scope;
Y
Yu Yang 已提交
413 414 415 416
  // Step 1. RunStartupProgram and Bcast the params to devs.
  Executor exe(places[0]);
  exe.Run(startup_program, scope, 0);
  // Create local scopes
Y
Yu Yang 已提交
417 418
  for (size_t i = 0; i < member_->places_.size(); ++i) {
    member_->local_scopes_.push_back(&scope->NewScope());
Y
Yu Yang 已提交
419 420 421 422
  }
  member_->main_place_ = places[0];

  // Bcast Parameters to all GPUs
Y
Yu Yang 已提交
423
  BuildNCCLCommunicator();
Y
Yu Yang 已提交
424 425 426
  if (platform::is_gpu_place(member_->main_place_) &&
      member_->local_scopes_.size() != 1) {  // Is CUDA
    BCastParamsToGPUs(startup_program);
Y
Yu Yang 已提交
427 428 429 430 431 432
  }
  // Startup Program has been run. All local scopes has correct parameters.

  // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
  // ncclOp
  ConstructDependencyGraph(params, main_program, loss_var_name);
Y
Yu Yang 已提交
433 434

  // Step 3. Create vars in each scope;
Y
Yu Yang 已提交
435
  for (auto *scope : member_->local_scopes_) {
Y
Yu Yang 已提交
436 437 438 439 440 441 442 443
    for (auto *var : main_program.Block(0).AllVars()) {
      if (scope->FindVar(var->Name()) != nullptr) {
        continue;
      }

      InitializeVariable(scope->Var(var->Name()), var->GetType());
    }
  }
Y
Yu Yang 已提交
444 445 446 447 448
}

void ParallelExecutor::ConstructDependencyGraph(
    const std::unordered_set<std::string> &params,
    const ProgramDesc &main_program, const std::string &loss_var_name) const {
Y
Yu Yang 已提交
449
  std::unordered_set<std::string> grads;
Y
Yu Yang 已提交
450 451 452 453 454 455 456 457 458 459 460 461 462 463 464
  for (auto &each_param : params) {
    grads.insert(each_param + "@GRAD");
  }

  bool is_forwarding = true;
  for (auto *op : main_program.Block(0).AllOps()) {
    bool change_forward = false;
    if (!is_forwarding) {
      // FIXME(yy): Do not hard code like this
      if (op->OutputArgumentNames().size() == 1 &&
          op->OutputArgumentNames()[0] == loss_var_name + "@GRAD") {
        continue;  // Drop fill 1. for backward coeff;
      }
    }

Y
Yu Yang 已提交
465 466 467 468 469
    for (size_t i = 0; i < member_->places_.size(); ++i) {
      auto &p = member_->places_[i];
      auto *s = member_->local_scopes_[i];

      member_->ops_.emplace_back(new ComputationOpHandle(*op, s, p));
Y
Yu Yang 已提交
470
      auto *op_handle = member_->ops_.back().get();
Y
Yu Yang 已提交
471 472
      op_handle->dev_ctx_[p] = const_cast<platform::DeviceContext *>(
          platform::DeviceContextPool::Instance().Get(p));
Y
Yu Yang 已提交
473 474 475 476

      auto var_names = op->InputArgumentNames();

      for (auto &each_var_name : var_names) {
Y
Yu Yang 已提交
477
        VarHandle *var = GetVarHandle(each_var_name, p);
Y
Yu Yang 已提交
478
        op_handle->inputs_.emplace_back(var);
Y
Yu Yang 已提交
479
        var->pending_ops_.emplace(op_handle);
Y
Yu Yang 已提交
480 481 482 483
      }
      var_names = op->OutputArgumentNames();

      for (auto &each_var_name : var_names) {
Y
Yu Yang 已提交
484
        GenerateVar(op_handle, each_var_name, p);
Y
Yu Yang 已提交
485 486 487 488 489
      }

      if (is_forwarding) {
        if (var_names.size() == 1 && var_names[0] == loss_var_name) {
          // Insert ScaleCost OpHandle
Y
Yu Yang 已提交
490
          member_->ops_.emplace_back(new ScaleLossGradOpHandle(
Y
Yu Yang 已提交
491
              this->member_->local_scopes_.size(), s, p));
Y
Yu Yang 已提交
492
          op_handle = member_->ops_.back().get();
Y
Yu Yang 已提交
493

Y
Yu Yang 已提交
494
          op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p);
Y
Yu Yang 已提交
495

Y
Yu Yang 已提交
496 497 498 499 500 501
          // FIXME: Currently ScaleLossGradOp only use device_count as scale
          // factor. So it does not depend on any other operators.
          // VarHandle *loss = GetVarHandle(loss_var_name, place);
          // loss->pending_ops_.emplace_back(op_handle);
          // op_handle->inputs_.emplace_back(loss);

Y
Yu Yang 已提交
502
          GenerateVar(op_handle, loss_var_name + "@GRAD", p);
Y
Yu Yang 已提交
503 504 505 506 507 508 509 510 511 512 513 514 515 516
          change_forward = true;
        }
      }
    }

    if (change_forward) {
      is_forwarding = false;
    }

    if (!is_forwarding) {
      auto var_names = op->OutputArgumentNames();
      for (auto &og : var_names) {
        if (grads.count(og) != 0) {  // is param grad
          // Insert NCCL AllReduce Op
Y
Yu Yang 已提交
517
          member_->ops_.emplace_back(new NCCLAllReduceOpHandle(member_));
Y
Yu Yang 已提交
518 519
          auto *op_handle = member_->ops_.back().get();

Y
Yu Yang 已提交
520 521 522
          for (size_t i = 0; i < member_->places_.size(); ++i) {
            auto &p = member_->places_[i];
            auto &vars = member_->vars_[p][og];
Y
Yu Yang 已提交
523 524 525 526 527 528

            if (vars.empty()) {  // This device has no data. continue.
              continue;
            }
            auto *prev_grad = &vars[vars.size() - 1];
            op_handle->inputs_.emplace_back(prev_grad);
Y
Yu Yang 已提交
529
            prev_grad->pending_ops_.emplace(op_handle);
Y
Yu Yang 已提交
530
            auto &var = vars[vars.size()];
Y
Yu Yang 已提交
531
            var.place_ = p;
Y
Yu Yang 已提交
532 533 534 535
            var.generated_op_ = op_handle;
            var.name_ = og;
            var.version_ = vars.size() - 1;
            op_handle->outputs_.emplace_back(&var);
Y
Yu Yang 已提交
536
            op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p);
Y
Yu Yang 已提交
537 538 539 540 541
          }
        }
      }
    }
  }
Y
Yu Yang 已提交
542

Y
Yu Yang 已提交
543 544 545
  /*
    Dependency graph has been constructed. However, there are still data
    harzaeds need to be handled.
Y
Yu Yang 已提交
546
   */
547
  PolishGraphToSupportDataHazards();
Y
Yu Yang 已提交
548
}
Y
Yu Yang 已提交
549

Y
Yu Yang 已提交
550 551 552 553 554 555 556
/**
 * We only handle write after read(WAR), since it should not have a write
 * after write in program. If there are write after write operators, we need
 * prune them.
 *
 * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
 */
557
void ParallelExecutor::PolishGraphToSupportDataHazards() const {
Y
Yu Yang 已提交
558 559 560 561 562 563 564 565 566 567 568
  for (auto &place_pair : member_->vars_) {
    for (auto &name_pair : place_pair.second) {
      if (name_pair.second.size() <= 1) {
        return;
      }
      auto it_new = name_pair.second.rbegin();
      auto it_old = name_pair.second.rbegin();
      ++it_old;
      for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
        auto *write_op = it_new->second.generated_op_;
        auto &read_ops = it_old->second.pending_ops_;
Y
Yu Yang 已提交
569 570 571 572 573 574
        auto *ex_write_op = it_old->second.generated_op_;

        if (ex_write_op == nullptr) {  // Nobody write this var.
          continue;
        }

Y
Yu Yang 已提交
575 576
        for (auto *read_op : read_ops) {
          // Manually add a dependency var from read_op to write_op;
Y
Yu Yang 已提交
577 578 579 580
          if (read_op == write_op) {
            // Read Write is the same op.
            continue;
          }
Y
Yu Yang 已提交
581 582

          auto *dep_var = new DependencyVarHandle();
Y
Yu Yang 已提交
583

Y
Yu Yang 已提交
584 585 586
          dep_var->generated_op_ = read_op;
          read_op->outputs_.emplace_back(dep_var);

Y
Yu Yang 已提交
587
          dep_var->pending_ops_.emplace(write_op);
Y
Yu Yang 已提交
588 589 590 591 592 593
          write_op->inputs_.emplace_back(dep_var);
          member_->dep_vars_.emplace(dep_var);
        }
      }
    }
  }
Y
Yu Yang 已提交
594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628
}

void ParallelExecutor::GenerateVar(OpHandle *op_handle,
                                   const std::string &each_var_name,
                                   const platform::Place &place) const {
  auto &vars = member_->vars_[place][each_var_name];
  size_t version = vars.size();
  auto &var = vars[version];
  var.version_ = version;
  var.generated_op_ = op_handle;
  var.name_ = each_var_name;
  var.place_ = place;
  op_handle->outputs_.emplace_back(&var);
}

VarHandle *ParallelExecutor::GetVarHandle(const std::string &each_var_name,
                                          const platform::Place &place) const {
  auto &var_holders = member_->vars_[place];
  auto &var_holder = var_holders[each_var_name];
  VarHandle *var = nullptr;
  if (var_holder.empty()) {
    auto &init_var = var_holder[0];
    init_var.place_ = place;
    init_var.name_ = each_var_name;
    init_var.generated_op_ = nullptr;
    init_var.version_ = 0;
    var = &init_var;
  } else {
    var = &var_holder.rbegin()->second;
  }
  return var;
}

void ParallelExecutor::BCastParamsToGPUs(
    const ProgramDesc &startup_program) const {
Y
Yu Yang 已提交
629
#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
630
  auto *main_scope = member_->local_scopes_[0];
Y
Yu Yang 已提交
631

Y
Yu Yang 已提交
632 633 634 635 636 637 638 639
  for (auto *var_desc : startup_program.Block(0).AllVars()) {
    if (var_desc->GetType() == proto::VarType::LOD_TENSOR) {
      auto &main_tensor =
          main_scope->FindVar(var_desc->Name())->Get<LoDTensor>();
      ncclDataType_t data_type = ToNCCLDataType(main_tensor.type());
      auto &dims = main_tensor.dims();
      size_t numel = main_tensor.numel();

Y
Stash  
Yu Yang 已提交
640
      platform::dynload::ncclGroupStart();
Y
Yu Yang 已提交
641

Y
Update  
Yu Yang 已提交
642 643 644 645 646 647
      for (size_t i = 0; i < member_->places_.size(); ++i) {
        auto place = member_->places_[i];
        void *buffer;
        if (i == 0) {
          buffer = const_cast<void *>(main_tensor.data<void>());
        } else {
Y
Yu Yang 已提交
648
          auto local_scope = member_->local_scopes_[i];
Y
Update  
Yu Yang 已提交
649 650 651 652 653
          auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>();
          t->Resize(dims);
          buffer = t->mutable_data(place, main_tensor.type());
        }

Y
Stash  
Yu Yang 已提交
654
        auto &nccl_ctx = member_->GetNCCLCtx(place);
Y
Update  
Yu Yang 已提交
655
        platform::dynload::ncclBcast(buffer, numel, data_type, 0, nccl_ctx.comm,
Y
Stash  
Yu Yang 已提交
656
                                     nccl_ctx.stream());
Y
Yu Yang 已提交
657
      }
Y
Stash  
Yu Yang 已提交
658 659
      platform::dynload::ncclGroupEnd();
    }
Y
Yu Yang 已提交
660 661 662 663

    for (auto &stream : member_->communication_streams_) {
      stream.second.ctx_->Wait();
    }
Y
Stash  
Yu Yang 已提交
664
  }
Y
Yu Yang 已提交
665 666 667 668
#else
  PADDLE_THROW("Not compiled with CUDA");
#endif
}
Y
Yu Yang 已提交
669

Y
Yu Yang 已提交
670 671
void ParallelExecutor::BuildNCCLCommunicator() const {
#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
672
  for (auto &place : member_->places_) {
Y
Yu Yang 已提交
673
    int dev_id = boost::get<platform::CUDAPlace>(place).device;
Y
Yu Yang 已提交
674

Y
Yu Yang 已提交
675 676
    member_->communication_streams_.emplace(
        dev_id, ParallelExecutorPrivate::NCCLContext(dev_id));
Y
Yu Yang 已提交
677
  }
Y
Yu Yang 已提交
678 679

  ParallelExecutorPrivate::NCCLContext::InitNCCLContext(
Y
Update  
Yu Yang 已提交
680
      member_->communication_streams_, member_->places_);
Y
Yu Yang 已提交
681
#endif
Y
Yu Yang 已提交
682 683
}

Y
Yu Yang 已提交
684 685
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
                           const std::string &fetched_var_name) {
Y
Yu Yang 已提交
686
  bool use_event = false;
Y
Yu Yang 已提交
687
  auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size());
Y
Yu Yang 已提交
688
  // Version --> VarHandle
Y
Yu Yang 已提交
689
  member_->exception_.reset();
Y
Yu Yang 已提交
690
  std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
Y
Yu Yang 已提交
691
  std::unordered_map<OpHandle *, size_t> pending_ops;
Y
Yu Yang 已提交
692
  std::vector<DummyVarHandle> dummy_vars;
Y
Yu Yang 已提交
693 694 695 696

  for (auto &place_pair : member_->vars_) {
    for (auto &name_pair : place_pair.second) {
      for (auto &version_pair : name_pair.second) {
Y
Yu Yang 已提交
697 698
        pending_vars[&version_pair.second] =
            version_pair.second.generated_op_ == nullptr;
Y
Yu Yang 已提交
699 700 701 702
      }
    }
  }

Y
Yu Yang 已提交
703
  for (auto &var : member_->dep_vars_) {
Y
Yu Yang 已提交
704
    pending_vars[var.get()] = var->generated_op_ == nullptr;
Y
Yu Yang 已提交
705 706
  }

Y
Yu Yang 已提交
707 708
  std::vector<OpHandle *> to_run;

Y
Yu Yang 已提交
709
  for (auto &op : member_->ops_) {
Y
Yu Yang 已提交
710 711 712 713 714 715 716
    if (op->inputs_.empty()) {  // Special case, Op has no input.
      to_run.emplace_back(op.get());
    } else {
      pending_ops.insert({op.get(), op->inputs_.size()});
    }
  }

Y
Yu Yang 已提交
717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738
  std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;

  for (auto &fetch_var_name : fetch_tensors) {
    for (auto &pair : member_->vars_) {
      auto it = pair.second.find(fetch_var_name);
      if (it != pair.second.end()) {
        fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
      }
    }
  }

  std::vector<FetchOpHandle> fetch_ops;

  for (size_t i = 0; i < fetch_tensors.size(); ++i) {
    auto &var_name = fetch_tensors[i];
    auto &vars = fetched_vars[var_name];
    fetch_ops.emplace_back();
    FetchOpHandle *op = &fetch_ops.back();
    op->data_ = fetched_data;
    op->offset_ = i;
    op->local_scopes_ = &member_->local_scopes_;
    for (auto &p : member_->places_) {
Y
Yu Yang 已提交
739
      op->dev_ctx_[p] = member_->GetNCCLCtx(p).ctx_.get();
Y
Yu Yang 已提交
740 741 742 743 744 745
    }

    for (auto *var : vars) {
      var->pending_ops_.emplace(op);
      op->inputs_.emplace_back(var);
    }
Y
Yu Yang 已提交
746 747 748 749 750 751 752

    dummy_vars.emplace_back();
    auto *var = &dummy_vars.back();
    op->outputs_.emplace_back(var);
    var->generated_op_ = op;
    pending_vars[var] = false;

Y
Yu Yang 已提交
753 754 755
    pending_ops.insert({op, op->inputs_.size()});
  }

Y
Yu Yang 已提交
756
  for (auto *op : to_run) {
Y
Yu Yang 已提交
757
    RunOp(use_event, pending_vars, op);
Y
Yu Yang 已提交
758 759
  }

Y
Yu Yang 已提交
760
  while (!pending_vars.empty()) {
Y
Yu Yang 已提交
761
    VarHandleBase *ready_var = nullptr;
Y
Yu Yang 已提交
762
    for (auto &pair : pending_vars) {
Y
Yu Yang 已提交
763
      if (pair.second.load(std::memory_order_consume)) {
Y
Yu Yang 已提交
764
        ready_var = pair.first;
Y
Yu Yang 已提交
765 766
      }
    }
Y
Yu Yang 已提交
767
    if (ready_var == nullptr) {
Y
Yu Yang 已提交
768 769 770 771
      // FIXME use conditional var instead of busy wait.
      if (member_->exception_) {
        throw * member_->exception_;
      }
Y
Yu Yang 已提交
772
      continue;
Y
Yu Yang 已提交
773
    }
Y
Yu Yang 已提交
774
    pending_vars.erase(ready_var);
Y
Yu Yang 已提交
775
    to_run.clear();
Y
Yu Yang 已提交
776 777 778 779 780
    for (auto *op : ready_var->pending_ops_) {
      auto &deps = pending_ops[op];
      --deps;
      if (deps == 0) {
        to_run.emplace_back(op);
Y
Yu Yang 已提交
781 782 783 784
      }
    }
    for (auto *op : to_run) {
      pending_ops.erase(op);
Y
Yu Yang 已提交
785
      RunOp(use_event, pending_vars, op);
Y
Yu Yang 已提交
786 787
    }
  }
Y
Yu Yang 已提交
788

Y
Yu Yang 已提交
789 790 791
  for (auto &p : member_->places_) {
    platform::DeviceContextPool::Instance().Get(p)->Wait();
  }
Y
Yu Yang 已提交
792 793 794 795

  fetch_ops.clear();
  *member_->global_scope_->Var(fetched_var_name)->GetMutable<LoDTensorArray>() =
      fetched_data->tensors_;
Y
Yu Yang 已提交
796
}
Y
Yu Yang 已提交
797

Y
Yu Yang 已提交
798
void ParallelExecutor::RunOp(
Y
Yu Yang 已提交
799
    bool use_event,
Y
Yu Yang 已提交
800
    std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
Y
Yu Yang 已提交
801
    OpHandle *op) const {
Y
Yu Yang 已提交
802 803
  std::vector<std::atomic<bool> *> *ready_buffer =
      new std::vector<std::atomic<bool> *>();
Y
Yu Yang 已提交
804
  for (auto *var : op->outputs_) {
Y
Debug  
Yu Yang 已提交
805
    ready_buffer->emplace_back(&pending_vars[var]);
Y
Yu Yang 已提交
806 807
  }

Y
Yu Yang 已提交
808
  auto op_run = [ready_buffer, op, this, use_event] {
Y
Yu Yang 已提交
809
    try {
Y
Add log  
Yu Yang 已提交
810
      VLOG(10) << op->DebugString();
Y
Yu Yang 已提交
811
      op->Run(use_event);
Y
Debug  
Yu Yang 已提交
812
      for (auto *ready : *ready_buffer) {
Y
Yu Yang 已提交
813
        ready->store(true, std::memory_order_release);
Y
Yu Yang 已提交
814
      }
Y
Debug  
Yu Yang 已提交
815
      delete ready_buffer;
Y
Yu Yang 已提交
816 817 818 819 820 821
    } catch (platform::EnforceNotMet ex) {
      member_->exception_.reset(new platform::EnforceNotMet(ex));
    } catch (...) {
      LOG(FATAL) << "Unknown exception catched";
    }
  };
Y
Yu Yang 已提交
822 823 824 825 826
  if (member_->pool_) {
    member_->pool_->enqueue(op_run);
  } else {
    op_run();
  }
Y
Yu Yang 已提交
827
}
Y
Yu Yang 已提交
828
}  // namespace framework
Y
Yang Yang 已提交
829
}  // namespace paddle