block_desc.cc 4.5 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/block_desc.h"
16
#include "paddle/framework/operator.h"
F
fengjiayi 已提交
17
#include "paddle/framework/program_desc.h"
F
fengjiayi 已提交
18 19 20 21

namespace paddle {
namespace framework {

Y
Yu Yang 已提交
22
VarDesc *BlockDesc::Var(const std::string &name) {
F
fengjiayi 已提交
23
  auto it = vars_.find(name);
D
Dong Zhihong 已提交
24
  if (it != vars_.end()) {
D
Dong Zhihong 已提交
25
    return it->second.get();
D
Dong Zhihong 已提交
26
  }
27
  need_update_ = true;
Y
Yu Yang 已提交
28
  auto *var = new VarDesc(name);
F
fengjiayi 已提交
29 30 31 32
  vars_[name].reset(var);
  return var;
}

Y
Yu Yang 已提交
33
VarDesc *BlockDesc::FindVar(const std::string &name) const {
F
fengjiayi 已提交
34
  auto it = vars_.find(name);
D
Dong Zhihong 已提交
35 36 37
  if (it == vars_.end()) {
    return nullptr;
  }
F
fengjiayi 已提交
38 39 40
  return it->second.get();
}

Y
Yu Yang 已提交
41
bool BlockDesc::HasVar(const std::string &name) const {
Q
qiaolongfei 已提交
42
  return vars_.find(name) != vars_.end();
Q
qiaolongfei 已提交
43 44
}

Y
Yu Yang 已提交
45
VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
46 47
  if (name == kEmptyVarName) return nullptr;

48 49 50 51 52 53 54 55
  auto it = vars_.find(name);
  if (it == vars_.end()) {
    return Parent() == kNoneBlockIndex ? nullptr
                                       : ParentBlock()->FindVarRecursive(name);
  }
  return it->second.get();
}

Y
Yu Yang 已提交
56 57
VarDesc *BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) {
  VarDesc *res = FindVarRecursive(name_bytes);
Y
Yang Yang(Tony) 已提交
58 59 60 61 62 63
  if (res == nullptr) {
    res = Var(name_bytes);
  }
  return res;
}

Y
Yu Yang 已提交
64
bool BlockDesc::HasVarRecursive(const std::string &name) const {
65 66 67
  return FindVarRecursive(name) != nullptr;
}

Y
Yu Yang 已提交
68 69
std::vector<VarDesc *> BlockDesc::AllVars() const {
  std::vector<VarDesc *> res;
F
fengjiayi 已提交
70 71 72 73 74 75
  for (const auto &p : vars_) {
    res.push_back(p.second.get());
  }
  return res;
}

Y
Yu Yang 已提交
76
OpDesc *BlockDesc::AppendOp() {
F
fengjiayi 已提交
77
  need_update_ = true;
Y
Yu Yang 已提交
78
  ops_.emplace_back(new OpDesc());
F
fengjiayi 已提交
79 80 81
  return ops_.back().get();
}

Y
Yu Yang 已提交
82
void BlockDesc::AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc) {
83 84 85 86
  need_update_ = true;
  ops_.emplace_back(std::move(op_desc));
}

Y
Yu Yang 已提交
87
OpDesc *BlockDesc::PrependOp() {
F
fengjiayi 已提交
88
  need_update_ = true;
Y
Yu Yang 已提交
89
  ops_.emplace_front(new OpDesc());
F
fengjiayi 已提交
90 91 92
  return ops_.front().get();
}

Y
Yu Yang 已提交
93 94
std::vector<OpDesc *> BlockDesc::AllOps() const {
  std::vector<OpDesc *> res;
F
fengjiayi 已提交
95 96 97 98 99 100
  for (const auto &op : ops_) {
    res.push_back(op.get());
  }
  return res;
}

Y
Yu Yang 已提交
101
void BlockDesc::Flush() {
102 103 104 105
  for (auto &op_desc : ops_) {
    op_desc->Flush();
  }

F
fengjiayi 已提交
106 107
  if (need_update_) {
    auto &op_field = *this->desc_->mutable_ops();
108
    this->ClearPBOps();
F
fengjiayi 已提交
109 110 111 112
    op_field.Reserve(static_cast<int>(ops_.size()));
    for (auto &op_desc : ops_) {
      op_field.AddAllocated(op_desc->Proto());
    }
F
Fix bug  
fengjiayi 已提交
113
    auto &var_field = *this->desc_->mutable_vars();
114
    this->ClearPBVars();
F
Fix bug  
fengjiayi 已提交
115 116 117 118
    var_field.Reserve(static_cast<int>(vars_.size()));
    for (auto &var_desc : vars_) {
      var_field.AddAllocated(var_desc.second->Proto());
    }
F
fengjiayi 已提交
119 120 121 122
    need_update_ = false;
  }
}

Y
Yu Yang 已提交
123
BlockDesc *BlockDesc::ParentBlock() const {
124
  if (this->desc_->parent_idx() == kNoneBlockIndex) {
F
fengjiayi 已提交
125 126
    return nullptr;
  }
127
  return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
F
fengjiayi 已提交
128 129
}

Y
Yu Yang 已提交
130
proto::BlockDesc *BlockDesc::Proto() {
131 132 133
  Flush();
  return desc_;
}
134

Y
Yu Yang 已提交
135
BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
136
    : prog_(prog), desc_(desc), need_update_(false) {
137
  for (const proto::VarDesc &var_desc : desc_->vars()) {
Y
Yu Yang 已提交
138
    vars_[var_desc.name()].reset(new VarDesc(var_desc));
139
  }
140
  for (const proto::OpDesc &op_desc : desc_->ops()) {
Y
Yu Yang 已提交
141
    ops_.emplace_back(new OpDesc(op_desc, prog));
142 143 144
  }
}

Y
Yu Yang 已提交
145 146
BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
                     ProgramDesc *prog)
Y
Yu Yang 已提交
147 148 149
    : prog_(prog), desc_(desc) {
  need_update_ = true;
  for (auto &op : other.ops_) {
Y
Yu Yang 已提交
150
    ops_.emplace_back(new OpDesc(*op));
Y
Yu Yang 已提交
151 152 153
  }

  for (auto &it : other.vars_) {
Y
Yu Yang 已提交
154
    auto *var = new VarDesc(*it.second);
Y
Yu Yang 已提交
155 156 157
    vars_[it.first].reset(var);
  }
}
158

Y
Yu Yang 已提交
159
void BlockDesc::ClearPBOps() {
160 161 162 163 164 165 166
  auto ops = this->desc_->mutable_ops();
  while (!ops->empty()) {
    // we do not own the OpDesc, so release the ownership.
    ops->ReleaseLast();
  }
}

Y
Yu Yang 已提交
167
void BlockDesc::ClearPBVars() {
168 169 170 171 172 173 174
  auto vars = this->desc_->mutable_vars();
  while (!vars->empty()) {
    // we do not own the VarDesc, so release the ownership.
    vars->ReleaseLast();
  }
}

F
fengjiayi 已提交
175 176
}  // namespace framework
}  // namespace paddle