operation.h 3.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

Z
zhangbo9674 已提交
17
#include <iostream>
18 19 20 21
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation_utils.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/value_impl.h"
22 23

namespace ir {
24
class OpBase;
25
class Program;
26
class Block;
27

28 29 30 31 32
class alignas(8) Operation final {
 public:
  ///
  /// \brief Malloc memory and construct objects in the following order:
  /// OpResultImpls|Operation|OpOperandImpls.
33 34
  /// NOTE: Similar to new and delete, the destroy() and the create() need to be
  /// used in conjunction.
35 36
  ///
  static Operation *create(const std::vector<ir::OpResult> &inputs,
37
                           const AttributeMap &attribute,
38 39 40 41
                           const std::vector<ir::Type> &output_types,
                           ir::OpInfo op_info,
                           size_t num_regions = 0);
  static Operation *create(OperationArgument &&op_argument);
42

43
  ///
C
co63oc 已提交
44
  /// \brief Destroy the operation objects and free memory by create().
45
  ///
46 47
  void destroy();

48 49
  Block *parent() const { return parent_; }

50 51
  IrContext *ir_context() const;

52
  ir::OpResult GetResultByIndex(uint32_t index) const;
53

54
  ir::OpOperand GetOperandByIndex(uint32_t index) const;
55

56 57
  std::string print();

58
  const AttributeMap &attribute() const { return attribute_; }
59

60
  ir::OpInfo op_info() const { return op_info_; }
61

62 63 64 65
  uint32_t num_results() const { return num_results_; }

  uint32_t num_operands() const { return num_operands_; }

66 67
  uint32_t num_regions() const { return num_regions_; }

68 69
  std::string op_name() const;

70
  template <typename T>
Z
zhangbo9674 已提交
71
  T dyn_cast() {
72 73 74 75 76 77 78 79 80 81 82 83
    return CastUtil<T>::call(this);
  }

  template <typename Trait>
  bool HasTrait() const {
    return op_info_.HasTrait<Trait>();
  }

  template <typename Interface>
  bool HasInterface() const {
    return op_info_.HasInterface<Interface>();
  }
84

85 86 87 88 89 90
  Program *parent_program() const { return parent_program_; }

  void set_parent_program(Program *parent_program) {
    parent_program_ = parent_program;
  }

91 92 93
  /// Returns the region held by this operation at position 'index'.
  Region &GetRegion(unsigned index);

94
 private:
95 96 97
  Operation(const AttributeMap &attribute,
            ir::OpInfo op_info,
            uint32_t num_results,
98
            uint32_t num_operands,
99
            uint32_t num_regions);
100 101 102

  template <typename T, typename Enabler = void>
  struct CastUtil {
Z
zhangbo9674 已提交
103
    static T call(Operation *op) {
104
      throw("Can't dyn_cast to T, T should be a Op or Trait or Interface");
105
    }
106
  };
107

108 109 110
  friend class Block;
  void set_parent(Block *parent) { parent_ = parent; }

111
  template <typename T>
112 113 114
  struct CastUtil<
      T,
      typename std::enable_if<std::is_base_of<OpBase, T>::value>::type> {
Z
zhangbo9674 已提交
115
    static T call(Operation *op) { return T::dyn_cast(op); }
116
  };
117

118
  AttributeMap attribute_;
119

120
  OpInfo op_info_;
121

122 123 124
  const uint32_t num_results_ = 0;
  const uint32_t num_operands_ = 0;
  const uint32_t num_regions_ = 0;
125

126
  Region *regions_{nullptr};
127
  Program *parent_program_{nullptr};
128
  Block *parent_{nullptr};
129 130 131
};

}  // namespace ir