block_desc.cc 5.1 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/block_desc.h"
16
#include "paddle/framework/operator.h"
F
fengjiayi 已提交
17
#include "paddle/framework/program_desc.h"
F
fengjiayi 已提交
18 19 20 21

namespace paddle {
namespace framework {

Y
Yu Yang 已提交
22
VarDesc *BlockDesc::Var(const std::string &name) {
F
fengjiayi 已提交
23
  auto it = vars_.find(name);
D
Dong Zhihong 已提交
24
  if (it != vars_.end()) {
D
Dong Zhihong 已提交
25
    return it->second.get();
D
Dong Zhihong 已提交
26
  }
27
  need_update_ = true;
Y
Yu Yang 已提交
28
  auto *var = new VarDesc(name);
F
fengjiayi 已提交
29 30 31 32
  vars_[name].reset(var);
  return var;
}

Y
Yu Yang 已提交
33
VarDesc *BlockDesc::FindVar(const std::string &name) const {
F
fengjiayi 已提交
34
  auto it = vars_.find(name);
D
Dong Zhihong 已提交
35 36 37
  if (it == vars_.end()) {
    return nullptr;
  }
F
fengjiayi 已提交
38 39 40
  return it->second.get();
}

Y
Yu Yang 已提交
41
bool BlockDesc::HasVar(const std::string &name) const {
Q
qiaolongfei 已提交
42
  return vars_.find(name) != vars_.end();
Q
qiaolongfei 已提交
43 44
}

Y
Yu Yang 已提交
45
VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
46 47
  if (name == kEmptyVarName) return nullptr;

48 49 50 51 52 53 54 55
  auto it = vars_.find(name);
  if (it == vars_.end()) {
    return Parent() == kNoneBlockIndex ? nullptr
                                       : ParentBlock()->FindVarRecursive(name);
  }
  return it->second.get();
}

Y
Yang Yu 已提交
56
VarDesc &BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) {
Y
Yu Yang 已提交
57
  VarDesc *res = FindVarRecursive(name_bytes);
Y
Yang Yang(Tony) 已提交
58 59 60
  if (res == nullptr) {
    res = Var(name_bytes);
  }
Y
Yang Yu 已提交
61
  return *res;
Y
Yang Yang(Tony) 已提交
62 63
}

Y
Yu Yang 已提交
64
bool BlockDesc::HasVarRecursive(const std::string &name) const {
65 66 67
  return FindVarRecursive(name) != nullptr;
}

Y
Yu Yang 已提交
68 69
std::vector<VarDesc *> BlockDesc::AllVars() const {
  std::vector<VarDesc *> res;
F
fengjiayi 已提交
70 71 72 73 74 75
  for (const auto &p : vars_) {
    res.push_back(p.second.get());
  }
  return res;
}

Y
Yu Yang 已提交
76
OpDesc *BlockDesc::AppendOp() {
F
fengjiayi 已提交
77
  need_update_ = true;
78
  ops_.emplace_back(new OpDesc(this));
F
fengjiayi 已提交
79 80 81
  return ops_.back().get();
}

Y
Yu Yang 已提交
82
void BlockDesc::AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc) {
83 84 85 86
  need_update_ = true;
  ops_.emplace_back(std::move(op_desc));
}

Y
Yu Yang 已提交
87
OpDesc *BlockDesc::PrependOp() {
F
fengjiayi 已提交
88
  need_update_ = true;
89
  ops_.emplace_front(new OpDesc(this));
F
fengjiayi 已提交
90 91 92
  return ops_.front().get();
}

T
typhoonzero 已提交
93
void BlockDesc::RemoveOp(size_t s, size_t e) {
T
typhoonzero 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107
  if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
    return;
  }
  need_update_ = true;
  for (auto it = ops_.begin() + s; it != ops_.begin() + e; it++) {
    auto names = (*it)->InputArgumentNames();
    for (auto n : names) {
      // TODO(typhoonzero): delete vars if no other op use it.
      VLOG(3) << "deleting var " << n;
    }
  }
  ops_.erase(ops_.begin() + s, ops_.begin() + e);
}

Y
Yu Yang 已提交
108 109
std::vector<OpDesc *> BlockDesc::AllOps() const {
  std::vector<OpDesc *> res;
F
fengjiayi 已提交
110 111 112 113 114 115
  for (const auto &op : ops_) {
    res.push_back(op.get());
  }
  return res;
}

Y
Yu Yang 已提交
116
void BlockDesc::Flush() {
117 118 119 120
  for (auto &op_desc : ops_) {
    op_desc->Flush();
  }

F
fengjiayi 已提交
121 122
  if (need_update_) {
    auto &op_field = *this->desc_->mutable_ops();
123
    this->ClearPBOps();
F
fengjiayi 已提交
124 125 126 127
    op_field.Reserve(static_cast<int>(ops_.size()));
    for (auto &op_desc : ops_) {
      op_field.AddAllocated(op_desc->Proto());
    }
F
Fix bug  
fengjiayi 已提交
128
    auto &var_field = *this->desc_->mutable_vars();
129
    this->ClearPBVars();
F
Fix bug  
fengjiayi 已提交
130 131 132 133
    var_field.Reserve(static_cast<int>(vars_.size()));
    for (auto &var_desc : vars_) {
      var_field.AddAllocated(var_desc.second->Proto());
    }
F
fengjiayi 已提交
134 135 136 137
    need_update_ = false;
  }
}

Y
Yu Yang 已提交
138
BlockDesc *BlockDesc::ParentBlock() const {
139
  if (this->desc_->parent_idx() == kNoneBlockIndex) {
F
fengjiayi 已提交
140 141
    return nullptr;
  }
142
  return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
F
fengjiayi 已提交
143 144
}

Y
Yu Yang 已提交
145
proto::BlockDesc *BlockDesc::Proto() {
146 147 148
  Flush();
  return desc_;
}
149

Y
Yu Yang 已提交
150
BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
151
    : prog_(prog), desc_(desc), need_update_(false) {
152
  for (const proto::VarDesc &var_desc : desc_->vars()) {
Y
Yu Yang 已提交
153
    vars_[var_desc.name()].reset(new VarDesc(var_desc));
154
  }
155
  for (const proto::OpDesc &op_desc : desc_->ops()) {
156
    ops_.emplace_back(new OpDesc(op_desc, prog, this));
157
  }
158 159
  std::cout << "Constructed block idx " << desc->idx() << " from protobuf str"
            << std::endl;
160 161
}

Y
Yu Yang 已提交
162 163
BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
                     ProgramDesc *prog)
Y
Yu Yang 已提交
164 165 166
    : prog_(prog), desc_(desc) {
  need_update_ = true;
  for (auto &op : other.ops_) {
167
    ops_.emplace_back(new OpDesc(*op, this));
Y
Yu Yang 已提交
168 169 170
  }

  for (auto &it : other.vars_) {
Y
Yu Yang 已提交
171
    auto *var = new VarDesc(*it.second);
Y
Yu Yang 已提交
172 173 174
    vars_[it.first].reset(var);
  }
}
175

Y
Yu Yang 已提交
176
void BlockDesc::ClearPBOps() {
177 178 179 180 181 182 183
  auto ops = this->desc_->mutable_ops();
  while (!ops->empty()) {
    // we do not own the OpDesc, so release the ownership.
    ops->ReleaseLast();
  }
}

Y
Yu Yang 已提交
184
void BlockDesc::ClearPBVars() {
185 186 187 188 189 190 191
  auto vars = this->desc_->mutable_vars();
  while (!vars->empty()) {
    // we do not own the VarDesc, so release the ownership.
    vars->ReleaseLast();
  }
}

F
fengjiayi 已提交
192 193
}  // namespace framework
}  // namespace paddle