op_desc.h 7.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <atomic>
18
#include <memory>
A
Abhinav Arora 已提交
19
#include <string>
F
fengjiayi 已提交
20
#include <unordered_map>
H
hong 已提交
21
#include <utility>
F
fengjiayi 已提交
22
#include <vector>
W
wanghuancoder 已提交
23

24
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
Y
Yi Wang 已提交
25 26 27
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/var_desc.h"
F
fengjiayi 已提交
28 29 30 31

namespace paddle {
namespace framework {

32
class VarDesc;
Y
Yu Yang 已提交
33 34
class BlockDesc;
class ProgramDesc;
W
wanghuancoder 已提交
35

36 37
using paddle::distributed::auto_parallel::OperatorDistAttr;

Y
Yu Yang 已提交
38
class OpDesc {
F
fengjiayi 已提交
39
 public:
Y
Yu Yang 已提交
40
  OpDesc() {}
F
fengjiayi 已提交
41

42 43 44 45
  OpDesc(const std::string &type,
         const VariableNameMap &inputs,
         const VariableNameMap &outputs,
         const AttributeMap &attrs);
F
fengjiayi 已提交
46

47 48
  OpDesc(const OpDesc &desc);

F
fengjiayi 已提交
49
  OpDesc(const proto::OpDesc &desc, BlockDesc *block);
50 51 52

  explicit OpDesc(BlockDesc *block) : block_(block) {}

X
Xin Pan 已提交
53
  OpDesc(const OpDesc &other, BlockDesc *block);
54

55 56
  OpDesc &operator=(const OpDesc &other);

57
  void CopyFrom(const OpDesc &op_desc);
F
fengjiayi 已提交
58

59
  proto::OpDesc *Proto();
F
fengjiayi 已提交
60

61
  std::string Type() const { return desc_.type(); }
F
fengjiayi 已提交
62

63
  void SetType(const std::string &type) { desc_.set_type(type); }
F
fengjiayi 已提交
64 65 66

  const std::vector<std::string> &Input(const std::string &name) const;

67 68 69 70
  std::vector<std::string> Input(const std::string &name,
                                 bool with_attr_var) const;

  std::vector<std::string> InputArgumentNames(bool with_attr_var = false) const;
F
Update  
fengjiayi 已提交
71

F
fengjiayi 已提交
72 73 74 75 76
  void SetInput(const std::string &param_name,
                const std::vector<std::string> &args);

  const std::vector<std::string> &Output(const std::string &name) const;

77 78
  bool HasOutput(const std::string &name) const;

F
Update  
fengjiayi 已提交
79 80
  std::vector<std::string> OutputArgumentNames() const;

F
fengjiayi 已提交
81 82
  void SetOutput(const std::string &param_name,
                 const std::vector<std::string> &args);
83
  void RemoveOutput(const std::string &name);
F
fengjiayi 已提交
84

85 86
  void RemoveInput(const std::string &name);

87
  bool HasAttr(const std::string &name, bool with_attr_var = false) const;
F
fengjiayi 已提交
88

89
  bool HasProtoAttr(const std::string &name) const;
F
fengjiayi 已提交
90

91 92
  proto::AttrType GetAttrType(const std::string &name,
                              bool with_attr_var = false) const;
F
fengjiayi 已提交
93

94
  std::vector<std::string> AttrNames(bool with_attr_var = false) const;
F
fengjiayi 已提交
95 96

  void SetAttr(const std::string &name, const Attribute &v);
97
  void RemoveAttr(const std::string &name);
F
fengjiayi 已提交
98

99 100 101 102 103 104 105 106 107
  // NOTE(chenfeiyu): this template is added to avoid using a variant(Attribute)
  // as a parameter of a function which is bound to python, which causes
  // unexpected type conversion due to the overload resolution mechanism
  // https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
  template <typename T>
  void SetPlainAttr(const std::string &name, const T &value) {
    SetAttr(name, value);
  }

108 109 110 111
  void SetVarAttr(const std::string &name, VarDesc *var);

  void SetVarsAttr(const std::string &name, std::vector<VarDesc *> vars);

A
Abhinav Arora 已提交
112
  void SetBlockAttr(const std::string &name, BlockDesc *block);
F
fengjiayi 已提交
113

114 115
  void SetBlocksAttr(const std::string &name, std::vector<BlockDesc *> blocks);

116
  Attribute GetAttr(const std::string &name, bool with_attr_var = false) const;
F
fengjiayi 已提交
117

118 119 120 121
  template <typename T>
  T GetAttrIfExists(const std::string &name) const {
    T result{};
    if (HasAttr(name)) {
R
Ruibiao Chen 已提交
122
      result = PADDLE_GET_CONST(T, GetAttr(name));
123 124 125 126
    }
    return result;
  }

M
minqiyang 已提交
127
  const proto::OpProto::Attr &GetProtoAttr(const std::string &name) const;
M
minqiyang 已提交
128

Y
yuyang18 已提交
129
  Attribute GetNullableAttr(const std::string &name) const;
Y
Fix bug  
yuyang18 已提交
130

G
gongweibao 已提交
131 132 133
  int GetBlockAttrId(const std::string &name) const;

  std::vector<int> GetBlocksAttrIds(const std::string &name) const;
F
fengjiayi 已提交
134

F
fengjiayi 已提交
135 136
  void Rename(const std::string &old_name, const std::string &new_name);

Y
Yang Yang(Tony) 已提交
137 138 139 140
  void RenameOutput(const std::string &old_name, const std::string &new_name);

  void RenameInput(const std::string &old_name, const std::string &new_name);

F
fengjiayi 已提交
141
  // Only be used in C++
Y
Yu Yang 已提交
142
  const AttributeMap &GetAttrMap() const;
F
fengjiayi 已提交
143

144 145
  // Only be used in C++
  void SetAttrMap(const AttributeMap &attr_map);
146

147 148 149 150
  void SetRuntimeAttrMap(const AttributeMap &attr_map);

  const AttributeMap &GetRuntimeAttrMap() const;

151 152 153
  std::vector<std::string> InputNames(bool with_attr_var = false) const {
    return MapKeys(inputs_);
  }
Y
Yu Yang 已提交
154
  std::vector<std::string> OutputNames() const { return MapKeys(outputs_); }
155

Y
Yu Yang 已提交
156 157
  const VariableNameMap &Inputs() const { return inputs_; }

158 159
  VariableNameMap Inputs(bool with_attr_var) const;

Y
Yu Yang 已提交
160 161
  const VariableNameMap &Outputs() const { return outputs_; }

162 163 164 165 166 167 168 169 170 171
  VariableNameMap *MutableInputs() {
    this->need_update_ = true;
    return &this->inputs_;
  }

  VariableNameMap *MutableOutputs() {
    this->need_update_ = true;
    return &this->outputs_;
  }

172 173 174 175 176
  AttributeMap *MutableAttrMap() {
    this->need_update_ = true;
    return &this->attrs_;
  }

F
fengjiayi 已提交
177 178
  void CheckAttrs();

H
hong 已提交
179
  void InferShape(const BlockDesc &block);
Y
Yu Yang 已提交
180

Y
Yu Yang 已提交
181
  void InferVarType(BlockDesc *block) const;
Y
Yu Yang 已提交
182

183
  void SetIsTarget(bool is_target) { desc_.set_is_target(is_target); }
184

185 186
  void Flush();

187 188
  BlockDesc *Block() { return this->block_; }

S
sneaxiy 已提交
189 190
  const BlockDesc *Block() const { return this->block_; }

191 192
  void UpdateVarAttr(const std::string &name, const Attribute &attr);

193 194 195
  bool NeedUpdate() const { return need_update_; }

  // The following methods are only used for auto parallel.
196
  uint64_t Id() const { return id_; }
197 198
  uint64_t OriginalId() const { return original_id_; }
  void SetOriginalId(uint64_t original_id) { original_id_ = original_id; }
199
  const OperatorDistAttr *DistAttr() const;
200 201
  OperatorDistAttr *MutableDistAttr();
  void SetDistAttr(const OperatorDistAttr &dist_attr);
L
Leo Chen 已提交
202

F
fengjiayi 已提交
203
 private:
204 205 206 207
  friend class ProgramDesc;
  // Find VarDesc from OpDesc located Block into global Block
  VarDesc *FindVarRecursive(const std::string &name);

208 209 210 211 212
  template <typename MapType>
  static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
    std::vector<typename MapType::key_type> ret_val;
    ret_val.reserve(map.size());
    std::transform(
213 214 215
        map.begin(),
        map.end(),
        std::back_inserter(ret_val),
216 217 218 219
        [](const typename MapType::value_type &pair) { return pair.first; });
    return ret_val;
  }

H
HongyuJia 已提交
220
  // it it really needed? or just maintain a ptr from block?
221
  proto::OpDesc desc_;
222
  BlockDesc *block_{nullptr};  // not_own
223
  // input arg name => input variable names
Y
Yu Yang 已提交
224
  VariableNameMap inputs_;
225
  // output arg name => output variable names
Y
Yu Yang 已提交
226
  VariableNameMap outputs_;
227
  // attribute name => all original attrs
Y
Yu Yang 已提交
228
  AttributeMap attrs_;
229 230 231 232 233 234
  // runtime_attrs_ contains the attributes which used for dispatching kernel
  // (use_mkldnn, use_cudnn, ...) or passing additional configuration for
  // special heterogeneous kernel (workspace_size_MB, ...).
  // The attributes in runtime_attrs_ are setted by framework (such as PASS),
  // and not in the python api.
  AttributeMap runtime_attrs_;
F
fengjiayi 已提交
235 236 237 238

  // need_update_ indicate there some local changes not be synchronized. If
  // local changes should be synchronized, need_update_ should be set to true.
  bool need_update_{false};
239

240 241 242 243 244 245
  // Note: the following members are only used for auto_parallel for now.
  static uint64_t GenerateId() {
    static std::atomic<std::uint64_t> uid{0};
    // Must start from one
    return ++uid;
  }
246
  uint64_t id_ = GenerateId();
247
  uint64_t original_id_ = id_;
248
  std::unique_ptr<OperatorDistAttr> dist_attr_;
F
fengjiayi 已提交
249
};
250 251

std::vector<std::string> AttrVarNames(const Attribute &attr);
F
fengjiayi 已提交
252 253
}  // namespace framework
}  // namespace paddle