var_desc.h 6.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <algorithm>
18
#include <atomic>
19
#include <string>
F
fengjiayi 已提交
20
#include <vector>
W
wanghuancoder 已提交
21

Y
Yu Yang 已提交
22
#include "glog/logging.h"
23
#include "paddle/fluid/framework/attribute.h"
Y
Yi Wang 已提交
24
#include "paddle/fluid/framework/framework.pb.h"
25
#include "paddle/fluid/framework/type_defs.h"
F
fengjiayi 已提交
26 27 28 29 30 31 32 33 34 35

namespace paddle {
namespace framework {

// convert between std::vector and protobuf repeated.
template <typename T>
inline std::vector<T> RepeatedToVector(
    const google::protobuf::RepeatedField<T> &repeated_field) {
  std::vector<T> ret;
  ret.reserve(repeated_field.size());
36 37
  std::copy(
      repeated_field.begin(), repeated_field.end(), std::back_inserter(ret));
F
fengjiayi 已提交
38 39 40 41 42 43
  return ret;
}

template <typename T, typename RepeatedField>
inline void VectorToRepeated(const std::vector<T> &vec,
                             RepeatedField *repeated_field) {
F
fengjiayi 已提交
44
  repeated_field->Clear();
F
fengjiayi 已提交
45 46 47 48 49 50 51 52 53 54
  repeated_field->Reserve(vec.size());
  for (const auto &elem : vec) {
    *repeated_field->Add() = elem;
  }
}

// Specialize vector<bool>.
template <typename RepeatedField>
inline void VectorToRepeated(const std::vector<bool> &vec,
                             RepeatedField *repeated_field) {
F
fengjiayi 已提交
55
  repeated_field->Clear();
F
fengjiayi 已提交
56 57 58 59 60 61
  repeated_field->Reserve(vec.size());
  for (auto elem : vec) {
    *repeated_field->Add() = elem;
  }
}

Y
Yu Yang 已提交
62
class VarDesc {
F
fengjiayi 已提交
63
 public:
Y
Yu Yang 已提交
64
  explicit VarDesc(const std::string &name) {
65
    desc_.set_name(name);
X
Xin Pan 已提交
66
    // TODO(paddle-dev): Why default to lodtensor.
67
    desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
L
Leo Chen 已提交
68
    need_updated_ = true;
69
  }
F
fengjiayi 已提交
70

L
Leo Chen 已提交
71 72 73
  explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {
    // need_updated_ = true;
  }
74

75 76 77 78 79
  // Explicitly implement the copy constructor for auto parallel
  VarDesc(const VarDesc &other)
      : desc_(other.desc_),
        attrs_(other.attrs_),
        original_id_(other.original_id_) {}
80 81 82 83
  VarDesc &operator=(const VarDesc &other) {
    desc_ = other.desc_;
    attrs_ = other.attrs_;
    original_id_ = other.original_id_;
L
Leo Chen 已提交
84
    need_updated_ = true;
85 86
    return *this;
  }
87

L
Leo Chen 已提交
88 89 90 91
  proto::VarDesc *Proto() {
    return &desc_;
    need_updated_ = true;
  }
F
fengjiayi 已提交
92

93 94
  const proto::VarDesc *Proto() const { return &desc_; }

F
fengjiayi 已提交
95 96
  std::string Name() const { return desc_.name(); }

L
Leo Chen 已提交
97 98 99 100
  void SetName(std::string name) {
    desc_.set_name(name);
    need_updated_ = true;
  }
101

F
fengjiayi 已提交
102 103 104 105
  void SetTensorDescNum(size_t num);

  size_t GetTensorDescNum() const;

F
fengjiayi 已提交
106 107
  void SetShape(const std::vector<int64_t> &dims);

F
fengjiayi 已提交
108
  void SetShapes(const std::vector<std::vector<int64_t>> &multiple_dims);
F
fengjiayi 已提交
109 110 111 112 113

  std::vector<int64_t> GetShape() const;

  std::vector<std::vector<int64_t>> GetShapes() const;

114
  void SetDataType(proto::VarType::Type data_type);
F
fengjiayi 已提交
115

116 117
  void SetDataTypes(
      const std::vector<proto::VarType::Type> &multiple_data_type);
F
fengjiayi 已提交
118

119
  proto::VarType::Type GetDataType() const;
F
fengjiayi 已提交
120

121 122
  size_t ElementSize() const;

123
  std::vector<proto::VarType::Type> GetDataTypes() const;
F
fengjiayi 已提交
124

Y
Stash  
Yu Yang 已提交
125 126
  void SetLoDLevel(int32_t lod_level);

F
fengjiayi 已提交
127 128
  void SetLoDLevels(const std::vector<int32_t> &multiple_lod_level);

129
  int32_t GetLoDLevel() const;
Y
Stash  
Yu Yang 已提交
130

F
fengjiayi 已提交
131 132
  std::vector<int32_t> GetLoDLevels() const;

133
  proto::VarType::Type GetType() const;
Y
Yu Yang 已提交
134

135
  void SetType(proto::VarType::Type type);
Y
Yu Yang 已提交
136

137 138
  bool Persistable() const { return desc_.persistable(); }

L
Leo Chen 已提交
139 140 141 142
  void SetPersistable(bool persistable) {
    desc_.set_persistable(persistable);
    need_updated_ = true;
  }
143

144 145 146 147
  bool IsParameter() const { return desc_.is_parameter(); }

  void SetIsParameter(bool is_parameter) {
    desc_.set_is_parameter(is_parameter);
L
Leo Chen 已提交
148
    need_updated_ = true;
149 150
  }

L
Leo Chen 已提交
151 152 153 154
  void ClearIsParameter() {
    desc_.clear_is_parameter();
    need_updated_ = true;
  }
155 156 157 158 159 160 161

  bool HasIsParameter() const { return desc_.has_is_parameter(); }

  bool StopGradient() const { return desc_.stop_gradient(); }

  void SetStopGradient(bool stop_gradient) {
    desc_.set_stop_gradient(stop_gradient);
L
Leo Chen 已提交
162
    need_updated_ = true;
163 164
  }

L
Leo Chen 已提交
165 166 167 168
  void ClearStopGradient() {
    desc_.clear_stop_gradient();
    need_updated_ = true;
  }
169 170 171

  bool HasStopGradient() const { return desc_.has_stop_gradient(); }

H
Huihuang Zheng 已提交
172 173 174 175
  bool NeedCheckFeed() const { return desc_.need_check_feed(); }

  void SetNeedCheckFeed(bool need_check_feed) {
    desc_.set_need_check_feed(need_check_feed);
L
Leo Chen 已提交
176
    need_updated_ = true;
H
Huihuang Zheng 已提交
177 178
  }

179 180 181 182 183 184 185 186 187 188 189
  bool HasAttr(const std::string &name) const {
    return attrs_.find(name) != attrs_.end();
  }

  std::vector<std::string> AttrNames() const;

  void SetAttr(const std::string &name, const Attribute &v);
  void RemoveAttr(const std::string &name);

  Attribute GetAttr(const std::string &name) const;

190
  // The Id() and OriginalId() are only used for auto parallel.
191
  uint64_t Id() const { return id_; }
192
  uint64_t OriginalId() const { return original_id_; }
L
Leo Chen 已提交
193 194 195 196 197 198 199
  void SetOriginalId(uint64_t original_id) {
    original_id_ = original_id;
    need_updated_ = true;
  }

  bool NeedUpdate() const { return need_updated_; }
  void SetNeedUpdate(bool need) { need_updated_ = need; }
200

F
fengjiayi 已提交
201
 private:
202 203 204 205
  const proto::VarType::TensorDesc &tensor_desc() const;
  std::vector<proto::VarType::TensorDesc> tensor_descs() const;
  proto::VarType::TensorDesc *mutable_tensor_desc();
  std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();
Y
Yu Yang 已提交
206

207 208 209 210 211 212 213
  // This thread-safe implementation seems to be redudent since the neural
  // networks are usually constructed in a single thread.
  static uint64_t GenerateId() {
    static std::atomic<std::uint64_t> uid{0};
    return ++uid;
  }

L
Leo Chen 已提交
214
  // it it really needed? or just mantain a ptr from block?
215
  proto::VarDesc desc_;
216
  AttributeMap attrs_;
217

L
Leo Chen 已提交
218 219
  bool need_updated_{false};

220
  // Note: the id_ is unique for all VarDesc (only for auto parallel).
221
  uint64_t id_ = GenerateId();
222 223 224 225 226
  // Note: the orignal_id_ is used for referring to the original VarDesc
  // that the current VarDesc is built from (only for auto parallel).
  // The default original_id_ is same as the id_, which means the
  // current VarDesc is not built from the other one.
  uint64_t original_id_ = id_;
F
fengjiayi 已提交
227
};
228 229

bool operator==(const VarDesc &left, const VarDesc &right);
F
fengjiayi 已提交
230 231
}  // namespace framework
}  // namespace paddle