diff --git a/paddle/ir/CMakeLists.txt b/paddle/ir/CMakeLists.txt index d104ae066b59cd1ad353972ea9ea2a42524efd2d..5e6af70335a1077b296bd81066923734c20b57e2 100644 --- a/paddle/ir/CMakeLists.txt +++ b/paddle/ir/CMakeLists.txt @@ -2,4 +2,12 @@ if(NOT WITH_NEWIR) return() endif() -add_subdirectory(type) +set(NEWIR_SOURCE_DIR "${PADDLE_SOURCE_DIR}/paddle/ir") +set(NEWIR_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/ir") + +# ir tests +add_subdirectory(tests) + +file(GLOB IR_SRCS "*.cc") + +cc_library(new_ir SRCS ${IR_SRCS}) diff --git a/paddle/ir/builtin_type.h b/paddle/ir/builtin_type.h new file mode 100644 index 0000000000000000000000000000000000000000..77159794bf11ed5ce977d89defeaa733236e8af4 --- /dev/null +++ b/paddle/ir/builtin_type.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/type.h" + +namespace ir { +/// +/// \brief Definitions of built-in type classes. The built-in type object get +/// method is as follows: Type fp32 = Float32Type::get(ctx); +/// +class Float32Type : public ir::Type { + public: + using Type::Type; + + DECLARE_TYPE_UTILITY_FUNCTOR(Float32Type, ir::TypeStorage); + + static Float32Type get(ir::IrContext *context); +}; + +class Int32Type : public ir::Type { + public: + using Type::Type; + + DECLARE_TYPE_UTILITY_FUNCTOR(Int32Type, ir::TypeStorage); + + static Int32Type get(ir::IrContext *context); +}; + +} // namespace ir diff --git a/paddle/ir/ir_context.cc b/paddle/ir/ir_context.cc new file mode 100644 index 0000000000000000000000000000000000000000..6aed600903438eabe7542498083de8d0008c1225 --- /dev/null +++ b/paddle/ir/ir_context.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/ir/builtin_type.h" +#include "paddle/ir/ir_context.h" +#include "paddle/ir/spin_lock.h" +#include "paddle/ir/type_base.h" + +namespace ir { +// The implementation class of the IrContext class +class IrContextImpl { + public: + IrContextImpl() {} + + ~IrContextImpl() { + std::lock_guard guard(registed_abstract_types_lock_); + for (auto abstract_type_map : registed_abstract_types_) { + delete abstract_type_map.second; + } + registed_abstract_types_.clear(); + } + + void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) { + std::lock_guard guard(registed_abstract_types_lock_); + VLOG(4) << "IrContext register an abstract_type of: [TypeId_hash=" + << std::hash()(type_id) + << ", AbstractType_ptr=" << abstract_type << "]."; + registed_abstract_types_.emplace(type_id, abstract_type); + } + + AbstractType *lookup(ir::TypeId type_id) { + std::lock_guard guard(registed_abstract_types_lock_); + auto iter = registed_abstract_types_.find(type_id); + if (iter == registed_abstract_types_.end()) { + VLOG(4) << "IrContext not fonund cached abstract_type of: [TypeId_hash=" + << std::hash()(type_id) << "]."; + return nullptr; + } else { + VLOG(4) << "IrContext fonund a cached abstract_type of: [TypeId_hash=" + << std::hash()(type_id) + << ", AbstractType_ptr=" << iter->second << "]."; + return iter->second; + } + } + + ir::SpinLock registed_abstract_types_lock_; + + // Cached AbstractType instances. + std::unordered_map registed_abstract_types_; + + // TypeStorage uniquer and cache instances. + StorageManager registed_storage_manager_; + + // Some built-in type. + Float32Type fp32_type; + Int32Type int32_type; +}; + +IrContext *IrContext::Instance() { + static IrContext context; + return &context; +} + +IrContext::IrContext() : impl_(new IrContextImpl()) { + VLOG(4) << "IrContext register built-in type..."; + REGISTER_TYPE_2_IRCONTEXT(Float32Type, this); + impl_->fp32_type = TypeManager::get(this); + VLOG(4) << "Float32Type registration complete"; + REGISTER_TYPE_2_IRCONTEXT(Int32Type, this); + impl_->int32_type = TypeManager::get(this); + VLOG(4) << "Int32Type registration complete"; +} + +void IrContext::RegisterAbstractType(ir::TypeId type_id, + AbstractType *abstract_type) { + impl().RegisterAbstractType(type_id, abstract_type); +} + +StorageManager &IrContext::storage_manager() { + return impl().registed_storage_manager_; +} + +std::unordered_map + &IrContext::registed_abstracted_type() { + return impl().registed_abstract_types_; +} + +const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) { + VLOG(4) << "Lookup abstract type [TypeId_hash=" + << std::hash()(type_id) << "] from IrContext [ptr=" << ctx + << "]."; + auto &impl = ctx->impl(); + AbstractType *abstract_type = impl.lookup(type_id); + if (abstract_type) { + return *abstract_type; + } else { + throw("Abstract type not found in IrContext."); + } +} + +Float32Type Float32Type::get(IrContext *ctx) { return ctx->impl().fp32_type; } + +Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; } + +} // namespace ir diff --git a/paddle/ir/ir_context.h b/paddle/ir/ir_context.h new file mode 100644 index 0000000000000000000000000000000000000000..146497e6c6c94974a26ffc1c1238aff9e30b7a2f --- /dev/null +++ b/paddle/ir/ir_context.h @@ -0,0 +1,82 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace ir { +class IrContextImpl; +class StorageManager; +class AbstractType; +class TypeId; + +/// +/// \brief IrContext is a global parameterless class used to store and manage +/// Type and its related data structures. +/// +class IrContext { + public: + /// + /// \brief Initializes a new instance of IrContext. + /// + static IrContext *Instance(); + + /// + /// \brief Get an instance of IrContextImpl, a private member of IrContext. + /// For the specific definition of IrContextImpl, see ir_context.cc. + /// + /// \return The instance of IrContextImpl. + /// + IrContextImpl &impl() { return *impl_; } + + /// + /// \brief Register an AbstractType to IrContext + /// + /// \param type_id The type id of the AbstractType. + /// \param abstract_type AbstractType* provided by user. + /// + void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type); + + /// + /// \brief Returns the storage uniquer used for constructing TypeStorage + /// instances. + /// + /// \return The storage uniquer used for constructing TypeStorage + /// instances. + /// + StorageManager &storage_manager(); + + /// + /// \brief Returns the storage uniquer used for constructing TypeStorage + /// instances. + /// + /// \return The storage uniquer used for constructing TypeStorage + /// instances. + /// + std::unordered_map ®isted_abstracted_type(); + + IrContext(const IrContext &) = delete; + + void operator=(const IrContext &) = delete; + + private: + IrContext(); + + const std::unique_ptr impl_; +}; + +} // namespace ir diff --git a/paddle/ir/spin_lock.h b/paddle/ir/spin_lock.h new file mode 100644 index 0000000000000000000000000000000000000000..4150f419c31598b73b199382467dde0a99dce0f3 --- /dev/null +++ b/paddle/ir/spin_lock.h @@ -0,0 +1,66 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#if defined(_M_X64) || defined(__x86_64__) || defined(_M_IX86) || \ + defined(__i386__) +#define __PADDLE_x86__ +#include +#endif +#include +#include + +namespace ir { +static inline void CpuRelax() { +#if defined(__PADDLE_x86__) + _mm_pause(); +#endif +} + +class SpinLock { + public: + SpinLock() : mlock_(false) {} + + void lock() { + for (;;) { + if (!mlock_.exchange(true, std::memory_order_acquire)) { + break; + } + constexpr int kMaxLoop = 32; + for (int loop = 1; mlock_.load(std::memory_order_relaxed);) { + if (loop <= kMaxLoop) { + for (int i = 1; i <= loop; ++i) { + CpuRelax(); + } + loop *= 2; + } else { + std::this_thread::yield(); + } + } + } + } + + void unlock() { mlock_.store(false, std::memory_order_release); } + + private: + SpinLock(const SpinLock&) = delete; + SpinLock(SpinLock&&) = delete; + SpinLock& operator=(const SpinLock&) = delete; + SpinLock& operator=(SpinLock&&) = delete; + std::atomic mlock_; +}; + +} // namespace ir diff --git a/paddle/ir/storage_manager.cc b/paddle/ir/storage_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..991077e8777c4632bdf3bb2f77476d9c5909884f --- /dev/null +++ b/paddle/ir/storage_manager.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/ir/storage_manager.h" + +namespace ir { +// This is a structure for creating, caching, and looking up Storage of +// parameteric types. +struct ParametricStorageManager { + using StorageBase = StorageManager::StorageBase; + + ParametricStorageManager() {} + + ~ParametricStorageManager() { + for (const auto &instance : parametric_instances_) { + delete instance.second; + } + parametric_instances_.clear(); + } + + // Get the storage of parametric type, if not in the cache, create and + // insert the cache. + StorageBase *GetOrCreate(std::size_t hash_value, + std::function equal_func, + std::function constructor) { + if (parametric_instances_.count(hash_value) != 0) { + auto pr = parametric_instances_.equal_range(hash_value); + while (pr.first != pr.second) { + if (equal_func(pr.first->second)) { + VLOG(4) << "Found a cached parameteric storage of: [param_hash=" + << hash_value << ", storage_ptr=" << pr.first->second << "]."; + return pr.first->second; + } + ++pr.first; + } + } + StorageBase *storage = constructor(); + parametric_instances_.emplace(hash_value, storage); + VLOG(4) << "No cache found, construct and cache a new parameteric storage " + "of: [param_hash=" + << hash_value << ", storage_ptr=" << storage << "]."; + return storage; + } + + private: + // In order to prevent hash conflicts, the unordered_multimap data structure + // is used for storage. + std::unordered_multimap parametric_instances_; +}; + +StorageManager::StorageManager() {} + +StorageManager::~StorageManager() = default; + +StorageManager::StorageBase *StorageManager::GetParametricStorageTypeImpl( + TypeId type_id, + std::size_t hash_value, + std::function equal_func, + std::function constructor) { + std::lock_guard guard(parametric_instance_lock_); + VLOG(4) << "StorageManager get parameteretric storage of: [TypeId_hash=" + << std::hash()(type_id) << ", param_hash=" << hash_value + << "]."; + if (parametric_instance_.find(type_id) == parametric_instance_.end()) + throw("The input data pointer is null."); + ParametricStorageManager ¶metric_storage = *parametric_instance_[type_id]; + return parametric_storage.GetOrCreate(hash_value, equal_func, constructor); +} + +StorageManager::StorageBase *StorageManager::GetParameterlessStorageTypeImpl( + TypeId type_id) { + std::lock_guard guard(parameterless_instances_lock_); + VLOG(4) << "StorageManager get parameterless storage of: [TypeId_hash=" + << std::hash()(type_id) << "]."; + if (parameterless_instances_.find(type_id) == parameterless_instances_.end()) + throw("TypeId not found in IrContext."); + StorageBase *parameterless_instance = parameterless_instances_[type_id]; + return parameterless_instance; +} + +void StorageManager::RegisterParametricStorageTypeImpl(TypeId type_id) { + std::lock_guard guard(parametric_instance_lock_); + VLOG(4) << "StorageManager register parameteric storage of: [TypeId_hash=" + << std::hash()(type_id) << "]."; + parametric_instance_.emplace(type_id, + std::make_unique()); +} + +void StorageManager::RegisterParameterlessStorageTypeImpl( + TypeId type_id, std::function constructor) { + std::lock_guard guard(parameterless_instances_lock_); + VLOG(4) << "StorageManager register parameterless storage of: [TypeId_hash=" + << std::hash()(type_id) << "]."; + if (parameterless_instances_.find(type_id) != parameterless_instances_.end()) + throw("storage class already registered"); + parameterless_instances_.emplace(type_id, constructor()); +} + +} // namespace ir diff --git a/paddle/ir/storage_manager.h b/paddle/ir/storage_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..f94174586bc97e83cc52986b47237d006e5dd690 --- /dev/null +++ b/paddle/ir/storage_manager.h @@ -0,0 +1,149 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/ir/spin_lock.h" +#include "paddle/ir/type_id.h" + +namespace ir { +/// +/// \brief The implementation of the class StorageManager. +/// +// struct StorageManagerImpl; +struct ParametricStorageManager; + +/// +/// \brief A utility class for getting or creating Storage class instances. +/// Storage class must be a derived class of StorageManager::StorageBase. +/// There are two types of Storage class: +/// One is a parameterless type, which can directly obtain an instance through +/// the get method; The other is a parameteric type, which needs to comply with +/// the following conditions: (1) Need to define a type alias called ParamKey, +/// it serves as the unique identifier for the Storage class; (2) Need to +/// provide a hash method on the ParamKey for storage and access; (3) Need to +/// provide method 'bool operator==(const ParamKey &) const', used to compare +/// Storage instance and ParamKey instance. +/// +class StorageManager { + public: + /// + /// \brief This class is the base class of all storage classes, + /// and any type of storage needs to inherit from this class. + /// + class StorageBase { + protected: + StorageBase() = default; + }; + + StorageManager(); + + ~StorageManager(); + + /// + /// \brief Get a unique storage instance of parametric Type. + /// + /// \param init_func Used to initialize a newly inserted storage instance. + /// \param type_id The type id of the AbstractType. + /// \param args Parameters of the wrapped function. + /// \return A uniqued instance of Storage. + /// + template + Storage *GetParametricStorageType(std::function init_func, + TypeId type_id, + Args &&...args) { + typename Storage::ParamKey param = + typename Storage::ParamKey(std::forward(args)...); + std::size_t hash_value = Storage::HashValue(param); + auto equal_func = [¶m](const StorageBase *existing) { + return static_cast(*existing) == param; + }; + auto constructor = [&]() { + auto *storage = Storage::Construct(param); + if (init_func) init_func(storage); + return storage; + }; + return static_cast(GetParametricStorageTypeImpl( + type_id, hash_value, equal_func, constructor)); + } + + /// + /// \brief Get a unique storage instance of parameterless Type. + /// + /// \param type_id The type id of the AbstractType. + /// \return A uniqued instance of Storage. + /// + template + Storage *GetParameterlessStorageType(TypeId type_id) { + return static_cast(GetParameterlessStorageTypeImpl(type_id)); + } + + /// + /// \brief Register a new parametric storage class. + /// + /// \param type_id The type id of the AbstractType. + /// + template + void RegisterParametricStorageType(TypeId type_id) { + return RegisterParametricStorageTypeImpl(type_id); + } + + /// + /// \brief Register a new parameterless storage class. + /// + /// \param type_id The type id of the AbstractType. + /// \param init_func Used to initialize a newly inserted storage instance. + /// + template + void RegisterParameterlessStorageType( + TypeId type_id, std::function init_func) { + auto constructor = [&]() { + auto *storage = new Storage(); + if (init_func) init_func(storage); + return storage; + }; + RegisterParameterlessStorageTypeImpl(type_id, constructor); + } + + private: + StorageBase *GetParametricStorageTypeImpl( + TypeId type_id, + std::size_t hash_value, + std::function equal_func, + std::function constructor); + + StorageBase *GetParameterlessStorageTypeImpl(TypeId type_id); + + void RegisterParametricStorageTypeImpl(TypeId type_id); + + void RegisterParameterlessStorageTypeImpl( + TypeId type_id, std::function constructor); + + // This map is a mapping between type id and parameteric type storage. + std::unordered_map> + parametric_instance_; + + ir::SpinLock parametric_instance_lock_; + + // This map is a mapping between type id and parameterless type storage. + std::unordered_map parameterless_instances_; + + ir::SpinLock parameterless_instances_lock_; +}; + +} // namespace ir diff --git a/paddle/ir/tests/CMakeLists.txt b/paddle/ir/tests/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a1200519faff6fe9d7b44061b57bd9aca7590429 --- /dev/null +++ b/paddle/ir/tests/CMakeLists.txt @@ -0,0 +1 @@ +cc_test_old(type_test SRCS type_test.cc DEPS new_ir gtest) diff --git a/paddle/ir/tests/type_test.cc b/paddle/ir/tests/type_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..85deb51b694d586d233b0f4ffe57256263e3a5d0 --- /dev/null +++ b/paddle/ir/tests/type_test.cc @@ -0,0 +1,131 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/ir/builtin_type.h" +#include "paddle/ir/ir_context.h" +#include "paddle/ir/type_base.h" + +TEST(type_test, type_id) { + class TypeA {}; + class TypeB {}; + + // (1) Test construct TypeId by TypeId::Get() + ir::TypeId a_id = ir::TypeId::get(); + ir::TypeId a_other_id = ir::TypeId::get(); + ir::TypeId b_id = ir::TypeId::get(); + EXPECT_EQ(a_id, a_other_id); + EXPECT_NE(a_id, b_id); + + // (2) Test TypeId hash + std::unordered_map type_id_register; + type_id_register.emplace(a_id, &a_id); + type_id_register.emplace(b_id, &b_id); + for (auto kv : type_id_register) { + EXPECT_EQ(kv.first, *kv.second); + } +} + +TEST(type_test, abstract_type) { + class TypeA {}; + + ir::TypeId a_id = ir::TypeId::get(); + ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id); + + EXPECT_EQ(abstract_type_a.type_id(), a_id); +} + +TEST(type_test, type_storage) { + class TypeA {}; + + ir::TypeId a_id = ir::TypeId::get(); + ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id); + + ir::TypeStorage storage_a(&abstract_type_a); + + EXPECT_EQ(storage_a.abstract_type().type_id(), abstract_type_a.type_id()); +} + +TEST(type_test, built_in_type) { + // Test creation of built-in parameterless type. + ir::IrContext *ctx = ir::IrContext::Instance(); + ir::Type fp32_1 = ir::Float32Type::get(ctx); + + // Test interfaces of class Type + ir::Type fp32_2 = ir::Float32Type::get(ctx); + EXPECT_EQ(fp32_1 == fp32_2, 1); + EXPECT_EQ(fp32_1 != fp32_2, 0); + EXPECT_EQ(fp32_1.type_id() == fp32_2.type_id(), 1); + EXPECT_EQ(&fp32_1.abstract_type() == + &ir::AbstractType::lookup(fp32_1.type_id(), ctx), + 1); + EXPECT_EQ(ir::Float32Type::classof(fp32_1), 1); + + ir::Type int32_1 = ir::Int32Type::get(ctx); + ir::Type int32_2 = ir::Int32Type::get(ctx); + EXPECT_EQ(int32_1 == int32_2, 1); + EXPECT_EQ(int32_1.type_id() == int32_2.type_id(), 1); + EXPECT_EQ(&int32_1.abstract_type() == + &ir::AbstractType::lookup(int32_1.type_id(), ctx), + 1); + EXPECT_EQ(ir::Int32Type::classof(int32_1), 1); +} + +struct IntegerTypeStorage : public ir::TypeStorage { + IntegerTypeStorage(unsigned width, unsigned signedness) + : width_(width), signedness_(signedness) {} + using ParamKey = std::pair; + + static std::size_t HashValue(const ParamKey &key) { + return hash_combine(std::hash()(std::get<0>(key)), + std::hash()(std::get<1>(key))); + } + + bool operator==(const ParamKey &key) const { + return ParamKey(width_, signedness_) == key; + } + + static IntegerTypeStorage *Construct(ParamKey key) { + return new IntegerTypeStorage(key.first, key.second); + } + + ParamKey GetAsKey() const { return ParamKey(width_, signedness_); } + + unsigned width_ : 30; + unsigned signedness_ : 2; + + private: + static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { + return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + } +}; + +class IntegerType : public ir::Type { + public: + using Type::Type; + DECLARE_TYPE_UTILITY_FUNCTOR(IntegerType, IntegerTypeStorage); +}; + +TEST(type_test, parameteric_type) { + ir::IrContext *ctx = ir::IrContext::Instance(); + REGISTER_TYPE_2_IRCONTEXT(IntegerType, ctx); + ir::Type int1_1 = IntegerType::get(ctx, 1, 0); + ir::Type int1_2 = IntegerType::get(ctx, 1, 0); + EXPECT_EQ(int1_1 == int1_2, 1); + + ir::Type int8 = IntegerType::get(ctx, 8, 0); + EXPECT_EQ(int8 == int1_2, 0); +} diff --git a/paddle/ir/type.h b/paddle/ir/type.h new file mode 100644 index 0000000000000000000000000000000000000000..c4c5663dda9db4baee548bf605529477b9c2b159 --- /dev/null +++ b/paddle/ir/type.h @@ -0,0 +1,76 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/type_base.h" + +namespace ir { +/// +/// \brief Unified interface of the Type class. Derivation of all Type classes +/// only derives interfaces, not members. For example, DenseTensorType, +/// Float32Type, etc. are all derived classes of Type, but no new member +/// variables will be added. +/// +class Type { + public: + using StorageType = TypeStorage; + + constexpr Type() = default; + + Type(const StorageType *storage) // NOLINT + : storage_(const_cast(storage)) {} + + Type(const Type &other) = default; + + Type &operator=(const Type &other) = default; + + /// + /// \brief Comparison operations. + /// + bool operator==(Type other) const { return storage_ == other.storage_; } + bool operator!=(Type other) const { return storage_ != other.storage_; } + + explicit operator bool() const { return storage_; } + + bool operator!() const { return storage_ == nullptr; } + + TypeId type_id() { return storage_->abstract_type().type_id(); } + + const AbstractType &abstract_type() { return storage_->abstract_type(); } + + StorageType *storage() const { return storage_; } + + /// + /// \brief Enable hashing Type. + /// + friend struct std::hash; + + protected: + StorageType *storage_{nullptr}; +}; + +} // namespace ir + +namespace std { +/// +/// \brief Enable hashing Type. +/// +template <> +struct hash { + std::size_t operator()(const ir::Type &obj) const { + return std::hash()(obj.storage_); + } +}; +} // namespace std diff --git a/paddle/ir/type/CMakeLists.txt b/paddle/ir/type/CMakeLists.txt deleted file mode 100644 index 649ac7d74b40768b7c65fa5b9421fce072d784dd..0000000000000000000000000000000000000000 --- a/paddle/ir/type/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -cc_test( - type_support_test - SRCS type_support_test.cc - DEPS gtest) diff --git a/paddle/ir/type/type_support.h b/paddle/ir/type/type_support.h deleted file mode 100644 index fc2a397768563bda0525168b02d1216cb612d6d6..0000000000000000000000000000000000000000 --- a/paddle/ir/type/type_support.h +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -namespace ir { - -/// \brief TypeId is the unique identification of Type, each Type corresponds to -/// a unique TypeId, the same id indicates the same Type class. TypeId provides -/// an instantiation interface: TypeId::get. -/// Example: -/// \code{cpp} -/// class TypeA {}; -/// TypeId type_a_id = TypeId::get(); -/// \endcode -class TypeId { - struct Storage {}; - - public: - /// \brief Returns the unique TypeId of Type T. - /// \return The unique TypeId of Type T. - template - static TypeId get() { - static Storage instance; - return TypeId(&instance); - } - - /// \brief Comparison operations. - inline bool operator==(const TypeId &other) const { - return storage_ == other.storage_; - } - - /// \brief Comparison operations. - inline bool operator!=(const TypeId &other) const { - return !(*this == other); - } - - /// \brief Enable hashing TypeId instances. - friend struct std::hash; - - private: - /// \brief Construct a TypeId and initialize storage. - /// \param storage The storage of this TypeId. - explicit TypeId(const Storage *storage) : storage_(storage) {} - - const Storage *storage_; -}; - -/// \brief Abstract the properties and behaviors common to all Type classes into -/// an AbstractType class. There are two types in Type system: -/// on-parameter/singleton type and parameter-type. The common attributes of all -/// types is TypeId (and possibly others). Therefore, construct a class with -/// TypeId as its member. -class AbstractType { - public: - /// \brief Construct an AbstractType by TypeId directly. - /// \param type_id The type id of the AbstractType. - static AbstractType get(TypeId type_id) { return AbstractType(type_id); } - - /// \brief Returns the type id of the AbstractType. - /// \return The type id of the AbstractType. - TypeId type_id() const { return type_id_; } - - /* TODO(zhangbo9674): After the IRContext is designed, AbstractType will be - * cached to IRContext with TypeId as key. - */ - - private: - /// \brief The constructor is set to private and provides the user with the - /// get method to obtain and manage the AstractType. - /// \param type_id The type id of the AbstractType. - explicit AbstractType(TypeId type_id) : type_id_(type_id) {} - - TypeId type_id_; -}; - -/// \brief TypeStorage is used to store all information of a Type. A Type object -/// contains a TypeStorage. For non-parameter type, the information includes: -/// TypeId, so TypeStorage only needs to include AbstractType; For parameter -/// type, in addition to AbstractType/TypeId, parameter information needs to be -/// included. So that, non-parameter type can be constructed by TypeStorage -/// directly but parameter type should be constructed by Derived TypeStorage. -class TypeStorage { - public: - /// \brief Construct a TypeStorage and initialize abstract_type. - /// \param abstract_type The abstract_type of this TypeStorage. - explicit TypeStorage(AbstractType *abstract_type) - : abstract_type_(abstract_type) {} - - /// \brief Returns the AbstractType of the TypeStorage. - /// \return The AbstractType of the TypeStorage. - const AbstractType &abstract_type() { return *abstract_type_; } - - private: - AbstractType *abstract_type_{nullptr}; -}; - -} // namespace ir - -// Custom specialization of std::hash can be injected in namespace std. -namespace std { -/// \brief Enable hashing TypeId instances. -template <> -struct hash { - std::size_t operator()(const ir::TypeId &obj) const { - return std::hash()(obj.storage_); - } -}; -} // namespace std diff --git a/paddle/ir/type/type_support_test.cc b/paddle/ir/type/type_support_test.cc deleted file mode 100644 index 061ef8e31c8b7cad899f83cc1977aa8083afd1b5..0000000000000000000000000000000000000000 --- a/paddle/ir/type/type_support_test.cc +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/ir/type/type_support.h" -#include -#include - -TEST(type_support, type_id) { - class TypeA {}; - class TypeB {}; - - // (1) Test construct TypeId by TypeId::Get() - ir::TypeId a_id = ir::TypeId::get(); - ir::TypeId a_other_id = ir::TypeId::get(); - ir::TypeId b_id = ir::TypeId::get(); - EXPECT_EQ(a_id, a_other_id); - EXPECT_NE(a_id, b_id); - - // (2) Test TypeId hash - std::unordered_map type_id_register; - type_id_register.emplace(a_id, &a_id); - type_id_register.emplace(b_id, &b_id); - for (auto kv : type_id_register) { - EXPECT_EQ(kv.first, *kv.second); - } -} - -TEST(type_support, abstract_type) { - class TypeA {}; - - ir::TypeId a_id = ir::TypeId::get(); - ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id); - - EXPECT_EQ(abstract_type_a.type_id(), a_id); -} - -TEST(type_support, type_storage) { - class TypeA {}; - - ir::TypeId a_id = ir::TypeId::get(); - ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id); - - ir::TypeStorage storage_a(&abstract_type_a); - - EXPECT_EQ(storage_a.abstract_type().type_id(), abstract_type_a.type_id()); -} diff --git a/paddle/ir/type_base.h b/paddle/ir/type_base.h new file mode 100644 index 0000000000000000000000000000000000000000..aa800498f6eadbf450d218c6fad2522e21c0360e --- /dev/null +++ b/paddle/ir/type_base.h @@ -0,0 +1,251 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/ir_context.h" +#include "paddle/ir/storage_manager.h" +#include "paddle/ir/type_id.h" + +namespace ir { +/// +/// \brief Abstract the properties and behaviors common to all Type classes into +/// an AbstractType class. There are two types in Type system: +/// on-parameter/parameterless type and parameter-type. The common attributes of +/// all types is TypeId (and possibly others). Therefore, construct a class with +/// TypeId as its member. +/// +class AbstractType { + public: + /// + /// \brief Construct an AbstractType by TypeId directly. + /// + /// \param type_id The type id of the AbstractType. + /// + static AbstractType get(TypeId type_id) { return AbstractType(type_id); } + + /// + /// \brief Returns the type id of the AbstractType. + /// + /// \return The type id of the AbstractType. + /// + TypeId type_id() const { return type_id_; } + + /// + /// \brief Find the AbstractType instance whose TypeId is type_id from + /// IrContext. + /// + /// \param type_id The type id of the AbstractType. + /// \param ctx The IrContext. + /// \return The AbstractType instance whose TypeId is type_id. + /// + static const AbstractType &lookup(TypeId type_id, IrContext *ctx); + + private: + /// + /// \brief The constructor is set to private and provides the user with the + /// get method to obtain and manage the AstractType. + /// + /// \param type_id The type id of the AbstractType. + /// + explicit AbstractType(TypeId type_id) : type_id_(type_id) {} + + TypeId type_id_; +}; + +struct TypeManager; + +/// +/// \brief TypeStorage is used to store all information of a Type. A Type object +/// contains a TypeStorage. For non-parameter type, the information includes: +/// TypeId, so TypeStorage only needs to include AbstractType; For parameter +/// type, in addition to AbstractType/TypeId, parameter information needs to be +/// included. So that, non-parameter type can be constructed by TypeStorage +/// directly but parameter type should be constructed by Derived TypeStorage. +/// +class TypeStorage : public StorageManager::StorageBase { + friend StorageManager; + friend TypeManager; + + public: + /// + /// \brief Construct a TypeStorage and initialize abstract_type. + /// + /// \param abstract_type The abstract_type of this TypeStorage. + /// + explicit TypeStorage(AbstractType *abstract_type) + : abstract_type_(abstract_type) {} + + TypeStorage() {} + + /// + /// \brief Returns the AbstractType of the TypeStorage. + /// + /// \return The AbstractType of the TypeStorage. + /// + const AbstractType &abstract_type() { return *abstract_type_; } + + private: + /// + /// \brief Initialize TypeStorage based on the AbstractType* provided by the + /// user + /// + /// \param abstract_type AbstractType* provided by the user, the + /// construction method of AbstractType refers to AbstractType::get. + /// + void initialize(const AbstractType &abstract_type) { + abstract_type_ = const_cast(&abstract_type); + } + + AbstractType *abstract_type_{nullptr}; // not owned +}; + +/// +/// \brief TypeManager is a utility class that provides interfaces for get or +/// unique Type instances in IrContext. +/// +struct TypeManager { + /// + /// \brief Get a unique instance of Type T from IrContext. Note: For a + /// parameteric_type, if not found in IrContext, it will try to create a new + /// instance and register it to IrContext; for a parameterless type, only + /// search. + /// + /// \param ctx The IrContext instance. + /// \param args Parameters of the wrapped function. + /// \return The unique instance of Type T from IrContext. + /// + template + static T get(IrContext *ctx, Args &&...args) { + return get( + ctx, ir::TypeId::get(), std::forward(args)...); + } + + /// + /// \brief Get a unique instance of parametric Type T from IrContext. If not + /// found in IrContext, it will try to create a new instance and register it + /// to IrContext; + /// + /// \param ctx The IrContext instance. + /// \param type_id The type id of the AbstractType. + /// \param args Parameters of the wrapped function. + /// \return The unique instance of Type T from IrContext. + /// + template + static std:: + enable_if_t::value, T> + get(IrContext *ctx, TypeId type_id, Args &&...args) { + return ctx->storage_manager() + .GetParametricStorageType( + [&, type_id](TypeStorage *storage) { + storage->initialize(AbstractType::lookup(type_id, ctx)); + }, + type_id, + std::forward(args)...); + } + + /// + /// \brief Get a unique instance of parameterless Type T from IrContext, only + /// search. + /// + /// \param ctx The IrContext instance. + /// \param type_id The type id of the AbstractType. + /// \return The unique instance of Type T from IrContext. + /// + template + static std:: + enable_if_t::value, T> + get(IrContext *ctx, TypeId type_id) { + return ctx->storage_manager() + .GetParameterlessStorageType(type_id); + } + + /// + /// \brief Register a unique instance of Type T to IrContext. + /// + /// \param ctx The IrContext instance. + /// + template + static void RegisterType(IrContext *ctx) { + RegisterType(ctx, + ir::TypeId::get()); // class Type需要提供type_id接口 + } + + /// + /// \brief Register a unique instance of parametric Type T to IrContext. + /// + /// \param ctx The IrContext instance. + /// \param type_id The type id of the Type T. + /// + template + static std::enable_if_t< + !std::is_same::value> + RegisterType(IrContext *ctx, TypeId type_id) { + ctx->storage_manager() + .RegisterParametricStorageType(type_id); + } + + /// + /// \brief Register a unique instance of parameterless Type T to IrContext. + /// + /// \param ctx The IrContext instance. + /// \param type_id The type id of the Type T. + /// + template + static std::enable_if_t< + std::is_same::value> + RegisterType(IrContext *ctx, TypeId type_id) { + ctx->storage_manager().RegisterParameterlessStorageType( + type_id, [&ctx, type_id](TypeStorage *storage) { + storage->initialize(AbstractType::lookup(type_id, ctx)); + }); + } +}; + +/// +/// \brief This macro definition is used to add some necessary functions to the +/// custom Type class. +/// +#define DECLARE_TYPE_UTILITY_FUNCTOR(concrete_type, storage_type) \ + using StorageType = storage_type; \ + \ + StorageType *storage() const { \ + return static_cast(this->storage_); \ + } \ + \ + static ir::TypeId type_id() { return ir::TypeId::get(); } \ + \ + template \ + static bool classof(T val) { \ + return val.type_id() == type_id(); \ + } \ + \ + template \ + static concrete_type get(ir::IrContext *ctx, Args... args) { \ + return ir::TypeManager::template get(ctx, args...); \ + } + +/// +/// \brief This macro definition is used to register custom Type class. +/// +#define REGISTER_TYPE_2_IRCONTEXT(concrete_type, ir_context) \ + ir::AbstractType *abstract_type_##concrete_type = new ir::AbstractType( \ + std::move(ir::AbstractType::get(ir::TypeId::get()))); \ + \ + ir_context->RegisterAbstractType(ir::TypeId::get(), \ + abstract_type_##concrete_type); \ + \ + ir::TypeManager::RegisterType(ir_context); + +} // namespace ir diff --git a/paddle/ir/type_id.h b/paddle/ir/type_id.h new file mode 100644 index 0000000000000000000000000000000000000000..b7a2dcd362d012eb200a3f909eb78f60a43e8cea --- /dev/null +++ b/paddle/ir/type_id.h @@ -0,0 +1,86 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace ir { + +/// +/// \brief TypeId is the unique identification of Type, each Type corresponds to +/// a unique TypeId, the same id indicates the same Type class. TypeId provides +/// an instantiation interface: TypeId::get. +/// +/// Example: +/// \code{cpp} +/// class TypeA {}; +/// TypeId type_a_id = TypeId::get(); +/// \endcode +/// +class TypeId { + struct Storage {}; + + public: + /// + /// \brief Returns the unique TypeId of Type T. + /// + /// \return The unique TypeId of Type T. + /// + template + static TypeId get() { + static Storage instance; + return TypeId(&instance); + } + + /// + /// \brief Comparison operations. + /// + inline bool operator==(const TypeId &other) const { + return storage_ == other.storage_; + } + inline bool operator!=(const TypeId &other) const { + return !(*this == other); + } + + /// + /// \brief Enable hashing TypeId instances. + /// + friend struct std::hash; + + private: + /// + /// \brief Construct a TypeId and initialize storage. + /// + /// \param storage The storage of this TypeId. + /// + explicit TypeId(const Storage *storage) : storage_(storage) {} + + const Storage *storage_; +}; + +} // namespace ir + +namespace std { +/// +/// \brief Enable hashing TypeId instances. +/// +template <> +struct hash { + std::size_t operator()(const ir::TypeId &obj) const { + return std::hash()(obj.storage_); + } +}; +} // namespace std