operation.h 4.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <ostream>
18
#include <vector>
19
#include "paddle/ir/core/block.h"
20 21 22
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation_utils.h"
#include "paddle/ir/core/type.h"
23 24

namespace ir {
25
class OpBase;
26
class Program;
27 28
class OpOperand;
class OpResult;
29

30
class IR_API alignas(8) Operation final {
31 32 33 34
 public:
  ///
  /// \brief Malloc memory and construct objects in the following order:
  /// OpResultImpls|Operation|OpOperandImpls.
35 36
  /// NOTE: Similar to new and delete, the destroy() and the create() need to be
  /// used in conjunction.
37
  ///
38
  static Operation *Create(const std::vector<ir::OpResult> &inputs,
39
                           const AttributeMap &attributes,
40 41 42
                           const std::vector<ir::Type> &output_types,
                           ir::OpInfo op_info,
                           size_t num_regions = 0);
43
  static Operation *Create(OperationArgument &&op_argument);
44

45
  ///
C
co63oc 已提交
46
  /// \brief Destroy the operation objects and free memory by create().
47
  ///
48
  void Destroy();
49

50
  IrContext *ir_context() const;
51

52
  Dialect *dialect() const;
53

54 55
  OpResult result(uint32_t index) const;

56 57 58
  OpOperand op_operand(uint32_t index) const;

  Value operand(uint32_t index) const;
59

60 61
  /// Returns the region held by this operation at position 'index'.
  Region &region(unsigned index);
62
  const Region &region(unsigned index) const;
63

64
  void Print(std::ostream &os) const;
65

66 67
  const AttributeMap &attributes() const { return attributes_; }

68
  void set_attribute(const std::string &key, Attribute value) {
69 70
    attributes_[key] = value;
  }
71

72 73 74 75 76 77
  Attribute attribute(const std::string &key) const;

  bool HasAttribute(const std::string &key) const {
    return attributes_.find(key) != attributes_.end();
  }

78
  ir::OpInfo info() const { return info_; }
79

80 81 82 83
  uint32_t num_results() const { return num_results_; }

  uint32_t num_operands() const { return num_operands_; }

84 85
  uint32_t num_regions() const { return num_regions_; }

86
  std::string name() const;
87

88
  template <typename T>
Z
zhangbo9674 已提交
89
  T dyn_cast() {
90 91 92 93 94
    return CastUtil<T>::call(this);
  }

  template <typename Trait>
  bool HasTrait() const {
95
    return info_.HasTrait<Trait>();
96 97 98 99
  }

  template <typename Interface>
  bool HasInterface() const {
100
    return info_.HasInterface<Interface>();
101
  }
102

103
  Block *GetParent() const { return parent_; }
104

105 106 107 108 109
  Region *GetParentRegion() const;

  Operation *GetParentOp() const;

  Program *GetParentProgram();
110

111 112
  operator Block::iterator() { return position_; }

113 114
  operator Block::const_iterator() const { return position_; }

115 116 117 118 119 120 121
  /// Replace all uses of results of this operation with the provided 'values'.
  void ReplaceAllUsesWith(const std::vector<Value> &values);

  inline void ReplaceAllUsesWith(Value value) {
    ReplaceAllUsesWith(std::vector<Value>{value});
  }

122 123
  void Verify();

124
 private:
125 126 127
  Operation(const AttributeMap &attribute,
            ir::OpInfo op_info,
            uint32_t num_results,
128
            uint32_t num_operands,
129
            uint32_t num_regions);
130 131 132

  template <typename T, typename Enabler = void>
  struct CastUtil {
Z
zhangbo9674 已提交
133
    static T call(Operation *op) {
134
      throw("Can't dyn_cast to T, T should be a Op or Trait or Interface");
135
    }
136
  };
137

138
  // Allow access to 'SetParent'.
139
  friend class Block;
140
  void SetParent(Block *parent, const Block::iterator &position);
141

142
  template <typename T>
143 144 145
  struct CastUtil<
      T,
      typename std::enable_if<std::is_base_of<OpBase, T>::value>::type> {
Z
zhangbo9674 已提交
146
    static T call(Operation *op) { return T::dyn_cast(op); }
147
  };
148

149
  AttributeMap attributes_;
150

151
  OpInfo info_;
152

153 154 155
  const uint32_t num_results_ = 0;
  const uint32_t num_operands_ = 0;
  const uint32_t num_regions_ = 0;
156

157 158
  Region *regions_{nullptr};
  Block *parent_{nullptr};
159
  Block::iterator position_;
160 161 162
};

}  // namespace ir