From b0a604cb90702d6a851a6961b58ebdbc628f9ff9 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Thu, 2 Mar 2023 11:21:35 +0800 Subject: [PATCH] [IR] Type system stage3: add class Dialect (#50959) * add dialect * add some interface for dialect * add some dialect interfaces for class Type * set WITH_NEWIR=OFF * refine code by comment * polish code * refine include style * refine log for debug --- paddle/ir/builtin_dialect.cc | 29 ++++++++++ paddle/ir/builtin_dialect.h | 40 +++++++++++++ paddle/ir/builtin_type.h | 11 +++- paddle/ir/dialect.cc | 28 +++++++++ paddle/ir/dialect.h | 77 +++++++++++++++++++++++++ paddle/ir/ir_context.cc | 108 +++++++++++++++++++++++++++-------- paddle/ir/ir_context.h | 59 +++++++++++++++++++ paddle/ir/storage_manager.cc | 24 ++++---- paddle/ir/storage_manager.h | 4 +- paddle/ir/tests/type_test.cc | 75 ++++++++++++++++++------ paddle/ir/type.cc | 21 +++++++ paddle/ir/type.h | 4 ++ paddle/ir/type_base.h | 46 +++++++++++---- 13 files changed, 462 insertions(+), 64 deletions(-) create mode 100644 paddle/ir/builtin_dialect.cc create mode 100644 paddle/ir/builtin_dialect.h create mode 100644 paddle/ir/dialect.cc create mode 100644 paddle/ir/dialect.h create mode 100644 paddle/ir/type.cc diff --git a/paddle/ir/builtin_dialect.cc b/paddle/ir/builtin_dialect.cc new file mode 100644 index 00000000000..5c797c4214c --- /dev/null +++ b/paddle/ir/builtin_dialect.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/builtin_dialect.h" +#include "paddle/ir/builtin_type.h" + +namespace ir { +BuiltinDialect::BuiltinDialect(ir::IrContext *context) + : ir::Dialect(name(), context, ir::TypeId::get()) { + initialize(); +} + +void BuiltinDialect::initialize() { + // Register all built-in types defined in builtin_type.h. + RegisterTypes(); +} + +} // namespace ir diff --git a/paddle/ir/builtin_dialect.h b/paddle/ir/builtin_dialect.h new file mode 100644 index 00000000000..5016c1abea0 --- /dev/null +++ b/paddle/ir/builtin_dialect.h @@ -0,0 +1,40 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/dialect.h" + +namespace ir { +/// +/// \brief Built-in Dialect: automatically registered into global IrContext, +/// all built-in types defined in builtin_type.h will be registered in this +/// Dialect. +/// +class BuiltinDialect : public ir::Dialect { + public: + explicit BuiltinDialect(ir::IrContext *context); + /// + /// \brief Each Dialect needs to provide a name function to return the name of + /// the Dialect. + /// + /// \return The name of this Dialect. + /// + static const char *name() { return "builtin"; } + + private: + void initialize(); +}; + +} // namespace ir diff --git a/paddle/ir/builtin_type.h b/paddle/ir/builtin_type.h index 77159794bf1..8b15ae6eed0 100644 --- a/paddle/ir/builtin_type.h +++ b/paddle/ir/builtin_type.h @@ -17,9 +17,18 @@ #include "paddle/ir/type.h" namespace ir { +/// +/// \brief This macro is used to get a list of all built-in types in this file. +/// +#define GET_BUILT_IN_TYPE_LIST ir::Float32Type, ir::Int32Type + /// /// \brief Definitions of built-in type classes. The built-in type object get -/// method is as follows: Type fp32 = Float32Type::get(ctx); +/// method is as follows: +/// \code{cpp} +/// ir::IrContext *ctx = ir::IrContext::Instance(); +/// Type fp32 = Float32Type::get(ctx); +/// \endcode /// class Float32Type : public ir::Type { public: diff --git a/paddle/ir/dialect.cc b/paddle/ir/dialect.cc new file mode 100644 index 00000000000..5a913fdf4bb --- /dev/null +++ b/paddle/ir/dialect.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/dialect.h" + +namespace ir { +Dialect::Dialect(std::string name, ir::IrContext *context, ir::TypeId id) + : name_(std::move(name)), context_(context), id_(id) {} + +void Dialect::RegisterType(ir::AbstractType &&abstract_type) { + ir::AbstractType *new_abstract_type = + new ir::AbstractType(std::move(abstract_type)); + this->ir_context()->RegisterAbstractType(new_abstract_type->type_id(), + new_abstract_type); +} + +} // namespace ir diff --git a/paddle/ir/dialect.h b/paddle/ir/dialect.h new file mode 100644 index 00000000000..6b5b733f782 --- /dev/null +++ b/paddle/ir/dialect.h @@ -0,0 +1,77 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/ir_context.h" +#include "paddle/ir/type_base.h" + +namespace ir { +/// +/// \brief Dialect can basically be understood as a namespace. In Dialect, we +/// can define a series of types, operations, etc. An instance of the dialect +/// object will be loaded into the global IrContext. Specific compilers only +/// need to combine existing dialects and add their own extensions or +/// customizations. +/// +class Dialect { + public: + Dialect(std::string name, ir::IrContext *context, ir::TypeId id); + + const std::string &name() const { return name_; } + + ir::IrContext *ir_context() const { return context_; } + + ir::TypeId id() const { return id_; } + + /// + /// \brief Register all types contained in the template parameter Args. + /// To register only one Type, you can use the RegisterType template function. + /// + template + void RegisterTypes() { + (void)std::initializer_list{0, (RegisterType(), 0)...}; + } + + /// + /// \brief Register type of class T. + /// + template + void RegisterType() { + VLOG(4) << "Type registered into Dialect. --->"; + ir::AbstractType *abstract_type = + new ir::AbstractType(std::move(ir::AbstractType::get(*this))); + this->ir_context()->RegisterAbstractType(ir::TypeId::get(), + abstract_type); + ir::TypeManager::RegisterType(this->ir_context()); + VLOG(4) << "----------------------------------"; + } + + /// + /// \brief Register abstract_type into context. + /// NOTE: It's not recommended to use this interface directly. This interface + /// only registers abstract_type. To register TypeStorage into context, you + /// need to call ir::TypeManager::RegisterType() additionally, + /// RegisterType() is recommended to use. + /// + void RegisterType(ir::AbstractType &&abstract_type); + + private: + std::string name_; + + ir::IrContext *context_; // not owned + + ir::TypeId id_; +}; +} // namespace ir diff --git a/paddle/ir/ir_context.cc b/paddle/ir/ir_context.cc index 6aed6009034..6fca67c16e2 100644 --- a/paddle/ir/ir_context.cc +++ b/paddle/ir/ir_context.cc @@ -12,61 +12,95 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/ir/ir_context.h" + #include +#include "paddle/ir/builtin_dialect.h" #include "paddle/ir/builtin_type.h" -#include "paddle/ir/ir_context.h" +#include "paddle/ir/dialect.h" #include "paddle/ir/spin_lock.h" #include "paddle/ir/type_base.h" namespace ir { -// The implementation class of the IrContext class +// The implementation class of the IrContext class, cache registered +// AbstractType, TypeStorage, Dialect. class IrContextImpl { public: IrContextImpl() {} ~IrContextImpl() { - std::lock_guard guard(registed_abstract_types_lock_); - for (auto abstract_type_map : registed_abstract_types_) { + std::lock_guard guard(destructor_lock_); + for (auto &abstract_type_map : registed_abstract_types_) { delete abstract_type_map.second; } registed_abstract_types_.clear(); + + for (auto &dialect_map : registed_dialect_) { + delete dialect_map.second; + } + registed_dialect_.clear(); } void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) { std::lock_guard guard(registed_abstract_types_lock_); - VLOG(4) << "IrContext register an abstract_type of: [TypeId_hash=" + VLOG(4) << "Register an abstract_type of: [TypeId_hash=" << std::hash()(type_id) << ", AbstractType_ptr=" << abstract_type << "]."; registed_abstract_types_.emplace(type_id, abstract_type); } - AbstractType *lookup(ir::TypeId type_id) { + AbstractType *GetAbstractType(ir::TypeId type_id) { std::lock_guard guard(registed_abstract_types_lock_); auto iter = registed_abstract_types_.find(type_id); - if (iter == registed_abstract_types_.end()) { - VLOG(4) << "IrContext not fonund cached abstract_type of: [TypeId_hash=" - << std::hash()(type_id) << "]."; - return nullptr; - } else { - VLOG(4) << "IrContext fonund a cached abstract_type of: [TypeId_hash=" + if (iter != registed_abstract_types_.end()) { + VLOG(4) << "Fonund a cached abstract_type of: [TypeId_hash=" << std::hash()(type_id) << ", AbstractType_ptr=" << iter->second << "]."; return iter->second; } + LOG(WARNING) << "No cache found abstract_type of: [TypeId_hash=" + << std::hash()(type_id) << "]."; + return nullptr; } - ir::SpinLock registed_abstract_types_lock_; + void RegisterDialect(std::string name, Dialect *dialect) { + std::lock_guard guard(registed_dialect_lock_); + VLOG(4) << "Register a dialect of: [name=" << name + << ", dialect_ptr=" << dialect << "]."; + registed_dialect_.emplace(name, dialect); + } + + Dialect *GetDialect(std::string name) { + std::lock_guard guard(registed_dialect_lock_); + auto iter = registed_dialect_.find(name); + if (iter != registed_dialect_.end()) { + VLOG(4) << "Fonund a cached dialect of: [name=" << name + << ", dialect_ptr=" << iter->second << "]."; + return iter->second; + } + LOG(WARNING) << "No cache fonund dialect of: [name=" << name << "]."; + return nullptr; + } // Cached AbstractType instances. std::unordered_map registed_abstract_types_; + ir::SpinLock registed_abstract_types_lock_; + // TypeStorage uniquer and cache instances. StorageManager registed_storage_manager_; - // Some built-in type. + // The dialcet registered in the context. + std::unordered_map registed_dialect_; + + ir::SpinLock registed_dialect_lock_; + + // Some built-in types. Float32Type fp32_type; Int32Type int32_type; + + ir::SpinLock destructor_lock_; }; IrContext *IrContext::Instance() { @@ -75,13 +109,12 @@ IrContext *IrContext::Instance() { } IrContext::IrContext() : impl_(new IrContextImpl()) { - VLOG(4) << "IrContext register built-in type..."; - REGISTER_TYPE_2_IRCONTEXT(Float32Type, this); + VLOG(4) << "BuiltinDialect registered into IrContext. ===>"; + GetOrRegisterDialect(); + VLOG(4) << "=============================================="; + impl_->fp32_type = TypeManager::get(this); - VLOG(4) << "Float32Type registration complete"; - REGISTER_TYPE_2_IRCONTEXT(Int32Type, this); impl_->int32_type = TypeManager::get(this); - VLOG(4) << "Int32Type registration complete"; } void IrContext::RegisterAbstractType(ir::TypeId type_id, @@ -98,12 +131,41 @@ std::unordered_map return impl().registed_abstract_types_; } -const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) { - VLOG(4) << "Lookup abstract type [TypeId_hash=" - << std::hash()(type_id) << "] from IrContext [ptr=" << ctx +Dialect *IrContext::GetOrRegisterDialect( + std::string dialect_name, std::function constructor) { + VLOG(4) << "Try to get or register a Dialect of: [name=" << dialect_name << "]."; + Dialect *dialect = impl().GetDialect(dialect_name); + if (dialect == nullptr) { + VLOG(4) << "Create and register a new Dialect of: [name=" << dialect_name + << "]."; + dialect = constructor(); + impl().RegisterDialect(dialect_name, dialect); + } + return dialect; +} + +std::vector IrContext::GetRegisteredDialects() { + std::vector result; + for (auto dialect_map : impl().registed_dialect_) { + result.push_back(dialect_map.second); + } + return result; +} + +Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) { + for (auto dialect_map : impl().registed_dialect_) { + if (dialect_map.first == dialect_name) { + return dialect_map.second; + } + } + LOG(WARNING) << "No dialect registered for " << dialect_name; + return nullptr; +} + +const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) { auto &impl = ctx->impl(); - AbstractType *abstract_type = impl.lookup(type_id); + AbstractType *abstract_type = impl.GetAbstractType(type_id); if (abstract_type) { return *abstract_type; } else { diff --git a/paddle/ir/ir_context.h b/paddle/ir/ir_context.h index 146497e6c6c..7343c574874 100644 --- a/paddle/ir/ir_context.h +++ b/paddle/ir/ir_context.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include @@ -23,6 +24,7 @@ class IrContextImpl; class StorageManager; class AbstractType; class TypeId; +class Dialect; /// /// \brief IrContext is a global parameterless class used to store and manage @@ -69,6 +71,63 @@ class IrContext { /// std::unordered_map ®isted_abstracted_type(); + /// + /// \brief Get the dialect of the DialectT class in the context, ff not found, + /// create and register to context. + /// + /// \param DialectT The Dialect class that needs to be found or register. + /// + /// \return The dialect of the DialectT class in the context. + /// + template + DialectT *GetOrRegisterDialect() { + return static_cast( + GetOrRegisterDialect(DialectT::name(), [this]() { + DialectT *dialect = new DialectT(this); + return dialect; + })); + } + + /// + /// \brief Get the dialect of the DialectT class in the context, ff not found, + /// create and register to context. + /// + /// \param dialect_name The dialect name. + /// \param dialect_id The TypeId of the dialect. + /// \param constructor The dialect constructor. + /// + /// \return The dialect named "dialect_name" in the context. + /// + Dialect *GetOrRegisterDialect(std::string dialect_name, + std::function constructor); + + /// + /// \brief Get the dialect list registered to the context. + /// + /// \return The dialect list registered to the context. + /// + std::vector GetRegisteredDialects(); + + /// + /// \brief Get the dialect named "name" from the context. + /// + /// \param name The name of the dialect to be obtained. + /// + /// \return The dialect named "name" from the context. + /// + Dialect *GetRegisteredDialect(const std::string &dialect_name); + + /// + /// \brief Get a registered dialect for the given dialect type T. The + /// Dialect must provide a static 'name' method. + /// + /// \return The registered dialect for the given dialect type T. + /// + template + T *GetRegisteredDialect() { + return static_cast(GetRegisteredDialect(T::name())); + } + IrContext(const IrContext &) = delete; void operator=(const IrContext &) = delete; diff --git a/paddle/ir/storage_manager.cc b/paddle/ir/storage_manager.cc index 991077e8777..9cec4a48e14 100644 --- a/paddle/ir/storage_manager.cc +++ b/paddle/ir/storage_manager.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/ir/storage_manager.h" + #include #include -#include "paddle/ir/storage_manager.h" - namespace ir { // This is a structure for creating, caching, and looking up Storage of // parameteric types. @@ -72,7 +72,7 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageTypeImpl( std::function equal_func, std::function constructor) { std::lock_guard guard(parametric_instance_lock_); - VLOG(4) << "StorageManager get parameteretric storage of: [TypeId_hash=" + VLOG(4) << "Try to get a parameteretric storage of: [TypeId_hash=" << std::hash()(type_id) << ", param_hash=" << hash_value << "]."; if (parametric_instance_.find(type_id) == parametric_instance_.end()) @@ -83,18 +83,18 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageTypeImpl( StorageManager::StorageBase *StorageManager::GetParameterlessStorageTypeImpl( TypeId type_id) { - std::lock_guard guard(parameterless_instances_lock_); - VLOG(4) << "StorageManager get parameterless storage of: [TypeId_hash=" + std::lock_guard guard(parameterless_instance_lock_); + VLOG(4) << "Try to get a parameterless storage of: [TypeId_hash=" << std::hash()(type_id) << "]."; - if (parameterless_instances_.find(type_id) == parameterless_instances_.end()) + if (parameterless_instance_.find(type_id) == parameterless_instance_.end()) throw("TypeId not found in IrContext."); - StorageBase *parameterless_instance = parameterless_instances_[type_id]; + StorageBase *parameterless_instance = parameterless_instance_[type_id]; return parameterless_instance; } void StorageManager::RegisterParametricStorageTypeImpl(TypeId type_id) { std::lock_guard guard(parametric_instance_lock_); - VLOG(4) << "StorageManager register parameteric storage of: [TypeId_hash=" + VLOG(4) << "Register a parameteric storage of: [TypeId_hash=" << std::hash()(type_id) << "]."; parametric_instance_.emplace(type_id, std::make_unique()); @@ -102,12 +102,12 @@ void StorageManager::RegisterParametricStorageTypeImpl(TypeId type_id) { void StorageManager::RegisterParameterlessStorageTypeImpl( TypeId type_id, std::function constructor) { - std::lock_guard guard(parameterless_instances_lock_); - VLOG(4) << "StorageManager register parameterless storage of: [TypeId_hash=" + std::lock_guard guard(parameterless_instance_lock_); + VLOG(4) << "Register a parameterless storage of: [TypeId_hash=" << std::hash()(type_id) << "]."; - if (parameterless_instances_.find(type_id) != parameterless_instances_.end()) + if (parameterless_instance_.find(type_id) != parameterless_instance_.end()) throw("storage class already registered"); - parameterless_instances_.emplace(type_id, constructor()); + parameterless_instance_.emplace(type_id, constructor()); } } // namespace ir diff --git a/paddle/ir/storage_manager.h b/paddle/ir/storage_manager.h index f94174586bc..8b6c1a330e4 100644 --- a/paddle/ir/storage_manager.h +++ b/paddle/ir/storage_manager.h @@ -141,9 +141,9 @@ class StorageManager { ir::SpinLock parametric_instance_lock_; // This map is a mapping between type id and parameterless type storage. - std::unordered_map parameterless_instances_; + std::unordered_map parameterless_instance_; - ir::SpinLock parameterless_instances_lock_; + ir::SpinLock parameterless_instance_lock_; }; } // namespace ir diff --git a/paddle/ir/tests/type_test.cc b/paddle/ir/tests/type_test.cc index 85deb51b694..e8901be6c35 100644 --- a/paddle/ir/tests/type_test.cc +++ b/paddle/ir/tests/type_test.cc @@ -15,22 +15,27 @@ #include #include +#include "paddle/ir/builtin_dialect.h" #include "paddle/ir/builtin_type.h" +#include "paddle/ir/dialect.h" #include "paddle/ir/ir_context.h" +#include "paddle/ir/type.h" #include "paddle/ir/type_base.h" TEST(type_test, type_id) { + // Define two empty classes, just for testing. class TypeA {}; class TypeB {}; - // (1) Test construct TypeId by TypeId::Get() + // Test 1: Test construct TypeId by TypeId::get() and overloaded operator== + // method. ir::TypeId a_id = ir::TypeId::get(); ir::TypeId a_other_id = ir::TypeId::get(); ir::TypeId b_id = ir::TypeId::get(); EXPECT_EQ(a_id, a_other_id); EXPECT_NE(a_id, b_id); - // (2) Test TypeId hash + // Test 2: Test the hash function of TypeId. std::unordered_map type_id_register; type_id_register.emplace(a_id, &a_id); type_id_register.emplace(b_id, &b_id); @@ -39,32 +44,38 @@ TEST(type_test, type_id) { } } -TEST(type_test, abstract_type) { +TEST(type_test, type_base) { + // Define two empty classes, just for testing. class TypeA {}; - ir::TypeId a_id = ir::TypeId::get(); - ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id); + // Define a FakeDialect without registering any types. + struct FakeDialect : ir::Dialect { + explicit FakeDialect(ir::IrContext *context) + : ir::Dialect(name(), context, ir::TypeId::get()) {} + static const char *name() { return "fake"; } + }; - EXPECT_EQ(abstract_type_a.type_id(), a_id); -} - -TEST(type_test, type_storage) { - class TypeA {}; + // Test 1: Test the function of IrContext to register Dialect. + ir::IrContext *ctx = ir::IrContext::Instance(); + ir::Dialect *fake_dialect = ctx->GetOrRegisterDialect(); + // Test 2: Test the get method of AbstractType. ir::TypeId a_id = ir::TypeId::get(); - ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id); + ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id, *fake_dialect); + EXPECT_EQ(abstract_type_a.type_id(), a_id); + // Test 3: Test the constructor of TypeStorage. ir::TypeStorage storage_a(&abstract_type_a); - EXPECT_EQ(storage_a.abstract_type().type_id(), abstract_type_a.type_id()); } TEST(type_test, built_in_type) { - // Test creation of built-in parameterless type. + // Test 1: Test the built-in type of IrContext. ir::IrContext *ctx = ir::IrContext::Instance(); ir::Type fp32_1 = ir::Float32Type::get(ctx); - // Test interfaces of class Type + // Test 2: Test the interfaces of class Type: judgment, type_id, + // abstract_type, classof. ir::Type fp32_2 = ir::Float32Type::get(ctx); EXPECT_EQ(fp32_1 == fp32_2, 1); EXPECT_EQ(fp32_1 != fp32_2, 0); @@ -84,6 +95,7 @@ TEST(type_test, built_in_type) { EXPECT_EQ(ir::Int32Type::classof(int32_1), 1); } +// Customize a parameterized TypeStorage IntegerTypeStorage. struct IntegerTypeStorage : public ir::TypeStorage { IntegerTypeStorage(unsigned width, unsigned signedness) : width_(width), signedness_(signedness) {} @@ -113,19 +125,50 @@ struct IntegerTypeStorage : public ir::TypeStorage { } }; +// Customize a parameterized type: IntegerType, storage type is +// IntegerTypeStorage. class IntegerType : public ir::Type { public: using Type::Type; DECLARE_TYPE_UTILITY_FUNCTOR(IntegerType, IntegerTypeStorage); }; -TEST(type_test, parameteric_type) { +// Customize a Dialect IntegerDialect, registration type of IntegerType. +struct IntegerDialect : ir::Dialect { + explicit IntegerDialect(ir::IrContext *context) + : ir::Dialect(name(), context, ir::TypeId::get()) { + RegisterType(); + } + static const char *name() { return "integer"; } +}; + +TEST(type_test, custom_type_dialect) { ir::IrContext *ctx = ir::IrContext::Instance(); - REGISTER_TYPE_2_IRCONTEXT(IntegerType, ctx); + + // Test 1: Test the function of IrContext to register Dialect. + ctx->GetOrRegisterDialect(); + ir::Type int1_1 = IntegerType::get(ctx, 1, 0); ir::Type int1_2 = IntegerType::get(ctx, 1, 0); EXPECT_EQ(int1_1 == int1_2, 1); ir::Type int8 = IntegerType::get(ctx, 8, 0); EXPECT_EQ(int8 == int1_2, 0); + + // Test 2: Test Dialect interfaces + EXPECT_EQ(ctx == int8.ir_context(), 1); + + EXPECT_EQ(int8.dialect().id() == ir::TypeId::get(), 1); + + std::vector dialect_list = ctx->GetRegisteredDialects(); + EXPECT_EQ(dialect_list.size() == 3, 1); // integer, builtin, fake + + ir::Dialect *dialect_builtin1 = ctx->GetRegisteredDialect("builtin"); + ir::Dialect *dialect_builtin2 = + ctx->GetRegisteredDialect(); + EXPECT_EQ(dialect_builtin1 == dialect_builtin2, 1); + + ir::Dialect *dialect_integer1 = ctx->GetRegisteredDialect("integer"); + ir::Dialect *dialect_integer2 = ctx->GetRegisteredDialect(); + EXPECT_EQ(dialect_integer1 == dialect_integer2, 1); } diff --git a/paddle/ir/type.cc b/paddle/ir/type.cc new file mode 100644 index 00000000000..bde3194f8fd --- /dev/null +++ b/paddle/ir/type.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/type.h" +#include "paddle/ir/dialect.h" + +namespace ir { +IrContext *Type::ir_context() const { return dialect().ir_context(); } + +} // namespace ir diff --git a/paddle/ir/type.h b/paddle/ir/type.h index c4c5663dda9..64e65e99688 100644 --- a/paddle/ir/type.h +++ b/paddle/ir/type.h @@ -52,6 +52,10 @@ class Type { StorageType *storage() const { return storage_; } + const Dialect &dialect() const { return storage_->abstract_type().dialect(); } + + IrContext *ir_context() const; + /// /// \brief Enable hashing Type. /// diff --git a/paddle/ir/type_base.h b/paddle/ir/type_base.h index aa800498f6e..803ac7eb721 100644 --- a/paddle/ir/type_base.h +++ b/paddle/ir/type_base.h @@ -19,6 +19,8 @@ #include "paddle/ir/type_id.h" namespace ir { +class Dialect; + /// /// \brief Abstract the properties and behaviors common to all Type classes into /// an AbstractType class. There are two types in Type system: @@ -32,8 +34,21 @@ class AbstractType { /// \brief Construct an AbstractType by TypeId directly. /// /// \param type_id The type id of the AbstractType. + /// \param dialect The Dialect which the type registered to. + /// + static AbstractType get(TypeId type_id, const Dialect &dialect) { + return AbstractType(type_id, dialect); + } + /// - static AbstractType get(TypeId type_id) { return AbstractType(type_id); } + /// \brief Construct an AbstractType by TypeId directly. + /// + /// \param dialect The Dialect which the type registered to. + /// + template + static AbstractType get(const Dialect &dialect) { + return AbstractType(TypeId::get(), dialect); + } /// /// \brief Returns the type id of the AbstractType. @@ -42,6 +57,13 @@ class AbstractType { /// TypeId type_id() const { return type_id_; } + /// + /// \brief Get the dialect this type was registered to. + /// + /// \return The dialect this type was registered to. + /// + const Dialect &dialect() const { return dialect_; } + /// /// \brief Find the AbstractType instance whose TypeId is type_id from /// IrContext. @@ -58,10 +80,14 @@ class AbstractType { /// get method to obtain and manage the AstractType. /// /// \param type_id The type id of the AbstractType. + /// \param dialect The Dialect which the type registered to. /// - explicit AbstractType(TypeId type_id) : type_id_(type_id) {} + explicit AbstractType(TypeId type_id, const Dialect &dialect) + : type_id_(type_id), dialect_(dialect) {} TypeId type_id_; + + const Dialect &dialect_; }; struct TypeManager; @@ -239,13 +265,13 @@ struct TypeManager { /// /// \brief This macro definition is used to register custom Type class. /// -#define REGISTER_TYPE_2_IRCONTEXT(concrete_type, ir_context) \ - ir::AbstractType *abstract_type_##concrete_type = new ir::AbstractType( \ - std::move(ir::AbstractType::get(ir::TypeId::get()))); \ - \ - ir_context->RegisterAbstractType(ir::TypeId::get(), \ - abstract_type_##concrete_type); \ - \ - ir::TypeManager::RegisterType(ir_context); +#define REGISTER_TYPE_2_IRCONTEXT(concrete_type, dialect) \ + ir::AbstractType *abstract_type_##concrete_type = new ir::AbstractType( \ + std::move(ir::AbstractType::get(*dialect))); \ + \ + dialect->ir_context()->RegisterAbstractType( \ + ir::TypeId::get(), abstract_type_##concrete_type); \ + \ + ir::TypeManager::RegisterType(dialect->ir_context()); } // namespace ir -- GitLab