op_desc.h 7.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <atomic>
18
#include <memory>
A
Abhinav Arora 已提交
19
#include <string>
F
fengjiayi 已提交
20
#include <unordered_map>
H
hong 已提交
21
#include <utility>
F
fengjiayi 已提交
22
#include <vector>
W
wanghuancoder 已提交
23

24
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
Y
Yi Wang 已提交
25 26 27
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/var_desc.h"
28
#include "paddle/phi/core/macros.h"
F
fengjiayi 已提交
29 30 31 32

namespace paddle {
namespace framework {

33
class VarDesc;
Y
Yu Yang 已提交
34 35
class BlockDesc;
class ProgramDesc;
W
wanghuancoder 已提交
36

37 38
using paddle::distributed::auto_parallel::OperatorDistAttr;

Y
Yu Yang 已提交
39
class OpDesc {
F
fengjiayi 已提交
40
 public:
Y
Yu Yang 已提交
41
  OpDesc() {}
F
fengjiayi 已提交
42

43 44 45 46
  OpDesc(const std::string &type,
         const VariableNameMap &inputs,
         const VariableNameMap &outputs,
         const AttributeMap &attrs);
F
fengjiayi 已提交
47

48 49
  OpDesc(const OpDesc &desc);

F
fengjiayi 已提交
50
  OpDesc(const proto::OpDesc &desc, BlockDesc *block);
51 52 53

  explicit OpDesc(BlockDesc *block) : block_(block) {}

X
Xin Pan 已提交
54
  OpDesc(const OpDesc &other, BlockDesc *block);
55

56 57
  OpDesc &operator=(const OpDesc &other);

58
  void CopyFrom(const OpDesc &op_desc);
F
fengjiayi 已提交
59

60
  proto::OpDesc *Proto();
F
fengjiayi 已提交
61

62
  std::string Type() const { return desc_.type(); }
F
fengjiayi 已提交
63

64
  void SetType(const std::string &type) { desc_.set_type(type); }
F
fengjiayi 已提交
65 66 67

  const std::vector<std::string> &Input(const std::string &name) const;

68 69 70 71
  std::vector<std::string> Input(const std::string &name,
                                 bool with_attr_var) const;

  std::vector<std::string> InputArgumentNames(bool with_attr_var = false) const;
F
Update  
fengjiayi 已提交
72

F
fengjiayi 已提交
73 74 75 76 77
  void SetInput(const std::string &param_name,
                const std::vector<std::string> &args);

  const std::vector<std::string> &Output(const std::string &name) const;

78 79
  bool HasOutput(const std::string &name) const;

80 81
  bool HasInput(const std::string &name) const;

F
Update  
fengjiayi 已提交
82 83
  std::vector<std::string> OutputArgumentNames() const;

F
fengjiayi 已提交
84 85
  void SetOutput(const std::string &param_name,
                 const std::vector<std::string> &args);
86
  void RemoveOutput(const std::string &name);
F
fengjiayi 已提交
87

88 89
  void RemoveInput(const std::string &name);

90
  bool HasAttr(const std::string &name, bool with_attr_var = false) const;
F
fengjiayi 已提交
91

92
  bool HasProtoAttr(const std::string &name) const;
F
fengjiayi 已提交
93

94 95
  proto::AttrType GetAttrType(const std::string &name,
                              bool with_attr_var = false) const;
F
fengjiayi 已提交
96

97
  std::vector<std::string> AttrNames(bool with_attr_var = false) const;
F
fengjiayi 已提交
98 99

  void SetAttr(const std::string &name, const Attribute &v);
100
  void RemoveAttr(const std::string &name);
F
fengjiayi 已提交
101

102 103 104 105 106 107 108 109 110
  // NOTE(chenfeiyu): this template is added to avoid using a variant(Attribute)
  // as a parameter of a function which is bound to python, which causes
  // unexpected type conversion due to the overload resolution mechanism
  // https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
  template <typename T>
  void SetPlainAttr(const std::string &name, const T &value) {
    SetAttr(name, value);
  }

111 112 113 114
  void SetVarAttr(const std::string &name, VarDesc *var);

  void SetVarsAttr(const std::string &name, std::vector<VarDesc *> vars);

A
Abhinav Arora 已提交
115
  void SetBlockAttr(const std::string &name, BlockDesc *block);
F
fengjiayi 已提交
116

117 118
  void SetBlocksAttr(const std::string &name, std::vector<BlockDesc *> blocks);

119
  Attribute GetAttr(const std::string &name, bool with_attr_var = false) const;
F
fengjiayi 已提交
120

121 122 123 124
  template <typename T>
  T GetAttrIfExists(const std::string &name) const {
    T result{};
    if (HasAttr(name)) {
R
Ruibiao Chen 已提交
125
      result = PADDLE_GET_CONST(T, GetAttr(name));
126 127 128 129
    }
    return result;
  }

M
minqiyang 已提交
130
  const proto::OpProto::Attr &GetProtoAttr(const std::string &name) const;
M
minqiyang 已提交
131

Y
yuyang18 已提交
132
  Attribute GetNullableAttr(const std::string &name) const;
Y
Fix bug  
yuyang18 已提交
133

G
gongweibao 已提交
134 135 136
  int GetBlockAttrId(const std::string &name) const;

  std::vector<int> GetBlocksAttrIds(const std::string &name) const;
F
fengjiayi 已提交
137

F
fengjiayi 已提交
138 139
  void Rename(const std::string &old_name, const std::string &new_name);

Y
Yang Yang(Tony) 已提交
140 141 142 143
  void RenameOutput(const std::string &old_name, const std::string &new_name);

  void RenameInput(const std::string &old_name, const std::string &new_name);

F
fengjiayi 已提交
144
  // Only be used in C++
Y
Yu Yang 已提交
145
  const AttributeMap &GetAttrMap() const;
F
fengjiayi 已提交
146

147 148
  // Only be used in C++
  void SetAttrMap(const AttributeMap &attr_map);
149

150 151 152 153
  void SetRuntimeAttrMap(const AttributeMap &attr_map);

  const AttributeMap &GetRuntimeAttrMap() const;

154
  std::vector<std::string> InputNames(bool with_attr_var UNUSED = false) const {
155 156
    return MapKeys(inputs_);
  }
Y
Yu Yang 已提交
157
  std::vector<std::string> OutputNames() const { return MapKeys(outputs_); }
158

Y
Yu Yang 已提交
159 160
  const VariableNameMap &Inputs() const { return inputs_; }

161 162
  VariableNameMap Inputs(bool with_attr_var) const;

Y
Yu Yang 已提交
163 164
  const VariableNameMap &Outputs() const { return outputs_; }

165 166 167 168 169 170 171 172 173 174
  VariableNameMap *MutableInputs() {
    this->need_update_ = true;
    return &this->inputs_;
  }

  VariableNameMap *MutableOutputs() {
    this->need_update_ = true;
    return &this->outputs_;
  }

175 176 177 178 179
  AttributeMap *MutableAttrMap() {
    this->need_update_ = true;
    return &this->attrs_;
  }

F
fengjiayi 已提交
180 181
  void CheckAttrs();

H
hong 已提交
182
  void InferShape(const BlockDesc &block);
Y
Yu Yang 已提交
183

Y
Yu Yang 已提交
184
  void InferVarType(BlockDesc *block) const;
Y
Yu Yang 已提交
185

186
  void SetIsTarget(bool is_target) { desc_.set_is_target(is_target); }
187

188 189
  void Flush();

190 191
  BlockDesc *Block() { return this->block_; }

S
sneaxiy 已提交
192 193
  const BlockDesc *Block() const { return this->block_; }

194 195
  void UpdateVarAttr(const std::string &name, const Attribute &attr);

196 197 198
  bool NeedUpdate() const { return need_update_; }

  // The following methods are only used for auto parallel.
199
  uint64_t Id() const { return id_; }
200 201
  uint64_t OriginalId() const { return original_id_; }
  void SetOriginalId(uint64_t original_id) { original_id_ = original_id; }
202
  const OperatorDistAttr *DistAttr() const;
203 204
  OperatorDistAttr *MutableDistAttr();
  void SetDistAttr(const OperatorDistAttr &dist_attr);
L
Leo Chen 已提交
205

J
Jiabin Yang 已提交
206 207
  void ResetBlock() { this->block_ = nullptr; }

F
fengjiayi 已提交
208
 private:
209 210 211 212
  friend class ProgramDesc;
  // Find VarDesc from OpDesc located Block into global Block
  VarDesc *FindVarRecursive(const std::string &name);

213 214 215 216 217
  template <typename MapType>
  static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
    std::vector<typename MapType::key_type> ret_val;
    ret_val.reserve(map.size());
    std::transform(
218 219 220
        map.begin(),
        map.end(),
        std::back_inserter(ret_val),
221 222 223 224
        [](const typename MapType::value_type &pair) { return pair.first; });
    return ret_val;
  }

S
Shuangchi He 已提交
225
  // Is it really needed? Or just maintain a ptr from the block?
226
  proto::OpDesc desc_;
227
  BlockDesc *block_{nullptr};  // not_own
228
  // input arg name => input variable names
Y
Yu Yang 已提交
229
  VariableNameMap inputs_;
230
  // output arg name => output variable names
Y
Yu Yang 已提交
231
  VariableNameMap outputs_;
232
  // attribute name => all original attrs
Y
Yu Yang 已提交
233
  AttributeMap attrs_;
234 235 236 237 238 239
  // runtime_attrs_ contains the attributes which used for dispatching kernel
  // (use_mkldnn, use_cudnn, ...) or passing additional configuration for
  // special heterogeneous kernel (workspace_size_MB, ...).
  // The attributes in runtime_attrs_ are setted by framework (such as PASS),
  // and not in the python api.
  AttributeMap runtime_attrs_;
F
fengjiayi 已提交
240 241 242 243

  // need_update_ indicate there some local changes not be synchronized. If
  // local changes should be synchronized, need_update_ should be set to true.
  bool need_update_{false};
244

245 246 247 248 249 250
  // Note: the following members are only used for auto_parallel for now.
  static uint64_t GenerateId() {
    static std::atomic<std::uint64_t> uid{0};
    // Must start from one
    return ++uid;
  }
251
  uint64_t id_ = GenerateId();
252
  uint64_t original_id_ = id_;
253
  std::unique_ptr<OperatorDistAttr> dist_attr_;
F
fengjiayi 已提交
254
};
255 256

std::vector<std::string> AttrVarNames(const Attribute &attr);
F
fengjiayi 已提交
257 258
}  // namespace framework
}  // namespace paddle