op_desc.h 6.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <atomic>
A
Abhinav Arora 已提交
18
#include <string>
F
fengjiayi 已提交
19
#include <unordered_map>
H
hong 已提交
20
#include <utility>
F
fengjiayi 已提交
21
#include <vector>
W
wanghuancoder 已提交
22

Y
Yi Wang 已提交
23 24 25
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/var_desc.h"
F
fengjiayi 已提交
26 27 28 29

namespace paddle {
namespace framework {

Y
Yu Yang 已提交
30 31
class BlockDesc;
class ProgramDesc;
W
wanghuancoder 已提交
32

Y
Yu Yang 已提交
33
class OpDesc {
F
fengjiayi 已提交
34
 public:
Y
Yu Yang 已提交
35
  OpDesc() {}
F
fengjiayi 已提交
36

37 38 39 40
  OpDesc(const std::string &type,
         const VariableNameMap &inputs,
         const VariableNameMap &outputs,
         const AttributeMap &attrs);
F
fengjiayi 已提交
41

F
fengjiayi 已提交
42
  OpDesc(const proto::OpDesc &desc, BlockDesc *block);
43 44 45

  explicit OpDesc(BlockDesc *block) : block_(block) {}

X
Xin Pan 已提交
46
  OpDesc(const OpDesc &other, BlockDesc *block);
47

48
  void CopyFrom(const OpDesc &op_desc);
F
fengjiayi 已提交
49

50
  proto::OpDesc *Proto();
F
fengjiayi 已提交
51

52
  std::string Type() const { return desc_.type(); }
F
fengjiayi 已提交
53

54
  void SetType(const std::string &type) { desc_.set_type(type); }
F
fengjiayi 已提交
55 56 57

  const std::vector<std::string> &Input(const std::string &name) const;

F
Update  
fengjiayi 已提交
58 59
  std::vector<std::string> InputArgumentNames() const;

F
fengjiayi 已提交
60 61 62 63 64
  void SetInput(const std::string &param_name,
                const std::vector<std::string> &args);

  const std::vector<std::string> &Output(const std::string &name) const;

65 66
  bool HasOutput(const std::string &name) const;

F
Update  
fengjiayi 已提交
67 68
  std::vector<std::string> OutputArgumentNames() const;

F
fengjiayi 已提交
69 70
  void SetOutput(const std::string &param_name,
                 const std::vector<std::string> &args);
71
  void RemoveOutput(const std::string &name);
F
fengjiayi 已提交
72

73 74
  void RemoveInput(const std::string &name);

F
fengjiayi 已提交
75 76 77 78
  bool HasAttr(const std::string &name) const {
    return attrs_.find(name) != attrs_.end();
  }

79
  bool HasProtoAttr(const std::string &name) const;
F
fengjiayi 已提交
80

81
  proto::AttrType GetAttrType(const std::string &name) const;
F
fengjiayi 已提交
82 83 84 85

  std::vector<std::string> AttrNames() const;

  void SetAttr(const std::string &name, const Attribute &v);
86
  void RemoveAttr(const std::string &name);
F
fengjiayi 已提交
87

A
Abhinav Arora 已提交
88
  void SetBlockAttr(const std::string &name, BlockDesc *block);
F
fengjiayi 已提交
89

90 91
  void SetBlocksAttr(const std::string &name, std::vector<BlockDesc *> blocks);

F
fengjiayi 已提交
92 93
  Attribute GetAttr(const std::string &name) const;

94 95 96 97
  template <typename T>
  T GetAttrIfExists(const std::string &name) const {
    T result{};
    if (HasAttr(name)) {
R
Ruibiao Chen 已提交
98
      result = PADDLE_GET_CONST(T, GetAttr(name));
99 100 101 102
    }
    return result;
  }

M
minqiyang 已提交
103
  const proto::OpProto::Attr &GetProtoAttr(const std::string &name) const;
M
minqiyang 已提交
104

Y
yuyang18 已提交
105
  Attribute GetNullableAttr(const std::string &name) const;
Y
Fix bug  
yuyang18 已提交
106

G
gongweibao 已提交
107 108 109
  int GetBlockAttrId(const std::string &name) const;

  std::vector<int> GetBlocksAttrIds(const std::string &name) const;
F
fengjiayi 已提交
110

F
fengjiayi 已提交
111 112
  void Rename(const std::string &old_name, const std::string &new_name);

Y
Yang Yang(Tony) 已提交
113 114 115 116
  void RenameOutput(const std::string &old_name, const std::string &new_name);

  void RenameInput(const std::string &old_name, const std::string &new_name);

F
fengjiayi 已提交
117
  // Only be used in C++
Y
Yu Yang 已提交
118
  const AttributeMap &GetAttrMap() const;
F
fengjiayi 已提交
119

120 121
  // Only be used in C++
  void SetAttrMap(const AttributeMap &attr_map);
122

Y
Yu Yang 已提交
123 124
  std::vector<std::string> InputNames() const { return MapKeys(inputs_); }
  std::vector<std::string> OutputNames() const { return MapKeys(outputs_); }
125

Y
Yu Yang 已提交
126 127 128 129
  const VariableNameMap &Inputs() const { return inputs_; }

  const VariableNameMap &Outputs() const { return outputs_; }

130 131 132 133 134 135 136 137 138 139
  VariableNameMap *MutableInputs() {
    this->need_update_ = true;
    return &this->inputs_;
  }

  VariableNameMap *MutableOutputs() {
    this->need_update_ = true;
    return &this->outputs_;
  }

140 141 142 143 144
  AttributeMap *MutableAttrMap() {
    this->need_update_ = true;
    return &this->attrs_;
  }

F
fengjiayi 已提交
145 146
  void CheckAttrs();

H
hong 已提交
147
  void InferShape(const BlockDesc &block);
Y
Yu Yang 已提交
148

Y
Yu Yang 已提交
149
  void InferVarType(BlockDesc *block) const;
Y
Yu Yang 已提交
150

151
  void SetIsTarget(bool is_target) { desc_.set_is_target(is_target); }
152

153 154
  void Flush();

155 156
  BlockDesc *Block() { return this->block_; }

S
sneaxiy 已提交
157 158
  const BlockDesc *Block() const { return this->block_; }

159
  // The Id() and OrignalId() are only used for auto parallel.
160
  uint64_t Id() const { return id_; }
161 162
  uint64_t OriginalId() const { return original_id_; }
  void SetOriginalId(uint64_t original_id) { original_id_ = original_id; }
163

F
fengjiayi 已提交
164
 private:
165 166 167 168 169
  template <typename MapType>
  static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
    std::vector<typename MapType::key_type> ret_val;
    ret_val.reserve(map.size());
    std::transform(
170 171 172
        map.begin(),
        map.end(),
        std::back_inserter(ret_val),
173 174 175 176
        [](const typename MapType::value_type &pair) { return pair.first; });
    return ret_val;
  }

177 178 179 180 181 182 183 184
  // This thread-safe implementation seems to be redudent since the neural
  // networks are usually constructed in a single thread
  static uint64_t GenerateId() {
    static std::atomic<std::uint64_t> uid{0};
    // Must start from one
    return ++uid;
  }

185
  proto::OpDesc desc_;
186
  BlockDesc *block_{nullptr};  // not_own
187
  // input arg name => input variable names
Y
Yu Yang 已提交
188
  VariableNameMap inputs_;
189
  // output arg name => output variable names
Y
Yu Yang 已提交
190 191
  VariableNameMap outputs_;
  AttributeMap attrs_;
F
fengjiayi 已提交
192 193 194 195

  // need_update_ indicate there some local changes not be synchronized. If
  // local changes should be synchronized, need_update_ should be set to true.
  bool need_update_{false};
196

197
  // Note: the id_ is unique (only for auto parallel).
198
  uint64_t id_ = GenerateId();
199 200 201 202 203
  // Note: the orignal_id_ is used for referring to the original OpDesc
  // that the current OpDesc is built from (only for auto parallel).
  // The default original_id_ is same as the id_, which means the
  // current OpDesc is not built from the other one.
  uint64_t original_id_ = id_;
F
fengjiayi 已提交
204
};
F
fengjiayi 已提交
205 206
}  // namespace framework
}  // namespace paddle