executor.cc 21.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/executor.h"
16
#include <memory>
Y
Yi Wang 已提交
17
#include "paddle/fluid/framework/feed_fetch_method.h"
D
dongdaxiang 已提交
18 19
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
Z
Zeng Jinle 已提交
20
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
21
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
S
sneaxiy 已提交
22
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
Y
Yi Wang 已提交
23
#include "paddle/fluid/platform/place.h"
X
Xin Pan 已提交
24
#include "paddle/fluid/platform/profiler.h"
25 26 27
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
28
#include "paddle/fluid/framework/executor_gc_helper.h"
Y
Yang Yu 已提交
29

D
dzhwinter 已提交
30
DECLARE_bool(benchmark);
31
DECLARE_bool(use_mkldnn);
Q
qijun 已提交
32 33 34

namespace paddle {
namespace framework {
X
Xin Pan 已提交
35 36 37 38 39
namespace {
// block id starts from 0. This id is used to represent the codeblock
// wrapping the first block 0.
int kProgramId = -1;
}  // namespace
Q
qijun 已提交
40

Q
Qiao Longfei 已提交
41
ExecutorPrepareContext::ExecutorPrepareContext(
S
sneaxiy 已提交
42 43 44 45 46
    const framework::ProgramDesc& prog, size_t block_id)
    : prog_(prog), block_id_(block_id) {}

void ExecutorPrepareContext::PrepareUnusedVars(
    const std::vector<std::string>& keep_vars, bool force_disable_gc) {
Z
Zeng Jinle 已提交
47 48 49
  // If gc is enabled and block size > 1
  if (prog_.Size() > 1) {
    operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
50 51 52
        prog_, block_id_, ops_);
    operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(prog_, block_id_,
                                                               ops_);
Z
Zeng Jinle 已提交
53
    operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
54
        prog_, block_id_, ops_);
Z
Zeng Jinle 已提交
55
  }
56 57 58 59 60 61

  force_disable_gc_ = force_disable_gc;
  if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) {
    return;
  }

S
sneaxiy 已提交
62
  unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars);
S
sneaxiy 已提交
63
}
Y
Yu Yang 已提交
64

Q
Qiao Longfei 已提交
65
ExecutorPrepareContext::~ExecutorPrepareContext() {
M
minqiyang 已提交
66
  VLOG(5) << "destroy ExecutorPrepareContext";
Q
Qiao Longfei 已提交
67
}
Y
Yu Yang 已提交
68

D
dzhwinter 已提交
69
Executor::Executor(const platform::Place& place) : place_(place) {}
Q
qijun 已提交
70

71 72
Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN
73
  // Clear mkl-dnn cache,
74
  // this is needed to have mkl-dnn unit tests working
75
  ClearMKLDNNCache(place_);
76 77 78
#endif
}

Y
Yancey1989 已提交
79
void Executor::Close() {
T
tangwei12 已提交
80 81 82 83 84 85 86
  // #ifdef PADDLE_WITH_DISTRIBUTE
  //   // TODO(typhoonzero): complete message will need to use real trainer_id,
  //   // except 0.
  //   auto client =
  //       paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
  //   client->SendComplete();
  // #endif
Y
Yancey1989 已提交
87
}
W
Wu Yi 已提交
88

L
Liu Yiqun 已提交
89 90
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
                               int block_id) {
91
  VLOG(3) << "Creating Variables for block " << block_id;
L
Liu Yiqun 已提交
92
  auto& global_block = pdesc.Block(block_id);
93 94 95 96 97 98 99 100 101 102 103 104
  const Scope* ancestor_scope = scope;
  while (ancestor_scope->parent()) {
    ancestor_scope = ancestor_scope->parent();
  }
  if (ancestor_scope != scope) {
    for (auto& var : global_block.AllVars()) {
      if (var->Name() == framework::kEmptyVarName) {
        continue;
      }

      if (var->Persistable()) {
        auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
105
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
106 107
        VLOG(3) << "Create Variable " << var->Name()
                << " global, which pointer is " << ptr;
108 109
      } else {
        auto* ptr = scope->Var(var->Name());
110
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
111 112
        VLOG(3) << "Create Variable " << var->Name()
                << " locally, which pointer is " << ptr;
113 114 115 116 117
      }
    }
  } else {
    for (auto& var : global_block.AllVars()) {
      auto* ptr = scope->Var(var->Name());
118
      InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
119 120
      VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
              << ptr;
121 122 123 124
    }
  }
}

125 126 127
std::shared_ptr<TrainerBase> Executor::InitForDataset(
    const ProgramDesc& main_program, const std::string& trainer_desc_str,
    Scope* scope, Dataset* dataset) {
D
dongdaxiang 已提交
128 129
  VLOG(3) << "Start to RunFromDataset in executor";
  TrainerDesc trainer_desc;
H
hutuxian 已提交
130
  bool success = trainer_desc.ParseFromString(trainer_desc_str);
131 132 133 134
  PADDLE_ENFORCE_EQ(success, true,
                    platform::errors::PreconditionNotMet(
                        "Fail to parse TrainerDesc from string:\n%s",
                        trainer_desc_str.c_str()));
D
dongdaxiang 已提交
135 136 137 138 139 140 141 142
  VLOG(3) << "Going to create trainer, trainer class is "
          << trainer_desc.class_name();
  std::shared_ptr<TrainerBase> trainer;
  trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name());
  // initialize trainer
  VLOG(3) << "Going to initialize trainer";
  trainer->Initialize(trainer_desc, dataset);
  VLOG(3) << "Set root scope here";
D
dongdaxiang 已提交
143
  trainer->SetScope(scope);
D
dongdaxiang 已提交
144 145 146 147 148
  // prepare training environment and helper environment
  VLOG(3) << "Try to init train environment";
  trainer->InitTrainerEnv(main_program, place_);
  VLOG(3) << "Try to init other environment";
  trainer->InitOtherEnv(main_program);
149 150 151 152
  return trainer;
}

void Executor::RunFromDataset(std::shared_ptr<TrainerBase> trainer) {
153 154 155
  PADDLE_ENFORCE_NOT_NULL(
      trainer, platform::errors::InvalidArgument(
                   "Trainer is nullptr, invoke InitForDataset first"));
D
dongdaxiang 已提交
156 157 158
  // training and finalize training
  VLOG(3) << "Trainer starts to run";
  trainer->Run();
D
Dong Daxiang 已提交
159 160 161
}

void Executor::ReleaseTrainer(std::shared_ptr<TrainerBase> trainer) {
D
dongdaxiang 已提交
162 163 164
  VLOG(3) << "Trainer going to finalize";
  trainer->Finalize();
}
D
dongdaxiang 已提交
165

Y
Yu Yang 已提交
166
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
S
sneaxiy 已提交
167 168
                   bool create_local_scope, bool create_vars,
                   const std::vector<std::string>& skip_ref_cnt_vars,
169
                   bool force_disable_gc, bool keep_kid_scopes) {
X
Xin Pan 已提交
170
  platform::RecordBlock b(block_id);
171
  if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
S
sneaxiy 已提交
172
  auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
173 174
  RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars,
                     keep_kid_scopes);
Q
qijun 已提交
175 176
}

177 178 179 180 181 182 183
// Check whether the block already has feed operators and feed_holder.
// Return false if the block does not have any feed operators.
// If some feed operators have been prepended to the block, check that
// the info contained in these feed operators matches the feed_targets
// and feed_holder_name. Raise exception when any mismatch is found.
// Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators(
184
    const BlockDesc& block,
L
Liu Yiqun 已提交
185
    const std::map<std::string, const LoDTensor*>& feed_targets,
186 187
    const std::string& feed_holder_name) {
  size_t feed_count = 0;
188
  for (auto* op : block.AllOps()) {
189 190
    if (op->Type() == kFeedOpType) {
      feed_count++;
L
Liu Yiqun 已提交
191
      // The input variable's name of feed_op should be feed_holder_name.
192 193 194 195 196
      PADDLE_ENFORCE_EQ(
          op->Input("X")[0], feed_holder_name,
          platform::errors::PreconditionNotMet(
              "Input to feed op should be '%s', but received '%s'.",
              feed_holder_name, op->Input("X")[0]));
197
      std::string feed_target_name = op->Output("Out")[0];
198 199 200 201 202
      PADDLE_ENFORCE_NE(feed_targets.find(feed_target_name), feed_targets.end(),
                        platform::errors::PreconditionNotMet(
                            "Feed operator output name '%s' cannot be found in "
                            "'feed_targets'",
                            feed_target_name));
203 204 205 206 207 208
    }
  }

  if (feed_count > 0) {
    PADDLE_ENFORCE_EQ(
        feed_count, feed_targets.size(),
209 210 211 212
        platform::errors::PreconditionNotMet(
            "The number of feed operators should match 'feed_targets', but "
            "received feed_count: %zu, required feed_targets.size(): %zu.",
            feed_count, feed_targets.size()));
213

214
    if (!feed_holder_name.empty()) {
L
Liu Yiqun 已提交
215
      // When feed operator are present, so should be feed_holder.
216
      auto var = block.FindVar(feed_holder_name);
217 218 219 220 221 222 223 224 225 226
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable", feed_holder_name));
      PADDLE_ENFORCE_EQ(
          var->GetType(), proto::VarType::FEED_MINIBATCH,
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FEED_MINIBATCH' type, but received "
              "'%s'.",
              feed_holder_name, DataTypeToString(var->GetType())));
227
    }
228 229 230 231 232 233 234 235 236 237 238 239
  }

  return feed_count > 0;
}

// Check whether the block already has fetch operators and fetch_holder.
// Return false if the block does not have any fetch operators.
// If some fetch operators have been appended to the block, check that
// the info contained in these fetch operators matches the fetch_targets
// and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators(
L
Liu Yiqun 已提交
240
    const BlockDesc& block,
241
    const std::map<std::string, FetchType*>& fetch_targets,
242 243
    const std::string& fetch_holder_name) {
  size_t fetch_count = 0;
244
  for (auto* op : block.AllOps()) {
245 246
    if (op->Type() == kFetchOpType) {
      fetch_count++;
L
Liu Yiqun 已提交
247
      // The output variable's name of fetch_op should be fetch_holder_name.
248 249 250 251 252
      PADDLE_ENFORCE_EQ(
          op->Output("Out")[0], fetch_holder_name,
          platform::errors::PreconditionNotMet(
              "Output of fetch op should be '%s', but received '%s'.",
              fetch_holder_name, op->Output("Out")[0]));
253
      std::string fetch_target_name = op->Input("X")[0];
254 255 256 257 258 259
      PADDLE_ENFORCE_NE(fetch_targets.find(fetch_target_name),
                        fetch_targets.end(),
                        platform::errors::NotFound(
                            "Fetch operator input name '%s' cannot be found in "
                            "'fetch_targets'.",
                            fetch_target_name));
260 261 262 263 264 265
    }
  }

  if (fetch_count > 0) {
    PADDLE_ENFORCE_EQ(
        fetch_count, fetch_targets.size(),
266 267 268 269
        platform::errors::PreconditionNotMet(
            "The number of fetch operators should match 'fetch_targets', but "
            "received fetch_count: %zu, required fetch_targets.size(): %zu.",
            fetch_count, fetch_targets.size()));
270

271
    if (!fetch_holder_name.empty()) {
L
Liu Yiqun 已提交
272
      // When fetch operator are present, so should be fetch_holder.
273
      auto var = block.FindVar(fetch_holder_name);
274 275 276 277 278 279 280 281 282
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable.", fetch_holder_name));
      PADDLE_ENFORCE_EQ(
          var->GetType(), proto::VarType::FETCH_LIST,
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FETCH_LIST' type, but received '%s'.",
              fetch_holder_name, DataTypeToString(var->GetType())));
283
    }
284 285 286 287 288 289
  }

  return fetch_count > 0;
}

void Executor::Run(const ProgramDesc& program, Scope* scope,
290
                   std::map<std::string, const LoDTensor*>* feed_targets,
291
                   std::map<std::string, FetchType*>* fetch_targets,
W
Wu Yi 已提交
292 293
                   bool create_local_scope, bool create_vars,
                   const std::string& feed_holder_name,
294
                   const std::string& fetch_holder_name) {
X
Xin Pan 已提交
295
  platform::RecordBlock b(kProgramId);
296
  if (FLAGS_use_mkldnn) EnableMKLDNN(program);
297
  bool has_feed_ops =
298
      has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
299
  bool has_fetch_ops =
300
      has_fetch_operators(program.Block(0), *fetch_targets, fetch_holder_name);
301 302

  ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
S
sneaxiy 已提交
303
  std::unique_ptr<ProgramDesc> unique_ptr_of_copy_program;
304
  if (!has_feed_ops || !has_fetch_ops) {
S
sneaxiy 已提交
305 306
    unique_ptr_of_copy_program.reset(new ProgramDesc(program));
    copy_program = unique_ptr_of_copy_program.get();
307
  }
308 309
  auto* global_block = copy_program->MutableBlock(0);

310
  if (!has_feed_ops) {
311 312
    // create feed_holder variable
    auto* feed_holder = global_block->Var(feed_holder_name);
313
    feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
314 315 316
    feed_holder->SetPersistable(true);

    int i = 0;
317
    for (auto& feed_target : (*feed_targets)) {
318
      std::string var_name = feed_target.first;
M
minqiyang 已提交
319
      VLOG(3) << "feed target's name: " << var_name;
320 321 322 323 324 325 326 327 328 329 330 331 332

      // prepend feed op
      auto* op = global_block->PrependOp();
      op->SetType(kFeedOpType);
      op->SetInput("X", {feed_holder_name});
      op->SetOutput("Out", {var_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

333
  if (!has_fetch_ops) {
334 335
    // create fetch_holder variable
    auto* fetch_holder = global_block->Var(fetch_holder_name);
336
    fetch_holder->SetType(proto::VarType::FETCH_LIST);
337 338 339
    fetch_holder->SetPersistable(true);

    int i = 0;
340
    for (auto& fetch_target : (*fetch_targets)) {
341
      std::string var_name = fetch_target.first;
M
minqiyang 已提交
342
      VLOG(3) << "fetch target's name: " << var_name;
343 344 345 346 347 348 349 350 351 352 353 354 355

      // append fetch op
      auto* op = global_block->AppendOp();
      op->SetType(kFetchOpType);
      op->SetInput("X", {var_name});
      op->SetOutput("Out", {fetch_holder_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

356
  auto ctx = Prepare(*copy_program, 0);
W
Wu Yi 已提交
357 358 359
  RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
                     create_local_scope, create_vars, feed_holder_name,
                     fetch_holder_name);
360 361
}

Q
Qiao Longfei 已提交
362
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
S
fix bug  
sneaxiy 已提交
363
    const ProgramDesc& program, int block_id,
S
sneaxiy 已提交
364
    const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
S
sneaxiy 已提交
365 366
  std::unique_ptr<ExecutorPrepareContext> ctx(
      new ExecutorPrepareContext(program, block_id));
367 368 369 370 371
  PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size(),
                    platform::errors::InvalidArgument(
                        "Input block id = %d, but it should be less than "
                        "program.size() which is %d",
                        static_cast<size_t>(block_id), program.Size()));
Y
Yu Yang 已提交
372 373 374 375
  auto& block = program.Block(block_id);
  for (auto& op_desc : block.AllOps()) {
    ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
S
sneaxiy 已提交
376
  ctx->PrepareUnusedVars(skip_ref_cnt_vars, force_disable_gc);
Q
Qiyang Min 已提交
377
  return ctx;
Y
Yu Yang 已提交
378 379
}

T
refine  
typhoonzero 已提交
380
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
S
fix bug  
sneaxiy 已提交
381
    const ProgramDesc& program, const std::vector<int>& block_ids,
S
sneaxiy 已提交
382 383
    const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
    bool force_disable_gc) {
384
  PADDLE_ENFORCE_EQ(
S
fix bug  
sneaxiy 已提交
385
      skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
386 387 388 389
      true,
      platform::errors::InvalidArgument("skip_ref_cnt_vars should be either "
                                        "empty or equals to block number %d",
                                        block_ids.size()));
T
typhoonzero 已提交
390
  std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
S
fix bug  
sneaxiy 已提交
391
  size_t idx = 0;
T
typhoonzero 已提交
392
  for (auto& bid : block_ids) {
393 394 395 396 397
    PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size(),
                      platform::errors::InvalidArgument(
                          "Input block id = %zu, but it should be less than "
                          "program.size() which is %zu",
                          static_cast<size_t>(bid), program.Size()));
S
sneaxiy 已提交
398
    auto* ctx = new ExecutorPrepareContext(program, bid);
T
typhoonzero 已提交
399 400 401 402
    auto& block = program.Block(bid);
    for (auto& op_desc : block.AllOps()) {
      ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
    }
S
sneaxiy 已提交
403 404 405 406 407
    if (skip_ref_cnt_vars.empty()) {
      ctx->PrepareUnusedVars(std::vector<std::string>(), force_disable_gc);
    } else {
      ctx->PrepareUnusedVars(skip_ref_cnt_vars[idx], force_disable_gc);
    }
T
typhoonzero 已提交
408
    result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx));
S
fix bug  
sneaxiy 已提交
409
    ++idx;
T
typhoonzero 已提交
410 411 412 413
  }
  return result;
}

414 415 416 417 418
void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
                                         Scope* scope, int64_t start_op_index,
                                         int64_t end_op_index,
                                         bool create_local_scope,
                                         bool create_vars, bool keep_kids) {
419
  platform::RecordBlock b(kProgramId);
420 421
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope shouldn't be null"));
Y
Yu Yang 已提交
422 423 424 425
  Scope* local_scope = scope;
  if (create_vars) {
    if (create_local_scope) {
      local_scope = &scope->NewScope();
426 427
    }
    CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
L
Liu Yiqun 已提交
428
  }
Y
Yu Yang 已提交
429

S
sneaxiy 已提交
430
  int64_t max_memory_size = GetEagerDeletionThreshold();
S
sneaxiy 已提交
431
  std::unique_ptr<GarbageCollector> gc;
S
sneaxiy 已提交
432
  if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
S
sneaxiy 已提交
433
    if (platform::is_gpu_place(place_)) {
434
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
fix bug  
sneaxiy 已提交
435
      if (IsFastEagerDeletionModeEnabled()) {
S
sneaxiy 已提交
436
        gc.reset(new UnsafeFastGPUGarbageCollector(
437
            BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
S
fix bug  
sneaxiy 已提交
438
      } else {
S
sneaxiy 已提交
439
        gc.reset(new DefaultStreamGarbageCollector(
440
            BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
S
fix bug  
sneaxiy 已提交
441
      }
442 443 444
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No GPU gc found in CPU/XPU paddle"));
S
sneaxiy 已提交
445
#endif
446
    } else if (platform::is_cpu_place(place_)) {
447 448
      gc.reset(new CPUGarbageCollector(
          BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
449 450 451 452 453 454 455
    } else if (platform::is_xpu_place(place_)) {
#ifdef PADDLE_WITH_XPU
      gc.reset(new XPUGarbageCollector(
          BOOST_GET_CONST(platform::XPUPlace, place_), max_memory_size));
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No XPU gc found in CPU/GPU paddle"));
S
sneaxiy 已提交
456
#endif
457
    }
S
sneaxiy 已提交
458 459
  }

460 461
  for (int64_t i = start_op_index; i < end_op_index; ++i) {
    auto& op = ctx->ops_[i];
462
    op->Run(*local_scope, place_);
S
fix bug  
sneaxiy 已提交
463
    if (gc) {
S
sneaxiy 已提交
464
      DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get());
S
sneaxiy 已提交
465
    }
Y
Yu Yang 已提交
466
  }
S
sneaxiy 已提交
467

L
Leo Chen 已提交
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
  auto callback = [scope, local_scope, keep_kids]() {
    if (local_scope != scope) {
      VLOG(4) << "Delete scope: " << local_scope;
      scope->DeleteScope(local_scope);
    } else {
      if (!keep_kids) {
        VLOG(4) << "Drop kids: " << scope;
        // By default, we should delete all kid scopes after run executor
        // because
        // some operators may create local scope when running, such as while_op.
        // But when while_op also create a local executor to run it's sub block,
        // the sub scopes it created should not be dropped immediately, because
        // while_grad_op will use some variables created during while_op run, so
        // we need to keep the kids and wait for the outer executor to drop
        // them.

        scope->DropKids();
      }
      VLOG(4) << "Keep kids: " << scope;
    }
  };
S
sneaxiy 已提交
489

L
Leo Chen 已提交
490 491 492
  if (gc) {
    VLOG(4) << "Async deleting scope";
    gc->DirectClearCallback(callback);
493
  } else {
L
Leo Chen 已提交
494 495 496
    VLOG(4) << "Sync deleting scope";
    platform::DeviceContextPool::Instance().Get(place_)->Wait();
    callback();
Y
Yu Yang 已提交
497 498 499
  }
}

500 501 502 503 504 505 506 507 508
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
                                  bool create_local_scope, bool create_vars,
                                  bool keep_kids) {
  int64_t start_op_index = 0;
  int64_t end_op_index = ctx->ops_.size();
  RunPartialPreparedContext(ctx, scope, start_op_index, end_op_index,
                            create_local_scope, create_vars, keep_kids);
}

509 510
void Executor::RunPreparedContext(
    ExecutorPrepareContext* ctx, Scope* scope,
511
    std::map<std::string, const LoDTensor*>* feed_targets,
512
    std::map<std::string, FetchType*>* fetch_targets, bool create_local_scope,
W
Wu Yi 已提交
513 514
    bool create_vars, const std::string& feed_holder_name,
    const std::string& fetch_holder_name) {
515 516
  auto& global_block = ctx->prog_.Block(ctx->block_id_);

517 518 519 520 521
  PADDLE_ENFORCE_EQ(
      has_feed_operators(global_block, *feed_targets, feed_holder_name), true,
      platform::errors::PreconditionNotMet(
          "Program in ExecutorPrepareContext should has feed_ops."));
  PADDLE_ENFORCE_EQ(
522
      has_fetch_operators(global_block, *fetch_targets, fetch_holder_name),
523 524
      true, platform::errors::PreconditionNotMet(
                "Program in the prepared context should has fetch_ops."));
525

526 527 528 529
  // map the data of feed_targets to feed_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFeedOpType) {
      std::string feed_target_name = op->Output("Out")[0];
530
      int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
531 532
      SetFeedVariable(scope, *(*feed_targets)[feed_target_name],
                      feed_holder_name, idx);
533 534 535
    }
  }

W
Wu Yi 已提交
536
  RunPreparedContext(ctx, scope, create_local_scope, create_vars);
537 538 539 540 541

  // obtain the data of fetch_targets from fetch_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFetchOpType) {
      std::string fetch_target_name = op->Input("X")[0];
542
      int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
543
      *(*fetch_targets)[fetch_target_name] =
544 545 546 547 548
          GetFetchVariable(*scope, fetch_holder_name, idx);
    }
  }
}

549 550
void Executor::EnableMKLDNN(const ProgramDesc& program) {
#ifdef PADDLE_WITH_MKLDNN
M
minqiyang 已提交
551
  VLOG(3) << "use_mkldnn=True";
552 553 554 555 556 557 558 559
  for (size_t bid = 0; bid < program.Size(); ++bid) {
    auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
    for (auto* op : block->AllOps()) {
      if (op->HasAttr("use_mkldnn")) {
        op->SetAttr("use_mkldnn", true);
      }
    }
  }
560
  platform::AttachPointerHashToMKLDNNKey(this, place_);
561 562 563
#else
  LOG(WARNING)
      << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
564 565
#endif
}
Q
qijun 已提交
566 567
}  // namespace framework
}  // namespace paddle