executor.cc 22.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/executor.h"
S
sneaxiy 已提交
16
#include <deque>
17
#include <memory>
18
#include <unordered_map>
19
#include <unordered_set>
S
sneaxiy 已提交
20
#include <utility>
D
dongdaxiang 已提交
21 22 23
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
24
#include "paddle/fluid/framework/data_type.h"
Y
Yi Wang 已提交
25 26 27 28 29
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
D
dongdaxiang 已提交
30 31
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
32
#include "paddle/fluid/framework/transfer_scope_cache.h"
W
Wang Guibao 已提交
33
#include "paddle/fluid/framework/variable_helper.h"
Z
Zeng Jinle 已提交
34
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
35
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
S
sneaxiy 已提交
36
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
Y
Yi Wang 已提交
37
#include "paddle/fluid/platform/place.h"
X
Xin Pan 已提交
38
#include "paddle/fluid/platform/profiler.h"
39 40 41
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
Y
Yang Yu 已提交
42

D
dzhwinter 已提交
43
DECLARE_bool(benchmark);
44
DECLARE_bool(use_mkldnn);
Q
qijun 已提交
45 46 47

namespace paddle {
namespace framework {
X
Xin Pan 已提交
48 49 50 51 52
namespace {
// block id starts from 0. This id is used to represent the codeblock
// wrapping the first block 0.
int kProgramId = -1;
}  // namespace
Q
qijun 已提交
53

Q
Qiao Longfei 已提交
54
ExecutorPrepareContext::ExecutorPrepareContext(
S
sneaxiy 已提交
55 56 57 58 59
    const framework::ProgramDesc& prog, size_t block_id)
    : prog_(prog), block_id_(block_id) {}

void ExecutorPrepareContext::PrepareUnusedVars(
    const std::vector<std::string>& keep_vars, bool force_disable_gc) {
Z
Zeng Jinle 已提交
60 61 62
  // If gc is enabled and block size > 1
  if (prog_.Size() > 1) {
    operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
63 64 65
        prog_, block_id_, ops_);
    operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(prog_, block_id_,
                                                               ops_);
Z
Zeng Jinle 已提交
66
    operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
67
        prog_, block_id_, ops_);
Z
Zeng Jinle 已提交
68
  }
69 70 71 72 73 74

  force_disable_gc_ = force_disable_gc;
  if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) {
    return;
  }

S
sneaxiy 已提交
75
  unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars);
S
sneaxiy 已提交
76
}
Y
Yu Yang 已提交
77

Q
Qiao Longfei 已提交
78
ExecutorPrepareContext::~ExecutorPrepareContext() {
M
minqiyang 已提交
79
  VLOG(5) << "destroy ExecutorPrepareContext";
Q
Qiao Longfei 已提交
80
}
Y
Yu Yang 已提交
81

D
dzhwinter 已提交
82
Executor::Executor(const platform::Place& place) : place_(place) {}
Q
qijun 已提交
83

84 85
Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN
86
  // Clear mkl-dnn cache,
87
  // this is needed to have mkl-dnn unit tests working
88
  ClearMKLDNNCache(place_);
89 90 91
#endif
}

Y
Yancey1989 已提交
92
void Executor::Close() {
T
tangwei12 已提交
93 94 95 96 97 98 99
  // #ifdef PADDLE_WITH_DISTRIBUTE
  //   // TODO(typhoonzero): complete message will need to use real trainer_id,
  //   // except 0.
  //   auto client =
  //       paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
  //   client->SendComplete();
  // #endif
Y
Yancey1989 已提交
100
}
W
Wu Yi 已提交
101

L
Liu Yiqun 已提交
102 103
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
                               int block_id) {
104
  VLOG(3) << "Creating Variables for block " << block_id;
L
Liu Yiqun 已提交
105
  auto& global_block = pdesc.Block(block_id);
106 107 108 109 110 111 112 113 114 115 116 117
  const Scope* ancestor_scope = scope;
  while (ancestor_scope->parent()) {
    ancestor_scope = ancestor_scope->parent();
  }
  if (ancestor_scope != scope) {
    for (auto& var : global_block.AllVars()) {
      if (var->Name() == framework::kEmptyVarName) {
        continue;
      }

      if (var->Persistable()) {
        auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
118
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
119 120
        VLOG(3) << "Create Variable " << var->Name()
                << " global, which pointer is " << ptr;
121 122
      } else {
        auto* ptr = scope->Var(var->Name());
123
        InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
124 125
        VLOG(3) << "Create Variable " << var->Name()
                << " locally, which pointer is " << ptr;
126 127 128 129 130
      }
    }
  } else {
    for (auto& var : global_block.AllVars()) {
      auto* ptr = scope->Var(var->Name());
131
      InitializeVariable(ptr, var->GetType());
M
minqiyang 已提交
132 133
      VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
              << ptr;
134 135 136 137
    }
  }
}

138 139 140
std::shared_ptr<TrainerBase> Executor::InitForDataset(
    const ProgramDesc& main_program, const std::string& trainer_desc_str,
    Scope* scope, Dataset* dataset) {
D
dongdaxiang 已提交
141 142
  VLOG(3) << "Start to RunFromDataset in executor";
  TrainerDesc trainer_desc;
H
hutuxian 已提交
143
  bool success = trainer_desc.ParseFromString(trainer_desc_str);
144 145 146 147
  PADDLE_ENFORCE_EQ(success, true,
                    platform::errors::PreconditionNotMet(
                        "Fail to parse TrainerDesc from string:\n%s",
                        trainer_desc_str.c_str()));
D
dongdaxiang 已提交
148 149 150 151 152 153 154 155
  VLOG(3) << "Going to create trainer, trainer class is "
          << trainer_desc.class_name();
  std::shared_ptr<TrainerBase> trainer;
  trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name());
  // initialize trainer
  VLOG(3) << "Going to initialize trainer";
  trainer->Initialize(trainer_desc, dataset);
  VLOG(3) << "Set root scope here";
D
dongdaxiang 已提交
156
  trainer->SetScope(scope);
D
dongdaxiang 已提交
157 158 159 160 161
  // prepare training environment and helper environment
  VLOG(3) << "Try to init train environment";
  trainer->InitTrainerEnv(main_program, place_);
  VLOG(3) << "Try to init other environment";
  trainer->InitOtherEnv(main_program);
162 163 164 165
  return trainer;
}

void Executor::RunFromDataset(std::shared_ptr<TrainerBase> trainer) {
166 167 168
  PADDLE_ENFORCE_NOT_NULL(
      trainer, platform::errors::InvalidArgument(
                   "Trainer is nullptr, invoke InitForDataset first"));
D
dongdaxiang 已提交
169 170 171
  // training and finalize training
  VLOG(3) << "Trainer starts to run";
  trainer->Run();
D
Dong Daxiang 已提交
172 173 174
}

void Executor::ReleaseTrainer(std::shared_ptr<TrainerBase> trainer) {
D
dongdaxiang 已提交
175 176 177
  VLOG(3) << "Trainer going to finalize";
  trainer->Finalize();
}
D
dongdaxiang 已提交
178

Y
Yu Yang 已提交
179
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
S
sneaxiy 已提交
180 181
                   bool create_local_scope, bool create_vars,
                   const std::vector<std::string>& skip_ref_cnt_vars,
182
                   bool force_disable_gc, bool keep_kid_scopes) {
X
Xin Pan 已提交
183
  platform::RecordBlock b(block_id);
184
  if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
S
sneaxiy 已提交
185
  auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
186 187
  RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars,
                     keep_kid_scopes);
Q
qijun 已提交
188 189
}

190 191 192 193 194 195 196
// Check whether the block already has feed operators and feed_holder.
// Return false if the block does not have any feed operators.
// If some feed operators have been prepended to the block, check that
// the info contained in these feed operators matches the feed_targets
// and feed_holder_name. Raise exception when any mismatch is found.
// Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators(
197
    const BlockDesc& block,
L
Liu Yiqun 已提交
198
    const std::map<std::string, const LoDTensor*>& feed_targets,
199 200
    const std::string& feed_holder_name) {
  size_t feed_count = 0;
201
  for (auto* op : block.AllOps()) {
202 203
    if (op->Type() == kFeedOpType) {
      feed_count++;
L
Liu Yiqun 已提交
204
      // The input variable's name of feed_op should be feed_holder_name.
205 206 207 208 209
      PADDLE_ENFORCE_EQ(
          op->Input("X")[0], feed_holder_name,
          platform::errors::PreconditionNotMet(
              "Input to feed op should be '%s', but received '%s'.",
              feed_holder_name, op->Input("X")[0]));
210
      std::string feed_target_name = op->Output("Out")[0];
211 212 213 214 215
      PADDLE_ENFORCE_NE(feed_targets.find(feed_target_name), feed_targets.end(),
                        platform::errors::PreconditionNotMet(
                            "Feed operator output name '%s' cannot be found in "
                            "'feed_targets'",
                            feed_target_name));
216 217 218 219 220 221
    }
  }

  if (feed_count > 0) {
    PADDLE_ENFORCE_EQ(
        feed_count, feed_targets.size(),
222 223 224 225
        platform::errors::PreconditionNotMet(
            "The number of feed operators should match 'feed_targets', but "
            "received feed_count: %zu, required feed_targets.size(): %zu.",
            feed_count, feed_targets.size()));
226

227
    if (!feed_holder_name.empty()) {
L
Liu Yiqun 已提交
228
      // When feed operator are present, so should be feed_holder.
229
      auto var = block.FindVar(feed_holder_name);
230 231 232 233 234 235 236 237 238 239
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable", feed_holder_name));
      PADDLE_ENFORCE_EQ(
          var->GetType(), proto::VarType::FEED_MINIBATCH,
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FEED_MINIBATCH' type, but received "
              "'%s'.",
              feed_holder_name, DataTypeToString(var->GetType())));
240
    }
241 242 243 244 245 246 247 248 249 250 251 252
  }

  return feed_count > 0;
}

// Check whether the block already has fetch operators and fetch_holder.
// Return false if the block does not have any fetch operators.
// If some fetch operators have been appended to the block, check that
// the info contained in these fetch operators matches the fetch_targets
// and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators(
L
Liu Yiqun 已提交
253
    const BlockDesc& block,
254
    const std::map<std::string, FetchType*>& fetch_targets,
255 256
    const std::string& fetch_holder_name) {
  size_t fetch_count = 0;
257
  for (auto* op : block.AllOps()) {
258 259
    if (op->Type() == kFetchOpType) {
      fetch_count++;
L
Liu Yiqun 已提交
260
      // The output variable's name of fetch_op should be fetch_holder_name.
261 262 263 264 265
      PADDLE_ENFORCE_EQ(
          op->Output("Out")[0], fetch_holder_name,
          platform::errors::PreconditionNotMet(
              "Output of fetch op should be '%s', but received '%s'.",
              fetch_holder_name, op->Output("Out")[0]));
266
      std::string fetch_target_name = op->Input("X")[0];
267 268 269 270 271 272
      PADDLE_ENFORCE_NE(fetch_targets.find(fetch_target_name),
                        fetch_targets.end(),
                        platform::errors::NotFound(
                            "Fetch operator input name '%s' cannot be found in "
                            "'fetch_targets'.",
                            fetch_target_name));
273 274 275 276 277 278
    }
  }

  if (fetch_count > 0) {
    PADDLE_ENFORCE_EQ(
        fetch_count, fetch_targets.size(),
279 280 281 282
        platform::errors::PreconditionNotMet(
            "The number of fetch operators should match 'fetch_targets', but "
            "received fetch_count: %zu, required fetch_targets.size(): %zu.",
            fetch_count, fetch_targets.size()));
283

284
    if (!fetch_holder_name.empty()) {
L
Liu Yiqun 已提交
285
      // When fetch operator are present, so should be fetch_holder.
286
      auto var = block.FindVar(fetch_holder_name);
287 288 289 290 291 292 293 294 295
      PADDLE_ENFORCE_NOT_NULL(
          var,
          platform::errors::PreconditionNotMet(
              "Block should already have a '%s' variable.", fetch_holder_name));
      PADDLE_ENFORCE_EQ(
          var->GetType(), proto::VarType::FETCH_LIST,
          platform::errors::PreconditionNotMet(
              "'%s' variable should be 'FETCH_LIST' type, but received '%s'.",
              fetch_holder_name, DataTypeToString(var->GetType())));
296
    }
297 298 299 300 301 302
  }

  return fetch_count > 0;
}

void Executor::Run(const ProgramDesc& program, Scope* scope,
303
                   std::map<std::string, const LoDTensor*>* feed_targets,
304
                   std::map<std::string, FetchType*>* fetch_targets,
W
Wu Yi 已提交
305 306
                   bool create_local_scope, bool create_vars,
                   const std::string& feed_holder_name,
307
                   const std::string& fetch_holder_name) {
X
Xin Pan 已提交
308
  platform::RecordBlock b(kProgramId);
309
  if (FLAGS_use_mkldnn) EnableMKLDNN(program);
310
  bool has_feed_ops =
311
      has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
312
  bool has_fetch_ops =
313
      has_fetch_operators(program.Block(0), *fetch_targets, fetch_holder_name);
314 315

  ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
S
sneaxiy 已提交
316
  std::unique_ptr<ProgramDesc> unique_ptr_of_copy_program;
317
  if (!has_feed_ops || !has_fetch_ops) {
S
sneaxiy 已提交
318 319
    unique_ptr_of_copy_program.reset(new ProgramDesc(program));
    copy_program = unique_ptr_of_copy_program.get();
320
  }
321 322
  auto* global_block = copy_program->MutableBlock(0);

323
  if (!has_feed_ops) {
324 325
    // create feed_holder variable
    auto* feed_holder = global_block->Var(feed_holder_name);
326
    feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
327 328 329
    feed_holder->SetPersistable(true);

    int i = 0;
330
    for (auto& feed_target : (*feed_targets)) {
331
      std::string var_name = feed_target.first;
M
minqiyang 已提交
332
      VLOG(3) << "feed target's name: " << var_name;
333 334 335 336 337 338 339 340 341 342 343 344 345

      // prepend feed op
      auto* op = global_block->PrependOp();
      op->SetType(kFeedOpType);
      op->SetInput("X", {feed_holder_name});
      op->SetOutput("Out", {var_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

346
  if (!has_fetch_ops) {
347 348
    // create fetch_holder variable
    auto* fetch_holder = global_block->Var(fetch_holder_name);
349
    fetch_holder->SetType(proto::VarType::FETCH_LIST);
350 351 352
    fetch_holder->SetPersistable(true);

    int i = 0;
353
    for (auto& fetch_target : (*fetch_targets)) {
354
      std::string var_name = fetch_target.first;
M
minqiyang 已提交
355
      VLOG(3) << "fetch target's name: " << var_name;
356 357 358 359 360 361 362 363 364 365 366 367 368

      // append fetch op
      auto* op = global_block->AppendOp();
      op->SetType(kFetchOpType);
      op->SetInput("X", {var_name});
      op->SetOutput("Out", {fetch_holder_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

369
  auto ctx = Prepare(*copy_program, 0);
W
Wu Yi 已提交
370 371 372
  RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
                     create_local_scope, create_vars, feed_holder_name,
                     fetch_holder_name);
373 374
}

Q
Qiao Longfei 已提交
375
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
S
fix bug  
sneaxiy 已提交
376
    const ProgramDesc& program, int block_id,
S
sneaxiy 已提交
377
    const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
S
sneaxiy 已提交
378 379
  std::unique_ptr<ExecutorPrepareContext> ctx(
      new ExecutorPrepareContext(program, block_id));
380 381 382 383 384
  PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size(),
                    platform::errors::InvalidArgument(
                        "Input block id = %d, but it should be less than "
                        "program.size() which is %d",
                        static_cast<size_t>(block_id), program.Size()));
Y
Yu Yang 已提交
385 386 387 388
  auto& block = program.Block(block_id);
  for (auto& op_desc : block.AllOps()) {
    ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
S
sneaxiy 已提交
389
  ctx->PrepareUnusedVars(skip_ref_cnt_vars, force_disable_gc);
Q
Qiyang Min 已提交
390
  return ctx;
Y
Yu Yang 已提交
391 392
}

T
refine  
typhoonzero 已提交
393
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
S
fix bug  
sneaxiy 已提交
394
    const ProgramDesc& program, const std::vector<int>& block_ids,
S
sneaxiy 已提交
395 396
    const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
    bool force_disable_gc) {
397
  PADDLE_ENFORCE_EQ(
S
fix bug  
sneaxiy 已提交
398
      skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
399 400 401 402
      true,
      platform::errors::InvalidArgument("skip_ref_cnt_vars should be either "
                                        "empty or equals to block number %d",
                                        block_ids.size()));
T
typhoonzero 已提交
403
  std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
S
fix bug  
sneaxiy 已提交
404
  size_t idx = 0;
T
typhoonzero 已提交
405
  for (auto& bid : block_ids) {
406 407 408 409 410
    PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size(),
                      platform::errors::InvalidArgument(
                          "Input block id = %zu, but it should be less than "
                          "program.size() which is %zu",
                          static_cast<size_t>(bid), program.Size()));
S
sneaxiy 已提交
411
    auto* ctx = new ExecutorPrepareContext(program, bid);
T
typhoonzero 已提交
412 413 414 415
    auto& block = program.Block(bid);
    for (auto& op_desc : block.AllOps()) {
      ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
    }
S
sneaxiy 已提交
416 417 418 419 420
    if (skip_ref_cnt_vars.empty()) {
      ctx->PrepareUnusedVars(std::vector<std::string>(), force_disable_gc);
    } else {
      ctx->PrepareUnusedVars(skip_ref_cnt_vars[idx], force_disable_gc);
    }
T
typhoonzero 已提交
421
    result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx));
S
fix bug  
sneaxiy 已提交
422
    ++idx;
T
typhoonzero 已提交
423 424 425 426
  }
  return result;
}

427 428 429 430 431
void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
                                         Scope* scope, int64_t start_op_index,
                                         int64_t end_op_index,
                                         bool create_local_scope,
                                         bool create_vars, bool keep_kids) {
432
  platform::RecordBlock b(kProgramId);
433 434
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope shouldn't be null"));
Y
Yu Yang 已提交
435 436 437 438
  Scope* local_scope = scope;
  if (create_vars) {
    if (create_local_scope) {
      local_scope = &scope->NewScope();
439 440
    }
    CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
L
Liu Yiqun 已提交
441
  }
Y
Yu Yang 已提交
442

S
sneaxiy 已提交
443
  int64_t max_memory_size = GetEagerDeletionThreshold();
S
sneaxiy 已提交
444
  std::unique_ptr<GarbageCollector> gc;
S
sneaxiy 已提交
445
  if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
S
sneaxiy 已提交
446
    if (platform::is_gpu_place(place_)) {
447
#ifdef PADDLE_WITH_CUDA
S
fix bug  
sneaxiy 已提交
448
      if (IsFastEagerDeletionModeEnabled()) {
S
sneaxiy 已提交
449
        gc.reset(new UnsafeFastGPUGarbageCollector(
450
            BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
S
fix bug  
sneaxiy 已提交
451
      } else {
S
sneaxiy 已提交
452
        gc.reset(new DefaultStreamGarbageCollector(
453
            BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
S
fix bug  
sneaxiy 已提交
454
      }
455 456 457
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No GPU gc found in CPU/XPU paddle"));
S
sneaxiy 已提交
458
#endif
459
    } else if (platform::is_cpu_place(place_)) {
460 461
      gc.reset(new CPUGarbageCollector(
          BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
462 463 464 465 466 467 468
    } else if (platform::is_xpu_place(place_)) {
#ifdef PADDLE_WITH_XPU
      gc.reset(new XPUGarbageCollector(
          BOOST_GET_CONST(platform::XPUPlace, place_), max_memory_size));
#else
      PADDLE_THROW(
          platform::errors::Unimplemented("No XPU gc found in CPU/GPU paddle"));
S
sneaxiy 已提交
469
#endif
470
    }
S
sneaxiy 已提交
471 472
  }

473 474
  for (int64_t i = start_op_index; i < end_op_index; ++i) {
    auto& op = ctx->ops_[i];
475
    op->Run(*local_scope, place_);
S
fix bug  
sneaxiy 已提交
476
    if (gc) {
S
sneaxiy 已提交
477
      DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get());
S
sneaxiy 已提交
478
    }
Y
Yu Yang 已提交
479
  }
S
sneaxiy 已提交
480

L
Leo Chen 已提交
481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
  auto callback = [scope, local_scope, keep_kids]() {
    if (local_scope != scope) {
      VLOG(4) << "Delete scope: " << local_scope;
      scope->DeleteScope(local_scope);
    } else {
      if (!keep_kids) {
        VLOG(4) << "Drop kids: " << scope;
        // By default, we should delete all kid scopes after run executor
        // because
        // some operators may create local scope when running, such as while_op.
        // But when while_op also create a local executor to run it's sub block,
        // the sub scopes it created should not be dropped immediately, because
        // while_grad_op will use some variables created during while_op run, so
        // we need to keep the kids and wait for the outer executor to drop
        // them.

        scope->DropKids();
      }
      VLOG(4) << "Keep kids: " << scope;
    }
  };
S
sneaxiy 已提交
502

L
Leo Chen 已提交
503 504 505
  if (gc) {
    VLOG(4) << "Async deleting scope";
    gc->DirectClearCallback(callback);
506
  } else {
L
Leo Chen 已提交
507 508 509
    VLOG(4) << "Sync deleting scope";
    platform::DeviceContextPool::Instance().Get(place_)->Wait();
    callback();
Y
Yu Yang 已提交
510 511 512
  }
}

513 514 515 516 517 518 519 520 521
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
                                  bool create_local_scope, bool create_vars,
                                  bool keep_kids) {
  int64_t start_op_index = 0;
  int64_t end_op_index = ctx->ops_.size();
  RunPartialPreparedContext(ctx, scope, start_op_index, end_op_index,
                            create_local_scope, create_vars, keep_kids);
}

522 523
void Executor::RunPreparedContext(
    ExecutorPrepareContext* ctx, Scope* scope,
524
    std::map<std::string, const LoDTensor*>* feed_targets,
525
    std::map<std::string, FetchType*>* fetch_targets, bool create_local_scope,
W
Wu Yi 已提交
526 527
    bool create_vars, const std::string& feed_holder_name,
    const std::string& fetch_holder_name) {
528 529
  auto& global_block = ctx->prog_.Block(ctx->block_id_);

530 531 532 533 534
  PADDLE_ENFORCE_EQ(
      has_feed_operators(global_block, *feed_targets, feed_holder_name), true,
      platform::errors::PreconditionNotMet(
          "Program in ExecutorPrepareContext should has feed_ops."));
  PADDLE_ENFORCE_EQ(
535
      has_fetch_operators(global_block, *fetch_targets, fetch_holder_name),
536 537
      true, platform::errors::PreconditionNotMet(
                "Program in the prepared context should has fetch_ops."));
538

539 540 541 542
  // map the data of feed_targets to feed_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFeedOpType) {
      std::string feed_target_name = op->Output("Out")[0];
543
      int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
544 545
      SetFeedVariable(scope, *(*feed_targets)[feed_target_name],
                      feed_holder_name, idx);
546 547 548
    }
  }

W
Wu Yi 已提交
549
  RunPreparedContext(ctx, scope, create_local_scope, create_vars);
550 551 552 553 554

  // obtain the data of fetch_targets from fetch_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFetchOpType) {
      std::string fetch_target_name = op->Input("X")[0];
555
      int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
556
      *(*fetch_targets)[fetch_target_name] =
557 558 559 560 561
          GetFetchVariable(*scope, fetch_holder_name, idx);
    }
  }
}

562 563
void Executor::EnableMKLDNN(const ProgramDesc& program) {
#ifdef PADDLE_WITH_MKLDNN
M
minqiyang 已提交
564
  VLOG(3) << "use_mkldnn=True";
565 566 567 568 569 570 571 572
  for (size_t bid = 0; bid < program.Size(); ++bid) {
    auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
    for (auto* op : block->AllOps()) {
      if (op->HasAttr("use_mkldnn")) {
        op->SetAttr("use_mkldnn", true);
      }
    }
  }
573
  platform::AttachPointerHashToMKLDNNKey(this, place_);
574 575 576
#else
  LOG(WARNING)
      << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
577 578
#endif
}
Q
qijun 已提交
579 580
}  // namespace framework
}  // namespace paddle