未验证 提交 d91d758d 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] OpTrait & OpInterface & OpInfo (#52846)

* add OpTrait OpInterface ValueIterator TypeList

* refine code

* refine code

* refine code

* add opinfo

* add typeid copy constructor

* add trait interface construct method for opinfo_impl

* add trait interface construct method for opinfo_impl

* add trait interface construct method for opinfo_impl

* add trait interface construct method for opinfo_impl

* add trait interface construct method for opinfo_impl

* add create

* add member func for opinfo

* fix compile bug

* add op interface in ircontext

* fix compile bug

* fix compile bug

* refine code

* fix compile bug

* add ut

* refine ut

* refine code of opinfo_impl

* delete unused code

* add dyncast for operation

* refine comment

* refine opinfo_impl

* delete unused code

* refine code by comment

* refine code

* refine code

* refine code for registerOp

* refine opfin create

* refine code of search method of ircontext

* refine op attribute

* change opinfo_map key from type_id to string
上级 b7295120
......@@ -59,19 +59,16 @@ DictionaryAttributeStorage::ParamKey DictionaryAttributeStorage::GetAsKey()
}
Attribute DictionaryAttributeStorage::GetValue(const StrAttribute &name) const {
if (size_ > 0) {
size_t left = 0;
size_t right = size_ - 1;
size_t mid = 0;
while (left <= right) {
mid = (left + right) / 2;
size_t right = size_;
while (left < right) {
size_t mid = left + (right - left) / 2;
if (data_[mid].name() == name) {
return data_[mid].value();
} else if (data_[mid].name() < name) {
left = mid + 1;
} else {
right = mid - 1;
}
right = mid;
}
}
return nullptr;
......
......@@ -31,4 +31,8 @@ void Dialect::RegisterAttribute(ir::AbstractAttribute &&abstract_attribute) {
this->ir_context()->RegisterAbstractAttribute(
new_abstract_attribute->type_id(), new_abstract_attribute);
}
void Dialect::RegisterOp(const std::string &name, OpInfoImpl *op_info) {
this->ir_context()->RegisterOpInfo(name, op_info);
}
} // namespace ir
......@@ -16,6 +16,7 @@
#include "paddle/ir/attribute_base.h"
#include "paddle/ir/ir_context.h"
#include "paddle/ir/op_info_impl.h"
#include "paddle/ir/type_base.h"
namespace ir {
......@@ -45,17 +46,19 @@ class Dialect {
(void)std::initializer_list<int>{0, (RegisterType<Args>(), 0)...};
}
///
/// \brief Register type of class T.
///
template <typename T>
void RegisterType() {
VLOG(4) << "Type registered into Dialect. --->";
// if (this->ir_context()->registed_abstract_type().count(
// ir::TypeId::get<T>()) == 0) {
if (this->ir_context()->GetRegisteredAbstractType(ir::TypeId::get<T>()) ==
nullptr) {
ir::AbstractType *abstract_type =
new ir::AbstractType(std::move(ir::AbstractType::get<T>(*this)));
this->ir_context()->RegisterAbstractType(ir::TypeId::get<T>(),
abstract_type);
ir::TypeManager::RegisterType<T>(this->ir_context());
}
VLOG(4) << "----------------------------------";
}
......@@ -78,24 +81,42 @@ class Dialect {
(void)std::initializer_list<int>{0, (RegisterAttribute<Args>(), 0)...};
}
///
/// \brief Register attribute of class T.
///
template <typename T>
void RegisterAttribute() {
VLOG(4) << "Attribute registered into Dialect. --->";
if (this->ir_context()->GetRegisteredAbstractAttribute(
ir::TypeId::get<T>()) == nullptr) {
ir::AbstractAttribute *abstract_attribute = new ir::AbstractAttribute(
std::move(ir::AbstractAttribute::get<T>(*this)));
this->ir_context()->RegisterAbstractAttribute(ir::TypeId::get<T>(),
abstract_attribute);
ir::AttributeManager::RegisterAttribute<T>(this->ir_context());
}
VLOG(4) << "----------------------------------";
}
void RegisterAttribute(ir::AbstractAttribute &&abstract_attribute);
///
/// \brief Register abstract_attribute into context.
/// \brief Register Operation methods.
///
void RegisterAttribute(ir::AbstractAttribute &&abstract_attribute);
template <typename... Args>
void RegisterOps() {
(void)std::initializer_list<int>{0, (RegisterOp<Args>(), 0)...};
}
template <typename ConcertOp>
void RegisterOp() {
std::string name = this->name() + "." + std::string(ConcertOp::name());
VLOG(4) << "Op " << name << " registered into Dialect. --->";
if (this->ir_context()->GetRegisteredOpInfo(name) == nullptr) {
ir::OpInfoImpl *op_info = ir::OpInfoImpl::create<ConcertOp>(this);
this->ir_context()->RegisterOpInfo(name, op_info);
}
VLOG(4) << "----------------------------------";
}
void RegisterOp(const std::string &name, OpInfoImpl *op_info);
private:
std::string name_;
......
......@@ -20,6 +20,7 @@
#include "paddle/ir/builtin_dialect.h"
#include "paddle/ir/builtin_type.h"
#include "paddle/ir/dialect.h"
#include "paddle/ir/op_info_impl.h"
#include "paddle/ir/spin_lock.h"
#include "paddle/ir/type_base.h"
......@@ -46,6 +47,11 @@ class IrContextImpl {
delete dialect_map.second;
}
registed_dialect_.clear();
for (auto &op_map : registed_op_infos_) {
op_map.second->destroy();
}
registed_op_infos_.clear();
}
void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) {
......@@ -93,6 +99,25 @@ class IrContextImpl {
return nullptr;
}
void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) {
std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
VLOG(4) << "Register an operation of: [Name=" << name
<< ", OpInfoImpl ptr=" << opinfo << "].";
registed_op_infos_.emplace(name, opinfo);
}
OpInfoImpl *GetOpInfo(const std::string &name) {
std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
auto iter = registed_op_infos_.find(name);
if (iter != registed_op_infos_.end()) {
VLOG(4) << "Fonund a cached operation of: [name=" << name
<< ", OpInfoImpl ptr=" << iter->second << "].";
return iter->second;
}
LOG(WARNING) << "No cache found operation of: [Name=" << name << "].";
return nullptr;
}
void RegisterDialect(std::string name, Dialect *dialect) {
std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
VLOG(4) << "Register a dialect of: [name=" << name
......@@ -135,6 +160,10 @@ class IrContextImpl {
std::unordered_map<std::string, Dialect *> registed_dialect_;
ir::SpinLock registed_dialect_lock_;
// The Op registered in the context.
std::unordered_map<std::string, OpInfoImpl *> registed_op_infos_;
ir::SpinLock registed_op_infos_lock_;
ir::SpinLock destructor_lock_;
};
......@@ -165,9 +194,12 @@ StorageManager &IrContext::type_storage_manager() {
return impl().registed_type_storage_manager_;
}
std::unordered_map<TypeId, AbstractType *>
&IrContext::registed_abstracted_type() {
return impl().registed_abstract_types_;
AbstractType *IrContext::GetRegisteredAbstractType(TypeId id) {
auto search = impl().registed_abstract_types_.find(id);
if (search != impl().registed_abstract_types_.end()) {
return search->second;
}
return nullptr;
}
void IrContext::RegisterAbstractAttribute(
......@@ -179,9 +211,12 @@ StorageManager &IrContext::attribute_storage_manager() {
return impl().registed_attribute_storage_manager_;
}
std::unordered_map<TypeId, AbstractAttribute *>
&IrContext::registed_abstracted_attribute() {
return impl().registed_abstract_attributes_;
AbstractAttribute *IrContext::GetRegisteredAbstractAttribute(TypeId id) {
auto search = impl().registed_abstract_attributes_.find(id);
if (search != impl().registed_abstract_attributes_.end()) {
return search->second;
}
return nullptr;
}
Dialect *IrContext::GetOrRegisterDialect(
......@@ -216,6 +251,17 @@ Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) {
return nullptr;
}
OpInfoImpl *IrContext::GetRegisteredOpInfo(const std::string &name) {
OpInfoImpl *rtn = impl().GetOpInfo(name);
return rtn ? rtn : nullptr;
}
void IrContext::RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) {
if (impl().GetOpInfo(name) == nullptr) {
impl().RegisterOpInfo(name, opinfo);
}
}
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
auto &impl = ctx->impl();
AbstractType *abstract_type = impl.GetAbstractType(type_id);
......
......@@ -26,6 +26,7 @@ class AbstractType;
class AbstractAttribute;
class TypeId;
class Dialect;
class OpInfoImpl;
///
/// \brief IrContext is a global parameterless class used to store and manage
......@@ -47,7 +48,7 @@ class IrContext {
IrContextImpl &impl() { return *impl_; }
///
/// \brief Register an AbstractType to IrContext
/// \brief Register an AbstractType to IrContext.
///
/// \param type_id The type id of the AbstractType.
/// \param abstract_type AbstractType* provided by user.
......@@ -64,13 +65,9 @@ class IrContext {
StorageManager &type_storage_manager();
///
/// \brief Returns the storage uniquer used for constructing TypeStorage
/// instances.
///
/// \return The storage uniquer used for constructing TypeStorage
/// instances.
/// \brief Get registered AbstractType from IrContext.
///
std::unordered_map<TypeId, AbstractType *> &registed_abstracted_type();
AbstractType *GetRegisteredAbstractType(TypeId id);
///
/// \brief Register an AbstractAttribute to IrContext
......@@ -91,14 +88,16 @@ class IrContext {
StorageManager &attribute_storage_manager();
///
/// \brief Returns the storage uniquer used for constructing AttributeStorage
/// instances.
/// \brief Get registered AbstractAttribute from IrContext.
///
/// \return The storage uniquer used for constructing AttributeStorage
/// instances.
AbstractAttribute *GetRegisteredAbstractAttribute(TypeId id);
///
/// \brief Get or register operaiton.
///
std::unordered_map<TypeId, AbstractAttribute *>
&registed_abstracted_attribute();
void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo);
OpInfoImpl *GetRegisteredOpInfo(const std::string &name);
///
/// \brief Get the dialect of the DialectT class in the context, ff not found,
......
......@@ -15,23 +15,59 @@
#pragma once
#include "paddle/ir/operation.h"
#include "paddle/ir/utils.h"
namespace ir {
class OpBase {
public:
Operation *operation() { return operation_; }
explicit OpBase(const Operation *operation) : operation_(operation) {}
explicit operator bool() { return operation() != nullptr; }
const Operation *operation() const { return operation_; }
operator Operation *() const { return operation_; }
explicit operator bool() const { return operation() != nullptr; }
Operation *operator->() const { return operation_; }
operator const Operation *() const { return operation_; }
protected:
explicit OpBase(Operation *operation) : operation_(operation) {}
const Operation *operator->() const { return operation_; }
private:
Operation *operation_;
const Operation *operation_; // Not owned
};
///
/// \brief OpTrait
///
template <class ConcreteTrait>
class OpTraitBase : public OpBase {
public:
explicit OpTraitBase(const Operation *op) : OpBase(op) {}
static TypeId GetTraitId() { return TypeId::get<ConcreteTrait>(); }
};
///
/// \brief OpInterface
///
template <typename ConcreteInterface>
class OpInterfaceBase : public OpBase {
public:
// explicit OpInterfaceBase(Operation *op) : OpBase(op) {}
explicit OpInterfaceBase(const Operation *op) : OpBase(op) {}
static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }
};
template <typename ConcreteOp, class... TraitOrInterface>
class Op : public OpBase {
public:
using OpBase::OpBase;
using TraitList =
typename Filter<OpTraitBase, std::tuple<TraitOrInterface...>>::Type;
using InterfaceList =
typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include "paddle/ir/op_info_impl.h"
namespace ir {
class OpInfo {
public:
constexpr OpInfo() = default;
OpInfo(const OpInfoImpl *impl) : impl_(impl) {} // NOLINT
OpInfo(const OpInfo &other) = default;
OpInfo &operator=(const OpInfo &other) = default;
bool operator==(OpInfo other) const { return impl_ == other.impl_; }
bool operator!=(OpInfo other) const { return impl_ != other.impl_; }
explicit operator bool() const { return impl_; }
bool operator!() const { return impl_ == nullptr; }
const OpInfoImpl *impl() const { return impl_; }
template <typename Trait>
bool HasTrait() const {
return impl_->HasTrait<Trait>();
}
template <typename Interface>
bool HasInterface() const {
return impl_->HasInterface<Interface>();
}
friend struct std::hash<OpInfo>;
private:
const OpInfoImpl *impl_{nullptr}; // not owned
};
} // namespace ir
namespace std {
template <>
struct hash<ir::OpInfo> {
std::size_t operator()(const ir::OpInfo &obj) const {
return std::hash<const ir::OpInfoImpl *>()(obj.impl_);
}
};
} // namespace std
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cstring>
#include <initializer_list>
#include <utility>
#include "paddle/ir/builtin_attribute.h"
// #include "paddle/ir/ir_context.h"
#include "paddle/ir/type.h"
namespace ir {
class Dialect;
///
/// \brief Tool template class for construct interfaces or Traits.
///
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits {
public:
/// Construct method for interfaces.
static std::pair<TypeId, void *> *interface(
std::pair<TypeId, void *> *p_interface) {
(void)std::initializer_list<int>{
0, (PlacementConstrctInterface<Args>(p_interface), 0)...};
return p_interface;
}
/// Construct method for traits.
static TypeId *trait(TypeId *p_trait) {
(void)std::initializer_list<int>{
0, (PlacementConstrctTrait<Args>(p_trait), 0)...};
return p_trait;
}
private:
/// Placement new interface.
template <typename T>
static void PlacementConstrctInterface(
std::pair<TypeId, void *> *&p_interface) { // NOLINT
new (&(p_interface->first)) TypeId(ir::TypeId::get<T>());
p_interface->second =
malloc(sizeof(typename T::template Model<ConcreteOp>));
new (p_interface->second) typename T::template Model<ConcreteOp>();
VLOG(4) << "New a interface: id[" << p_interface->first.storage()
<< "], interface[" << p_interface->second << "].";
++p_interface;
}
/// Placement new trait.
template <typename T>
static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT
new (p_trait) TypeId(ir::TypeId::get<T>());
VLOG(4) << "New a trait: id[" << (*p_trait).storage() << "].";
++p_trait;
}
};
/// Specialized for tuple type.
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits<ConcreteOp, std::tuple<Args...>> {
public:
/// Construct method for interfaces.
static std::pair<TypeId, void *> *interface(
std::pair<TypeId, void *> *p_interface) {
return ConstructInterfacesOrTraits<ConcreteOp, Args...>::interface(
p_interface);
}
/// Construct method for traits.
static TypeId *trait(TypeId *p_trait) {
return ConstructInterfacesOrTraits<ConcreteOp, Args...>::trait(p_trait);
}
};
///
/// \brief OpInfoImpl class.
///
class OpInfoImpl {
public:
///
/// \brief Construct and Deconstruct OpInfoImpl. The memory layout of
/// OpInfoImpl is: std::pair<TypeId, void *>... | TypeId... | OpInfoImpl
///
template <typename ConcreteOp>
static OpInfoImpl *create(ir::Dialect *dialect) {
// (1) Malloc memory for interfaces, traits, opinfo_impl.
size_t interfaces_num =
std::tuple_size<typename ConcreteOp::InterfaceList>::value;
size_t traits_num = std::tuple_size<typename ConcreteOp::TraitList>::value;
size_t attributes_num = ConcreteOp::attributes_num();
VLOG(4) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, "
<< traits_num << " traits, " << attributes_num << " attributes.";
size_t base_size = sizeof(std::pair<ir::TypeId, void *>) * interfaces_num +
sizeof(ir::TypeId) * traits_num + sizeof(OpInfoImpl);
void *base_ptr = malloc(base_size);
VLOG(4) << "Malloc " << base_size << " Bytes at " << base_ptr;
// (2) Construct interfaces and sort by TypeId.
std::pair<ir::TypeId, void *> *p_first_interface = nullptr;
if (interfaces_num > 0) {
p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(base_ptr);
VLOG(4) << "Construct interfaces at " << p_first_interface << " ......";
ConstructInterfacesOrTraits<
ConcreteOp,
typename ConcreteOp::InterfaceList>::interface(p_first_interface);
std::sort(p_first_interface, p_first_interface + interfaces_num);
base_ptr = reinterpret_cast<void *>(p_first_interface + interfaces_num);
}
// (3) Construct traits and sort by TypeId.
ir::TypeId *p_first_trait = nullptr;
if (traits_num > 0) {
p_first_trait = reinterpret_cast<ir::TypeId *>(base_ptr);
VLOG(4) << "Construct traits at " << p_first_trait << " ......";
ConstructInterfacesOrTraits<ConcreteOp, typename ConcreteOp::TraitList>::
trait(p_first_trait);
std::sort(p_first_trait, p_first_trait + traits_num);
base_ptr = reinterpret_cast<void *>(p_first_trait + traits_num);
}
// (4) Construct opinfo_impl.
OpInfoImpl *p_opinfo_impl = reinterpret_cast<OpInfoImpl *>(base_ptr);
VLOG(4) << "Construct op_info_impl at " << p_opinfo_impl << " ......";
OpInfoImpl *op_info =
new (p_opinfo_impl) OpInfoImpl(interfaces_num,
traits_num,
ConcreteOp::attributes_name_,
attributes_num,
ir::TypeId::get<ConcreteOp>(),
ConcreteOp::name(),
dialect);
return op_info;
}
void destroy() {
VLOG(4) << "Destroy op_info impl at " << this;
// (1) free interfaces
void *base_ptr = reinterpret_cast<void *>(
reinterpret_cast<char *>(this) - sizeof(ir::TypeId) * num_traits_ -
sizeof(std::pair<ir::TypeId, void *>) * num_interfaces_);
if (num_interfaces_ > 0) {
std::pair<ir::TypeId, void *> *p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(base_ptr);
for (size_t i = 0; i < num_interfaces_; i++) {
free((p_first_interface + i)->second);
}
}
// (2) free memeory
VLOG(4) << "Free base_ptr " << base_ptr;
free(base_ptr);
}
///
/// \brief Search methods for Trait or Interface.
///
template <typename Trait>
bool HasTrait() const {
return HasTrait(TypeId::get<Trait>());
}
bool HasTrait(TypeId trait_id) const {
if (num_traits_ > 0) {
TypeId *p_first_trait = reinterpret_cast<TypeId *>(
reinterpret_cast<char *>(const_cast<OpInfoImpl *>(this)) -
sizeof(ir::TypeId) * num_traits_);
return std::binary_search(
p_first_trait, p_first_trait + num_traits_, trait_id);
}
return false;
}
template <typename Interface>
bool HasInterface() const {
return HasInterface(TypeId::get<Interface>());
}
bool HasInterface(TypeId interface_id) const {
if (num_interfaces_ > 0) {
std::pair<ir::TypeId, void *> *p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(
reinterpret_cast<char *>(const_cast<OpInfoImpl *>(this)) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(std::pair<ir::TypeId, void *>) * num_interfaces_);
return std::binary_search(p_first_interface,
p_first_interface + num_interfaces_,
std::make_pair(interface_id, nullptr),
CompareInterface);
}
return false;
}
template <typename Interface>
typename Interface::Concept *GetInterfaceImpl() const {
if (num_interfaces_ > 0) {
ir::TypeId interface_id = ir::TypeId::get<Interface>();
std::pair<ir::TypeId, void *> *p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(
reinterpret_cast<char *>(const_cast<OpInfoImpl *>(this)) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(std::pair<ir::TypeId, void *>) * num_interfaces_);
size_t left = 0;
size_t right = num_interfaces_;
while (left < right) {
size_t mid = left + (right - left) / 2;
if ((p_first_interface + mid)->first == interface_id) {
return reinterpret_cast<typename Interface::Concept *>(
(p_first_interface + mid)->second);
} else if ((p_first_interface + mid)->first < interface_id) {
left = mid + 1;
} else {
right = mid;
}
}
}
return nullptr;
}
ir::TypeId id() const { return op_id_; }
const char *name() const { return op_name_; }
ir::Dialect *dialect() const { return dialect_; }
private:
OpInfoImpl(uint32_t num_interfaces,
uint32_t num_traits,
const char **p_attributes,
uint32_t num_attributes,
TypeId op_id,
const char *op_name,
ir::Dialect *dialect)
: num_interfaces_(num_interfaces),
num_traits_(num_traits),
p_attributes_(p_attributes),
num_attributes_(num_attributes),
op_id_(op_id),
op_name_(op_name),
dialect_(dialect) {}
static bool CompareInterface(const std::pair<ir::TypeId, void *> &a,
const std::pair<ir::TypeId, void *> &b) {
return a.first < b.first;
}
/// Interface will be recorded by std::pair<TypeId, void*>.
uint32_t num_interfaces_ = 0;
/// Trait will be recorded by TypeId.
uint32_t num_traits_ = 0;
/// Attributes array address.
const char **p_attributes_{nullptr};
/// The number of attributes for this Op.
uint32_t num_attributes_ = 0;
/// The TypeId of this Op.
TypeId op_id_;
/// The name of this Op.
const char *op_name_;
/// The dialect of this Op belong to.
ir::Dialect *dialect_;
};
} // namespace ir
......@@ -21,7 +21,8 @@ namespace ir {
// OpInlineResult, Operation, Operand.
Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
ir::DictionaryAttribute attribute) {
ir::DictionaryAttribute attribute,
ir::OpInfo op_info) {
// 1. Calculate the required memory size for OpResults + Operation +
// OpOperands.
uint32_t num_results = output_types.size();
......@@ -52,7 +53,7 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
}
// 3.2. Construct Operation.
Operation *op =
new (base_ptr) Operation(num_results, num_operands, attribute);
new (base_ptr) Operation(num_results, num_operands, attribute, op_info);
base_ptr += sizeof(Operation);
// 3.3. Construct OpOperands.
if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
......@@ -116,13 +117,15 @@ void Operation::destroy() {
Operation::Operation(uint32_t num_results,
uint32_t num_operands,
ir::DictionaryAttribute attribute) {
ir::DictionaryAttribute attribute,
ir::OpInfo op_info) {
if (!attribute) {
throw("unexpected null attribute dictionary");
}
num_results_ = num_results;
num_operands_ = num_operands;
attribute_ = attribute;
op_info_ = op_info;
}
ir::OpResult Operation::GetResultByIndex(uint32_t index) {
......
......@@ -15,10 +15,15 @@
#pragma once
#include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/op_info.h"
#include "paddle/ir/type.h"
#include "paddle/ir/value_impl.h"
namespace ir {
template <class ConcreteTrait>
class OpTraitBase;
template <typename ConcreteInterface>
class OpInterfaceBase;
class alignas(8) Operation final {
public:
......@@ -28,7 +33,8 @@ class alignas(8) Operation final {
///
static Operation *create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
ir::DictionaryAttribute attribute);
ir::DictionaryAttribute attribute,
ir::OpInfo op_info);
void destroy();
......@@ -36,19 +42,60 @@ class alignas(8) Operation final {
std::string print();
ir::DictionaryAttribute attribute() { return attribute_; }
ir::DictionaryAttribute attribute() const { return attribute_; }
uint32_t num_results() { return num_results_; }
ir::OpInfo op_info() const { return op_info_; }
uint32_t num_operands() { return num_operands_; }
uint32_t num_results() const { return num_results_; }
uint32_t num_operands() const { return num_operands_; }
template <typename T>
T dyn_cast() const {
return CastUtil<T>::call(this);
}
template <typename Trait>
bool HasTrait() const {
return op_info_.HasTrait<Trait>();
}
template <typename Interface>
bool HasInterface() const {
return op_info_.HasInterface<Interface>();
}
private:
Operation(uint32_t num_results,
uint32_t num_operands,
ir::DictionaryAttribute attribute);
ir::DictionaryAttribute attribute,
ir::OpInfo op_info);
template <typename T, typename Enabler = void>
struct CastUtil {
static T call(const Operation *op) {
throw("Can't dyn_cast to T, T should be a Trait or Interface");
}
};
template <typename T>
struct CastUtil<T,
typename std::enable_if<
std::is_base_of<OpTraitBase<T>, T>::value>::type> {
static T call(const Operation *op) { return T(op); }
};
template <typename T>
struct CastUtil<T,
typename std::enable_if<
std::is_base_of<OpInterfaceBase<T>, T>::value>::type> {
static T call(const Operation *op) {
return T(op, op->op_info_.impl()->GetInterfaceImpl<T>());
}
};
ir::DictionaryAttribute attribute_;
ir::OpInfo op_info_;
uint32_t num_results_ = 0;
uint32_t num_operands_ = 0;
......
......@@ -45,6 +45,12 @@ class TypeId {
return TypeId(&instance);
}
TypeId(const TypeId &other) = default;
TypeId &operator=(const TypeId &other) = default;
const Storage *storage() const { return storage_; }
///
/// \brief Comparison operations.
///
......@@ -54,6 +60,9 @@ class TypeId {
inline bool operator!=(const TypeId &other) const {
return !(*this == other);
}
inline bool operator<(const TypeId &other) const {
return storage_ < other.storage_;
}
///
/// \brief Enable hashing TypeId instances.
......
......@@ -17,12 +17,107 @@
#include <cassert>
#include <cstdint>
#include <cstdlib>
#include <tuple>
#include <type_traits>
namespace ir {
///
/// \brief Equivalent to boost::hash_combine.
///
std::size_t hash_combine(std::size_t lhs, std::size_t rhs);
///
/// \brief Aligned malloc and free functions.
///
void *aligned_malloc(size_t size, size_t alignment);
void aligned_free(void *mem_ptr);
///
/// \brief Some template methods for manipulating std::tuple.
///
/// (1) Pop front element from Tuple
template <typename Tuple>
struct PopFrontT;
template <typename Head, typename... Tail>
struct PopFrontT<std::tuple<Head, Tail...>> {
public:
using Type = std::tuple<Tail...>;
};
template <typename Tuple>
using PopFront = typename PopFrontT<Tuple>::Type;
/// (2) Push front element to Tuple
template <typename NewElement, typename Tuple>
struct PushFrontT;
template <typename NewElement, typename... Elements>
struct PushFrontT<NewElement, std::tuple<Elements...>> {
public:
using Type = std::tuple<NewElement, Elements...>;
};
template <typename NewElement, typename... Elements>
struct PushFrontT<std::tuple<NewElement>, std::tuple<Elements...>> {
public:
using Type = std::tuple<NewElement, Elements...>;
};
template <typename NewElement, typename Tuple>
using PushFront = typename PushFrontT<NewElement, Tuple>::Type;
/// (3) IsEmpty
template <typename Tuple>
struct IsEmpty {
static constexpr bool value = false;
};
template <>
struct IsEmpty<std::tuple<>> {
static constexpr bool value = true;
};
/// (4) IfThenElseT
template <bool COND, typename TrueT, typename FalseT>
struct IfThenElseT {
using Type = TrueT;
};
template <typename TrueT, typename FalseT>
struct IfThenElseT<false, TrueT, FalseT> {
using Type = FalseT;
};
template <bool COND, typename TrueT, typename FalseT>
using IfThenElse = typename IfThenElseT<COND, TrueT, FalseT>::Type;
/// (5) Filter out all types inherited from BaseT from the tuple.
template <template <typename> class BaseT,
typename Tuple,
bool Empty = IsEmpty<Tuple>::value>
struct Filter;
template <template <typename> class BaseT, typename Tuple>
struct Filter<BaseT, Tuple, false> {
private:
using Matched =
IfThenElse<std::is_base_of<BaseT<std::tuple_element_t<0, Tuple>>,
std::tuple_element_t<0, Tuple>>::value,
std::tuple<std::tuple_element_t<0, Tuple>>,
std::tuple<>>;
using Rest = typename Filter<BaseT, PopFront<Tuple>>::Type;
public:
using Type =
IfThenElse<IsEmpty<Matched>::value, Rest, PushFront<Matched, Rest>>;
};
// basis case:
template <template <typename> class BaseT, typename Tuple>
struct Filter<BaseT, Tuple, true> {
using Type = std::tuple<>;
};
} // namespace ir
......@@ -75,6 +75,12 @@ Operation *Value::GetDefiningOp() const {
std::string Value::print_ud_chain() { return impl_->print_ud_chain(); }
Value::use_iterator Value::begin() const {
return ir::OpOperand(impl_->first_use());
}
Value::use_iterator Value::end() const { return Value::use_iterator(); }
// OpResult
bool OpResult::classof(Value value) {
return ir::isa<detail::OpResultImpl>(value.impl());
......
......@@ -56,6 +56,38 @@ class OpOperand {
detail::OpOperandImpl *impl_{nullptr};
};
///
/// \brief Value Iterator
///
template <typename OperandType>
class ValueUseIterator {
public:
ValueUseIterator(OperandType use = nullptr) : current_(use) {} // NOLINT
bool operator==(const ValueUseIterator<OperandType> &rhs) const {
return current_ == rhs.current_;
}
ir::Operation *owner() const { return current_.impl()->owner(); }
OperandType get() const { return current_; }
OperandType operator*() const { return get(); }
ValueUseIterator<OperandType> &operator++() {
current_ = current_.impl()->next_use();
return *this;
}
ValueUseIterator<OperandType> operator++(int) {
ValueUseIterator<OperandType> tmp = *this;
++*(this);
return tmp;
}
protected:
OperandType current_;
};
///
/// \brief Value class represents the SSA value in the IR system. This class
/// only provides interfaces, for specific implementation, see Impl class.
......@@ -96,6 +128,15 @@ class Value {
std::string print_ud_chain();
///
/// \brief Provide iterator interface to access Value use chain.
///
using use_iterator = ValueUseIterator<OpOperand>;
use_iterator begin() const;
use_iterator end() const;
friend struct std::hash<Value>;
protected:
......
......@@ -2,4 +2,5 @@ if(WITH_NEWIR)
cc_test_old(type_test SRCS type_test.cc DEPS new_ir gtest)
cc_test_old(ir_attribute_test SRCS ir_attribute_test.cc DEPS new_ir gtest)
cc_test_old(ir_value_test SRCS ir_value_test.cc DEPS new_ir gtest)
cc_test_old(ir_op_test SRCS ir_op_test.cc DEPS new_ir gtest)
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/ir/builtin_type.h"
#include "paddle/ir/dialect.h"
#include "paddle/ir/ir_context.h"
#include "paddle/ir/op_base.h"
/// \brief Define built-in Trait, derived from OpTraitBase.
class ReadOnlyTrait : public ir::OpTraitBase<ReadOnlyTrait> {
public:
explicit ReadOnlyTrait(const ir::Operation *op)
: ir::OpTraitBase<ReadOnlyTrait>(op) {}
};
/// \brief Define built-in Interface, derived from OpInterfaceBase. Concepts and
/// Models need to be defined within the class. Concept defines abstract
/// interface functions, and Model is a template class that defines the specific
/// implementation of interface functions based on template parameters.
class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> {
public:
struct Concept {
explicit Concept(void (*infer_shape)(const ir::Operation *))
: infer_shape_(infer_shape) {}
void (*infer_shape_)(const ir::Operation *);
};
template <class ConcreteOp>
struct Model : public Concept {
static void InferShape(const ir::Operation *op) {
ConcreteOp concret_op = ConcreteOp(op);
if (concret_op == nullptr) throw("concret_op is nullptr");
concret_op.InferShape();
}
Model() : Concept(InferShape) {
if (sizeof(Model) != sizeof(Concept)) {
throw("sizeof(Model) != sizeof(Concept)");
}
}
};
InferShapeInterface(const ir::Operation *op, Concept *impl)
: ir::OpInterfaceBase<InferShapeInterface>(op), impl_(impl) {}
void InferShape() { impl_->infer_shape_(operation()); }
private:
Concept *impl_;
};
// Define op1.
class Operation1 : public ir::Op<Operation1> {
public:
using Op::Op;
static const char *name() { return "Operation1"; }
static const char *attributes_name_[];
static uint32_t attributes_num() { return 2; }
};
const char *Operation1::attributes_name_[] = {"op1_attr1", "op1_attr2"};
// Define op2.
class Operation2
: public ir::Op<Operation2, ReadOnlyTrait, InferShapeInterface> {
public:
using Op::Op;
static const char *name() { return "Operation2"; }
static const char *attributes_name_[];
static uint32_t attributes_num() { return 2; }
static void InferShape() {
std::cout << "This is op2's InferShape interface." << std::endl;
}
};
const char *Operation2::attributes_name_[] = {"op2_attr1", "op2_attr2"};
// Define a dialect, op1 and op2 will be registered by this dialect.
class TestDialect : public ir::Dialect {
public:
explicit TestDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<TestDialect>()) {
initialize();
}
static const char *name() { return "op_test"; }
private:
void initialize() { RegisterOps<Operation1, Operation2>(); }
};
ir::DictionaryAttribute CreateAttribute(std::string attribute_name,
std::string attribute) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::StrAttribute attr_name = ir::StrAttribute::get(ctx, attribute_name);
ir::Attribute attr_value = ir::StrAttribute::get(ctx, attribute);
std::map<ir::StrAttribute, ir::Attribute> named_attr;
named_attr.insert(
std::pair<ir::StrAttribute, ir::Attribute>(attr_name, attr_value));
return ir::DictionaryAttribute::get(ctx, named_attr);
}
TEST(op_test, op_test) {
// (1) Register Dialect, Operation1, Operation2 into IrContext.
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Dialect *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
std::cout << test_dialect << std::endl;
// (2) Get registered operations.
std::string op1_name =
test_dialect->name() + "." + std::string(Operation1::name());
ir::OpInfoImpl *op1_info = ctx->GetRegisteredOpInfo(op1_name);
EXPECT_EQ(op1_info != nullptr, true);
std::string op2_name =
test_dialect->name() + "." + std::string(Operation2::name());
ir::OpInfoImpl *op2_info = ctx->GetRegisteredOpInfo(op2_name);
EXPECT_EQ(op2_info != nullptr, true);
EXPECT_EQ(op1_info->HasTrait<ReadOnlyTrait>(), false);
EXPECT_EQ(op1_info->HasInterface<InferShapeInterface>(), false);
EXPECT_EQ(op2_info->HasTrait<ReadOnlyTrait>(), true);
EXPECT_EQ(op2_info->HasInterface<InferShapeInterface>(), true);
// (3) Test uses for op.
std::vector<ir::OpResult> op_inputs = {};
std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op =
ir::Operation::create(op_inputs,
op_output_types,
CreateAttribute("op1_name", "op1_attr"),
op2_info);
if (op->HasTrait<ReadOnlyTrait>()) {
ReadOnlyTrait trait = op->dyn_cast<ReadOnlyTrait>();
EXPECT_EQ(trait.operation(), op);
}
if (op->HasInterface<InferShapeInterface>()) {
InferShapeInterface interface = op->dyn_cast<InferShapeInterface>();
interface.InferShape();
}
op->destroy();
}
......@@ -40,21 +40,30 @@ TEST(value_test, value_test) {
// 1. Construct OP1: a = OP1()
std::vector<ir::OpResult> op1_inputs = {};
std::vector<ir::Type> op1_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op1 = ir::Operation::create(
op1_inputs, op1_output_types, CreateAttribute("op1_name", "op1_attr"));
ir::Operation *op1 =
ir::Operation::create(op1_inputs,
op1_output_types,
CreateAttribute("op1_name", "op1_attr"),
nullptr);
std::cout << op1->print() << std::endl;
// 2. Construct OP2: b = OP2();
std::vector<ir::OpResult> op2_inputs = {};
std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op2 = ir::Operation::create(
op2_inputs, op2_output_types, CreateAttribute("op2_name", "op2_attr"));
ir::Operation *op2 =
ir::Operation::create(op2_inputs,
op2_output_types,
CreateAttribute("op2_name", "op2_attr"),
nullptr);
std::cout << op2->print() << std::endl;
// 3. Construct OP3: c = OP3(a, b);
std::vector<ir::OpResult> op3_inputs = {op1->GetResultByIndex(0),
op2->GetResultByIndex(0)};
std::vector<ir::Type> op3_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op3 = ir::Operation::create(
op3_inputs, op3_output_types, CreateAttribute("op3_name", "op3_attr"));
ir::Operation *op3 =
ir::Operation::create(op3_inputs,
op3_output_types,
CreateAttribute("op3_name", "op3_attr"),
nullptr);
std::cout << op3->print() << std::endl;
// 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c);
std::vector<ir::OpResult> op4_inputs = {op1->GetResultByIndex(0),
......@@ -63,8 +72,11 @@ TEST(value_test, value_test) {
for (size_t i = 0; i < 7; i++) {
op4_output_types.push_back(ir::Float32Type::get(ctx));
}
ir::Operation *op4 = ir::Operation::create(
op4_inputs, op4_output_types, CreateAttribute("op4_name", "op4_attr"));
ir::Operation *op4 =
ir::Operation::create(op4_inputs,
op4_output_types,
CreateAttribute("op4_name", "op4_attr"),
nullptr);
std::cout << op4->print() << std::endl;
// Test 1:
......@@ -86,6 +98,12 @@ TEST(value_test, value_test) {
EXPECT_EQ(op4_first_input->next_use(), op3_first_input);
EXPECT_EQ(op3_first_input->next_use(), nullptr);
// Test 3: Value iterator
ir::Value::use_iterator iter = op1->GetResultByIndex(0).begin();
EXPECT_EQ(iter.owner(), op4);
++iter;
EXPECT_EQ(iter.owner(), op3);
// destroy
std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op4->destroy();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册