operation.h 5.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <ostream>
18
#include <vector>
19
#include "paddle/ir/core/block.h"
20
#include "paddle/ir/core/enforce.h"
21
#include "paddle/ir/core/macros.h"
22 23 24
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation_utils.h"
#include "paddle/ir/core/type.h"
25 26

namespace ir {
27
class OpBase;
28
class Program;
29 30
class OpOperand;
class OpResult;
31

32 33 34 35
namespace detial {
class BlockOperandImpl;
}  // namespace detial

36
class IR_API alignas(8) Operation final {
37 38 39 40
 public:
  ///
  /// \brief Malloc memory and construct objects in the following order:
  /// OpResultImpls|Operation|OpOperandImpls.
41 42
  /// NOTE: Similar to new and delete, the destroy() and the create() need to be
  /// used in conjunction.
43
  ///
44
  static Operation *Create(const std::vector<ir::OpResult> &inputs,
45
                           const AttributeMap &attributes,
46 47
                           const std::vector<ir::Type> &output_types,
                           ir::OpInfo op_info,
48 49
                           size_t num_regions = 0,
                           const std::vector<Block *> &successors = {});
50
  static Operation *Create(OperationArgument &&op_argument);
51

52
  ///
C
co63oc 已提交
53
  /// \brief Destroy the operation objects and free memory by create().
54
  ///
55
  void Destroy();
56

57
  IrContext *ir_context() const;
58

59
  Dialect *dialect() const;
60

61 62
  OpResult result(uint32_t index) const;

63
  OpOperand operand(uint32_t index) const;
64

65
  Value operand_source(uint32_t index) const;
66

67 68 69 70 71 72
  uint32_t num_successors() const { return num_successors_; }
  BlockOperand block_operand(uint32_t index) const;
  Block *successor(uint32_t index) const;
  void set_successor(Block *block, unsigned index);
  bool HasSuccessors() { return num_successors_ != 0; }

73 74
  /// Returns the region held by this operation at position 'index'.
  Region &region(unsigned index);
75
  const Region &region(unsigned index) const;
76
  uint32_t num_regions() const { return num_regions_; }
77

78
  void Print(std::ostream &os);
79

80 81
  const AttributeMap &attributes() const { return attributes_; }

82 83
  template <typename T>
  T attribute(const std::string &name) {
84 85 86
    Attribute attr = attribute(name);
    IR_ENFORCE(attr.isa<T>(), "Attribute (%s) type is not right.", name);
    return attr.dyn_cast<T>();
87 88
  }

89
  void set_attribute(const std::string &key, Attribute value) {
90 91
    attributes_[key] = value;
  }
92

93 94 95 96 97 98
  Attribute attribute(const std::string &key) const;

  bool HasAttribute(const std::string &key) const {
    return attributes_.find(key) != attributes_.end();
  }

99
  ir::OpInfo info() const { return info_; }
100

101 102 103 104
  uint32_t num_results() const { return num_results_; }

  uint32_t num_operands() const { return num_operands_; }

105
  std::string name() const;
106

107
  template <typename T>
Z
zhangbo9674 已提交
108
  T dyn_cast() {
109 110 111
    return CastUtil<T>::call(this);
  }

112 113 114 115 116
  template <typename T>
  bool isa() const {
    return T::classof(this);
  }

117 118
  template <typename Trait>
  bool HasTrait() const {
119
    return info_.HasTrait<Trait>();
120 121 122 123
  }

  template <typename Interface>
  bool HasInterface() const {
124
    return info_.HasInterface<Interface>();
125
  }
126

127
  const Block *GetParent() const { return parent_; }
128

129 130 131 132 133 134
  Block *GetParent() {
    return const_cast<Block *>(
        const_cast<const Operation *>(this)->GetParent());
  }

  Region *GetParentRegion();
135 136 137

  Operation *GetParentOp() const;

138 139 140 141 142 143
  const Program *GetParentProgram() const;

  Program *GetParentProgram() {
    return const_cast<Program *>(
        const_cast<const Operation *>(this)->GetParentProgram());
  }
144

145 146
  operator Block::iterator() { return position_; }

147 148
  operator Block::const_iterator() const { return position_; }

149 150 151
  /// Replace all uses of results of this operation with the provided 'values'.
  void ReplaceAllUsesWith(const std::vector<Value> &values);

152 153
  void ReplaceAllUsesWith(const std::vector<OpResult> &op_results);

154 155 156 157
  inline void ReplaceAllUsesWith(Value value) {
    ReplaceAllUsesWith(std::vector<Value>{value});
  }

158 159
  void Verify();

160 161 162 163
  std::vector<OpOperand> operands() const;

  std::vector<OpResult> results() const;

164
 private:
165
  DISABLE_COPY_AND_ASSIGN(Operation);
166 167 168
  Operation(const AttributeMap &attribute,
            ir::OpInfo op_info,
            uint32_t num_results,
169
            uint32_t num_operands,
170 171
            uint32_t num_regions,
            uint32_t num_successors);
172 173 174

  template <typename T, typename Enabler = void>
  struct CastUtil {
Z
zhangbo9674 已提交
175
    static T call(Operation *op) {
176
      throw("Can't dyn_cast to T, T should be a Op or Trait or Interface");
177
    }
178
  };
179

180
  // Allow access to 'SetParent'.
181
  friend class Block;
182
  void SetParent(Block *parent, const Block::iterator &position);
183

184
  template <typename T>
185 186 187
  struct CastUtil<
      T,
      typename std::enable_if<std::is_base_of<OpBase, T>::value>::type> {
Z
zhangbo9674 已提交
188
    static T call(Operation *op) { return T::dyn_cast(op); }
189
  };
190

191
  AttributeMap attributes_;
192

193
  OpInfo info_;
194

195 196 197
  const uint32_t num_results_ = 0;
  const uint32_t num_operands_ = 0;
  const uint32_t num_regions_ = 0;
198
  const uint32_t num_successors_ = 0;
199

200
  detail::BlockOperandImpl *block_operands_{nullptr};
201 202
  Region *regions_{nullptr};
  Block *parent_{nullptr};
203
  Block::iterator position_;
204 205 206
};

}  // namespace ir