executor.cc 25.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/executor.h"
16

17
#include <memory>
18

Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/feed_fetch_method.h"
D
dongdaxiang 已提交
20 21
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
Z
Zeng Jinle 已提交
22
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
23
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
S
sneaxiy 已提交
24
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
Y
Yi Wang 已提交
25
#include "paddle/fluid/platform/place.h"
X
Xin Pan 已提交
26
#include "paddle/fluid/platform/profiler.h"
27
#include "paddle/fluid/platform/profiler/event_tracing.h"
28 29 30
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
31
#include "paddle/fluid/framework/executor_gc_helper.h"
Y
Yang Yu 已提交
32

D
dzhwinter 已提交
33
DECLARE_bool(benchmark);
34
DECLARE_bool(use_mkldnn);
Q
qijun 已提交
35 36 37

namespace paddle {
namespace framework {
X
Xin Pan 已提交
38 39 40 41 42
namespace {
// block id starts from 0. This id is used to represent the codeblock
// wrapping the first block 0.
int kProgramId = -1;
}  // namespace
Q
qijun 已提交
43

Q
Qiao Longfei 已提交
44
ExecutorPrepareContext::ExecutorPrepareContext(
S
sneaxiy 已提交
45 46 47 48 49
    const framework::ProgramDesc& prog, size_t block_id)
    : prog_(prog), block_id_(block_id) {}

void ExecutorPrepareContext::PrepareUnusedVars(
    const std::vector<std::string>& keep_vars, bool force_disable_gc) {
Z
Zeng Jinle 已提交
50 51 52
  // If gc is enabled and block size > 1
  if (prog_.Size() > 1) {
    operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
53
        prog_, block_id_, ops_);
54 55
    operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
        prog_, block_id_, ops_);
Z
Zeng Jinle 已提交
56
    operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
57
        prog_, block_id_, ops_);
Z
Zeng Jinle 已提交
58
  }
59 60 61 62 63 64

  force_disable_gc_ = force_disable_gc;
  if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) {
    return;
  }

S
sneaxiy 已提交
65
  unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars);
S
sneaxiy 已提交
66
}
Y
Yu Yang 已提交
67

Q
Qiao Longfei 已提交
68
ExecutorPrepareContext::~ExecutorPrepareContext() {
M
minqiyang 已提交
69
  VLOG(5) << "destroy ExecutorPrepareContext";
Q
Qiao Longfei 已提交
70
}
Y
Yu Yang 已提交
71

D
dzhwinter 已提交
72
Executor::Executor(const platform::Place& place) : place_(place) {}
Q
qijun 已提交
73

74 75
Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN
76
  // Clear mkl-dnn cache,
77
  // this is needed to have mkl-dnn unit tests working
78
  platform::ClearMKLDNNCache(place_, this);
79 80 81
#endif
}

Y
Yancey1989 已提交
82
void Executor::Close() {
T
tangwei12 已提交
83 84 85 86 87 88 89
  // #ifdef PADDLE_WITH_DISTRIBUTE
  //   // TODO(typhoonzero): complete message will need to use real trainer_id,
  //   // except 0.
  //   auto client =
  //       paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
  //   client->SendComplete();
  // #endif
Y
Yancey1989 已提交
90
}
W
Wu Yi 已提交
91

92 93
void Executor::CreateVariables(const ProgramDesc& pdesc,
                               Scope* scope,
L
Liu Yiqun 已提交
94
                               int block_id) {
95
  VLOG(3) << "Creating Variables for block " << block_id;
L
Liu Yiqun 已提交
96
  auto& global_block = pdesc.Block(block_id);
97 98 99 100 101 102 103 104 105 106 107 108
  const Scope* ancestor_scope = scope;
  while (ancestor_scope->parent()) {
    ancestor_scope = ancestor_scope->parent();
  }
  if (ancestor_scope != scope) {
    for (auto& var : global_block.AllVars()) {
      if (var->Name() == framework::kEmptyVarName) {
        continue;
      }

      if (var->Persistable()) {
        auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
S
Steffy-zxf 已提交
109 110

        VLOG(3) << "Initialize Variable " << var->Name();
111
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
112
        VLOG(3) << "Create Variable " << var->Name()
S
Steffy-zxf 已提交
113 114
                << " global, which pointer is " << ptr << " type is "
                << static_cast<int>(var->GetType());
115 116
      } else {
        auto* ptr = scope->Var(var->Name());
117
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
118
        VLOG(3) << "Create Variable " << var->Name()
S
Steffy-zxf 已提交
119 120
                << " locally, which pointer is " << ptr << "Variable Type "
                << static_cast<int>(var->GetType());
121 122 123 124 125
      }
    }
  } else {
    for (auto& var : global_block.AllVars()) {
      auto* ptr = scope->Var(var->Name());
126
      InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
127 128
      VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
              << ptr;
129 130 131 132
    }
  }
}

133
std::shared_ptr<TrainerBase> Executor::InitForDataset(
134 135 136 137
    const ProgramDesc& main_program,
    const std::string& trainer_desc_str,
    Scope* scope,
    Dataset* dataset) {
138
  VLOG(3) << "Start to InitForDataset in executor";
D
dongdaxiang 已提交
139
  TrainerDesc trainer_desc;
H
hutuxian 已提交
140
  bool success = trainer_desc.ParseFromString(trainer_desc_str);
141 142
  PADDLE_ENFORCE_EQ(success,
                    true,
143 144 145
                    platform::errors::PreconditionNotMet(
                        "Fail to parse TrainerDesc from string:\n%s",
                        trainer_desc_str.c_str()));
D
dongdaxiang 已提交
146 147 148 149 150 151 152 153
  VLOG(3) << "Going to create trainer, trainer class is "
          << trainer_desc.class_name();
  std::shared_ptr<TrainerBase> trainer;
  trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name());
  // initialize trainer
  VLOG(3) << "Going to initialize trainer";
  trainer->Initialize(trainer_desc, dataset);
  VLOG(3) << "Set root scope here";
D
dongdaxiang 已提交
154
  trainer->SetScope(scope);
D
dongdaxiang 已提交
155 156 157 158 159
  // prepare training environment and helper environment
  VLOG(3) << "Try to init train environment";
  trainer->InitTrainerEnv(main_program, place_);
  VLOG(3) << "Try to init other environment";
  trainer->InitOtherEnv(main_program);
160 161 162 163
  return trainer;
}

void Executor::RunFromDataset(std::shared_ptr<TrainerBase> trainer) {
164
  PADDLE_ENFORCE_NOT_NULL(
165 166 167
      trainer,
      platform::errors::InvalidArgument(
          "Trainer is nullptr, invoke InitForDataset first"));
D
dongdaxiang 已提交
168 169 170
  // training and finalize training
  VLOG(3) << "Trainer starts to run";
  trainer->Run();
D
Dong Daxiang 已提交
171 172 173
}

void Executor::ReleaseTrainer(std::shared_ptr<TrainerBase> trainer) {
D
dongdaxiang 已提交
174 175 176
  VLOG(3) << "Trainer going to finalize";
  trainer->Finalize();
}
D
dongdaxiang 已提交
177

178 179 180 181 182
void Executor::Run(const ProgramDesc& pdesc,
                   Scope* scope,
                   int block_id,
                   bool create_local_scope,
                   bool create_vars,
S
sneaxiy 已提交
183
                   const std::vector<std::string>& skip_ref_cnt_vars,
184 185
                   bool force_disable_gc,
                   bool keep_kid_scopes) {
186
  LOG_FIRST_N(INFO, 1) << "Old Executor is Running.";
187 188
  platform::RecordEvent record_run(
      "Executor::Run", platform::TracerEventType::UserDefined, 1);
X
Xin Pan 已提交
189
  platform::RecordBlock b(block_id);
190
  if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
J
Jacek Czaja 已提交
191
  auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
192 193
#ifdef PADDLE_WITH_MKLDNN
  platform::AttachPointerHashToMKLDNNKey(this, place_);
J
Jacek Czaja 已提交
194
  platform::RegisterModelLayout(ctx->ops_, place_);
195
#endif
196 197
  RunPreparedContext(
      ctx.get(), scope, create_local_scope, create_vars, keep_kid_scopes);
Q
qijun 已提交
198 199
}

200 201 202 203 204 205 206
// Check whether the block already has feed operators and feed_holder.
// Return false if the block does not have any feed operators.
// If some feed operators have been prepended to the block, check that
// the info contained in these feed operators matches the feed_targets
// and feed_holder_name. Raise exception when any mismatch is found.
// Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators(
207
    const BlockDesc& block,
L
Liu Yiqun 已提交
208
    const std::map<std::string, const LoDTensor*>& feed_targets,
209 210
    const std::string& feed_holder_name) {
  size_t feed_count = 0;
211
  for (auto* op : block.AllOps()) {
212 213
    if (op->Type() == kFeedOpType) {
      feed_count++;
L
Liu Yiqun 已提交
214
      // The input variable's name of feed_op should be feed_holder_name.
215
      PADDLE_ENFORCE_EQ(
216 217
          op->Input("X")[0],
          feed_holder_name,
218 219
          platform::errors::PreconditionNotMet(
              "Input to feed op should be '%s', but received '%s'.",
220 221
              feed_holder_name,
              op->Input("X")[0]));
222
      std::string feed_target_name = op->Output("Out")[0];
223 224
      PADDLE_ENFORCE_NE(feed_targets.find(feed_target_name),
                        feed_targets.end(),
225 226 227 228
                        platform::errors::PreconditionNotMet(
                            "Feed operator output name '%s' cannot be found in "
                            "'feed_targets'",
                            feed_target_name));
229 230 231 232 233
    }
  }

  if (feed_count > 0) {
    PADDLE_ENFORCE_EQ(
234 235
        feed_count,
        feed_targets.size(),
236 237 238
        platform::errors::PreconditionNotMet(
            "The number of feed operators should match 'feed_targets', but "
            "received feed_count: %zu, required feed_targets.size(): %zu.",
239 240
            feed_count,
            feed_targets.size()));
241

242
    if (!feed_holder_name.empty()) {
L
Liu Yiqun 已提交
243
      // When feed operator are present, so should be feed_holder.
244
      auto var = block.FindVar(feed_holder_name);
245 246 247 248 249
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable", feed_holder_name));
      PADDLE_ENFORCE_EQ(
250 251
          var->GetType(),
          proto::VarType::FEED_MINIBATCH,
252 253 254
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FEED_MINIBATCH' type, but received "
              "'%s'.",
255 256
              feed_holder_name,
              DataTypeToString(var->GetType())));
257
    }
258 259 260 261 262 263 264 265 266 267 268 269
  }

  return feed_count > 0;
}

// Check whether the block already has fetch operators and fetch_holder.
// Return false if the block does not have any fetch operators.
// If some fetch operators have been appended to the block, check that
// the info contained in these fetch operators matches the fetch_targets
// and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators(
L
Liu Yiqun 已提交
270
    const BlockDesc& block,
271
    const std::map<std::string, FetchType*>& fetch_targets,
272 273
    const std::string& fetch_holder_name) {
  size_t fetch_count = 0;
274
  for (auto* op : block.AllOps()) {
275 276
    if (op->Type() == kFetchOpType) {
      fetch_count++;
L
Liu Yiqun 已提交
277
      // The output variable's name of fetch_op should be fetch_holder_name.
278
      PADDLE_ENFORCE_EQ(
279 280
          op->Output("Out")[0],
          fetch_holder_name,
281 282
          platform::errors::PreconditionNotMet(
              "Output of fetch op should be '%s', but received '%s'.",
283 284
              fetch_holder_name,
              op->Output("Out")[0]));
285
      std::string fetch_target_name = op->Input("X")[0];
286 287 288 289 290 291
      PADDLE_ENFORCE_NE(fetch_targets.find(fetch_target_name),
                        fetch_targets.end(),
                        platform::errors::NotFound(
                            "Fetch operator input name '%s' cannot be found in "
                            "'fetch_targets'.",
                            fetch_target_name));
292 293 294 295 296
    }
  }

  if (fetch_count > 0) {
    PADDLE_ENFORCE_EQ(
297 298
        fetch_count,
        fetch_targets.size(),
299 300 301
        platform::errors::PreconditionNotMet(
            "The number of fetch operators should match 'fetch_targets', but "
            "received fetch_count: %zu, required fetch_targets.size(): %zu.",
302 303
            fetch_count,
            fetch_targets.size()));
304

305
    if (!fetch_holder_name.empty()) {
L
Liu Yiqun 已提交
306
      // When fetch operator are present, so should be fetch_holder.
307
      auto var = block.FindVar(fetch_holder_name);
308 309 310 311 312
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable.", fetch_holder_name));
      PADDLE_ENFORCE_EQ(
313 314
          var->GetType(),
          proto::VarType::FETCH_LIST,
315 316
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FETCH_LIST' type, but received '%s'.",
317 318
              fetch_holder_name,
              DataTypeToString(var->GetType())));
319
    }
320 321 322 323 324
  }

  return fetch_count > 0;
}

325 326
void Executor::Run(const ProgramDesc& program,
                   Scope* scope,
327
                   std::map<std::string, const LoDTensor*>* feed_targets,
328
                   std::map<std::string, FetchType*>* fetch_targets,
329 330
                   bool create_local_scope,
                   bool create_vars,
W
Wu Yi 已提交
331
                   const std::string& feed_holder_name,
332
                   const std::string& fetch_holder_name) {
333 334
  platform::RecordEvent record_run(
      "Executor::Run", platform::TracerEventType::UserDefined, 1);
X
Xin Pan 已提交
335
  platform::RecordBlock b(kProgramId);
336
  if (FLAGS_use_mkldnn) EnableMKLDNN(program);
337 338 339
#ifdef PADDLE_WITH_MKLDNN
  platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
340
  bool has_feed_ops =
341
      has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
342
  bool has_fetch_ops =
343
      has_fetch_operators(program.Block(0), *fetch_targets, fetch_holder_name);
344 345

  ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
S
sneaxiy 已提交
346
  std::unique_ptr<ProgramDesc> unique_ptr_of_copy_program;
347
  if (!has_feed_ops || !has_fetch_ops) {
S
sneaxiy 已提交
348 349
    unique_ptr_of_copy_program.reset(new ProgramDesc(program));
    copy_program = unique_ptr_of_copy_program.get();
350
  }
351 352
  auto* global_block = copy_program->MutableBlock(0);

353
  if (!has_feed_ops) {
354 355
    // create feed_holder variable
    auto* feed_holder = global_block->Var(feed_holder_name);
356
    feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
357 358 359
    feed_holder->SetPersistable(true);

    int i = 0;
360
    for (auto& feed_target : (*feed_targets)) {
361
      std::string var_name = feed_target.first;
M
minqiyang 已提交
362
      VLOG(3) << "feed target's name: " << var_name;
363 364 365 366 367 368 369 370 371 372 373 374 375

      // prepend feed op
      auto* op = global_block->PrependOp();
      op->SetType(kFeedOpType);
      op->SetInput("X", {feed_holder_name});
      op->SetOutput("Out", {var_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

376
  if (!has_fetch_ops) {
377 378
    // create fetch_holder variable
    auto* fetch_holder = global_block->Var(fetch_holder_name);
379
    fetch_holder->SetType(proto::VarType::FETCH_LIST);
380 381 382
    fetch_holder->SetPersistable(true);

    int i = 0;
383
    for (auto& fetch_target : (*fetch_targets)) {
384
      std::string var_name = fetch_target.first;
M
minqiyang 已提交
385
      VLOG(3) << "fetch target's name: " << var_name;
386 387 388 389 390 391 392 393 394 395 396 397 398

      // append fetch op
      auto* op = global_block->AppendOp();
      op->SetType(kFetchOpType);
      op->SetInput("X", {var_name});
      op->SetOutput("Out", {fetch_holder_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

399
  auto ctx = Prepare(*copy_program, 0);
400 401 402 403 404 405 406
  RunPreparedContext(ctx.get(),
                     scope,
                     feed_targets,
                     fetch_targets,
                     create_local_scope,
                     create_vars,
                     feed_holder_name,
W
Wu Yi 已提交
407
                     fetch_holder_name);
408 409
}

Q
Qiao Longfei 已提交
410
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
411 412 413 414
    const ProgramDesc& program,
    int block_id,
    const std::vector<std::string>& skip_ref_cnt_vars,
    bool force_disable_gc) {
S
sneaxiy 已提交
415 416
  std::unique_ptr<ExecutorPrepareContext> ctx(
      new ExecutorPrepareContext(program, block_id));
417 418
  PADDLE_ENFORCE_LT(static_cast<size_t>(block_id),
                    program.Size(),
419 420 421
                    platform::errors::InvalidArgument(
                        "Input block id = %d, but it should be less than "
                        "program.size() which is %d",
422 423
                        static_cast<size_t>(block_id),
                        program.Size()));
Y
Yu Yang 已提交
424 425 426 427
  auto& block = program.Block(block_id);
  for (auto& op_desc : block.AllOps()) {
    ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
S
sneaxiy 已提交
428
  ctx->PrepareUnusedVars(skip_ref_cnt_vars, force_disable_gc);
Q
Qiyang Min 已提交
429
  return ctx;
Y
Yu Yang 已提交
430 431
}

T
refine  
typhoonzero 已提交
432
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
433 434
    const ProgramDesc& program,
    const std::vector<int>& block_ids,
S
sneaxiy 已提交
435 436
    const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
    bool force_disable_gc) {
437
  PADDLE_ENFORCE_EQ(
S
fix bug  
sneaxiy 已提交
438
      skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
439 440 441 442
      true,
      platform::errors::InvalidArgument("skip_ref_cnt_vars should be either "
                                        "empty or equals to block number %d",
                                        block_ids.size()));
T
typhoonzero 已提交
443
  std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
S
fix bug  
sneaxiy 已提交
444
  size_t idx = 0;
T
typhoonzero 已提交
445
  for (auto& bid : block_ids) {
446 447
    PADDLE_ENFORCE_LT(static_cast<size_t>(bid),
                      program.Size(),
448 449 450
                      platform::errors::InvalidArgument(
                          "Input block id = %zu, but it should be less than "
                          "program.size() which is %zu",
451 452
                          static_cast<size_t>(bid),
                          program.Size()));
S
sneaxiy 已提交
453
    auto* ctx = new ExecutorPrepareContext(program, bid);
T
typhoonzero 已提交
454 455 456 457
    auto& block = program.Block(bid);
    for (auto& op_desc : block.AllOps()) {
      ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
    }
S
sneaxiy 已提交
458 459 460 461 462
    if (skip_ref_cnt_vars.empty()) {
      ctx->PrepareUnusedVars(std::vector<std::string>(), force_disable_gc);
    } else {
      ctx->PrepareUnusedVars(skip_ref_cnt_vars[idx], force_disable_gc);
    }
T
typhoonzero 已提交
463
    result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx));
S
fix bug  
sneaxiy 已提交
464
    ++idx;
T
typhoonzero 已提交
465 466 467 468
  }
  return result;
}

469
void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
470 471
                                         Scope* scope,
                                         int64_t start_op_index,
472 473
                                         int64_t end_op_index,
                                         bool create_local_scope,
474 475
                                         bool create_vars,
                                         bool keep_kids) {
L
liutiexing 已提交
476
  platform::RecordEvent record_run("Executor::RunPartialPreparedContext",
477 478
                                   platform::TracerEventType::UserDefined,
                                   1);
479
  platform::RecordBlock b(kProgramId);
480 481
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope shouldn't be null"));
Y
Yu Yang 已提交
482 483 484 485
  Scope* local_scope = scope;
  if (create_vars) {
    if (create_local_scope) {
      local_scope = &scope->NewScope();
486 487
    }
    CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
L
Liu Yiqun 已提交
488
  }
Y
Yu Yang 已提交
489

S
sneaxiy 已提交
490
  int64_t max_memory_size = GetEagerDeletionThreshold();
S
sneaxiy 已提交
491
  std::unique_ptr<GarbageCollector> gc;
S
sneaxiy 已提交
492
  if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
S
sneaxiy 已提交
493
    if (platform::is_gpu_place(place_)) {
494
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
fix bug  
sneaxiy 已提交
495
      if (IsFastEagerDeletionModeEnabled()) {
496
        gc.reset(new UnsafeFastGPUGarbageCollector(place_, max_memory_size));
S
fix bug  
sneaxiy 已提交
497
      } else {
498
        gc.reset(new DefaultStreamGarbageCollector(place_, max_memory_size));
S
fix bug  
sneaxiy 已提交
499
      }
500 501 502
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No GPU gc found in CPU/XPU paddle"));
S
sneaxiy 已提交
503
#endif
504
    } else if (platform::is_cpu_place(place_)) {
505
      gc.reset(new CPUGarbageCollector(place_, max_memory_size));
506 507
    } else if (platform::is_xpu_place(place_)) {
#ifdef PADDLE_WITH_XPU
508
      gc.reset(new XPUGarbageCollector(place_, max_memory_size));
509 510 511
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No XPU gc found in CPU/GPU paddle"));
J
jianghaicheng 已提交
512 513 514
#endif
    } else if (platform::is_ipu_place(place_)) {
#ifdef PADDLE_WITH_IPU
515
      gc.reset(new IPUGarbageCollector(place_, max_memory_size));
J
jianghaicheng 已提交
516 517 518
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No IPU gc found in CPU/IPU paddle"));
519 520 521
#endif
    } else if (platform::is_npu_place(place_)) {
#ifdef PADDLE_WITH_ASCEND_CL
522 523
      if (IsFastEagerDeletionModeEnabled()) {
        VLOG(4) << "Use unsafe fast gc for NPU.";
524
        gc.reset(new NPUUnsafeFastGarbageCollector(place_, max_memory_size));
525 526 527 528 529 530
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "Please set FLAGS_fast_eager_deletion_mode=true to use "
            "GarbageCollector on NPU."));
        // TODO(zhiqiu): fix bugs and enable NPUDefaultStreamGarbageCollector.
        VLOG(4) << "Use default stream gc for NPU.";
531
        gc.reset(new NPUDefaultStreamGarbageCollector(place_, max_memory_size));
532
      }
533
#else
534 535
      PADDLE_THROW(
          platform::errors::Unimplemented("No NPU gc found in CPU/NPU paddle"));
F
fwenguang 已提交
536 537 538 539
#endif
    } else if (platform::is_mlu_place(place_)) {
#ifdef PADDLE_WITH_MLU
      if (IsFastEagerDeletionModeEnabled()) {
540
        gc.reset(new MLUUnsafeFastGarbageCollector(place_, max_memory_size));
F
fwenguang 已提交
541
      } else {
542
        gc.reset(new MLUDefaultStreamGarbageCollector(place_, max_memory_size));
F
fwenguang 已提交
543 544 545 546
      }
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No MLU gc found in CPU/MLU paddle"));
547 548 549 550 551 552 553 554 555 556 557 558 559 560
#endif
    } else if (platform::is_custom_place(place_)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
      if (IsFastEagerDeletionModeEnabled()) {
        VLOG(4) << "Use unsafe fast gc for " << place_ << ".";
        gc.reset(new CustomDeviceUnsafeFastGarbageCollector(place_,
                                                            max_memory_size));
      } else {
        VLOG(4) << "Use default stream gc for " << place_ << ".";
        gc.reset(
            new CustomDefaultStreamGarbageCollector(place_, max_memory_size));
      }
#else
      PADDLE_THROW(platform::errors::Unimplemented("No CustomDevice gc found"));
S
sneaxiy 已提交
561
#endif
562
    }
S
sneaxiy 已提交
563 564
  }

565 566
  for (int64_t i = start_op_index; i < end_op_index; ++i) {
    auto& op = ctx->ops_[i];
567
    op->Run(*local_scope, place_);
S
fix bug  
sneaxiy 已提交
568
    if (gc) {
569 570
      platform::RecordEvent record(
          "CheckGC", platform::TracerEventType::UserDefined, 10);
S
sneaxiy 已提交
571
      DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get());
S
sneaxiy 已提交
572
    }
Y
Yu Yang 已提交
573
  }
S
sneaxiy 已提交
574

L
Leo Chen 已提交
575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595
  auto callback = [scope, local_scope, keep_kids]() {
    if (local_scope != scope) {
      VLOG(4) << "Delete scope: " << local_scope;
      scope->DeleteScope(local_scope);
    } else {
      if (!keep_kids) {
        VLOG(4) << "Drop kids: " << scope;
        // By default, we should delete all kid scopes after run executor
        // because
        // some operators may create local scope when running, such as while_op.
        // But when while_op also create a local executor to run it's sub block,
        // the sub scopes it created should not be dropped immediately, because
        // while_grad_op will use some variables created during while_op run, so
        // we need to keep the kids and wait for the outer executor to drop
        // them.

        scope->DropKids();
      }
      VLOG(4) << "Keep kids: " << scope;
    }
  };
S
sneaxiy 已提交
596

L
Leo Chen 已提交
597 598 599
  if (gc) {
    VLOG(4) << "Async deleting scope";
    gc->DirectClearCallback(callback);
600
  } else {
L
Leo Chen 已提交
601 602 603
    VLOG(4) << "Sync deleting scope";
    platform::DeviceContextPool::Instance().Get(place_)->Wait();
    callback();
Y
Yu Yang 已提交
604 605 606
  }
}

607 608 609 610
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx,
                                  Scope* scope,
                                  bool create_local_scope,
                                  bool create_vars,
611 612 613
                                  bool keep_kids) {
  int64_t start_op_index = 0;
  int64_t end_op_index = ctx->ops_.size();
614 615 616 617 618 619 620
  RunPartialPreparedContext(ctx,
                            scope,
                            start_op_index,
                            end_op_index,
                            create_local_scope,
                            create_vars,
                            keep_kids);
621 622
}

623
void Executor::RunPreparedContext(
624 625
    ExecutorPrepareContext* ctx,
    Scope* scope,
626
    std::map<std::string, const LoDTensor*>* feed_targets,
627 628 629 630
    std::map<std::string, FetchType*>* fetch_targets,
    bool create_local_scope,
    bool create_vars,
    const std::string& feed_holder_name,
W
Wu Yi 已提交
631
    const std::string& fetch_holder_name) {
632 633
  auto& global_block = ctx->prog_.Block(ctx->block_id_);

634
  PADDLE_ENFORCE_EQ(
635 636
      has_feed_operators(global_block, *feed_targets, feed_holder_name),
      true,
637 638 639
      platform::errors::PreconditionNotMet(
          "Program in ExecutorPrepareContext should has feed_ops."));
  PADDLE_ENFORCE_EQ(
640
      has_fetch_operators(global_block, *fetch_targets, fetch_holder_name),
641 642 643
      true,
      platform::errors::PreconditionNotMet(
          "Program in the prepared context should has fetch_ops."));
644

645 646 647 648
  // map the data of feed_targets to feed_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFeedOpType) {
      std::string feed_target_name = op->Output("Out")[0];
R
Ruibiao Chen 已提交
649
      int idx = PADDLE_GET_CONST(int, op->GetAttr("col"));
650 651
      SetFeedVariable(
          scope, *(*feed_targets)[feed_target_name], feed_holder_name, idx);
652 653 654
    }
  }

W
Wu Yi 已提交
655
  RunPreparedContext(ctx, scope, create_local_scope, create_vars);
656 657 658 659 660

  // obtain the data of fetch_targets from fetch_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFetchOpType) {
      std::string fetch_target_name = op->Input("X")[0];
R
Ruibiao Chen 已提交
661
      int idx = PADDLE_GET_CONST(int, op->GetAttr("col"));
662
      *(*fetch_targets)[fetch_target_name] =
663 664 665 666 667
          GetFetchVariable(*scope, fetch_holder_name, idx);
    }
  }
}

668 669
void Executor::EnableMKLDNN(const ProgramDesc& program) {
#ifdef PADDLE_WITH_MKLDNN
M
minqiyang 已提交
670
  VLOG(3) << "use_mkldnn=True";
671 672 673 674 675 676 677 678
  for (size_t bid = 0; bid < program.Size(); ++bid) {
    auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
    for (auto* op : block->AllOps()) {
      if (op->HasAttr("use_mkldnn")) {
        op->SetAttr("use_mkldnn", true);
      }
    }
  }
679 680 681
#else
  LOG(WARNING)
      << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
682 683
#endif
}
Q
qijun 已提交
684 685
}  // namespace framework
}  // namespace paddle