var_desc.h 5.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <algorithm>
18
#include <atomic>
19
#include <string>
F
fengjiayi 已提交
20
#include <vector>
W
wanghuancoder 已提交
21

Y
Yu Yang 已提交
22
#include "glog/logging.h"
23
#include "paddle/fluid/framework/attribute.h"
Y
Yi Wang 已提交
24
#include "paddle/fluid/framework/framework.pb.h"
25
#include "paddle/fluid/framework/type_defs.h"
F
fengjiayi 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43

namespace paddle {
namespace framework {

// convert between std::vector and protobuf repeated.
template <typename T>
inline std::vector<T> RepeatedToVector(
    const google::protobuf::RepeatedField<T> &repeated_field) {
  std::vector<T> ret;
  ret.reserve(repeated_field.size());
  std::copy(repeated_field.begin(), repeated_field.end(),
            std::back_inserter(ret));
  return ret;
}

template <typename T, typename RepeatedField>
inline void VectorToRepeated(const std::vector<T> &vec,
                             RepeatedField *repeated_field) {
F
fengjiayi 已提交
44
  repeated_field->Clear();
F
fengjiayi 已提交
45 46 47 48 49 50 51 52 53 54
  repeated_field->Reserve(vec.size());
  for (const auto &elem : vec) {
    *repeated_field->Add() = elem;
  }
}

// Specialize vector<bool>.
template <typename RepeatedField>
inline void VectorToRepeated(const std::vector<bool> &vec,
                             RepeatedField *repeated_field) {
F
fengjiayi 已提交
55
  repeated_field->Clear();
F
fengjiayi 已提交
56 57 58 59 60 61
  repeated_field->Reserve(vec.size());
  for (auto elem : vec) {
    *repeated_field->Add() = elem;
  }
}

Y
Yu Yang 已提交
62
class VarDesc {
F
fengjiayi 已提交
63
 public:
Y
Yu Yang 已提交
64
  explicit VarDesc(const std::string &name) {
65
    desc_.set_name(name);
X
Xin Pan 已提交
66
    // TODO(paddle-dev): Why default to lodtensor.
67
    desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
68
  }
F
fengjiayi 已提交
69

Y
Yu Yang 已提交
70
  explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {}
71

72 73 74 75 76 77
  // Explicitly implement the copy constructor for auto parallel
  VarDesc(const VarDesc &other)
      : desc_(other.desc_),
        attrs_(other.attrs_),
        original_id_(other.original_id_) {}

78
  proto::VarDesc *Proto() { return &desc_; }
F
fengjiayi 已提交
79

80 81
  const proto::VarDesc *Proto() const { return &desc_; }

F
fengjiayi 已提交
82 83
  std::string Name() const { return desc_.name(); }

84 85
  void SetName(std::string name) { desc_.set_name(name); }

F
fengjiayi 已提交
86 87 88 89
  void SetTensorDescNum(size_t num);

  size_t GetTensorDescNum() const;

F
fengjiayi 已提交
90 91
  void SetShape(const std::vector<int64_t> &dims);

F
fengjiayi 已提交
92
  void SetShapes(const std::vector<std::vector<int64_t>> &multiple_dims);
F
fengjiayi 已提交
93 94 95 96 97

  std::vector<int64_t> GetShape() const;

  std::vector<std::vector<int64_t>> GetShapes() const;

98
  void SetDataType(proto::VarType::Type data_type);
F
fengjiayi 已提交
99

100 101
  void SetDataTypes(
      const std::vector<proto::VarType::Type> &multiple_data_type);
F
fengjiayi 已提交
102

103
  proto::VarType::Type GetDataType() const;
F
fengjiayi 已提交
104

105 106
  size_t ElementSize() const;

107
  std::vector<proto::VarType::Type> GetDataTypes() const;
F
fengjiayi 已提交
108

Y
Stash  
Yu Yang 已提交
109 110
  void SetLoDLevel(int32_t lod_level);

F
fengjiayi 已提交
111 112
  void SetLoDLevels(const std::vector<int32_t> &multiple_lod_level);

113
  int32_t GetLoDLevel() const;
Y
Stash  
Yu Yang 已提交
114

F
fengjiayi 已提交
115 116
  std::vector<int32_t> GetLoDLevels() const;

117
  proto::VarType::Type GetType() const;
Y
Yu Yang 已提交
118

119
  void SetType(proto::VarType::Type type);
Y
Yu Yang 已提交
120

121 122 123 124
  bool Persistable() const { return desc_.persistable(); }

  void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }

125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
  bool IsParameter() const { return desc_.is_parameter(); }

  void SetIsParameter(bool is_parameter) {
    desc_.set_is_parameter(is_parameter);
  }

  void ClearIsParameter() { desc_.clear_is_parameter(); }

  bool HasIsParameter() const { return desc_.has_is_parameter(); }

  bool StopGradient() const { return desc_.stop_gradient(); }

  void SetStopGradient(bool stop_gradient) {
    desc_.set_stop_gradient(stop_gradient);
  }

  void ClearStopGradient() { desc_.clear_stop_gradient(); }

  bool HasStopGradient() const { return desc_.has_stop_gradient(); }

H
Huihuang Zheng 已提交
145 146 147 148 149 150
  bool NeedCheckFeed() const { return desc_.need_check_feed(); }

  void SetNeedCheckFeed(bool need_check_feed) {
    desc_.set_need_check_feed(need_check_feed);
  }

151 152 153 154 155 156 157 158 159 160 161
  bool HasAttr(const std::string &name) const {
    return attrs_.find(name) != attrs_.end();
  }

  std::vector<std::string> AttrNames() const;

  void SetAttr(const std::string &name, const Attribute &v);
  void RemoveAttr(const std::string &name);

  Attribute GetAttr(const std::string &name) const;

162
  // The Id() and OriginalId() are only used for auto parallel.
163
  uint64_t Id() const { return id_; }
164 165
  uint64_t OriginalId() const { return original_id_; }
  void SetOriginalId(uint64_t original_id) { original_id_ = original_id; }
166

F
fengjiayi 已提交
167
 private:
168 169 170 171
  const proto::VarType::TensorDesc &tensor_desc() const;
  std::vector<proto::VarType::TensorDesc> tensor_descs() const;
  proto::VarType::TensorDesc *mutable_tensor_desc();
  std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();
Y
Yu Yang 已提交
172

173 174 175 176 177 178 179
  // This thread-safe implementation seems to be redudent since the neural
  // networks are usually constructed in a single thread.
  static uint64_t GenerateId() {
    static std::atomic<std::uint64_t> uid{0};
    return ++uid;
  }

180
  proto::VarDesc desc_;
181
  AttributeMap attrs_;
182 183

  // Note: the id_ is unique for all VarDesc (only for auto parallel).
184
  uint64_t id_ = GenerateId();
185 186 187 188 189
  // Note: the orignal_id_ is used for referring to the original VarDesc
  // that the current VarDesc is built from (only for auto parallel).
  // The default original_id_ is same as the id_, which means the
  // current VarDesc is not built from the other one.
  uint64_t original_id_ = id_;
F
fengjiayi 已提交
190
};
191 192

bool operator==(const VarDesc &left, const VarDesc &right);
F
fengjiayi 已提交
193 194
}  // namespace framework
}  // namespace paddle