ir_context.cc 12.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/ir/core/ir_context.h"
16

17 18
#include <unordered_map>

19 20 21 22 23 24 25
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/op_info_impl.h"
#include "paddle/ir/core/spin_lock.h"
#include "paddle/ir/core/type_base.h"
26 27

namespace ir {
28
// The implementation class of the IrContext class, cache registered
Z
zhangbo9674 已提交
29
// AbstractType, TypeStorage, AbstractAttribute, AttributeStorage, Dialect.
30 31 32 33 34
class IrContextImpl {
 public:
  IrContextImpl() {}

  ~IrContextImpl() {
35 36
    std::lock_guard<ir::SpinLock> guard(destructor_lock_);
    for (auto &abstract_type_map : registed_abstract_types_) {
37 38 39
      delete abstract_type_map.second;
    }
    registed_abstract_types_.clear();
40

Z
zhangbo9674 已提交
41 42 43 44 45
    for (auto &abstract_attribute_map : registed_abstract_attributes_) {
      delete abstract_attribute_map.second;
    }
    registed_abstract_attributes_.clear();

46 47 48 49
    for (auto &dialect_map : registed_dialect_) {
      delete dialect_map.second;
    }
    registed_dialect_.clear();
50 51 52 53 54

    for (auto &op_map : registed_op_infos_) {
      op_map.second->destroy();
    }
    registed_op_infos_.clear();
55 56 57 58
  }

  void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) {
    std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
59
    VLOG(4) << "Register an abstract_type of: [TypeId_hash="
60 61 62 63 64
            << std::hash<ir::TypeId>()(type_id)
            << ", AbstractType_ptr=" << abstract_type << "].";
    registed_abstract_types_.emplace(type_id, abstract_type);
  }

65
  AbstractType *GetAbstractType(ir::TypeId type_id) {
66 67
    std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
    auto iter = registed_abstract_types_.find(type_id);
68
    if (iter != registed_abstract_types_.end()) {
C
co63oc 已提交
69
      VLOG(4) << "Found a cached abstract_type of: [TypeId_hash="
70 71 72 73
              << std::hash<ir::TypeId>()(type_id)
              << ", AbstractType_ptr=" << iter->second << "].";
      return iter->second;
    }
74 75 76
    LOG(WARNING) << "No cache found abstract_type of: [TypeId_hash="
                 << std::hash<ir::TypeId>()(type_id) << "].";
    return nullptr;
77 78
  }

Z
zhangbo9674 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91
  void RegisterAbstractAttribute(ir::TypeId type_id,
                                 AbstractAttribute *abstract_attribute) {
    std::lock_guard<ir::SpinLock> guard(registed_abstract_attributes_lock_);
    VLOG(4) << "Register an abstract_attribute of: [TypeId_hash="
            << std::hash<ir::TypeId>()(type_id)
            << ", AbstractAttribute_ptr=" << abstract_attribute << "].";
    registed_abstract_attributes_.emplace(type_id, abstract_attribute);
  }

  AbstractAttribute *GetAbstractAttribute(ir::TypeId type_id) {
    std::lock_guard<ir::SpinLock> guard(registed_abstract_attributes_lock_);
    auto iter = registed_abstract_attributes_.find(type_id);
    if (iter != registed_abstract_attributes_.end()) {
C
co63oc 已提交
92
      VLOG(4) << "Found a cached abstract_attribute of: [TypeId_hash="
Z
zhangbo9674 已提交
93 94 95 96 97 98 99 100 101
              << std::hash<ir::TypeId>()(type_id)
              << ", AbstractAttribute_ptr=" << iter->second << "].";
      return iter->second;
    }
    LOG(WARNING) << "No cache found abstract_attribute of: [TypeId_hash="
                 << std::hash<ir::TypeId>()(type_id) << "].";
    return nullptr;
  }

102 103 104 105
  bool IsOpInfoRegistered(const std::string &name) {
    return registed_op_infos_.find(name) != registed_op_infos_.end();
  }

106 107 108 109 110 111 112 113 114 115 116
  void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) {
    std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
    VLOG(4) << "Register an operation of: [Name=" << name
            << ", OpInfoImpl ptr=" << opinfo << "].";
    registed_op_infos_.emplace(name, opinfo);
  }

  OpInfoImpl *GetOpInfo(const std::string &name) {
    std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
    auto iter = registed_op_infos_.find(name);
    if (iter != registed_op_infos_.end()) {
C
co63oc 已提交
117
      VLOG(4) << "Found a cached operation of: [name=" << name
118 119 120 121 122 123 124
              << ", OpInfoImpl ptr=" << iter->second << "].";
      return iter->second;
    }
    LOG(WARNING) << "No cache found operation of: [Name=" << name << "].";
    return nullptr;
  }

125 126 127 128 129 130 131
  void RegisterDialect(std::string name, Dialect *dialect) {
    std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
    VLOG(4) << "Register a dialect of: [name=" << name
            << ", dialect_ptr=" << dialect << "].";
    registed_dialect_.emplace(name, dialect);
  }

132 133 134 135 136
  bool IsDialectRegistered(const std::string &name) {
    return registed_dialect_.find(name) != registed_dialect_.end();
  }

  Dialect *GetDialect(const std::string &name) {
137 138 139
    std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
    auto iter = registed_dialect_.find(name);
    if (iter != registed_dialect_.end()) {
C
co63oc 已提交
140
      VLOG(4) << "Found a cached dialect of: [name=" << name
141 142 143
              << ", dialect_ptr=" << iter->second << "].";
      return iter->second;
    }
C
co63oc 已提交
144
    LOG(WARNING) << "No cache found dialect of: [name=" << name << "].";
145 146
    return nullptr;
  }
147 148 149

  // Cached AbstractType instances.
  std::unordered_map<TypeId, AbstractType *> registed_abstract_types_;
150
  ir::SpinLock registed_abstract_types_lock_;
151
  // TypeStorage uniquer and cache instances.
Z
zhangbo9674 已提交
152
  StorageManager registed_type_storage_manager_;
153 154
  // Cache some built-in type objects.
  Float16Type fp16_type;
155
  Float32Type fp32_type;
156 157
  Float64Type fp64_type;
  Int16Type int16_type;
158
  Int32Type int32_type;
159
  Int64Type int64_type;
160

Z
zhangbo9674 已提交
161 162 163 164 165 166
  // Cached AbstractAttribute instances.
  std::unordered_map<TypeId, AbstractAttribute *> registed_abstract_attributes_;
  ir::SpinLock registed_abstract_attributes_lock_;
  // AttributeStorage uniquer and cache instances.
  StorageManager registed_attribute_storage_manager_;

C
co63oc 已提交
167
  // The dialect registered in the context.
Z
zhangbo9674 已提交
168 169 170
  std::unordered_map<std::string, Dialect *> registed_dialect_;
  ir::SpinLock registed_dialect_lock_;

171 172 173 174
  // The Op registered in the context.
  std::unordered_map<std::string, OpInfoImpl *> registed_op_infos_;
  ir::SpinLock registed_op_infos_lock_;

175
  ir::SpinLock destructor_lock_;
176 177 178 179 180 181 182 183
};

IrContext *IrContext::Instance() {
  static IrContext context;
  return &context;
}

IrContext::IrContext() : impl_(new IrContextImpl()) {
184 185 186 187
  VLOG(4) << "BuiltinDialect registered into IrContext. ===>";
  GetOrRegisterDialect<BuiltinDialect>();
  VLOG(4) << "==============================================";

188
  impl_->fp16_type = TypeManager::get<Float16Type>(this);
189
  impl_->fp32_type = TypeManager::get<Float32Type>(this);
190 191
  impl_->fp64_type = TypeManager::get<Float64Type>(this);
  impl_->int16_type = TypeManager::get<Int16Type>(this);
192
  impl_->int32_type = TypeManager::get<Int32Type>(this);
193
  impl_->int64_type = TypeManager::get<Int64Type>(this);
194 195
}

Z
zhangbo9674 已提交
196 197
StorageManager &IrContext::type_storage_manager() {
  return impl().registed_type_storage_manager_;
198 199
}

200 201 202 203 204 205
AbstractType *IrContext::GetRegisteredAbstractType(TypeId id) {
  auto search = impl().registed_abstract_types_.find(id);
  if (search != impl().registed_abstract_types_.end()) {
    return search->second;
  }
  return nullptr;
206 207
}

Z
zhangbo9674 已提交
208
void IrContext::RegisterAbstractAttribute(
209 210 211 212 213 214 215 216
    ir::TypeId type_id, AbstractAttribute &&abstract_attribute) {
  if (GetRegisteredAbstractAttribute(type_id) == nullptr) {
    impl().RegisterAbstractAttribute(
        type_id, new AbstractAttribute(std::move(abstract_attribute)));
    VLOG(4) << "<--- Attribute registered into IrContext. --->";
  } else {
    LOG(WARNING) << " Attribute already registered.";
  }
Z
zhangbo9674 已提交
217 218 219 220 221 222
}

StorageManager &IrContext::attribute_storage_manager() {
  return impl().registed_attribute_storage_manager_;
}

223 224 225 226 227 228
AbstractAttribute *IrContext::GetRegisteredAbstractAttribute(TypeId id) {
  auto search = impl().registed_abstract_attributes_.find(id);
  if (search != impl().registed_abstract_attributes_.end()) {
    return search->second;
  }
  return nullptr;
Z
zhangbo9674 已提交
229 230
}

231
Dialect *IrContext::GetOrRegisterDialect(
232
    const std::string &dialect_name, std::function<Dialect *()> constructor) {
233
  VLOG(4) << "Try to get or register a Dialect of: [name=" << dialect_name
234
          << "].";
235
  if (!impl().IsDialectRegistered(dialect_name)) {
236 237
    VLOG(4) << "Create and register a new Dialect of: [name=" << dialect_name
            << "].";
238
    impl().RegisterDialect(dialect_name, constructor());
239
  }
240
  return impl().GetDialect(dialect_name);
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
}

std::vector<Dialect *> IrContext::GetRegisteredDialects() {
  std::vector<Dialect *> result;
  for (auto dialect_map : impl().registed_dialect_) {
    result.push_back(dialect_map.second);
  }
  return result;
}

Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) {
  for (auto dialect_map : impl().registed_dialect_) {
    if (dialect_map.first == dialect_name) {
      return dialect_map.second;
    }
  }
  LOG(WARNING) << "No dialect registered for " << dialect_name;
  return nullptr;
}

261 262 263 264 265 266 267 268 269
void IrContext::RegisterAbstractType(ir::TypeId type_id,
                                     AbstractType &&abstract_type) {
  if (GetRegisteredAbstractType(type_id) == nullptr) {
    impl().RegisterAbstractType(type_id,
                                new AbstractType(std::move(abstract_type)));
    VLOG(4) << "<--- Type registered into IrContext. --->";
  } else {
    LOG(WARNING) << " type already registered.";
  }
270 271
}

272 273 274 275 276 277
void IrContext::RegisterOpInfo(Dialect *dialect,
                               TypeId op_id,
                               const char *name,
                               std::vector<InterfaceValue> &&interface_map,
                               const std::vector<TypeId> &trait_set,
                               size_t attributes_num,
278 279
                               const char **attributes_name,
                               VerifyPtr verify) {
280 281 282
  if (impl().IsOpInfoRegistered(name)) {
    LOG(WARNING) << name << " op already registered.";
  } else {
283 284 285 286 287 288
    OpInfoImpl *opinfo = OpInfoImpl::create(dialect,
                                            op_id,
                                            name,
                                            std::move(interface_map),
                                            trait_set,
                                            attributes_num,
289 290
                                            attributes_name,
                                            verify);
291
    impl().RegisterOpInfo(name, opinfo);
292
    VLOG(4) << name << " op registered into IrContext. --->";
293 294 295
  }
}

296 297 298 299 300
OpInfo IrContext::GetRegisteredOpInfo(const std::string &name) {
  OpInfoImpl *rtn = impl().GetOpInfo(name);
  return rtn ? rtn : nullptr;
}

301
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
302
  auto &impl = ctx->impl();
303
  AbstractType *abstract_type = impl.GetAbstractType(type_id);
304 305 306 307 308 309 310
  if (abstract_type) {
    return *abstract_type;
  } else {
    throw("Abstract type not found in IrContext.");
  }
}

Z
zhangbo9674 已提交
311 312 313 314 315 316 317 318 319 320 321
const AbstractAttribute &AbstractAttribute::lookup(TypeId type_id,
                                                   IrContext *ctx) {
  auto &impl = ctx->impl();
  AbstractAttribute *abstract_attribute = impl.GetAbstractAttribute(type_id);
  if (abstract_attribute) {
    return *abstract_attribute;
  } else {
    throw("Abstract attribute not found in IrContext.");
  }
}

322 323
Float16Type Float16Type::get(IrContext *ctx) { return ctx->impl().fp16_type; }

324 325
Float32Type Float32Type::get(IrContext *ctx) { return ctx->impl().fp32_type; }

326 327 328 329
Float64Type Float64Type::get(IrContext *ctx) { return ctx->impl().fp64_type; }

Int16Type Int16Type::get(IrContext *ctx) { return ctx->impl().int16_type; }

330 331
Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; }

332 333
Int64Type Int64Type::get(IrContext *ctx) { return ctx->impl().int64_type; }

334
}  // namespace ir