未验证 提交 a5827f0e 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Type system stage2: add class Type, type uniquer utils, class IRContext (#50412)

* add TypeUniquer and IrContext

* refine include code

* add Type, TypeBase

* add built-in type

* add bulit-in Float32Type

* refine ut

* refine code

* refine code

* delete type_base

* rename ImplType to StorageType

* rename ImplType to StorageType

* add macros util for register type

* add macros util for register type

* refine name

* refine name

* change storage manager

* add multi_thread for ir_ctx

* rwlock_2_spinlock, add REGISTER_TYPE_2_IRCONTEXT

* DECLARE_TYPE_UTILITY_FUNCTOR

* refine ircontext singleton

* del destructor for ParametricStorageManager

* refine code

* Add necessary logs for debugging

* refine ir_context instance

* refine type get interface

* refine code by comment
上级 0d12afea
......@@ -2,4 +2,12 @@ if(NOT WITH_NEWIR)
return()
endif()
add_subdirectory(type)
set(NEWIR_SOURCE_DIR "${PADDLE_SOURCE_DIR}/paddle/ir")
set(NEWIR_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/ir")
# ir tests
add_subdirectory(tests)
file(GLOB IR_SRCS "*.cc")
cc_library(new_ir SRCS ${IR_SRCS})
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/type.h"
namespace ir {
///
/// \brief Definitions of built-in type classes. The built-in type object get
/// method is as follows: Type fp32 = Float32Type::get(ctx);
///
class Float32Type : public ir::Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(Float32Type, ir::TypeStorage);
static Float32Type get(ir::IrContext *context);
};
class Int32Type : public ir::Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(Int32Type, ir::TypeStorage);
static Int32Type get(ir::IrContext *context);
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unordered_map>
#include "paddle/ir/builtin_type.h"
#include "paddle/ir/ir_context.h"
#include "paddle/ir/spin_lock.h"
#include "paddle/ir/type_base.h"
namespace ir {
// The implementation class of the IrContext class
class IrContextImpl {
public:
IrContextImpl() {}
~IrContextImpl() {
std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
for (auto abstract_type_map : registed_abstract_types_) {
delete abstract_type_map.second;
}
registed_abstract_types_.clear();
}
void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) {
std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
VLOG(4) << "IrContext register an abstract_type of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id)
<< ", AbstractType_ptr=" << abstract_type << "].";
registed_abstract_types_.emplace(type_id, abstract_type);
}
AbstractType *lookup(ir::TypeId type_id) {
std::lock_guard<ir::SpinLock> guard(registed_abstract_types_lock_);
auto iter = registed_abstract_types_.find(type_id);
if (iter == registed_abstract_types_.end()) {
VLOG(4) << "IrContext not fonund cached abstract_type of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
return nullptr;
} else {
VLOG(4) << "IrContext fonund a cached abstract_type of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id)
<< ", AbstractType_ptr=" << iter->second << "].";
return iter->second;
}
}
ir::SpinLock registed_abstract_types_lock_;
// Cached AbstractType instances.
std::unordered_map<TypeId, AbstractType *> registed_abstract_types_;
// TypeStorage uniquer and cache instances.
StorageManager registed_storage_manager_;
// Some built-in type.
Float32Type fp32_type;
Int32Type int32_type;
};
IrContext *IrContext::Instance() {
static IrContext context;
return &context;
}
IrContext::IrContext() : impl_(new IrContextImpl()) {
VLOG(4) << "IrContext register built-in type...";
REGISTER_TYPE_2_IRCONTEXT(Float32Type, this);
impl_->fp32_type = TypeManager::get<Float32Type>(this);
VLOG(4) << "Float32Type registration complete";
REGISTER_TYPE_2_IRCONTEXT(Int32Type, this);
impl_->int32_type = TypeManager::get<Int32Type>(this);
VLOG(4) << "Int32Type registration complete";
}
void IrContext::RegisterAbstractType(ir::TypeId type_id,
AbstractType *abstract_type) {
impl().RegisterAbstractType(type_id, abstract_type);
}
StorageManager &IrContext::storage_manager() {
return impl().registed_storage_manager_;
}
std::unordered_map<TypeId, AbstractType *>
&IrContext::registed_abstracted_type() {
return impl().registed_abstract_types_;
}
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
VLOG(4) << "Lookup abstract type [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "] from IrContext [ptr=" << ctx
<< "].";
auto &impl = ctx->impl();
AbstractType *abstract_type = impl.lookup(type_id);
if (abstract_type) {
return *abstract_type;
} else {
throw("Abstract type not found in IrContext.");
}
}
Float32Type Float32Type::get(IrContext *ctx) { return ctx->impl().fp32_type; }
Int32Type Int32Type::get(IrContext *ctx) { return ctx->impl().int32_type; }
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <memory>
#include <unordered_map>
namespace ir {
class IrContextImpl;
class StorageManager;
class AbstractType;
class TypeId;
///
/// \brief IrContext is a global parameterless class used to store and manage
/// Type and its related data structures.
///
class IrContext {
public:
///
/// \brief Initializes a new instance of IrContext.
///
static IrContext *Instance();
///
/// \brief Get an instance of IrContextImpl, a private member of IrContext.
/// For the specific definition of IrContextImpl, see ir_context.cc.
///
/// \return The instance of IrContextImpl.
///
IrContextImpl &impl() { return *impl_; }
///
/// \brief Register an AbstractType to IrContext
///
/// \param type_id The type id of the AbstractType.
/// \param abstract_type AbstractType* provided by user.
///
void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type);
///
/// \brief Returns the storage uniquer used for constructing TypeStorage
/// instances.
///
/// \return The storage uniquer used for constructing TypeStorage
/// instances.
///
StorageManager &storage_manager();
///
/// \brief Returns the storage uniquer used for constructing TypeStorage
/// instances.
///
/// \return The storage uniquer used for constructing TypeStorage
/// instances.
///
std::unordered_map<TypeId, AbstractType *> &registed_abstracted_type();
IrContext(const IrContext &) = delete;
void operator=(const IrContext &) = delete;
private:
IrContext();
const std::unique_ptr<IrContextImpl> impl_;
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#if defined(_M_X64) || defined(__x86_64__) || defined(_M_IX86) || \
defined(__i386__)
#define __PADDLE_x86__
#include <immintrin.h>
#endif
#include <mutex>
#include <thread>
namespace ir {
static inline void CpuRelax() {
#if defined(__PADDLE_x86__)
_mm_pause();
#endif
}
class SpinLock {
public:
SpinLock() : mlock_(false) {}
void lock() {
for (;;) {
if (!mlock_.exchange(true, std::memory_order_acquire)) {
break;
}
constexpr int kMaxLoop = 32;
for (int loop = 1; mlock_.load(std::memory_order_relaxed);) {
if (loop <= kMaxLoop) {
for (int i = 1; i <= loop; ++i) {
CpuRelax();
}
loop *= 2;
} else {
std::this_thread::yield();
}
}
}
}
void unlock() { mlock_.store(false, std::memory_order_release); }
private:
SpinLock(const SpinLock&) = delete;
SpinLock(SpinLock&&) = delete;
SpinLock& operator=(const SpinLock&) = delete;
SpinLock& operator=(SpinLock&&) = delete;
std::atomic<bool> mlock_;
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <unordered_map>
#include "paddle/ir/storage_manager.h"
namespace ir {
// This is a structure for creating, caching, and looking up Storage of
// parameteric types.
struct ParametricStorageManager {
using StorageBase = StorageManager::StorageBase;
ParametricStorageManager() {}
~ParametricStorageManager() {
for (const auto &instance : parametric_instances_) {
delete instance.second;
}
parametric_instances_.clear();
}
// Get the storage of parametric type, if not in the cache, create and
// insert the cache.
StorageBase *GetOrCreate(std::size_t hash_value,
std::function<bool(const StorageBase *)> equal_func,
std::function<StorageBase *()> constructor) {
if (parametric_instances_.count(hash_value) != 0) {
auto pr = parametric_instances_.equal_range(hash_value);
while (pr.first != pr.second) {
if (equal_func(pr.first->second)) {
VLOG(4) << "Found a cached parameteric storage of: [param_hash="
<< hash_value << ", storage_ptr=" << pr.first->second << "].";
return pr.first->second;
}
++pr.first;
}
}
StorageBase *storage = constructor();
parametric_instances_.emplace(hash_value, storage);
VLOG(4) << "No cache found, construct and cache a new parameteric storage "
"of: [param_hash="
<< hash_value << ", storage_ptr=" << storage << "].";
return storage;
}
private:
// In order to prevent hash conflicts, the unordered_multimap data structure
// is used for storage.
std::unordered_multimap<size_t, StorageBase *> parametric_instances_;
};
StorageManager::StorageManager() {}
StorageManager::~StorageManager() = default;
StorageManager::StorageBase *StorageManager::GetParametricStorageTypeImpl(
TypeId type_id,
std::size_t hash_value,
std::function<bool(const StorageBase *)> equal_func,
std::function<StorageBase *()> constructor) {
std::lock_guard<ir::SpinLock> guard(parametric_instance_lock_);
VLOG(4) << "StorageManager get parameteretric storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << ", param_hash=" << hash_value
<< "].";
if (parametric_instance_.find(type_id) == parametric_instance_.end())
throw("The input data pointer is null.");
ParametricStorageManager &parametric_storage = *parametric_instance_[type_id];
return parametric_storage.GetOrCreate(hash_value, equal_func, constructor);
}
StorageManager::StorageBase *StorageManager::GetParameterlessStorageTypeImpl(
TypeId type_id) {
std::lock_guard<ir::SpinLock> guard(parameterless_instances_lock_);
VLOG(4) << "StorageManager get parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instances_.find(type_id) == parameterless_instances_.end())
throw("TypeId not found in IrContext.");
StorageBase *parameterless_instance = parameterless_instances_[type_id];
return parameterless_instance;
}
void StorageManager::RegisterParametricStorageTypeImpl(TypeId type_id) {
std::lock_guard<ir::SpinLock> guard(parametric_instance_lock_);
VLOG(4) << "StorageManager register parameteric storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
parametric_instance_.emplace(type_id,
std::make_unique<ParametricStorageManager>());
}
void StorageManager::RegisterParameterlessStorageTypeImpl(
TypeId type_id, std::function<StorageBase *()> constructor) {
std::lock_guard<ir::SpinLock> guard(parameterless_instances_lock_);
VLOG(4) << "StorageManager register parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instances_.find(type_id) != parameterless_instances_.end())
throw("storage class already registered");
parameterless_instances_.emplace(type_id, constructor());
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <type_traits>
#include <unordered_map>
#include "paddle/ir/spin_lock.h"
#include "paddle/ir/type_id.h"
namespace ir {
///
/// \brief The implementation of the class StorageManager.
///
// struct StorageManagerImpl;
struct ParametricStorageManager;
///
/// \brief A utility class for getting or creating Storage class instances.
/// Storage class must be a derived class of StorageManager::StorageBase.
/// There are two types of Storage class:
/// One is a parameterless type, which can directly obtain an instance through
/// the get method; The other is a parameteric type, which needs to comply with
/// the following conditions: (1) Need to define a type alias called ParamKey,
/// it serves as the unique identifier for the Storage class; (2) Need to
/// provide a hash method on the ParamKey for storage and access; (3) Need to
/// provide method 'bool operator==(const ParamKey &) const', used to compare
/// Storage instance and ParamKey instance.
///
class StorageManager {
public:
///
/// \brief This class is the base class of all storage classes,
/// and any type of storage needs to inherit from this class.
///
class StorageBase {
protected:
StorageBase() = default;
};
StorageManager();
~StorageManager();
///
/// \brief Get a unique storage instance of parametric Type.
///
/// \param init_func Used to initialize a newly inserted storage instance.
/// \param type_id The type id of the AbstractType.
/// \param args Parameters of the wrapped function.
/// \return A uniqued instance of Storage.
///
template <typename Storage, typename... Args>
Storage *GetParametricStorageType(std::function<void(Storage *)> init_func,
TypeId type_id,
Args &&...args) {
typename Storage::ParamKey param =
typename Storage::ParamKey(std::forward<Args>(args)...);
std::size_t hash_value = Storage::HashValue(param);
auto equal_func = [&param](const StorageBase *existing) {
return static_cast<const Storage &>(*existing) == param;
};
auto constructor = [&]() {
auto *storage = Storage::Construct(param);
if (init_func) init_func(storage);
return storage;
};
return static_cast<Storage *>(GetParametricStorageTypeImpl(
type_id, hash_value, equal_func, constructor));
}
///
/// \brief Get a unique storage instance of parameterless Type.
///
/// \param type_id The type id of the AbstractType.
/// \return A uniqued instance of Storage.
///
template <typename Storage>
Storage *GetParameterlessStorageType(TypeId type_id) {
return static_cast<Storage *>(GetParameterlessStorageTypeImpl(type_id));
}
///
/// \brief Register a new parametric storage class.
///
/// \param type_id The type id of the AbstractType.
///
template <typename Storage>
void RegisterParametricStorageType(TypeId type_id) {
return RegisterParametricStorageTypeImpl(type_id);
}
///
/// \brief Register a new parameterless storage class.
///
/// \param type_id The type id of the AbstractType.
/// \param init_func Used to initialize a newly inserted storage instance.
///
template <typename Storage>
void RegisterParameterlessStorageType(
TypeId type_id, std::function<void(Storage *)> init_func) {
auto constructor = [&]() {
auto *storage = new Storage();
if (init_func) init_func(storage);
return storage;
};
RegisterParameterlessStorageTypeImpl(type_id, constructor);
}
private:
StorageBase *GetParametricStorageTypeImpl(
TypeId type_id,
std::size_t hash_value,
std::function<bool(const StorageBase *)> equal_func,
std::function<StorageBase *()> constructor);
StorageBase *GetParameterlessStorageTypeImpl(TypeId type_id);
void RegisterParametricStorageTypeImpl(TypeId type_id);
void RegisterParameterlessStorageTypeImpl(
TypeId type_id, std::function<StorageBase *()> constructor);
// This map is a mapping between type id and parameteric type storage.
std::unordered_map<TypeId, std::unique_ptr<ParametricStorageManager>>
parametric_instance_;
ir::SpinLock parametric_instance_lock_;
// This map is a mapping between type id and parameterless type storage.
std::unordered_map<TypeId, StorageBase *> parameterless_instances_;
ir::SpinLock parameterless_instances_lock_;
};
} // namespace ir
cc_test_old(type_test SRCS type_test.cc DEPS new_ir gtest)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <unordered_map>
#include "paddle/ir/builtin_type.h"
#include "paddle/ir/ir_context.h"
#include "paddle/ir/type_base.h"
TEST(type_test, type_id) {
class TypeA {};
class TypeB {};
// (1) Test construct TypeId by TypeId::Get()
ir::TypeId a_id = ir::TypeId::get<TypeA>();
ir::TypeId a_other_id = ir::TypeId::get<TypeA>();
ir::TypeId b_id = ir::TypeId::get<TypeB>();
EXPECT_EQ(a_id, a_other_id);
EXPECT_NE(a_id, b_id);
// (2) Test TypeId hash
std::unordered_map<ir::TypeId, ir::TypeId *> type_id_register;
type_id_register.emplace(a_id, &a_id);
type_id_register.emplace(b_id, &b_id);
for (auto kv : type_id_register) {
EXPECT_EQ(kv.first, *kv.second);
}
}
TEST(type_test, abstract_type) {
class TypeA {};
ir::TypeId a_id = ir::TypeId::get<TypeA>();
ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id);
EXPECT_EQ(abstract_type_a.type_id(), a_id);
}
TEST(type_test, type_storage) {
class TypeA {};
ir::TypeId a_id = ir::TypeId::get<TypeA>();
ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id);
ir::TypeStorage storage_a(&abstract_type_a);
EXPECT_EQ(storage_a.abstract_type().type_id(), abstract_type_a.type_id());
}
TEST(type_test, built_in_type) {
// Test creation of built-in parameterless type.
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Type fp32_1 = ir::Float32Type::get(ctx);
// Test interfaces of class Type
ir::Type fp32_2 = ir::Float32Type::get(ctx);
EXPECT_EQ(fp32_1 == fp32_2, 1);
EXPECT_EQ(fp32_1 != fp32_2, 0);
EXPECT_EQ(fp32_1.type_id() == fp32_2.type_id(), 1);
EXPECT_EQ(&fp32_1.abstract_type() ==
&ir::AbstractType::lookup(fp32_1.type_id(), ctx),
1);
EXPECT_EQ(ir::Float32Type::classof(fp32_1), 1);
ir::Type int32_1 = ir::Int32Type::get(ctx);
ir::Type int32_2 = ir::Int32Type::get(ctx);
EXPECT_EQ(int32_1 == int32_2, 1);
EXPECT_EQ(int32_1.type_id() == int32_2.type_id(), 1);
EXPECT_EQ(&int32_1.abstract_type() ==
&ir::AbstractType::lookup(int32_1.type_id(), ctx),
1);
EXPECT_EQ(ir::Int32Type::classof(int32_1), 1);
}
struct IntegerTypeStorage : public ir::TypeStorage {
IntegerTypeStorage(unsigned width, unsigned signedness)
: width_(width), signedness_(signedness) {}
using ParamKey = std::pair<unsigned, unsigned>;
static std::size_t HashValue(const ParamKey &key) {
return hash_combine(std::hash<unsigned>()(std::get<0>(key)),
std::hash<unsigned>()(std::get<1>(key)));
}
bool operator==(const ParamKey &key) const {
return ParamKey(width_, signedness_) == key;
}
static IntegerTypeStorage *Construct(ParamKey key) {
return new IntegerTypeStorage(key.first, key.second);
}
ParamKey GetAsKey() const { return ParamKey(width_, signedness_); }
unsigned width_ : 30;
unsigned signedness_ : 2;
private:
static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
}
};
class IntegerType : public ir::Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(IntegerType, IntegerTypeStorage);
};
TEST(type_test, parameteric_type) {
ir::IrContext *ctx = ir::IrContext::Instance();
REGISTER_TYPE_2_IRCONTEXT(IntegerType, ctx);
ir::Type int1_1 = IntegerType::get(ctx, 1, 0);
ir::Type int1_2 = IntegerType::get(ctx, 1, 0);
EXPECT_EQ(int1_1 == int1_2, 1);
ir::Type int8 = IntegerType::get(ctx, 8, 0);
EXPECT_EQ(int8 == int1_2, 0);
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/type_base.h"
namespace ir {
///
/// \brief Unified interface of the Type class. Derivation of all Type classes
/// only derives interfaces, not members. For example, DenseTensorType,
/// Float32Type, etc. are all derived classes of Type, but no new member
/// variables will be added.
///
class Type {
public:
using StorageType = TypeStorage;
constexpr Type() = default;
Type(const StorageType *storage) // NOLINT
: storage_(const_cast<StorageType *>(storage)) {}
Type(const Type &other) = default;
Type &operator=(const Type &other) = default;
///
/// \brief Comparison operations.
///
bool operator==(Type other) const { return storage_ == other.storage_; }
bool operator!=(Type other) const { return storage_ != other.storage_; }
explicit operator bool() const { return storage_; }
bool operator!() const { return storage_ == nullptr; }
TypeId type_id() { return storage_->abstract_type().type_id(); }
const AbstractType &abstract_type() { return storage_->abstract_type(); }
StorageType *storage() const { return storage_; }
///
/// \brief Enable hashing Type.
///
friend struct std::hash<Type>;
protected:
StorageType *storage_{nullptr};
};
} // namespace ir
namespace std {
///
/// \brief Enable hashing Type.
///
template <>
struct hash<ir::Type> {
std::size_t operator()(const ir::Type &obj) const {
return std::hash<ir::Type::StorageType *>()(obj.storage_);
}
};
} // namespace std
cc_test(
type_support_test
SRCS type_support_test.cc
DEPS gtest)
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/ir/type/type_support.h"
#include <gtest/gtest.h>
#include <unordered_map>
TEST(type_support, type_id) {
class TypeA {};
class TypeB {};
// (1) Test construct TypeId by TypeId::Get()
ir::TypeId a_id = ir::TypeId::get<TypeA>();
ir::TypeId a_other_id = ir::TypeId::get<TypeA>();
ir::TypeId b_id = ir::TypeId::get<TypeB>();
EXPECT_EQ(a_id, a_other_id);
EXPECT_NE(a_id, b_id);
// (2) Test TypeId hash
std::unordered_map<ir::TypeId, ir::TypeId*> type_id_register;
type_id_register.emplace(a_id, &a_id);
type_id_register.emplace(b_id, &b_id);
for (auto kv : type_id_register) {
EXPECT_EQ(kv.first, *kv.second);
}
}
TEST(type_support, abstract_type) {
class TypeA {};
ir::TypeId a_id = ir::TypeId::get<TypeA>();
ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id);
EXPECT_EQ(abstract_type_a.type_id(), a_id);
}
TEST(type_support, type_storage) {
class TypeA {};
ir::TypeId a_id = ir::TypeId::get<TypeA>();
ir::AbstractType abstract_type_a = ir::AbstractType::get(a_id);
ir::TypeStorage storage_a(&abstract_type_a);
EXPECT_EQ(storage_a.abstract_type().type_id(), abstract_type_a.type_id());
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/ir_context.h"
#include "paddle/ir/storage_manager.h"
#include "paddle/ir/type_id.h"
namespace ir {
///
/// \brief Abstract the properties and behaviors common to all Type classes into
/// an AbstractType class. There are two types in Type system:
/// on-parameter/parameterless type and parameter-type. The common attributes of
/// all types is TypeId (and possibly others). Therefore, construct a class with
/// TypeId as its member.
///
class AbstractType {
public:
///
/// \brief Construct an AbstractType by TypeId directly.
///
/// \param type_id The type id of the AbstractType.
///
static AbstractType get(TypeId type_id) { return AbstractType(type_id); }
///
/// \brief Returns the type id of the AbstractType.
///
/// \return The type id of the AbstractType.
///
TypeId type_id() const { return type_id_; }
///
/// \brief Find the AbstractType instance whose TypeId is type_id from
/// IrContext.
///
/// \param type_id The type id of the AbstractType.
/// \param ctx The IrContext.
/// \return The AbstractType instance whose TypeId is type_id.
///
static const AbstractType &lookup(TypeId type_id, IrContext *ctx);
private:
///
/// \brief The constructor is set to private and provides the user with the
/// get method to obtain and manage the AstractType.
///
/// \param type_id The type id of the AbstractType.
///
explicit AbstractType(TypeId type_id) : type_id_(type_id) {}
TypeId type_id_;
};
struct TypeManager;
///
/// \brief TypeStorage is used to store all information of a Type. A Type object
/// contains a TypeStorage. For non-parameter type, the information includes:
/// TypeId, so TypeStorage only needs to include AbstractType; For parameter
/// type, in addition to AbstractType/TypeId, parameter information needs to be
/// included. So that, non-parameter type can be constructed by TypeStorage
/// directly but parameter type should be constructed by Derived TypeStorage.
///
class TypeStorage : public StorageManager::StorageBase {
friend StorageManager;
friend TypeManager;
public:
///
/// \brief Construct a TypeStorage and initialize abstract_type.
///
/// \param abstract_type The abstract_type of this TypeStorage.
///
explicit TypeStorage(AbstractType *abstract_type)
: abstract_type_(abstract_type) {}
TypeStorage() {}
///
/// \brief Returns the AbstractType of the TypeStorage.
///
/// \return The AbstractType of the TypeStorage.
///
const AbstractType &abstract_type() { return *abstract_type_; }
private:
///
/// \brief Initialize TypeStorage based on the AbstractType* provided by the
/// user
///
/// \param abstract_type AbstractType* provided by the user, the
/// construction method of AbstractType refers to AbstractType::get.
///
void initialize(const AbstractType &abstract_type) {
abstract_type_ = const_cast<AbstractType *>(&abstract_type);
}
AbstractType *abstract_type_{nullptr}; // not owned
};
///
/// \brief TypeManager is a utility class that provides interfaces for get or
/// unique Type instances in IrContext.
///
struct TypeManager {
///
/// \brief Get a unique instance of Type T from IrContext. Note: For a
/// parameteric_type, if not found in IrContext, it will try to create a new
/// instance and register it to IrContext; for a parameterless type, only
/// search.
///
/// \param ctx The IrContext instance.
/// \param args Parameters of the wrapped function.
/// \return The unique instance of Type T from IrContext.
///
template <typename T, typename... Args>
static T get(IrContext *ctx, Args &&...args) {
return get<T, Args...>(
ctx, ir::TypeId::get<T>(), std::forward<Args>(args)...);
}
///
/// \brief Get a unique instance of parametric Type T from IrContext. If not
/// found in IrContext, it will try to create a new instance and register it
/// to IrContext;
///
/// \param ctx The IrContext instance.
/// \param type_id The type id of the AbstractType.
/// \param args Parameters of the wrapped function.
/// \return The unique instance of Type T from IrContext.
///
template <typename T, typename... Args>
static std::
enable_if_t<!std::is_same<typename T::StorageType, TypeStorage>::value, T>
get(IrContext *ctx, TypeId type_id, Args &&...args) {
return ctx->storage_manager()
.GetParametricStorageType<typename T::StorageType>(
[&, type_id](TypeStorage *storage) {
storage->initialize(AbstractType::lookup(type_id, ctx));
},
type_id,
std::forward<Args>(args)...);
}
///
/// \brief Get a unique instance of parameterless Type T from IrContext, only
/// search.
///
/// \param ctx The IrContext instance.
/// \param type_id The type id of the AbstractType.
/// \return The unique instance of Type T from IrContext.
///
template <typename T>
static std::
enable_if_t<std::is_same<typename T::StorageType, TypeStorage>::value, T>
get(IrContext *ctx, TypeId type_id) {
return ctx->storage_manager()
.GetParameterlessStorageType<typename T::StorageType>(type_id);
}
///
/// \brief Register a unique instance of Type T to IrContext.
///
/// \param ctx The IrContext instance.
///
template <typename T>
static void RegisterType(IrContext *ctx) {
RegisterType<T>(ctx,
ir::TypeId::get<T>()); // class Type需要提供type_id接口
}
///
/// \brief Register a unique instance of parametric Type T to IrContext.
///
/// \param ctx The IrContext instance.
/// \param type_id The type id of the Type T.
///
template <typename T>
static std::enable_if_t<
!std::is_same<typename T::StorageType, TypeStorage>::value>
RegisterType(IrContext *ctx, TypeId type_id) {
ctx->storage_manager()
.RegisterParametricStorageType<typename T::StorageType>(type_id);
}
///
/// \brief Register a unique instance of parameterless Type T to IrContext.
///
/// \param ctx The IrContext instance.
/// \param type_id The type id of the Type T.
///
template <typename T>
static std::enable_if_t<
std::is_same<typename T::StorageType, TypeStorage>::value>
RegisterType(IrContext *ctx, TypeId type_id) {
ctx->storage_manager().RegisterParameterlessStorageType<TypeStorage>(
type_id, [&ctx, type_id](TypeStorage *storage) {
storage->initialize(AbstractType::lookup(type_id, ctx));
});
}
};
///
/// \brief This macro definition is used to add some necessary functions to the
/// custom Type class.
///
#define DECLARE_TYPE_UTILITY_FUNCTOR(concrete_type, storage_type) \
using StorageType = storage_type; \
\
StorageType *storage() const { \
return static_cast<StorageType *>(this->storage_); \
} \
\
static ir::TypeId type_id() { return ir::TypeId::get<concrete_type>(); } \
\
template <typename T> \
static bool classof(T val) { \
return val.type_id() == type_id(); \
} \
\
template <typename... Args> \
static concrete_type get(ir::IrContext *ctx, Args... args) { \
return ir::TypeManager::template get<concrete_type>(ctx, args...); \
}
///
/// \brief This macro definition is used to register custom Type class.
///
#define REGISTER_TYPE_2_IRCONTEXT(concrete_type, ir_context) \
ir::AbstractType *abstract_type_##concrete_type = new ir::AbstractType( \
std::move(ir::AbstractType::get(ir::TypeId::get<concrete_type>()))); \
\
ir_context->RegisterAbstractType(ir::TypeId::get<concrete_type>(), \
abstract_type_##concrete_type); \
\
ir::TypeManager::RegisterType<concrete_type>(ir_context);
} // namespace ir
......@@ -14,105 +14,69 @@
#pragma once
#include <glog/logging.h>
#include <functional>
namespace ir {
///
/// \brief TypeId is the unique identification of Type, each Type corresponds to
/// a unique TypeId, the same id indicates the same Type class. TypeId provides
/// an instantiation interface: TypeId::get.
///
/// Example:
/// \code{cpp}
/// class TypeA {};
/// TypeId type_a_id = TypeId::get<TypeA>();
/// \endcode
///
class TypeId {
struct Storage {};
public:
///
/// \brief Returns the unique TypeId of Type T.
///
/// \return The unique TypeId of Type T.
///
template <typename T>
static TypeId get() {
static Storage instance;
return TypeId(&instance);
}
///
/// \brief Comparison operations.
///
inline bool operator==(const TypeId &other) const {
return storage_ == other.storage_;
}
/// \brief Comparison operations.
inline bool operator!=(const TypeId &other) const {
return !(*this == other);
}
///
/// \brief Enable hashing TypeId instances.
///
friend struct std::hash<TypeId>;
private:
///
/// \brief Construct a TypeId and initialize storage.
///
/// \param storage The storage of this TypeId.
///
explicit TypeId(const Storage *storage) : storage_(storage) {}
const Storage *storage_;
};
/// \brief Abstract the properties and behaviors common to all Type classes into
/// an AbstractType class. There are two types in Type system:
/// on-parameter/singleton type and parameter-type. The common attributes of all
/// types is TypeId (and possibly others). Therefore, construct a class with
/// TypeId as its member.
class AbstractType {
public:
/// \brief Construct an AbstractType by TypeId directly.
/// \param type_id The type id of the AbstractType.
static AbstractType get(TypeId type_id) { return AbstractType(type_id); }
/// \brief Returns the type id of the AbstractType.
/// \return The type id of the AbstractType.
TypeId type_id() const { return type_id_; }
/* TODO(zhangbo9674): After the IRContext is designed, AbstractType will be
* cached to IRContext with TypeId as key.
*/
private:
/// \brief The constructor is set to private and provides the user with the
/// get method to obtain and manage the AstractType.
/// \param type_id The type id of the AbstractType.
explicit AbstractType(TypeId type_id) : type_id_(type_id) {}
TypeId type_id_;
};
/// \brief TypeStorage is used to store all information of a Type. A Type object
/// contains a TypeStorage. For non-parameter type, the information includes:
/// TypeId, so TypeStorage only needs to include AbstractType; For parameter
/// type, in addition to AbstractType/TypeId, parameter information needs to be
/// included. So that, non-parameter type can be constructed by TypeStorage
/// directly but parameter type should be constructed by Derived TypeStorage.
class TypeStorage {
public:
/// \brief Construct a TypeStorage and initialize abstract_type.
/// \param abstract_type The abstract_type of this TypeStorage.
explicit TypeStorage(AbstractType *abstract_type)
: abstract_type_(abstract_type) {}
/// \brief Returns the AbstractType of the TypeStorage.
/// \return The AbstractType of the TypeStorage.
const AbstractType &abstract_type() { return *abstract_type_; }
private:
AbstractType *abstract_type_{nullptr};
};
} // namespace ir
// Custom specialization of std::hash can be injected in namespace std.
namespace std {
///
/// \brief Enable hashing TypeId instances.
///
template <>
struct hash<ir::TypeId> {
std::size_t operator()(const ir::TypeId &obj) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册