ir_context.cc 12.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/ir/core/ir_context.h"
16

17 18
#include <unordered_map>

19 20 21 22 23 24 25
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/op_info_impl.h"
#include "paddle/ir/core/spin_lock.h"
#include "paddle/ir/core/type_base.h"
26 27

namespace ir {
28
// The implementation class of the IrContext class, cache registered
Z
zhangbo9674 已提交
29
// AbstractType, TypeStorage, AbstractAttribute, AttributeStorage, Dialect.
30 31
class IrContextImpl {
 public:
32
  IrContextImpl() = default;
33 34

  ~IrContextImpl() {
35 36
    std::lock_guard<ir::SpinLock> guard(destructor_lock_);
    for (auto &abstract_type_map : registed_abstract_types_) {
37 38 39
      delete abstract_type_map.second;
    }
    registed_abstract_types_.clear();
40

Z
zhangbo9674 已提交
41 42 43 44 45
    for (auto &abstract_attribute_map : registed_abstract_attributes_) {
      delete abstract_attribute_map.second;
    }
    registed_abstract_attributes_.clear();

46 47 48 49
    for (auto &dialect_map : registed_dialect_) {
      delete dialect_map.second;
    }
    registed_dialect_.clear();
50 51

    for (auto &op_map : registed_op_infos_) {
52
      OpInfoImpl::Destroy(op_map.second);
53 54
    }
    registed_op_infos_.clear();
55 56 57 58
  }

  void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) {
    std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
59
    VLOG(6) << "Register an abstract_type of: [TypeId_hash="
60 61 62 63 64
            << std::hash<ir::TypeId>()(type_id)
            << ", AbstractType_ptr=" << abstract_type << "].";
    registed_abstract_types_.emplace(type_id, abstract_type);
  }

65
  AbstractType *GetAbstractType(ir::TypeId type_id) {
66 67
    std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
    auto iter = registed_abstract_types_.find(type_id);
68
    if (iter != registed_abstract_types_.end()) {
69
      VLOG(6) << "Found a cached abstract_type of: [TypeId_hash="
70 71 72 73
              << std::hash<ir::TypeId>()(type_id)
              << ", AbstractType_ptr=" << iter->second << "].";
      return iter->second;
    }
74 75 76
    LOG(WARNING) << "No cache found abstract_type of: [TypeId_hash="
                 << std::hash<ir::TypeId>()(type_id) << "].";
    return nullptr;
77 78
  }

Z
zhangbo9674 已提交
79 80 81
  void RegisterAbstractAttribute(ir::TypeId type_id,
                                 AbstractAttribute *abstract_attribute) {
    std::lock_guard<ir::SpinLock> guard(registed_abstract_attributes_lock_);
82
    VLOG(6) << "Register an abstract_attribute of: [TypeId_hash="
Z
zhangbo9674 已提交
83 84 85 86 87 88 89 90 91
            << std::hash<ir::TypeId>()(type_id)
            << ", AbstractAttribute_ptr=" << abstract_attribute << "].";
    registed_abstract_attributes_.emplace(type_id, abstract_attribute);
  }

  AbstractAttribute *GetAbstractAttribute(ir::TypeId type_id) {
    std::lock_guard<ir::SpinLock> guard(registed_abstract_attributes_lock_);
    auto iter = registed_abstract_attributes_.find(type_id);
    if (iter != registed_abstract_attributes_.end()) {
C
co63oc 已提交
92
      VLOG(4) << "Found a cached abstract_attribute of: [TypeId_hash="
Z
zhangbo9674 已提交
93 94 95 96 97 98 99 100 101
              << std::hash<ir::TypeId>()(type_id)
              << ", AbstractAttribute_ptr=" << iter->second << "].";
      return iter->second;
    }
    LOG(WARNING) << "No cache found abstract_attribute of: [TypeId_hash="
                 << std::hash<ir::TypeId>()(type_id) << "].";
    return nullptr;
  }

102 103 104 105
  bool IsOpInfoRegistered(const std::string &name) {
    return registed_op_infos_.find(name) != registed_op_infos_.end();
  }

106
  void RegisterOpInfo(const std::string &name, OpInfo info) {
107
    std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
108
    VLOG(6) << "Register an operation of: [Name=" << name
109 110
            << ", OpInfo ptr=" << info.AsOpaquePointer() << "].";
    registed_op_infos_.emplace(name, info);
111 112
  }

113
  OpInfo GetOpInfo(const std::string &name) {
114 115 116
    std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
    auto iter = registed_op_infos_.find(name);
    if (iter != registed_op_infos_.end()) {
117
      VLOG(8) << "Found a cached OpInfo of: [name=" << name
118
              << ", OpInfo: ptr=" << iter->second.AsOpaquePointer() << "].";
119 120 121
      return iter->second;
    }
    LOG(WARNING) << "No cache found operation of: [Name=" << name << "].";
122
    return OpInfo();
123
  }
124
  const OpInfoMap &registered_op_info_map() { return registed_op_infos_; }
125

126 127
  void RegisterDialect(std::string name, Dialect *dialect) {
    std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
128
    VLOG(6) << "Register a dialect of: [name=" << name
129 130 131 132
            << ", dialect_ptr=" << dialect << "].";
    registed_dialect_.emplace(name, dialect);
  }

133 134 135 136 137
  bool IsDialectRegistered(const std::string &name) {
    return registed_dialect_.find(name) != registed_dialect_.end();
  }

  Dialect *GetDialect(const std::string &name) {
138 139 140
    std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
    auto iter = registed_dialect_.find(name);
    if (iter != registed_dialect_.end()) {
141
      VLOG(6) << "Found a cached dialect of: [name=" << name
142 143 144
              << ", dialect_ptr=" << iter->second << "].";
      return iter->second;
    }
C
co63oc 已提交
145
    LOG(WARNING) << "No cache found dialect of: [name=" << name << "].";
146 147
    return nullptr;
  }
148 149 150

  // Cached AbstractType instances.
  std::unordered_map<TypeId, AbstractType *> registed_abstract_types_;
151
  ir::SpinLock registed_abstract_types_lock_;
152
  // TypeStorage uniquer and cache instances.
Z
zhangbo9674 已提交
153
  StorageManager registed_type_storage_manager_;
154
  // Cache some built-in type objects.
155
  BFloat16Type bfp16_type;
156
  Float16Type fp16_type;
157
  Float32Type fp32_type;
158
  Float64Type fp64_type;
K
kangguangli 已提交
159 160
  UInt8Type uint8_type;
  Int8Type int8_type;
161
  Int16Type int16_type;
162
  Int32Type int32_type;
163
  Int64Type int64_type;
K
kangguangli 已提交
164 165 166
  BoolType bool_type;
  Complex64Type complex64_type;
  Complex128Type complex128_type;
167

Z
zhangbo9674 已提交
168 169 170 171 172 173
  // Cached AbstractAttribute instances.
  std::unordered_map<TypeId, AbstractAttribute *> registed_abstract_attributes_;
  ir::SpinLock registed_abstract_attributes_lock_;
  // AttributeStorage uniquer and cache instances.
  StorageManager registed_attribute_storage_manager_;

C
co63oc 已提交
174
  // The dialect registered in the context.
Z
zhangbo9674 已提交
175 176 177
  std::unordered_map<std::string, Dialect *> registed_dialect_;
  ir::SpinLock registed_dialect_lock_;

178
  // The Op registered in the context.
179
  OpInfoMap registed_op_infos_;
180 181
  ir::SpinLock registed_op_infos_lock_;

182
  ir::SpinLock destructor_lock_;
183 184 185 186 187 188 189
};

IrContext *IrContext::Instance() {
  static IrContext context;
  return &context;
}

190 191
IrContext::~IrContext() { delete impl_; }

192
IrContext::IrContext() : impl_(new IrContextImpl()) {
193 194 195 196
  VLOG(4) << "BuiltinDialect registered into IrContext. ===>";
  GetOrRegisterDialect<BuiltinDialect>();
  VLOG(4) << "==============================================";

197
  impl_->bfp16_type = TypeManager::get<BFloat16Type>(this);
198
  impl_->fp16_type = TypeManager::get<Float16Type>(this);
199
  impl_->fp32_type = TypeManager::get<Float32Type>(this);
200
  impl_->fp64_type = TypeManager::get<Float64Type>(this);
K
kangguangli 已提交
201 202
  impl_->uint8_type = TypeManager::get<UInt8Type>(this);
  impl_->int8_type = TypeManager::get<Int8Type>(this);
203
  impl_->int16_type = TypeManager::get<Int16Type>(this);
204
  impl_->int32_type = TypeManager::get<Int32Type>(this);
205
  impl_->int64_type = TypeManager::get<Int64Type>(this);
K
kangguangli 已提交
206 207 208
  impl_->bool_type = TypeManager::get<BoolType>(this);
  impl_->complex64_type = TypeManager::get<Complex64Type>(this);
  impl_->complex128_type = TypeManager::get<Complex128Type>(this);
209 210
}

Z
zhangbo9674 已提交
211 212
StorageManager &IrContext::type_storage_manager() {
  return impl().registed_type_storage_manager_;
213 214
}

215 216 217 218 219 220
AbstractType *IrContext::GetRegisteredAbstractType(TypeId id) {
  auto search = impl().registed_abstract_types_.find(id);
  if (search != impl().registed_abstract_types_.end()) {
    return search->second;
  }
  return nullptr;
221 222
}

Z
zhangbo9674 已提交
223
void IrContext::RegisterAbstractAttribute(
224 225 226 227 228 229 230
    ir::TypeId type_id, AbstractAttribute &&abstract_attribute) {
  if (GetRegisteredAbstractAttribute(type_id) == nullptr) {
    impl().RegisterAbstractAttribute(
        type_id, new AbstractAttribute(std::move(abstract_attribute)));
  } else {
    LOG(WARNING) << " Attribute already registered.";
  }
Z
zhangbo9674 已提交
231 232 233 234 235 236
}

StorageManager &IrContext::attribute_storage_manager() {
  return impl().registed_attribute_storage_manager_;
}

237 238 239 240 241 242
AbstractAttribute *IrContext::GetRegisteredAbstractAttribute(TypeId id) {
  auto search = impl().registed_abstract_attributes_.find(id);
  if (search != impl().registed_abstract_attributes_.end()) {
    return search->second;
  }
  return nullptr;
Z
zhangbo9674 已提交
243 244
}

245
Dialect *IrContext::GetOrRegisterDialect(
246
    const std::string &dialect_name, std::function<Dialect *()> constructor) {
247
  VLOG(4) << "Try to get or register a Dialect of: [name=" << dialect_name
248
          << "].";
249
  if (!impl().IsDialectRegistered(dialect_name)) {
250 251
    VLOG(4) << "Create and register a new Dialect of: [name=" << dialect_name
            << "].";
252
    impl().RegisterDialect(dialect_name, constructor());
253
  }
254
  return impl().GetDialect(dialect_name);
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
}

std::vector<Dialect *> IrContext::GetRegisteredDialects() {
  std::vector<Dialect *> result;
  for (auto dialect_map : impl().registed_dialect_) {
    result.push_back(dialect_map.second);
  }
  return result;
}

Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) {
  for (auto dialect_map : impl().registed_dialect_) {
    if (dialect_map.first == dialect_name) {
      return dialect_map.second;
    }
  }
  LOG(WARNING) << "No dialect registered for " << dialect_name;
  return nullptr;
}

275 276 277 278 279 280 281 282
void IrContext::RegisterAbstractType(ir::TypeId type_id,
                                     AbstractType &&abstract_type) {
  if (GetRegisteredAbstractType(type_id) == nullptr) {
    impl().RegisterAbstractType(type_id,
                                new AbstractType(std::move(abstract_type)));
  } else {
    LOG(WARNING) << " type already registered.";
  }
283 284
}

285 286 287 288 289 290
void IrContext::RegisterOpInfo(Dialect *dialect,
                               TypeId op_id,
                               const char *name,
                               std::vector<InterfaceValue> &&interface_map,
                               const std::vector<TypeId> &trait_set,
                               size_t attributes_num,
291 292
                               const char **attributes_name,
                               VerifyPtr verify) {
293 294 295
  if (impl().IsOpInfoRegistered(name)) {
    LOG(WARNING) << name << " op already registered.";
  } else {
296 297 298 299 300 301 302 303 304
    OpInfo info = OpInfoImpl::Create(dialect,
                                     op_id,
                                     name,
                                     std::move(interface_map),
                                     trait_set,
                                     attributes_num,
                                     attributes_name,
                                     verify);
    impl().RegisterOpInfo(name, info);
305 306 307
  }
}

308
OpInfo IrContext::GetRegisteredOpInfo(const std::string &name) {
309 310 311 312 313
  return impl().GetOpInfo(name);
}

const OpInfoMap &IrContext::registered_op_info_map() {
  return impl().registered_op_info_map();
314 315
}

316
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
317 318 319
  AbstractType *abstract_type = ctx->impl().GetAbstractType(type_id);
  IR_ENFORCE(abstract_type, "Abstract type not found in IrContext.");
  return *abstract_type;
320 321
}

Z
zhangbo9674 已提交
322 323
const AbstractAttribute &AbstractAttribute::lookup(TypeId type_id,
                                                   IrContext *ctx) {
324 325 326 327
  AbstractAttribute *abstract_attribute =
      ctx->impl().GetAbstractAttribute(type_id);
  IR_ENFORCE(abstract_attribute, "Abstract attribute not found in IrContext.");
  return *abstract_attribute;
Z
zhangbo9674 已提交
328 329
}

330 331 332 333
BFloat16Type BFloat16Type::get(IrContext *ctx) {
  return ctx->impl().bfp16_type;
}

334 335
Float16Type Float16Type::get(IrContext *ctx) { return ctx->impl().fp16_type; }

336 337
Float32Type Float32Type::get(IrContext *ctx) { return ctx->impl().fp32_type; }

338 339 340 341
Float64Type Float64Type::get(IrContext *ctx) { return ctx->impl().fp64_type; }

Int16Type Int16Type::get(IrContext *ctx) { return ctx->impl().int16_type; }

342 343
Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; }

344 345
Int64Type Int64Type::get(IrContext *ctx) { return ctx->impl().int64_type; }

K
kangguangli 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358 359
Int8Type Int8Type::get(IrContext *ctx) { return ctx->impl().int8_type; }

UInt8Type UInt8Type::get(IrContext *ctx) { return ctx->impl().uint8_type; }

BoolType BoolType::get(IrContext *ctx) { return ctx->impl().bool_type; }

Complex64Type Complex64Type::get(IrContext *ctx) {
  return ctx->impl().complex64_type;
}

Complex128Type Complex128Type::get(IrContext *ctx) {
  return ctx->impl().complex128_type;
}

360
}  // namespace ir