op_desc.h 6.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <atomic>
A
Abhinav Arora 已提交
18
#include <string>
F
fengjiayi 已提交
19
#include <unordered_map>
H
hong 已提交
20
#include <utility>
F
fengjiayi 已提交
21
#include <vector>
W
wanghuancoder 已提交
22

Y
Yi Wang 已提交
23 24 25
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/var_desc.h"
F
fengjiayi 已提交
26 27 28 29

namespace paddle {
namespace framework {

30
class VarDesc;
Y
Yu Yang 已提交
31 32
class BlockDesc;
class ProgramDesc;
W
wanghuancoder 已提交
33

Y
Yu Yang 已提交
34
class OpDesc {
F
fengjiayi 已提交
35
 public:
Y
Yu Yang 已提交
36
  OpDesc() {}
F
fengjiayi 已提交
37

38 39 40 41
  OpDesc(const std::string &type,
         const VariableNameMap &inputs,
         const VariableNameMap &outputs,
         const AttributeMap &attrs);
F
fengjiayi 已提交
42

F
fengjiayi 已提交
43
  OpDesc(const proto::OpDesc &desc, BlockDesc *block);
44 45 46

  explicit OpDesc(BlockDesc *block) : block_(block) {}

X
Xin Pan 已提交
47
  OpDesc(const OpDesc &other, BlockDesc *block);
48

49
  void CopyFrom(const OpDesc &op_desc);
F
fengjiayi 已提交
50

51
  proto::OpDesc *Proto();
F
fengjiayi 已提交
52

53
  std::string Type() const { return desc_.type(); }
F
fengjiayi 已提交
54

55
  void SetType(const std::string &type) { desc_.set_type(type); }
F
fengjiayi 已提交
56 57 58

  const std::vector<std::string> &Input(const std::string &name) const;

59 60 61 62
  std::vector<std::string> Input(const std::string &name,
                                 bool with_attr_var) const;

  std::vector<std::string> InputArgumentNames(bool with_attr_var = false) const;
F
Update  
fengjiayi 已提交
63

F
fengjiayi 已提交
64 65 66 67 68
  void SetInput(const std::string &param_name,
                const std::vector<std::string> &args);

  const std::vector<std::string> &Output(const std::string &name) const;

69 70
  bool HasOutput(const std::string &name) const;

F
Update  
fengjiayi 已提交
71 72
  std::vector<std::string> OutputArgumentNames() const;

F
fengjiayi 已提交
73 74
  void SetOutput(const std::string &param_name,
                 const std::vector<std::string> &args);
75
  void RemoveOutput(const std::string &name);
F
fengjiayi 已提交
76

77 78
  void RemoveInput(const std::string &name);

79
  bool HasAttr(const std::string &name, bool with_attr_var = false) const;
F
fengjiayi 已提交
80

81
  bool HasProtoAttr(const std::string &name) const;
F
fengjiayi 已提交
82

83 84
  proto::AttrType GetAttrType(const std::string &name,
                              bool with_attr_var = false) const;
F
fengjiayi 已提交
85

86
  std::vector<std::string> AttrNames(bool with_attr_var = false) const;
F
fengjiayi 已提交
87 88

  void SetAttr(const std::string &name, const Attribute &v);
89
  void RemoveAttr(const std::string &name);
F
fengjiayi 已提交
90

91 92 93 94
  void SetVarAttr(const std::string &name, VarDesc *var);

  void SetVarsAttr(const std::string &name, std::vector<VarDesc *> vars);

A
Abhinav Arora 已提交
95
  void SetBlockAttr(const std::string &name, BlockDesc *block);
F
fengjiayi 已提交
96

97 98
  void SetBlocksAttr(const std::string &name, std::vector<BlockDesc *> blocks);

99
  Attribute GetAttr(const std::string &name, bool with_attr_var = false) const;
F
fengjiayi 已提交
100

101 102 103 104
  template <typename T>
  T GetAttrIfExists(const std::string &name) const {
    T result{};
    if (HasAttr(name)) {
R
Ruibiao Chen 已提交
105
      result = PADDLE_GET_CONST(T, GetAttr(name));
106 107 108 109
    }
    return result;
  }

M
minqiyang 已提交
110
  const proto::OpProto::Attr &GetProtoAttr(const std::string &name) const;
M
minqiyang 已提交
111

Y
yuyang18 已提交
112
  Attribute GetNullableAttr(const std::string &name) const;
Y
Fix bug  
yuyang18 已提交
113

G
gongweibao 已提交
114 115 116
  int GetBlockAttrId(const std::string &name) const;

  std::vector<int> GetBlocksAttrIds(const std::string &name) const;
F
fengjiayi 已提交
117

F
fengjiayi 已提交
118 119
  void Rename(const std::string &old_name, const std::string &new_name);

Y
Yang Yang(Tony) 已提交
120 121 122 123
  void RenameOutput(const std::string &old_name, const std::string &new_name);

  void RenameInput(const std::string &old_name, const std::string &new_name);

F
fengjiayi 已提交
124
  // Only be used in C++
Y
Yu Yang 已提交
125
  const AttributeMap &GetAttrMap() const;
F
fengjiayi 已提交
126

127 128
  // Only be used in C++
  void SetAttrMap(const AttributeMap &attr_map);
129

130 131 132
  std::vector<std::string> InputNames(bool with_attr_var = false) const {
    return MapKeys(inputs_);
  }
Y
Yu Yang 已提交
133
  std::vector<std::string> OutputNames() const { return MapKeys(outputs_); }
134

Y
Yu Yang 已提交
135 136
  const VariableNameMap &Inputs() const { return inputs_; }

137 138
  VariableNameMap Inputs(bool with_attr_var) const;

Y
Yu Yang 已提交
139 140
  const VariableNameMap &Outputs() const { return outputs_; }

141 142 143 144 145 146 147 148 149 150
  VariableNameMap *MutableInputs() {
    this->need_update_ = true;
    return &this->inputs_;
  }

  VariableNameMap *MutableOutputs() {
    this->need_update_ = true;
    return &this->outputs_;
  }

151 152 153 154 155
  AttributeMap *MutableAttrMap() {
    this->need_update_ = true;
    return &this->attrs_;
  }

F
fengjiayi 已提交
156 157
  void CheckAttrs();

H
hong 已提交
158
  void InferShape(const BlockDesc &block);
Y
Yu Yang 已提交
159

Y
Yu Yang 已提交
160
  void InferVarType(BlockDesc *block) const;
Y
Yu Yang 已提交
161

162
  void SetIsTarget(bool is_target) { desc_.set_is_target(is_target); }
163

164 165
  void Flush();

166 167
  BlockDesc *Block() { return this->block_; }

S
sneaxiy 已提交
168 169
  const BlockDesc *Block() const { return this->block_; }

170 171
  void UpdateVarAttr(const std::string &name, const Attribute &attr);

172
  // The Id() and OrignalId() are only used for auto parallel.
173
  uint64_t Id() const { return id_; }
174 175
  uint64_t OriginalId() const { return original_id_; }
  void SetOriginalId(uint64_t original_id) { original_id_ = original_id; }
176

F
fengjiayi 已提交
177
 private:
178 179 180 181
  friend class ProgramDesc;
  // Find VarDesc from OpDesc located Block into global Block
  VarDesc *FindVarRecursive(const std::string &name);

182 183 184 185 186
  template <typename MapType>
  static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
    std::vector<typename MapType::key_type> ret_val;
    ret_val.reserve(map.size());
    std::transform(
187 188 189
        map.begin(),
        map.end(),
        std::back_inserter(ret_val),
190 191 192 193
        [](const typename MapType::value_type &pair) { return pair.first; });
    return ret_val;
  }

194 195 196 197 198 199 200 201
  // This thread-safe implementation seems to be redudent since the neural
  // networks are usually constructed in a single thread
  static uint64_t GenerateId() {
    static std::atomic<std::uint64_t> uid{0};
    // Must start from one
    return ++uid;
  }

202
  proto::OpDesc desc_;
203
  BlockDesc *block_{nullptr};  // not_own
204
  // input arg name => input variable names
Y
Yu Yang 已提交
205
  VariableNameMap inputs_;
206
  // output arg name => output variable names
Y
Yu Yang 已提交
207
  VariableNameMap outputs_;
208
  // attribute name => all original attrs
Y
Yu Yang 已提交
209
  AttributeMap attrs_;
F
fengjiayi 已提交
210 211 212 213

  // need_update_ indicate there some local changes not be synchronized. If
  // local changes should be synchronized, need_update_ should be set to true.
  bool need_update_{false};
214

215
  // Note: the id_ is unique (only for auto parallel).
216
  uint64_t id_ = GenerateId();
217 218 219 220 221
  // Note: the orignal_id_ is used for referring to the original OpDesc
  // that the current OpDesc is built from (only for auto parallel).
  // The default original_id_ is same as the id_, which means the
  // current OpDesc is not built from the other one.
  uint64_t original_id_ = id_;
F
fengjiayi 已提交
222
};
223 224

std::vector<std::string> AttrVarNames(const Attribute &attr);
F
fengjiayi 已提交
225 226
}  // namespace framework
}  // namespace paddle