未验证 提交 b0a604cb 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Type system stage3: add class Dialect (#50959)

* add dialect

* add some interface for dialect

* add some dialect interfaces for class Type

* set WITH_NEWIR=OFF

* refine code by comment

* polish code

* refine include style

* refine log for debug
上级 8f156fd7
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/builtin_dialect.h"
#include "paddle/ir/builtin_type.h"
namespace ir {
BuiltinDialect::BuiltinDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<BuiltinDialect>()) {
initialize();
}
void BuiltinDialect::initialize() {
// Register all built-in types defined in builtin_type.h.
RegisterTypes<GET_BUILT_IN_TYPE_LIST>();
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/dialect.h"
namespace ir {
///
/// \brief Built-in Dialect: automatically registered into global IrContext,
/// all built-in types defined in builtin_type.h will be registered in this
/// Dialect.
///
class BuiltinDialect : public ir::Dialect {
public:
explicit BuiltinDialect(ir::IrContext *context);
///
/// \brief Each Dialect needs to provide a name function to return the name of
/// the Dialect.
///
/// \return The name of this Dialect.
///
static const char *name() { return "builtin"; }
private:
void initialize();
};
} // namespace ir
......@@ -17,9 +17,18 @@
#include "paddle/ir/type.h"
namespace ir {
///
/// \brief This macro is used to get a list of all built-in types in this file.
///
#define GET_BUILT_IN_TYPE_LIST ir::Float32Type, ir::Int32Type
///
/// \brief Definitions of built-in type classes. The built-in type object get
/// method is as follows: Type fp32 = Float32Type::get(ctx);
/// method is as follows:
/// \code{cpp}
/// ir::IrContext *ctx = ir::IrContext::Instance();
/// Type fp32 = Float32Type::get(ctx);
/// \endcode
///
class Float32Type : public ir::Type {
public:
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/dialect.h"
namespace ir {
Dialect::Dialect(std::string name, ir::IrContext *context, ir::TypeId id)
: name_(std::move(name)), context_(context), id_(id) {}
void Dialect::RegisterType(ir::AbstractType &&abstract_type) {
ir::AbstractType *new_abstract_type =
new ir::AbstractType(std::move(abstract_type));
this->ir_context()->RegisterAbstractType(new_abstract_type->type_id(),
new_abstract_type);
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/ir_context.h"
#include "paddle/ir/type_base.h"
namespace ir {
///
/// \brief Dialect can basically be understood as a namespace. In Dialect, we
/// can define a series of types, operations, etc. An instance of the dialect
/// object will be loaded into the global IrContext. Specific compilers only
/// need to combine existing dialects and add their own extensions or
/// customizations.
///
class Dialect {
public:
Dialect(std::string name, ir::IrContext *context, ir::TypeId id);
const std::string &name() const { return name_; }
ir::IrContext *ir_context() const { return context_; }
ir::TypeId id() const { return id_; }
///
/// \brief Register all types contained in the template parameter Args.
/// To register only one Type, you can use the RegisterType template function.
///
template <typename... Args>
void RegisterTypes() {
(void)std::initializer_list<int>{0, (RegisterType<Args>(), 0)...};
}
///
/// \brief Register type of class T.
///
template <typename T>
void RegisterType() {
VLOG(4) << "Type registered into Dialect. --->";
ir::AbstractType *abstract_type =
new ir::AbstractType(std::move(ir::AbstractType::get<T>(*this)));
this->ir_context()->RegisterAbstractType(ir::TypeId::get<T>(),
abstract_type);
ir::TypeManager::RegisterType<T>(this->ir_context());
VLOG(4) << "----------------------------------";
}
///
/// \brief Register abstract_type into context.
/// NOTE: It's not recommended to use this interface directly. This interface
/// only registers abstract_type. To register TypeStorage into context, you
/// need to call ir::TypeManager::RegisterType<T>() additionally,
/// RegisterType<T>() is recommended to use.
///
void RegisterType(ir::AbstractType &&abstract_type);
private:
std::string name_;
ir::IrContext *context_; // not owned
ir::TypeId id_;
};
} // namespace ir
......@@ -12,61 +12,95 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/ir_context.h"
#include <unordered_map>
#include "paddle/ir/builtin_dialect.h"
#include "paddle/ir/builtin_type.h"
#include "paddle/ir/ir_context.h"
#include "paddle/ir/dialect.h"
#include "paddle/ir/spin_lock.h"
#include "paddle/ir/type_base.h"
namespace ir {
// The implementation class of the IrContext class
// The implementation class of the IrContext class, cache registered
// AbstractType, TypeStorage, Dialect.
class IrContextImpl {
public:
IrContextImpl() {}
~IrContextImpl() {
std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
for (auto abstract_type_map : registed_abstract_types_) {
std::lock_guard<ir::SpinLock> guard(destructor_lock_);
for (auto &abstract_type_map : registed_abstract_types_) {
delete abstract_type_map.second;
}
registed_abstract_types_.clear();
for (auto &dialect_map : registed_dialect_) {
delete dialect_map.second;
}
registed_dialect_.clear();
}
void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) {
std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
VLOG(4) << "IrContext register an abstract_type of: [TypeId_hash="
VLOG(4) << "Register an abstract_type of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id)
<< ", AbstractType_ptr=" << abstract_type << "].";
registed_abstract_types_.emplace(type_id, abstract_type);
}
AbstractType *lookup(ir::TypeId type_id) {
AbstractType *GetAbstractType(ir::TypeId type_id) {
std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
auto iter = registed_abstract_types_.find(type_id);
if (iter == registed_abstract_types_.end()) {
VLOG(4) << "IrContext not fonund cached abstract_type of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
return nullptr;
} else {
VLOG(4) << "IrContext fonund a cached abstract_type of: [TypeId_hash="
if (iter != registed_abstract_types_.end()) {
VLOG(4) << "Fonund a cached abstract_type of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id)
<< ", AbstractType_ptr=" << iter->second << "].";
return iter->second;
}
LOG(WARNING) << "No cache found abstract_type of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
return nullptr;
}
ir::SpinLock registed_abstract_types_lock_;
void RegisterDialect(std::string name, Dialect *dialect) {
std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
VLOG(4) << "Register a dialect of: [name=" << name
<< ", dialect_ptr=" << dialect << "].";
registed_dialect_.emplace(name, dialect);
}
Dialect *GetDialect(std::string name) {
std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
auto iter = registed_dialect_.find(name);
if (iter != registed_dialect_.end()) {
VLOG(4) << "Fonund a cached dialect of: [name=" << name
<< ", dialect_ptr=" << iter->second << "].";
return iter->second;
}
LOG(WARNING) << "No cache fonund dialect of: [name=" << name << "].";
return nullptr;
}
// Cached AbstractType instances.
std::unordered_map<TypeId, AbstractType *> registed_abstract_types_;
ir::SpinLock registed_abstract_types_lock_;
// TypeStorage uniquer and cache instances.
StorageManager registed_storage_manager_;
// Some built-in type.
// The dialcet registered in the context.
std::unordered_map<std::string, Dialect *> registed_dialect_;
ir::SpinLock registed_dialect_lock_;
// Some built-in types.
Float32Type fp32_type;
Int32Type int32_type;
ir::SpinLock destructor_lock_;
};
IrContext *IrContext::Instance() {
......@@ -75,13 +109,12 @@ IrContext *IrContext::Instance() {
}
IrContext::IrContext() : impl_(new IrContextImpl()) {
VLOG(4) << "IrContext register built-in type...";
REGISTER_TYPE_2_IRCONTEXT(Float32Type, this);
VLOG(4) << "BuiltinDialect registered into IrContext. ===>";
GetOrRegisterDialect<BuiltinDialect>();
VLOG(4) << "==============================================";
impl_->fp32_type = TypeManager::get<Float32Type>(this);
VLOG(4) << "Float32Type registration complete";
REGISTER_TYPE_2_IRCONTEXT(Int32Type, this);
impl_->int32_type = TypeManager::get<Int32Type>(this);
VLOG(4) << "Int32Type registration complete";
}
void IrContext::RegisterAbstractType(ir::TypeId type_id,
......@@ -98,12 +131,41 @@ std::unordered_map<TypeId, AbstractType *>
return impl().registed_abstract_types_;
}
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
VLOG(4) << "Lookup abstract type [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "] from IrContext [ptr=" << ctx
Dialect *IrContext::GetOrRegisterDialect(
std::string dialect_name, std::function<Dialect *()> constructor) {
VLOG(4) << "Try to get or register a Dialect of: [name=" << dialect_name
<< "].";
Dialect *dialect = impl().GetDialect(dialect_name);
if (dialect == nullptr) {
VLOG(4) << "Create and register a new Dialect of: [name=" << dialect_name
<< "].";
dialect = constructor();
impl().RegisterDialect(dialect_name, dialect);
}
return dialect;
}
std::vector<Dialect *> IrContext::GetRegisteredDialects() {
std::vector<Dialect *> result;
for (auto dialect_map : impl().registed_dialect_) {
result.push_back(dialect_map.second);
}
return result;
}
Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) {
for (auto dialect_map : impl().registed_dialect_) {
if (dialect_map.first == dialect_name) {
return dialect_map.second;
}
}
LOG(WARNING) << "No dialect registered for " << dialect_name;
return nullptr;
}
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
auto &impl = ctx->impl();
AbstractType *abstract_type = impl.lookup(type_id);
AbstractType *abstract_type = impl.GetAbstractType(type_id);
if (abstract_type) {
return *abstract_type;
} else {
......
......@@ -15,6 +15,7 @@
#pragma once
#include <glog/logging.h>
#include <functional>
#include <memory>
#include <unordered_map>
......@@ -23,6 +24,7 @@ class IrContextImpl;
class StorageManager;
class AbstractType;
class TypeId;
class Dialect;
///
/// \brief IrContext is a global parameterless class used to store and manage
......@@ -69,6 +71,63 @@ class IrContext {
///
std::unordered_map<TypeId, AbstractType *> &registed_abstracted_type();
///
/// \brief Get the dialect of the DialectT class in the context, ff not found,
/// create and register to context.
///
/// \param DialectT The Dialect class that needs to be found or register.
///
/// \return The dialect of the DialectT class in the context.
///
template <typename DialectT>
DialectT *GetOrRegisterDialect() {
return static_cast<DialectT *>(
GetOrRegisterDialect(DialectT::name(), [this]() {
DialectT *dialect = new DialectT(this);
return dialect;
}));
}
///
/// \brief Get the dialect of the DialectT class in the context, ff not found,
/// create and register to context.
///
/// \param dialect_name The dialect name.
/// \param dialect_id The TypeId of the dialect.
/// \param constructor The dialect constructor.
///
/// \return The dialect named "dialect_name" in the context.
///
Dialect *GetOrRegisterDialect(std::string dialect_name,
std::function<Dialect *()> constructor);
///
/// \brief Get the dialect list registered to the context.
///
/// \return The dialect list registered to the context.
///
std::vector<Dialect *> GetRegisteredDialects();
///
/// \brief Get the dialect named "name" from the context.
///
/// \param name The name of the dialect to be obtained.
///
/// \return The dialect named "name" from the context.
///
Dialect *GetRegisteredDialect(const std::string &dialect_name);
///
/// \brief Get a registered dialect for the given dialect type T. The
/// Dialect must provide a static 'name' method.
///
/// \return The registered dialect for the given dialect type T.
///
template <typename T>
T *GetRegisteredDialect() {
return static_cast<T *>(GetRegisteredDialect(T::name()));
}
IrContext(const IrContext &) = delete;
void operator=(const IrContext &) = delete;
......
......@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/storage_manager.h"
#include <memory>
#include <unordered_map>
#include "paddle/ir/storage_manager.h"
namespace ir {
// This is a structure for creating, caching, and looking up Storage of
// parameteric types.
......@@ -72,7 +72,7 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageTypeImpl(
std::function<bool(const StorageBase *)> equal_func,
std::function<StorageBase *()> constructor) {
std::lock_guard<ir::SpinLock> guard(parametric_instance_lock_);
VLOG(4) << "StorageManager get parameteretric storage of: [TypeId_hash="
VLOG(4) << "Try to get a parameteretric storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << ", param_hash=" << hash_value
<< "].";
if (parametric_instance_.find(type_id) == parametric_instance_.end())
......@@ -83,18 +83,18 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageTypeImpl(
StorageManager::StorageBase *StorageManager::GetParameterlessStorageTypeImpl(
TypeId type_id) {
std::lock_guard<ir::SpinLock> guard(parameterless_instances_lock_);
VLOG(4) << "StorageManager get parameterless storage of: [TypeId_hash="
std::lock_guard<ir::SpinLock> guard(parameterless_instance_lock_);
VLOG(4) << "Try to get a parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instances_.find(type_id) == parameterless_instances_.end())
if (parameterless_instance_.find(type_id) == parameterless_instance_.end())
throw("TypeId not found in IrContext.");
StorageBase *parameterless_instance = parameterless_instances_[type_id];
StorageBase *parameterless_instance = parameterless_instance_[type_id];
return parameterless_instance;
}
void StorageManager::RegisterParametricStorageTypeImpl(TypeId type_id) {
std::lock_guard<ir::SpinLock> guard(parametric_instance_lock_);
VLOG(4) << "StorageManager register parameteric storage of: [TypeId_hash="
VLOG(4) << "Register a parameteric storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
parametric_instance_.emplace(type_id,
std::make_unique<ParametricStorageManager>());
......@@ -102,12 +102,12 @@ void StorageManager::RegisterParametricStorageTypeImpl(TypeId type_id) {
void StorageManager::RegisterParameterlessStorageTypeImpl(
TypeId type_id, std::function<StorageBase *()> constructor) {
std::lock_guard<ir::SpinLock> guard(parameterless_instances_lock_);
VLOG(4) << "StorageManager register parameterless storage of: [TypeId_hash="
std::lock_guard<ir::SpinLock> guard(parameterless_instance_lock_);
VLOG(4) << "Register a parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instances_.find(type_id) != parameterless_instances_.end())
if (parameterless_instance_.find(type_id) != parameterless_instance_.end())
throw("storage class already registered");
parameterless_instances_.emplace(type_id, constructor());
parameterless_instance_.emplace(type_id, constructor());
}
} // namespace ir
......@@ -141,9 +141,9 @@ class StorageManager {
ir::SpinLock parametric_instance_lock_;
// This map is a mapping between type id and parameterless type storage.
std::unordered_map<TypeId, StorageBase *> parameterless_instances_;
std::unordered_map<TypeId, StorageBase *> parameterless_instance_;
ir::SpinLock parameterless_instances_lock_;
ir::SpinLock parameterless_instance_lock_;
};
} // namespace ir
......@@ -15,22 +15,27 @@
#include <gtest/gtest.h>
#include <unordered_map>
#include "paddle/ir/builtin_dialect.h"
#include "paddle/ir/builtin_type.h"
#include "paddle/ir/dialect.h"
#include "paddle/ir/ir_context.h"
#include "paddle/ir/type.h"
#include "paddle/ir/type_base.h"
TEST(type_test, type_id) {
// Define two empty classes, just for testing.
class TypeA {};
class TypeB {};
// (1) Test construct TypeId by TypeId::Get()
// Test 1: Test construct TypeId by TypeId::get<T>() and overloaded operator==
// method.
ir::TypeId a_id = ir::TypeId::get<TypeA>();
ir::TypeId a_other_id = ir::TypeId::get<TypeA>();
ir::TypeId b_id = ir::TypeId::get<TypeB>();
EXPECT_EQ(a_id, a_other_id);
EXPECT_NE(a_id, b_id);
// (2) Test TypeId hash
// Test 2: Test the hash function of TypeId.
std::unordered_map<ir::TypeId, ir::TypeId *> type_id_register;
type_id_register.emplace(a_id, &a_id);
type_id_register.emplace(b_id, &b_id);
......@@ -39,32 +44,38 @@ TEST(type_test, type_id) {
}
}
TEST(type_test, abstract_type) {
TEST(type_test, type_base) {
// Define two empty classes, just for testing.
class TypeA {};
ir::TypeId a_id = ir::TypeId::get<TypeA>();
ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id);
// Define a FakeDialect without registering any types.
struct FakeDialect : ir::Dialect {
explicit FakeDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<FakeDialect>()) {}
static const char *name() { return "fake"; }
};
EXPECT_EQ(abstract_type_a.type_id(), a_id);
}
TEST(type_test, type_storage) {
class TypeA {};
// Test 1: Test the function of IrContext to register Dialect.
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Dialect *fake_dialect = ctx->GetOrRegisterDialect<FakeDialect>();
// Test 2: Test the get method of AbstractType.
ir::TypeId a_id = ir::TypeId::get<TypeA>();
ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id);
ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id, *fake_dialect);
EXPECT_EQ(abstract_type_a.type_id(), a_id);
// Test 3: Test the constructor of TypeStorage.
ir::TypeStorage storage_a(&abstract_type_a);
EXPECT_EQ(storage_a.abstract_type().type_id(), abstract_type_a.type_id());
}
TEST(type_test, built_in_type) {
// Test creation of built-in parameterless type.
// Test 1: Test the built-in type of IrContext.
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Type fp32_1 = ir::Float32Type::get(ctx);
// Test interfaces of class Type
// Test 2: Test the interfaces of class Type: judgment, type_id,
// abstract_type, classof.
ir::Type fp32_2 = ir::Float32Type::get(ctx);
EXPECT_EQ(fp32_1 == fp32_2, 1);
EXPECT_EQ(fp32_1 != fp32_2, 0);
......@@ -84,6 +95,7 @@ TEST(type_test, built_in_type) {
EXPECT_EQ(ir::Int32Type::classof(int32_1), 1);
}
// Customize a parameterized TypeStorage IntegerTypeStorage.
struct IntegerTypeStorage : public ir::TypeStorage {
IntegerTypeStorage(unsigned width, unsigned signedness)
: width_(width), signedness_(signedness) {}
......@@ -113,19 +125,50 @@ struct IntegerTypeStorage : public ir::TypeStorage {
}
};
// Customize a parameterized type: IntegerType, storage type is
// IntegerTypeStorage.
class IntegerType : public ir::Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(IntegerType, IntegerTypeStorage);
};
TEST(type_test, parameteric_type) {
// Customize a Dialect IntegerDialect, registration type of IntegerType.
struct IntegerDialect : ir::Dialect {
explicit IntegerDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<IntegerDialect>()) {
RegisterType<IntegerType>();
}
static const char *name() { return "integer"; }
};
TEST(type_test, custom_type_dialect) {
ir::IrContext *ctx = ir::IrContext::Instance();
REGISTER_TYPE_2_IRCONTEXT(IntegerType, ctx);
// Test 1: Test the function of IrContext to register Dialect.
ctx->GetOrRegisterDialect<IntegerDialect>();
ir::Type int1_1 = IntegerType::get(ctx, 1, 0);
ir::Type int1_2 = IntegerType::get(ctx, 1, 0);
EXPECT_EQ(int1_1 == int1_2, 1);
ir::Type int8 = IntegerType::get(ctx, 8, 0);
EXPECT_EQ(int8 == int1_2, 0);
// Test 2: Test Dialect interfaces
EXPECT_EQ(ctx == int8.ir_context(), 1);
EXPECT_EQ(int8.dialect().id() == ir::TypeId::get<IntegerDialect>(), 1);
std::vector<ir::Dialect *> dialect_list = ctx->GetRegisteredDialects();
EXPECT_EQ(dialect_list.size() == 3, 1); // integer, builtin, fake
ir::Dialect *dialect_builtin1 = ctx->GetRegisteredDialect("builtin");
ir::Dialect *dialect_builtin2 =
ctx->GetRegisteredDialect<ir::BuiltinDialect>();
EXPECT_EQ(dialect_builtin1 == dialect_builtin2, 1);
ir::Dialect *dialect_integer1 = ctx->GetRegisteredDialect("integer");
ir::Dialect *dialect_integer2 = ctx->GetRegisteredDialect<IntegerDialect>();
EXPECT_EQ(dialect_integer1 == dialect_integer2, 1);
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/type.h"
#include "paddle/ir/dialect.h"
namespace ir {
IrContext *Type::ir_context() const { return dialect().ir_context(); }
} // namespace ir
......@@ -52,6 +52,10 @@ class Type {
StorageType *storage() const { return storage_; }
const Dialect &dialect() const { return storage_->abstract_type().dialect(); }
IrContext *ir_context() const;
///
/// \brief Enable hashing Type.
///
......
......@@ -19,6 +19,8 @@
#include "paddle/ir/type_id.h"
namespace ir {
class Dialect;
///
/// \brief Abstract the properties and behaviors common to all Type classes into
/// an AbstractType class. There are two types in Type system:
......@@ -32,8 +34,21 @@ class AbstractType {
/// \brief Construct an AbstractType by TypeId directly.
///
/// \param type_id The type id of the AbstractType.
/// \param dialect The Dialect which the type registered to.
///
static AbstractType get(TypeId type_id, const Dialect &dialect) {
return AbstractType(type_id, dialect);
}
///
static AbstractType get(TypeId type_id) { return AbstractType(type_id); }
/// \brief Construct an AbstractType by TypeId directly.
///
/// \param dialect The Dialect which the type registered to.
///
template <typename T>
static AbstractType get(const Dialect &dialect) {
return AbstractType(TypeId::get<T>(), dialect);
}
///
/// \brief Returns the type id of the AbstractType.
......@@ -42,6 +57,13 @@ class AbstractType {
///
TypeId type_id() const { return type_id_; }
///
/// \brief Get the dialect this type was registered to.
///
/// \return The dialect this type was registered to.
///
const Dialect &dialect() const { return dialect_; }
///
/// \brief Find the AbstractType instance whose TypeId is type_id from
/// IrContext.
......@@ -58,10 +80,14 @@ class AbstractType {
/// get method to obtain and manage the AstractType.
///
/// \param type_id The type id of the AbstractType.
/// \param dialect The Dialect which the type registered to.
///
explicit AbstractType(TypeId type_id) : type_id_(type_id) {}
explicit AbstractType(TypeId type_id, const Dialect &dialect)
: type_id_(type_id), dialect_(dialect) {}
TypeId type_id_;
const Dialect &dialect_;
};
struct TypeManager;
......@@ -239,13 +265,13 @@ struct TypeManager {
///
/// \brief This macro definition is used to register custom Type class.
///
#define REGISTER_TYPE_2_IRCONTEXT(concrete_type, ir_context) \
ir::AbstractType *abstract_type_##concrete_type = new ir::AbstractType( \
std::move(ir::AbstractType::get(ir::TypeId::get<concrete_type>()))); \
\
ir_context->RegisterAbstractType(ir::TypeId::get<concrete_type>(), \
abstract_type_##concrete_type); \
\
ir::TypeManager::RegisterType<concrete_type>(ir_context);
#define REGISTER_TYPE_2_IRCONTEXT(concrete_type, dialect) \
ir::AbstractType *abstract_type_##concrete_type = new ir::AbstractType( \
std::move(ir::AbstractType::get<concrete_type>(*dialect))); \
\
dialect->ir_context()->RegisterAbstractType( \
ir::TypeId::get<concrete_type>(), abstract_type_##concrete_type); \
\
ir::TypeManager::RegisterType<concrete_type>(dialect->ir_context());
} // namespace ir
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册