executor.cc 22.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/executor.h"
16
#include <memory>
Y
Yi Wang 已提交
17
#include "paddle/fluid/framework/feed_fetch_method.h"
D
dongdaxiang 已提交
18 19
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
Z
Zeng Jinle 已提交
20
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
21
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
S
sneaxiy 已提交
22
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
Y
Yi Wang 已提交
23
#include "paddle/fluid/platform/place.h"
X
Xin Pan 已提交
24
#include "paddle/fluid/platform/profiler.h"
25 26 27
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
28
#include "paddle/fluid/framework/executor_gc_helper.h"
Y
Yang Yu 已提交
29

D
dzhwinter 已提交
30
DECLARE_bool(benchmark);
31
DECLARE_bool(use_mkldnn);
Q
qijun 已提交
32 33 34

namespace paddle {
namespace framework {
X
Xin Pan 已提交
35 36 37 38 39
namespace {
// block id starts from 0. This id is used to represent the codeblock
// wrapping the first block 0.
int kProgramId = -1;
}  // namespace
Q
qijun 已提交
40

Q
Qiao Longfei 已提交
41
ExecutorPrepareContext::ExecutorPrepareContext(
S
sneaxiy 已提交
42 43 44 45 46
    const framework::ProgramDesc& prog, size_t block_id)
    : prog_(prog), block_id_(block_id) {}

void ExecutorPrepareContext::PrepareUnusedVars(
    const std::vector<std::string>& keep_vars, bool force_disable_gc) {
Z
Zeng Jinle 已提交
47 48 49
  // If gc is enabled and block size > 1
  if (prog_.Size() > 1) {
    operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
50 51 52
        prog_, block_id_, ops_);
    operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(prog_, block_id_,
                                                               ops_);
Z
Zeng Jinle 已提交
53
    operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
54
        prog_, block_id_, ops_);
Z
Zeng Jinle 已提交
55
  }
56 57 58 59 60 61

  force_disable_gc_ = force_disable_gc;
  if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) {
    return;
  }

S
sneaxiy 已提交
62
  unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars);
S
sneaxiy 已提交
63
}
Y
Yu Yang 已提交
64

Q
Qiao Longfei 已提交
65
ExecutorPrepareContext::~ExecutorPrepareContext() {
M
minqiyang 已提交
66
  VLOG(5) << "destroy ExecutorPrepareContext";
Q
Qiao Longfei 已提交
67
}
Y
Yu Yang 已提交
68

D
dzhwinter 已提交
69
Executor::Executor(const platform::Place& place) : place_(place) {}
Q
qijun 已提交
70

71 72
Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN
73
  // Clear mkl-dnn cache,
74
  // this is needed to have mkl-dnn unit tests working
75
  ClearMKLDNNCache(place_, this);
76 77 78
#endif
}

Y
Yancey1989 已提交
79
void Executor::Close() {
T
tangwei12 已提交
80 81 82 83 84 85 86
  // #ifdef PADDLE_WITH_DISTRIBUTE
  //   // TODO(typhoonzero): complete message will need to use real trainer_id,
  //   // except 0.
  //   auto client =
  //       paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
  //   client->SendComplete();
  // #endif
Y
Yancey1989 已提交
87
}
W
Wu Yi 已提交
88

L
Liu Yiqun 已提交
89 90
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
                               int block_id) {
91
  VLOG(3) << "Creating Variables for block " << block_id;
L
Liu Yiqun 已提交
92
  auto& global_block = pdesc.Block(block_id);
93 94 95 96 97 98 99 100 101 102 103 104
  const Scope* ancestor_scope = scope;
  while (ancestor_scope->parent()) {
    ancestor_scope = ancestor_scope->parent();
  }
  if (ancestor_scope != scope) {
    for (auto& var : global_block.AllVars()) {
      if (var->Name() == framework::kEmptyVarName) {
        continue;
      }

      if (var->Persistable()) {
        auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
105
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
106 107
        VLOG(3) << "Create Variable " << var->Name()
                << " global, which pointer is " << ptr;
108 109
      } else {
        auto* ptr = scope->Var(var->Name());
110
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
111 112
        VLOG(3) << "Create Variable " << var->Name()
                << " locally, which pointer is " << ptr;
113 114 115 116 117
      }
    }
  } else {
    for (auto& var : global_block.AllVars()) {
      auto* ptr = scope->Var(var->Name());
118
      InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
119 120
      VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
              << ptr;
121 122 123 124
    }
  }
}

125 126 127
std::shared_ptr<TrainerBase> Executor::InitForDataset(
    const ProgramDesc& main_program, const std::string& trainer_desc_str,
    Scope* scope, Dataset* dataset) {
D
dongdaxiang 已提交
128 129
  VLOG(3) << "Start to RunFromDataset in executor";
  TrainerDesc trainer_desc;
H
hutuxian 已提交
130
  bool success = trainer_desc.ParseFromString(trainer_desc_str);
131 132 133 134
  PADDLE_ENFORCE_EQ(success, true,
                    platform::errors::PreconditionNotMet(
                        "Fail to parse TrainerDesc from string:\n%s",
                        trainer_desc_str.c_str()));
D
dongdaxiang 已提交
135 136 137 138 139 140 141 142
  VLOG(3) << "Going to create trainer, trainer class is "
          << trainer_desc.class_name();
  std::shared_ptr<TrainerBase> trainer;
  trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name());
  // initialize trainer
  VLOG(3) << "Going to initialize trainer";
  trainer->Initialize(trainer_desc, dataset);
  VLOG(3) << "Set root scope here";
D
dongdaxiang 已提交
143
  trainer->SetScope(scope);
D
dongdaxiang 已提交
144 145 146 147 148
  // prepare training environment and helper environment
  VLOG(3) << "Try to init train environment";
  trainer->InitTrainerEnv(main_program, place_);
  VLOG(3) << "Try to init other environment";
  trainer->InitOtherEnv(main_program);
149 150 151 152
  return trainer;
}

void Executor::RunFromDataset(std::shared_ptr<TrainerBase> trainer) {
153 154 155
  PADDLE_ENFORCE_NOT_NULL(
      trainer, platform::errors::InvalidArgument(
                   "Trainer is nullptr, invoke InitForDataset first"));
D
dongdaxiang 已提交
156 157 158
  // training and finalize training
  VLOG(3) << "Trainer starts to run";
  trainer->Run();
D
Dong Daxiang 已提交
159 160 161
}

void Executor::ReleaseTrainer(std::shared_ptr<TrainerBase> trainer) {
D
dongdaxiang 已提交
162 163 164
  VLOG(3) << "Trainer going to finalize";
  trainer->Finalize();
}
D
dongdaxiang 已提交
165

Y
Yu Yang 已提交
166
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
S
sneaxiy 已提交
167 168
                   bool create_local_scope, bool create_vars,
                   const std::vector<std::string>& skip_ref_cnt_vars,
169
                   bool force_disable_gc, bool keep_kid_scopes) {
X
Xin Pan 已提交
170
  platform::RecordBlock b(block_id);
171
  if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
172 173 174
#ifdef PADDLE_WITH_MKLDNN
  platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
S
sneaxiy 已提交
175
  auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
176 177
  RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars,
                     keep_kid_scopes);
Q
qijun 已提交
178 179
}

180 181 182 183 184 185 186
// Check whether the block already has feed operators and feed_holder.
// Return false if the block does not have any feed operators.
// If some feed operators have been prepended to the block, check that
// the info contained in these feed operators matches the feed_targets
// and feed_holder_name. Raise exception when any mismatch is found.
// Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators(
187
    const BlockDesc& block,
L
Liu Yiqun 已提交
188
    const std::map<std::string, const LoDTensor*>& feed_targets,
189 190
    const std::string& feed_holder_name) {
  size_t feed_count = 0;
191
  for (auto* op : block.AllOps()) {
192 193
    if (op->Type() == kFeedOpType) {
      feed_count++;
L
Liu Yiqun 已提交
194
      // The input variable's name of feed_op should be feed_holder_name.
195 196 197 198 199
      PADDLE_ENFORCE_EQ(
          op->Input("X")[0], feed_holder_name,
          platform::errors::PreconditionNotMet(
              "Input to feed op should be '%s', but received '%s'.",
              feed_holder_name, op->Input("X")[0]));
200
      std::string feed_target_name = op->Output("Out")[0];
201 202 203 204 205
      PADDLE_ENFORCE_NE(feed_targets.find(feed_target_name), feed_targets.end(),
                        platform::errors::PreconditionNotMet(
                            "Feed operator output name '%s' cannot be found in "
                            "'feed_targets'",
                            feed_target_name));
206 207 208 209 210 211
    }
  }

  if (feed_count > 0) {
    PADDLE_ENFORCE_EQ(
        feed_count, feed_targets.size(),
212 213 214 215
        platform::errors::PreconditionNotMet(
            "The number of feed operators should match 'feed_targets', but "
            "received feed_count: %zu, required feed_targets.size(): %zu.",
            feed_count, feed_targets.size()));
216

217
    if (!feed_holder_name.empty()) {
L
Liu Yiqun 已提交
218
      // When feed operator are present, so should be feed_holder.
219
      auto var = block.FindVar(feed_holder_name);
220 221 222 223 224 225 226 227 228 229
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable", feed_holder_name));
      PADDLE_ENFORCE_EQ(
          var->GetType(), proto::VarType::FEED_MINIBATCH,
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FEED_MINIBATCH' type, but received "
              "'%s'.",
              feed_holder_name, DataTypeToString(var->GetType())));
230
    }
231 232 233 234 235 236 237 238 239 240 241 242
  }

  return feed_count > 0;
}

// Check whether the block already has fetch operators and fetch_holder.
// Return false if the block does not have any fetch operators.
// If some fetch operators have been appended to the block, check that
// the info contained in these fetch operators matches the fetch_targets
// and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators(
L
Liu Yiqun 已提交
243
    const BlockDesc& block,
244
    const std::map<std::string, FetchType*>& fetch_targets,
245 246
    const std::string& fetch_holder_name) {
  size_t fetch_count = 0;
247
  for (auto* op : block.AllOps()) {
248 249
    if (op->Type() == kFetchOpType) {
      fetch_count++;
L
Liu Yiqun 已提交
250
      // The output variable's name of fetch_op should be fetch_holder_name.
251 252 253 254 255
      PADDLE_ENFORCE_EQ(
          op->Output("Out")[0], fetch_holder_name,
          platform::errors::PreconditionNotMet(
              "Output of fetch op should be '%s', but received '%s'.",
              fetch_holder_name, op->Output("Out")[0]));
256
      std::string fetch_target_name = op->Input("X")[0];
257 258 259 260 261 262
      PADDLE_ENFORCE_NE(fetch_targets.find(fetch_target_name),
                        fetch_targets.end(),
                        platform::errors::NotFound(
                            "Fetch operator input name '%s' cannot be found in "
                            "'fetch_targets'.",
                            fetch_target_name));
263 264 265 266 267 268
    }
  }

  if (fetch_count > 0) {
    PADDLE_ENFORCE_EQ(
        fetch_count, fetch_targets.size(),
269 270 271 272
        platform::errors::PreconditionNotMet(
            "The number of fetch operators should match 'fetch_targets', but "
            "received fetch_count: %zu, required fetch_targets.size(): %zu.",
            fetch_count, fetch_targets.size()));
273

274
    if (!fetch_holder_name.empty()) {
L
Liu Yiqun 已提交
275
      // When fetch operator are present, so should be fetch_holder.
276
      auto var = block.FindVar(fetch_holder_name);
277 278 279 280 281 282 283 284 285
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable.", fetch_holder_name));
      PADDLE_ENFORCE_EQ(
          var->GetType(), proto::VarType::FETCH_LIST,
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FETCH_LIST' type, but received '%s'.",
              fetch_holder_name, DataTypeToString(var->GetType())));
286
    }
287 288 289 290 291 292
  }

  return fetch_count > 0;
}

void Executor::Run(const ProgramDesc& program, Scope* scope,
293
                   std::map<std::string, const LoDTensor*>* feed_targets,
294
                   std::map<std::string, FetchType*>* fetch_targets,
W
Wu Yi 已提交
295 296
                   bool create_local_scope, bool create_vars,
                   const std::string& feed_holder_name,
297
                   const std::string& fetch_holder_name) {
X
Xin Pan 已提交
298
  platform::RecordBlock b(kProgramId);
299
  if (FLAGS_use_mkldnn) EnableMKLDNN(program);
300 301 302
#ifdef PADDLE_WITH_MKLDNN
  platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
303
  bool has_feed_ops =
304
      has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
305
  bool has_fetch_ops =
306
      has_fetch_operators(program.Block(0), *fetch_targets, fetch_holder_name);
307 308

  ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
S
sneaxiy 已提交
309
  std::unique_ptr<ProgramDesc> unique_ptr_of_copy_program;
310
  if (!has_feed_ops || !has_fetch_ops) {
S
sneaxiy 已提交
311 312
    unique_ptr_of_copy_program.reset(new ProgramDesc(program));
    copy_program = unique_ptr_of_copy_program.get();
313
  }
314 315
  auto* global_block = copy_program->MutableBlock(0);

316
  if (!has_feed_ops) {
317 318
    // create feed_holder variable
    auto* feed_holder = global_block->Var(feed_holder_name);
319
    feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
320 321 322
    feed_holder->SetPersistable(true);

    int i = 0;
323
    for (auto& feed_target : (*feed_targets)) {
324
      std::string var_name = feed_target.first;
M
minqiyang 已提交
325
      VLOG(3) << "feed target's name: " << var_name;
326 327 328 329 330 331 332 333 334 335 336 337 338

      // prepend feed op
      auto* op = global_block->PrependOp();
      op->SetType(kFeedOpType);
      op->SetInput("X", {feed_holder_name});
      op->SetOutput("Out", {var_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

339
  if (!has_fetch_ops) {
340 341
    // create fetch_holder variable
    auto* fetch_holder = global_block->Var(fetch_holder_name);
342
    fetch_holder->SetType(proto::VarType::FETCH_LIST);
343 344 345
    fetch_holder->SetPersistable(true);

    int i = 0;
346
    for (auto& fetch_target : (*fetch_targets)) {
347
      std::string var_name = fetch_target.first;
M
minqiyang 已提交
348
      VLOG(3) << "fetch target's name: " << var_name;
349 350 351 352 353 354 355 356 357 358 359 360 361

      // append fetch op
      auto* op = global_block->AppendOp();
      op->SetType(kFetchOpType);
      op->SetInput("X", {var_name});
      op->SetOutput("Out", {fetch_holder_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

362
  auto ctx = Prepare(*copy_program, 0);
W
Wu Yi 已提交
363 364 365
  RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
                     create_local_scope, create_vars, feed_holder_name,
                     fetch_holder_name);
366 367
}

Q
Qiao Longfei 已提交
368
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
S
fix bug  
sneaxiy 已提交
369
    const ProgramDesc& program, int block_id,
S
sneaxiy 已提交
370
    const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
S
sneaxiy 已提交
371 372
  std::unique_ptr<ExecutorPrepareContext> ctx(
      new ExecutorPrepareContext(program, block_id));
373 374 375 376 377
  PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size(),
                    platform::errors::InvalidArgument(
                        "Input block id = %d, but it should be less than "
                        "program.size() which is %d",
                        static_cast<size_t>(block_id), program.Size()));
Y
Yu Yang 已提交
378 379 380 381
  auto& block = program.Block(block_id);
  for (auto& op_desc : block.AllOps()) {
    ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
S
sneaxiy 已提交
382
  ctx->PrepareUnusedVars(skip_ref_cnt_vars, force_disable_gc);
Q
Qiyang Min 已提交
383
  return ctx;
Y
Yu Yang 已提交
384 385
}

T
refine  
typhoonzero 已提交
386
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
S
fix bug  
sneaxiy 已提交
387
    const ProgramDesc& program, const std::vector<int>& block_ids,
S
sneaxiy 已提交
388 389
    const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
    bool force_disable_gc) {
390
  PADDLE_ENFORCE_EQ(
S
fix bug  
sneaxiy 已提交
391
      skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
392 393 394 395
      true,
      platform::errors::InvalidArgument("skip_ref_cnt_vars should be either "
                                        "empty or equals to block number %d",
                                        block_ids.size()));
T
typhoonzero 已提交
396
  std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
S
fix bug  
sneaxiy 已提交
397
  size_t idx = 0;
T
typhoonzero 已提交
398
  for (auto& bid : block_ids) {
399 400 401 402 403
    PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size(),
                      platform::errors::InvalidArgument(
                          "Input block id = %zu, but it should be less than "
                          "program.size() which is %zu",
                          static_cast<size_t>(bid), program.Size()));
S
sneaxiy 已提交
404
    auto* ctx = new ExecutorPrepareContext(program, bid);
T
typhoonzero 已提交
405 406 407 408
    auto& block = program.Block(bid);
    for (auto& op_desc : block.AllOps()) {
      ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
    }
S
sneaxiy 已提交
409 410 411 412 413
    if (skip_ref_cnt_vars.empty()) {
      ctx->PrepareUnusedVars(std::vector<std::string>(), force_disable_gc);
    } else {
      ctx->PrepareUnusedVars(skip_ref_cnt_vars[idx], force_disable_gc);
    }
T
typhoonzero 已提交
414
    result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx));
S
fix bug  
sneaxiy 已提交
415
    ++idx;
T
typhoonzero 已提交
416 417 418 419
  }
  return result;
}

420 421 422 423 424
void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
                                         Scope* scope, int64_t start_op_index,
                                         int64_t end_op_index,
                                         bool create_local_scope,
                                         bool create_vars, bool keep_kids) {
425
  platform::RecordBlock b(kProgramId);
426 427
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope shouldn't be null"));
Y
Yu Yang 已提交
428 429 430 431
  Scope* local_scope = scope;
  if (create_vars) {
    if (create_local_scope) {
      local_scope = &scope->NewScope();
432 433
    }
    CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
L
Liu Yiqun 已提交
434
  }
Y
Yu Yang 已提交
435

S
sneaxiy 已提交
436
  int64_t max_memory_size = GetEagerDeletionThreshold();
S
sneaxiy 已提交
437
  std::unique_ptr<GarbageCollector> gc;
S
sneaxiy 已提交
438
  if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
S
sneaxiy 已提交
439
    if (platform::is_gpu_place(place_)) {
440
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
fix bug  
sneaxiy 已提交
441
      if (IsFastEagerDeletionModeEnabled()) {
S
sneaxiy 已提交
442
        gc.reset(new UnsafeFastGPUGarbageCollector(
443
            BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
S
fix bug  
sneaxiy 已提交
444
      } else {
S
sneaxiy 已提交
445
        gc.reset(new DefaultStreamGarbageCollector(
446
            BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
S
fix bug  
sneaxiy 已提交
447
      }
448 449 450
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No GPU gc found in CPU/XPU paddle"));
S
sneaxiy 已提交
451
#endif
452
    } else if (platform::is_cpu_place(place_)) {
453 454
      gc.reset(new CPUGarbageCollector(
          BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
455 456 457 458 459 460 461
    } else if (platform::is_xpu_place(place_)) {
#ifdef PADDLE_WITH_XPU
      gc.reset(new XPUGarbageCollector(
          BOOST_GET_CONST(platform::XPUPlace, place_), max_memory_size));
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No XPU gc found in CPU/GPU paddle"));
462 463 464
#endif
    } else if (platform::is_npu_place(place_)) {
#ifdef PADDLE_WITH_ASCEND_CL
465 466 467 468 469 470 471 472 473 474 475 476 477
      if (IsFastEagerDeletionModeEnabled()) {
        VLOG(4) << "Use unsafe fast gc for NPU.";
        gc.reset(new NPUUnsafeFastGarbageCollector(
            BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size));
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "Please set FLAGS_fast_eager_deletion_mode=true to use "
            "GarbageCollector on NPU."));
        // TODO(zhiqiu): fix bugs and enable NPUDefaultStreamGarbageCollector.
        VLOG(4) << "Use default stream gc for NPU.";
        gc.reset(new NPUDefaultStreamGarbageCollector(
            BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size));
      }
478
#else
479 480
      PADDLE_THROW(
          platform::errors::Unimplemented("No NPU gc found in CPU/NPU paddle"));
S
sneaxiy 已提交
481
#endif
482
    }
S
sneaxiy 已提交
483 484
  }

485 486
  for (int64_t i = start_op_index; i < end_op_index; ++i) {
    auto& op = ctx->ops_[i];
487
    op->Run(*local_scope, place_);
S
fix bug  
sneaxiy 已提交
488
    if (gc) {
S
sneaxiy 已提交
489
      DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get());
S
sneaxiy 已提交
490
    }
Y
Yu Yang 已提交
491
  }
S
sneaxiy 已提交
492

L
Leo Chen 已提交
493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
  auto callback = [scope, local_scope, keep_kids]() {
    if (local_scope != scope) {
      VLOG(4) << "Delete scope: " << local_scope;
      scope->DeleteScope(local_scope);
    } else {
      if (!keep_kids) {
        VLOG(4) << "Drop kids: " << scope;
        // By default, we should delete all kid scopes after run executor
        // because
        // some operators may create local scope when running, such as while_op.
        // But when while_op also create a local executor to run it's sub block,
        // the sub scopes it created should not be dropped immediately, because
        // while_grad_op will use some variables created during while_op run, so
        // we need to keep the kids and wait for the outer executor to drop
        // them.

        scope->DropKids();
      }
      VLOG(4) << "Keep kids: " << scope;
    }
  };
S
sneaxiy 已提交
514

L
Leo Chen 已提交
515 516 517
  if (gc) {
    VLOG(4) << "Async deleting scope";
    gc->DirectClearCallback(callback);
518
  } else {
L
Leo Chen 已提交
519 520 521
    VLOG(4) << "Sync deleting scope";
    platform::DeviceContextPool::Instance().Get(place_)->Wait();
    callback();
Y
Yu Yang 已提交
522 523 524
  }
}

525 526 527 528 529 530 531 532 533
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
                                  bool create_local_scope, bool create_vars,
                                  bool keep_kids) {
  int64_t start_op_index = 0;
  int64_t end_op_index = ctx->ops_.size();
  RunPartialPreparedContext(ctx, scope, start_op_index, end_op_index,
                            create_local_scope, create_vars, keep_kids);
}

534 535
void Executor::RunPreparedContext(
    ExecutorPrepareContext* ctx, Scope* scope,
536
    std::map<std::string, const LoDTensor*>* feed_targets,
537
    std::map<std::string, FetchType*>* fetch_targets, bool create_local_scope,
W
Wu Yi 已提交
538 539
    bool create_vars, const std::string& feed_holder_name,
    const std::string& fetch_holder_name) {
540 541
  auto& global_block = ctx->prog_.Block(ctx->block_id_);

542 543 544 545 546
  PADDLE_ENFORCE_EQ(
      has_feed_operators(global_block, *feed_targets, feed_holder_name), true,
      platform::errors::PreconditionNotMet(
          "Program in ExecutorPrepareContext should has feed_ops."));
  PADDLE_ENFORCE_EQ(
547
      has_fetch_operators(global_block, *fetch_targets, fetch_holder_name),
548 549
      true, platform::errors::PreconditionNotMet(
                "Program in the prepared context should has fetch_ops."));
550

551 552 553 554
  // map the data of feed_targets to feed_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFeedOpType) {
      std::string feed_target_name = op->Output("Out")[0];
555
      int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
556 557
      SetFeedVariable(scope, *(*feed_targets)[feed_target_name],
                      feed_holder_name, idx);
558 559 560
    }
  }

W
Wu Yi 已提交
561
  RunPreparedContext(ctx, scope, create_local_scope, create_vars);
562 563 564 565 566

  // obtain the data of fetch_targets from fetch_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFetchOpType) {
      std::string fetch_target_name = op->Input("X")[0];
567
      int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
568
      *(*fetch_targets)[fetch_target_name] =
569 570 571 572 573
          GetFetchVariable(*scope, fetch_holder_name, idx);
    }
  }
}

574 575
void Executor::EnableMKLDNN(const ProgramDesc& program) {
#ifdef PADDLE_WITH_MKLDNN
M
minqiyang 已提交
576
  VLOG(3) << "use_mkldnn=True";
577 578 579 580 581 582 583 584
  for (size_t bid = 0; bid < program.Size(); ++bid) {
    auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
    for (auto* op : block->AllOps()) {
      if (op->HasAttr("use_mkldnn")) {
        op->SetAttr("use_mkldnn", true);
      }
    }
  }
585 586 587
#else
  LOG(WARNING)
      << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
588 589
#endif
}
Q
qijun 已提交
590 591
}  // namespace framework
}  // namespace paddle