executor.cc 22.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/executor.h"
S
sneaxiy 已提交
16
#include <deque>
17
#include <memory>
18
#include <unordered_map>
19
#include <unordered_set>
S
sneaxiy 已提交
20
#include <utility>
D
dongdaxiang 已提交
21 22 23
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
24
#include "paddle/fluid/framework/data_type.h"
Y
Yi Wang 已提交
25 26 27 28 29
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
D
dongdaxiang 已提交
30 31
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
32
#include "paddle/fluid/framework/transfer_scope_cache.h"
W
Wang Guibao 已提交
33
#include "paddle/fluid/framework/variable_helper.h"
Z
Zeng Jinle 已提交
34
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
35
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
S
sneaxiy 已提交
36
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
W
Wu Yi 已提交
37
#include "paddle/fluid/operators/distributed/distributed.h"
Y
Yi Wang 已提交
38
#include "paddle/fluid/platform/place.h"
X
Xin Pan 已提交
39
#include "paddle/fluid/platform/profiler.h"
40 41 42
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
Y
Yang Yu 已提交
43

D
dzhwinter 已提交
44
DECLARE_bool(benchmark);
45
DECLARE_bool(use_mkldnn);
Q
qijun 已提交
46 47 48

namespace paddle {
namespace framework {
X
Xin Pan 已提交
49 50 51 52 53
namespace {
// block id starts from 0. This id is used to represent the codeblock
// wrapping the first block 0.
int kProgramId = -1;
}  // namespace
Q
qijun 已提交
54

Q
Qiao Longfei 已提交
55
ExecutorPrepareContext::ExecutorPrepareContext(
S
sneaxiy 已提交
56 57 58 59 60
    const framework::ProgramDesc& prog, size_t block_id)
    : prog_(prog), block_id_(block_id) {}

void ExecutorPrepareContext::PrepareUnusedVars(
    const std::vector<std::string>& keep_vars, bool force_disable_gc) {
Z
Zeng Jinle 已提交
61 62 63
  // If gc is enabled and block size > 1
  if (prog_.Size() > 1) {
    operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
64 65 66
        prog_, block_id_, ops_);
    operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(prog_, block_id_,
                                                               ops_);
Z
Zeng Jinle 已提交
67
    operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
68
        prog_, block_id_, ops_);
Z
Zeng Jinle 已提交
69
  }
70 71 72 73 74 75

  force_disable_gc_ = force_disable_gc;
  if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) {
    return;
  }

S
sneaxiy 已提交
76
  unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars);
S
sneaxiy 已提交
77
}
Y
Yu Yang 已提交
78

Q
Qiao Longfei 已提交
79
ExecutorPrepareContext::~ExecutorPrepareContext() {
M
minqiyang 已提交
80
  VLOG(5) << "destroy ExecutorPrepareContext";
Q
Qiao Longfei 已提交
81
}
Y
Yu Yang 已提交
82

D
dzhwinter 已提交
83
Executor::Executor(const platform::Place& place) : place_(place) {}
Q
qijun 已提交
84

85 86
Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN
87
  // Clear mkl-dnn cache,
88
  // this is needed to have mkl-dnn unit tests working
89
  ClearMKLDNNCache(place_);
90 91 92
#endif
}

Y
Yancey1989 已提交
93
void Executor::Close() {
W
Wu Yi 已提交
94
#ifdef PADDLE_WITH_DISTRIBUTE
W
Wu Yi 已提交
95 96
  // TODO(typhoonzero): complete message will need to use real trainer_id,
  // except 0.
97 98 99
  auto client =
      paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
  client->SendComplete();
W
Wu Yi 已提交
100
#endif
Y
Yancey1989 已提交
101
}
W
Wu Yi 已提交
102

L
Liu Yiqun 已提交
103 104
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
                               int block_id) {
105
  VLOG(3) << "Creating Variables for block " << block_id;
L
Liu Yiqun 已提交
106
  auto& global_block = pdesc.Block(block_id);
107 108 109 110 111 112 113 114 115 116 117 118
  const Scope* ancestor_scope = scope;
  while (ancestor_scope->parent()) {
    ancestor_scope = ancestor_scope->parent();
  }
  if (ancestor_scope != scope) {
    for (auto& var : global_block.AllVars()) {
      if (var->Name() == framework::kEmptyVarName) {
        continue;
      }

      if (var->Persistable()) {
        auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
119
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
120 121
        VLOG(3) << "Create Variable " << var->Name()
                << " global, which pointer is " << ptr;
122 123
      } else {
        auto* ptr = scope->Var(var->Name());
124
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
125 126
        VLOG(3) << "Create Variable " << var->Name()
                << " locally, which pointer is " << ptr;
127 128 129 130 131
      }
    }
  } else {
    for (auto& var : global_block.AllVars()) {
      auto* ptr = scope->Var(var->Name());
132
      InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
133 134
      VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
              << ptr;
135 136 137 138
    }
  }
}

139 140 141
std::shared_ptr<TrainerBase> Executor::InitForDataset(
    const ProgramDesc& main_program, const std::string& trainer_desc_str,
    Scope* scope, Dataset* dataset) {
D
dongdaxiang 已提交
142 143
  VLOG(3) << "Start to RunFromDataset in executor";
  TrainerDesc trainer_desc;
H
hutuxian 已提交
144
  bool success = trainer_desc.ParseFromString(trainer_desc_str);
145 146 147 148
  PADDLE_ENFORCE_EQ(success, true,
                    platform::errors::PreconditionNotMet(
                        "Fail to parse TrainerDesc from string:\n%s",
                        trainer_desc_str.c_str()));
D
dongdaxiang 已提交
149 150 151 152 153 154 155 156
  VLOG(3) << "Going to create trainer, trainer class is "
          << trainer_desc.class_name();
  std::shared_ptr<TrainerBase> trainer;
  trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name());
  // initialize trainer
  VLOG(3) << "Going to initialize trainer";
  trainer->Initialize(trainer_desc, dataset);
  VLOG(3) << "Set root scope here";
D
dongdaxiang 已提交
157
  trainer->SetScope(scope);
D
dongdaxiang 已提交
158 159 160 161 162
  // prepare training environment and helper environment
  VLOG(3) << "Try to init train environment";
  trainer->InitTrainerEnv(main_program, place_);
  VLOG(3) << "Try to init other environment";
  trainer->InitOtherEnv(main_program);
163 164 165 166
  return trainer;
}

void Executor::RunFromDataset(std::shared_ptr<TrainerBase> trainer) {
167 168 169
  PADDLE_ENFORCE_NOT_NULL(
      trainer, platform::errors::InvalidArgument(
                   "Trainer is nullptr, invoke InitForDataset first"));
D
dongdaxiang 已提交
170 171 172
  // training and finalize training
  VLOG(3) << "Trainer starts to run";
  trainer->Run();
D
Dong Daxiang 已提交
173 174 175
}

void Executor::ReleaseTrainer(std::shared_ptr<TrainerBase> trainer) {
D
dongdaxiang 已提交
176 177 178
  VLOG(3) << "Trainer going to finalize";
  trainer->Finalize();
}
D
dongdaxiang 已提交
179

Y
Yu Yang 已提交
180
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
S
sneaxiy 已提交
181 182
                   bool create_local_scope, bool create_vars,
                   const std::vector<std::string>& skip_ref_cnt_vars,
183
                   bool force_disable_gc, bool keep_kid_scopes) {
X
Xin Pan 已提交
184
  platform::RecordBlock b(block_id);
185
  if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
S
sneaxiy 已提交
186
  auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
187 188
  RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars,
                     keep_kid_scopes);
Q
qijun 已提交
189 190
}

191 192 193 194 195 196 197
// Check whether the block already has feed operators and feed_holder.
// Return false if the block does not have any feed operators.
// If some feed operators have been prepended to the block, check that
// the info contained in these feed operators matches the feed_targets
// and feed_holder_name. Raise exception when any mismatch is found.
// Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators(
198
    const BlockDesc& block,
L
Liu Yiqun 已提交
199
    const std::map<std::string, const LoDTensor*>& feed_targets,
200 201
    const std::string& feed_holder_name) {
  size_t feed_count = 0;
202
  for (auto* op : block.AllOps()) {
203 204
    if (op->Type() == kFeedOpType) {
      feed_count++;
L
Liu Yiqun 已提交
205
      // The input variable's name of feed_op should be feed_holder_name.
206 207 208 209 210
      PADDLE_ENFORCE_EQ(
          op->Input("X")[0], feed_holder_name,
          platform::errors::PreconditionNotMet(
              "Input to feed op should be '%s', but received '%s'.",
              feed_holder_name, op->Input("X")[0]));
211
      std::string feed_target_name = op->Output("Out")[0];
212 213 214 215 216
      PADDLE_ENFORCE_NE(feed_targets.find(feed_target_name), feed_targets.end(),
                        platform::errors::PreconditionNotMet(
                            "Feed operator output name '%s' cannot be found in "
                            "'feed_targets'",
                            feed_target_name));
217 218 219 220 221 222
    }
  }

  if (feed_count > 0) {
    PADDLE_ENFORCE_EQ(
        feed_count, feed_targets.size(),
223 224 225 226
        platform::errors::PreconditionNotMet(
            "The number of feed operators should match 'feed_targets', but "
            "received feed_count: %zu, required feed_targets.size(): %zu.",
            feed_count, feed_targets.size()));
227

228
    if (!feed_holder_name.empty()) {
L
Liu Yiqun 已提交
229
      // When feed operator are present, so should be feed_holder.
230
      auto var = block.FindVar(feed_holder_name);
231 232 233 234 235 236 237 238 239 240
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable", feed_holder_name));
      PADDLE_ENFORCE_EQ(
          var->GetType(), proto::VarType::FEED_MINIBATCH,
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FEED_MINIBATCH' type, but received "
              "'%s'.",
              feed_holder_name, DataTypeToString(var->GetType())));
241
    }
242 243 244 245 246 247 248 249 250 251 252 253
  }

  return feed_count > 0;
}

// Check whether the block already has fetch operators and fetch_holder.
// Return false if the block does not have any fetch operators.
// If some fetch operators have been appended to the block, check that
// the info contained in these fetch operators matches the fetch_targets
// and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators(
L
Liu Yiqun 已提交
254
    const BlockDesc& block,
255
    const std::map<std::string, FetchType*>& fetch_targets,
256 257
    const std::string& fetch_holder_name) {
  size_t fetch_count = 0;
258
  for (auto* op : block.AllOps()) {
259 260
    if (op->Type() == kFetchOpType) {
      fetch_count++;
L
Liu Yiqun 已提交
261
      // The output variable's name of fetch_op should be fetch_holder_name.
262 263 264 265 266
      PADDLE_ENFORCE_EQ(
          op->Output("Out")[0], fetch_holder_name,
          platform::errors::PreconditionNotMet(
              "Output of fetch op should be '%s', but received '%s'.",
              fetch_holder_name, op->Output("Out")[0]));
267
      std::string fetch_target_name = op->Input("X")[0];
268 269 270 271 272 273
      PADDLE_ENFORCE_NE(fetch_targets.find(fetch_target_name),
                        fetch_targets.end(),
                        platform::errors::NotFound(
                            "Fetch operator input name '%s' cannot be found in "
                            "'fetch_targets'.",
                            fetch_target_name));
274 275 276 277 278 279
    }
  }

  if (fetch_count > 0) {
    PADDLE_ENFORCE_EQ(
        fetch_count, fetch_targets.size(),
280 281 282 283
        platform::errors::PreconditionNotMet(
            "The number of fetch operators should match 'fetch_targets', but "
            "received fetch_count: %zu, required fetch_targets.size(): %zu.",
            fetch_count, fetch_targets.size()));
284

285
    if (!fetch_holder_name.empty()) {
L
Liu Yiqun 已提交
286
      // When fetch operator are present, so should be fetch_holder.
287
      auto var = block.FindVar(fetch_holder_name);
288 289 290 291 292 293 294 295 296
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable.", fetch_holder_name));
      PADDLE_ENFORCE_EQ(
          var->GetType(), proto::VarType::FETCH_LIST,
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FETCH_LIST' type, but received '%s'.",
              fetch_holder_name, DataTypeToString(var->GetType())));
297
    }
298 299 300 301 302 303
  }

  return fetch_count > 0;
}

void Executor::Run(const ProgramDesc& program, Scope* scope,
304
                   std::map<std::string, const LoDTensor*>* feed_targets,
305
                   std::map<std::string, FetchType*>* fetch_targets,
W
Wu Yi 已提交
306 307
                   bool create_local_scope, bool create_vars,
                   const std::string& feed_holder_name,
308
                   const std::string& fetch_holder_name) {
X
Xin Pan 已提交
309
  platform::RecordBlock b(kProgramId);
310
  if (FLAGS_use_mkldnn) EnableMKLDNN(program);
311
  bool has_feed_ops =
312
      has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
313
  bool has_fetch_ops =
314
      has_fetch_operators(program.Block(0), *fetch_targets, fetch_holder_name);
315 316

  ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
S
sneaxiy 已提交
317
  std::unique_ptr<ProgramDesc> unique_ptr_of_copy_program;
318
  if (!has_feed_ops || !has_fetch_ops) {
S
sneaxiy 已提交
319 320
    unique_ptr_of_copy_program.reset(new ProgramDesc(program));
    copy_program = unique_ptr_of_copy_program.get();
321
  }
322 323
  auto* global_block = copy_program->MutableBlock(0);

324
  if (!has_feed_ops) {
325 326
    // create feed_holder variable
    auto* feed_holder = global_block->Var(feed_holder_name);
327
    feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
328 329 330
    feed_holder->SetPersistable(true);

    int i = 0;
331
    for (auto& feed_target : (*feed_targets)) {
332
      std::string var_name = feed_target.first;
M
minqiyang 已提交
333
      VLOG(3) << "feed target's name: " << var_name;
334 335 336 337 338 339 340 341 342 343 344 345 346

      // prepend feed op
      auto* op = global_block->PrependOp();
      op->SetType(kFeedOpType);
      op->SetInput("X", {feed_holder_name});
      op->SetOutput("Out", {var_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

347
  if (!has_fetch_ops) {
348 349
    // create fetch_holder variable
    auto* fetch_holder = global_block->Var(fetch_holder_name);
350
    fetch_holder->SetType(proto::VarType::FETCH_LIST);
351 352 353
    fetch_holder->SetPersistable(true);

    int i = 0;
354
    for (auto& fetch_target : (*fetch_targets)) {
355
      std::string var_name = fetch_target.first;
M
minqiyang 已提交
356
      VLOG(3) << "fetch target's name: " << var_name;
357 358 359 360 361 362 363 364 365 366 367 368 369

      // append fetch op
      auto* op = global_block->AppendOp();
      op->SetType(kFetchOpType);
      op->SetInput("X", {var_name});
      op->SetOutput("Out", {fetch_holder_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

370
  auto ctx = Prepare(*copy_program, 0);
W
Wu Yi 已提交
371 372 373
  RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
                     create_local_scope, create_vars, feed_holder_name,
                     fetch_holder_name);
374 375
}

Q
Qiao Longfei 已提交
376
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
S
fix bug  
sneaxiy 已提交
377
    const ProgramDesc& program, int block_id,
S
sneaxiy 已提交
378
    const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
S
sneaxiy 已提交
379 380
  std::unique_ptr<ExecutorPrepareContext> ctx(
      new ExecutorPrepareContext(program, block_id));
381 382 383 384 385
  PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size(),
                    platform::errors::InvalidArgument(
                        "Input block id = %d, but it should be less than "
                        "program.size() which is %d",
                        static_cast<size_t>(block_id), program.Size()));
Y
Yu Yang 已提交
386 387 388 389
  auto& block = program.Block(block_id);
  for (auto& op_desc : block.AllOps()) {
    ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
S
sneaxiy 已提交
390
  ctx->PrepareUnusedVars(skip_ref_cnt_vars, force_disable_gc);
Q
Qiyang Min 已提交
391
  return ctx;
Y
Yu Yang 已提交
392 393
}

T
refine  
typhoonzero 已提交
394
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
S
fix bug  
sneaxiy 已提交
395
    const ProgramDesc& program, const std::vector<int>& block_ids,
S
sneaxiy 已提交
396 397
    const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
    bool force_disable_gc) {
398
  PADDLE_ENFORCE_EQ(
S
fix bug  
sneaxiy 已提交
399
      skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
400 401 402 403
      true,
      platform::errors::InvalidArgument("skip_ref_cnt_vars should be either "
                                        "empty or equals to block number %d",
                                        block_ids.size()));
T
typhoonzero 已提交
404
  std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
S
fix bug  
sneaxiy 已提交
405
  size_t idx = 0;
T
typhoonzero 已提交
406
  for (auto& bid : block_ids) {
407 408 409 410 411
    PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size(),
                      platform::errors::InvalidArgument(
                          "Input block id = %zu, but it should be less than "
                          "program.size() which is %zu",
                          static_cast<size_t>(bid), program.Size()));
S
sneaxiy 已提交
412
    auto* ctx = new ExecutorPrepareContext(program, bid);
T
typhoonzero 已提交
413 414 415 416
    auto& block = program.Block(bid);
    for (auto& op_desc : block.AllOps()) {
      ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
    }
S
sneaxiy 已提交
417 418 419 420 421
    if (skip_ref_cnt_vars.empty()) {
      ctx->PrepareUnusedVars(std::vector<std::string>(), force_disable_gc);
    } else {
      ctx->PrepareUnusedVars(skip_ref_cnt_vars[idx], force_disable_gc);
    }
T
typhoonzero 已提交
422
    result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx));
S
fix bug  
sneaxiy 已提交
423
    ++idx;
T
typhoonzero 已提交
424 425 426 427
  }
  return result;
}

428 429 430 431 432
void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
                                         Scope* scope, int64_t start_op_index,
                                         int64_t end_op_index,
                                         bool create_local_scope,
                                         bool create_vars, bool keep_kids) {
433
  platform::RecordBlock b(kProgramId);
434 435
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope shouldn't be null"));
Y
Yu Yang 已提交
436 437 438 439
  Scope* local_scope = scope;
  if (create_vars) {
    if (create_local_scope) {
      local_scope = &scope->NewScope();
440 441
    }
    CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
L
Liu Yiqun 已提交
442
  }
Y
Yu Yang 已提交
443

S
sneaxiy 已提交
444
  int64_t max_memory_size = GetEagerDeletionThreshold();
S
sneaxiy 已提交
445
  std::unique_ptr<GarbageCollector> gc;
S
sneaxiy 已提交
446
  if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
S
sneaxiy 已提交
447
    if (platform::is_gpu_place(place_)) {
448
#ifdef PADDLE_WITH_CUDA
S
fix bug  
sneaxiy 已提交
449
      if (IsFastEagerDeletionModeEnabled()) {
S
sneaxiy 已提交
450
        gc.reset(new UnsafeFastGPUGarbageCollector(
451
            BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
S
fix bug  
sneaxiy 已提交
452
      } else {
S
sneaxiy 已提交
453
        gc.reset(new DefaultStreamGarbageCollector(
454
            BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
S
fix bug  
sneaxiy 已提交
455
      }
456 457 458
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No GPU gc found in CPU/XPU paddle"));
S
sneaxiy 已提交
459
#endif
460
    } else if (platform::is_cpu_place(place_)) {
461 462
      gc.reset(new CPUGarbageCollector(
          BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
463 464 465 466 467 468 469
    } else if (platform::is_xpu_place(place_)) {
#ifdef PADDLE_WITH_XPU
      gc.reset(new XPUGarbageCollector(
          BOOST_GET_CONST(platform::XPUPlace, place_), max_memory_size));
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No XPU gc found in CPU/GPU paddle"));
S
sneaxiy 已提交
470
#endif
471
    }
S
sneaxiy 已提交
472 473
  }

474 475
  for (int64_t i = start_op_index; i < end_op_index; ++i) {
    auto& op = ctx->ops_[i];
476
    op->Run(*local_scope, place_);
S
fix bug  
sneaxiy 已提交
477
    if (gc) {
S
sneaxiy 已提交
478
      DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get());
S
sneaxiy 已提交
479
    }
Y
Yu Yang 已提交
480
  }
S
sneaxiy 已提交
481

L
Leo Chen 已提交
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
  auto callback = [scope, local_scope, keep_kids]() {
    if (local_scope != scope) {
      VLOG(4) << "Delete scope: " << local_scope;
      scope->DeleteScope(local_scope);
    } else {
      if (!keep_kids) {
        VLOG(4) << "Drop kids: " << scope;
        // By default, we should delete all kid scopes after run executor
        // because
        // some operators may create local scope when running, such as while_op.
        // But when while_op also create a local executor to run it's sub block,
        // the sub scopes it created should not be dropped immediately, because
        // while_grad_op will use some variables created during while_op run, so
        // we need to keep the kids and wait for the outer executor to drop
        // them.

        scope->DropKids();
      }
      VLOG(4) << "Keep kids: " << scope;
    }
  };
S
sneaxiy 已提交
503

L
Leo Chen 已提交
504 505 506
  if (gc) {
    VLOG(4) << "Async deleting scope";
    gc->DirectClearCallback(callback);
507
  } else {
L
Leo Chen 已提交
508 509 510
    VLOG(4) << "Sync deleting scope";
    platform::DeviceContextPool::Instance().Get(place_)->Wait();
    callback();
Y
Yu Yang 已提交
511 512 513
  }
}

514 515 516 517 518 519 520 521 522
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
                                  bool create_local_scope, bool create_vars,
                                  bool keep_kids) {
  int64_t start_op_index = 0;
  int64_t end_op_index = ctx->ops_.size();
  RunPartialPreparedContext(ctx, scope, start_op_index, end_op_index,
                            create_local_scope, create_vars, keep_kids);
}

523 524
void Executor::RunPreparedContext(
    ExecutorPrepareContext* ctx, Scope* scope,
525
    std::map<std::string, const LoDTensor*>* feed_targets,
526
    std::map<std::string, FetchType*>* fetch_targets, bool create_local_scope,
W
Wu Yi 已提交
527 528
    bool create_vars, const std::string& feed_holder_name,
    const std::string& fetch_holder_name) {
529 530
  auto& global_block = ctx->prog_.Block(ctx->block_id_);

531 532 533 534 535
  PADDLE_ENFORCE_EQ(
      has_feed_operators(global_block, *feed_targets, feed_holder_name), true,
      platform::errors::PreconditionNotMet(
          "Program in ExecutorPrepareContext should has feed_ops."));
  PADDLE_ENFORCE_EQ(
536
      has_fetch_operators(global_block, *fetch_targets, fetch_holder_name),
537 538
      true, platform::errors::PreconditionNotMet(
                "Program in the prepared context should has fetch_ops."));
539

540 541 542 543
  // map the data of feed_targets to feed_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFeedOpType) {
      std::string feed_target_name = op->Output("Out")[0];
544
      int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
545 546
      SetFeedVariable(scope, *(*feed_targets)[feed_target_name],
                      feed_holder_name, idx);
547 548 549
    }
  }

W
Wu Yi 已提交
550
  RunPreparedContext(ctx, scope, create_local_scope, create_vars);
551 552 553 554 555

  // obtain the data of fetch_targets from fetch_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFetchOpType) {
      std::string fetch_target_name = op->Input("X")[0];
556
      int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
557
      *(*fetch_targets)[fetch_target_name] =
558 559 560 561 562
          GetFetchVariable(*scope, fetch_holder_name, idx);
    }
  }
}

563 564
void Executor::EnableMKLDNN(const ProgramDesc& program) {
#ifdef PADDLE_WITH_MKLDNN
M
minqiyang 已提交
565
  VLOG(3) << "use_mkldnn=True";
566 567 568 569 570 571 572 573
  for (size_t bid = 0; bid < program.Size(); ++bid) {
    auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
    for (auto* op : block->AllOps()) {
      if (op->HasAttr("use_mkldnn")) {
        op->SetAttr("use_mkldnn", true);
      }
    }
  }
574
  platform::AttachPointerHashToMKLDNNKey(this, place_);
575 576 577
#else
  LOG(WARNING)
      << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
578 579
#endif
}
Q
qijun 已提交
580 581
}  // namespace framework
}  // namespace paddle