program_desc.cc 2.8 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

F
fengjiayi 已提交
15 16
#include "paddle/framework/program_desc.h"
#include "paddle/framework/block_desc.h"
F
fengjiayi 已提交
17 18 19 20

namespace paddle {
namespace framework {

21 22 23
const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";

Y
Yu Yang 已提交
24
BlockDesc *ProgramDesc::AppendBlock(const BlockDesc &parent) {
25
  auto *b = desc_.add_blocks();
F
fengjiayi 已提交
26
  b->set_parent_idx(parent.ID());
27
  b->set_idx(desc_.blocks_size() - 1);
Y
Yu Yang 已提交
28
  blocks_.emplace_back(new BlockDesc(this, b));
F
fengjiayi 已提交
29 30 31
  return blocks_.back().get();
}

Y
Yu Yang 已提交
32
proto::ProgramDesc *ProgramDesc::Proto() {
F
fengjiayi 已提交
33
  for (auto &block : blocks_) {
34
    block->Flush();
F
fengjiayi 已提交
35
  }
36
  return &desc_;
F
fengjiayi 已提交
37 38
}

Y
Yu Yang 已提交
39
ProgramDesc::ProgramDesc() {
40
  auto *block = desc_.mutable_blocks()->Add();
41 42
  block->set_idx(kRootBlockIndex);
  block->set_parent_idx(kNoneBlockIndex);
Y
Yu Yang 已提交
43
  blocks_.emplace_back(new BlockDesc(this, block));
F
fengjiayi 已提交
44
}
Y
Yu Yang 已提交
45

Y
Yu Yang 已提交
46
ProgramDesc::ProgramDesc(const ProgramDesc &o) {
47
  desc_ = o.desc_;
Y
Yu Yang 已提交
48

49 50
  for (int i = 0; i < desc_.blocks_size(); ++i) {
    auto *block = desc_.mutable_blocks(i);
Y
Yu Yang 已提交
51
    blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this));
Y
Yu Yang 已提交
52 53
  }
}
54

Y
Yu Yang 已提交
55
ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
56 57
  desc_ = desc;
  for (auto &block_desc : *desc_.mutable_blocks()) {
Y
Yu Yang 已提交
58
    blocks_.emplace_back(new BlockDesc(this, &block_desc));
59 60 61
  }
}

Y
Yu Yang 已提交
62
ProgramDesc::ProgramDesc(const std::string &binary_str) {
63 64 65
  PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
                 "Fail to parse program_desc from binary string.");
  for (auto &block_desc : *desc_.mutable_blocks()) {
Y
Yu Yang 已提交
66
    blocks_.emplace_back(new BlockDesc(this, &block_desc));
67 68 69
  }
}

70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
const std::vector<std::string> ProgramDesc::GetFeedVarNames() {
  BlockDesc *global_block = blocks_[0].get();
  std::vector<std::string> feed_var_names;
  for (auto *op : global_block->AllOps()) {
    if (op->Type() == "feed") {
      feed_var_names.insert(feed_var_names.begin(), op->Output("Out")[0]);
    }
  }
  return feed_var_names;
}

const std::vector<std::string> ProgramDesc::GetFetchVarNames() {
  BlockDesc *global_block = blocks_[0].get();
  std::vector<std::string> fetch_var_names;
  for (auto *op : global_block->AllOps()) {
    if (op->Type() == "fetch") {
      fetch_var_names.push_back(op->Input("X")[0]);
    }
  }
  return fetch_var_names;
}

F
fengjiayi 已提交
92 93
}  // namespace framework
}  // namespace paddle