ir_context.cc 12.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/ir/core/ir_context.h"
16

17 18
#include <unordered_map>

19 20 21 22 23 24 25
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/op_info_impl.h"
#include "paddle/ir/core/spin_lock.h"
#include "paddle/ir/core/type_base.h"
26 27

namespace ir {
28
// The implementation class of the IrContext class, cache registered
Z
zhangbo9674 已提交
29
// AbstractType, TypeStorage, AbstractAttribute, AttributeStorage, Dialect.
30 31 32 33 34
class IrContextImpl {
 public:
  IrContextImpl() {}

  ~IrContextImpl() {
35 36
    std::lock_guard<ir::SpinLock> guard(destructor_lock_);
    for (auto &abstract_type_map : registed_abstract_types_) {
37 38 39
      delete abstract_type_map.second;
    }
    registed_abstract_types_.clear();
40

Z
zhangbo9674 已提交
41 42 43 44 45
    for (auto &abstract_attribute_map : registed_abstract_attributes_) {
      delete abstract_attribute_map.second;
    }
    registed_abstract_attributes_.clear();

46 47 48 49
    for (auto &dialect_map : registed_dialect_) {
      delete dialect_map.second;
    }
    registed_dialect_.clear();
50 51

    for (auto &op_map : registed_op_infos_) {
52
      OpInfoImpl::Destroy(op_map.second);
53 54
    }
    registed_op_infos_.clear();
55 56 57 58
  }

  void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) {
    std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
59
    VLOG(4) << "Register an abstract_type of: [TypeId_hash="
60 61 62 63 64
            << std::hash<ir::TypeId>()(type_id)
            << ", AbstractType_ptr=" << abstract_type << "].";
    registed_abstract_types_.emplace(type_id, abstract_type);
  }

65
  AbstractType *GetAbstractType(ir::TypeId type_id) {
66 67
    std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
    auto iter = registed_abstract_types_.find(type_id);
68
    if (iter != registed_abstract_types_.end()) {
C
co63oc 已提交
69
      VLOG(4) << "Found a cached abstract_type of: [TypeId_hash="
70 71 72 73
              << std::hash<ir::TypeId>()(type_id)
              << ", AbstractType_ptr=" << iter->second << "].";
      return iter->second;
    }
74 75 76
    LOG(WARNING) << "No cache found abstract_type of: [TypeId_hash="
                 << std::hash<ir::TypeId>()(type_id) << "].";
    return nullptr;
77 78
  }

Z
zhangbo9674 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91
  void RegisterAbstractAttribute(ir::TypeId type_id,
                                 AbstractAttribute *abstract_attribute) {
    std::lock_guard<ir::SpinLock> guard(registed_abstract_attributes_lock_);
    VLOG(4) << "Register an abstract_attribute of: [TypeId_hash="
            << std::hash<ir::TypeId>()(type_id)
            << ", AbstractAttribute_ptr=" << abstract_attribute << "].";
    registed_abstract_attributes_.emplace(type_id, abstract_attribute);
  }

  AbstractAttribute *GetAbstractAttribute(ir::TypeId type_id) {
    std::lock_guard<ir::SpinLock> guard(registed_abstract_attributes_lock_);
    auto iter = registed_abstract_attributes_.find(type_id);
    if (iter != registed_abstract_attributes_.end()) {
C
co63oc 已提交
92
      VLOG(4) << "Found a cached abstract_attribute of: [TypeId_hash="
Z
zhangbo9674 已提交
93 94 95 96 97 98 99 100 101
              << std::hash<ir::TypeId>()(type_id)
              << ", AbstractAttribute_ptr=" << iter->second << "].";
      return iter->second;
    }
    LOG(WARNING) << "No cache found abstract_attribute of: [TypeId_hash="
                 << std::hash<ir::TypeId>()(type_id) << "].";
    return nullptr;
  }

102 103 104 105
  bool IsOpInfoRegistered(const std::string &name) {
    return registed_op_infos_.find(name) != registed_op_infos_.end();
  }

106
  void RegisterOpInfo(const std::string &name, OpInfo info) {
107 108
    std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
    VLOG(4) << "Register an operation of: [Name=" << name
109 110
            << ", OpInfo ptr=" << info.AsOpaquePointer() << "].";
    registed_op_infos_.emplace(name, info);
111 112
  }

113
  OpInfo GetOpInfo(const std::string &name) {
114 115 116
    std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
    auto iter = registed_op_infos_.find(name);
    if (iter != registed_op_infos_.end()) {
117 118
      VLOG(4) << "Found a cached OpInfo of: [name=" << name
              << ", OpInfo: ptr=" << iter->second.AsOpaquePointer() << "].";
119 120 121
      return iter->second;
    }
    LOG(WARNING) << "No cache found operation of: [Name=" << name << "].";
122
    return OpInfo();
123
  }
124
  const OpInfoMap &registered_op_info_map() { return registed_op_infos_; }
125

126 127 128 129 130 131 132
  void RegisterDialect(std::string name, Dialect *dialect) {
    std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
    VLOG(4) << "Register a dialect of: [name=" << name
            << ", dialect_ptr=" << dialect << "].";
    registed_dialect_.emplace(name, dialect);
  }

133 134 135 136 137
  bool IsDialectRegistered(const std::string &name) {
    return registed_dialect_.find(name) != registed_dialect_.end();
  }

  Dialect *GetDialect(const std::string &name) {
138 139 140
    std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
    auto iter = registed_dialect_.find(name);
    if (iter != registed_dialect_.end()) {
C
co63oc 已提交
141
      VLOG(4) << "Found a cached dialect of: [name=" << name
142 143 144
              << ", dialect_ptr=" << iter->second << "].";
      return iter->second;
    }
C
co63oc 已提交
145
    LOG(WARNING) << "No cache found dialect of: [name=" << name << "].";
146 147
    return nullptr;
  }
148 149 150

  // Cached AbstractType instances.
  std::unordered_map<TypeId, AbstractType *> registed_abstract_types_;
151
  ir::SpinLock registed_abstract_types_lock_;
152
  // TypeStorage uniquer and cache instances.
Z
zhangbo9674 已提交
153
  StorageManager registed_type_storage_manager_;
154
  // Cache some built-in type objects.
155
  BFloat16Type bfp16_type;
156
  Float16Type fp16_type;
157
  Float32Type fp32_type;
158 159
  Float64Type fp64_type;
  Int16Type int16_type;
160
  Int32Type int32_type;
161
  Int64Type int64_type;
162

Z
zhangbo9674 已提交
163 164 165 166 167 168
  // Cached AbstractAttribute instances.
  std::unordered_map<TypeId, AbstractAttribute *> registed_abstract_attributes_;
  ir::SpinLock registed_abstract_attributes_lock_;
  // AttributeStorage uniquer and cache instances.
  StorageManager registed_attribute_storage_manager_;

C
co63oc 已提交
169
  // The dialect registered in the context.
Z
zhangbo9674 已提交
170 171 172
  std::unordered_map<std::string, Dialect *> registed_dialect_;
  ir::SpinLock registed_dialect_lock_;

173
  // The Op registered in the context.
174
  OpInfoMap registed_op_infos_;
175 176
  ir::SpinLock registed_op_infos_lock_;

177
  ir::SpinLock destructor_lock_;
178 179 180 181 182 183 184
};

IrContext *IrContext::Instance() {
  static IrContext context;
  return &context;
}

185 186
IrContext::~IrContext() { delete impl_; }

187
IrContext::IrContext() : impl_(new IrContextImpl()) {
188 189 190 191
  VLOG(4) << "BuiltinDialect registered into IrContext. ===>";
  GetOrRegisterDialect<BuiltinDialect>();
  VLOG(4) << "==============================================";

192
  impl_->bfp16_type = TypeManager::get<BFloat16Type>(this);
193
  impl_->fp16_type = TypeManager::get<Float16Type>(this);
194
  impl_->fp32_type = TypeManager::get<Float32Type>(this);
195 196
  impl_->fp64_type = TypeManager::get<Float64Type>(this);
  impl_->int16_type = TypeManager::get<Int16Type>(this);
197
  impl_->int32_type = TypeManager::get<Int32Type>(this);
198
  impl_->int64_type = TypeManager::get<Int64Type>(this);
199 200
}

Z
zhangbo9674 已提交
201 202
StorageManager &IrContext::type_storage_manager() {
  return impl().registed_type_storage_manager_;
203 204
}

205 206 207 208 209 210
AbstractType *IrContext::GetRegisteredAbstractType(TypeId id) {
  auto search = impl().registed_abstract_types_.find(id);
  if (search != impl().registed_abstract_types_.end()) {
    return search->second;
  }
  return nullptr;
211 212
}

Z
zhangbo9674 已提交
213
void IrContext::RegisterAbstractAttribute(
214 215 216 217 218 219 220 221
    ir::TypeId type_id, AbstractAttribute &&abstract_attribute) {
  if (GetRegisteredAbstractAttribute(type_id) == nullptr) {
    impl().RegisterAbstractAttribute(
        type_id, new AbstractAttribute(std::move(abstract_attribute)));
    VLOG(4) << "<--- Attribute registered into IrContext. --->";
  } else {
    LOG(WARNING) << " Attribute already registered.";
  }
Z
zhangbo9674 已提交
222 223 224 225 226 227
}

StorageManager &IrContext::attribute_storage_manager() {
  return impl().registed_attribute_storage_manager_;
}

228 229 230 231 232 233
AbstractAttribute *IrContext::GetRegisteredAbstractAttribute(TypeId id) {
  auto search = impl().registed_abstract_attributes_.find(id);
  if (search != impl().registed_abstract_attributes_.end()) {
    return search->second;
  }
  return nullptr;
Z
zhangbo9674 已提交
234 235
}

236
Dialect *IrContext::GetOrRegisterDialect(
237
    const std::string &dialect_name, std::function<Dialect *()> constructor) {
238
  VLOG(4) << "Try to get or register a Dialect of: [name=" << dialect_name
239
          << "].";
240
  if (!impl().IsDialectRegistered(dialect_name)) {
241 242
    VLOG(4) << "Create and register a new Dialect of: [name=" << dialect_name
            << "].";
243
    impl().RegisterDialect(dialect_name, constructor());
244
  }
245
  return impl().GetDialect(dialect_name);
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
}

std::vector<Dialect *> IrContext::GetRegisteredDialects() {
  std::vector<Dialect *> result;
  for (auto dialect_map : impl().registed_dialect_) {
    result.push_back(dialect_map.second);
  }
  return result;
}

Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) {
  for (auto dialect_map : impl().registed_dialect_) {
    if (dialect_map.first == dialect_name) {
      return dialect_map.second;
    }
  }
  LOG(WARNING) << "No dialect registered for " << dialect_name;
  return nullptr;
}

266 267 268 269 270 271 272 273 274
void IrContext::RegisterAbstractType(ir::TypeId type_id,
                                     AbstractType &&abstract_type) {
  if (GetRegisteredAbstractType(type_id) == nullptr) {
    impl().RegisterAbstractType(type_id,
                                new AbstractType(std::move(abstract_type)));
    VLOG(4) << "<--- Type registered into IrContext. --->";
  } else {
    LOG(WARNING) << " type already registered.";
  }
275 276
}

277 278 279 280 281 282
void IrContext::RegisterOpInfo(Dialect *dialect,
                               TypeId op_id,
                               const char *name,
                               std::vector<InterfaceValue> &&interface_map,
                               const std::vector<TypeId> &trait_set,
                               size_t attributes_num,
283 284
                               const char **attributes_name,
                               VerifyPtr verify) {
285 286 287
  if (impl().IsOpInfoRegistered(name)) {
    LOG(WARNING) << name << " op already registered.";
  } else {
288 289 290 291 292 293 294 295 296
    OpInfo info = OpInfoImpl::Create(dialect,
                                     op_id,
                                     name,
                                     std::move(interface_map),
                                     trait_set,
                                     attributes_num,
                                     attributes_name,
                                     verify);
    impl().RegisterOpInfo(name, info);
297
    VLOG(4) << name << " op registered into IrContext. --->";
298 299 300
  }
}

301
OpInfo IrContext::GetRegisteredOpInfo(const std::string &name) {
302 303 304 305 306
  return impl().GetOpInfo(name);
}

const OpInfoMap &IrContext::registered_op_info_map() {
  return impl().registered_op_info_map();
307 308
}

309
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
310 311 312
  AbstractType *abstract_type = ctx->impl().GetAbstractType(type_id);
  IR_ENFORCE(abstract_type, "Abstract type not found in IrContext.");
  return *abstract_type;
313 314
}

Z
zhangbo9674 已提交
315 316
const AbstractAttribute &AbstractAttribute::lookup(TypeId type_id,
                                                   IrContext *ctx) {
317 318 319 320
  AbstractAttribute *abstract_attribute =
      ctx->impl().GetAbstractAttribute(type_id);
  IR_ENFORCE(abstract_attribute, "Abstract attribute not found in IrContext.");
  return *abstract_attribute;
Z
zhangbo9674 已提交
321 322
}

323 324 325 326
BFloat16Type BFloat16Type::get(IrContext *ctx) {
  return ctx->impl().bfp16_type;
}

327 328
Float16Type Float16Type::get(IrContext *ctx) { return ctx->impl().fp16_type; }

329 330
Float32Type Float32Type::get(IrContext *ctx) { return ctx->impl().fp32_type; }

331 332 333 334
Float64Type Float64Type::get(IrContext *ctx) { return ctx->impl().fp64_type; }

Int16Type Int16Type::get(IrContext *ctx) { return ctx->impl().int16_type; }

335 336
Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; }

337 338
Int64Type Int64Type::get(IrContext *ctx) { return ctx->impl().int64_type; }

339
}  // namespace ir