未验证 提交 586d9018 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Attribute system (#51636)

* add Attribute system to new ir

* set StorageType to Storage in Type and Attribute

* refine strAttr

* refine name of StrAttribute

* add DictionaryAttribute

* refine code

* refine dic_attr

* refine code

* Set DictionaryAttribute ParamKey is map

* refine code

* refine code by comment

* refine code

* refine code

* refine code

* refine code

* fix complie bug

* refine code

* add const for Attribute storage
上级 52a31b87
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/attribute.h"
#include "paddle/ir/dialect.h"
namespace ir {
IrContext *Attribute::ir_context() const { return dialect().ir_context(); }
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/attribute_base.h"
#include "paddle/ir/cast_utils.h"
namespace ir {
///
/// \brief Unified interface of the Attribute class. Derivation of all Attribute
/// classes only derives interfaces, not members.
///
class Attribute {
public:
using Storage = AttributeStorage;
constexpr Attribute() = default;
Attribute(const Storage *storage) // NOLINT
: storage_(storage) {}
Attribute(const Attribute &other) = default;
Attribute &operator=(const Attribute &other) = default;
bool operator==(Attribute other) const { return storage_ == other.storage_; }
bool operator!=(Attribute other) const { return storage_ != other.storage_; }
explicit operator bool() const { return storage_; }
bool operator!() const { return storage_ == nullptr; }
///
/// \brief Some Attribute attribute acquisition interfaces.
///
TypeId type_id() { return storage_->abstract_attribute().type_id(); }
const AbstractAttribute &abstract_attribute() {
return storage_->abstract_attribute();
}
const Storage *storage() const { return storage_; }
const Dialect &dialect() const {
return storage_->abstract_attribute().dialect();
}
IrContext *ir_context() const;
///
/// \brief Methods for type judgment and cast.
///
static bool classof(Attribute) { return true; }
template <typename T>
bool isa() const {
return ir::isa<T>(*this);
}
template <typename U>
U dyn_cast() const {
return ir::dyn_cast<U>(*this);
}
friend struct std::hash<Attribute>;
protected:
const Storage *storage_{nullptr};
};
} // namespace ir
namespace std {
template <>
struct hash<ir::Attribute> {
std::size_t operator()(const ir::Attribute &obj) const {
return std::hash<const ir::Attribute::Storage *>()(obj.storage_);
}
};
} // namespace std
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/ir_context.h"
#include "paddle/ir/storage_manager.h"
#include "paddle/ir/type_id.h"
namespace ir {
class Dialect;
///
/// \brief Abstract the properties and behaviors common to all Attribute classes
/// into an AbstractAttribute class.
///
class AbstractAttribute {
public:
///
/// \brief Construct an AbstractAttribute by TypeId directly.
///
/// \param type_id The id of the AbstractAttribute.
/// \param dialect The Dialect which the attribute registered to.
///
static AbstractAttribute get(TypeId type_id, const Dialect &dialect) {
return AbstractAttribute(type_id, dialect);
}
///
/// \brief Construct an AbstractAttribute by TypeId directly.
///
/// \param dialect The Dialect which the attribute registered to.
///
template <typename T>
static AbstractAttribute get(const Dialect &dialect) {
return AbstractAttribute(TypeId::get<T>(), dialect);
}
///
/// \brief Returns the type id of the AbstractAttribute.
///
/// \return The id of the AbstractAttribute.
///
TypeId type_id() const { return type_id_; }
///
/// \brief Get the dialect this attribute was registered to.
///
/// \return The dialect this attribute was registered to.
///
const Dialect &dialect() const { return dialect_; }
///
/// \brief Find the AbstractAttribute instance whose TypeId is type_id from
/// IrContext.
///
/// \param type_id The type id of the AbstractAttribute.
/// \param ctx The IrContext.
/// \return The AbstractAttribute instance whose TypeId is type_id.
///
static const AbstractAttribute &lookup(TypeId type_id, IrContext *ctx);
private:
///
/// \brief The constructor is set to private and provides the user with the
/// get method to obtain and manage the AbstractAttribute.
///
/// \param type_id The type id of the AbstractAttribute.
/// \param dialect The Dialect which the attribute registered to.
///
explicit AbstractAttribute(TypeId type_id, const Dialect &dialect)
: type_id_(type_id), dialect_(dialect) {}
TypeId type_id_;
const Dialect &dialect_;
};
struct AttributeManager;
///
/// \brief AttributeStorage is used to store all information of a Attribute. A
/// Attribute object contains a AttributeStorage. For non-parameter attribute,
/// the information includes: TypeId, so AttributeStorage only needs to include
/// AbstractAttribute; For parameteric attribute, in addition to
/// AbstractAttribute/TypeId, parameteric information needs to be included. So
/// that, non-parameteric attribute can be constructed by AttributeStorage
/// directly but parameteric attribute should be constructed by Derived
/// AttributeStorage.
///
class AttributeStorage : public StorageManager::StorageBase {
friend StorageManager;
friend AttributeManager;
public:
///
/// \brief Construct a AttributeStorage and initialize abstract_attribute.
///
/// \param abstract_attribute The abstract_attribute of this AttributeStorage.
///
explicit AttributeStorage(AbstractAttribute *abstract_attribute)
: abstract_attribute_(abstract_attribute) {}
AttributeStorage() {}
///
/// \brief Returns the AbstractAttribute of the AttributeStorage.
///
/// \return The AbstractAttribute of the AttributeStorage.
///
const AbstractAttribute &abstract_attribute() const {
return *abstract_attribute_;
}
private:
///
/// \brief Initialize AttributeStorage based on the AbstractAttribute*
/// provided by the user
///
/// \param abstract_attribute AbstractAttribute* provided by the user, the
/// construction method of AbstractAttribute refers to AbstractAttribute::get.
///
void initialize(const AbstractAttribute &abstract_attribute) {
abstract_attribute_ = const_cast<AbstractAttribute *>(&abstract_attribute);
}
AbstractAttribute *abstract_attribute_{nullptr}; // not owned
};
///
/// \brief AttributeManager is a utility class that provides interfaces for get
/// or unique Attribute instances in IrContext.
///
struct AttributeManager {
///
/// \brief Get a unique instance of Attribute T from IrContext. Note: For a
/// parameteric attribute, if not found in IrContext, it will try to create a
/// new instance and register it to IrContext; for a parameterless attribute,
/// only search.
///
/// \param ctx The IrContext instance.
/// \param args Parameters of the wrapped function.
/// \return The unique instance of Attribute T from IrContext.
///
template <typename T, typename... Args>
static T get(IrContext *ctx, Args &&...args) {
return get<T, Args...>(
ctx, ir::TypeId::get<T>(), std::forward<Args>(args)...);
}
///
/// \brief Get a unique instance of parametric Attribute T from IrContext. If
/// not found in IrContext, it will try to create a new instance and register
/// it to IrContext;
///
/// \param ctx The IrContext instance.
/// \param type_id The type id of the AbstractAttribute.
/// \param args Parameters of the wrapped function.
/// \return The unique instance of Attribute T from IrContext.
///
template <typename T, typename... Args>
static std::enable_if_t<
!std::is_same<typename T::Storage, AttributeStorage>::value,
T>
get(IrContext *ctx, TypeId type_id, Args &&...args) {
return ctx->attribute_storage_manager()
.GetParametricStorage<typename T::Storage>(
[&, type_id](AttributeStorage *storage) {
storage->initialize(AbstractAttribute::lookup(type_id, ctx));
},
type_id,
std::forward<Args>(args)...);
}
///
/// \brief Get a unique instance of parameterless Attribute T from IrContext.
///
/// \param ctx The IrContext instance.
/// \param type_id The type id of the AbstractAttribute.
/// \return The unique instance of Attribute T from IrContext.
///
template <typename T>
static std::
enable_if_t<std::is_same<typename T::Storage, AttributeStorage>::value, T>
get(IrContext *ctx, TypeId type_id) {
return ctx->attribute_storage_manager()
.GetParameterlessStorage<typename T::Storage>(type_id);
}
///
/// \brief Register a unique instance of Attribute T to IrContext.
///
/// \param ctx The IrContext instance.
///
template <typename T>
static void RegisterAttribute(IrContext *ctx) {
RegisterAttribute<T>(ctx, ir::TypeId::get<T>());
}
///
/// \brief Register a unique parametric Attribute T to IrContext.
///
/// \param ctx The IrContext instance.
/// \param type_id The type id of the Attribute T.
///
template <typename T>
static std::enable_if_t<
!std::is_same<typename T::Storage, AttributeStorage>::value>
RegisterAttribute(IrContext *ctx, TypeId type_id) {
ctx->attribute_storage_manager()
.RegisterParametricStorage<typename T::Storage>(type_id);
}
///
/// \brief Register a unique parameterless Attribute T to IrContext.
///
/// \param ctx The IrContext instance.
/// \param type_id The type id of the Attribute T.
///
template <typename T>
static std::enable_if_t<
std::is_same<typename T::Storage, AttributeStorage>::value>
RegisterAttribute(IrContext *ctx, TypeId type_id) {
ctx->attribute_storage_manager()
.RegisterParameterlessStorage<AttributeStorage>(
type_id, [&ctx, type_id](AttributeStorage *storage) {
storage->initialize(AbstractAttribute::lookup(type_id, ctx));
});
}
};
///
/// \brief Add some necessary functions to the custom Attribute class.
///
#define DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(concrete_attribute, storage_type) \
using Storage = storage_type; \
\
const Storage *storage() const { \
return static_cast<const Storage *>(this->storage_); \
} \
\
static ir::TypeId type_id() { \
return ir::TypeId::get<concrete_attribute>(); \
} \
\
template <typename T> \
static bool classof(T val) { \
return val.type_id() == type_id(); \
} \
\
template <typename... Args> \
static concrete_attribute get(ir::IrContext *ctx, Args... args) { \
return ir::AttributeManager::template get<concrete_attribute>(ctx, \
args...); \
}
///
/// \brief This macro definition is used to register custom Attribute class.
///
#define REGISTER_ATTRIBUTE_2_IRCONTEXT(concrete_attribute, dialect) \
ir::AbstractAttribute *abstract_attribute_##concrete_attribute = \
new ir::AbstractAttribute(std::move( \
ir::AbstractAttribute::get<concrete_attribute>(*dialect))); \
\
dialect->ir_context()->RegisterAbstractAttribute( \
ir::TypeId::get<concrete_attribute>(), \
abstract_attribute_##concrete_attribute); \
\
ir::AttributeManager::RegisterAttribute<concrete_attribute>( \
dialect->ir_context());
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/builtin_attribute.h"
namespace ir {
std::string StrAttribute::data() const { return storage()->GetAsKey(); }
uint32_t StrAttribute::size() const { return storage()->GetAsKey().size(); }
NamedAttribute::NamedAttribute(StrAttribute name, Attribute value)
: name_(name), value_(value) {}
bool NamedAttribute::operator<(const NamedAttribute &right) const {
return name() < right.name();
}
bool NamedAttribute::operator==(const NamedAttribute &right) const {
return name() == right.name() && value() == right.value();
}
bool NamedAttribute::operator!=(const NamedAttribute &right) const {
return !(*this == right);
}
Attribute DictionaryAttribute::GetValue(const StrAttribute &name) {
return storage()->GetValue(name);
}
uint32_t DictionaryAttribute::size() const { return storage()->size(); }
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/attribute.h"
#include "paddle/ir/builtin_attribute_storage.h"
namespace ir {
///
/// \brief All built-in attributes.
///
#define GET_BUILT_IN_ATTRIBUTE_LIST ir::StrAttribute, ir::DictionaryAttribute
class StrAttribute : public ir::Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(StrAttribute, StrAttributeStorage);
bool operator<(const StrAttribute &right) const {
return storage() < right.storage();
}
std::string data() const;
uint32_t size() const;
};
class NamedAttribute {
public:
NamedAttribute(StrAttribute name, Attribute value);
StrAttribute name() const { return name_; }
Attribute value() const { return value_; }
void SetName(StrAttribute name) { name_ = name; }
void SetValue(Attribute value) { value_ = value; }
bool operator<(const NamedAttribute &right) const;
bool operator==(const NamedAttribute &right) const;
bool operator!=(const NamedAttribute &right) const;
friend struct std::hash<NamedAttribute>;
operator std::pair<const StrAttribute, Attribute>() const {
return std::make_pair(name_, value_);
}
private:
StrAttribute name_;
Attribute value_;
};
class DictionaryAttribute : public ir::Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(DictionaryAttribute,
DictionaryAttributeStorage);
Attribute GetValue(const StrAttribute &name);
uint32_t size() const;
};
} // namespace ir
namespace std {
static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
}
template <>
struct hash<ir::NamedAttribute> {
std::size_t operator()(const ir::NamedAttribute &obj) const {
return hash_combine(std::hash<ir::Attribute>()(obj.name_),
std::hash<ir::Attribute>()(obj.value_));
}
};
} // namespace std
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/builtin_attribute_storage.h"
#include "paddle/ir/builtin_attribute.h"
namespace ir {
DictionaryAttributeStorage::DictionaryAttributeStorage(const ParamKey &key) {
size_ = key.size();
data_ = reinterpret_cast<NamedAttribute *>(
malloc(sizeof(NamedAttribute) * size_));
uint32_t idx = 0;
for (auto iter = key.begin(); iter != key.end(); ++iter) {
data_[idx].SetName(iter->first);
data_[idx].SetValue(iter->second);
idx++;
}
}
std::size_t DictionaryAttributeStorage::HashValue(const ParamKey &key) {
std::size_t hash_value = key.size();
for (auto iter = key.begin(); iter != key.end(); ++iter) {
hash_value = hash_combine(
hash_value,
std::hash<NamedAttribute>()(NamedAttribute(iter->first, iter->second)));
}
return hash_value;
}
bool DictionaryAttributeStorage::operator==(const ParamKey &key) const {
uint32_t size = key.size();
if (size_ != size) return false;
uint32_t idx = 0;
for (auto iter = key.begin(); iter != key.end(); ++iter) {
if (data_[idx] != NamedAttribute(iter->first, iter->second)) {
return false;
}
idx++;
}
return true;
}
DictionaryAttributeStorage::ParamKey DictionaryAttributeStorage::GetAsKey()
const {
return ParamKey(data_, data_ + size_);
}
Attribute DictionaryAttributeStorage::GetValue(const StrAttribute &name) const {
if (size_ > 0) {
size_t left = 0;
size_t right = size_ - 1;
size_t mid = 0;
while (left <= right) {
mid = (left + right) / 2;
if (data_[mid].name() == name) {
return data_[mid].value();
} else if (data_[mid].name() < name) {
left = mid + 1;
} else {
right = mid - 1;
}
}
}
return nullptr;
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <map>
#include <type_traits>
#include "paddle/ir/attribute.h"
namespace ir {
///
/// \brief Define Parameteric AttributeStorage for StrAttribute.
///
struct StrAttributeStorage : public ir::AttributeStorage {
using ParamKey = std::string;
explicit StrAttributeStorage(const ParamKey &key) {
data_ = reinterpret_cast<char *>(malloc(key.size()));
memcpy(data_, const_cast<char *>(key.c_str()), key.size());
size_ = key.size();
}
~StrAttributeStorage() { free(data_); }
static StrAttributeStorage *Construct(ParamKey key) {
return new StrAttributeStorage(key);
}
static std::size_t HashValue(const ParamKey &key) {
return std::hash<std::string>()(key);
}
bool operator==(const ParamKey &key) const {
return std::equal(data_, data_ + size_, const_cast<char *>(key.c_str()));
}
ParamKey GetAsKey() const { return ParamKey(data_, size_); }
private:
char *data_;
uint32_t size_;
};
///
/// \brief Define Parameteric AttributeStorage for DictionaryAttributeStorage.
///
class StrAttribute;
class NamedAttribute;
struct DictionaryAttributeStorage : public AttributeStorage {
using ParamKey = std::map<StrAttribute, Attribute>;
explicit DictionaryAttributeStorage(const ParamKey &key);
~DictionaryAttributeStorage() { free(data_); }
static DictionaryAttributeStorage *Construct(ParamKey key) {
return new DictionaryAttributeStorage(key);
}
static std::size_t HashValue(const ParamKey &key);
bool operator==(const ParamKey &key) const;
ParamKey GetAsKey() const;
Attribute GetValue(const StrAttribute &name) const;
NamedAttribute *data() const { return data_; }
uint32_t size() const { return size_; }
private:
static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
}
NamedAttribute *data_;
uint32_t size_;
};
} // namespace ir
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/ir/builtin_dialect.h"
#include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/builtin_type.h"
namespace ir {
......@@ -24,6 +25,7 @@ BuiltinDialect::BuiltinDialect(ir::IrContext *context)
void BuiltinDialect::initialize() {
// Register all built-in types defined in builtin_type.h.
RegisterTypes<GET_BUILT_IN_TYPE_LIST>();
RegisterAttributes<GET_BUILT_IN_ATTRIBUTE_LIST>();
}
} // namespace ir
......@@ -25,4 +25,10 @@ void Dialect::RegisterType(ir::AbstractType &&abstract_type) {
new_abstract_type);
}
void Dialect::RegisterAttribute(ir::AbstractAttribute &&abstract_attribute) {
ir::AbstractAttribute *new_abstract_attribute =
new ir::AbstractAttribute(std::move(abstract_attribute));
this->ir_context()->RegisterAbstractAttribute(
new_abstract_attribute->type_id(), new_abstract_attribute);
}
} // namespace ir
......@@ -14,16 +14,17 @@
#pragma once
#include "paddle/ir/attribute_base.h"
#include "paddle/ir/ir_context.h"
#include "paddle/ir/type_base.h"
namespace ir {
///
/// \brief Dialect can basically be understood as a namespace. In Dialect, we
/// can define a series of types, operations, etc. An instance of the dialect
/// object will be loaded into the global IrContext. Specific compilers only
/// need to combine existing dialects and add their own extensions or
/// customizations.
/// can define a series of types, attributes, operations, etc. An instance of
/// the dialect object will be loaded into the global IrContext. Specific
/// compilers only need to combine existing dialects and add their own
/// extensions or customizations.
///
class Dialect {
public:
......@@ -67,6 +68,35 @@ class Dialect {
///
void RegisterType(ir::AbstractType &&abstract_type);
///
/// \brief Register all attributes contained in the template parameter Args.
/// To register only one Attribute, you can use the RegisterAttribute template
/// function.
///
template <typename... Args>
void RegisterAttributes() {
(void)std::initializer_list<int>{0, (RegisterAttribute<Args>(), 0)...};
}
///
/// \brief Register attribute of class T.
///
template <typename T>
void RegisterAttribute() {
VLOG(4) << "Attribute registered into Dialect. --->";
ir::AbstractAttribute *abstract_attribute = new ir::AbstractAttribute(
std::move(ir::AbstractAttribute::get<T>(*this)));
this->ir_context()->RegisterAbstractAttribute(ir::TypeId::get<T>(),
abstract_attribute);
ir::AttributeManager::RegisterAttribute<T>(this->ir_context());
VLOG(4) << "----------------------------------";
}
///
/// \brief Register abstract_attribute into context.
///
void RegisterAttribute(ir::AbstractAttribute &&abstract_attribute);
private:
std::string name_;
......
......@@ -16,6 +16,7 @@
#include <unordered_map>
#include "paddle/ir/attribute_base.h"
#include "paddle/ir/builtin_dialect.h"
#include "paddle/ir/builtin_type.h"
#include "paddle/ir/dialect.h"
......@@ -24,7 +25,7 @@
namespace ir {
// The implementation class of the IrContext class, cache registered
// AbstractType, TypeStorage, Dialect.
// AbstractType, TypeStorage, AbstractAttribute, AttributeStorage, Dialect.
class IrContextImpl {
public:
IrContextImpl() {}
......@@ -36,6 +37,11 @@ class IrContextImpl {
}
registed_abstract_types_.clear();
for (auto &abstract_attribute_map : registed_abstract_attributes_) {
delete abstract_attribute_map.second;
}
registed_abstract_attributes_.clear();
for (auto &dialect_map : registed_dialect_) {
delete dialect_map.second;
}
......@@ -64,6 +70,29 @@ class IrContextImpl {
return nullptr;
}
void RegisterAbstractAttribute(ir::TypeId type_id,
AbstractAttribute *abstract_attribute) {
std::lock_guard<ir::SpinLock> guard(registed_abstract_attributes_lock_);
VLOG(4) << "Register an abstract_attribute of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id)
<< ", AbstractAttribute_ptr=" << abstract_attribute << "].";
registed_abstract_attributes_.emplace(type_id, abstract_attribute);
}
AbstractAttribute *GetAbstractAttribute(ir::TypeId type_id) {
std::lock_guard<ir::SpinLock> guard(registed_abstract_attributes_lock_);
auto iter = registed_abstract_attributes_.find(type_id);
if (iter != registed_abstract_attributes_.end()) {
VLOG(4) << "Fonund a cached abstract_attribute of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id)
<< ", AbstractAttribute_ptr=" << iter->second << "].";
return iter->second;
}
LOG(WARNING) << "No cache found abstract_attribute of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
return nullptr;
}
void RegisterDialect(std::string name, Dialect *dialect) {
std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
VLOG(4) << "Register a dialect of: [name=" << name
......@@ -86,14 +115,8 @@ class IrContextImpl {
// Cached AbstractType instances.
std::unordered_map<TypeId, AbstractType *> registed_abstract_types_;
ir::SpinLock registed_abstract_types_lock_;
// TypeStorage uniquer and cache instances.
StorageManager registed_storage_manager_;
// The dialcet registered in the context.
std::unordered_map<std::string, Dialect *> registed_dialect_;
ir::SpinLock registed_dialect_lock_;
StorageManager registed_type_storage_manager_;
// Cache some built-in type objects.
Float16Type fp16_type;
Float32Type fp32_type;
......@@ -102,6 +125,16 @@ class IrContextImpl {
Int32Type int32_type;
Int64Type int64_type;
// Cached AbstractAttribute instances.
std::unordered_map<TypeId, AbstractAttribute *> registed_abstract_attributes_;
ir::SpinLock registed_abstract_attributes_lock_;
// AttributeStorage uniquer and cache instances.
StorageManager registed_attribute_storage_manager_;
// The dialcet registered in the context.
std::unordered_map<std::string, Dialect *> registed_dialect_;
ir::SpinLock registed_dialect_lock_;
ir::SpinLock destructor_lock_;
};
......@@ -128,8 +161,8 @@ void IrContext::RegisterAbstractType(ir::TypeId type_id,
impl().RegisterAbstractType(type_id, abstract_type);
}
StorageManager &IrContext::storage_manager() {
return impl().registed_storage_manager_;
StorageManager &IrContext::type_storage_manager() {
return impl().registed_type_storage_manager_;
}
std::unordered_map<TypeId, AbstractType *>
......@@ -137,6 +170,20 @@ std::unordered_map<TypeId, AbstractType *>
return impl().registed_abstract_types_;
}
void IrContext::RegisterAbstractAttribute(
ir::TypeId type_id, AbstractAttribute *abstract_attribute) {
impl().RegisterAbstractAttribute(type_id, abstract_attribute);
}
StorageManager &IrContext::attribute_storage_manager() {
return impl().registed_attribute_storage_manager_;
}
std::unordered_map<TypeId, AbstractAttribute *>
&IrContext::registed_abstracted_attribute() {
return impl().registed_abstract_attributes_;
}
Dialect *IrContext::GetOrRegisterDialect(
std::string dialect_name, std::function<Dialect *()> constructor) {
VLOG(4) << "Try to get or register a Dialect of: [name=" << dialect_name
......@@ -179,6 +226,17 @@ const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
}
}
const AbstractAttribute &AbstractAttribute::lookup(TypeId type_id,
IrContext *ctx) {
auto &impl = ctx->impl();
AbstractAttribute *abstract_attribute = impl.GetAbstractAttribute(type_id);
if (abstract_attribute) {
return *abstract_attribute;
} else {
throw("Abstract attribute not found in IrContext.");
}
}
Float16Type Float16Type::get(IrContext *ctx) { return ctx->impl().fp16_type; }
Float32Type Float32Type::get(IrContext *ctx) { return ctx->impl().fp32_type; }
......
......@@ -23,12 +23,13 @@ namespace ir {
class IrContextImpl;
class StorageManager;
class AbstractType;
class AbstractAttribute;
class TypeId;
class Dialect;
///
/// \brief IrContext is a global parameterless class used to store and manage
/// Type and its related data structures.
/// Type, Attribute and other related data structures.
///
class IrContext {
public:
......@@ -60,7 +61,7 @@ class IrContext {
/// \return The storage uniquer used for constructing TypeStorage
/// instances.
///
StorageManager &storage_manager();
StorageManager &type_storage_manager();
///
/// \brief Returns the storage uniquer used for constructing TypeStorage
......@@ -71,6 +72,34 @@ class IrContext {
///
std::unordered_map<TypeId, AbstractType *> &registed_abstracted_type();
///
/// \brief Register an AbstractAttribute to IrContext
///
/// \param type_id The type id of the AbstractAttribute.
/// \param abstract_attribute AbstractAttribute* provided by user.
///
void RegisterAbstractAttribute(ir::TypeId type_id,
AbstractAttribute *abstract_attribute);
///
/// \brief Returns the storage uniquer used for constructing AttributeStorage
/// instances.
///
/// \return The storage uniquer used for constructing AttributeStorage
/// instances.
///
StorageManager &attribute_storage_manager();
///
/// \brief Returns the storage uniquer used for constructing AttributeStorage
/// instances.
///
/// \return The storage uniquer used for constructing AttributeStorage
/// instances.
///
std::unordered_map<TypeId, AbstractAttribute *>
&registed_abstracted_attribute();
///
/// \brief Get the dialect of the DialectT class in the context, ff not found,
/// create and register to context.
......
......@@ -66,7 +66,7 @@ StorageManager::StorageManager() {}
StorageManager::~StorageManager() = default;
StorageManager::StorageBase *StorageManager::GetParametricStorageTypeImpl(
StorageManager::StorageBase *StorageManager::GetParametricStorageImpl(
TypeId type_id,
std::size_t hash_value,
std::function<bool(const StorageBase *)> equal_func,
......@@ -81,7 +81,7 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageTypeImpl(
return parametric_storage.GetOrCreate(hash_value, equal_func, constructor);
}
StorageManager::StorageBase *StorageManager::GetParameterlessStorageTypeImpl(
StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl(
TypeId type_id) {
std::lock_guard<ir::SpinLock> guard(parameterless_instance_lock_);
VLOG(4) << "Try to get a parameterless storage of: [TypeId_hash="
......@@ -92,7 +92,7 @@ StorageManager::StorageBase *StorageManager::GetParameterlessStorageTypeImpl(
return parameterless_instance;
}
void StorageManager::RegisterParametricStorageTypeImpl(TypeId type_id) {
void StorageManager::RegisterParametricStorageImpl(TypeId type_id) {
std::lock_guard<ir::SpinLock> guard(parametric_instance_lock_);
VLOG(4) << "Register a parameteric storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
......@@ -100,7 +100,7 @@ void StorageManager::RegisterParametricStorageTypeImpl(TypeId type_id) {
std::make_unique<ParametricStorageManager>());
}
void StorageManager::RegisterParameterlessStorageTypeImpl(
void StorageManager::RegisterParameterlessStorageImpl(
TypeId type_id, std::function<StorageBase *()> constructor) {
std::lock_guard<ir::SpinLock> guard(parameterless_instance_lock_);
VLOG(4) << "Register a parameterless storage of: [TypeId_hash="
......
......@@ -64,7 +64,7 @@ class StorageManager {
/// \return A uniqued instance of Storage.
///
template <typename Storage, typename... Args>
Storage *GetParametricStorageType(std::function<void(Storage *)> init_func,
Storage *GetParametricStorage(std::function<void(Storage *)> init_func,
TypeId type_id,
Args &&...args) {
typename Storage::ParamKey param =
......@@ -78,8 +78,8 @@ class StorageManager {
if (init_func) init_func(storage);
return storage;
};
return static_cast<Storage *>(GetParametricStorageTypeImpl(
type_id, hash_value, equal_func, constructor));
return static_cast<Storage *>(
GetParametricStorageImpl(type_id, hash_value, equal_func, constructor));
}
///
......@@ -89,8 +89,8 @@ class StorageManager {
/// \return A uniqued instance of Storage.
///
template <typename Storage>
Storage *GetParameterlessStorageType(TypeId type_id) {
return static_cast<Storage *>(GetParameterlessStorageTypeImpl(type_id));
Storage *GetParameterlessStorage(TypeId type_id) {
return static_cast<Storage *>(GetParameterlessStorageImpl(type_id));
}
///
......@@ -99,8 +99,8 @@ class StorageManager {
/// \param type_id The type id of the AbstractType.
///
template <typename Storage>
void RegisterParametricStorageType(TypeId type_id) {
return RegisterParametricStorageTypeImpl(type_id);
void RegisterParametricStorage(TypeId type_id) {
return RegisterParametricStorageImpl(type_id);
}
///
......@@ -110,28 +110,28 @@ class StorageManager {
/// \param init_func Used to initialize a newly inserted storage instance.
///
template <typename Storage>
void RegisterParameterlessStorageType(
TypeId type_id, std::function<void(Storage *)> init_func) {
void RegisterParameterlessStorage(TypeId type_id,
std::function<void(Storage *)> init_func) {
auto constructor = [&]() {
auto *storage = new Storage();
if (init_func) init_func(storage);
return storage;
};
RegisterParameterlessStorageTypeImpl(type_id, constructor);
RegisterParameterlessStorageImpl(type_id, constructor);
}
private:
StorageBase *GetParametricStorageTypeImpl(
StorageBase *GetParametricStorageImpl(
TypeId type_id,
std::size_t hash_value,
std::function<bool(const StorageBase *)> equal_func,
std::function<StorageBase *()> constructor);
StorageBase *GetParameterlessStorageTypeImpl(TypeId type_id);
StorageBase *GetParameterlessStorageImpl(TypeId type_id);
void RegisterParametricStorageTypeImpl(TypeId type_id);
void RegisterParametricStorageImpl(TypeId type_id);
void RegisterParameterlessStorageTypeImpl(
void RegisterParameterlessStorageImpl(
TypeId type_id, std::function<StorageBase *()> constructor);
// This map is a mapping between type id and parameteric type storage.
......
cc_test_old(type_test SRCS type_test.cc DEPS new_ir gtest)
cc_test_old(ir_attribute_test SRCS ir_attribute_test.cc DEPS new_ir gtest)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <map>
#include "paddle/ir/attribute.h"
#include "paddle/ir/attribute_base.h"
#include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/builtin_dialect.h"
#include "paddle/ir/dialect.h"
#include "paddle/ir/ir_context.h"
TEST(attribute_test, attribute_base) {
class AttributeA {};
struct FakeDialect : ir::Dialect {
explicit FakeDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<FakeDialect>()) {}
static const char *name() { return "fake"; }
};
// Test 1: Test the function of IrContext to register Dialect.
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Dialect *fake_dialect = ctx->GetOrRegisterDialect<FakeDialect>();
// Test 2: Test the get method of AbstractType.
ir::TypeId a_id = ir::TypeId::get<AttributeA>();
ir::AbstractAttribute abstract_attribute_a =
ir::AbstractAttribute::get(a_id, *fake_dialect);
EXPECT_EQ(abstract_attribute_a.type_id(), a_id);
// Test 3: Test the constructor of AbstractStorage.
ir::AttributeStorage storage_a(&abstract_attribute_a);
EXPECT_EQ(storage_a.abstract_attribute().type_id(),
abstract_attribute_a.type_id());
}
TEST(attribute_test, built_in_attribute) {
ir::IrContext *ctx = ir::IrContext::Instance();
// Test 1: Test the parameteric built-in attribute of IrContext.
std::string str_tmp = "string_a";
ir::Attribute string_attr_1 = ir::StrAttribute::get(ctx, str_tmp);
ir::Attribute string_attr_2 = ir::StrAttribute::get(ctx, str_tmp);
EXPECT_EQ(string_attr_1, string_attr_2);
EXPECT_EQ(ir::StrAttribute::classof(string_attr_1), 1);
// Test 2: Test isa and dyn_cast.
EXPECT_EQ(string_attr_1.isa<ir::StrAttribute>(), true);
ir::StrAttribute string_attr_cast_1 =
string_attr_1.dyn_cast<ir::StrAttribute>();
EXPECT_EQ(string_attr_cast_1.isa<ir::StrAttribute>(), true);
EXPECT_EQ(string_attr_cast_1.size() == 8, 1);
}
TEST(attribute_test, dictionary_attribute) {
ir::IrContext *ctx = ir::IrContext::Instance();
std::string str_attr1_name = "attr1_name";
std::string str_attr1_value = "attr1_value";
ir::StrAttribute attr1_name = ir::StrAttribute::get(ctx, str_attr1_name);
ir::Attribute attr1_value = ir::StrAttribute::get(ctx, str_attr1_value);
std::string str_attr2_name = "attr2_name";
std::string str_attr2_value = "attr2_value";
ir::StrAttribute attr2_name = ir::StrAttribute::get(ctx, str_attr2_name);
ir::Attribute attr2_value = ir::StrAttribute::get(ctx, str_attr2_value);
std::map<ir::StrAttribute, ir::Attribute> named_attr1;
named_attr1.insert(
std::pair<ir::StrAttribute, ir::Attribute>(attr1_name, attr1_value));
named_attr1.insert(
std::pair<ir::StrAttribute, ir::Attribute>(attr2_name, attr2_value));
ir::DictionaryAttribute dic_attr1 =
ir::DictionaryAttribute::get(ctx, named_attr1);
std::map<ir::StrAttribute, ir::Attribute> named_attr2;
named_attr2.insert(
std::pair<ir::StrAttribute, ir::Attribute>(attr2_name, attr2_value));
named_attr2.insert(
std::pair<ir::StrAttribute, ir::Attribute>(attr1_name, attr1_value));
ir::DictionaryAttribute dic_attr2 =
ir::DictionaryAttribute::get(ctx, named_attr2);
EXPECT_EQ(dic_attr1, dic_attr2);
EXPECT_EQ(attr1_value, dic_attr1.GetValue(attr1_name));
EXPECT_EQ(attr2_value, dic_attr1.GetValue(attr2_name));
}
......@@ -26,12 +26,12 @@ namespace ir {
///
class Type {
public:
using StorageType = TypeStorage;
using Storage = TypeStorage;
constexpr Type() = default;
Type(const StorageType *storage) // NOLINT
: storage_(const_cast<StorageType *>(storage)) {}
Type(const Storage *storage) // NOLINT
: storage_(const_cast<Storage *>(storage)) {}
Type(const Type &other) = default;
......@@ -55,7 +55,7 @@ class Type {
const AbstractType &abstract_type() { return storage_->abstract_type(); }
StorageType *storage() const { return storage_; }
const Storage *storage() const { return storage_; }
const Dialect &dialect() const { return storage_->abstract_type().dialect(); }
......@@ -82,7 +82,7 @@ class Type {
friend struct std::hash<Type>;
protected:
StorageType *storage_{nullptr};
const Storage *storage_{nullptr};
};
} // namespace ir
......@@ -94,7 +94,7 @@ namespace std {
template <>
struct hash<ir::Type> {
std::size_t operator()(const ir::Type &obj) const {
return std::hash<ir::Type::StorageType *>()(obj.storage_);
return std::hash<const ir::Type::Storage *>()(obj.storage_);
}
};
} // namespace std
......@@ -24,9 +24,9 @@ class Dialect;
///
/// \brief Abstract the properties and behaviors common to all Type classes into
/// an AbstractType class. There are two types in Type system:
/// on-parameter/parameterless type and parameter-type. The common attributes of
/// all types is TypeId (and possibly others). Therefore, construct a class with
/// TypeId as its member.
/// non-parameter/parameterless type and parameteric-type. The common attributes
/// of all types is TypeId (and possibly others). Therefore, construct a class
/// with TypeId as its member.
///
class AbstractType {
public:
......@@ -95,10 +95,10 @@ struct TypeManager;
///
/// \brief TypeStorage is used to store all information of a Type. A Type object
/// contains a TypeStorage. For non-parameter type, the information includes:
/// TypeId, so TypeStorage only needs to include AbstractType; For parameter
/// type, in addition to AbstractType/TypeId, parameter information needs to be
/// included. So that, non-parameter type can be constructed by TypeStorage
/// directly but parameter type should be constructed by Derived TypeStorage.
/// TypeId, so TypeStorage only needs to include AbstractType; For parameteric
/// type, in addition to AbstractType/TypeId, parameteric information needs to
/// be included. So that, non-parameteric type can be constructed by TypeStorage
/// directly but parameteric type should be constructed by Derived TypeStorage.
///
class TypeStorage : public StorageManager::StorageBase {
friend StorageManager;
......@@ -120,7 +120,7 @@ class TypeStorage : public StorageManager::StorageBase {
///
/// \return The AbstractType of the TypeStorage.
///
const AbstractType &abstract_type() { return *abstract_type_; }
const AbstractType &abstract_type() const { return *abstract_type_; }
private:
///
......@@ -170,10 +170,10 @@ struct TypeManager {
///
template <typename T, typename... Args>
static std::
enable_if_t<!std::is_same<typename T::StorageType, TypeStorage>::value, T>
enable_if_t<!std::is_same<typename T::Storage, TypeStorage>::value, T>
get(IrContext *ctx, TypeId type_id, Args &&...args) {
return ctx->storage_manager()
.GetParametricStorageType<typename T::StorageType>(
return ctx->type_storage_manager()
.GetParametricStorage<typename T::Storage>(
[&, type_id](TypeStorage *storage) {
storage->initialize(AbstractType::lookup(type_id, ctx));
},
......@@ -190,11 +190,11 @@ struct TypeManager {
/// \return The unique instance of Type T from IrContext.
///
template <typename T>
static std::
enable_if_t<std::is_same<typename T::StorageType, TypeStorage>::value, T>
static std::enable_if_t<std::is_same<typename T::Storage, TypeStorage>::value,
T>
get(IrContext *ctx, TypeId type_id) {
return ctx->storage_manager()
.GetParameterlessStorageType<typename T::StorageType>(type_id);
return ctx->type_storage_manager()
.GetParameterlessStorage<typename T::Storage>(type_id);
}
///
......@@ -204,8 +204,7 @@ struct TypeManager {
///
template <typename T>
static void RegisterType(IrContext *ctx) {
RegisterType<T>(ctx,
ir::TypeId::get<T>()); // class Type需要提供type_id接口
RegisterType<T>(ctx, ir::TypeId::get<T>());
}
///
......@@ -216,10 +215,10 @@ struct TypeManager {
///
template <typename T>
static std::enable_if_t<
!std::is_same<typename T::StorageType, TypeStorage>::value>
!std::is_same<typename T::Storage, TypeStorage>::value>
RegisterType(IrContext *ctx, TypeId type_id) {
ctx->storage_manager()
.RegisterParametricStorageType<typename T::StorageType>(type_id);
ctx->type_storage_manager().RegisterParametricStorage<typename T::Storage>(
type_id);
}
///
......@@ -229,10 +228,9 @@ struct TypeManager {
/// \param type_id The type id of the Type T.
///
template <typename T>
static std::enable_if_t<
std::is_same<typename T::StorageType, TypeStorage>::value>
static std::enable_if_t<std::is_same<typename T::Storage, TypeStorage>::value>
RegisterType(IrContext *ctx, TypeId type_id) {
ctx->storage_manager().RegisterParameterlessStorageType<TypeStorage>(
ctx->type_storage_manager().RegisterParameterlessStorage<TypeStorage>(
type_id, [&ctx, type_id](TypeStorage *storage) {
storage->initialize(AbstractType::lookup(type_id, ctx));
});
......@@ -244,10 +242,10 @@ struct TypeManager {
/// custom Type class.
///
#define DECLARE_TYPE_UTILITY_FUNCTOR(concrete_type, storage_type) \
using StorageType = storage_type; \
using Storage = storage_type; \
\
StorageType *storage() const { \
return static_cast<StorageType *>(this->storage_); \
const Storage *storage() const { \
return static_cast<const Storage *>(this->storage_); \
} \
\
static ir::TypeId type_id() { return ir::TypeId::get<concrete_type>(); } \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册