executor.cc 24.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/executor.h"
16

17
#include <memory>
18

Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/feed_fetch_method.h"
D
dongdaxiang 已提交
20 21
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
Z
Zeng Jinle 已提交
22
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
23
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
S
sneaxiy 已提交
24
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
Y
Yi Wang 已提交
25
#include "paddle/fluid/platform/place.h"
X
Xin Pan 已提交
26
#include "paddle/fluid/platform/profiler.h"
27
#include "paddle/fluid/platform/profiler/event_tracing.h"
28 29 30
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
31
#include "paddle/fluid/framework/executor_gc_helper.h"
Y
Yang Yu 已提交
32

D
dzhwinter 已提交
33
DECLARE_bool(benchmark);
34
DECLARE_bool(use_mkldnn);
Q
qijun 已提交
35 36 37

namespace paddle {
namespace framework {
X
Xin Pan 已提交
38 39 40 41 42
namespace {
// block id starts from 0. This id is used to represent the codeblock
// wrapping the first block 0.
int kProgramId = -1;
}  // namespace
Q
qijun 已提交
43

Q
Qiao Longfei 已提交
44
ExecutorPrepareContext::ExecutorPrepareContext(
S
sneaxiy 已提交
45 46 47 48 49
    const framework::ProgramDesc& prog, size_t block_id)
    : prog_(prog), block_id_(block_id) {}

void ExecutorPrepareContext::PrepareUnusedVars(
    const std::vector<std::string>& keep_vars, bool force_disable_gc) {
Z
Zeng Jinle 已提交
50 51 52
  // If gc is enabled and block size > 1
  if (prog_.Size() > 1) {
    operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
53 54 55
        prog_, block_id_, ops_);
    operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(prog_, block_id_,
                                                               ops_);
Z
Zeng Jinle 已提交
56
    operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
57
        prog_, block_id_, ops_);
Z
Zeng Jinle 已提交
58
  }
59 60 61 62 63 64

  force_disable_gc_ = force_disable_gc;
  if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) {
    return;
  }

S
sneaxiy 已提交
65
  unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars);
S
sneaxiy 已提交
66
}
Y
Yu Yang 已提交
67

Q
Qiao Longfei 已提交
68
ExecutorPrepareContext::~ExecutorPrepareContext() {
M
minqiyang 已提交
69
  VLOG(5) << "destroy ExecutorPrepareContext";
Q
Qiao Longfei 已提交
70
}
Y
Yu Yang 已提交
71

D
dzhwinter 已提交
72
Executor::Executor(const platform::Place& place) : place_(place) {}
Q
qijun 已提交
73

74 75
Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN
76
  // Clear mkl-dnn cache,
77
  // this is needed to have mkl-dnn unit tests working
78
  platform::ClearMKLDNNCache(place_, this);
79 80 81
#endif
}

Y
Yancey1989 已提交
82
void Executor::Close() {
T
tangwei12 已提交
83 84 85 86 87 88 89
  // #ifdef PADDLE_WITH_DISTRIBUTE
  //   // TODO(typhoonzero): complete message will need to use real trainer_id,
  //   // except 0.
  //   auto client =
  //       paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
  //   client->SendComplete();
  // #endif
Y
Yancey1989 已提交
90
}
W
Wu Yi 已提交
91

L
Liu Yiqun 已提交
92 93
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
                               int block_id) {
94
  VLOG(3) << "Creating Variables for block " << block_id;
L
Liu Yiqun 已提交
95
  auto& global_block = pdesc.Block(block_id);
96 97 98 99 100 101 102 103 104 105 106 107
  const Scope* ancestor_scope = scope;
  while (ancestor_scope->parent()) {
    ancestor_scope = ancestor_scope->parent();
  }
  if (ancestor_scope != scope) {
    for (auto& var : global_block.AllVars()) {
      if (var->Name() == framework::kEmptyVarName) {
        continue;
      }

      if (var->Persistable()) {
        auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
S
Steffy-zxf 已提交
108 109

        VLOG(3) << "Initialize Variable " << var->Name();
110
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
111
        VLOG(3) << "Create Variable " << var->Name()
S
Steffy-zxf 已提交
112 113
                << " global, which pointer is " << ptr << " type is "
                << static_cast<int>(var->GetType());
114 115
      } else {
        auto* ptr = scope->Var(var->Name());
116
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
117
        VLOG(3) << "Create Variable " << var->Name()
S
Steffy-zxf 已提交
118 119
                << " locally, which pointer is " << ptr << "Variable Type "
                << static_cast<int>(var->GetType());
120 121 122 123 124
      }
    }
  } else {
    for (auto& var : global_block.AllVars()) {
      auto* ptr = scope->Var(var->Name());
125
      InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
126 127
      VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
              << ptr;
128 129 130 131
    }
  }
}

132 133 134
std::shared_ptr<TrainerBase> Executor::InitForDataset(
    const ProgramDesc& main_program, const std::string& trainer_desc_str,
    Scope* scope, Dataset* dataset) {
135
  VLOG(3) << "Start to InitForDataset in executor";
D
dongdaxiang 已提交
136
  TrainerDesc trainer_desc;
H
hutuxian 已提交
137
  bool success = trainer_desc.ParseFromString(trainer_desc_str);
138 139 140 141
  PADDLE_ENFORCE_EQ(success, true,
                    platform::errors::PreconditionNotMet(
                        "Fail to parse TrainerDesc from string:\n%s",
                        trainer_desc_str.c_str()));
D
dongdaxiang 已提交
142 143 144 145 146 147 148 149
  VLOG(3) << "Going to create trainer, trainer class is "
          << trainer_desc.class_name();
  std::shared_ptr<TrainerBase> trainer;
  trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name());
  // initialize trainer
  VLOG(3) << "Going to initialize trainer";
  trainer->Initialize(trainer_desc, dataset);
  VLOG(3) << "Set root scope here";
D
dongdaxiang 已提交
150
  trainer->SetScope(scope);
D
dongdaxiang 已提交
151 152 153 154 155
  // prepare training environment and helper environment
  VLOG(3) << "Try to init train environment";
  trainer->InitTrainerEnv(main_program, place_);
  VLOG(3) << "Try to init other environment";
  trainer->InitOtherEnv(main_program);
156 157 158 159
  return trainer;
}

void Executor::RunFromDataset(std::shared_ptr<TrainerBase> trainer) {
160 161 162
  PADDLE_ENFORCE_NOT_NULL(
      trainer, platform::errors::InvalidArgument(
                   "Trainer is nullptr, invoke InitForDataset first"));
D
dongdaxiang 已提交
163 164 165
  // training and finalize training
  VLOG(3) << "Trainer starts to run";
  trainer->Run();
D
Dong Daxiang 已提交
166 167 168
}

void Executor::ReleaseTrainer(std::shared_ptr<TrainerBase> trainer) {
D
dongdaxiang 已提交
169 170 171
  VLOG(3) << "Trainer going to finalize";
  trainer->Finalize();
}
D
dongdaxiang 已提交
172

Y
Yu Yang 已提交
173
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
S
sneaxiy 已提交
174 175
                   bool create_local_scope, bool create_vars,
                   const std::vector<std::string>& skip_ref_cnt_vars,
176
                   bool force_disable_gc, bool keep_kid_scopes) {
L
liutiexing 已提交
177 178
  platform::RecordEvent record_run("Executor::Run",
                                   platform::TracerEventType::UserDefined, 1);
X
Xin Pan 已提交
179
  platform::RecordBlock b(block_id);
180
  if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
J
Jacek Czaja 已提交
181
  auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
182 183
#ifdef PADDLE_WITH_MKLDNN
  platform::AttachPointerHashToMKLDNNKey(this, place_);
J
Jacek Czaja 已提交
184
  platform::RegisterModelLayout(ctx->ops_, place_);
185
#endif
186 187
  RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars,
                     keep_kid_scopes);
Q
qijun 已提交
188 189
}

190 191 192 193 194 195 196
// Check whether the block already has feed operators and feed_holder.
// Return false if the block does not have any feed operators.
// If some feed operators have been prepended to the block, check that
// the info contained in these feed operators matches the feed_targets
// and feed_holder_name. Raise exception when any mismatch is found.
// Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators(
197
    const BlockDesc& block,
L
Liu Yiqun 已提交
198
    const std::map<std::string, const LoDTensor*>& feed_targets,
199 200
    const std::string& feed_holder_name) {
  size_t feed_count = 0;
201
  for (auto* op : block.AllOps()) {
202 203
    if (op->Type() == kFeedOpType) {
      feed_count++;
L
Liu Yiqun 已提交
204
      // The input variable's name of feed_op should be feed_holder_name.
205 206 207 208 209
      PADDLE_ENFORCE_EQ(
          op->Input("X")[0], feed_holder_name,
          platform::errors::PreconditionNotMet(
              "Input to feed op should be '%s', but received '%s'.",
              feed_holder_name, op->Input("X")[0]));
210
      std::string feed_target_name = op->Output("Out")[0];
211 212 213 214 215
      PADDLE_ENFORCE_NE(feed_targets.find(feed_target_name), feed_targets.end(),
                        platform::errors::PreconditionNotMet(
                            "Feed operator output name '%s' cannot be found in "
                            "'feed_targets'",
                            feed_target_name));
216 217 218 219 220 221
    }
  }

  if (feed_count > 0) {
    PADDLE_ENFORCE_EQ(
        feed_count, feed_targets.size(),
222 223 224 225
        platform::errors::PreconditionNotMet(
            "The number of feed operators should match 'feed_targets', but "
            "received feed_count: %zu, required feed_targets.size(): %zu.",
            feed_count, feed_targets.size()));
226

227
    if (!feed_holder_name.empty()) {
L
Liu Yiqun 已提交
228
      // When feed operator are present, so should be feed_holder.
229
      auto var = block.FindVar(feed_holder_name);
230 231 232 233 234 235 236 237 238 239
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable", feed_holder_name));
      PADDLE_ENFORCE_EQ(
          var->GetType(), proto::VarType::FEED_MINIBATCH,
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FEED_MINIBATCH' type, but received "
              "'%s'.",
              feed_holder_name, DataTypeToString(var->GetType())));
240
    }
241 242 243 244 245 246 247 248 249 250 251 252
  }

  return feed_count > 0;
}

// Check whether the block already has fetch operators and fetch_holder.
// Return false if the block does not have any fetch operators.
// If some fetch operators have been appended to the block, check that
// the info contained in these fetch operators matches the fetch_targets
// and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators(
L
Liu Yiqun 已提交
253
    const BlockDesc& block,
254
    const std::map<std::string, FetchType*>& fetch_targets,
255 256
    const std::string& fetch_holder_name) {
  size_t fetch_count = 0;
257
  for (auto* op : block.AllOps()) {
258 259
    if (op->Type() == kFetchOpType) {
      fetch_count++;
L
Liu Yiqun 已提交
260
      // The output variable's name of fetch_op should be fetch_holder_name.
261 262 263 264 265
      PADDLE_ENFORCE_EQ(
          op->Output("Out")[0], fetch_holder_name,
          platform::errors::PreconditionNotMet(
              "Output of fetch op should be '%s', but received '%s'.",
              fetch_holder_name, op->Output("Out")[0]));
266
      std::string fetch_target_name = op->Input("X")[0];
267 268 269 270 271 272
      PADDLE_ENFORCE_NE(fetch_targets.find(fetch_target_name),
                        fetch_targets.end(),
                        platform::errors::NotFound(
                            "Fetch operator input name '%s' cannot be found in "
                            "'fetch_targets'.",
                            fetch_target_name));
273 274 275 276 277 278
    }
  }

  if (fetch_count > 0) {
    PADDLE_ENFORCE_EQ(
        fetch_count, fetch_targets.size(),
279 280 281 282
        platform::errors::PreconditionNotMet(
            "The number of fetch operators should match 'fetch_targets', but "
            "received fetch_count: %zu, required fetch_targets.size(): %zu.",
            fetch_count, fetch_targets.size()));
283

284
    if (!fetch_holder_name.empty()) {
L
Liu Yiqun 已提交
285
      // When fetch operator are present, so should be fetch_holder.
286
      auto var = block.FindVar(fetch_holder_name);
287 288 289 290 291 292 293 294 295
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable.", fetch_holder_name));
      PADDLE_ENFORCE_EQ(
          var->GetType(), proto::VarType::FETCH_LIST,
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FETCH_LIST' type, but received '%s'.",
              fetch_holder_name, DataTypeToString(var->GetType())));
296
    }
297 298 299 300 301 302
  }

  return fetch_count > 0;
}

void Executor::Run(const ProgramDesc& program, Scope* scope,
303
                   std::map<std::string, const LoDTensor*>* feed_targets,
304
                   std::map<std::string, FetchType*>* fetch_targets,
W
Wu Yi 已提交
305 306
                   bool create_local_scope, bool create_vars,
                   const std::string& feed_holder_name,
307
                   const std::string& fetch_holder_name) {
L
liutiexing 已提交
308 309
  platform::RecordEvent record_run("Executor::Run",
                                   platform::TracerEventType::UserDefined, 1);
X
Xin Pan 已提交
310
  platform::RecordBlock b(kProgramId);
311
  if (FLAGS_use_mkldnn) EnableMKLDNN(program);
312 313 314
#ifdef PADDLE_WITH_MKLDNN
  platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
315
  bool has_feed_ops =
316
      has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
317
  bool has_fetch_ops =
318
      has_fetch_operators(program.Block(0), *fetch_targets, fetch_holder_name);
319 320

  ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
S
sneaxiy 已提交
321
  std::unique_ptr<ProgramDesc> unique_ptr_of_copy_program;
322
  if (!has_feed_ops || !has_fetch_ops) {
S
sneaxiy 已提交
323 324
    unique_ptr_of_copy_program.reset(new ProgramDesc(program));
    copy_program = unique_ptr_of_copy_program.get();
325
  }
326 327
  auto* global_block = copy_program->MutableBlock(0);

328
  if (!has_feed_ops) {
329 330
    // create feed_holder variable
    auto* feed_holder = global_block->Var(feed_holder_name);
331
    feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
332 333 334
    feed_holder->SetPersistable(true);

    int i = 0;
335
    for (auto& feed_target : (*feed_targets)) {
336
      std::string var_name = feed_target.first;
M
minqiyang 已提交
337
      VLOG(3) << "feed target's name: " << var_name;
338 339 340 341 342 343 344 345 346 347 348 349 350

      // prepend feed op
      auto* op = global_block->PrependOp();
      op->SetType(kFeedOpType);
      op->SetInput("X", {feed_holder_name});
      op->SetOutput("Out", {var_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

351
  if (!has_fetch_ops) {
352 353
    // create fetch_holder variable
    auto* fetch_holder = global_block->Var(fetch_holder_name);
354
    fetch_holder->SetType(proto::VarType::FETCH_LIST);
355 356 357
    fetch_holder->SetPersistable(true);

    int i = 0;
358
    for (auto& fetch_target : (*fetch_targets)) {
359
      std::string var_name = fetch_target.first;
M
minqiyang 已提交
360
      VLOG(3) << "fetch target's name: " << var_name;
361 362 363 364 365 366 367 368 369 370 371 372 373

      // append fetch op
      auto* op = global_block->AppendOp();
      op->SetType(kFetchOpType);
      op->SetInput("X", {var_name});
      op->SetOutput("Out", {fetch_holder_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

374
  auto ctx = Prepare(*copy_program, 0);
W
Wu Yi 已提交
375 376 377
  RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
                     create_local_scope, create_vars, feed_holder_name,
                     fetch_holder_name);
378 379
}

Q
Qiao Longfei 已提交
380
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
S
fix bug  
sneaxiy 已提交
381
    const ProgramDesc& program, int block_id,
S
sneaxiy 已提交
382
    const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
S
sneaxiy 已提交
383 384
  std::unique_ptr<ExecutorPrepareContext> ctx(
      new ExecutorPrepareContext(program, block_id));
385 386 387 388 389
  PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size(),
                    platform::errors::InvalidArgument(
                        "Input block id = %d, but it should be less than "
                        "program.size() which is %d",
                        static_cast<size_t>(block_id), program.Size()));
Y
Yu Yang 已提交
390 391 392 393
  auto& block = program.Block(block_id);
  for (auto& op_desc : block.AllOps()) {
    ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
S
sneaxiy 已提交
394
  ctx->PrepareUnusedVars(skip_ref_cnt_vars, force_disable_gc);
Q
Qiyang Min 已提交
395
  return ctx;
Y
Yu Yang 已提交
396 397
}

T
refine  
typhoonzero 已提交
398
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
S
fix bug  
sneaxiy 已提交
399
    const ProgramDesc& program, const std::vector<int>& block_ids,
S
sneaxiy 已提交
400 401
    const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
    bool force_disable_gc) {
402
  PADDLE_ENFORCE_EQ(
S
fix bug  
sneaxiy 已提交
403
      skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
404 405 406 407
      true,
      platform::errors::InvalidArgument("skip_ref_cnt_vars should be either "
                                        "empty or equals to block number %d",
                                        block_ids.size()));
T
typhoonzero 已提交
408
  std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
S
fix bug  
sneaxiy 已提交
409
  size_t idx = 0;
T
typhoonzero 已提交
410
  for (auto& bid : block_ids) {
411 412 413 414 415
    PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size(),
                      platform::errors::InvalidArgument(
                          "Input block id = %zu, but it should be less than "
                          "program.size() which is %zu",
                          static_cast<size_t>(bid), program.Size()));
S
sneaxiy 已提交
416
    auto* ctx = new ExecutorPrepareContext(program, bid);
T
typhoonzero 已提交
417 418 419 420
    auto& block = program.Block(bid);
    for (auto& op_desc : block.AllOps()) {
      ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
    }
S
sneaxiy 已提交
421 422 423 424 425
    if (skip_ref_cnt_vars.empty()) {
      ctx->PrepareUnusedVars(std::vector<std::string>(), force_disable_gc);
    } else {
      ctx->PrepareUnusedVars(skip_ref_cnt_vars[idx], force_disable_gc);
    }
T
typhoonzero 已提交
426
    result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx));
S
fix bug  
sneaxiy 已提交
427
    ++idx;
T
typhoonzero 已提交
428 429 430 431
  }
  return result;
}

432 433 434 435 436
void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
                                         Scope* scope, int64_t start_op_index,
                                         int64_t end_op_index,
                                         bool create_local_scope,
                                         bool create_vars, bool keep_kids) {
L
liutiexing 已提交
437 438
  platform::RecordEvent record_run("Executor::RunPartialPreparedContext",
                                   platform::TracerEventType::UserDefined, 1);
439
  platform::RecordBlock b(kProgramId);
440 441
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope shouldn't be null"));
Y
Yu Yang 已提交
442 443 444 445
  Scope* local_scope = scope;
  if (create_vars) {
    if (create_local_scope) {
      local_scope = &scope->NewScope();
446 447
    }
    CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
L
Liu Yiqun 已提交
448
  }
Y
Yu Yang 已提交
449

S
sneaxiy 已提交
450
  int64_t max_memory_size = GetEagerDeletionThreshold();
S
sneaxiy 已提交
451
  std::unique_ptr<GarbageCollector> gc;
S
sneaxiy 已提交
452
  if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
S
sneaxiy 已提交
453
    if (platform::is_gpu_place(place_)) {
454
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
fix bug  
sneaxiy 已提交
455
      if (IsFastEagerDeletionModeEnabled()) {
456
        gc.reset(new UnsafeFastGPUGarbageCollector(place_, max_memory_size));
S
fix bug  
sneaxiy 已提交
457
      } else {
458
        gc.reset(new DefaultStreamGarbageCollector(place_, max_memory_size));
S
fix bug  
sneaxiy 已提交
459
      }
460 461 462
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No GPU gc found in CPU/XPU paddle"));
S
sneaxiy 已提交
463
#endif
464
    } else if (platform::is_cpu_place(place_)) {
465
      gc.reset(new CPUGarbageCollector(place_, max_memory_size));
466 467
    } else if (platform::is_xpu_place(place_)) {
#ifdef PADDLE_WITH_XPU
468
      gc.reset(new XPUGarbageCollector(place_, max_memory_size));
469 470 471
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No XPU gc found in CPU/GPU paddle"));
J
jianghaicheng 已提交
472 473 474
#endif
    } else if (platform::is_ipu_place(place_)) {
#ifdef PADDLE_WITH_IPU
475
      gc.reset(new IPUGarbageCollector(place_, max_memory_size));
J
jianghaicheng 已提交
476 477 478
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No IPU gc found in CPU/IPU paddle"));
479 480 481
#endif
    } else if (platform::is_npu_place(place_)) {
#ifdef PADDLE_WITH_ASCEND_CL
482 483
      if (IsFastEagerDeletionModeEnabled()) {
        VLOG(4) << "Use unsafe fast gc for NPU.";
484
        gc.reset(new NPUUnsafeFastGarbageCollector(place_, max_memory_size));
485 486 487 488 489 490
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "Please set FLAGS_fast_eager_deletion_mode=true to use "
            "GarbageCollector on NPU."));
        // TODO(zhiqiu): fix bugs and enable NPUDefaultStreamGarbageCollector.
        VLOG(4) << "Use default stream gc for NPU.";
491
        gc.reset(new NPUDefaultStreamGarbageCollector(place_, max_memory_size));
492
      }
493
#else
494 495
      PADDLE_THROW(
          platform::errors::Unimplemented("No NPU gc found in CPU/NPU paddle"));
F
fwenguang 已提交
496 497 498 499
#endif
    } else if (platform::is_mlu_place(place_)) {
#ifdef PADDLE_WITH_MLU
      if (IsFastEagerDeletionModeEnabled()) {
500
        gc.reset(new MLUUnsafeFastGarbageCollector(place_, max_memory_size));
F
fwenguang 已提交
501
      } else {
502
        gc.reset(new MLUDefaultStreamGarbageCollector(place_, max_memory_size));
F
fwenguang 已提交
503 504 505 506
      }
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No MLU gc found in CPU/MLU paddle"));
507 508 509 510 511 512 513 514 515 516 517 518 519 520
#endif
    } else if (platform::is_custom_place(place_)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
      if (IsFastEagerDeletionModeEnabled()) {
        VLOG(4) << "Use unsafe fast gc for " << place_ << ".";
        gc.reset(new CustomDeviceUnsafeFastGarbageCollector(place_,
                                                            max_memory_size));
      } else {
        VLOG(4) << "Use default stream gc for " << place_ << ".";
        gc.reset(
            new CustomDefaultStreamGarbageCollector(place_, max_memory_size));
      }
#else
      PADDLE_THROW(platform::errors::Unimplemented("No CustomDevice gc found"));
S
sneaxiy 已提交
521
#endif
522
    }
S
sneaxiy 已提交
523 524
  }

525 526
  for (int64_t i = start_op_index; i < end_op_index; ++i) {
    auto& op = ctx->ops_[i];
527
    op->Run(*local_scope, place_);
S
fix bug  
sneaxiy 已提交
528
    if (gc) {
L
liutiexing 已提交
529 530
      platform::RecordEvent record("CheckGC",
                                   platform::TracerEventType::UserDefined, 10);
S
sneaxiy 已提交
531
      DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get());
S
sneaxiy 已提交
532
    }
Y
Yu Yang 已提交
533
  }
S
sneaxiy 已提交
534

L
Leo Chen 已提交
535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555
  auto callback = [scope, local_scope, keep_kids]() {
    if (local_scope != scope) {
      VLOG(4) << "Delete scope: " << local_scope;
      scope->DeleteScope(local_scope);
    } else {
      if (!keep_kids) {
        VLOG(4) << "Drop kids: " << scope;
        // By default, we should delete all kid scopes after run executor
        // because
        // some operators may create local scope when running, such as while_op.
        // But when while_op also create a local executor to run it's sub block,
        // the sub scopes it created should not be dropped immediately, because
        // while_grad_op will use some variables created during while_op run, so
        // we need to keep the kids and wait for the outer executor to drop
        // them.

        scope->DropKids();
      }
      VLOG(4) << "Keep kids: " << scope;
    }
  };
S
sneaxiy 已提交
556

L
Leo Chen 已提交
557 558 559
  if (gc) {
    VLOG(4) << "Async deleting scope";
    gc->DirectClearCallback(callback);
560
  } else {
L
Leo Chen 已提交
561 562 563
    VLOG(4) << "Sync deleting scope";
    platform::DeviceContextPool::Instance().Get(place_)->Wait();
    callback();
Y
Yu Yang 已提交
564 565 566
  }
}

567 568 569 570 571 572 573 574 575
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
                                  bool create_local_scope, bool create_vars,
                                  bool keep_kids) {
  int64_t start_op_index = 0;
  int64_t end_op_index = ctx->ops_.size();
  RunPartialPreparedContext(ctx, scope, start_op_index, end_op_index,
                            create_local_scope, create_vars, keep_kids);
}

576 577
void Executor::RunPreparedContext(
    ExecutorPrepareContext* ctx, Scope* scope,
578
    std::map<std::string, const LoDTensor*>* feed_targets,
579
    std::map<std::string, FetchType*>* fetch_targets, bool create_local_scope,
W
Wu Yi 已提交
580 581
    bool create_vars, const std::string& feed_holder_name,
    const std::string& fetch_holder_name) {
582 583
  auto& global_block = ctx->prog_.Block(ctx->block_id_);

584 585 586 587 588
  PADDLE_ENFORCE_EQ(
      has_feed_operators(global_block, *feed_targets, feed_holder_name), true,
      platform::errors::PreconditionNotMet(
          "Program in ExecutorPrepareContext should has feed_ops."));
  PADDLE_ENFORCE_EQ(
589
      has_fetch_operators(global_block, *fetch_targets, fetch_holder_name),
590 591 592
      true,
      platform::errors::PreconditionNotMet(
          "Program in the prepared context should has fetch_ops."));
593

594 595 596 597
  // map the data of feed_targets to feed_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFeedOpType) {
      std::string feed_target_name = op->Output("Out")[0];
598
      int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
599 600
      SetFeedVariable(scope, *(*feed_targets)[feed_target_name],
                      feed_holder_name, idx);
601 602 603
    }
  }

W
Wu Yi 已提交
604
  RunPreparedContext(ctx, scope, create_local_scope, create_vars);
605 606 607 608 609

  // obtain the data of fetch_targets from fetch_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFetchOpType) {
      std::string fetch_target_name = op->Input("X")[0];
610
      int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
611
      *(*fetch_targets)[fetch_target_name] =
612 613 614 615 616
          GetFetchVariable(*scope, fetch_holder_name, idx);
    }
  }
}

617 618
void Executor::EnableMKLDNN(const ProgramDesc& program) {
#ifdef PADDLE_WITH_MKLDNN
M
minqiyang 已提交
619
  VLOG(3) << "use_mkldnn=True";
620 621 622 623 624 625 626 627
  for (size_t bid = 0; bid < program.Size(); ++bid) {
    auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
    for (auto* op : block->AllOps()) {
      if (op->HasAttr("use_mkldnn")) {
        op->SetAttr("use_mkldnn", true);
      }
    }
  }
628 629 630
#else
  LOG(WARNING)
      << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
631 632
#endif
}
Q
qijun 已提交
633 634
}  // namespace framework
}  // namespace paddle