未验证 提交 b0a604cb 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Type system stage3: add class Dialect (#50959)

* add dialect

* add some interface for dialect

* add some dialect interfaces for class Type

* set WITH_NEWIR=OFF

* refine code by comment

* polish code

* refine include style

* refine log for debug
上级 8f156fd7
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/builtin_dialect.h"
#include "paddle/ir/builtin_type.h"
namespace ir {
BuiltinDialect::BuiltinDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<BuiltinDialect>()) {
initialize();
}
void BuiltinDialect::initialize() {
// Register all built-in types defined in builtin_type.h.
RegisterTypes<GET_BUILT_IN_TYPE_LIST>();
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/dialect.h"
namespace ir {
///
/// \brief Built-in Dialect: automatically registered into global IrContext,
/// all built-in types defined in builtin_type.h will be registered in this
/// Dialect.
///
class BuiltinDialect : public ir::Dialect {
public:
explicit BuiltinDialect(ir::IrContext *context);
///
/// \brief Each Dialect needs to provide a name function to return the name of
/// the Dialect.
///
/// \return The name of this Dialect.
///
static const char *name() { return "builtin"; }
private:
void initialize();
};
} // namespace ir
...@@ -17,9 +17,18 @@ ...@@ -17,9 +17,18 @@
#include "paddle/ir/type.h" #include "paddle/ir/type.h"
namespace ir { namespace ir {
///
/// \brief This macro is used to get a list of all built-in types in this file.
///
#define GET_BUILT_IN_TYPE_LIST ir::Float32Type, ir::Int32Type
/// ///
/// \brief Definitions of built-in type classes. The built-in type object get /// \brief Definitions of built-in type classes. The built-in type object get
/// method is as follows: Type fp32 = Float32Type::get(ctx); /// method is as follows:
/// \code{cpp}
/// ir::IrContext *ctx = ir::IrContext::Instance();
/// Type fp32 = Float32Type::get(ctx);
/// \endcode
/// ///
class Float32Type : public ir::Type { class Float32Type : public ir::Type {
public: public:
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/dialect.h"
namespace ir {
Dialect::Dialect(std::string name, ir::IrContext *context, ir::TypeId id)
: name_(std::move(name)), context_(context), id_(id) {}
void Dialect::RegisterType(ir::AbstractType &&abstract_type) {
ir::AbstractType *new_abstract_type =
new ir::AbstractType(std::move(abstract_type));
this->ir_context()->RegisterAbstractType(new_abstract_type->type_id(),
new_abstract_type);
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/ir_context.h"
#include "paddle/ir/type_base.h"
namespace ir {
///
/// \brief Dialect can basically be understood as a namespace. In Dialect, we
/// can define a series of types, operations, etc. An instance of the dialect
/// object will be loaded into the global IrContext. Specific compilers only
/// need to combine existing dialects and add their own extensions or
/// customizations.
///
class Dialect {
public:
Dialect(std::string name, ir::IrContext *context, ir::TypeId id);
const std::string &name() const { return name_; }
ir::IrContext *ir_context() const { return context_; }
ir::TypeId id() const { return id_; }
///
/// \brief Register all types contained in the template parameter Args.
/// To register only one Type, you can use the RegisterType template function.
///
template <typename... Args>
void RegisterTypes() {
(void)std::initializer_list<int>{0, (RegisterType<Args>(), 0)...};
}
///
/// \brief Register type of class T.
///
template <typename T>
void RegisterType() {
VLOG(4) << "Type registered into Dialect. --->";
ir::AbstractType *abstract_type =
new ir::AbstractType(std::move(ir::AbstractType::get<T>(*this)));
this->ir_context()->RegisterAbstractType(ir::TypeId::get<T>(),
abstract_type);
ir::TypeManager::RegisterType<T>(this->ir_context());
VLOG(4) << "----------------------------------";
}
///
/// \brief Register abstract_type into context.
/// NOTE: It's not recommended to use this interface directly. This interface
/// only registers abstract_type. To register TypeStorage into context, you
/// need to call ir::TypeManager::RegisterType<T>() additionally,
/// RegisterType<T>() is recommended to use.
///
void RegisterType(ir::AbstractType &&abstract_type);
private:
std::string name_;
ir::IrContext *context_; // not owned
ir::TypeId id_;
};
} // namespace ir
...@@ -12,61 +12,95 @@ ...@@ -12,61 +12,95 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/ir/ir_context.h"
#include <unordered_map> #include <unordered_map>
#include "paddle/ir/builtin_dialect.h"
#include "paddle/ir/builtin_type.h" #include "paddle/ir/builtin_type.h"
#include "paddle/ir/ir_context.h" #include "paddle/ir/dialect.h"
#include "paddle/ir/spin_lock.h" #include "paddle/ir/spin_lock.h"
#include "paddle/ir/type_base.h" #include "paddle/ir/type_base.h"
namespace ir { namespace ir {
// The implementation class of the IrContext class // The implementation class of the IrContext class, cache registered
// AbstractType, TypeStorage, Dialect.
class IrContextImpl { class IrContextImpl {
public: public:
IrContextImpl() {} IrContextImpl() {}
~IrContextImpl() { ~IrContextImpl() {
std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_); std::lock_guard<ir::SpinLock> guard(destructor_lock_);
for (auto abstract_type_map : registed_abstract_types_) { for (auto &abstract_type_map : registed_abstract_types_) {
delete abstract_type_map.second; delete abstract_type_map.second;
} }
registed_abstract_types_.clear(); registed_abstract_types_.clear();
for (auto &dialect_map : registed_dialect_) {
delete dialect_map.second;
}
registed_dialect_.clear();
} }
void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) { void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) {
std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_); std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
VLOG(4) << "IrContext register an abstract_type of: [TypeId_hash=" VLOG(4) << "Register an abstract_type of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << std::hash<ir::TypeId>()(type_id)
<< ", AbstractType_ptr=" << abstract_type << "]."; << ", AbstractType_ptr=" << abstract_type << "].";
registed_abstract_types_.emplace(type_id, abstract_type); registed_abstract_types_.emplace(type_id, abstract_type);
} }
AbstractType *lookup(ir::TypeId type_id) { AbstractType *GetAbstractType(ir::TypeId type_id) {
std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_); std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
auto iter = registed_abstract_types_.find(type_id); auto iter = registed_abstract_types_.find(type_id);
if (iter == registed_abstract_types_.end()) { if (iter != registed_abstract_types_.end()) {
VLOG(4) << "IrContext not fonund cached abstract_type of: [TypeId_hash=" VLOG(4) << "Fonund a cached abstract_type of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
return nullptr;
} else {
VLOG(4) << "IrContext fonund a cached abstract_type of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << std::hash<ir::TypeId>()(type_id)
<< ", AbstractType_ptr=" << iter->second << "]."; << ", AbstractType_ptr=" << iter->second << "].";
return iter->second; return iter->second;
} }
LOG(WARNING) << "No cache found abstract_type of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
return nullptr;
} }
ir::SpinLock registed_abstract_types_lock_; void RegisterDialect(std::string name, Dialect *dialect) {
std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
VLOG(4) << "Register a dialect of: [name=" << name
<< ", dialect_ptr=" << dialect << "].";
registed_dialect_.emplace(name, dialect);
}
Dialect *GetDialect(std::string name) {
std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
auto iter = registed_dialect_.find(name);
if (iter != registed_dialect_.end()) {
VLOG(4) << "Fonund a cached dialect of: [name=" << name
<< ", dialect_ptr=" << iter->second << "].";
return iter->second;
}
LOG(WARNING) << "No cache fonund dialect of: [name=" << name << "].";
return nullptr;
}
// Cached AbstractType instances. // Cached AbstractType instances.
std::unordered_map<TypeId, AbstractType *> registed_abstract_types_; std::unordered_map<TypeId, AbstractType *> registed_abstract_types_;
ir::SpinLock registed_abstract_types_lock_;
// TypeStorage uniquer and cache instances. // TypeStorage uniquer and cache instances.
StorageManager registed_storage_manager_; StorageManager registed_storage_manager_;
// Some built-in type. // The dialcet registered in the context.
std::unordered_map<std::string, Dialect *> registed_dialect_;
ir::SpinLock registed_dialect_lock_;
// Some built-in types.
Float32Type fp32_type; Float32Type fp32_type;
Int32Type int32_type; Int32Type int32_type;
ir::SpinLock destructor_lock_;
}; };
IrContext *IrContext::Instance() { IrContext *IrContext::Instance() {
...@@ -75,13 +109,12 @@ IrContext *IrContext::Instance() { ...@@ -75,13 +109,12 @@ IrContext *IrContext::Instance() {
} }
IrContext::IrContext() : impl_(new IrContextImpl()) { IrContext::IrContext() : impl_(new IrContextImpl()) {
VLOG(4) << "IrContext register built-in type..."; VLOG(4) << "BuiltinDialect registered into IrContext. ===>";
REGISTER_TYPE_2_IRCONTEXT(Float32Type, this); GetOrRegisterDialect<BuiltinDialect>();
VLOG(4) << "==============================================";
impl_->fp32_type = TypeManager::get<Float32Type>(this); impl_->fp32_type = TypeManager::get<Float32Type>(this);
VLOG(4) << "Float32Type registration complete";
REGISTER_TYPE_2_IRCONTEXT(Int32Type, this);
impl_->int32_type = TypeManager::get<Int32Type>(this); impl_->int32_type = TypeManager::get<Int32Type>(this);
VLOG(4) << "Int32Type registration complete";
} }
void IrContext::RegisterAbstractType(ir::TypeId type_id, void IrContext::RegisterAbstractType(ir::TypeId type_id,
...@@ -98,12 +131,41 @@ std::unordered_map<TypeId, AbstractType *> ...@@ -98,12 +131,41 @@ std::unordered_map<TypeId, AbstractType *>
return impl().registed_abstract_types_; return impl().registed_abstract_types_;
} }
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) { Dialect *IrContext::GetOrRegisterDialect(
VLOG(4) << "Lookup abstract type [TypeId_hash=" std::string dialect_name, std::function<Dialect *()> constructor) {
<< std::hash<ir::TypeId>()(type_id) << "] from IrContext [ptr=" << ctx VLOG(4) << "Try to get or register a Dialect of: [name=" << dialect_name
<< "]."; << "].";
Dialect *dialect = impl().GetDialect(dialect_name);
if (dialect == nullptr) {
VLOG(4) << "Create and register a new Dialect of: [name=" << dialect_name
<< "].";
dialect = constructor();
impl().RegisterDialect(dialect_name, dialect);
}
return dialect;
}
std::vector<Dialect *> IrContext::GetRegisteredDialects() {
std::vector<Dialect *> result;
for (auto dialect_map : impl().registed_dialect_) {
result.push_back(dialect_map.second);
}
return result;
}
Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) {
for (auto dialect_map : impl().registed_dialect_) {
if (dialect_map.first == dialect_name) {
return dialect_map.second;
}
}
LOG(WARNING) << "No dialect registered for " << dialect_name;
return nullptr;
}
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
auto &impl = ctx->impl(); auto &impl = ctx->impl();
AbstractType *abstract_type = impl.lookup(type_id); AbstractType *abstract_type = impl.GetAbstractType(type_id);
if (abstract_type) { if (abstract_type) {
return *abstract_type; return *abstract_type;
} else { } else {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <glog/logging.h> #include <glog/logging.h>
#include <functional>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
...@@ -23,6 +24,7 @@ class IrContextImpl; ...@@ -23,6 +24,7 @@ class IrContextImpl;
class StorageManager; class StorageManager;
class AbstractType; class AbstractType;
class TypeId; class TypeId;
class Dialect;
/// ///
/// \brief IrContext is a global parameterless class used to store and manage /// \brief IrContext is a global parameterless class used to store and manage
...@@ -69,6 +71,63 @@ class IrContext { ...@@ -69,6 +71,63 @@ class IrContext {
/// ///
std::unordered_map<TypeId, AbstractType *> &registed_abstracted_type(); std::unordered_map<TypeId, AbstractType *> &registed_abstracted_type();
///
/// \brief Get the dialect of the DialectT class in the context, ff not found,
/// create and register to context.
///
/// \param DialectT The Dialect class that needs to be found or register.
///
/// \return The dialect of the DialectT class in the context.
///
template <typename DialectT>
DialectT *GetOrRegisterDialect() {
return static_cast<DialectT *>(
GetOrRegisterDialect(DialectT::name(), [this]() {
DialectT *dialect = new DialectT(this);
return dialect;
}));
}
///
/// \brief Get the dialect of the DialectT class in the context, ff not found,
/// create and register to context.
///
/// \param dialect_name The dialect name.
/// \param dialect_id The TypeId of the dialect.
/// \param constructor The dialect constructor.
///
/// \return The dialect named "dialect_name" in the context.
///
Dialect *GetOrRegisterDialect(std::string dialect_name,
std::function<Dialect *()> constructor);
///
/// \brief Get the dialect list registered to the context.
///
/// \return The dialect list registered to the context.
///
std::vector<Dialect *> GetRegisteredDialects();
///
/// \brief Get the dialect named "name" from the context.
///
/// \param name The name of the dialect to be obtained.
///
/// \return The dialect named "name" from the context.
///
Dialect *GetRegisteredDialect(const std::string &dialect_name);
///
/// \brief Get a registered dialect for the given dialect type T. The
/// Dialect must provide a static 'name' method.
///
/// \return The registered dialect for the given dialect type T.
///
template <typename T>
T *GetRegisteredDialect() {
return static_cast<T *>(GetRegisteredDialect(T::name()));
}
IrContext(const IrContext &) = delete; IrContext(const IrContext &) = delete;
void operator=(const IrContext &) = delete; void operator=(const IrContext &) = delete;
......
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/ir/storage_manager.h"
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include "paddle/ir/storage_manager.h"
namespace ir { namespace ir {
// This is a structure for creating, caching, and looking up Storage of // This is a structure for creating, caching, and looking up Storage of
// parameteric types. // parameteric types.
...@@ -72,7 +72,7 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageTypeImpl( ...@@ -72,7 +72,7 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageTypeImpl(
std::function<bool(const StorageBase *)> equal_func, std::function<bool(const StorageBase *)> equal_func,
std::function<StorageBase *()> constructor) { std::function<StorageBase *()> constructor) {
std::lock_guard<ir::SpinLock> guard(parametric_instance_lock_); std::lock_guard<ir::SpinLock> guard(parametric_instance_lock_);
VLOG(4) << "StorageManager get parameteretric storage of: [TypeId_hash=" VLOG(4) << "Try to get a parameteretric storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << ", param_hash=" << hash_value << std::hash<ir::TypeId>()(type_id) << ", param_hash=" << hash_value
<< "]."; << "].";
if (parametric_instance_.find(type_id) == parametric_instance_.end()) if (parametric_instance_.find(type_id) == parametric_instance_.end())
...@@ -83,18 +83,18 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageTypeImpl( ...@@ -83,18 +83,18 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageTypeImpl(
StorageManager::StorageBase *StorageManager::GetParameterlessStorageTypeImpl( StorageManager::StorageBase *StorageManager::GetParameterlessStorageTypeImpl(
TypeId type_id) { TypeId type_id) {
std::lock_guard<ir::SpinLock> guard(parameterless_instances_lock_); std::lock_guard<ir::SpinLock> guard(parameterless_instance_lock_);
VLOG(4) << "StorageManager get parameterless storage of: [TypeId_hash=" VLOG(4) << "Try to get a parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "]."; << std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instances_.find(type_id) == parameterless_instances_.end()) if (parameterless_instance_.find(type_id) == parameterless_instance_.end())
throw("TypeId not found in IrContext."); throw("TypeId not found in IrContext.");
StorageBase *parameterless_instance = parameterless_instances_[type_id]; StorageBase *parameterless_instance = parameterless_instance_[type_id];
return parameterless_instance; return parameterless_instance;
} }
void StorageManager::RegisterParametricStorageTypeImpl(TypeId type_id) { void StorageManager::RegisterParametricStorageTypeImpl(TypeId type_id) {
std::lock_guard<ir::SpinLock> guard(parametric_instance_lock_); std::lock_guard<ir::SpinLock> guard(parametric_instance_lock_);
VLOG(4) << "StorageManager register parameteric storage of: [TypeId_hash=" VLOG(4) << "Register a parameteric storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "]."; << std::hash<ir::TypeId>()(type_id) << "].";
parametric_instance_.emplace(type_id, parametric_instance_.emplace(type_id,
std::make_unique<ParametricStorageManager>()); std::make_unique<ParametricStorageManager>());
...@@ -102,12 +102,12 @@ void StorageManager::RegisterParametricStorageTypeImpl(TypeId type_id) { ...@@ -102,12 +102,12 @@ void StorageManager::RegisterParametricStorageTypeImpl(TypeId type_id) {
void StorageManager::RegisterParameterlessStorageTypeImpl( void StorageManager::RegisterParameterlessStorageTypeImpl(
TypeId type_id, std::function<StorageBase *()> constructor) { TypeId type_id, std::function<StorageBase *()> constructor) {
std::lock_guard<ir::SpinLock> guard(parameterless_instances_lock_); std::lock_guard<ir::SpinLock> guard(parameterless_instance_lock_);
VLOG(4) << "StorageManager register parameterless storage of: [TypeId_hash=" VLOG(4) << "Register a parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "]."; << std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instances_.find(type_id) != parameterless_instances_.end()) if (parameterless_instance_.find(type_id) != parameterless_instance_.end())
throw("storage class already registered"); throw("storage class already registered");
parameterless_instances_.emplace(type_id, constructor()); parameterless_instance_.emplace(type_id, constructor());
} }
} // namespace ir } // namespace ir
...@@ -141,9 +141,9 @@ class StorageManager { ...@@ -141,9 +141,9 @@ class StorageManager {
ir::SpinLock parametric_instance_lock_; ir::SpinLock parametric_instance_lock_;
// This map is a mapping between type id and parameterless type storage. // This map is a mapping between type id and parameterless type storage.
std::unordered_map<TypeId, StorageBase *> parameterless_instances_; std::unordered_map<TypeId, StorageBase *> parameterless_instance_;
ir::SpinLock parameterless_instances_lock_; ir::SpinLock parameterless_instance_lock_;
}; };
} // namespace ir } // namespace ir
...@@ -15,22 +15,27 @@ ...@@ -15,22 +15,27 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <unordered_map> #include <unordered_map>
#include "paddle/ir/builtin_dialect.h"
#include "paddle/ir/builtin_type.h" #include "paddle/ir/builtin_type.h"
#include "paddle/ir/dialect.h"
#include "paddle/ir/ir_context.h" #include "paddle/ir/ir_context.h"
#include "paddle/ir/type.h"
#include "paddle/ir/type_base.h" #include "paddle/ir/type_base.h"
TEST(type_test, type_id) { TEST(type_test, type_id) {
// Define two empty classes, just for testing.
class TypeA {}; class TypeA {};
class TypeB {}; class TypeB {};
// (1) Test construct TypeId by TypeId::Get() // Test 1: Test construct TypeId by TypeId::get<T>() and overloaded operator==
// method.
ir::TypeId a_id = ir::TypeId::get<TypeA>(); ir::TypeId a_id = ir::TypeId::get<TypeA>();
ir::TypeId a_other_id = ir::TypeId::get<TypeA>(); ir::TypeId a_other_id = ir::TypeId::get<TypeA>();
ir::TypeId b_id = ir::TypeId::get<TypeB>(); ir::TypeId b_id = ir::TypeId::get<TypeB>();
EXPECT_EQ(a_id, a_other_id); EXPECT_EQ(a_id, a_other_id);
EXPECT_NE(a_id, b_id); EXPECT_NE(a_id, b_id);
// (2) Test TypeId hash // Test 2: Test the hash function of TypeId.
std::unordered_map<ir::TypeId, ir::TypeId *> type_id_register; std::unordered_map<ir::TypeId, ir::TypeId *> type_id_register;
type_id_register.emplace(a_id, &a_id); type_id_register.emplace(a_id, &a_id);
type_id_register.emplace(b_id, &b_id); type_id_register.emplace(b_id, &b_id);
...@@ -39,32 +44,38 @@ TEST(type_test, type_id) { ...@@ -39,32 +44,38 @@ TEST(type_test, type_id) {
} }
} }
TEST(type_test, abstract_type) { TEST(type_test, type_base) {
// Define two empty classes, just for testing.
class TypeA {}; class TypeA {};
ir::TypeId a_id = ir::TypeId::get<TypeA>(); // Define a FakeDialect without registering any types.
ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id); struct FakeDialect : ir::Dialect {
explicit FakeDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<FakeDialect>()) {}
static const char *name() { return "fake"; }
};
EXPECT_EQ(abstract_type_a.type_id(), a_id); // Test 1: Test the function of IrContext to register Dialect.
} ir::IrContext *ctx = ir::IrContext::Instance();
ir::Dialect *fake_dialect = ctx->GetOrRegisterDialect<FakeDialect>();
TEST(type_test, type_storage) {
class TypeA {};
// Test 2: Test the get method of AbstractType.
ir::TypeId a_id = ir::TypeId::get<TypeA>(); ir::TypeId a_id = ir::TypeId::get<TypeA>();
ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id); ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id, *fake_dialect);
EXPECT_EQ(abstract_type_a.type_id(), a_id);
// Test 3: Test the constructor of TypeStorage.
ir::TypeStorage storage_a(&abstract_type_a); ir::TypeStorage storage_a(&abstract_type_a);
EXPECT_EQ(storage_a.abstract_type().type_id(), abstract_type_a.type_id()); EXPECT_EQ(storage_a.abstract_type().type_id(), abstract_type_a.type_id());
} }
TEST(type_test, built_in_type) { TEST(type_test, built_in_type) {
// Test creation of built-in parameterless type. // Test 1: Test the built-in type of IrContext.
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ir::Type fp32_1 = ir::Float32Type::get(ctx); ir::Type fp32_1 = ir::Float32Type::get(ctx);
// Test interfaces of class Type // Test 2: Test the interfaces of class Type: judgment, type_id,
// abstract_type, classof.
ir::Type fp32_2 = ir::Float32Type::get(ctx); ir::Type fp32_2 = ir::Float32Type::get(ctx);
EXPECT_EQ(fp32_1 == fp32_2, 1); EXPECT_EQ(fp32_1 == fp32_2, 1);
EXPECT_EQ(fp32_1 != fp32_2, 0); EXPECT_EQ(fp32_1 != fp32_2, 0);
...@@ -84,6 +95,7 @@ TEST(type_test, built_in_type) { ...@@ -84,6 +95,7 @@ TEST(type_test, built_in_type) {
EXPECT_EQ(ir::Int32Type::classof(int32_1), 1); EXPECT_EQ(ir::Int32Type::classof(int32_1), 1);
} }
// Customize a parameterized TypeStorage IntegerTypeStorage.
struct IntegerTypeStorage : public ir::TypeStorage { struct IntegerTypeStorage : public ir::TypeStorage {
IntegerTypeStorage(unsigned width, unsigned signedness) IntegerTypeStorage(unsigned width, unsigned signedness)
: width_(width), signedness_(signedness) {} : width_(width), signedness_(signedness) {}
...@@ -113,19 +125,50 @@ struct IntegerTypeStorage : public ir::TypeStorage { ...@@ -113,19 +125,50 @@ struct IntegerTypeStorage : public ir::TypeStorage {
} }
}; };
// Customize a parameterized type: IntegerType, storage type is
// IntegerTypeStorage.
class IntegerType : public ir::Type { class IntegerType : public ir::Type {
public: public:
using Type::Type; using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(IntegerType, IntegerTypeStorage); DECLARE_TYPE_UTILITY_FUNCTOR(IntegerType, IntegerTypeStorage);
}; };
TEST(type_test, parameteric_type) { // Customize a Dialect IntegerDialect, registration type of IntegerType.
struct IntegerDialect : ir::Dialect {
explicit IntegerDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<IntegerDialect>()) {
RegisterType<IntegerType>();
}
static const char *name() { return "integer"; }
};
TEST(type_test, custom_type_dialect) {
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
REGISTER_TYPE_2_IRCONTEXT(IntegerType, ctx);
// Test 1: Test the function of IrContext to register Dialect.
ctx->GetOrRegisterDialect<IntegerDialect>();
ir::Type int1_1 = IntegerType::get(ctx, 1, 0); ir::Type int1_1 = IntegerType::get(ctx, 1, 0);
ir::Type int1_2 = IntegerType::get(ctx, 1, 0); ir::Type int1_2 = IntegerType::get(ctx, 1, 0);
EXPECT_EQ(int1_1 == int1_2, 1); EXPECT_EQ(int1_1 == int1_2, 1);
ir::Type int8 = IntegerType::get(ctx, 8, 0); ir::Type int8 = IntegerType::get(ctx, 8, 0);
EXPECT_EQ(int8 == int1_2, 0); EXPECT_EQ(int8 == int1_2, 0);
// Test 2: Test Dialect interfaces
EXPECT_EQ(ctx == int8.ir_context(), 1);
EXPECT_EQ(int8.dialect().id() == ir::TypeId::get<IntegerDialect>(), 1);
std::vector<ir::Dialect *> dialect_list = ctx->GetRegisteredDialects();
EXPECT_EQ(dialect_list.size() == 3, 1); // integer, builtin, fake
ir::Dialect *dialect_builtin1 = ctx->GetRegisteredDialect("builtin");
ir::Dialect *dialect_builtin2 =
ctx->GetRegisteredDialect<ir::BuiltinDialect>();
EXPECT_EQ(dialect_builtin1 == dialect_builtin2, 1);
ir::Dialect *dialect_integer1 = ctx->GetRegisteredDialect("integer");
ir::Dialect *dialect_integer2 = ctx->GetRegisteredDialect<IntegerDialect>();
EXPECT_EQ(dialect_integer1 == dialect_integer2, 1);
} }
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/type.h"
#include "paddle/ir/dialect.h"
namespace ir {
IrContext *Type::ir_context() const { return dialect().ir_context(); }
} // namespace ir
...@@ -52,6 +52,10 @@ class Type { ...@@ -52,6 +52,10 @@ class Type {
StorageType *storage() const { return storage_; } StorageType *storage() const { return storage_; }
const Dialect &dialect() const { return storage_->abstract_type().dialect(); }
IrContext *ir_context() const;
/// ///
/// \brief Enable hashing Type. /// \brief Enable hashing Type.
/// ///
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#include "paddle/ir/type_id.h" #include "paddle/ir/type_id.h"
namespace ir { namespace ir {
class Dialect;
/// ///
/// \brief Abstract the properties and behaviors common to all Type classes into /// \brief Abstract the properties and behaviors common to all Type classes into
/// an AbstractType class. There are two types in Type system: /// an AbstractType class. There are two types in Type system:
...@@ -32,8 +34,21 @@ class AbstractType { ...@@ -32,8 +34,21 @@ class AbstractType {
/// \brief Construct an AbstractType by TypeId directly. /// \brief Construct an AbstractType by TypeId directly.
/// ///
/// \param type_id The type id of the AbstractType. /// \param type_id The type id of the AbstractType.
/// \param dialect The Dialect which the type registered to.
///
static AbstractType get(TypeId type_id, const Dialect &dialect) {
return AbstractType(type_id, dialect);
}
/// ///
static AbstractType get(TypeId type_id) { return AbstractType(type_id); } /// \brief Construct an AbstractType by TypeId directly.
///
/// \param dialect The Dialect which the type registered to.
///
template <typename T>
static AbstractType get(const Dialect &dialect) {
return AbstractType(TypeId::get<T>(), dialect);
}
/// ///
/// \brief Returns the type id of the AbstractType. /// \brief Returns the type id of the AbstractType.
...@@ -42,6 +57,13 @@ class AbstractType { ...@@ -42,6 +57,13 @@ class AbstractType {
/// ///
TypeId type_id() const { return type_id_; } TypeId type_id() const { return type_id_; }
///
/// \brief Get the dialect this type was registered to.
///
/// \return The dialect this type was registered to.
///
const Dialect &dialect() const { return dialect_; }
/// ///
/// \brief Find the AbstractType instance whose TypeId is type_id from /// \brief Find the AbstractType instance whose TypeId is type_id from
/// IrContext. /// IrContext.
...@@ -58,10 +80,14 @@ class AbstractType { ...@@ -58,10 +80,14 @@ class AbstractType {
/// get method to obtain and manage the AstractType. /// get method to obtain and manage the AstractType.
/// ///
/// \param type_id The type id of the AbstractType. /// \param type_id The type id of the AbstractType.
/// \param dialect The Dialect which the type registered to.
/// ///
explicit AbstractType(TypeId type_id) : type_id_(type_id) {} explicit AbstractType(TypeId type_id, const Dialect &dialect)
: type_id_(type_id), dialect_(dialect) {}
TypeId type_id_; TypeId type_id_;
const Dialect &dialect_;
}; };
struct TypeManager; struct TypeManager;
...@@ -239,13 +265,13 @@ struct TypeManager { ...@@ -239,13 +265,13 @@ struct TypeManager {
/// ///
/// \brief This macro definition is used to register custom Type class. /// \brief This macro definition is used to register custom Type class.
/// ///
#define REGISTER_TYPE_2_IRCONTEXT(concrete_type, ir_context) \ #define REGISTER_TYPE_2_IRCONTEXT(concrete_type, dialect) \
ir::AbstractType *abstract_type_##concrete_type = new ir::AbstractType( \ ir::AbstractType *abstract_type_##concrete_type = new ir::AbstractType( \
std::move(ir::AbstractType::get(ir::TypeId::get<concrete_type>()))); \ std::move(ir::AbstractType::get<concrete_type>(*dialect))); \
\ \
ir_context->RegisterAbstractType(ir::TypeId::get<concrete_type>(), \ dialect->ir_context()->RegisterAbstractType( \
abstract_type_##concrete_type); \ ir::TypeId::get<concrete_type>(), abstract_type_##concrete_type); \
\ \
ir::TypeManager::RegisterType<concrete_type>(ir_context); ir::TypeManager::RegisterType<concrete_type>(dialect->ir_context());
} // namespace ir } // namespace ir
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册