executor.cc 25.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/executor.h"
16

17
#include <memory>
18

Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/feed_fetch_method.h"
D
dongdaxiang 已提交
20 21
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
Z
Zeng Jinle 已提交
22
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
23
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
S
sneaxiy 已提交
24
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
Y
Yi Wang 已提交
25
#include "paddle/fluid/platform/place.h"
X
Xin Pan 已提交
26
#include "paddle/fluid/platform/profiler.h"
27
#include "paddle/fluid/platform/profiler/event_tracing.h"
28 29 30
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
31
#include "paddle/fluid/framework/executor_gc_helper.h"
Y
Yang Yu 已提交
32

D
dzhwinter 已提交
33
DECLARE_bool(benchmark);
34
DECLARE_bool(use_mkldnn);
Q
qijun 已提交
35 36 37

namespace paddle {
namespace framework {
X
Xin Pan 已提交
38 39 40 41 42
namespace {
// block id starts from 0. This id is used to represent the codeblock
// wrapping the first block 0.
int kProgramId = -1;
}  // namespace
Q
qijun 已提交
43

Q
Qiao Longfei 已提交
44
ExecutorPrepareContext::ExecutorPrepareContext(
S
sneaxiy 已提交
45 46 47 48 49
    const framework::ProgramDesc& prog, size_t block_id)
    : prog_(prog), block_id_(block_id) {}

void ExecutorPrepareContext::PrepareUnusedVars(
    const std::vector<std::string>& keep_vars, bool force_disable_gc) {
Z
Zeng Jinle 已提交
50 51 52
  // If gc is enabled and block size > 1
  if (prog_.Size() > 1) {
    operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
53
        prog_, block_id_, ops_);
54 55
    operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
        prog_, block_id_, ops_);
Z
Zeng Jinle 已提交
56
    operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
57
        prog_, block_id_, ops_);
Z
Zeng Jinle 已提交
58
  }
59 60 61 62 63 64

  force_disable_gc_ = force_disable_gc;
  if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) {
    return;
  }

S
sneaxiy 已提交
65
  unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars);
S
sneaxiy 已提交
66
}
Y
Yu Yang 已提交
67

Q
Qiao Longfei 已提交
68
ExecutorPrepareContext::~ExecutorPrepareContext() {
M
minqiyang 已提交
69
  VLOG(5) << "destroy ExecutorPrepareContext";
Q
Qiao Longfei 已提交
70
}
Y
Yu Yang 已提交
71

D
dzhwinter 已提交
72
Executor::Executor(const platform::Place& place) : place_(place) {}
Q
qijun 已提交
73

74 75
Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN
76
  // Clear mkl-dnn cache,
77
  // this is needed to have mkl-dnn unit tests working
78
  platform::ClearMKLDNNCache(place_, this);
79 80 81
#endif
}

Y
Yancey1989 已提交
82
void Executor::Close() {
T
tangwei12 已提交
83 84 85 86 87 88 89
  // #ifdef PADDLE_WITH_DISTRIBUTE
  //   // TODO(typhoonzero): complete message will need to use real trainer_id,
  //   // except 0.
  //   auto client =
  //       paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
  //   client->SendComplete();
  // #endif
Y
Yancey1989 已提交
90
}
W
Wu Yi 已提交
91

92 93
void Executor::CreateVariables(const ProgramDesc& pdesc,
                               Scope* scope,
L
Liu Yiqun 已提交
94
                               int block_id) {
95
  VLOG(3) << "Creating Variables for block " << block_id;
L
Liu Yiqun 已提交
96
  auto& global_block = pdesc.Block(block_id);
97 98 99 100 101 102 103 104 105 106 107 108
  const Scope* ancestor_scope = scope;
  while (ancestor_scope->parent()) {
    ancestor_scope = ancestor_scope->parent();
  }
  if (ancestor_scope != scope) {
    for (auto& var : global_block.AllVars()) {
      if (var->Name() == framework::kEmptyVarName) {
        continue;
      }

      if (var->Persistable()) {
        auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
S
Steffy-zxf 已提交
109 110

        VLOG(3) << "Initialize Variable " << var->Name();
111
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
112
        VLOG(3) << "Create Variable " << var->Name()
S
Steffy-zxf 已提交
113 114
                << " global, which pointer is " << ptr << " type is "
                << static_cast<int>(var->GetType());
115 116
      } else {
        auto* ptr = scope->Var(var->Name());
117
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
118
        VLOG(3) << "Create Variable " << var->Name()
S
Steffy-zxf 已提交
119 120
                << " locally, which pointer is " << ptr << "Variable Type "
                << static_cast<int>(var->GetType());
121 122 123 124 125
      }
    }
  } else {
    for (auto& var : global_block.AllVars()) {
      auto* ptr = scope->Var(var->Name());
126
      InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
127 128
      VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
              << ptr;
129 130 131 132
    }
  }
}

133
std::shared_ptr<TrainerBase> Executor::InitForDataset(
134 135 136 137
    const ProgramDesc& main_program,
    const std::string& trainer_desc_str,
    Scope* scope,
    Dataset* dataset) {
138
  VLOG(3) << "Start to InitForDataset in executor";
D
dongdaxiang 已提交
139
  TrainerDesc trainer_desc;
H
hutuxian 已提交
140
  bool success = trainer_desc.ParseFromString(trainer_desc_str);
141 142
  PADDLE_ENFORCE_EQ(success,
                    true,
143 144 145
                    platform::errors::PreconditionNotMet(
                        "Fail to parse TrainerDesc from string:\n%s",
                        trainer_desc_str.c_str()));
D
dongdaxiang 已提交
146 147 148 149 150 151 152 153
  VLOG(3) << "Going to create trainer, trainer class is "
          << trainer_desc.class_name();
  std::shared_ptr<TrainerBase> trainer;
  trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name());
  // initialize trainer
  VLOG(3) << "Going to initialize trainer";
  trainer->Initialize(trainer_desc, dataset);
  VLOG(3) << "Set root scope here";
D
dongdaxiang 已提交
154
  trainer->SetScope(scope);
D
dongdaxiang 已提交
155 156 157 158 159
  // prepare training environment and helper environment
  VLOG(3) << "Try to init train environment";
  trainer->InitTrainerEnv(main_program, place_);
  VLOG(3) << "Try to init other environment";
  trainer->InitOtherEnv(main_program);
160 161 162 163
  return trainer;
}

void Executor::RunFromDataset(std::shared_ptr<TrainerBase> trainer) {
164
  PADDLE_ENFORCE_NOT_NULL(
165 166 167
      trainer,
      platform::errors::InvalidArgument(
          "Trainer is nullptr, invoke InitForDataset first"));
D
dongdaxiang 已提交
168 169 170
  // training and finalize training
  VLOG(3) << "Trainer starts to run";
  trainer->Run();
D
Dong Daxiang 已提交
171 172 173
}

void Executor::ReleaseTrainer(std::shared_ptr<TrainerBase> trainer) {
D
dongdaxiang 已提交
174 175 176
  VLOG(3) << "Trainer going to finalize";
  trainer->Finalize();
}
D
dongdaxiang 已提交
177

178 179 180 181 182
void Executor::Run(const ProgramDesc& pdesc,
                   Scope* scope,
                   int block_id,
                   bool create_local_scope,
                   bool create_vars,
S
sneaxiy 已提交
183
                   const std::vector<std::string>& skip_ref_cnt_vars,
184 185 186 187
                   bool force_disable_gc,
                   bool keep_kid_scopes) {
  platform::RecordEvent record_run(
      "Executor::Run", platform::TracerEventType::UserDefined, 1);
X
Xin Pan 已提交
188
  platform::RecordBlock b(block_id);
189
  if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
J
Jacek Czaja 已提交
190
  auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
191 192
#ifdef PADDLE_WITH_MKLDNN
  platform::AttachPointerHashToMKLDNNKey(this, place_);
J
Jacek Czaja 已提交
193
  platform::RegisterModelLayout(ctx->ops_, place_);
194
#endif
195 196
  RunPreparedContext(
      ctx.get(), scope, create_local_scope, create_vars, keep_kid_scopes);
Q
qijun 已提交
197 198
}

199 200 201 202 203 204 205
// Check whether the block already has feed operators and feed_holder.
// Return false if the block does not have any feed operators.
// If some feed operators have been prepended to the block, check that
// the info contained in these feed operators matches the feed_targets
// and feed_holder_name. Raise exception when any mismatch is found.
// Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators(
206
    const BlockDesc& block,
L
Liu Yiqun 已提交
207
    const std::map<std::string, const LoDTensor*>& feed_targets,
208 209
    const std::string& feed_holder_name) {
  size_t feed_count = 0;
210
  for (auto* op : block.AllOps()) {
211 212
    if (op->Type() == kFeedOpType) {
      feed_count++;
L
Liu Yiqun 已提交
213
      // The input variable's name of feed_op should be feed_holder_name.
214
      PADDLE_ENFORCE_EQ(
215 216
          op->Input("X")[0],
          feed_holder_name,
217 218
          platform::errors::PreconditionNotMet(
              "Input to feed op should be '%s', but received '%s'.",
219 220
              feed_holder_name,
              op->Input("X")[0]));
221
      std::string feed_target_name = op->Output("Out")[0];
222 223
      PADDLE_ENFORCE_NE(feed_targets.find(feed_target_name),
                        feed_targets.end(),
224 225 226 227
                        platform::errors::PreconditionNotMet(
                            "Feed operator output name '%s' cannot be found in "
                            "'feed_targets'",
                            feed_target_name));
228 229 230 231 232
    }
  }

  if (feed_count > 0) {
    PADDLE_ENFORCE_EQ(
233 234
        feed_count,
        feed_targets.size(),
235 236 237
        platform::errors::PreconditionNotMet(
            "The number of feed operators should match 'feed_targets', but "
            "received feed_count: %zu, required feed_targets.size(): %zu.",
238 239
            feed_count,
            feed_targets.size()));
240

241
    if (!feed_holder_name.empty()) {
L
Liu Yiqun 已提交
242
      // When feed operator are present, so should be feed_holder.
243
      auto var = block.FindVar(feed_holder_name);
244 245 246 247 248
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable", feed_holder_name));
      PADDLE_ENFORCE_EQ(
249 250
          var->GetType(),
          proto::VarType::FEED_MINIBATCH,
251 252 253
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FEED_MINIBATCH' type, but received "
              "'%s'.",
254 255
              feed_holder_name,
              DataTypeToString(var->GetType())));
256
    }
257 258 259 260 261 262 263 264 265 266 267 268
  }

  return feed_count > 0;
}

// Check whether the block already has fetch operators and fetch_holder.
// Return false if the block does not have any fetch operators.
// If some fetch operators have been appended to the block, check that
// the info contained in these fetch operators matches the fetch_targets
// and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators(
L
Liu Yiqun 已提交
269
    const BlockDesc& block,
270
    const std::map<std::string, FetchType*>& fetch_targets,
271 272
    const std::string& fetch_holder_name) {
  size_t fetch_count = 0;
273
  for (auto* op : block.AllOps()) {
274 275
    if (op->Type() == kFetchOpType) {
      fetch_count++;
L
Liu Yiqun 已提交
276
      // The output variable's name of fetch_op should be fetch_holder_name.
277
      PADDLE_ENFORCE_EQ(
278 279
          op->Output("Out")[0],
          fetch_holder_name,
280 281
          platform::errors::PreconditionNotMet(
              "Output of fetch op should be '%s', but received '%s'.",
282 283
              fetch_holder_name,
              op->Output("Out")[0]));
284
      std::string fetch_target_name = op->Input("X")[0];
285 286 287 288 289 290
      PADDLE_ENFORCE_NE(fetch_targets.find(fetch_target_name),
                        fetch_targets.end(),
                        platform::errors::NotFound(
                            "Fetch operator input name '%s' cannot be found in "
                            "'fetch_targets'.",
                            fetch_target_name));
291 292 293 294 295
    }
  }

  if (fetch_count > 0) {
    PADDLE_ENFORCE_EQ(
296 297
        fetch_count,
        fetch_targets.size(),
298 299 300
        platform::errors::PreconditionNotMet(
            "The number of fetch operators should match 'fetch_targets', but "
            "received fetch_count: %zu, required fetch_targets.size(): %zu.",
301 302
            fetch_count,
            fetch_targets.size()));
303

304
    if (!fetch_holder_name.empty()) {
L
Liu Yiqun 已提交
305
      // When fetch operator are present, so should be fetch_holder.
306
      auto var = block.FindVar(fetch_holder_name);
307 308 309 310 311
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable.", fetch_holder_name));
      PADDLE_ENFORCE_EQ(
312 313
          var->GetType(),
          proto::VarType::FETCH_LIST,
314 315
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FETCH_LIST' type, but received '%s'.",
316 317
              fetch_holder_name,
              DataTypeToString(var->GetType())));
318
    }
319 320 321 322 323
  }

  return fetch_count > 0;
}

324 325
void Executor::Run(const ProgramDesc& program,
                   Scope* scope,
326
                   std::map<std::string, const LoDTensor*>* feed_targets,
327
                   std::map<std::string, FetchType*>* fetch_targets,
328 329
                   bool create_local_scope,
                   bool create_vars,
W
Wu Yi 已提交
330
                   const std::string& feed_holder_name,
331
                   const std::string& fetch_holder_name) {
332 333
  platform::RecordEvent record_run(
      "Executor::Run", platform::TracerEventType::UserDefined, 1);
X
Xin Pan 已提交
334
  platform::RecordBlock b(kProgramId);
335
  if (FLAGS_use_mkldnn) EnableMKLDNN(program);
336 337 338
#ifdef PADDLE_WITH_MKLDNN
  platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
339
  bool has_feed_ops =
340
      has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
341
  bool has_fetch_ops =
342
      has_fetch_operators(program.Block(0), *fetch_targets, fetch_holder_name);
343 344

  ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
S
sneaxiy 已提交
345
  std::unique_ptr<ProgramDesc> unique_ptr_of_copy_program;
346
  if (!has_feed_ops || !has_fetch_ops) {
S
sneaxiy 已提交
347 348
    unique_ptr_of_copy_program.reset(new ProgramDesc(program));
    copy_program = unique_ptr_of_copy_program.get();
349
  }
350 351
  auto* global_block = copy_program->MutableBlock(0);

352
  if (!has_feed_ops) {
353 354
    // create feed_holder variable
    auto* feed_holder = global_block->Var(feed_holder_name);
355
    feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
356 357 358
    feed_holder->SetPersistable(true);

    int i = 0;
359
    for (auto& feed_target : (*feed_targets)) {
360
      std::string var_name = feed_target.first;
M
minqiyang 已提交
361
      VLOG(3) << "feed target's name: " << var_name;
362 363 364 365 366 367 368 369 370 371 372 373 374

      // prepend feed op
      auto* op = global_block->PrependOp();
      op->SetType(kFeedOpType);
      op->SetInput("X", {feed_holder_name});
      op->SetOutput("Out", {var_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

375
  if (!has_fetch_ops) {
376 377
    // create fetch_holder variable
    auto* fetch_holder = global_block->Var(fetch_holder_name);
378
    fetch_holder->SetType(proto::VarType::FETCH_LIST);
379 380 381
    fetch_holder->SetPersistable(true);

    int i = 0;
382
    for (auto& fetch_target : (*fetch_targets)) {
383
      std::string var_name = fetch_target.first;
M
minqiyang 已提交
384
      VLOG(3) << "fetch target's name: " << var_name;
385 386 387 388 389 390 391 392 393 394 395 396 397

      // append fetch op
      auto* op = global_block->AppendOp();
      op->SetType(kFetchOpType);
      op->SetInput("X", {var_name});
      op->SetOutput("Out", {fetch_holder_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

398
  auto ctx = Prepare(*copy_program, 0);
399 400 401 402 403 404 405
  RunPreparedContext(ctx.get(),
                     scope,
                     feed_targets,
                     fetch_targets,
                     create_local_scope,
                     create_vars,
                     feed_holder_name,
W
Wu Yi 已提交
406
                     fetch_holder_name);
407 408
}

Q
Qiao Longfei 已提交
409
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
410 411 412 413
    const ProgramDesc& program,
    int block_id,
    const std::vector<std::string>& skip_ref_cnt_vars,
    bool force_disable_gc) {
S
sneaxiy 已提交
414 415
  std::unique_ptr<ExecutorPrepareContext> ctx(
      new ExecutorPrepareContext(program, block_id));
416 417
  PADDLE_ENFORCE_LT(static_cast<size_t>(block_id),
                    program.Size(),
418 419 420
                    platform::errors::InvalidArgument(
                        "Input block id = %d, but it should be less than "
                        "program.size() which is %d",
421 422
                        static_cast<size_t>(block_id),
                        program.Size()));
Y
Yu Yang 已提交
423 424 425 426
  auto& block = program.Block(block_id);
  for (auto& op_desc : block.AllOps()) {
    ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
S
sneaxiy 已提交
427
  ctx->PrepareUnusedVars(skip_ref_cnt_vars, force_disable_gc);
Q
Qiyang Min 已提交
428
  return ctx;
Y
Yu Yang 已提交
429 430
}

T
refine  
typhoonzero 已提交
431
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
432 433
    const ProgramDesc& program,
    const std::vector<int>& block_ids,
S
sneaxiy 已提交
434 435
    const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
    bool force_disable_gc) {
436
  PADDLE_ENFORCE_EQ(
S
fix bug  
sneaxiy 已提交
437
      skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
438 439 440 441
      true,
      platform::errors::InvalidArgument("skip_ref_cnt_vars should be either "
                                        "empty or equals to block number %d",
                                        block_ids.size()));
T
typhoonzero 已提交
442
  std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
S
fix bug  
sneaxiy 已提交
443
  size_t idx = 0;
T
typhoonzero 已提交
444
  for (auto& bid : block_ids) {
445 446
    PADDLE_ENFORCE_LT(static_cast<size_t>(bid),
                      program.Size(),
447 448 449
                      platform::errors::InvalidArgument(
                          "Input block id = %zu, but it should be less than "
                          "program.size() which is %zu",
450 451
                          static_cast<size_t>(bid),
                          program.Size()));
S
sneaxiy 已提交
452
    auto* ctx = new ExecutorPrepareContext(program, bid);
T
typhoonzero 已提交
453 454 455 456
    auto& block = program.Block(bid);
    for (auto& op_desc : block.AllOps()) {
      ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
    }
S
sneaxiy 已提交
457 458 459 460 461
    if (skip_ref_cnt_vars.empty()) {
      ctx->PrepareUnusedVars(std::vector<std::string>(), force_disable_gc);
    } else {
      ctx->PrepareUnusedVars(skip_ref_cnt_vars[idx], force_disable_gc);
    }
T
typhoonzero 已提交
462
    result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx));
S
fix bug  
sneaxiy 已提交
463
    ++idx;
T
typhoonzero 已提交
464 465 466 467
  }
  return result;
}

468
void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
469 470
                                         Scope* scope,
                                         int64_t start_op_index,
471 472
                                         int64_t end_op_index,
                                         bool create_local_scope,
473 474
                                         bool create_vars,
                                         bool keep_kids) {
L
liutiexing 已提交
475
  platform::RecordEvent record_run("Executor::RunPartialPreparedContext",
476 477
                                   platform::TracerEventType::UserDefined,
                                   1);
478
  platform::RecordBlock b(kProgramId);
479 480
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope shouldn't be null"));
Y
Yu Yang 已提交
481 482 483 484
  Scope* local_scope = scope;
  if (create_vars) {
    if (create_local_scope) {
      local_scope = &scope->NewScope();
485 486
    }
    CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
L
Liu Yiqun 已提交
487
  }
Y
Yu Yang 已提交
488

S
sneaxiy 已提交
489
  int64_t max_memory_size = GetEagerDeletionThreshold();
S
sneaxiy 已提交
490
  std::unique_ptr<GarbageCollector> gc;
S
sneaxiy 已提交
491
  if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
S
sneaxiy 已提交
492
    if (platform::is_gpu_place(place_)) {
493
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
fix bug  
sneaxiy 已提交
494
      if (IsFastEagerDeletionModeEnabled()) {
495
        gc.reset(new UnsafeFastGPUGarbageCollector(place_, max_memory_size));
S
fix bug  
sneaxiy 已提交
496
      } else {
497
        gc.reset(new DefaultStreamGarbageCollector(place_, max_memory_size));
S
fix bug  
sneaxiy 已提交
498
      }
499 500 501
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No GPU gc found in CPU/XPU paddle"));
S
sneaxiy 已提交
502
#endif
503
    } else if (platform::is_cpu_place(place_)) {
504
      gc.reset(new CPUGarbageCollector(place_, max_memory_size));
505 506
    } else if (platform::is_xpu_place(place_)) {
#ifdef PADDLE_WITH_XPU
507
      gc.reset(new XPUGarbageCollector(place_, max_memory_size));
508 509 510
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No XPU gc found in CPU/GPU paddle"));
J
jianghaicheng 已提交
511 512 513
#endif
    } else if (platform::is_ipu_place(place_)) {
#ifdef PADDLE_WITH_IPU
514
      gc.reset(new IPUGarbageCollector(place_, max_memory_size));
J
jianghaicheng 已提交
515 516 517
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No IPU gc found in CPU/IPU paddle"));
518 519 520
#endif
    } else if (platform::is_npu_place(place_)) {
#ifdef PADDLE_WITH_ASCEND_CL
521 522
      if (IsFastEagerDeletionModeEnabled()) {
        VLOG(4) << "Use unsafe fast gc for NPU.";
523
        gc.reset(new NPUUnsafeFastGarbageCollector(place_, max_memory_size));
524 525 526 527 528 529
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "Please set FLAGS_fast_eager_deletion_mode=true to use "
            "GarbageCollector on NPU."));
        // TODO(zhiqiu): fix bugs and enable NPUDefaultStreamGarbageCollector.
        VLOG(4) << "Use default stream gc for NPU.";
530
        gc.reset(new NPUDefaultStreamGarbageCollector(place_, max_memory_size));
531
      }
532
#else
533 534
      PADDLE_THROW(
          platform::errors::Unimplemented("No NPU gc found in CPU/NPU paddle"));
F
fwenguang 已提交
535 536 537 538
#endif
    } else if (platform::is_mlu_place(place_)) {
#ifdef PADDLE_WITH_MLU
      if (IsFastEagerDeletionModeEnabled()) {
539
        gc.reset(new MLUUnsafeFastGarbageCollector(place_, max_memory_size));
F
fwenguang 已提交
540
      } else {
541
        gc.reset(new MLUDefaultStreamGarbageCollector(place_, max_memory_size));
F
fwenguang 已提交
542 543 544 545
      }
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No MLU gc found in CPU/MLU paddle"));
546 547 548 549 550 551 552 553 554 555 556 557 558 559
#endif
    } else if (platform::is_custom_place(place_)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
      if (IsFastEagerDeletionModeEnabled()) {
        VLOG(4) << "Use unsafe fast gc for " << place_ << ".";
        gc.reset(new CustomDeviceUnsafeFastGarbageCollector(place_,
                                                            max_memory_size));
      } else {
        VLOG(4) << "Use default stream gc for " << place_ << ".";
        gc.reset(
            new CustomDefaultStreamGarbageCollector(place_, max_memory_size));
      }
#else
      PADDLE_THROW(platform::errors::Unimplemented("No CustomDevice gc found"));
S
sneaxiy 已提交
560
#endif
561
    }
S
sneaxiy 已提交
562 563
  }

564 565
  for (int64_t i = start_op_index; i < end_op_index; ++i) {
    auto& op = ctx->ops_[i];
566
    op->Run(*local_scope, place_);
S
fix bug  
sneaxiy 已提交
567
    if (gc) {
568 569
      platform::RecordEvent record(
          "CheckGC", platform::TracerEventType::UserDefined, 10);
S
sneaxiy 已提交
570
      DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get());
S
sneaxiy 已提交
571
    }
Y
Yu Yang 已提交
572
  }
S
sneaxiy 已提交
573

L
Leo Chen 已提交
574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594
  auto callback = [scope, local_scope, keep_kids]() {
    if (local_scope != scope) {
      VLOG(4) << "Delete scope: " << local_scope;
      scope->DeleteScope(local_scope);
    } else {
      if (!keep_kids) {
        VLOG(4) << "Drop kids: " << scope;
        // By default, we should delete all kid scopes after run executor
        // because
        // some operators may create local scope when running, such as while_op.
        // But when while_op also create a local executor to run it's sub block,
        // the sub scopes it created should not be dropped immediately, because
        // while_grad_op will use some variables created during while_op run, so
        // we need to keep the kids and wait for the outer executor to drop
        // them.

        scope->DropKids();
      }
      VLOG(4) << "Keep kids: " << scope;
    }
  };
S
sneaxiy 已提交
595

L
Leo Chen 已提交
596 597 598
  if (gc) {
    VLOG(4) << "Async deleting scope";
    gc->DirectClearCallback(callback);
599
  } else {
L
Leo Chen 已提交
600 601 602
    VLOG(4) << "Sync deleting scope";
    platform::DeviceContextPool::Instance().Get(place_)->Wait();
    callback();
Y
Yu Yang 已提交
603 604 605
  }
}

606 607 608 609
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx,
                                  Scope* scope,
                                  bool create_local_scope,
                                  bool create_vars,
610 611 612
                                  bool keep_kids) {
  int64_t start_op_index = 0;
  int64_t end_op_index = ctx->ops_.size();
613 614 615 616 617 618 619
  RunPartialPreparedContext(ctx,
                            scope,
                            start_op_index,
                            end_op_index,
                            create_local_scope,
                            create_vars,
                            keep_kids);
620 621
}

622
void Executor::RunPreparedContext(
623 624
    ExecutorPrepareContext* ctx,
    Scope* scope,
625
    std::map<std::string, const LoDTensor*>* feed_targets,
626 627 628 629
    std::map<std::string, FetchType*>* fetch_targets,
    bool create_local_scope,
    bool create_vars,
    const std::string& feed_holder_name,
W
Wu Yi 已提交
630
    const std::string& fetch_holder_name) {
631 632
  auto& global_block = ctx->prog_.Block(ctx->block_id_);

633
  PADDLE_ENFORCE_EQ(
634 635
      has_feed_operators(global_block, *feed_targets, feed_holder_name),
      true,
636 637 638
      platform::errors::PreconditionNotMet(
          "Program in ExecutorPrepareContext should has feed_ops."));
  PADDLE_ENFORCE_EQ(
639
      has_fetch_operators(global_block, *fetch_targets, fetch_holder_name),
640 641 642
      true,
      platform::errors::PreconditionNotMet(
          "Program in the prepared context should has fetch_ops."));
643

644 645 646 647
  // map the data of feed_targets to feed_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFeedOpType) {
      std::string feed_target_name = op->Output("Out")[0];
R
Ruibiao Chen 已提交
648
      int idx = PADDLE_GET_CONST(int, op->GetAttr("col"));
649 650
      SetFeedVariable(
          scope, *(*feed_targets)[feed_target_name], feed_holder_name, idx);
651 652 653
    }
  }

W
Wu Yi 已提交
654
  RunPreparedContext(ctx, scope, create_local_scope, create_vars);
655 656 657 658 659

  // obtain the data of fetch_targets from fetch_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFetchOpType) {
      std::string fetch_target_name = op->Input("X")[0];
R
Ruibiao Chen 已提交
660
      int idx = PADDLE_GET_CONST(int, op->GetAttr("col"));
661
      *(*fetch_targets)[fetch_target_name] =
662 663 664 665 666
          GetFetchVariable(*scope, fetch_holder_name, idx);
    }
  }
}

667 668
void Executor::EnableMKLDNN(const ProgramDesc& program) {
#ifdef PADDLE_WITH_MKLDNN
M
minqiyang 已提交
669
  VLOG(3) << "use_mkldnn=True";
670 671 672 673 674 675 676 677
  for (size_t bid = 0; bid < program.Size(); ++bid) {
    auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
    for (auto* op : block->AllOps()) {
      if (op->HasAttr("use_mkldnn")) {
        op->SetAttr("use_mkldnn", true);
      }
    }
  }
678 679 680
#else
  LOG(WARNING)
      << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
681 682
#endif
}
Q
qijun 已提交
683 684
}  // namespace framework
}  // namespace paddle