parallel_executor.cc 17.6 KB
Newer Older
Y
Yang Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/parallel_executor.h"
D
dzhwinter 已提交
16
#include <algorithm>
C
chengduoZH 已提交
17
#include <string>
18
#include <tuple>
Q
qiaolongfei 已提交
19
#include <vector>
C
chengduo 已提交
20
#include "paddle/fluid/framework/ir/graph_helper.h"
Y
Yu Yang 已提交
21

X
clean  
Xin Pan 已提交
22
#include "paddle/fluid/framework/ir/graph.h"
X
Xin Pan 已提交
23

P
peizhilin 已提交
24
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
Y
Yu Yang 已提交
25
#include "paddle/fluid/platform/nccl_helper.h"
Y
Yu Yang 已提交
26
#endif
Y
Yang Yang 已提交
27

Y
yuyang18 已提交
28
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
29
#include "paddle/fluid/framework/details/multi_devices_helper.h"
Y
Yancey1989 已提交
30
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
S
sneaxiy 已提交
31
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
Y
yuyang18 已提交
32
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
Y
Yu Yang 已提交
33
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
34
#include "paddle/fluid/platform/profiler.h"
Y
Yu Yang 已提交
35

Y
Yu Yang 已提交
36
#ifdef WITH_GPERFTOOLS
Y
Yu Yang 已提交
37
#include "gperftools/profiler.h"
Y
Yu Yang 已提交
38
#endif
Y
Yu Yang 已提交
39
DEFINE_string(pe_profile_fname, "",
Y
Yu Yang 已提交
40 41 42
              "Profiler filename for PE, which generated by gperftools."
              "Only valid when compiled `WITH_PRIFILER=ON`. Empty if disable.");

Y
Yang Yang 已提交
43
namespace paddle {
Y
Yu Yang 已提交
44 45
namespace framework {

Y
Yu Yang 已提交
46
static std::once_flag gProfileOnce;
Y
Yu Yang 已提交
47
#ifdef WITH_GPERFTOOLS
Y
Yu Yang 已提交
48
static bool gProfileStarted = false;
Y
Yu Yang 已提交
49
#endif
Y
Yu Yang 已提交
50 51 52
class ParallelExecutorPrivate {
 public:
  explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places)
Y
Yu Yang 已提交
53
      : places_(places) {
Y
Yu Yang 已提交
54
    if (!FLAGS_pe_profile_fname.empty()) {
Y
Yu Yang 已提交
55 56
      std::call_once(gProfileOnce, [] {
#ifdef WITH_GPERFTOOLS
Y
Yu Yang 已提交
57
        ProfilerStart(FLAGS_pe_profile_fname.c_str());
Y
Yu Yang 已提交
58 59 60
        gProfileStarted = true;
#else
        LOG(WARNING) << "Paddle is not compiled with gperftools. "
Y
Yu Yang 已提交
61
                        "FLAGS_pe_profile_fname will be ignored";
Y
Yu Yang 已提交
62 63 64 65
#endif
      });
    }
  }
Y
Yu Yang 已提交
66

67 68 69 70 71 72 73 74 75 76 77
  ~ParallelExecutorPrivate() {
    if (own_local_scope_) {
      for (size_t i = 1; i < local_scopes_.size(); ++i) {
        // Skip the first scope, since it is the global scope.
        Scope *local_scope = local_scopes_[i];
        if (global_scope_->HasKid(local_scope)) {
          global_scope_->DeleteScope(local_scope);
        }
      }
    }
  }
S
sneaxiy 已提交
78

S
sneaxiy 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92
  std::unique_ptr<ir::Graph> PrepareGCAndRefCnts(
      std::unique_ptr<ir::Graph> graph, size_t max_memory_size);

  inline bool HasGarbageCollectors() const { return !gcs_.empty(); }

  void ResetRuntimeReferenceCount(const std::vector<std::string> &fetch_tensors,
                                  const std::string &fetched_var_name) {
    for (size_t i = 0; i < runtime_ref_cnts_.size(); ++i) {
      for (auto &pair : global_ref_cnts_[i]) {
        runtime_ref_cnts_[i][pair.first] = pair.second;
      }

      for (auto &fetch_name : fetch_tensors) {
        runtime_ref_cnts_[i].erase(fetch_name);
S
sneaxiy 已提交
93
      }
S
sneaxiy 已提交
94
      runtime_ref_cnts_[i].erase(fetched_var_name);
S
sneaxiy 已提交
95 96 97
    }
  }

D
dzhwinter 已提交
98
  BuildStrategy build_strategy_;
Y
Yu Yang 已提交
99 100
  std::vector<platform::Place> places_;
  std::vector<Scope *> local_scopes_;
101
  Scope *global_scope_;  // not owned
Y
Yu Yang 已提交
102
  std::unique_ptr<details::SSAGraphExecutor> executor_;
Y
Yu Yang 已提交
103

P
peizhilin 已提交
104
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
Y
Yu Yang 已提交
105
  std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
Y
Yu Yang 已提交
106
#endif
C
chengduoZH 已提交
107 108
  bool own_local_scope_;
  bool use_cuda_;
109
  bool use_all_reduce_;
Y
Yancey1989 已提交
110
  size_t num_parallel_devices_;
S
sneaxiy 已提交
111

S
sneaxiy 已提交
112 113 114 115 116 117
  // global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
  // then keeps unchanged
  // Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_
  std::vector<details::ReferenceCountMap> global_ref_cnts_;
  std::vector<details::AtomicReferenceCountMap> runtime_ref_cnts_;
  details::GarbageCollectorMap gcs_;
Y
Yu Yang 已提交
118 119
};

S
sneaxiy 已提交
120 121 122 123 124 125 126
std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
    std::unique_ptr<ir::Graph> graph, size_t max_memory_size) {
  for (size_t i = 0; i < places_.size(); ++i) {
    auto &place = places_[i];
    if (gcs_.count(place) > 0) {
      continue;
    }
S
sneaxiy 已提交
127
    std::unique_ptr<GarbageCollector> gc;
S
sneaxiy 已提交
128
#ifdef PADDLE_WITH_CUDA
S
sneaxiy 已提交
129 130
    if (platform::is_gpu_place(place)) {
      if (IsFastEagerDeletionModeEnabled()) {
S
sneaxiy 已提交
131 132
        gc.reset(new UnsafeFastGPUGarbageCollector(
            boost::get<platform::CUDAPlace>(place), max_memory_size));
S
sneaxiy 已提交
133
      } else {
S
sneaxiy 已提交
134 135
        gc.reset(new StreamGarbageCollector(
            boost::get<platform::CUDAPlace>(place), max_memory_size));
S
sneaxiy 已提交
136 137
      }
      VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
S
sneaxiy 已提交
138
    } else {
S
sneaxiy 已提交
139
#endif
S
sneaxiy 已提交
140 141 142 143 144 145 146
      if (platform::is_cpu_place(place)) {
        gc.reset(new CPUGarbageCollector(boost::get<platform::CPUPlace>(place),
                                         max_memory_size));
        VLOG(10) << "Created GarbageCollector at " << place;
      } else {
        PADDLE_THROW("Unsupported place for garbage collection");
      }
S
sneaxiy 已提交
147 148 149 150
#ifdef PADDLE_WITH_CUDA
    }
#endif

S
sneaxiy 已提交
151
    gcs_.emplace(place, std::move(gc));
S
sneaxiy 已提交
152 153
  }

S
sneaxiy 已提交
154
  if (!gcs_.empty()) {
S
sneaxiy 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
    std::vector<details::LastLiveOpsOfVars> last_live_ops_of_vars;

    auto ref_cnt_pass =
        ir::PassRegistry::Instance().Get("reference_count_pass");
    ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount,
                              &global_ref_cnts_);
    ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars,
                              &last_live_ops_of_vars);
    graph = ref_cnt_pass->Apply(std::move(graph));
    VLOG(10) << "ReferenceCountPass Applied";

    auto eager_deletion_pass =
        ir::PassRegistry::Instance().Get("eager_deletion_pass");
    eager_deletion_pass->SetNotOwned(details::kRuntimeReferenceCount,
                                     &runtime_ref_cnts_);
    eager_deletion_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
    eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars,
                                     &last_live_ops_of_vars);
    eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_);
    graph = eager_deletion_pass->Apply(std::move(graph));
    VLOG(10) << "EagerDeletionPass Applied";
D
dzhwinter 已提交
176 177 178 179 180 181 182 183

    if (build_strategy_.memory_early_delete_) {
      auto early_delete_pass =
          ir::PassRegistry::Instance().Get("memory_early_delete_pass");
      early_delete_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
      graph = early_delete_pass->Apply(std::move(graph));
    }
    VLOG(10) << "MemoryEarlyDeletePass Applied.";
S
sneaxiy 已提交
184 185 186 187 188
  }

  return graph;
}

189 190 191 192
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
  return member_->local_scopes_;
}

Y
Yu Yang 已提交
193
ParallelExecutor::ParallelExecutor(
194
    const std::vector<platform::Place> &places,
195 196
    const std::unordered_set<std::string> &bcast_vars,
    const ProgramDesc &main_program, const std::string &loss_var_name,
Y
yuyang18 已提交
197
    Scope *scope, const std::vector<Scope *> &local_scopes,
198
    const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy,
199
    size_t num_trainers, size_t trainer_id)
Y
Yu Yang 已提交
200
    : member_(new ParallelExecutorPrivate(places)) {
Y
Yu Yang 已提交
201
  member_->global_scope_ = scope;
202
  member_->use_cuda_ = exec_strategy.use_cuda_;
D
dzhwinter 已提交
203
  member_->build_strategy_ = build_strategy;
204 205
  member_->use_all_reduce_ =
      build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce;
Y
Yancey1989 已提交
206
  member_->num_parallel_devices_ = num_trainers * places.size();
207 208 209 210 211

  if (!member_->use_all_reduce_) {
    PADDLE_ENFORCE(places.size() > 1,
                   "If you set build_strategy.reduce with 'Reduce',"
                   "the number of places must be greater than 1.");
Y
Yancey1989 已提交
212 213
  }

Y
Yancey1989 已提交
214
  if (build_strategy.enable_parallel_graph_) {
Y
Yancey1989 已提交
215 216
    PADDLE_ENFORCE(
        member_->use_all_reduce_,
Y
Yancey1989 已提交
217 218
        "build_strategy.reduce should be `AllReduce` if you want to enable"
        "ParallelGraph.");
Y
Yancey1989 已提交
219 220
    PADDLE_ENFORCE(
        member_->use_cuda_,
Y
Yancey1989 已提交
221 222
        "execution_strategy.use_cuda should be True if you want to enable "
        "ParallelGraph.");
223
  }
Y
Yu Yang 已提交
224

225
  // Step 1. Bcast the bcast_vars to devs.
Y
Yu Yang 已提交
226
  // Create local scopes
227
  if (local_scopes.empty()) {
C
chengduoZH 已提交
228
    member_->own_local_scope_ = true;
Y
Yu Yang 已提交
229 230
    member_->local_scopes_.emplace_back(member_->global_scope_);
    for (size_t i = 1; i < member_->places_.size(); ++i) {
Y
Debug  
Yu Yang 已提交
231
      member_->local_scopes_.emplace_back(&scope->NewScope());
232 233
    }
  } else {
C
chengduoZH 已提交
234
    member_->own_local_scope_ = false;
235 236
    PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size());
    for (size_t i = 0; i < member_->places_.size(); ++i) {
237
      member_->local_scopes_.emplace_back(&local_scopes[i]->NewScope());
238
    }
Y
Yu Yang 已提交
239 240
  }

C
chengduoZH 已提交
241
  if (member_->use_cuda_) {
Y
Yu Yang 已提交
242
// Bcast Parameters to all GPUs
P
peizhilin 已提交
243
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
C
chengduoZH 已提交
244
    auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
Y
Yancey1989 已提交
245
    ncclUniqueId *nccl_id = nullptr;
Y
Yancey1989 已提交
246 247 248 249
    // nccl collective would broadcast nccl id by gen_nccl_id operator.
    if (nccl_id_var != nullptr) {
      nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
    }
Y
Yancey1989 已提交
250
    if (build_strategy.enable_parallel_graph_ && places.size() > 1) {
Y
Yancey1989 已提交
251
      if (nccl_id == nullptr) {
Y
Yancey1989 已提交
252 253 254
        nccl_id = new ncclUniqueId();
        PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id));
      }
C
chengduoZH 已提交
255
    }
Y
Yancey1989 已提交
256

C
chengduoZH 已提交
257
    member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
Y
Yancey1989 已提交
258
        member_->places_, nccl_id, num_trainers, trainer_id));
C
chengduoZH 已提交
259 260
#else
    PADDLE_THROW("Not compiled with CUDA");
Y
Yu Yang 已提交
261
#endif
C
chengduoZH 已提交
262 263
  }
  if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
Y
Yancey1989 已提交
264
    BCastParamsToDevices(bcast_vars);
Y
Yu Yang 已提交
265
  }
Y
Yancey1989 已提交
266
  // Startup Program has been run. All local scopes has correct parameters.
Y
yuyang18 已提交
267

Y
Yancey1989 已提交
268 269 270
  // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
  // ncclOp
  std::vector<std::unique_ptr<ir::Graph>> graphs;
Y
Yancey1989 已提交
271
  member_->num_parallel_devices_ = member_->places_.size() * num_trainers;
P
peizhilin 已提交
272
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
Y
Yancey1989 已提交
273
  if (build_strategy.enable_parallel_graph_) {
Y
Yancey1989 已提交
274
    for (size_t i = 0; i < member_->places_.size(); ++i) {
Y
Yancey1989 已提交
275 276 277 278
      std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
          main_program, {member_->places_[i]}, loss_var_name,
          {member_->local_scopes_[i]}, member_->num_parallel_devices_,
          member_->use_cuda_, member_->nccl_ctxs_.get());
Y
Yancey1989 已提交
279 280 281 282
      graphs.push_back(std::move(graph));
    }
  } else {
    std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
283
        main_program, member_->places_, loss_var_name, member_->local_scopes_,
Y
Yancey1989 已提交
284 285
        member_->num_parallel_devices_, member_->use_cuda_,
        member_->nccl_ctxs_.get());
Y
Yancey1989 已提交
286 287
    graphs.push_back(std::move(graph));
  }
C
chengduoZH 已提交
288
#else
Y
Yancey1989 已提交
289 290 291
  std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
      main_program, member_->places_, loss_var_name, member_->local_scopes_,
      member_->num_parallel_devices_, member_->use_cuda_);
Y
Yancey1989 已提交
292
  graphs.push_back(std::move(graph));
Y
Yu Yang 已提交
293
#endif
Y
Yancey1989 已提交
294
  auto max_memory_size = GetEagerDeletionThreshold();
Y
Yancey1989 已提交
295 296
  // TODO(Yancey1989): fix gc failed on ParallelGraph strategy.
  if (max_memory_size >= 0 && !build_strategy.enable_parallel_graph_) {
Y
Yancey1989 已提交
297 298 299 300
    graphs[0] = member_->PrepareGCAndRefCnts(
        std::move(graphs[0]), static_cast<size_t>(max_memory_size));
  }

301 302
  // Step 3. Create vars in each scope. Passes may also create new vars.
  //         skip control vars and empty vars
Y
Yancey1989 已提交
303 304 305 306 307 308 309 310 311 312 313
  std::vector<details::VariableInfo> var_infos;
  for (auto &graph : graphs) {
    for (auto &node : graph->Nodes()) {
      if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
        var_infos.emplace_back();
        var_infos.back().name_ = node->Var()->Name();
        var_infos.back().type_ = node->Var()->GetType();
        var_infos.back().persistable_ = node->Var()->Persistable();
      }
    }
  }
Y
Yancey1989 已提交
314

W
Wu Yi 已提交
315 316
  // If the loss_var_name is given, the number of graph should be only one.
  if (loss_var_name.size()) {
Y
Yancey1989 已提交
317
    size_t graph_num = ir::GraphNum(*graphs[0]);
C
chengduo 已提交
318 319 320 321
    if (graph_num > 1) {
      LOG(WARNING)
          << "The number of graph should be only one, "
             "but the current graph has "
Y
Yancey1989 已提交
322
          << ir::GraphNum(*graphs[0])
C
chengduo 已提交
323 324 325 326 327
          << " sub_graphs. If you want to see the nodes of the "
             "sub_graphs, you should use 'FLAGS_print_sub_graph_dir' "
             "to specify the output dir. NOTES: if you not do training, "
             "please don't pass loss_var_name.";
    }
W
Wu Yi 已提交
328 329
  }

Y
Yancey1989 已提交
330
  if (build_strategy.enable_parallel_graph_) {
Y
Yancey1989 已提交
331
    member_->executor_.reset(new details::ParallelSSAGraphExecutor(
Y
Yancey1989 已提交
332 333
        exec_strategy, member_->local_scopes_, member_->places_,
        std::move(graphs)));
Y
yuyang18 已提交
334
  } else {
Y
Yancey1989 已提交
335 336 337 338 339 340 341 342 343
    if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
      member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
          exec_strategy, member_->local_scopes_, member_->places_,
          std::move(graphs[0])));
    } else {
      member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
          exec_strategy, member_->local_scopes_, member_->places_,
          std::move(graphs[0])));
    }
C
chengduoZH 已提交
344
  }
Y
yuyang18 已提交
345 346

  member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
Y
Yancey1989 已提交
347
      exec_strategy, member_->local_scopes_, std::move(var_infos),
Y
yuyang18 已提交
348
      member_->places_, std::move(member_->executor_)));
Y
Yu Yang 已提交
349 350
}

Y
Yancey1989 已提交
351
void ParallelExecutor::BCastParamsToDevices(
352
    const std::unordered_set<std::string> &vars) const {
X
Xin Pan 已提交
353
  // the initializing bcast, all vars would be bcast from device(0).
354
  for (auto &var : vars) {
X
Xin Pan 已提交
355
    framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var);
J
JiayiFeng 已提交
356
    if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
357 358 359 360
      continue;
    }

    auto &main_tensor = main_var->Get<LoDTensor>();
361
    if (!main_tensor.IsInitialized()) {
M
minqiyang 已提交
362
      VLOG(3) << "one in var not inited, return!";
363 364
      continue;
    }
365 366
    auto &dims = main_tensor.dims();
    if (paddle::platform::is_gpu_place(main_tensor.place())) {
P
peizhilin 已提交
367
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
368
      std::vector<void *> buffers;
369 370 371 372 373
      size_t numel = main_tensor.numel();
      ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());
      for (size_t i = 0; i < member_->places_.size(); ++i) {
        auto place = member_->places_[i];
        void *buffer;
374

X
Xin Pan 已提交
375
        if (i == 0) {
376 377
          buffer = const_cast<void *>(main_tensor.data<void>());
        } else {
Y
Yu Yang 已提交
378
          auto local_scope = member_->local_scopes_[i];
379
          auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
Y
Update  
Yu Yang 已提交
380
          t->Resize(dims);
381
          buffer = t->mutable_data(place, main_tensor.type());
Y
Update  
Yu Yang 已提交
382
        }
383
        buffers.push_back(buffer);
384
      }
385

386 387 388 389 390 391
      PADDLE_ENFORCE_EQ(member_->places_.size(), buffers.size(),
                        "variables' buffer size to bcast NOT equal to places");
      {
        platform::NCCLGroupGuard guard;
        for (size_t i = 0; i < member_->places_.size(); ++i) {
          auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]);
X
Xin Pan 已提交
392 393
          platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
                                       nccl_ctx.comm_, nccl_ctx.stream());
394
        }
395
        member_->nccl_ctxs_->WaitAll();
396
      }
C
chengduoZH 已提交
397 398 399
#else
      PADDLE_THROW("Not compiled with CUDA");
#endif
400 401
    } else {
      platform::CPUPlace cpu;
Y
Yancey1989 已提交
402
      for (size_t i = 0; i < member_->places_.size(); ++i) {
X
Xin Pan 已提交
403
        if (i == 0) continue;
Y
Yancey1989 已提交
404

405 406
        auto local_scope = member_->local_scopes_[i];
        auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
C
chengduo 已提交
407 408 409 410

        // FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix.
        if (member_->use_all_reduce_ || member_->use_cuda_ ||
            var == "@LR_DECAY_COUNTER@") {
411 412 413 414 415 416
          t->Resize(dims);
          t->mutable_data(cpu, main_tensor.type());
          paddle::framework::TensorCopy(main_tensor, cpu, t);
        } else {
          t->ShareDataWith(main_tensor);
        }
Y
Yu Yang 已提交
417
      }
Y
Stash  
Yu Yang 已提交
418 419
    }
  }
Y
Yu Yang 已提交
420
}
Y
Yu Yang 已提交
421

Y
Yu Yang 已提交
422 423
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
                           const std::string &fetched_var_name) {
Y
Yu Yang 已提交
424 425 426
#ifdef WITH_GPERFTOOLS
  if (gProfileStarted) {
    ProfilerFlush();
S
sneaxiy 已提交
427 428
  }
#endif
Y
Yu Yang 已提交
429

X
Xin Pan 已提交
430
  platform::RecordBlock b(0);
S
sneaxiy 已提交
431 432
  if (member_->HasGarbageCollectors()) {
    member_->ResetRuntimeReferenceCount(fetch_tensors, fetched_var_name);
S
sneaxiy 已提交
433
  }
S
sneaxiy 已提交
434 435 436
  auto fetch_data = member_->executor_->Run(fetch_tensors);
  *member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
      fetch_data;
Y
Yu Yang 已提交
437
}
Y
Yu Yang 已提交
438

Y
Yu Yang 已提交
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
void ParallelExecutor::FeedTensorsIntoLocalScopes(
    const std::vector<std::unordered_map<std::string, LoDTensor>> &tensors) {
  PADDLE_ENFORCE_EQ(member_->local_scopes_.size(), tensors.size());

  for (size_t i = 0; i < tensors.size(); ++i) {
    auto &map = tensors[i];
    auto *scope = member_->local_scopes_[i];
    for (auto &pair : map) {
      auto *trg = scope->Var(pair.first)->GetMutable<LoDTensor>();
      trg->ShareDataWith(pair.second);
      trg->set_lod(pair.second.lod());
    }
  }
}

void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
    const std::unordered_map<std::string, LoDTensor> &tensors) {
  for (auto pair : tensors) {
    auto lod_tensors = pair.second.SplitLoDTensor(member_->places_);
458 459 460 461 462
    PADDLE_ENFORCE_EQ(
        member_->places_.size(), lod_tensors.size(),
        "The number of samples of current batch is less than the count of "
        "devices, currently, it is not allowed. (%d vs %d)",
        member_->places_.size(), lod_tensors.size());
X
Xin Pan 已提交
463 464
    for (size_t j = 0; j < member_->places_.size(); ++j) {
      // TODO(panxy0718): Do I need to delete this var?
465
      auto t =
Y
Yu Yang 已提交
466
          member_->local_scopes_[j]->Var(pair.first)->GetMutable<LoDTensor>();
467 468
      t->ShareDataWith(lod_tensors[j]);
      t->set_lod(lod_tensors[j].lod());
X
Xin Pan 已提交
469 470 471 472
    }
  }
}

473
ParallelExecutor::~ParallelExecutor() {
474 475
  for (auto &p : member_->places_) {
    platform::DeviceContextPool::Instance().Get(p)->Wait();
C
chengduozh 已提交
476
  }
S
sneaxiy 已提交
477
  delete member_;
478 479
}

Y
Yu Yang 已提交
480
}  // namespace framework
Y
Yang Yang 已提交
481
}  // namespace paddle
S
sneaxiy 已提交
482

D
dzhwinter 已提交
483
USE_PASS(memory_early_delete_pass);
S
sneaxiy 已提交
484
USE_PASS(reference_count_pass);
S
sneaxiy 已提交
485
USE_PASS(eager_deletion_pass);