executor.cc 21.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/executor.h"
S
sneaxiy 已提交
16
#include <deque>
17
#include <memory>
18
#include <unordered_map>
19
#include <unordered_set>
S
sneaxiy 已提交
20
#include <utility>
D
dongdaxiang 已提交
21 22 23
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
24
#include "paddle/fluid/framework/data_type.h"
Y
Yi Wang 已提交
25 26 27 28 29
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
D
dongdaxiang 已提交
30 31
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
32
#include "paddle/fluid/framework/transfer_scope_cache.h"
W
Wang Guibao 已提交
33
#include "paddle/fluid/framework/variable_helper.h"
Z
Zeng Jinle 已提交
34
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
35
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
S
sneaxiy 已提交
36
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
W
Wu Yi 已提交
37
#include "paddle/fluid/operators/distributed/distributed.h"
Y
Yi Wang 已提交
38
#include "paddle/fluid/platform/place.h"
X
Xin Pan 已提交
39
#include "paddle/fluid/platform/profiler.h"
Y
Yang Yu 已提交
40

D
dzhwinter 已提交
41
DECLARE_bool(benchmark);
42
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
Q
qijun 已提交
43 44 45

namespace paddle {
namespace framework {
X
Xin Pan 已提交
46 47 48 49 50
namespace {
// block id starts from 0. This id is used to represent the codeblock
// wrapping the first block 0.
int kProgramId = -1;
}  // namespace
Q
qijun 已提交
51

Q
Qiao Longfei 已提交
52
ExecutorPrepareContext::ExecutorPrepareContext(
S
sneaxiy 已提交
53 54 55 56 57
    const framework::ProgramDesc& prog, size_t block_id)
    : prog_(prog), block_id_(block_id) {}

void ExecutorPrepareContext::PrepareUnusedVars(
    const std::vector<std::string>& keep_vars, bool force_disable_gc) {
Z
Zeng Jinle 已提交
58 59 60
  // If gc is enabled and block size > 1
  if (prog_.Size() > 1) {
    operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
61 62 63
        prog_, block_id_, ops_);
    operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(prog_, block_id_,
                                                               ops_);
Z
Zeng Jinle 已提交
64
    operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
65
        prog_, block_id_, ops_);
Z
Zeng Jinle 已提交
66
  }
67 68 69 70 71 72

  force_disable_gc_ = force_disable_gc;
  if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) {
    return;
  }

S
sneaxiy 已提交
73
  unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars);
S
sneaxiy 已提交
74
}
Y
Yu Yang 已提交
75

Q
Qiao Longfei 已提交
76
ExecutorPrepareContext::~ExecutorPrepareContext() {
M
minqiyang 已提交
77
  VLOG(5) << "destroy ExecutorPrepareContext";
Q
Qiao Longfei 已提交
78
}
Y
Yu Yang 已提交
79

D
dzhwinter 已提交
80
Executor::Executor(const platform::Place& place) : place_(place) {}
Q
qijun 已提交
81

82 83 84 85 86 87 88 89 90 91
Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN
  // Clear mkl-dnn cache, unless explicitly
  // (as set in constructor) marked not to do so
  // this is needed to have mkl-dnn unit tests working
  if (platform::is_cpu_place(place_)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::MKLDNNDeviceContext* dev_ctx =
        (platform::MKLDNNDeviceContext*)pool.Get(place_);
    dev_ctx->ResetBlobMap();
92
    platform::set_cur_paddle_data_layout(paddle::framework::DataLayout::kNCHW);
93 94 95 96
  }
#endif
}

Y
Yancey1989 已提交
97
void Executor::Close() {
W
Wu Yi 已提交
98
#ifdef PADDLE_WITH_DISTRIBUTE
W
Wu Yi 已提交
99 100
  // TODO(typhoonzero): complete message will need to use real trainer_id,
  // except 0.
101 102 103
  auto client =
      paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
  client->SendComplete();
W
Wu Yi 已提交
104
#endif
Y
Yancey1989 已提交
105
}
W
Wu Yi 已提交
106

L
Liu Yiqun 已提交
107 108
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
                               int block_id) {
109
  VLOG(3) << "Creating Variables for block " << block_id;
L
Liu Yiqun 已提交
110
  auto& global_block = pdesc.Block(block_id);
111 112 113 114 115 116 117 118 119 120 121 122
  const Scope* ancestor_scope = scope;
  while (ancestor_scope->parent()) {
    ancestor_scope = ancestor_scope->parent();
  }
  if (ancestor_scope != scope) {
    for (auto& var : global_block.AllVars()) {
      if (var->Name() == framework::kEmptyVarName) {
        continue;
      }

      if (var->Persistable()) {
        auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
123
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
124 125
        VLOG(3) << "Create Variable " << var->Name()
                << " global, which pointer is " << ptr;
126 127
      } else {
        auto* ptr = scope->Var(var->Name());
128
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
129 130
        VLOG(3) << "Create Variable " << var->Name()
                << " locally, which pointer is " << ptr;
131 132 133 134 135
      }
    }
  } else {
    for (auto& var : global_block.AllVars()) {
      auto* ptr = scope->Var(var->Name());
136
      InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
137 138
      VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
              << ptr;
139 140 141 142
    }
  }
}

143 144 145
std::shared_ptr<TrainerBase> Executor::InitForDataset(
    const ProgramDesc& main_program, const std::string& trainer_desc_str,
    Scope* scope, Dataset* dataset) {
D
dongdaxiang 已提交
146 147
  VLOG(3) << "Start to RunFromDataset in executor";
  TrainerDesc trainer_desc;
H
hutuxian 已提交
148
  bool success = trainer_desc.ParseFromString(trainer_desc_str);
149 150 151 152
  PADDLE_ENFORCE_EQ(success, true,
                    platform::errors::PreconditionNotMet(
                        "Fail to parse TrainerDesc from string:\n%s",
                        trainer_desc_str.c_str()));
D
dongdaxiang 已提交
153 154 155 156 157 158 159 160
  VLOG(3) << "Going to create trainer, trainer class is "
          << trainer_desc.class_name();
  std::shared_ptr<TrainerBase> trainer;
  trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name());
  // initialize trainer
  VLOG(3) << "Going to initialize trainer";
  trainer->Initialize(trainer_desc, dataset);
  VLOG(3) << "Set root scope here";
D
dongdaxiang 已提交
161
  trainer->SetScope(scope);
D
dongdaxiang 已提交
162 163 164 165 166
  // prepare training environment and helper environment
  VLOG(3) << "Try to init train environment";
  trainer->InitTrainerEnv(main_program, place_);
  VLOG(3) << "Try to init other environment";
  trainer->InitOtherEnv(main_program);
167 168 169 170
  return trainer;
}

void Executor::RunFromDataset(std::shared_ptr<TrainerBase> trainer) {
171 172 173
  PADDLE_ENFORCE_NOT_NULL(
      trainer, platform::errors::InvalidArgument(
                   "Trainer is nullptr, invoke InitForDataset first"));
D
dongdaxiang 已提交
174 175 176
  // training and finalize training
  VLOG(3) << "Trainer starts to run";
  trainer->Run();
D
Dong Daxiang 已提交
177 178 179
}

void Executor::ReleaseTrainer(std::shared_ptr<TrainerBase> trainer) {
D
dongdaxiang 已提交
180 181 182
  VLOG(3) << "Trainer going to finalize";
  trainer->Finalize();
}
D
dongdaxiang 已提交
183

Y
Yu Yang 已提交
184
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
S
sneaxiy 已提交
185 186
                   bool create_local_scope, bool create_vars,
                   const std::vector<std::string>& skip_ref_cnt_vars,
187
                   bool force_disable_gc, bool keep_kid_scopes) {
X
Xin Pan 已提交
188
  platform::RecordBlock b(block_id);
189
  if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
S
sneaxiy 已提交
190
  auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
191 192
  RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars,
                     keep_kid_scopes);
Q
qijun 已提交
193 194
}

195 196 197 198 199 200 201
// Check whether the block already has feed operators and feed_holder.
// Return false if the block does not have any feed operators.
// If some feed operators have been prepended to the block, check that
// the info contained in these feed operators matches the feed_targets
// and feed_holder_name. Raise exception when any mismatch is found.
// Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators(
202
    const BlockDesc& block,
L
Liu Yiqun 已提交
203
    const std::map<std::string, const LoDTensor*>& feed_targets,
204 205
    const std::string& feed_holder_name) {
  size_t feed_count = 0;
206
  for (auto* op : block.AllOps()) {
207 208
    if (op->Type() == kFeedOpType) {
      feed_count++;
L
Liu Yiqun 已提交
209
      // The input variable's name of feed_op should be feed_holder_name.
210 211 212 213 214
      PADDLE_ENFORCE_EQ(
          op->Input("X")[0], feed_holder_name,
          platform::errors::PreconditionNotMet(
              "Input to feed op should be '%s', but received '%s'.",
              feed_holder_name, op->Input("X")[0]));
215
      std::string feed_target_name = op->Output("Out")[0];
216 217 218 219 220
      PADDLE_ENFORCE_NE(feed_targets.find(feed_target_name), feed_targets.end(),
                        platform::errors::PreconditionNotMet(
                            "Feed operator output name '%s' cannot be found in "
                            "'feed_targets'",
                            feed_target_name));
221 222 223 224 225 226
    }
  }

  if (feed_count > 0) {
    PADDLE_ENFORCE_EQ(
        feed_count, feed_targets.size(),
227 228 229 230
        platform::errors::PreconditionNotMet(
            "The number of feed operators should match 'feed_targets', but "
            "received feed_count: %zu, required feed_targets.size(): %zu.",
            feed_count, feed_targets.size()));
231

232
    if (!feed_holder_name.empty()) {
L
Liu Yiqun 已提交
233
      // When feed operator are present, so should be feed_holder.
234
      auto var = block.FindVar(feed_holder_name);
235 236 237 238 239 240 241 242 243 244
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable", feed_holder_name));
      PADDLE_ENFORCE_EQ(
          var->GetType(), proto::VarType::FEED_MINIBATCH,
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FEED_MINIBATCH' type, but received "
              "'%s'.",
              feed_holder_name, DataTypeToString(var->GetType())));
245
    }
246 247 248 249 250 251 252 253 254 255 256 257
  }

  return feed_count > 0;
}

// Check whether the block already has fetch operators and fetch_holder.
// Return false if the block does not have any fetch operators.
// If some fetch operators have been appended to the block, check that
// the info contained in these fetch operators matches the fetch_targets
// and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators(
L
Liu Yiqun 已提交
258 259
    const BlockDesc& block,
    const std::map<std::string, LoDTensor*>& fetch_targets,
260 261
    const std::string& fetch_holder_name) {
  size_t fetch_count = 0;
262
  for (auto* op : block.AllOps()) {
263 264
    if (op->Type() == kFetchOpType) {
      fetch_count++;
L
Liu Yiqun 已提交
265
      // The output variable's name of fetch_op should be fetch_holder_name.
266 267 268 269 270
      PADDLE_ENFORCE_EQ(
          op->Output("Out")[0], fetch_holder_name,
          platform::errors::PreconditionNotMet(
              "Output of fetch op should be '%s', but received '%s'.",
              fetch_holder_name, op->Output("Out")[0]));
271
      std::string fetch_target_name = op->Input("X")[0];
272 273 274 275 276 277
      PADDLE_ENFORCE_NE(fetch_targets.find(fetch_target_name),
                        fetch_targets.end(),
                        platform::errors::NotFound(
                            "Fetch operator input name '%s' cannot be found in "
                            "'fetch_targets'.",
                            fetch_target_name));
278 279 280 281 282 283
    }
  }

  if (fetch_count > 0) {
    PADDLE_ENFORCE_EQ(
        fetch_count, fetch_targets.size(),
284 285 286 287
        platform::errors::PreconditionNotMet(
            "The number of fetch operators should match 'fetch_targets', but "
            "received fetch_count: %zu, required fetch_targets.size(): %zu.",
            fetch_count, fetch_targets.size()));
288

289
    if (!fetch_holder_name.empty()) {
L
Liu Yiqun 已提交
290
      // When fetch operator are present, so should be fetch_holder.
291
      auto var = block.FindVar(fetch_holder_name);
292 293 294 295 296 297 298 299 300
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable.", fetch_holder_name));
      PADDLE_ENFORCE_EQ(
          var->GetType(), proto::VarType::FETCH_LIST,
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FETCH_LIST' type, but received '%s'.",
              fetch_holder_name, DataTypeToString(var->GetType())));
301
    }
302 303 304 305 306 307
  }

  return fetch_count > 0;
}

void Executor::Run(const ProgramDesc& program, Scope* scope,
308 309
                   std::map<std::string, const LoDTensor*>* feed_targets,
                   std::map<std::string, LoDTensor*>* fetch_targets,
W
Wu Yi 已提交
310 311
                   bool create_local_scope, bool create_vars,
                   const std::string& feed_holder_name,
312
                   const std::string& fetch_holder_name) {
X
Xin Pan 已提交
313
  platform::RecordBlock b(kProgramId);
314
  if (FLAGS_use_mkldnn) EnableMKLDNN(program);
315
  bool has_feed_ops =
316
      has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
317
  bool has_fetch_ops =
318
      has_fetch_operators(program.Block(0), *fetch_targets, fetch_holder_name);
319 320

  ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
S
sneaxiy 已提交
321
  std::unique_ptr<ProgramDesc> unique_ptr_of_copy_program;
322
  if (!has_feed_ops || !has_fetch_ops) {
S
sneaxiy 已提交
323 324
    unique_ptr_of_copy_program.reset(new ProgramDesc(program));
    copy_program = unique_ptr_of_copy_program.get();
325
  }
326 327
  auto* global_block = copy_program->MutableBlock(0);

328
  if (!has_feed_ops) {
329 330
    // create feed_holder variable
    auto* feed_holder = global_block->Var(feed_holder_name);
331
    feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
332 333 334
    feed_holder->SetPersistable(true);

    int i = 0;
335
    for (auto& feed_target : (*feed_targets)) {
336
      std::string var_name = feed_target.first;
M
minqiyang 已提交
337
      VLOG(3) << "feed target's name: " << var_name;
338 339 340 341 342 343 344 345 346 347 348 349 350

      // prepend feed op
      auto* op = global_block->PrependOp();
      op->SetType(kFeedOpType);
      op->SetInput("X", {feed_holder_name});
      op->SetOutput("Out", {var_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

351
  if (!has_fetch_ops) {
352 353
    // create fetch_holder variable
    auto* fetch_holder = global_block->Var(fetch_holder_name);
354
    fetch_holder->SetType(proto::VarType::FETCH_LIST);
355 356 357
    fetch_holder->SetPersistable(true);

    int i = 0;
358
    for (auto& fetch_target : (*fetch_targets)) {
359
      std::string var_name = fetch_target.first;
M
minqiyang 已提交
360
      VLOG(3) << "fetch target's name: " << var_name;
361 362 363 364 365 366 367 368 369 370 371 372 373

      // append fetch op
      auto* op = global_block->AppendOp();
      op->SetType(kFetchOpType);
      op->SetInput("X", {var_name});
      op->SetOutput("Out", {fetch_holder_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

374
  auto ctx = Prepare(*copy_program, 0);
W
Wu Yi 已提交
375 376 377
  RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
                     create_local_scope, create_vars, feed_holder_name,
                     fetch_holder_name);
378 379
}

Q
Qiao Longfei 已提交
380
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
S
fix bug  
sneaxiy 已提交
381
    const ProgramDesc& program, int block_id,
S
sneaxiy 已提交
382
    const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
S
sneaxiy 已提交
383 384
  std::unique_ptr<ExecutorPrepareContext> ctx(
      new ExecutorPrepareContext(program, block_id));
385 386 387 388 389
  PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size(),
                    platform::errors::InvalidArgument(
                        "Input block id = %d, but it should be less than "
                        "program.size() which is %d",
                        static_cast<size_t>(block_id), program.Size()));
Y
Yu Yang 已提交
390 391 392 393
  auto& block = program.Block(block_id);
  for (auto& op_desc : block.AllOps()) {
    ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
S
sneaxiy 已提交
394
  ctx->PrepareUnusedVars(skip_ref_cnt_vars, force_disable_gc);
Q
Qiyang Min 已提交
395
  return ctx;
Y
Yu Yang 已提交
396 397
}

T
refine  
typhoonzero 已提交
398
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
S
fix bug  
sneaxiy 已提交
399
    const ProgramDesc& program, const std::vector<int>& block_ids,
S
sneaxiy 已提交
400 401
    const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
    bool force_disable_gc) {
402
  PADDLE_ENFORCE_EQ(
S
fix bug  
sneaxiy 已提交
403
      skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
404 405 406 407
      true,
      platform::errors::InvalidArgument("skip_ref_cnt_vars should be either "
                                        "empty or equals to block number %d",
                                        block_ids.size()));
T
typhoonzero 已提交
408
  std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
S
fix bug  
sneaxiy 已提交
409
  size_t idx = 0;
T
typhoonzero 已提交
410
  for (auto& bid : block_ids) {
411 412 413 414 415
    PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size(),
                      platform::errors::InvalidArgument(
                          "Input block id = %zu, but it should be less than "
                          "program.size() which is %zu",
                          static_cast<size_t>(bid), program.Size()));
S
sneaxiy 已提交
416
    auto* ctx = new ExecutorPrepareContext(program, bid);
T
typhoonzero 已提交
417 418 419 420
    auto& block = program.Block(bid);
    for (auto& op_desc : block.AllOps()) {
      ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
    }
S
sneaxiy 已提交
421 422 423 424 425
    if (skip_ref_cnt_vars.empty()) {
      ctx->PrepareUnusedVars(std::vector<std::string>(), force_disable_gc);
    } else {
      ctx->PrepareUnusedVars(skip_ref_cnt_vars[idx], force_disable_gc);
    }
T
typhoonzero 已提交
426
    result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx));
S
fix bug  
sneaxiy 已提交
427
    ++idx;
T
typhoonzero 已提交
428 429 430 431
  }
  return result;
}

432 433 434 435 436
void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
                                         Scope* scope, int64_t start_op_index,
                                         int64_t end_op_index,
                                         bool create_local_scope,
                                         bool create_vars, bool keep_kids) {
437
  platform::RecordBlock b(kProgramId);
438 439
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope shouldn't be null"));
Y
Yu Yang 已提交
440 441 442 443
  Scope* local_scope = scope;
  if (create_vars) {
    if (create_local_scope) {
      local_scope = &scope->NewScope();
444 445
    }
    CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
L
Liu Yiqun 已提交
446
  }
Y
Yu Yang 已提交
447

S
sneaxiy 已提交
448
  int64_t max_memory_size = GetEagerDeletionThreshold();
S
sneaxiy 已提交
449
  std::unique_ptr<GarbageCollector> gc;
S
sneaxiy 已提交
450
  if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
S
sneaxiy 已提交
451 452
#ifdef PADDLE_WITH_CUDA
    if (platform::is_gpu_place(place_)) {
S
fix bug  
sneaxiy 已提交
453
      if (IsFastEagerDeletionModeEnabled()) {
S
sneaxiy 已提交
454
        gc.reset(new UnsafeFastGPUGarbageCollector(
S
fix bug  
sneaxiy 已提交
455 456
            boost::get<platform::CUDAPlace>(place_), max_memory_size));
      } else {
S
sneaxiy 已提交
457
        gc.reset(new DefaultStreamGarbageCollector(
S
fix bug  
sneaxiy 已提交
458 459 460
            boost::get<platform::CUDAPlace>(place_), max_memory_size));
      }
    } else if (platform::is_cpu_place(place_)) {
S
sneaxiy 已提交
461
#endif
S
sneaxiy 已提交
462 463
      gc.reset(new CPUGarbageCollector(boost::get<platform::CPUPlace>(place_),
                                       max_memory_size));
S
sneaxiy 已提交
464 465 466 467 468
#ifdef PADDLE_WITH_CUDA
    }
#endif
  }

469 470
  for (int64_t i = start_op_index; i < end_op_index; ++i) {
    auto& op = ctx->ops_[i];
471
    op->Run(*local_scope, place_);
S
fix bug  
sneaxiy 已提交
472
    if (gc) {
S
sneaxiy 已提交
473
      DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get());
S
sneaxiy 已提交
474
    }
Y
Yu Yang 已提交
475
  }
S
sneaxiy 已提交
476

S
fix bug  
sneaxiy 已提交
477
  platform::DeviceContextPool::Instance().Get(place_)->Wait();
S
sneaxiy 已提交
478

Q
qiaolongfei 已提交
479
  if (local_scope != scope) {
Y
Yu Yang 已提交
480
    scope->DeleteScope(local_scope);
481
  } else {
Q
qiaolongfei 已提交
482 483 484 485 486
    if (!keep_kids) {
      // By default, we should delete all kid scopes after run executor because
      // some operators may create local scope when running, such as while_op.
      // But when while_op also create a local executor to run it's sub block,
      // the sub scopes it created should not be dropped immediately, because
Q
qiaolongfei 已提交
487 488
      // while_grad_op will use some variables created during while_op run, so
      // we need to keep the kids and wait for the outer executor to drop them.
489

Q
qiaolongfei 已提交
490 491
      scope->DropKids();
    }
Y
Yu Yang 已提交
492 493 494
  }
}

495 496 497 498 499 500 501 502 503
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
                                  bool create_local_scope, bool create_vars,
                                  bool keep_kids) {
  int64_t start_op_index = 0;
  int64_t end_op_index = ctx->ops_.size();
  RunPartialPreparedContext(ctx, scope, start_op_index, end_op_index,
                            create_local_scope, create_vars, keep_kids);
}

504 505
void Executor::RunPreparedContext(
    ExecutorPrepareContext* ctx, Scope* scope,
506
    std::map<std::string, const LoDTensor*>* feed_targets,
W
Wu Yi 已提交
507 508 509
    std::map<std::string, LoDTensor*>* fetch_targets, bool create_local_scope,
    bool create_vars, const std::string& feed_holder_name,
    const std::string& fetch_holder_name) {
510 511
  auto& global_block = ctx->prog_.Block(ctx->block_id_);

512 513 514 515 516
  PADDLE_ENFORCE_EQ(
      has_feed_operators(global_block, *feed_targets, feed_holder_name), true,
      platform::errors::PreconditionNotMet(
          "Program in ExecutorPrepareContext should has feed_ops."));
  PADDLE_ENFORCE_EQ(
517
      has_fetch_operators(global_block, *fetch_targets, fetch_holder_name),
518 519
      true, platform::errors::PreconditionNotMet(
                "Program in the prepared context should has fetch_ops."));
520

521 522 523 524 525
  // map the data of feed_targets to feed_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFeedOpType) {
      std::string feed_target_name = op->Output("Out")[0];
      int idx = boost::get<int>(op->GetAttr("col"));
526 527
      SetFeedVariable(scope, *(*feed_targets)[feed_target_name],
                      feed_holder_name, idx);
528 529 530
    }
  }

W
Wu Yi 已提交
531
  RunPreparedContext(ctx, scope, create_local_scope, create_vars);
532 533 534 535 536 537

  // obtain the data of fetch_targets from fetch_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFetchOpType) {
      std::string fetch_target_name = op->Input("X")[0];
      int idx = boost::get<int>(op->GetAttr("col"));
538
      *(*fetch_targets)[fetch_target_name] =
539 540 541 542 543
          GetFetchVariable(*scope, fetch_holder_name, idx);
    }
  }
}

544 545
void Executor::EnableMKLDNN(const ProgramDesc& program) {
#ifdef PADDLE_WITH_MKLDNN
M
minqiyang 已提交
546
  VLOG(3) << "use_mkldnn=True";
547 548 549 550 551 552 553 554
  for (size_t bid = 0; bid < program.Size(); ++bid) {
    auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
    for (auto* op : block->AllOps()) {
      if (op->HasAttr("use_mkldnn")) {
        op->SetAttr("use_mkldnn", true);
      }
    }
  }
555 556 557
#else
  LOG(WARNING)
      << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
558 559
#endif
}
Q
qijun 已提交
560 561
}  // namespace framework
}  // namespace paddle