ir_context.cc 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/ir/ir_context.h"

17 18
#include <unordered_map>

19
#include "paddle/ir/builtin_dialect.h"
20
#include "paddle/ir/builtin_type.h"
21
#include "paddle/ir/dialect.h"
22 23 24 25
#include "paddle/ir/spin_lock.h"
#include "paddle/ir/type_base.h"

namespace ir {
26 27
// The implementation class of the IrContext class, cache registered
// AbstractType, TypeStorage, Dialect.
28 29 30 31 32
class IrContextImpl {
 public:
  IrContextImpl() {}

  ~IrContextImpl() {
33 34
    std::lock_guard<ir::SpinLock> guard(destructor_lock_);
    for (auto &abstract_type_map : registed_abstract_types_) {
35 36 37
      delete abstract_type_map.second;
    }
    registed_abstract_types_.clear();
38 39 40 41 42

    for (auto &dialect_map : registed_dialect_) {
      delete dialect_map.second;
    }
    registed_dialect_.clear();
43 44 45 46
  }

  void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) {
    std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
47
    VLOG(4) << "Register an abstract_type of: [TypeId_hash="
48 49 50 51 52
            << std::hash<ir::TypeId>()(type_id)
            << ", AbstractType_ptr=" << abstract_type << "].";
    registed_abstract_types_.emplace(type_id, abstract_type);
  }

53
  AbstractType *GetAbstractType(ir::TypeId type_id) {
54 55
    std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
    auto iter = registed_abstract_types_.find(type_id);
56 57
    if (iter != registed_abstract_types_.end()) {
      VLOG(4) << "Fonund a cached abstract_type of: [TypeId_hash="
58 59 60 61
              << std::hash<ir::TypeId>()(type_id)
              << ", AbstractType_ptr=" << iter->second << "].";
      return iter->second;
    }
62 63 64
    LOG(WARNING) << "No cache found abstract_type of: [TypeId_hash="
                 << std::hash<ir::TypeId>()(type_id) << "].";
    return nullptr;
65 66
  }

67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
  void RegisterDialect(std::string name, Dialect *dialect) {
    std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
    VLOG(4) << "Register a dialect of: [name=" << name
            << ", dialect_ptr=" << dialect << "].";
    registed_dialect_.emplace(name, dialect);
  }

  Dialect *GetDialect(std::string name) {
    std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
    auto iter = registed_dialect_.find(name);
    if (iter != registed_dialect_.end()) {
      VLOG(4) << "Fonund a cached dialect of: [name=" << name
              << ", dialect_ptr=" << iter->second << "].";
      return iter->second;
    }
    LOG(WARNING) << "No cache fonund dialect of: [name=" << name << "].";
    return nullptr;
  }
85 86 87 88

  // Cached AbstractType instances.
  std::unordered_map<TypeId, AbstractType *> registed_abstract_types_;

89 90
  ir::SpinLock registed_abstract_types_lock_;

91 92 93
  // TypeStorage uniquer and cache instances.
  StorageManager registed_storage_manager_;

94 95 96 97 98 99
  // The dialcet registered in the context.
  std::unordered_map<std::string, Dialect *> registed_dialect_;

  ir::SpinLock registed_dialect_lock_;

  // Some built-in types.
100 101
  Float32Type fp32_type;
  Int32Type int32_type;
102 103

  ir::SpinLock destructor_lock_;
104 105 106 107 108 109 110 111
};

IrContext *IrContext::Instance() {
  static IrContext context;
  return &context;
}

IrContext::IrContext() : impl_(new IrContextImpl()) {
112 113 114 115
  VLOG(4) << "BuiltinDialect registered into IrContext. ===>";
  GetOrRegisterDialect<BuiltinDialect>();
  VLOG(4) << "==============================================";

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
  impl_->fp32_type = TypeManager::get<Float32Type>(this);
  impl_->int32_type = TypeManager::get<Int32Type>(this);
}

void IrContext::RegisterAbstractType(ir::TypeId type_id,
                                     AbstractType *abstract_type) {
  impl().RegisterAbstractType(type_id, abstract_type);
}

StorageManager &IrContext::storage_manager() {
  return impl().registed_storage_manager_;
}

std::unordered_map<TypeId, AbstractType *>
    &IrContext::registed_abstracted_type() {
  return impl().registed_abstract_types_;
}

134 135 136
Dialect *IrContext::GetOrRegisterDialect(
    std::string dialect_name, std::function<Dialect *()> constructor) {
  VLOG(4) << "Try to get or register a Dialect of: [name=" << dialect_name
137
          << "].";
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
  Dialect *dialect = impl().GetDialect(dialect_name);
  if (dialect == nullptr) {
    VLOG(4) << "Create and register a new Dialect of: [name=" << dialect_name
            << "].";
    dialect = constructor();
    impl().RegisterDialect(dialect_name, dialect);
  }
  return dialect;
}

std::vector<Dialect *> IrContext::GetRegisteredDialects() {
  std::vector<Dialect *> result;
  for (auto dialect_map : impl().registed_dialect_) {
    result.push_back(dialect_map.second);
  }
  return result;
}

Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) {
  for (auto dialect_map : impl().registed_dialect_) {
    if (dialect_map.first == dialect_name) {
      return dialect_map.second;
    }
  }
  LOG(WARNING) << "No dialect registered for " << dialect_name;
  return nullptr;
}

const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
167
  auto &impl = ctx->impl();
168
  AbstractType *abstract_type = impl.GetAbstractType(type_id);
169 170 171 172 173 174 175 176 177 178 179 180
  if (abstract_type) {
    return *abstract_type;
  } else {
    throw("Abstract type not found in IrContext.");
  }
}

Float32Type Float32Type::get(IrContext *ctx) { return ctx->impl().fp32_type; }

Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; }

}  // namespace ir