parallel_executor.cc 33.5 KB
Newer Older
Y
Yang Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/parallel_executor.h"
D
dzhwinter 已提交
16
#include <algorithm>
Q
qingqing01 已提交
17
#include <memory>
C
chengduoZH 已提交
18
#include <string>
19
#include <tuple>
Q
Qiao Longfei 已提交
20
#include <utility>
Q
qiaolongfei 已提交
21
#include <vector>
Q
Qiao Longfei 已提交
22
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
Y
yuyang18 已提交
23
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
24
#include "paddle/fluid/framework/details/multi_devices_helper.h"
25
#include "paddle/fluid/framework/details/op_handle_base.h"
Y
Yancey1989 已提交
26
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
Y
yuyang18 已提交
27
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
Y
Yu Yang 已提交
28
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
29 30
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
31
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
32
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
33
#include "paddle/fluid/platform/profiler.h"
Y
Yu Yang 已提交
34

35 36
DECLARE_bool(use_ngraph);

37 38
DECLARE_double(eager_delete_tensor_gb);

Y
Yu Yang 已提交
39
#ifdef WITH_GPERFTOOLS
Y
Yu Yang 已提交
40
#include "gperftools/profiler.h"
Y
Yu Yang 已提交
41
#endif
Y
Yu Yang 已提交
42
DEFINE_string(pe_profile_fname, "",
Y
Yu Yang 已提交
43 44
              "Profiler filename for PE, which generated by gperftools."
              "Only valid when compiled `WITH_PRIFILER=ON`. Empty if disable.");
45
DEFINE_bool(enable_parallel_graph, false,
Y
Yancey1989 已提交
46
            "Force disable parallel graph execution mode if set false.");
Y
Yu Yang 已提交
47

Y
Yang Yang 已提交
48
namespace paddle {
Y
Yu Yang 已提交
49 50
namespace framework {

Y
Yu Yang 已提交
51
static std::once_flag gProfileOnce;
Y
Yu Yang 已提交
52
#ifdef WITH_GPERFTOOLS
Y
Yu Yang 已提交
53
static bool gProfileStarted = false;
Y
Yu Yang 已提交
54
#endif
55

Y
Yu Yang 已提交
56 57 58
class ParallelExecutorPrivate {
 public:
  explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places)
Y
Yu Yang 已提交
59
      : places_(places) {
Y
Yu Yang 已提交
60
    if (!FLAGS_pe_profile_fname.empty()) {
Y
Yu Yang 已提交
61 62
      std::call_once(gProfileOnce, [] {
#ifdef WITH_GPERFTOOLS
Y
Yu Yang 已提交
63
        ProfilerStart(FLAGS_pe_profile_fname.c_str());
Y
Yu Yang 已提交
64 65 66
        gProfileStarted = true;
#else
        LOG(WARNING) << "Paddle is not compiled with gperftools. "
67
          "FLAGS_pe_profile_fname will be ignored";
Y
Yu Yang 已提交
68 69 70 71
#endif
      });
    }
  }
Y
Yu Yang 已提交
72

73 74 75 76 77 78 79 80 81 82 83
  ~ParallelExecutorPrivate() {
    if (own_local_scope_) {
      for (size_t i = 1; i < local_scopes_.size(); ++i) {
        // Skip the first scope, since it is the global scope.
        Scope *local_scope = local_scopes_[i];
        if (global_scope_->HasKid(local_scope)) {
          global_scope_->DeleteScope(local_scope);
        }
      }
    }
  }
S
sneaxiy 已提交
84

85
  ir::Graph *ApplyMemoryOptimizePass(ir::Graph *graph);
S
sneaxiy 已提交
86 87 88

  inline bool HasGarbageCollectors() const { return !gcs_.empty(); }

89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
  /**
   * NOTE(zengjinle): the feeded variables of users should not be reused,
   * because users may feed them into another network. Changing the feeded
   * variables that users can visit may cause calculation wrong, which is
   * a very subtle bug when traning networks. However, these variables
   * can be garbage collected.
   *
   * ParallelExecutor provides 2 methods to feed variables:
   *
   *  - FeedTensorsIntoLocalScopes: this method would share memory of feeded
   *                                variables, so we have to skip these.
   *
   *  - FeedAndSplitTensorIntoLocalScopes: this method would copy data of feeded
   *                                       variables, so we do not need to skip
   *                                       them.
   */
  inline void SetSkipMemoryReuse(size_t scope_idx, const std::string &name) {
    auto iter = mem_opt_var_infos_[scope_idx].find(name);
    if (iter != mem_opt_var_infos_[scope_idx].end()) {
      iter->second->SetSkipMemoryReuse(true);
    }
  }

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
  void InitNCCLCtxs(framework::Scope *scope, const BuildStrategy &bst) {
    VLOG(1) << "nccl comm num:" << bst.nccl_comm_num_ << ", nranks:" << nranks_
            << ", num_trainers:" << bst.num_trainers_
            << ", trainer_id:" << bst.trainer_id_;

    if (bst.use_hierarchical_allreduce_) {
      VLOG(1) << ", use_hierarchical_allreduce:"
              << bst.use_hierarchical_allreduce_ << ", inter_trainers_num:"
              << bst.hierarchical_allreduce_inter_nranks_
              << ", exter_trainers_num:"
              << bst.hierarchical_allreduce_exter_nranks_;
    }

    std::vector<ncclUniqueId *> flat_nccl_ids;
    if (nranks_ == 1) {
      // FIXME(gongwb): need not to create ncclid when nranks==1
129 130
      nccl_ctxs_->InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_,
                               bst.trainer_id_);
131 132 133 134 135 136 137 138 139 140 141 142
      return;
    }

    if (bst.enable_parallel_graph_) {
      VLOG(1) << "use only one ncclid in pg model";

      ncclUniqueId *nccl_id = nullptr;

      std::string var_name = platform::GetFlatNCCLVarName(0);
      auto nccl_id_var = scope->FindVar(var_name);
      if (nccl_id_var) {
        nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
143
        VLOG(10) << "find nccl_id_var:" << var_name << ", nccl_id:" << nccl_id;
144 145 146
      } else {
        nccl_id = new ncclUniqueId();
        PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id));
147 148
        VLOG(10) << "can't find nccl_id_var:" << var_name
                 << ", nccl_id:" << nccl_id;
149 150 151 152
      }

      flat_nccl_ids.push_back(nccl_id);

153 154
      nccl_ctxs_->InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_,
                               bst.trainer_id_);
155 156 157 158 159 160
      VLOG(1) << "init bst nccl context complete!";
      return;
    }

    // num_trainers ==1 && places > 1
    if (bst.num_trainers_ == 1) {
161 162
      nccl_ctxs_->InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_,
                               bst.trainer_id_);
163 164 165 166 167 168 169 170 171 172 173
      return;
    }

    for (int i = 0; i < static_cast<int>(bst.nccl_comm_num_); i++) {
      std::string var_name = platform::GetFlatNCCLVarName(i);
      auto nccl_id_var = scope->FindVar(var_name);
      PADDLE_ENFORCE(nccl_id_var, "can't find %s nccl_id_var", var_name);
      auto nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
      flat_nccl_ids.push_back(nccl_id);
    }

174 175
    nccl_ctxs_->InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_,
                             bst.trainer_id_);
176 177

    if (bst.use_hierarchical_allreduce_) {
G
gongweibao 已提交
178 179 180 181 182 183 184 185
      std::vector<ncclUniqueId *> inter_nccl_ids;
      for (int i = 0; i < static_cast<int>(bst.nccl_comm_num_); i++) {
        std::string var_name = platform::GetHierarchicalInterNCCLVarName(i);
        auto nccl_id_var = scope->FindVar(var_name);
        PADDLE_ENFORCE(nccl_id_var, "can't find %s nccl_id_var", var_name);
        auto inter_nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
        inter_nccl_ids.push_back(inter_nccl_id);
      }
186 187 188 189 190 191 192 193 194

      std::vector<ncclUniqueId *> exter_nccl_ids;
      for (int i = 0; i < static_cast<int>(bst.nccl_comm_num_); i++) {
        std::string var_name = platform::GetHierarchicalExterNCCLVarName(i);
        auto nccl_id_var = scope->FindVar(var_name);
        PADDLE_ENFORCE(nccl_id_var, "can't find %s nccl_id_var", var_name);
        auto nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
        exter_nccl_ids.push_back(nccl_id);
      }
G
gongweibao 已提交
195

196 197 198 199
      nccl_ctxs_->InitHierarchicalCtxs(
          places_, inter_nccl_ids, exter_nccl_ids, bst.num_trainers_,
          bst.trainer_id_, bst.hierarchical_allreduce_inter_nranks_,
          bst.hierarchical_allreduce_exter_nranks_);
200 201
    }
  }
202

203
  void InitOrGetNCCLCommunicator(framework::Scope *scope, BuildStrategy *bst) {
204 205 206 207 208 209 210 211 212 213 214
    const std::string var_name = "NCCLCommunicator";
    auto var = scope->FindVar(var_name);
    if (var != nullptr) {
      PADDLE_ENFORCE(var->IsInitialized(),
                     "if %s exists, it must be initialized", var_name);
      VLOG(1) << "find " << var_name
              << " in scope, so use it and does not recreate!";
      nccl_ctxs_ = var->GetMutable<platform::NCCLCommunicator>();
      return;
    }

215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
    if (bst->use_hierarchical_allreduce_) {
      PADDLE_ENFORCE(bst->num_trainers_ > 1, "num_trainers:%llu < 1",
                     bst->num_trainers_);
      PADDLE_ENFORCE(bst->hierarchical_allreduce_inter_nranks_ > 1,
                     "inter_nranks:%d < 1",
                     bst->hierarchical_allreduce_inter_nranks_);
      PADDLE_ENFORCE(
          (bst->num_trainers_ % bst->hierarchical_allreduce_inter_nranks_ == 0),
          "num_trainers:%llu mod inter_nranks:%d != 0", bst->num_trainers_,
          bst->hierarchical_allreduce_inter_nranks_);

      bst->hierarchical_allreduce_exter_nranks_ =
          bst->num_trainers_ / bst->hierarchical_allreduce_inter_nranks_;
    }

230 231
    VLOG(1) << "not find " << var_name << " in scope, so recreate it!";
    nccl_ctxs_ = scope->Var(var_name)->GetMutable<platform::NCCLCommunicator>();
232
    InitNCCLCtxs(scope, *bst);
233
  }
234 235
#endif

236 237 238 239 240
  inline bool IsPersistable(const std::string &name) const {
    auto iter = is_persistable_.find(name);
    return iter != is_persistable_.end() && iter->second;
  }

D
dzhwinter 已提交
241
  BuildStrategy build_strategy_;
Y
Yu Yang 已提交
242 243
  std::vector<platform::Place> places_;
  std::vector<Scope *> local_scopes_;
244
  std::vector<Scope *> local_exec_scopes_;
245
  Scope *global_scope_;  // not owned
Y
Yu Yang 已提交
246
  std::unique_ptr<details::SSAGraphExecutor> executor_;
Y
Yu Yang 已提交
247

248 249
  std::unordered_map<std::string, bool> is_persistable_;

P
peizhilin 已提交
250
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
251
  platform::NCCLCommunicator *nccl_ctxs_{nullptr};
Y
Yu Yang 已提交
252
#endif
C
chengduoZH 已提交
253 254
  bool own_local_scope_;
  bool use_cuda_;
255
  bool use_all_reduce_;
256
  size_t nranks_;
S
sneaxiy 已提交
257

258
  ir::MemOptVarInfoMapList mem_opt_var_infos_;
259
  ir::GarbageCollectorMap gcs_;
Y
Yu Yang 已提交
260 261
};

262
ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
263 264 265 266 267 268 269
  if (FLAGS_use_ngraph) {
    LOG_FIRST_N(WARNING, 1)
        << "FLAGS_use_ngraph=True, memory optimization strategy is "
           "disabled in ParallelExecutor";
    return graph;
  }

Z
Zeng Jinle 已提交
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
  /**
   * NOTE(zengjinle): If BuildStrategy.memory_optimize = None in Python,
   * set BuildStrategy.memory_optimize according to whether gc is enabled.
   * If gc is enabled, BuildStrategy.memory_optimize = False.
   * If gc is disabled, BuildStrategy.memory_optimize = True.
   * This is because gc+memory_optimize is worse than gc only.
   *
   * As an option, users can enable BuildStrategy.memory_optimize forcely
   * by setting True, and disable it forcely by setting False.
   */
  bool is_gc_enabled = (GetEagerDeletionThreshold() >= 0);
  if (!build_strategy_.memory_optimize_) {
    build_strategy_.memory_optimize_ = !is_gc_enabled;
  }

  bool need_mem_opt = build_strategy_.enable_inplace_ ||
                      build_strategy_.memory_optimize_.get() || is_gc_enabled;

  if (!need_mem_opt) return graph;

290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
  std::vector<ir::LastLiveOpsOfVars> last_live_ops_of_vars;

  auto ref_cnt_pass = ir::PassRegistry::Instance().Get("reference_count_pass");
  ref_cnt_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_);
  ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars);
  graph = ref_cnt_pass->Apply(graph);
  VLOG(10) << "ReferenceCountPass Applied";

  if (build_strategy_.enable_inplace_) {
    auto inplace_pass =
        ir::PassRegistry::Instance().Get("buffer_shared_inplace_pass");
    inplace_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_);
    inplace_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars);
    inplace_pass->SetNotOwned(ir::kUseCuda, &use_cuda_);
    VLOG(10) << "Start to apply buffer_shared_inplace_pass";
    graph = inplace_pass->Apply(graph);
    VLOG(10) << "buffer_shared_inplace_pass Applied";
307 308
    LOG_FIRST_N(INFO, 1) << "Inplace strategy is enabled, when "
                            "build_strategy.enable_inplace = True";
309 310
  }

311
  if (build_strategy_.memory_optimize_.get()) {
312 313 314 315 316 317 318 319 320 321
    auto cross_op_memory_reuse_pass = ir::PassRegistry::Instance().Get(
        "buffer_shared_cross_op_memory_reuse_pass");
    cross_op_memory_reuse_pass->SetNotOwned(ir::kMemOptVarInfoMapList,
                                            &mem_opt_var_infos_);
    cross_op_memory_reuse_pass->SetNotOwned(ir::kLastLiveOpsOfVars,
                                            &last_live_ops_of_vars);
    cross_op_memory_reuse_pass->SetNotOwned(ir::kUseCuda, &use_cuda_);
    VLOG(10) << "Start to apply buffer_shared_cross_op_memory_reuse_pass";
    graph = cross_op_memory_reuse_pass->Apply(graph);
    VLOG(10) << "buffer_shared_cross_op_memory_reuse_pass Applied";
Z
Zeng Jinle 已提交
322 323 324
    LOG(INFO) << "Cross op memory reuse strategy is enabled, when "
                 "build_strategy.memory_optimize = True or garbage collection "
                 "strategy is disabled, which is not recommended";
325
  }
326

327
  if (!is_gc_enabled) {
328 329 330 331
    return graph;
  }
  size_t max_memory_size = static_cast<size_t>(GetEagerDeletionThreshold());

S
sneaxiy 已提交
332 333 334 335 336
  for (size_t i = 0; i < places_.size(); ++i) {
    auto &place = places_[i];
    if (gcs_.count(place) > 0) {
      continue;
    }
S
sneaxiy 已提交
337
    std::unique_ptr<GarbageCollector> gc;
S
sneaxiy 已提交
338
#ifdef PADDLE_WITH_CUDA
S
sneaxiy 已提交
339 340
    if (platform::is_gpu_place(place)) {
      if (IsFastEagerDeletionModeEnabled()) {
S
sneaxiy 已提交
341 342
        gc.reset(new UnsafeFastGPUGarbageCollector(
            boost::get<platform::CUDAPlace>(place), max_memory_size));
S
sneaxiy 已提交
343
      } else {
S
sneaxiy 已提交
344 345
        gc.reset(new StreamGarbageCollector(
            boost::get<platform::CUDAPlace>(place), max_memory_size));
S
sneaxiy 已提交
346 347
      }
      VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
S
sneaxiy 已提交
348
    } else {
S
sneaxiy 已提交
349
#endif
S
sneaxiy 已提交
350 351 352 353 354 355 356
      if (platform::is_cpu_place(place)) {
        gc.reset(new CPUGarbageCollector(boost::get<platform::CPUPlace>(place),
                                         max_memory_size));
        VLOG(10) << "Created GarbageCollector at " << place;
      } else {
        PADDLE_THROW("Unsupported place for garbage collection");
      }
S
sneaxiy 已提交
357 358 359 360
#ifdef PADDLE_WITH_CUDA
    }
#endif

S
sneaxiy 已提交
361
    gcs_.emplace(place, std::move(gc));
S
sneaxiy 已提交
362 363
  }

S
sneaxiy 已提交
364
  if (!gcs_.empty()) {
S
sneaxiy 已提交
365 366
    auto eager_deletion_pass =
        ir::PassRegistry::Instance().Get("eager_deletion_pass");
367 368
    eager_deletion_pass->SetNotOwned(ir::kMemOptVarInfoMapList,
                                     &mem_opt_var_infos_);
369 370
    eager_deletion_pass->SetNotOwned(ir::kGarbageCollector, &gcs_);
    eager_deletion_pass->SetNotOwned(ir::kLastLiveOpsOfVars,
S
sneaxiy 已提交
371
                                     &last_live_ops_of_vars);
372
    eager_deletion_pass->SetNotOwned(ir::kAllPlaces, &places_);
373
    graph = eager_deletion_pass->Apply(graph);
S
sneaxiy 已提交
374
    VLOG(10) << "EagerDeletionPass Applied";
375 376 377
    LOG_FIRST_N(INFO, 1) << "Garbage collection strategy is enabled, when "
                         << "FLAGS_eager_delete_tensor_gb = "
                         << FLAGS_eager_delete_tensor_gb;
S
sneaxiy 已提交
378 379 380 381
  }
  return graph;
}

382 383
size_t ParallelExecutor::DeviceCount() const { return member_->places_.size(); }

384 385 386 387
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
  return member_->local_scopes_;
}

388 389 390 391 392 393 394 395 396 397 398 399 400 401
void ParallelExecutor::DropLocalExeScopes() {
  auto executor = dynamic_cast<details::ScopeBufferedSSAGraphExecutor *>(
      member_->executor_.get());
  if (executor) {
    executor->DropLocalExeScopes();
  }
}

bool ParallelExecutor::NeedCreateLocalExeScope() {
  auto executor = dynamic_cast<details::ScopeBufferedSSAGraphExecutor *>(
      member_->executor_.get());
  return executor && executor->NeedCreateLocalExeScope();
}

Y
Yan Xu 已提交
402 403 404 405 406 407 408 409
ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
                                   const std::vector<std::string> &bcast_vars,
                                   const std::string &loss_var_name,
                                   Scope *scope,
                                   const std::vector<Scope *> &local_scopes,
                                   const ExecutionStrategy &exec_strategy,
                                   const BuildStrategy &build_strategy,
                                   ir::Graph *graph)
Y
Yu Yang 已提交
410
    : member_(new ParallelExecutorPrivate(places)) {
Y
Yu Yang 已提交
411
  member_->global_scope_ = scope;
412
  member_->use_cuda_ = exec_strategy.use_cuda_;
D
dzhwinter 已提交
413
  member_->build_strategy_ = build_strategy;
C
chengduo 已提交
414 415
  member_->use_all_reduce_ = member_->build_strategy_.reduce_ ==
                             BuildStrategy::ReduceStrategy::kAllReduce;
X
Xin Pan 已提交
416
  member_->nranks_ = build_strategy.num_trainers_ * places.size();
C
chengduo 已提交
417 418 419 420 421 422 423
  if (!member_->use_all_reduce_ && member_->nranks_ == 1) {
    LOG(INFO) << "If you set build_strategy.reduce with 'Reduce',"
                 "the number of places should be greater than 1.";
    member_->build_strategy_.reduce_ =
        BuildStrategy::ReduceStrategy::kAllReduce;
    member_->use_all_reduce_ = true;
  }
424 425 426 427 428
#if defined(PADDLE_WITH_CUDA) && defined(_WIN32)
  if (member_->use_cuda_) {
    PADDLE_ENFORCE(places.size() == 1, "Windows can support Single GPU only.");
  }
#endif
Y
Yancey1989 已提交
429

430
  LOG(INFO) << string::Sprintf(
431 432 433
      "The Program will be executed on %s using ParallelExecutor, %lu "
      "cards are used, so %lu programs are executed in parallel.",
      (member_->use_cuda_ ? "CUDA" : "CPU"), places.size(), places.size());
C
chengduo 已提交
434

435
  // Step 1. Bcast the bcast_vars to devs.
Y
Yu Yang 已提交
436
  // Create local scopes
437
  if (local_scopes.empty()) {
C
chengduoZH 已提交
438
    member_->own_local_scope_ = true;
Y
Yu Yang 已提交
439 440
    member_->local_scopes_.emplace_back(member_->global_scope_);
    for (size_t i = 1; i < member_->places_.size(); ++i) {
Y
Debug  
Yu Yang 已提交
441
      member_->local_scopes_.emplace_back(&scope->NewScope());
442 443
    }
  } else {
C
chengduoZH 已提交
444
    member_->own_local_scope_ = false;
445 446
    PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size());
    for (size_t i = 0; i < member_->places_.size(); ++i) {
447
      member_->local_scopes_.emplace_back(&local_scopes[i]->NewScope());
448
    }
Y
Yu Yang 已提交
449 450
  }

Q
Qiao Longfei 已提交
451
  std::vector<ir::Graph *> graphs;
C
chengduo 已提交
452
  if (member_->build_strategy_.async_mode_) {
Q
Qiao Longfei 已提交
453 454
    PADDLE_ENFORCE(!member_->use_cuda_,
                   "gpu mode does not support async_mode_ now!");
Q
Qiao Longfei 已提交
455
    graphs.push_back(graph);
D
dongdaxiang 已提交
456
    for (size_t i = 1; i < places.size(); ++i) {
Q
Qiao Longfei 已提交
457 458 459 460
      auto *tmp_graph = new ir::Graph(graph->OriginProgram());
      async_graphs_.emplace_back(tmp_graph);
      graphs.push_back(tmp_graph);
    }
Q
Qiao Longfei 已提交
461
  }
Q
Qiao Longfei 已提交
462

Y
Yancey1989 已提交
463 464 465
  // FIXME(Yancey1989): parallel graph mode get better performance
  // in GPU allreduce distributed training. Need an elegant way to
  // choice the execution strategy.
C
chengduo 已提交
466 467 468 469
  member_->build_strategy_.enable_parallel_graph_ =
      EnableParallelGraphExecution(*graph, exec_strategy,
                                   member_->build_strategy_);
  if (member_->build_strategy_.enable_parallel_graph_) {
470 471 472 473
    LOG(INFO) << "The Executor would execute the graph by ParallelGraph "
                 "Execution which can get better performance,"
              << "you can force it off by env FLAGS_enable_parallel_graph=0";
  }
Y
Yancey1989 已提交
474

475
  if (member_->use_cuda_ && member_->nranks_ > 1) {
P
peizhilin 已提交
476
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
477
    member_->InitOrGetNCCLCommunicator(scope, &member_->build_strategy_);
Q
qingqing01 已提交
478

W
Wu Yi 已提交
479 480 481
    // Initialize device context's nccl comm, will be used by normal
    // Operators like sync_batch_norm, and collective ops.
    // NOTE: more than one ParallelExecutor with same place, the nccl comm will
Q
qingqing01 已提交
482
    // be rewrite and there will be some problem.
W
Wu Yi 已提交
483 484 485
    // NOTE: NCCL group-calls and non-group-calls can not use the same
    // NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use
    // same communicators.
486 487
    auto *nccl_ctxs =
        member_->nccl_ctxs_->GetSyncBatchNormCtx(scope, member_->places_);
488
    auto &pool = platform::DeviceContextPool::Instance();
Q
qingqing01 已提交
489 490 491
    for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) {
      auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
          pool.Get(member_->places_[dev_id]));
492
      auto &nccl_ctx = nccl_ctxs->at(member_->places_[dev_id]);
493
      dev_ctx->set_nccl_comm(nccl_ctx.comm());
Q
qingqing01 已提交
494
    }
Y
Yu Yang 已提交
495
#endif
C
chengduoZH 已提交
496
  }
Y
Yan Xu 已提交
497 498
  // broadcast parameters from the 0th device to others:
  auto need_broadcast = [&]() -> bool {
C
chengduo 已提交
499
    if (member_->build_strategy_.num_trainers_ > 1) {
Y
Yan Xu 已提交
500 501 502 503 504 505 506 507 508
      // 1. num_tariners would be grater than 1 for nccl distributed training.
      return true;
    } else if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
      // 2. Only one trainer process, but ParallelExecutor hold multiple
      // devices.
      return true;
    }
    return false;
  };
509
  // Bcast Parameters to all GPUs
Y
Yan Xu 已提交
510
  if (need_broadcast()) {
C
chengduo 已提交
511
    BCastParamsToDevices(bcast_vars, member_->build_strategy_.trainer_id_);
Y
Yu Yang 已提交
512
  }
513

Q
Qiao Longfei 已提交
514
  // Startup Program has been run. All local scopes has correct parameters.
Y
yuyang18 已提交
515

Q
Qiao Longfei 已提交
516 517 518
  // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
  // ncclOp
  std::vector<ir::Graph *> async_graphs(places.size());
P
peizhilin 已提交
519
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
C
chengduo 已提交
520
  if (member_->build_strategy_.async_mode_) {
Q
Qiao Longfei 已提交
521
    VLOG(3) << "use local async mode";
C
chengduo 已提交
522 523 524 525
    graph = member_->build_strategy_.Apply(
        graph, {member_->places_[0]}, loss_var_name,
        {member_->local_scopes_[0]}, 1, member_->use_cuda_,
        member_->nccl_ctxs_);
D
dongdaxiang 已提交
526
    for (size_t i = 1; i < member_->places_.size(); ++i) {
C
chengduo 已提交
527 528 529 530
      graphs[i] = member_->build_strategy_.Apply(
          graphs[i], {member_->places_[i]}, loss_var_name,
          {member_->local_scopes_[i]}, 1, member_->use_cuda_,
          member_->nccl_ctxs_);
531
      async_graphs[i] = graphs[i];
Q
Qiao Longfei 已提交
532
    }
Q
Qiao Longfei 已提交
533
  } else {
C
chengduo 已提交
534 535 536
    graph = member_->build_strategy_.Apply(
        graph, member_->places_, loss_var_name, member_->local_scopes_,
        member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_);
Q
Qiao Longfei 已提交
537
  }
C
chengduoZH 已提交
538
#else
C
chengduo 已提交
539
  if (member_->build_strategy_.async_mode_) {
Q
Qiao Longfei 已提交
540
    VLOG(3) << "use local async mode";
C
chengduo 已提交
541 542 543
    graph = member_->build_strategy_.Apply(
        graph, {member_->places_[0]}, loss_var_name,
        {member_->local_scopes_[0]}, 1, member_->use_cuda_);
544
    for (size_t i = 1; i < member_->places_.size(); ++i) {
C
chengduo 已提交
545
      graphs[i] = member_->build_strategy_.Apply(
546
          graphs[i], {member_->places_[i]}, loss_var_name,
Q
Qiao Longfei 已提交
547
          {member_->local_scopes_[i]}, 1, member_->use_cuda_);
548
      async_graphs[i] = graphs[i];
Q
Qiao Longfei 已提交
549
    }
Q
can run  
Qiao Longfei 已提交
550
  } else {
C
chengduo 已提交
551 552 553
    graph = member_->build_strategy_.Apply(
        graph, member_->places_, loss_var_name, member_->local_scopes_,
        member_->nranks_, member_->use_cuda_);
Q
can run  
Qiao Longfei 已提交
554
  }
Y
Yu Yang 已提交
555
#endif
556

557
  graph = member_->ApplyMemoryOptimizePass(graph);
Y
Yancey1989 已提交
558

Q
Qiao Longfei 已提交
559 560
  async_graphs[0] = graph;

561 562
  // Step 3. Create vars in each scope. Passes may also create new vars.
  //         skip control vars and empty vars
Y
Yancey1989 已提交
563
  std::vector<details::VariableInfo> var_infos;
Q
Qiao Longfei 已提交
564 565 566 567 568 569
  for (auto &node : graph->Nodes()) {
    if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
      var_infos.emplace_back();
      var_infos.back().name_ = node->Var()->Name();
      var_infos.back().type_ = node->Var()->GetType();
      var_infos.back().persistable_ = node->Var()->Persistable();
570 571 572

      member_->is_persistable_.emplace(node->Var()->Name(),
                                       node->Var()->Persistable());
Y
Yancey1989 已提交
573 574
    }
  }
Y
Yancey1989 已提交
575

576 577 578 579 580 581 582 583 584 585 586 587
  std::unordered_map<Scope *, Scope *> scope_map;
  for (auto *scope : member_->local_scopes_) {
    auto &local_exec_scope = scope->NewScope();
    member_->local_exec_scopes_.emplace_back(&local_exec_scope);
    scope_map.emplace(scope, &local_exec_scope);
  }

  PADDLE_ENFORCE_EQ(member_->local_scopes_.size(),
                    member_->local_exec_scopes_.size());

  std::vector<ir::Graph *> final_graphs;

C
chengduo 已提交
588
  if (member_->build_strategy_.async_mode_) {
Q
can run  
Qiao Longfei 已提交
589 590
    VLOG(3) << "use AsyncSSAGraphExecutor";
    member_->executor_.reset(new details::AsyncSSAGraphExecutor(
591 592 593
        exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
        member_->places_, async_graphs));
    final_graphs = async_graphs;
C
chengduo 已提交
594
  } else if (member_->build_strategy_.enable_parallel_graph_) {
Q
can run  
Qiao Longfei 已提交
595
    VLOG(3) << "use ParallelSSAGraphExecutor";
Y
Yancey1989 已提交
596
#ifdef PADDLE_WITH_CUDA
Y
Yancey1989 已提交
597 598
    // TODO(Yancey1989): Remove passing in the main_program when
    // allreduce_seq_pass doesn't need it as the attr.
599 600 601 602 603
    auto *pg_exe = new details::ParallelSSAGraphExecutor(
        exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
        member_->places_, graph);
    final_graphs = pg_exe->Graphs();
    member_->executor_.reset(pg_exe);
Y
Yancey1989 已提交
604 605 606 607
#else
    PADDLE_THROW(
        "Paddle should be compiled with CUDA for ParallelGraph Execution.");
#endif
Y
yuyang18 已提交
608
  } else {
Y
Yancey1989 已提交
609
    if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
Q
can run  
Qiao Longfei 已提交
610
      VLOG(3) << "use ThreadedSSAGraphExecutor";
Y
Yancey1989 已提交
611
      member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
612 613
          exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
          member_->places_, graph));
Y
Yancey1989 已提交
614
    } else {
Q
can run  
Qiao Longfei 已提交
615
      VLOG(3) << "use FastThreadedSSAGraphExecutor";
Y
Yancey1989 已提交
616
      member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
617 618
          exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
          member_->places_, graph));
Y
Yancey1989 已提交
619
    }
620
    final_graphs.emplace_back(graph);
C
chengduoZH 已提交
621
  }
Y
yuyang18 已提交
622

Q
can run  
Qiao Longfei 已提交
623
  VLOG(3) << "use ScopeBufferedSSAGraphExecutor";
C
chengduo 已提交
624
  if (!member_->build_strategy_.async_mode_) {
Q
Qiao Longfei 已提交
625
    member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
626 627 628 629 630 631 632 633 634
        exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
        std::move(var_infos), member_->places_, std::move(member_->executor_)));
  }

  for (auto *g : final_graphs) {
    auto ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*g);
    for (auto *op : ops) {
      op->SetLocalExecScopes(scope_map);
    }
Q
Qiao Longfei 已提交
635
  }
Y
Yu Yang 已提交
636 637
}

Y
Yancey1989 已提交
638
void ParallelExecutor::BCastParamsToDevices(
Y
Yan Xu 已提交
639
    const std::vector<std::string> &vars, int trainer_id) const {
Q
Qiao Longfei 已提交
640
  VLOG(3) << "BCastParamsToDevices";
X
Xin Pan 已提交
641
  // the initializing bcast, all vars would be bcast from device(0).
642
  for (auto &var : vars) {
X
Xin Pan 已提交
643
    framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var);
J
JiayiFeng 已提交
644
    if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
645 646 647 648
      continue;
    }

    auto &main_tensor = main_var->Get<LoDTensor>();
649
    if (!main_tensor.IsInitialized()) {
M
minqiyang 已提交
650
      VLOG(3) << "one in var not inited, return!";
651 652
      continue;
    }
653 654
    auto &dims = main_tensor.dims();
    if (paddle::platform::is_gpu_place(main_tensor.place())) {
P
peizhilin 已提交
655
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
656
      std::vector<void *> buffers;
C
chengduo 已提交
657
      buffers.reserve(member_->places_.size());
658 659 660 661 662
      size_t numel = main_tensor.numel();
      ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());
      for (size_t i = 0; i < member_->places_.size(); ++i) {
        auto place = member_->places_[i];
        void *buffer;
663

Y
Yan Xu 已提交
664
        if (i == 0 && trainer_id == 0) {
665 666
          buffer = const_cast<void *>(main_tensor.data<void>());
        } else {
Y
Yu Yang 已提交
667
          auto local_scope = member_->local_scopes_[i];
668
          auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
Y
Update  
Yu Yang 已提交
669
          t->Resize(dims);
670
          buffer = t->mutable_data(place, main_tensor.type());
Y
Update  
Yu Yang 已提交
671
        }
672
        buffers.push_back(buffer);
673
      }
674

675 676 677
      PADDLE_ENFORCE_EQ(member_->places_.size(), buffers.size(),
                        "variables' buffer size to bcast NOT equal to places");
      {
678
        auto *nccl_ctxs = member_->nccl_ctxs_->DefaultFlatCtx();
679 680
        platform::NCCLGroupGuard guard;
        for (size_t i = 0; i < member_->places_.size(); ++i) {
681
          auto &nccl_ctx = nccl_ctxs->at(member_->places_[i]);
X
Xin Pan 已提交
682 683
          platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
                                       nccl_ctx.comm_, nccl_ctx.stream());
684
        }
685
        nccl_ctxs->WaitAll();
686
      }
C
chengduoZH 已提交
687
#endif
688 689
    } else {
      platform::CPUPlace cpu;
C
chengduo 已提交
690
      for (size_t i = 1; i < member_->places_.size(); ++i) {
691 692
        auto local_scope = member_->local_scopes_[i];
        auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
C
chengduo 已提交
693

Q
Qiao Longfei 已提交
694
        auto copy_memory = [&] {
695 696 697
          t->Resize(dims);
          t->mutable_data(cpu, main_tensor.type());
          paddle::framework::TensorCopy(main_tensor, cpu, t);
Q
can run  
Qiao Longfei 已提交
698 699
        };

Q
Qiao Longfei 已提交
700
        auto share_memory = [&] { t->ShareDataWith(main_tensor); };
Q
can run  
Qiao Longfei 已提交
701 702 703 704 705 706 707

        // FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix.
        if (member_->build_strategy_.async_mode_) {
          share_memory();
        } else if (member_->use_all_reduce_ || member_->use_cuda_ ||
                   var == "@LR_DECAY_COUNTER@") {
          copy_memory();
708
        } else {
Q
can run  
Qiao Longfei 已提交
709
          share_memory();
710
        }
Y
Yu Yang 已提交
711
      }
Y
Stash  
Yu Yang 已提交
712 713
    }
  }
Y
Yu Yang 已提交
714
}
Y
Yu Yang 已提交
715

716 717
FeedFetchList ParallelExecutor::Run(
    const std::vector<std::string> &fetch_tensors) {
718
  VLOG(3) << "enter ParallelExecutor Run";
Y
Yu Yang 已提交
719 720 721
#ifdef WITH_GPERFTOOLS
  if (gProfileStarted) {
    ProfilerFlush();
S
sneaxiy 已提交
722 723
  }
#endif
Y
Yu Yang 已提交
724

X
Xin Pan 已提交
725
  platform::RecordBlock b(0);
726 727 728

  ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), fetch_tensors,
                                member_->HasGarbageCollectors());
729 730

  VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run";
S
sneaxiy 已提交
731
  auto fetch_data = member_->executor_->Run(fetch_tensors);
732
  return fetch_data;
Y
Yu Yang 已提交
733
}
Y
Yu Yang 已提交
734

Y
Yu Yang 已提交
735 736 737 738 739 740 741
void ParallelExecutor::FeedTensorsIntoLocalScopes(
    const std::vector<std::unordered_map<std::string, LoDTensor>> &tensors) {
  PADDLE_ENFORCE_EQ(member_->local_scopes_.size(), tensors.size());

  for (size_t i = 0; i < tensors.size(); ++i) {
    auto &map = tensors[i];
    for (auto &pair : map) {
742
      bool is_persistable = member_->IsPersistable(pair.first);
743 744 745
      if (!is_persistable) {
        member_->SetSkipMemoryReuse(i, pair.first);
      }
746 747 748 749 750
      auto *feed_scope = is_persistable ? member_->local_scopes_[i]
                                        : member_->local_exec_scopes_[i];
      auto *feed_var = feed_scope->Var(pair.first);

      auto *trg = feed_var->GetMutable<LoDTensor>();
Y
Yu Yang 已提交
751 752 753 754 755 756 757 758
      trg->ShareDataWith(pair.second);
      trg->set_lod(pair.second.lod());
    }
  }
}

void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
    const std::unordered_map<std::string, LoDTensor> &tensors) {
759
  size_t num_places = member_->places_.size();
760
  for (auto &pair : tensors) {
761 762 763 764
    bool is_persistable = member_->IsPersistable(pair.first);
    VLOG(3) << "Split " << (is_persistable ? "persistable" : "no persistable")
            << " data (" << pair.first << "), dim:" << pair.second.dims()
            << ", place: " << pair.second.place();
Y
Yu Yang 已提交
765
    auto lod_tensors = pair.second.SplitLoDTensor(member_->places_);
766 767
    bool is_cpu_place = platform::is_cpu_place(member_->places_.front());
    if (!is_persistable && num_places != lod_tensors.size()) {
C
chengduo 已提交
768
      auto error_info = string::Sprintf(
769 770 771
          "The number(%d) of samples[%s] of current batch is less than the "
          "count(%d) of devices(%s), currently, it is not allowed. ",
          lod_tensors.size(), pair.first, num_places,
C
chengduo 已提交
772 773 774 775 776 777 778
          (is_cpu_place ? "CPU" : "GPU"));
      if (is_cpu_place) {
        error_info +=
            "You should set the environment variable CPU_NUM in the system "
            "to determine the number of devices you need.";
      }
      PADDLE_THROW(error_info);
779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804
    } else if (is_persistable) {
      if (lod_tensors.size() == 1) {
        lod_tensors.reserve(num_places);
        auto &tensor = lod_tensors.front();
        PADDLE_ENFORCE_EQ(tensor.dims(), pair.second.dims(),
                          "The dim doesn't match.");
        PADDLE_ENFORCE_EQ(tensor.place(), member_->places_.at(0),
                          "The place doesn't match.");
        for (size_t i = 1; i < num_places; ++i) {
          lod_tensors.emplace_back();
          auto &tmp = lod_tensors.back();
          framework::TensorCopy(pair.second, member_->places_.at(i), &tmp);
        }
      }
      if (lod_tensors.size() != num_places) {
        auto error_info = string::Sprintf(
            "The number(%d) of samples[%s] of the current batch does not match "
            "the count(%d) of devices(%s). Because that %s is a persistable "
            "variable, you can feed just one sample, in that case, the input "
            "sample will be copied in %d copies and be sent to different "
            "places separately. If you need that different place has different "
            "value, you should feed %d samples.",
            lod_tensors.size(), pair.first, num_places,
            (is_cpu_place ? "CPU" : "GPU"), pair.first, num_places, num_places);
        PADDLE_THROW(error_info);
      }
C
chengduo 已提交
805
    }
806

807
    for (size_t j = 0; j < num_places; ++j) {
808 809 810 811 812
      auto *feed_scope = is_persistable ? member_->local_scopes_[j]
                                        : member_->local_exec_scopes_[j];
      auto *feed_var = feed_scope->Var(pair.first);

      auto t = feed_var->GetMutable<LoDTensor>();
813 814
      t->ShareDataWith(lod_tensors[j]);
      t->set_lod(lod_tensors[j].lod());
X
Xin Pan 已提交
815 816 817 818
    }
  }
}

X
Xin Pan 已提交
819 820 821 822 823 824 825
ParallelExecutor::~ParallelExecutor() {
  for (auto &p : member_->places_) {
    platform::DeviceContextPool::Instance().Get(p)->Wait();
  }
  delete member_;
}

826
bool ParallelExecutor::EnableParallelGraphExecution(
X
Xin Pan 已提交
827
    const ir::Graph &graph, const ExecutionStrategy &exec_strategy,
828
    const BuildStrategy &build_strategy) const {
829 830 831
  if (!FLAGS_enable_parallel_graph) {
    return false;
  }
832

Y
Yancey1989 已提交
833
  bool enable_parallel_graph = true;
834

X
Xin Pan 已提交
835 836 837 838 839 840 841 842 843 844 845 846 847
  for (ir::Node *node : graph.Nodes()) {
    if (node->IsVar() && node->Var()) {
      // TODO(Yancey1989): support sparse update in ParallelGraph mode.
      if (node->Var()->GetType() == proto::VarType::SELECTED_ROWS) {
        enable_parallel_graph = false;
        break;
      }
    } else if (node->IsOp() && node->Op()) {
      // TODO(Yancey1989): support pserver mode
      if (node->Op()->Type() == "send" || node->Op()->Type() == "recv") {
        enable_parallel_graph = false;
        break;
      }
848 849 850
    }
  }

851
  if (!member_->use_all_reduce_ || !member_->use_cuda_) {
Y
Yancey1989 已提交
852
    if (build_strategy.enable_sequential_execution_ ||
853
        exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental) {
Y
Yancey1989 已提交
854
      enable_parallel_graph = false;
855 856 857 858 859 860 861 862 863
    }
  }

#ifdef WIN32
  VLOG(1) << "Windows has no support to parallel graph, enable_parallel_graph "
             "would be forced to false.";
  enable_parallel_graph = false;
#endif

Y
Yancey1989 已提交
864
  return enable_parallel_graph;
865 866
}

Y
Yu Yang 已提交
867
}  // namespace framework
Y
Yang Yang 已提交
868
}  // namespace paddle
S
sneaxiy 已提交
869

S
sneaxiy 已提交
870
USE_PASS(reference_count_pass);
S
sneaxiy 已提交
871
USE_PASS(eager_deletion_pass);
872
USE_PASS(buffer_shared_inplace_pass);
873
USE_PASS(buffer_shared_cross_op_memory_reuse_pass);