var_desc.h 6.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <algorithm>
18
#include <atomic>
19
#include <string>
F
fengjiayi 已提交
20
#include <vector>
W
wanghuancoder 已提交
21

Y
Yu Yang 已提交
22
#include "glog/logging.h"
23
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
24
#include "paddle/fluid/framework/attribute.h"
Y
Yi Wang 已提交
25
#include "paddle/fluid/framework/framework.pb.h"
26
#include "paddle/fluid/framework/type_defs.h"
F
fengjiayi 已提交
27 28 29 30

namespace paddle {
namespace framework {

31 32
using paddle::distributed::auto_parallel::TensorDistAttr;

F
fengjiayi 已提交
33 34 35 36 37 38
// convert between std::vector and protobuf repeated.
template <typename T>
inline std::vector<T> RepeatedToVector(
    const google::protobuf::RepeatedField<T> &repeated_field) {
  std::vector<T> ret;
  ret.reserve(repeated_field.size());
39 40
  std::copy(
      repeated_field.begin(), repeated_field.end(), std::back_inserter(ret));
F
fengjiayi 已提交
41 42 43 44 45 46
  return ret;
}

template <typename T, typename RepeatedField>
inline void VectorToRepeated(const std::vector<T> &vec,
                             RepeatedField *repeated_field) {
F
fengjiayi 已提交
47
  repeated_field->Clear();
F
fengjiayi 已提交
48 49 50 51 52 53 54 55 56 57
  repeated_field->Reserve(vec.size());
  for (const auto &elem : vec) {
    *repeated_field->Add() = elem;
  }
}

// Specialize vector<bool>.
template <typename RepeatedField>
inline void VectorToRepeated(const std::vector<bool> &vec,
                             RepeatedField *repeated_field) {
F
fengjiayi 已提交
58
  repeated_field->Clear();
F
fengjiayi 已提交
59 60 61 62 63 64
  repeated_field->Reserve(vec.size());
  for (auto elem : vec) {
    *repeated_field->Add() = elem;
  }
}

Y
Yu Yang 已提交
65
class VarDesc {
F
fengjiayi 已提交
66
 public:
Y
Yu Yang 已提交
67
  explicit VarDesc(const std::string &name) {
68
    desc_.set_name(name);
X
Xin Pan 已提交
69
    // TODO(paddle-dev): Why default to lodtensor.
70
    desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
L
Leo Chen 已提交
71
    need_updated_ = true;
72
  }
F
fengjiayi 已提交
73

L
Leo Chen 已提交
74 75 76
  explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {
    // need_updated_ = true;
  }
77

78
  // Explicitly implement the copy constructor for auto parallel
79 80
  VarDesc(const VarDesc &other);

81 82 83 84
  VarDesc &operator=(const VarDesc &other) {
    desc_ = other.desc_;
    attrs_ = other.attrs_;
    original_id_ = other.original_id_;
85 86 87
    if (other.dist_attr_) {
      dist_attr_.reset(new TensorDistAttr(*other.dist_attr_));
    }
L
Leo Chen 已提交
88
    need_updated_ = true;
89 90
    return *this;
  }
91

L
Leo Chen 已提交
92 93
  proto::VarDesc *Proto() {
    need_updated_ = true;
94
    return &desc_;
L
Leo Chen 已提交
95
  }
F
fengjiayi 已提交
96

97 98
  const proto::VarDesc *Proto() const { return &desc_; }

F
fengjiayi 已提交
99 100
  std::string Name() const { return desc_.name(); }

L
Leo Chen 已提交
101 102 103 104
  void SetName(std::string name) {
    desc_.set_name(name);
    need_updated_ = true;
  }
105

F
fengjiayi 已提交
106 107 108 109
  void SetTensorDescNum(size_t num);

  size_t GetTensorDescNum() const;

F
fengjiayi 已提交
110 111
  void SetShape(const std::vector<int64_t> &dims);

F
fengjiayi 已提交
112
  void SetShapes(const std::vector<std::vector<int64_t>> &multiple_dims);
F
fengjiayi 已提交
113 114 115 116 117

  std::vector<int64_t> GetShape() const;

  std::vector<std::vector<int64_t>> GetShapes() const;

118
  void SetDataType(proto::VarType::Type data_type);
F
fengjiayi 已提交
119

120 121
  void SetDataTypes(
      const std::vector<proto::VarType::Type> &multiple_data_type);
F
fengjiayi 已提交
122

123
  proto::VarType::Type GetDataType() const;
F
fengjiayi 已提交
124

125 126
  size_t ElementSize() const;

127
  std::vector<proto::VarType::Type> GetDataTypes() const;
F
fengjiayi 已提交
128

Y
Stash  
Yu Yang 已提交
129 130
  void SetLoDLevel(int32_t lod_level);

F
fengjiayi 已提交
131 132
  void SetLoDLevels(const std::vector<int32_t> &multiple_lod_level);

133
  int32_t GetLoDLevel() const;
Y
Stash  
Yu Yang 已提交
134

F
fengjiayi 已提交
135 136
  std::vector<int32_t> GetLoDLevels() const;

137
  proto::VarType::Type GetType() const;
Y
Yu Yang 已提交
138

139
  void SetType(proto::VarType::Type type);
Y
Yu Yang 已提交
140

141 142
  bool Persistable() const { return desc_.persistable(); }

L
Leo Chen 已提交
143 144 145 146
  void SetPersistable(bool persistable) {
    desc_.set_persistable(persistable);
    need_updated_ = true;
  }
147

148 149 150 151
  bool IsParameter() const { return desc_.is_parameter(); }

  void SetIsParameter(bool is_parameter) {
    desc_.set_is_parameter(is_parameter);
L
Leo Chen 已提交
152
    need_updated_ = true;
153 154
  }

L
Leo Chen 已提交
155 156 157 158
  void ClearIsParameter() {
    desc_.clear_is_parameter();
    need_updated_ = true;
  }
159 160 161 162 163 164 165

  bool HasIsParameter() const { return desc_.has_is_parameter(); }

  bool StopGradient() const { return desc_.stop_gradient(); }

  void SetStopGradient(bool stop_gradient) {
    desc_.set_stop_gradient(stop_gradient);
L
Leo Chen 已提交
166
    need_updated_ = true;
167 168
  }

L
Leo Chen 已提交
169 170 171 172
  void ClearStopGradient() {
    desc_.clear_stop_gradient();
    need_updated_ = true;
  }
173 174 175

  bool HasStopGradient() const { return desc_.has_stop_gradient(); }

H
Huihuang Zheng 已提交
176 177 178 179
  bool NeedCheckFeed() const { return desc_.need_check_feed(); }

  void SetNeedCheckFeed(bool need_check_feed) {
    desc_.set_need_check_feed(need_check_feed);
L
Leo Chen 已提交
180
    need_updated_ = true;
H
Huihuang Zheng 已提交
181 182
  }

183 184 185 186 187 188 189 190 191 192 193
  bool HasAttr(const std::string &name) const {
    return attrs_.find(name) != attrs_.end();
  }

  std::vector<std::string> AttrNames() const;

  void SetAttr(const std::string &name, const Attribute &v);
  void RemoveAttr(const std::string &name);

  Attribute GetAttr(const std::string &name) const;

194 195 196 197
  bool NeedUpdate() const { return need_updated_; }
  void SetNeedUpdate(bool need) { need_updated_ = need; }

  // The following methods are only used for auto parallel.
198
  uint64_t Id() const { return id_; }
199
  uint64_t OriginalId() const { return original_id_; }
L
Leo Chen 已提交
200 201 202 203
  void SetOriginalId(uint64_t original_id) {
    original_id_ = original_id;
    need_updated_ = true;
  }
204 205
  TensorDistAttr *MutableDistAttr();
  void SetDistAttr(const TensorDistAttr &dist_attr);
206

F
fengjiayi 已提交
207
 private:
208 209 210 211
  const proto::VarType::TensorDesc &tensor_desc() const;
  std::vector<proto::VarType::TensorDesc> tensor_descs() const;
  proto::VarType::TensorDesc *mutable_tensor_desc();
  std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();
Y
Yu Yang 已提交
212

L
Leo Chen 已提交
213
  // it it really needed? or just mantain a ptr from block?
214
  proto::VarDesc desc_;
215
  AttributeMap attrs_;
216

L
Leo Chen 已提交
217 218
  bool need_updated_{false};

219 220 221 222 223
  // Note: the following members are only used for auto parallel for now
  static uint64_t GenerateId() {
    static std::atomic<std::uint64_t> uid{0};
    return ++uid;
  }
224
  uint64_t id_ = GenerateId();
225
  uint64_t original_id_ = id_;
226
  std::unique_ptr<TensorDistAttr> dist_attr_;
F
fengjiayi 已提交
227
};
228 229

bool operator==(const VarDesc &left, const VarDesc &right);
F
fengjiayi 已提交
230 231
}  // namespace framework
}  // namespace paddle