executor.cc 14.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/executor.h"
Y
Yang Yang 已提交
16

17
#include "paddle/fluid/framework/channel.h"
Y
Yi Wang 已提交
18 19 20 21 22 23
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h"
X
Xin Pan 已提交
24
#include "paddle/fluid/platform/profiler.h"
Y
Yang Yu 已提交
25

D
dzhwinter 已提交
26
DECLARE_bool(benchmark);
Q
qijun 已提交
27 28 29

namespace paddle {
namespace framework {
X
Xin Pan 已提交
30 31 32 33 34
namespace {
// block id starts from 0. This id is used to represent the codeblock
// wrapping the first block 0.
int kProgramId = -1;
}  // namespace
Q
qijun 已提交
35

Q
Qiao Longfei 已提交
36 37 38
ExecutorPrepareContext::ExecutorPrepareContext(
    const framework::ProgramDesc& prog, size_t block_id)
    : prog_(prog), block_id_(block_id) {}
Y
Yu Yang 已提交
39

Q
Qiao Longfei 已提交
40 41 42
ExecutorPrepareContext::~ExecutorPrepareContext() {
  VLOG(5) << "destroy ExecutorPrepareContext";
}
Y
Yu Yang 已提交
43

D
dzhwinter 已提交
44
Executor::Executor(const platform::Place& place) : place_(place) {}
Q
qijun 已提交
45

Y
Stash  
Yu Yang 已提交
46
void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
47
  if (var_type == proto::VarType::LOD_TENSOR) {
Q
QI JUN 已提交
48
    var->GetMutable<LoDTensor>();
49
  } else if (var_type == proto::VarType::SELECTED_ROWS) {
Q
QI JUN 已提交
50
    var->GetMutable<SelectedRows>();
51
  } else if (var_type == proto::VarType::FEED_MINIBATCH) {
Q
QI JUN 已提交
52
    var->GetMutable<FeedFetchList>();
53
  } else if (var_type == proto::VarType::FETCH_LIST) {
Q
QI JUN 已提交
54
    var->GetMutable<FeedFetchList>();
55
  } else if (var_type == proto::VarType::STEP_SCOPES) {
Y
Yu Yang 已提交
56
    var->GetMutable<std::vector<framework::Scope>>();
57
  } else if (var_type == proto::VarType::LOD_RANK_TABLE) {
Y
Yu Yang 已提交
58
    var->GetMutable<LoDRankTable>();
59
  } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
Y
Yu Yang 已提交
60
    var->GetMutable<LoDTensorArray>();
61
  } else if (var_type == proto::VarType::PLACE_LIST) {
Y
Yang Yu 已提交
62
    var->GetMutable<platform::PlaceList>();
63
  } else if (var_type == proto::VarType::READER) {
F
fengjiayi 已提交
64
    var->GetMutable<ReaderHolder>();
65 66
  } else if (var_type == proto::VarType::CHANNEL) {
    var->GetMutable<ChannelHolder>();
T
typhoonzero 已提交
67 68
  } else if (var_type == proto::VarType::RAW) {
    // GetMutable will be called in operator
Q
QI JUN 已提交
69 70
  } else {
    PADDLE_THROW(
Y
Yu Yang 已提交
71
        "Variable type %d is not in "
F
fengjiayi 已提交
72
        "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
T
typhoonzero 已提交
73
        "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
Y
Yu Yang 已提交
74
        var_type);
Q
QI JUN 已提交
75 76 77
  }
}

L
Liu Yiqun 已提交
78 79 80
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
                               int block_id) {
  auto& global_block = pdesc.Block(block_id);
81 82 83 84 85 86 87 88 89 90 91 92 93 94

  const Scope* ancestor_scope = scope;
  while (ancestor_scope->parent()) {
    ancestor_scope = ancestor_scope->parent();
  }

  if (ancestor_scope != scope) {
    for (auto& var : global_block.AllVars()) {
      if (var->Name() == framework::kEmptyVarName) {
        continue;
      }

      if (var->Persistable()) {
        auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
95
        InitializeVariable(ptr, var->GetType());
96 97 98 99
        VLOG(3) << "Create Variable " << var->Name()
                << " global, which pointer is " << ptr;
      } else {
        auto* ptr = scope->Var(var->Name());
100
        InitializeVariable(ptr, var->GetType());
101 102 103 104 105 106 107
        VLOG(3) << "Create Variable " << var->Name()
                << " locally, which pointer is " << ptr;
      }
    }
  } else {
    for (auto& var : global_block.AllVars()) {
      auto* ptr = scope->Var(var->Name());
108
      InitializeVariable(ptr, var->GetType());
109 110 111 112 113 114
      VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
              << ptr;
    }
  }
}

Y
Yu Yang 已提交
115
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
T
typhoonzero 已提交
116
                   bool create_local_scope, bool create_vars) {
X
Xin Pan 已提交
117
  platform::RecordBlock b(block_id);
Q
Qiao Longfei 已提交
118 119
  auto ctx = Prepare(pdesc, block_id);
  RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
Q
qijun 已提交
120 121
}

122 123 124 125 126 127 128
// Check whether the block already has feed operators and feed_holder.
// Return false if the block does not have any feed operators.
// If some feed operators have been prepended to the block, check that
// the info contained in these feed operators matches the feed_targets
// and feed_holder_name. Raise exception when any mismatch is found.
// Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators(
129
    const BlockDesc& block,
L
Liu Yiqun 已提交
130
    const std::map<std::string, const LoDTensor*>& feed_targets,
131 132
    const std::string& feed_holder_name) {
  size_t feed_count = 0;
133
  for (auto* op : block.AllOps()) {
134 135
    if (op->Type() == kFeedOpType) {
      feed_count++;
L
Liu Yiqun 已提交
136
      // The input variable's name of feed_op should be feed_holder_name.
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
      PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name,
                        "Input to feed op should be '%s'", feed_holder_name);
      std::string feed_target_name = op->Output("Out")[0];
      PADDLE_ENFORCE(
          feed_targets.find(feed_target_name) != feed_targets.end(),
          "Feed operator output name '%s' cannot be found in 'feed_targets'",
          feed_target_name);
    }
  }

  if (feed_count > 0) {
    PADDLE_ENFORCE_EQ(
        feed_count, feed_targets.size(),
        "The number of feed operators should match 'feed_targets'");

152
    if (!feed_holder_name.empty()) {
L
Liu Yiqun 已提交
153
      // When feed operator are present, so should be feed_holder.
154 155 156 157 158 159 160
      auto var = block.FindVar(feed_holder_name);
      PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
                              feed_holder_name);
      PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
                        "'%s' variable should be 'FEED_MINIBATCH' type",
                        feed_holder_name);
    }
161 162 163 164 165 166 167 168 169 170 171 172
  }

  return feed_count > 0;
}

// Check whether the block already has fetch operators and fetch_holder.
// Return false if the block does not have any fetch operators.
// If some fetch operators have been appended to the block, check that
// the info contained in these fetch operators matches the fetch_targets
// and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators(
L
Liu Yiqun 已提交
173 174
    const BlockDesc& block,
    const std::map<std::string, LoDTensor*>& fetch_targets,
175 176
    const std::string& fetch_holder_name) {
  size_t fetch_count = 0;
177
  for (auto* op : block.AllOps()) {
178 179
    if (op->Type() == kFetchOpType) {
      fetch_count++;
L
Liu Yiqun 已提交
180
      // The output variable's name of fetch_op should be fetch_holder_name.
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
      PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name,
                        "Output of fetch op should be '%s'", fetch_holder_name);
      std::string fetch_target_name = op->Input("X")[0];
      PADDLE_ENFORCE(
          fetch_targets.find(fetch_target_name) != fetch_targets.end(),
          "Fetch operator input name '%s' cannot be found in 'fetch_targets'",
          fetch_target_name);
    }
  }

  if (fetch_count > 0) {
    PADDLE_ENFORCE_EQ(
        fetch_count, fetch_targets.size(),
        "The number of fetch operators should match 'fetch_targets'");

196
    if (!fetch_holder_name.empty()) {
L
Liu Yiqun 已提交
197
      // When fetch operator are present, so should be fetch_holder.
198 199 200 201 202 203 204
      auto var = block.FindVar(fetch_holder_name);
      PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
                              fetch_holder_name);
      PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
                        "'%s' variable should be 'FETCH_LIST' type",
                        fetch_holder_name);
    }
205 206 207 208 209 210
  }

  return fetch_count > 0;
}

void Executor::Run(const ProgramDesc& program, Scope* scope,
211 212
                   std::map<std::string, const LoDTensor*>* feed_targets,
                   std::map<std::string, LoDTensor*>* fetch_targets,
W
Wu Yi 已提交
213 214
                   bool create_local_scope, bool create_vars,
                   const std::string& feed_holder_name,
215
                   const std::string& fetch_holder_name) {
X
Xin Pan 已提交
216
  platform::RecordBlock b(kProgramId);
217
  bool has_feed_ops =
218
      has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
219
  bool has_fetch_ops =
220
      has_fetch_operators(program.Block(0), *fetch_targets, fetch_holder_name);
221 222 223 224 225 226

  ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
  if (!has_feed_ops || !has_fetch_ops) {
    copy_program = std::unique_ptr<ProgramDesc>(new ProgramDesc(program)).get();
  }

227 228
  auto* global_block = copy_program->MutableBlock(0);

229
  if (!has_feed_ops) {
230 231
    // create feed_holder variable
    auto* feed_holder = global_block->Var(feed_holder_name);
232
    feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
233 234 235
    feed_holder->SetPersistable(true);

    int i = 0;
236
    for (auto& feed_target : (*feed_targets)) {
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
      std::string var_name = feed_target.first;
      VLOG(3) << "feed target's name: " << var_name;

      // prepend feed op
      auto* op = global_block->PrependOp();
      op->SetType(kFeedOpType);
      op->SetInput("X", {feed_holder_name});
      op->SetOutput("Out", {var_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

252
  if (!has_fetch_ops) {
253 254
    // create fetch_holder variable
    auto* fetch_holder = global_block->Var(fetch_holder_name);
255
    fetch_holder->SetType(proto::VarType::FETCH_LIST);
256 257 258
    fetch_holder->SetPersistable(true);

    int i = 0;
259
    for (auto& fetch_target : (*fetch_targets)) {
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
      std::string var_name = fetch_target.first;
      VLOG(3) << "fetch target's name: " << var_name;

      // append fetch op
      auto* op = global_block->AppendOp();
      op->SetType(kFetchOpType);
      op->SetInput("X", {var_name});
      op->SetOutput("Out", {fetch_holder_name});
      op->SetAttr("col", {static_cast<int>(i)});
      op->CheckAttrs();

      i++;
    }
  }

275
  auto ctx = Prepare(*copy_program, 0);
W
Wu Yi 已提交
276 277 278
  RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
                     create_local_scope, create_vars, feed_holder_name,
                     fetch_holder_name);
279 280
}

Q
Qiao Longfei 已提交
281 282
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
    const ProgramDesc& program, int block_id) {
Y
Yu Yang 已提交
283 284 285 286 287 288
  auto* ctx = new ExecutorPrepareContext(program, block_id);
  PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
  auto& block = program.Block(block_id);
  for (auto& op_desc : block.AllOps()) {
    ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
Q
Qiao Longfei 已提交
289
  return std::unique_ptr<ExecutorPrepareContext>(ctx);
Y
Yu Yang 已提交
290 291
}

T
refine  
typhoonzero 已提交
292
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
T
typhoonzero 已提交
293 294 295 296 297 298 299 300 301 302 303 304 305 306
    const ProgramDesc& program, const std::vector<int>& block_ids) {
  std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
  for (auto& bid : block_ids) {
    auto* ctx = new ExecutorPrepareContext(program, bid);
    PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size());
    auto& block = program.Block(bid);
    for (auto& op_desc : block.AllOps()) {
      ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
    }
    result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx));
  }
  return result;
}

Y
Yu Yang 已提交
307 308 309 310 311 312
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
                                  bool create_local_scope, bool create_vars) {
  Scope* local_scope = scope;
  if (create_vars) {
    if (create_local_scope) {
      local_scope = &scope->NewScope();
313 314
    }
    CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
L
Liu Yiqun 已提交
315
  }
Y
Yu Yang 已提交
316 317 318

  for (auto& op : ctx->ops_) {
    VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
319
    op->Run(*local_scope, place_);
Y
Yang Yang 已提交
320

Y
Yu Yang 已提交
321 322 323 324 325
    if (FLAGS_benchmark) {
      VLOG(2) << "Memory used after operator " + op->Type() + " running: "
              << memory::memory_usage(place_);
    }
  }
326
  platform::DeviceContextPool::Instance().Get(place_)->Wait();
Y
Yu Yang 已提交
327 328
  if (create_vars && create_local_scope) {
    scope->DeleteScope(local_scope);
329 330 331
  } else {
    // Delete the local scopes created in operators.
    scope->DropKids();
Y
Yu Yang 已提交
332 333 334 335 336 337 338 339 340
  }
  if (FLAGS_benchmark) {
    VLOG(2) << "-------------------------------------------------------";
    VLOG(2) << "Memory used after deleting local scope: "
            << memory::memory_usage(place_);
    VLOG(2) << "-------------------------------------------------------";
  }
}

341 342
void Executor::RunPreparedContext(
    ExecutorPrepareContext* ctx, Scope* scope,
343
    std::map<std::string, const LoDTensor*>* feed_targets,
W
Wu Yi 已提交
344 345 346
    std::map<std::string, LoDTensor*>* fetch_targets, bool create_local_scope,
    bool create_vars, const std::string& feed_holder_name,
    const std::string& fetch_holder_name) {
347 348
  auto& global_block = ctx->prog_.Block(ctx->block_id_);

349
  PADDLE_ENFORCE(
350
      has_feed_operators(global_block, *feed_targets, feed_holder_name),
351 352
      "Program in ExecutorPrepareContext should has feed_ops.");
  PADDLE_ENFORCE(
353
      has_fetch_operators(global_block, *fetch_targets, fetch_holder_name),
354 355
      "Program in the prepared context should has fetch_ops.");

356 357 358 359 360
  // map the data of feed_targets to feed_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFeedOpType) {
      std::string feed_target_name = op->Output("Out")[0];
      int idx = boost::get<int>(op->GetAttr("col"));
361 362
      SetFeedVariable(scope, *(*feed_targets)[feed_target_name],
                      feed_holder_name, idx);
363 364 365
    }
  }

W
Wu Yi 已提交
366
  RunPreparedContext(ctx, scope, create_local_scope, create_vars);
367 368 369 370 371 372

  // obtain the data of fetch_targets from fetch_holder
  for (auto* op : global_block.AllOps()) {
    if (op->Type() == kFetchOpType) {
      std::string fetch_target_name = op->Input("X")[0];
      int idx = boost::get<int>(op->GetAttr("col"));
373
      *(*fetch_targets)[fetch_target_name] =
374 375 376 377 378
          GetFetchVariable(*scope, fetch_holder_name, idx);
    }
  }
}

Q
qijun 已提交
379 380
}  // namespace framework
}  // namespace paddle