ir_context.cc 12.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/ir/core/ir_context.h"
16

17 18
#include <unordered_map>

19 20 21 22 23 24 25
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/op_info_impl.h"
#include "paddle/ir/core/spin_lock.h"
#include "paddle/ir/core/type_base.h"
26 27

namespace ir {
28
// The implementation class of the IrContext class, cache registered
Z
zhangbo9674 已提交
29
// AbstractType, TypeStorage, AbstractAttribute, AttributeStorage, Dialect.
30 31 32 33 34
class IrContextImpl {
 public:
  IrContextImpl() {}

  ~IrContextImpl() {
35 36
    std::lock_guard<ir::SpinLock> guard(destructor_lock_);
    for (auto &abstract_type_map : registed_abstract_types_) {
37 38 39
      delete abstract_type_map.second;
    }
    registed_abstract_types_.clear();
40

Z
zhangbo9674 已提交
41 42 43 44 45
    for (auto &abstract_attribute_map : registed_abstract_attributes_) {
      delete abstract_attribute_map.second;
    }
    registed_abstract_attributes_.clear();

46 47 48 49
    for (auto &dialect_map : registed_dialect_) {
      delete dialect_map.second;
    }
    registed_dialect_.clear();
50 51 52 53 54

    for (auto &op_map : registed_op_infos_) {
      op_map.second->destroy();
    }
    registed_op_infos_.clear();
55 56 57 58
  }

  void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) {
    std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
59
    VLOG(4) << "Register an abstract_type of: [TypeId_hash="
60 61 62 63 64
            << std::hash<ir::TypeId>()(type_id)
            << ", AbstractType_ptr=" << abstract_type << "].";
    registed_abstract_types_.emplace(type_id, abstract_type);
  }

65
  AbstractType *GetAbstractType(ir::TypeId type_id) {
66 67
    std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
    auto iter = registed_abstract_types_.find(type_id);
68
    if (iter != registed_abstract_types_.end()) {
C
co63oc 已提交
69
      VLOG(4) << "Found a cached abstract_type of: [TypeId_hash="
70 71 72 73
              << std::hash<ir::TypeId>()(type_id)
              << ", AbstractType_ptr=" << iter->second << "].";
      return iter->second;
    }
74 75 76
    LOG(WARNING) << "No cache found abstract_type of: [TypeId_hash="
                 << std::hash<ir::TypeId>()(type_id) << "].";
    return nullptr;
77 78
  }

Z
zhangbo9674 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91
  void RegisterAbstractAttribute(ir::TypeId type_id,
                                 AbstractAttribute *abstract_attribute) {
    std::lock_guard<ir::SpinLock> guard(registed_abstract_attributes_lock_);
    VLOG(4) << "Register an abstract_attribute of: [TypeId_hash="
            << std::hash<ir::TypeId>()(type_id)
            << ", AbstractAttribute_ptr=" << abstract_attribute << "].";
    registed_abstract_attributes_.emplace(type_id, abstract_attribute);
  }

  AbstractAttribute *GetAbstractAttribute(ir::TypeId type_id) {
    std::lock_guard<ir::SpinLock> guard(registed_abstract_attributes_lock_);
    auto iter = registed_abstract_attributes_.find(type_id);
    if (iter != registed_abstract_attributes_.end()) {
C
co63oc 已提交
92
      VLOG(4) << "Found a cached abstract_attribute of: [TypeId_hash="
Z
zhangbo9674 已提交
93 94 95 96 97 98 99 100 101
              << std::hash<ir::TypeId>()(type_id)
              << ", AbstractAttribute_ptr=" << iter->second << "].";
      return iter->second;
    }
    LOG(WARNING) << "No cache found abstract_attribute of: [TypeId_hash="
                 << std::hash<ir::TypeId>()(type_id) << "].";
    return nullptr;
  }

102 103 104 105
  bool IsOpInfoRegistered(const std::string &name) {
    return registed_op_infos_.find(name) != registed_op_infos_.end();
  }

106 107 108 109 110 111 112 113 114 115 116
  void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) {
    std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
    VLOG(4) << "Register an operation of: [Name=" << name
            << ", OpInfoImpl ptr=" << opinfo << "].";
    registed_op_infos_.emplace(name, opinfo);
  }

  OpInfoImpl *GetOpInfo(const std::string &name) {
    std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
    auto iter = registed_op_infos_.find(name);
    if (iter != registed_op_infos_.end()) {
C
co63oc 已提交
117
      VLOG(4) << "Found a cached operation of: [name=" << name
118 119 120 121 122 123 124
              << ", OpInfoImpl ptr=" << iter->second << "].";
      return iter->second;
    }
    LOG(WARNING) << "No cache found operation of: [Name=" << name << "].";
    return nullptr;
  }

125 126 127 128 129 130 131
  void RegisterDialect(std::string name, Dialect *dialect) {
    std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
    VLOG(4) << "Register a dialect of: [name=" << name
            << ", dialect_ptr=" << dialect << "].";
    registed_dialect_.emplace(name, dialect);
  }

132 133 134 135 136
  bool IsDialectRegistered(const std::string &name) {
    return registed_dialect_.find(name) != registed_dialect_.end();
  }

  Dialect *GetDialect(const std::string &name) {
137 138 139
    std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
    auto iter = registed_dialect_.find(name);
    if (iter != registed_dialect_.end()) {
C
co63oc 已提交
140
      VLOG(4) << "Found a cached dialect of: [name=" << name
141 142 143
              << ", dialect_ptr=" << iter->second << "].";
      return iter->second;
    }
C
co63oc 已提交
144
    LOG(WARNING) << "No cache found dialect of: [name=" << name << "].";
145 146
    return nullptr;
  }
147 148 149

  // Cached AbstractType instances.
  std::unordered_map<TypeId, AbstractType *> registed_abstract_types_;
150
  ir::SpinLock registed_abstract_types_lock_;
151
  // TypeStorage uniquer and cache instances.
Z
zhangbo9674 已提交
152
  StorageManager registed_type_storage_manager_;
153
  // Cache some built-in type objects.
154
  BFloat16Type bfp16_type;
155
  Float16Type fp16_type;
156
  Float32Type fp32_type;
157 158
  Float64Type fp64_type;
  Int16Type int16_type;
159
  Int32Type int32_type;
160
  Int64Type int64_type;
161

Z
zhangbo9674 已提交
162 163 164 165 166 167
  // Cached AbstractAttribute instances.
  std::unordered_map<TypeId, AbstractAttribute *> registed_abstract_attributes_;
  ir::SpinLock registed_abstract_attributes_lock_;
  // AttributeStorage uniquer and cache instances.
  StorageManager registed_attribute_storage_manager_;

C
co63oc 已提交
168
  // The dialect registered in the context.
Z
zhangbo9674 已提交
169 170 171
  std::unordered_map<std::string, Dialect *> registed_dialect_;
  ir::SpinLock registed_dialect_lock_;

172 173 174 175
  // The Op registered in the context.
  std::unordered_map<std::string, OpInfoImpl *> registed_op_infos_;
  ir::SpinLock registed_op_infos_lock_;

176
  ir::SpinLock destructor_lock_;
177 178 179 180 181 182 183 184
};

IrContext *IrContext::Instance() {
  static IrContext context;
  return &context;
}

IrContext::IrContext() : impl_(new IrContextImpl()) {
185 186 187 188
  VLOG(4) << "BuiltinDialect registered into IrContext. ===>";
  GetOrRegisterDialect<BuiltinDialect>();
  VLOG(4) << "==============================================";

189
  impl_->bfp16_type = TypeManager::get<BFloat16Type>(this);
190
  impl_->fp16_type = TypeManager::get<Float16Type>(this);
191
  impl_->fp32_type = TypeManager::get<Float32Type>(this);
192 193
  impl_->fp64_type = TypeManager::get<Float64Type>(this);
  impl_->int16_type = TypeManager::get<Int16Type>(this);
194
  impl_->int32_type = TypeManager::get<Int32Type>(this);
195
  impl_->int64_type = TypeManager::get<Int64Type>(this);
196 197
}

Z
zhangbo9674 已提交
198 199
StorageManager &IrContext::type_storage_manager() {
  return impl().registed_type_storage_manager_;
200 201
}

202 203 204 205 206 207
AbstractType *IrContext::GetRegisteredAbstractType(TypeId id) {
  auto search = impl().registed_abstract_types_.find(id);
  if (search != impl().registed_abstract_types_.end()) {
    return search->second;
  }
  return nullptr;
208 209
}

Z
zhangbo9674 已提交
210
void IrContext::RegisterAbstractAttribute(
211 212 213 214 215 216 217 218
    ir::TypeId type_id, AbstractAttribute &&abstract_attribute) {
  if (GetRegisteredAbstractAttribute(type_id) == nullptr) {
    impl().RegisterAbstractAttribute(
        type_id, new AbstractAttribute(std::move(abstract_attribute)));
    VLOG(4) << "<--- Attribute registered into IrContext. --->";
  } else {
    LOG(WARNING) << " Attribute already registered.";
  }
Z
zhangbo9674 已提交
219 220 221 222 223 224
}

StorageManager &IrContext::attribute_storage_manager() {
  return impl().registed_attribute_storage_manager_;
}

225 226 227 228 229 230
AbstractAttribute *IrContext::GetRegisteredAbstractAttribute(TypeId id) {
  auto search = impl().registed_abstract_attributes_.find(id);
  if (search != impl().registed_abstract_attributes_.end()) {
    return search->second;
  }
  return nullptr;
Z
zhangbo9674 已提交
231 232
}

233
Dialect *IrContext::GetOrRegisterDialect(
234
    const std::string &dialect_name, std::function<Dialect *()> constructor) {
235
  VLOG(4) << "Try to get or register a Dialect of: [name=" << dialect_name
236
          << "].";
237
  if (!impl().IsDialectRegistered(dialect_name)) {
238 239
    VLOG(4) << "Create and register a new Dialect of: [name=" << dialect_name
            << "].";
240
    impl().RegisterDialect(dialect_name, constructor());
241
  }
242
  return impl().GetDialect(dialect_name);
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
}

std::vector<Dialect *> IrContext::GetRegisteredDialects() {
  std::vector<Dialect *> result;
  for (auto dialect_map : impl().registed_dialect_) {
    result.push_back(dialect_map.second);
  }
  return result;
}

Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) {
  for (auto dialect_map : impl().registed_dialect_) {
    if (dialect_map.first == dialect_name) {
      return dialect_map.second;
    }
  }
  LOG(WARNING) << "No dialect registered for " << dialect_name;
  return nullptr;
}

263 264 265 266 267 268 269 270 271
void IrContext::RegisterAbstractType(ir::TypeId type_id,
                                     AbstractType &&abstract_type) {
  if (GetRegisteredAbstractType(type_id) == nullptr) {
    impl().RegisterAbstractType(type_id,
                                new AbstractType(std::move(abstract_type)));
    VLOG(4) << "<--- Type registered into IrContext. --->";
  } else {
    LOG(WARNING) << " type already registered.";
  }
272 273
}

274 275 276 277 278 279
void IrContext::RegisterOpInfo(Dialect *dialect,
                               TypeId op_id,
                               const char *name,
                               std::vector<InterfaceValue> &&interface_map,
                               const std::vector<TypeId> &trait_set,
                               size_t attributes_num,
280 281
                               const char **attributes_name,
                               VerifyPtr verify) {
282 283 284
  if (impl().IsOpInfoRegistered(name)) {
    LOG(WARNING) << name << " op already registered.";
  } else {
285 286 287 288 289 290
    OpInfoImpl *opinfo = OpInfoImpl::create(dialect,
                                            op_id,
                                            name,
                                            std::move(interface_map),
                                            trait_set,
                                            attributes_num,
291 292
                                            attributes_name,
                                            verify);
293
    impl().RegisterOpInfo(name, opinfo);
294
    VLOG(4) << name << " op registered into IrContext. --->";
295 296 297
  }
}

298 299 300 301 302
OpInfo IrContext::GetRegisteredOpInfo(const std::string &name) {
  OpInfoImpl *rtn = impl().GetOpInfo(name);
  return rtn ? rtn : nullptr;
}

303
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
304
  auto &impl = ctx->impl();
305
  AbstractType *abstract_type = impl.GetAbstractType(type_id);
306 307 308 309 310 311 312
  if (abstract_type) {
    return *abstract_type;
  } else {
    throw("Abstract type not found in IrContext.");
  }
}

Z
zhangbo9674 已提交
313 314 315 316 317 318 319 320 321 322 323
const AbstractAttribute &AbstractAttribute::lookup(TypeId type_id,
                                                   IrContext *ctx) {
  auto &impl = ctx->impl();
  AbstractAttribute *abstract_attribute = impl.GetAbstractAttribute(type_id);
  if (abstract_attribute) {
    return *abstract_attribute;
  } else {
    throw("Abstract attribute not found in IrContext.");
  }
}

324 325 326 327
BFloat16Type BFloat16Type::get(IrContext *ctx) {
  return ctx->impl().bfp16_type;
}

328 329
Float16Type Float16Type::get(IrContext *ctx) { return ctx->impl().fp16_type; }

330 331
Float32Type Float32Type::get(IrContext *ctx) { return ctx->impl().fp32_type; }

332 333 334 335
Float64Type Float64Type::get(IrContext *ctx) { return ctx->impl().fp64_type; }

Int16Type Int16Type::get(IrContext *ctx) { return ctx->impl().int16_type; }

336 337
Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; }

338 339
Int64Type Int64Type::get(IrContext *ctx) { return ctx->impl().int64_type; }

340
}  // namespace ir