executor.cc 14.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/executor.h"
Y
Yang Yang 已提交
16

17
#include "paddle/fluid/framework/channel.h"
Y
Yi Wang 已提交
18 19 20 21 22 23
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h"
X
Xin Pan 已提交
24
#include "paddle/fluid/platform/profiler.h"
Y
Yang Yu 已提交
25

D
dzhwinter 已提交
26
DECLARE_bool(benchmark);
Q
qijun 已提交
27 28 29

namespace paddle {
namespace framework {
X
Xin Pan 已提交
30 31 32 33 34
namespace {
// block id starts from 0. This id is used to represent the codeblock
// wrapping the first block 0.
int kProgramId = -1;
}  // namespace
Q
qijun 已提交
35

Q
Qiao Longfei 已提交
36 37 38
ExecutorPrepareContext::ExecutorPrepareContext(
    const framework::ProgramDesc& prog, size_t block_id)
    : prog_(prog), block_id_(block_id) {}
Y
Yu Yang 已提交
39

Q
Qiao Longfei 已提交
40 41 42
ExecutorPrepareContext::~ExecutorPrepareContext() {
  VLOG(5) << "destroy ExecutorPrepareContext";
}
Y
Yu Yang 已提交
43

D
dzhwinter 已提交
44
Executor::Executor(const platform::Place& place) : place_(place) {}
Q
qijun 已提交
45

Y
Stash  
Yu Yang 已提交
46
void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
47
  if (var_type == proto::VarType::LOD_TENSOR) {
Q
QI JUN 已提交
48
    var->GetMutable<LoDTensor>();
49
  } else if (var_type == proto::VarType::SELECTED_ROWS) {
Q
QI JUN 已提交
50
    var->GetMutable<SelectedRows>();
51
  } else if (var_type == proto::VarType::FEED_MINIBATCH) {
Q
QI JUN 已提交
52
    var->GetMutable<FeedFetchList>();
53
  } else if (var_type == proto::VarType::FETCH_LIST) {
Q
QI JUN 已提交
54
    var->GetMutable<FeedFetchList>();
55
  } else if (var_type == proto::VarType::STEP_SCOPES) {
Y
Yu Yang 已提交
56
    var->GetMutable<std::vector<framework::Scope>>();
57
  } else if (var_type == proto::VarType::LOD_RANK_TABLE) {
Y
Yu Yang 已提交
58
    var->GetMutable<LoDRankTable>();
59
  } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
Y
Yu Yang 已提交
60
    var->GetMutable<LoDTensorArray>();
61
  } else if (var_type == proto::VarType::PLACE_LIST) {
Y
Yang Yu 已提交
62
    var->GetMutable<platform::PlaceList>();
63
  } else if (var_type == proto::VarType::READER) {
F
fengjiayi 已提交
64
    var->GetMutable<ReaderHolder>();
65 66
  } else if (var_type == proto::VarType::CHANNEL) {
    var->GetMutable<ChannelHolder>();
T
typhoonzero 已提交
67 68
  } else if (var_type == proto::VarType::RAW) {
    // GetMutable will be called in operator
Q
QI JUN 已提交
69 70
  } else {
    PADDLE_THROW(
Y
Yu Yang 已提交
71
        "Variable type %d is not in "
F
fengjiayi 已提交
72
        "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
T
typhoonzero 已提交
73
        "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
Y
Yu Yang 已提交
74
        var_type);
Q
QI JUN 已提交
75 76 77
  }
}

L
Liu Yiqun 已提交
78 79 80
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
                               int block_id) {
  auto& global_block = pdesc.Block(block_id);
81 82 83 84 85 86 87 88 89 90 91 92 93 94

  const Scope* ancestor_scope = scope;
  while (ancestor_scope->parent()) {
    ancestor_scope = ancestor_scope->parent();
  }

  if (ancestor_scope != scope) {
    for (auto& var : global_block.AllVars()) {
      if (var->Name() == framework::kEmptyVarName) {
        continue;
      }

      if (var->Persistable()) {
        auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
95
        InitializeVariable(ptr, var->GetType());
96 97 98 99
        VLOG(3) << "Create Variable " << var->Name()
                << " global, which pointer is " << ptr;
      } else {
        auto* ptr = scope->Var(var->Name());
100
        InitializeVariable(ptr, var->GetType());
101 102 103 104 105 106 107
        VLOG(3) << "Create Variable " << var->Name()
                << " locally, which pointer is " << ptr;
      }
    }
  } else {
    for (auto& var : global_block.AllVars()) {
      auto* ptr = scope->Var(var->Name());
108
      InitializeVariable(ptr, var->GetType());
109 110 111 112 113 114
      VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
              << ptr;
    }
  }
}

Y
Yu Yang 已提交
115
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
T
typhoonzero 已提交
116
                   bool create_local_scope, bool create_vars) {
X
Xin Pan 已提交
117
  platform::RecordBlock b(block_id);
Q
Qiao Longfei 已提交
118 119
  auto ctx = Prepare(pdesc, block_id);
  RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
Q
qijun 已提交
120 121
}

122 123 124 125 126 127 128
// Check whether the block already has feed operators and feed_holder.
// Return false if the block does not have any feed operators.
// If some feed operators have been prepended to the block, check that
// the info contained in these feed operators matches the feed_targets
// and feed_holder_name. Raise exception when any mismatch is found.
// Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators(
129
    const BlockDesc& block,
L
Liu Yiqun 已提交
130
    const std::map<std::string, const LoDTensor*>& feed_targets,
131 132
    const std::string& feed_holder_name) {
  size_t feed_count = 0;
133
  for (auto* op : block.AllOps()) {
134 135
    if (op->Type() == kFeedOpType) {
      feed_count++;
L
Liu Yiqun 已提交
136
      // The input variable's name of feed_op should be feed_holder_name.
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
      PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name,
                        "Input to feed op should be '%s'", feed_holder_name);
      std::string feed_target_name = op->Output("Out")[0];
      PADDLE_ENFORCE(
          feed_targets.find(feed_target_name) != feed_targets.end(),
          "Feed operator output name '%s' cannot be found in 'feed_targets'",
          feed_target_name);
    }
  }

  if (feed_count > 0) {
    PADDLE_ENFORCE_EQ(
        feed_count, feed_targets.size(),
        "The number of feed operators should match 'feed_targets'");

152
    if (!feed_holder_name.empty()) {
L
Liu Yiqun 已提交
153
      // When feed operator are present, so should be feed_holder.
154 155 156 157 158 159 160
      auto var = block.FindVar(feed_holder_name);
      PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
                              feed_holder_name);
      PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
                        "'%s' variable should be 'FEED_MINIBATCH' type",
                        feed_holder_name);
    }
161 162 163 164 165 166 167 168 169 170 171 172
  }

  return feed_count > 0;
}

// Check whether the block already has fetch operators and fetch_holder.
// Return false if the block does not have any fetch operators.
// If some fetch operators have been appended to the block, check that
// the info contained in these fetch operators matches the fetch_targets
// and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators(
L
Liu Yiqun 已提交
173 174
    const BlockDesc& block,
    const std::map<std::string, LoDTensor*>& fetch_targets,
175 176
    const std::string& fetch_holder_name) {
  size_t fetch_count = 0;
177
  for (auto* op : block.AllOps()) {
178 179
    if (op->Type() == kFetchOpType) {
      fetch_count++;
L
Liu Yiqun 已提交
180
      // The output variable's name of fetch_op should be fetch_holder_name.
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
      PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name,
                        "Output of fetch op should be '%s'", fetch_holder_name);
      std::string fetch_target_name = op->Input("X")[0];
      PADDLE_ENFORCE(
          fetch_targets.find(fetch_target_name) != fetch_targets.end(),
          "Fetch operator input name '%s' cannot be found in 'fetch_targets'",
          fetch_target_name);
    }
  }

  if (fetch_count > 0) {
    PADDLE_ENFORCE_EQ(
        fetch_count, fetch_targets.size(),
        "The number of fetch operators should match 'fetch_targets'");

196
    if (!fetch_holder_name.empty()) {
L
Liu Yiqun 已提交
197
      // When fetch operator are present, so should be fetch_holder.
198 199 200 201 202 203 204
      auto var = block.FindVar(fetch_holder_name);
      PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
                              fetch_holder_name);
      PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
                        "'%s' variable should be 'FETCH_LIST' type",
                        fetch_holder_name);
    }
205 206 207 208 209 210
  }

  return fetch_count > 0;
}

void Executor::Run(const ProgramDesc& program, Scope* scope,
211 212
                   std::map<std::string, const LoDTensor*>* feed_targets,
                   std::map<std::string, LoDTensor*>* fetch_targets,
W
Wu Yi 已提交
213 214
                   bool create_local_scope, bool create_vars,
                   const std::string& feed_holder_name,
215
                   const std::string& fetch_holder_name) {
X
Xin Pan 已提交
216
  platform::RecordBlock b(kProgramId);
217
  bool has_feed_ops =
218
      has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
219
  bool has_fetch_ops =
220
      has_fetch_operators(program.Block(0), *fetch_targets, fetch_holder_name);
221 222

  ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
S
sneaxiy 已提交
223
  std::unique_ptr<ProgramDesc> unique_ptr_of_copy_program;
224
  if (!has_feed_ops || !has_fetch_ops) {
S
sneaxiy 已提交
225 226
    unique_ptr_of_copy_program.reset(new ProgramDesc(program));
    copy_program = unique_ptr_of_copy_program.get();
227 228
  }

229 230
  auto* global_block = copy_program->MutableBlock(0);

231
  if (!has_feed_ops) {
232 233
    // create feed_holder variable
    auto* feed_holder = global_block->Var(feed_holder_name);
234
    feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
235 236 237
    feed_holder->SetPersistable(true);

    int i = 0;
238
    for (auto& feed_target : (*feed_targets)) {
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
      std::string var_name = feed_target.first;
      VLOG(3) << "feed target's name: " << var_name;

      // prepend feed op
      auto* op = global_block->PrependOp();
      op->SetType(kFeedOpType);
      op->SetInput("X", {feed_holder_name});
      op->SetOutput("Out", {var_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

254
  if (!has_fetch_ops) {
255 256
    // create fetch_holder variable
    auto* fetch_holder = global_block->Var(fetch_holder_name);
257
    fetch_holder->SetType(proto::VarType::FETCH_LIST);
258 259 260
    fetch_holder->SetPersistable(true);

    int i = 0;
261
    for (auto& fetch_target : (*fetch_targets)) {
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
      std::string var_name = fetch_target.first;
      VLOG(3) << "fetch target's name: " << var_name;

      // append fetch op
      auto* op = global_block->AppendOp();
      op->SetType(kFetchOpType);
      op->SetInput("X", {var_name});
      op->SetOutput("Out", {fetch_holder_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

277
  auto ctx = Prepare(*copy_program, 0);
W
Wu Yi 已提交
278 279 280
  RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
                     create_local_scope, create_vars, feed_holder_name,
                     fetch_holder_name);
281 282
}

Q
Qiao Longfei 已提交
283 284
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
    const ProgramDesc& program, int block_id) {
Y
Yu Yang 已提交
285 286 287 288 289 290
  auto* ctx = new ExecutorPrepareContext(program, block_id);
  PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
  auto& block = program.Block(block_id);
  for (auto& op_desc : block.AllOps()) {
    ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
Q
Qiao Longfei 已提交
291
  return std::unique_ptr<ExecutorPrepareContext>(ctx);
Y
Yu Yang 已提交
292 293
}

T
refine  
typhoonzero 已提交
294
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
T
typhoonzero 已提交
295 296 297 298 299 300 301 302 303 304 305 306 307 308
    const ProgramDesc& program, const std::vector<int>& block_ids) {
  std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
  for (auto& bid : block_ids) {
    auto* ctx = new ExecutorPrepareContext(program, bid);
    PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size());
    auto& block = program.Block(bid);
    for (auto& op_desc : block.AllOps()) {
      ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
    }
    result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx));
  }
  return result;
}

Y
Yu Yang 已提交
309 310 311 312 313 314
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
                                  bool create_local_scope, bool create_vars) {
  Scope* local_scope = scope;
  if (create_vars) {
    if (create_local_scope) {
      local_scope = &scope->NewScope();
315 316
    }
    CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
L
Liu Yiqun 已提交
317
  }
Y
Yu Yang 已提交
318 319 320

  for (auto& op : ctx->ops_) {
    VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
321
    op->Run(*local_scope, place_);
Y
Yang Yang 已提交
322

Y
Yu Yang 已提交
323 324 325 326 327
    if (FLAGS_benchmark) {
      VLOG(2) << "Memory used after operator " + op->Type() + " running: "
              << memory::memory_usage(place_);
    }
  }
328
  platform::DeviceContextPool::Instance().Get(place_)->Wait();
Y
Yu Yang 已提交
329 330
  if (create_vars && create_local_scope) {
    scope->DeleteScope(local_scope);
331 332 333
  } else {
    // Delete the local scopes created in operators.
    scope->DropKids();
Y
Yu Yang 已提交
334 335 336 337 338 339 340 341 342
  }
  if (FLAGS_benchmark) {
    VLOG(2) << "-------------------------------------------------------";
    VLOG(2) << "Memory used after deleting local scope: "
            << memory::memory_usage(place_);
    VLOG(2) << "-------------------------------------------------------";
  }
}

343 344
void Executor::RunPreparedContext(
    ExecutorPrepareContext* ctx, Scope* scope,
345
    std::map<std::string, const LoDTensor*>* feed_targets,
W
Wu Yi 已提交
346 347 348
    std::map<std::string, LoDTensor*>* fetch_targets, bool create_local_scope,
    bool create_vars, const std::string& feed_holder_name,
    const std::string& fetch_holder_name) {
349 350
  auto& global_block = ctx->prog_.Block(ctx->block_id_);

351
  PADDLE_ENFORCE(
352
      has_feed_operators(global_block, *feed_targets, feed_holder_name),
353 354
      "Program in ExecutorPrepareContext should has feed_ops.");
  PADDLE_ENFORCE(
355
      has_fetch_operators(global_block, *fetch_targets, fetch_holder_name),
356 357
      "Program in the prepared context should has fetch_ops.");

358 359 360 361 362
  // map the data of feed_targets to feed_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFeedOpType) {
      std::string feed_target_name = op->Output("Out")[0];
      int idx = boost::get<int>(op->GetAttr("col"));
363 364
      SetFeedVariable(scope, *(*feed_targets)[feed_target_name],
                      feed_holder_name, idx);
365 366 367
    }
  }

W
Wu Yi 已提交
368
  RunPreparedContext(ctx, scope, create_local_scope, create_vars);
369 370 371 372 373 374

  // obtain the data of fetch_targets from fetch_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFetchOpType) {
      std::string fetch_target_name = op->Input("X")[0];
      int idx = boost::get<int>(op->GetAttr("col"));
375
      *(*fetch_targets)[fetch_target_name] =
376 377 378 379 380
          GetFetchVariable(*scope, fetch_holder_name, idx);
    }
  }
}

Q
qijun 已提交
381 382
}  // namespace framework
}  // namespace paddle