var_desc.h 6.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <algorithm>
18
#include <atomic>
19
#include <string>
F
fengjiayi 已提交
20
#include <vector>
W
wanghuancoder 已提交
21

Y
Yu Yang 已提交
22
#include "glog/logging.h"
23
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
24
#include "paddle/fluid/framework/attribute.h"
Y
Yi Wang 已提交
25
#include "paddle/fluid/framework/framework.pb.h"
26
#include "paddle/fluid/framework/type_defs.h"
F
fengjiayi 已提交
27 28 29 30

namespace paddle {
namespace framework {

31 32
using paddle::distributed::auto_parallel::TensorDistAttr;

F
fengjiayi 已提交
33 34 35 36 37 38
// convert between std::vector and protobuf repeated.
template <typename T>
inline std::vector<T> RepeatedToVector(
    const google::protobuf::RepeatedField<T> &repeated_field) {
  std::vector<T> ret;
  ret.reserve(repeated_field.size());
39 40
  std::copy(
      repeated_field.begin(), repeated_field.end(), std::back_inserter(ret));
F
fengjiayi 已提交
41 42 43 44 45 46
  return ret;
}

template <typename T, typename RepeatedField>
inline void VectorToRepeated(const std::vector<T> &vec,
                             RepeatedField *repeated_field) {
F
fengjiayi 已提交
47
  repeated_field->Clear();
F
fengjiayi 已提交
48 49 50 51 52 53 54 55 56 57
  repeated_field->Reserve(vec.size());
  for (const auto &elem : vec) {
    *repeated_field->Add() = elem;
  }
}

// Specialize vector<bool>.
template <typename RepeatedField>
inline void VectorToRepeated(const std::vector<bool> &vec,
                             RepeatedField *repeated_field) {
F
fengjiayi 已提交
58
  repeated_field->Clear();
F
fengjiayi 已提交
59 60 61 62 63 64
  repeated_field->Reserve(vec.size());
  for (auto elem : vec) {
    *repeated_field->Add() = elem;
  }
}

Y
Yu Yang 已提交
65
class VarDesc {
F
fengjiayi 已提交
66
 public:
Y
Yu Yang 已提交
67
  explicit VarDesc(const std::string &name) {
68
    desc_.set_name(name);
X
Xin Pan 已提交
69
    // TODO(paddle-dev): Why default to lodtensor.
70
    desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
L
Leo Chen 已提交
71
    need_updated_ = true;
72
  }
F
fengjiayi 已提交
73

74
  explicit VarDesc(const proto::VarDesc &desc);
75

76
  // Explicitly implement the copy constructor for auto parallel
77 78
  VarDesc(const VarDesc &other);

79 80 81 82
  VarDesc &operator=(const VarDesc &other) {
    desc_ = other.desc_;
    attrs_ = other.attrs_;
    original_id_ = other.original_id_;
83 84 85
    if (other.dist_attr_) {
      dist_attr_.reset(new TensorDistAttr(*other.dist_attr_));
    }
L
Leo Chen 已提交
86
    need_updated_ = true;
87 88
    return *this;
  }
89

L
Leo Chen 已提交
90
  proto::VarDesc *Proto() {
91
    Flush();  // Only flush attrs for auto parallel
92
    return &desc_;
L
Leo Chen 已提交
93
  }
F
fengjiayi 已提交
94

95 96
  const proto::VarDesc *Proto() const { return &desc_; }

F
fengjiayi 已提交
97 98
  std::string Name() const { return desc_.name(); }

L
Leo Chen 已提交
99 100 101 102
  void SetName(std::string name) {
    desc_.set_name(name);
    need_updated_ = true;
  }
103

F
fengjiayi 已提交
104 105 106 107
  void SetTensorDescNum(size_t num);

  size_t GetTensorDescNum() const;

F
fengjiayi 已提交
108 109
  void SetShape(const std::vector<int64_t> &dims);

F
fengjiayi 已提交
110
  void SetShapes(const std::vector<std::vector<int64_t>> &multiple_dims);
F
fengjiayi 已提交
111 112 113 114 115

  std::vector<int64_t> GetShape() const;

  std::vector<std::vector<int64_t>> GetShapes() const;

116
  void SetDataType(proto::VarType::Type data_type);
F
fengjiayi 已提交
117

118 119
  void SetDataTypes(
      const std::vector<proto::VarType::Type> &multiple_data_type);
F
fengjiayi 已提交
120

121
  proto::VarType::Type GetDataType() const;
F
fengjiayi 已提交
122

123 124
  size_t ElementSize() const;

125
  std::vector<proto::VarType::Type> GetDataTypes() const;
F
fengjiayi 已提交
126

Y
Stash  
Yu Yang 已提交
127 128
  void SetLoDLevel(int32_t lod_level);

F
fengjiayi 已提交
129 130
  void SetLoDLevels(const std::vector<int32_t> &multiple_lod_level);

131
  int32_t GetLoDLevel() const;
Y
Stash  
Yu Yang 已提交
132

F
fengjiayi 已提交
133 134
  std::vector<int32_t> GetLoDLevels() const;

135
  proto::VarType::Type GetType() const;
Y
Yu Yang 已提交
136

137
  void SetType(proto::VarType::Type type);
Y
Yu Yang 已提交
138

139 140
  bool Persistable() const { return desc_.persistable(); }

L
Leo Chen 已提交
141 142 143 144
  void SetPersistable(bool persistable) {
    desc_.set_persistable(persistable);
    need_updated_ = true;
  }
145

146 147 148 149
  bool IsParameter() const { return desc_.is_parameter(); }

  void SetIsParameter(bool is_parameter) {
    desc_.set_is_parameter(is_parameter);
L
Leo Chen 已提交
150
    need_updated_ = true;
151 152
  }

L
Leo Chen 已提交
153 154 155 156
  void ClearIsParameter() {
    desc_.clear_is_parameter();
    need_updated_ = true;
  }
157 158 159 160 161 162 163

  bool HasIsParameter() const { return desc_.has_is_parameter(); }

  bool StopGradient() const { return desc_.stop_gradient(); }

  void SetStopGradient(bool stop_gradient) {
    desc_.set_stop_gradient(stop_gradient);
L
Leo Chen 已提交
164
    need_updated_ = true;
165 166
  }

L
Leo Chen 已提交
167 168 169 170
  void ClearStopGradient() {
    desc_.clear_stop_gradient();
    need_updated_ = true;
  }
171 172 173

  bool HasStopGradient() const { return desc_.has_stop_gradient(); }

H
Huihuang Zheng 已提交
174 175 176 177
  bool NeedCheckFeed() const { return desc_.need_check_feed(); }

  void SetNeedCheckFeed(bool need_check_feed) {
    desc_.set_need_check_feed(need_check_feed);
L
Leo Chen 已提交
178
    need_updated_ = true;
H
Huihuang Zheng 已提交
179 180
  }

181 182 183 184 185 186 187 188 189 190 191
  bool HasAttr(const std::string &name) const {
    return attrs_.find(name) != attrs_.end();
  }

  std::vector<std::string> AttrNames() const;

  void SetAttr(const std::string &name, const Attribute &v);
  void RemoveAttr(const std::string &name);

  Attribute GetAttr(const std::string &name) const;

192 193 194
  bool NeedUpdate() const { return need_updated_; }
  void SetNeedUpdate(bool need) { need_updated_ = need; }

195 196
  void Flush();

197
  // The following methods are only used for auto parallel.
198
  uint64_t Id() const { return id_; }
199
  uint64_t OriginalId() const { return original_id_; }
L
Leo Chen 已提交
200 201 202 203
  void SetOriginalId(uint64_t original_id) {
    original_id_ = original_id;
    need_updated_ = true;
  }
204 205
  TensorDistAttr *MutableDistAttr();
  void SetDistAttr(const TensorDistAttr &dist_attr);
206

F
fengjiayi 已提交
207
 private:
208 209 210 211
  const proto::VarType::TensorDesc &tensor_desc() const;
  std::vector<proto::VarType::TensorDesc> tensor_descs() const;
  proto::VarType::TensorDesc *mutable_tensor_desc();
  std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();
Y
Yu Yang 已提交
212

S
Shuangchi He 已提交
213
  // Is it really needed? Or just mantain a ptr from the block?
214
  proto::VarDesc desc_;
215
  AttributeMap attrs_;
216

L
Leo Chen 已提交
217 218
  bool need_updated_{false};

219 220 221 222 223
  // Note: the following members are only used for auto parallel for now
  static uint64_t GenerateId() {
    static std::atomic<std::uint64_t> uid{0};
    return ++uid;
  }
224
  uint64_t id_ = GenerateId();
225
  uint64_t original_id_ = id_;
226
  std::unique_ptr<TensorDistAttr> dist_attr_;
F
fengjiayi 已提交
227
};
228 229

bool operator==(const VarDesc &left, const VarDesc &right);
F
fengjiayi 已提交
230 231
}  // namespace framework
}  // namespace paddle