program_desc.cc 9.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/program_desc.h"
16

17 18 19 20
extern "C" {
#include <xxhash.h>
}

21
#include <algorithm>
Y
Yi Wang 已提交
22
#include "paddle/fluid/framework/feed_fetch_type.h"
23 24 25
#include "paddle/fluid/framework/op_version_proto.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_converter.h"
X
version  
Xin Pan 已提交
26
#include "paddle/fluid/framework/version.h"
F
fengjiayi 已提交
27 28 29 30

namespace paddle {
namespace framework {

Y
Yu Yang 已提交
31
BlockDesc *ProgramDesc::AppendBlock(const BlockDesc &parent) {
32
  auto *b = desc_.add_blocks();
F
fengjiayi 已提交
33
  b->set_parent_idx(parent.ID());
34
  b->set_idx(desc_.blocks_size() - 1);
Y
Yu Yang 已提交
35
  blocks_.emplace_back(new BlockDesc(this, b));
F
fengjiayi 已提交
36 37 38
  return blocks_.back().get();
}

39
void ProgramDesc::Flush() {
F
fengjiayi 已提交
40
  for (auto &block : blocks_) {
41
    block->Flush();
F
fengjiayi 已提交
42
  }
43 44 45 46
}

proto::ProgramDesc *ProgramDesc::Proto() {
  Flush();
47
  return &desc_;
F
fengjiayi 已提交
48 49
}

50 51
proto::OpVersionMap *ProgramDesc::OpVersionMap() {
  return desc_.mutable_op_version_map();
52 53
}

54 55
bool ProgramDesc::HasOpVersionMap() const { return desc_.has_op_version_map(); }

X
clean  
Xin Pan 已提交
56
int64_t ProgramDesc::Version() const { return desc_.version().version(); }
X
version  
Xin Pan 已提交
57

58 59
bool ProgramDesc::HasVersion() const { return desc_.has_version(); }

60 61 62 63
void ProgramDesc::SetVersion(const int64_t version) {
  desc_.mutable_version()->set_version(version);
}

Y
Yu Yang 已提交
64
ProgramDesc::ProgramDesc() {
65
  SetVersion(kCurProgramVersion);
66
  auto *block = desc_.mutable_blocks()->Add();
67 68
  block->set_idx(kRootBlockIndex);
  block->set_parent_idx(kNoneBlockIndex);
Y
Yu Yang 已提交
69
  blocks_.emplace_back(new BlockDesc(this, block));
F
fengjiayi 已提交
70
}
Y
Yu Yang 已提交
71

Y
Yu Yang 已提交
72
ProgramDesc::ProgramDesc(const ProgramDesc &o) {
73
  desc_ = o.desc_;
74
  std::vector<framework::BlockDesc *> old_block_desc;
75 76
  for (int i = 0; i < desc_.blocks_size(); ++i) {
    auto *block = desc_.mutable_blocks(i);
Y
Yu Yang 已提交
77
    blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this));
78 79
    // record all block desc's ptr from origin program
    old_block_desc.emplace_back(o.blocks_[i].get());
Y
Yu Yang 已提交
80
  }
F
fengjiayi 已提交
81 82 83 84
  for (size_t block_id = 0; block_id < blocks_.size(); ++block_id) {
    auto all_ops = blocks_[block_id]->AllOps();
    for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) {
      auto &op = all_ops[op_id];
X
Xin Pan 已提交
85

F
fengjiayi 已提交
86 87
      for (const std::string &attr_name : op->AttrNames()) {
        if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) {
88
          framework::BlockDesc *block_desc =
R
Ruibiao Chen 已提交
89
              PADDLE_GET_CONST(framework::BlockDesc *, op->GetAttr(attr_name));
90 91
          if (std::find(old_block_desc.begin(),
                        old_block_desc.end(),
92 93 94 95 96 97 98 99 100 101 102 103
                        block_desc) != old_block_desc.end()) {
            // The block is owned by the origin program. Just use id to get
            // the corresponding block.
            int sub_block_id =
                o.Block(block_id).Op(op_id)->GetBlockAttrId(attr_name);
            op->SetBlockAttr(attr_name, MutableBlock(sub_block_id));
          } else {
            // The block is not owned by the origin program. Should copy
            // the real block desc instead of logical block in the program.
            VLOG(3) << "Set op's block attr with the original block";
            op->SetBlockAttr(attr_name, block_desc);
          }
X
Xin Pan 已提交
104 105 106 107 108 109 110 111
        } else if (op->GetAttrType(attr_name) == proto::AttrType::BLOCKS) {
          std::vector<int> sub_block_ids =
              o.Block(block_id).Op(op_id)->GetBlocksAttrIds(attr_name);
          std::vector<BlockDesc *> block_descs;
          for (int block_id : sub_block_ids) {
            block_descs.push_back(MutableBlock(block_id));
          }
          op->SetBlocksAttr(attr_name, block_descs);
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
        } else if (op->GetAttrType(attr_name, true) == proto::AttrType::VAR) {
          VarDesc *var_desc =
              PADDLE_GET_CONST(VarDesc *, op->GetAttr(attr_name, true));
          op->SetVarAttr(attr_name,
                         o.Block(block_id).FindVarRecursive(var_desc->Name()));
        } else if (op->GetAttrType(attr_name, true) == proto::AttrType::VARS) {
          std::vector<VarDesc *> vars_desc = PADDLE_GET_CONST(
              std::vector<VarDesc *>, op->GetAttr(attr_name, true));
          std::vector<VarDesc *> new_vars_desc;
          std::transform(
              vars_desc.begin(),
              vars_desc.end(),
              std::back_inserter(new_vars_desc),
              [&](VarDesc *var_desc) {
                return o.Block(block_id).FindVarRecursive(var_desc->Name());
              });
          op->SetVarsAttr(attr_name, new_vars_desc);
K
Kexin Zhao 已提交
129 130 131 132
        }
      }
    }
  }
Y
Yu Yang 已提交
133
}
134

Y
Yu Yang 已提交
135
ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
136
  desc_ = desc;
X
Xin Pan 已提交
137
  InitFromProto();
138 139
}

X
Xin Pan 已提交
140 141 142 143 144 145
void ProgramDesc::CopyFrom(const proto::ProgramDesc &desc) {
  blocks_.clear();
  desc_ = desc;
  InitFromProto();
}

Y
Yu Yang 已提交
146
ProgramDesc::ProgramDesc(const std::string &binary_str) {
147 148
  PADDLE_ENFORCE_EQ(desc_.ParseFromString(binary_str),
                    true,
149 150
                    platform::errors::InvalidArgument(
                        "Failed to parse program_desc from binary string."));
X
Xin Pan 已提交
151
  InitFromProto();
152
  scalar::ConvertProgram(this);
X
Xin Pan 已提交
153 154 155
}

void ProgramDesc::InitFromProto() {
156
  for (auto &block_desc : *desc_.mutable_blocks()) {
Y
Yu Yang 已提交
157
    blocks_.emplace_back(new BlockDesc(this, &block_desc));
158
  }
F
fengjiayi 已提交
159 160 161
  for (auto &block : blocks_) {
    for (auto *op : block->AllOps()) {
      for (const auto &attr : op->Proto()->attrs()) {
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
        if (attr.type() == proto::AttrType::VAR) {
          std::string var_name = attr.var_name();
          VLOG(3) << "InitFromProto: SetVarAttr " << attr.name() << " from "
                  << var_name;
          op->SetVarAttr(attr.name(), op->FindVarRecursive(var_name));
        } else if (attr.type() == proto::AttrType::VARS) {
          auto vars_name = attr.vars_name();
          std::vector<VarDesc *> vars_desc;
          for (auto &var_name : vars_name) {
            VLOG(3) << "InitFromProto: SetVarsAttr " << attr.name() << " from "
                    << var_name;
            vars_desc.emplace_back(op->FindVarRecursive(var_name));
          }
          op->SetVarsAttr(attr.name(), vars_desc);
        } else if (attr.type() == proto::AttrType::BLOCK) {
F
fengjiayi 已提交
177 178
          size_t blk_idx = attr.block_idx();
          op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
X
Xin Pan 已提交
179 180 181 182 183 184 185
        } else if (attr.type() == proto::AttrType::BLOCKS) {
          auto blks_idx = attr.blocks_idx();
          std::vector<BlockDesc *> block_descs;
          for (int blk_idx : blks_idx) {
            block_descs.push_back(this->MutableBlock(blk_idx));
          }
          op->SetBlocksAttr(attr.name(), block_descs);
F
fengjiayi 已提交
186 187 188 189
        }
      }
    }
  }
190 191
}

192
const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
193
  auto &global_block = Block(0);
194 195
  // The order of feed_target_names must follow the index specified in `col`.
  // since feed operator's order doesn't necessary follow 'col'.
196
  std::vector<std::string> feed_target_names;
197
  for (auto *op : global_block.AllOps()) {
198
    if (op->Type() == kFeedOpType) {
R
Ruibiao Chen 已提交
199
      size_t col = PADDLE_GET_CONST(int, op->GetAttr("col"));
200 201 202 203
      if (col >= feed_target_names.size()) {
        feed_target_names.resize(col + 1);
      }
      feed_target_names[col] = op->Output("Out")[0];
204 205
    }
  }
206
  return feed_target_names;
207 208
}

209
const std::vector<std::string> ProgramDesc::GetFetchTargetNames() {
210
  auto &global_block = Block(0);
211 212
  // The order of fetch_target_names must follow the index specified in `col`.
  // since fetch operator's order doesn't necessary follow 'col'.
213
  std::vector<std::string> fetch_target_names;
214
  for (auto *op : global_block.AllOps()) {
215
    if (op->Type() == kFetchOpType) {
R
Ruibiao Chen 已提交
216
      size_t col = PADDLE_GET_CONST(int, op->GetAttr("col"));
217 218 219 220
      if (col >= fetch_target_names.size()) {
        fetch_target_names.resize(col + 1);
      }
      fetch_target_names[col] = op->Input("X")[0];
221 222
    }
  }
223
  return fetch_target_names;
224 225
}

226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
void ProgramDesc::SetFeedHolderName(const std::string &feed_holder_name) {
  auto *global_block = MutableBlock(0);
  int index = 0;
  for (auto *op : global_block->AllOps()) {
    if (op->Type() == kFeedOpType) {
      // Unify the input's name of all feed_ops to feed_holder_name
      global_block->RemoveVar(op->Input("X")[0]);
      op->SetInput("X", {feed_holder_name});
      op->SetAttr("col", {index});
      op->CheckAttrs();
      index++;
    }
  }

  auto *feed_holder = global_block->Var(feed_holder_name);
  feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
  feed_holder->SetPersistable(true);
}

void ProgramDesc::SetFetchHolderName(const std::string &fetch_holder_name) {
  auto *global_block = MutableBlock(0);
  int index = 0;
  for (auto *op : global_block->AllOps()) {
    if (op->Type() == kFetchOpType) {
      // Unify the output's name of all fetch_ops to fetch_holder_name
      global_block->RemoveVar(op->Output("Out")[0]);
      op->SetOutput("Out", {fetch_holder_name});
      op->SetAttr("col", {index});
      op->CheckAttrs();
      index++;
    }
  }

  auto *fetch_holder = global_block->Var(fetch_holder_name);
  fetch_holder->SetType(proto::VarType::FETCH_LIST);
  fetch_holder->SetPersistable(true);
}

L
Leo Chen 已提交
264 265 266 267 268
std::string ProgramDesc::CachedHashString() {
  std::string serialize_str;
  if (cached_hash_str_.size() == 0 || NeedUpdate()) {
    Flush();
    desc_.SerializePartialToString(&serialize_str);
269 270 271
    // non-cryptographic is enough
    cached_hash_str_ =
        std::to_string(XXH64(serialize_str.c_str(), serialize_str.size(), 1));
L
Leo Chen 已提交
272 273 274 275
  }
  return cached_hash_str_;
}

L
Leo Chen 已提交
276 277 278 279 280 281 282 283 284 285 286
bool ProgramDesc::NeedUpdate() const {
  bool need = false;
  for (auto &block : blocks_) {
    if (block->NeedUpdate()) {
      need = true;
      break;
    }
  }
  return need;
}

F
fengjiayi 已提交
287 288
}  // namespace framework
}  // namespace paddle