ir_context.h 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#include <functional>
17
#include <memory>
18
#include <unordered_map>
19
#include <vector>
20 21 22 23 24

namespace ir {
class IrContextImpl;
class StorageManager;
class AbstractType;
Z
zhangbo9674 已提交
25
class AbstractAttribute;
26
class TypeId;
27
class Dialect;
28 29
class OpInfo;
class InterfaceValue;
30 31 32 33
class Type;
class OpResult;
class Attribute;

34 35
using OpInfoMap = std::unordered_map<std::string, OpInfo>;

36 37
///
/// \brief IrContext is a global parameterless class used to store and manage
Z
zhangbo9674 已提交
38
/// Type, Attribute and other related data structures.
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
///
class IrContext {
 public:
  ///
  /// \brief Initializes a new instance of IrContext.
  ///
  static IrContext *Instance();

  ///
  /// \brief Get an instance of IrContextImpl, a private member of IrContext.
  /// For the specific definition of IrContextImpl, see ir_context.cc.
  ///
  /// \return The instance of IrContextImpl.
  ///
  IrContextImpl &impl() { return *impl_; }

  ///
56
  /// \brief Register an AbstractType to IrContext.
57 58 59 60
  ///
  /// \param type_id The type id of the AbstractType.
  /// \param abstract_type AbstractType* provided by user.
  ///
61
  void RegisterAbstractType(TypeId type_id, AbstractType &&abstract_type);
62 63 64 65 66 67 68 69

  ///
  /// \brief Returns the storage uniquer used for constructing TypeStorage
  /// instances.
  ///
  /// \return The storage uniquer used for constructing TypeStorage
  /// instances.
  ///
Z
zhangbo9674 已提交
70
  StorageManager &type_storage_manager();
71 72

  ///
73
  /// \brief Get registered AbstractType from IrContext.
74
  ///
75
  AbstractType *GetRegisteredAbstractType(TypeId id);
76

Z
zhangbo9674 已提交
77 78 79 80
  ///
  /// \brief Register an AbstractAttribute to IrContext
  ///
  /// \param type_id The type id of the AbstractAttribute.
81
  /// \param abstract_attribute AbstractAttribute provided by user.
Z
zhangbo9674 已提交
82 83
  ///
  void RegisterAbstractAttribute(ir::TypeId type_id,
84
                                 AbstractAttribute &&abstract_attribute);
Z
zhangbo9674 已提交
85 86 87 88 89 90 91 92 93 94 95

  ///
  /// \brief Returns the storage uniquer used for constructing AttributeStorage
  /// instances.
  ///
  /// \return The storage uniquer used for constructing AttributeStorage
  /// instances.
  ///
  StorageManager &attribute_storage_manager();

  ///
96
  /// \brief Get registered AbstractAttribute from IrContext.
Z
zhangbo9674 已提交
97
  ///
98 99 100
  AbstractAttribute *GetRegisteredAbstractAttribute(TypeId id);

  ///
101
  /// \brief Register an op infomation to IrContext
Z
zhangbo9674 已提交
102
  ///
103 104 105 106 107 108 109 110 111 112 113 114
  void RegisterOpInfo(
      Dialect *dialect,
      TypeId op_id,
      const char *name,
      std::vector<InterfaceValue> &&interface_map,
      const std::vector<TypeId> &trait_set,
      size_t attributes_num,
      const char **attributes_name,
      void (*verify)(
          const std::vector<OpResult> &inputs,
          const std::vector<Type> &outputs,
          const std::unordered_map<std::string, Attribute> &attributes));
115

116 117 118 119
  ///
  /// \brief Get registered operaiton infomation.
  ///
  OpInfo GetRegisteredOpInfo(const std::string &name);
Z
zhangbo9674 已提交
120

121 122 123 124 125
  ///
  /// \brief Get registered operaiton infomation map.
  ///
  const OpInfoMap &registered_op_info_map();

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
  ///
  /// \brief Get the dialect of the DialectT class in the context, ff not found,
  /// create and register to context.
  ///
  /// \param DialectT The Dialect class that needs to be found or register.
  ///
  /// \return The dialect of the DialectT class in the context.
  ///
  template <typename DialectT>
  DialectT *GetOrRegisterDialect() {
    return static_cast<DialectT *>(
        GetOrRegisterDialect(DialectT::name(), [this]() {
          DialectT *dialect = new DialectT(this);
          return dialect;
        }));
  }

  ///
  /// \brief Get the dialect of the DialectT class in the context, ff not found,
  /// create and register to context.
  ///
  /// \param dialect_name The dialect name.
  /// \param dialect_id The TypeId of the dialect.
  /// \param constructor The dialect constructor.
  ///
  /// \return The dialect named "dialect_name" in the context.
  ///
153
  Dialect *GetOrRegisterDialect(const std::string &dialect_name,
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
                                std::function<Dialect *()> constructor);

  ///
  /// \brief Get the dialect list registered to the context.
  ///
  /// \return The dialect list registered to the context.
  ///
  std::vector<Dialect *> GetRegisteredDialects();

  ///
  /// \brief Get the dialect named "name" from the context.
  ///
  /// \param name The name of the dialect to be obtained.
  ///
  /// \return The dialect named "name" from the context.
  ///
  Dialect *GetRegisteredDialect(const std::string &dialect_name);

  ///
  /// \brief Get a registered dialect for the given dialect type T. The
  /// Dialect must provide a static 'name' method.
  ///
  /// \return The registered dialect for the given dialect type T.
  ///
  template <typename T>
  T *GetRegisteredDialect() {
    return static_cast<T *>(GetRegisteredDialect(T::name()));
  }

183 184 185 186 187 188 189 190 191 192
  IrContext(const IrContext &) = delete;

  void operator=(const IrContext &) = delete;

 private:
  IrContext();
  const std::unique_ptr<IrContextImpl> impl_;
};

}  // namespace ir