未验证 提交 d91d758d 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] OpTrait & OpInterface & OpInfo (#52846)

* add OpTrait OpInterface ValueIterator TypeList

* refine code

* refine code

* refine code

* add opinfo

* add typeid copy constructor

* add trait interface construct method for opinfo_impl

* add trait interface construct method for opinfo_impl

* add trait interface construct method for opinfo_impl

* add trait interface construct method for opinfo_impl

* add trait interface construct method for opinfo_impl

* add create

* add member func for opinfo

* fix compile bug

* add op interface in ircontext

* fix compile bug

* fix compile bug

* refine code

* fix compile bug

* add ut

* refine ut

* refine code of opinfo_impl

* delete unused code

* add dyncast for operation

* refine comment

* refine opinfo_impl

* delete unused code

* refine code by comment

* refine code

* refine code

* refine code for registerOp

* refine opfin create

* refine code of search method of ircontext

* refine op attribute

* change opinfo_map key from type_id to string
上级 b7295120
...@@ -59,19 +59,16 @@ DictionaryAttributeStorage::ParamKey DictionaryAttributeStorage::GetAsKey() ...@@ -59,19 +59,16 @@ DictionaryAttributeStorage::ParamKey DictionaryAttributeStorage::GetAsKey()
} }
Attribute DictionaryAttributeStorage::GetValue(const StrAttribute &name) const { Attribute DictionaryAttributeStorage::GetValue(const StrAttribute &name) const {
if (size_ > 0) { size_t left = 0;
size_t left = 0; size_t right = size_;
size_t right = size_ - 1; while (left < right) {
size_t mid = 0; size_t mid = left + (right - left) / 2;
while (left <= right) { if (data_[mid].name() == name) {
mid = (left + right) / 2; return data_[mid].value();
if (data_[mid].name() == name) { } else if (data_[mid].name() < name) {
return data_[mid].value(); left = mid + 1;
} else if (data_[mid].name() < name) { } else {
left = mid + 1; right = mid;
} else {
right = mid - 1;
}
} }
} }
return nullptr; return nullptr;
......
...@@ -31,4 +31,8 @@ void Dialect::RegisterAttribute(ir::AbstractAttribute &&abstract_attribute) { ...@@ -31,4 +31,8 @@ void Dialect::RegisterAttribute(ir::AbstractAttribute &&abstract_attribute) {
this->ir_context()->RegisterAbstractAttribute( this->ir_context()->RegisterAbstractAttribute(
new_abstract_attribute->type_id(), new_abstract_attribute); new_abstract_attribute->type_id(), new_abstract_attribute);
} }
void Dialect::RegisterOp(const std::string &name, OpInfoImpl *op_info) {
this->ir_context()->RegisterOpInfo(name, op_info);
}
} // namespace ir } // namespace ir
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/ir/attribute_base.h" #include "paddle/ir/attribute_base.h"
#include "paddle/ir/ir_context.h" #include "paddle/ir/ir_context.h"
#include "paddle/ir/op_info_impl.h"
#include "paddle/ir/type_base.h" #include "paddle/ir/type_base.h"
namespace ir { namespace ir {
...@@ -45,17 +46,19 @@ class Dialect { ...@@ -45,17 +46,19 @@ class Dialect {
(void)std::initializer_list<int>{0, (RegisterType<Args>(), 0)...}; (void)std::initializer_list<int>{0, (RegisterType<Args>(), 0)...};
} }
///
/// \brief Register type of class T.
///
template <typename T> template <typename T>
void RegisterType() { void RegisterType() {
VLOG(4) << "Type registered into Dialect. --->"; VLOG(4) << "Type registered into Dialect. --->";
ir::AbstractType *abstract_type = // if (this->ir_context()->registed_abstract_type().count(
new ir::AbstractType(std::move(ir::AbstractType::get<T>(*this))); // ir::TypeId::get<T>()) == 0) {
this->ir_context()->RegisterAbstractType(ir::TypeId::get<T>(), if (this->ir_context()->GetRegisteredAbstractType(ir::TypeId::get<T>()) ==
abstract_type); nullptr) {
ir::TypeManager::RegisterType<T>(this->ir_context()); ir::AbstractType *abstract_type =
new ir::AbstractType(std::move(ir::AbstractType::get<T>(*this)));
this->ir_context()->RegisterAbstractType(ir::TypeId::get<T>(),
abstract_type);
ir::TypeManager::RegisterType<T>(this->ir_context());
}
VLOG(4) << "----------------------------------"; VLOG(4) << "----------------------------------";
} }
...@@ -78,24 +81,42 @@ class Dialect { ...@@ -78,24 +81,42 @@ class Dialect {
(void)std::initializer_list<int>{0, (RegisterAttribute<Args>(), 0)...}; (void)std::initializer_list<int>{0, (RegisterAttribute<Args>(), 0)...};
} }
///
/// \brief Register attribute of class T.
///
template <typename T> template <typename T>
void RegisterAttribute() { void RegisterAttribute() {
VLOG(4) << "Attribute registered into Dialect. --->"; VLOG(4) << "Attribute registered into Dialect. --->";
ir::AbstractAttribute *abstract_attribute = new ir::AbstractAttribute( if (this->ir_context()->GetRegisteredAbstractAttribute(
std::move(ir::AbstractAttribute::get<T>(*this))); ir::TypeId::get<T>()) == nullptr) {
this->ir_context()->RegisterAbstractAttribute(ir::TypeId::get<T>(), ir::AbstractAttribute *abstract_attribute = new ir::AbstractAttribute(
abstract_attribute); std::move(ir::AbstractAttribute::get<T>(*this)));
ir::AttributeManager::RegisterAttribute<T>(this->ir_context()); this->ir_context()->RegisterAbstractAttribute(ir::TypeId::get<T>(),
abstract_attribute);
ir::AttributeManager::RegisterAttribute<T>(this->ir_context());
}
VLOG(4) << "----------------------------------"; VLOG(4) << "----------------------------------";
} }
void RegisterAttribute(ir::AbstractAttribute &&abstract_attribute);
/// ///
/// \brief Register abstract_attribute into context. /// \brief Register Operation methods.
/// ///
void RegisterAttribute(ir::AbstractAttribute &&abstract_attribute); template <typename... Args>
void RegisterOps() {
(void)std::initializer_list<int>{0, (RegisterOp<Args>(), 0)...};
}
template <typename ConcertOp>
void RegisterOp() {
std::string name = this->name() + "." + std::string(ConcertOp::name());
VLOG(4) << "Op " << name << " registered into Dialect. --->";
if (this->ir_context()->GetRegisteredOpInfo(name) == nullptr) {
ir::OpInfoImpl *op_info = ir::OpInfoImpl::create<ConcertOp>(this);
this->ir_context()->RegisterOpInfo(name, op_info);
}
VLOG(4) << "----------------------------------";
}
void RegisterOp(const std::string &name, OpInfoImpl *op_info);
private: private:
std::string name_; std::string name_;
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/ir/builtin_dialect.h" #include "paddle/ir/builtin_dialect.h"
#include "paddle/ir/builtin_type.h" #include "paddle/ir/builtin_type.h"
#include "paddle/ir/dialect.h" #include "paddle/ir/dialect.h"
#include "paddle/ir/op_info_impl.h"
#include "paddle/ir/spin_lock.h" #include "paddle/ir/spin_lock.h"
#include "paddle/ir/type_base.h" #include "paddle/ir/type_base.h"
...@@ -46,6 +47,11 @@ class IrContextImpl { ...@@ -46,6 +47,11 @@ class IrContextImpl {
delete dialect_map.second; delete dialect_map.second;
} }
registed_dialect_.clear(); registed_dialect_.clear();
for (auto &op_map : registed_op_infos_) {
op_map.second->destroy();
}
registed_op_infos_.clear();
} }
void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) { void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) {
...@@ -93,6 +99,25 @@ class IrContextImpl { ...@@ -93,6 +99,25 @@ class IrContextImpl {
return nullptr; return nullptr;
} }
void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) {
std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
VLOG(4) << "Register an operation of: [Name=" << name
<< ", OpInfoImpl ptr=" << opinfo << "].";
registed_op_infos_.emplace(name, opinfo);
}
OpInfoImpl *GetOpInfo(const std::string &name) {
std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
auto iter = registed_op_infos_.find(name);
if (iter != registed_op_infos_.end()) {
VLOG(4) << "Fonund a cached operation of: [name=" << name
<< ", OpInfoImpl ptr=" << iter->second << "].";
return iter->second;
}
LOG(WARNING) << "No cache found operation of: [Name=" << name << "].";
return nullptr;
}
void RegisterDialect(std::string name, Dialect *dialect) { void RegisterDialect(std::string name, Dialect *dialect) {
std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_); std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
VLOG(4) << "Register a dialect of: [name=" << name VLOG(4) << "Register a dialect of: [name=" << name
...@@ -135,6 +160,10 @@ class IrContextImpl { ...@@ -135,6 +160,10 @@ class IrContextImpl {
std::unordered_map<std::string, Dialect *> registed_dialect_; std::unordered_map<std::string, Dialect *> registed_dialect_;
ir::SpinLock registed_dialect_lock_; ir::SpinLock registed_dialect_lock_;
// The Op registered in the context.
std::unordered_map<std::string, OpInfoImpl *> registed_op_infos_;
ir::SpinLock registed_op_infos_lock_;
ir::SpinLock destructor_lock_; ir::SpinLock destructor_lock_;
}; };
...@@ -165,9 +194,12 @@ StorageManager &IrContext::type_storage_manager() { ...@@ -165,9 +194,12 @@ StorageManager &IrContext::type_storage_manager() {
return impl().registed_type_storage_manager_; return impl().registed_type_storage_manager_;
} }
std::unordered_map<TypeId, AbstractType *> AbstractType *IrContext::GetRegisteredAbstractType(TypeId id) {
&IrContext::registed_abstracted_type() { auto search = impl().registed_abstract_types_.find(id);
return impl().registed_abstract_types_; if (search != impl().registed_abstract_types_.end()) {
return search->second;
}
return nullptr;
} }
void IrContext::RegisterAbstractAttribute( void IrContext::RegisterAbstractAttribute(
...@@ -179,9 +211,12 @@ StorageManager &IrContext::attribute_storage_manager() { ...@@ -179,9 +211,12 @@ StorageManager &IrContext::attribute_storage_manager() {
return impl().registed_attribute_storage_manager_; return impl().registed_attribute_storage_manager_;
} }
std::unordered_map<TypeId, AbstractAttribute *> AbstractAttribute *IrContext::GetRegisteredAbstractAttribute(TypeId id) {
&IrContext::registed_abstracted_attribute() { auto search = impl().registed_abstract_attributes_.find(id);
return impl().registed_abstract_attributes_; if (search != impl().registed_abstract_attributes_.end()) {
return search->second;
}
return nullptr;
} }
Dialect *IrContext::GetOrRegisterDialect( Dialect *IrContext::GetOrRegisterDialect(
...@@ -216,6 +251,17 @@ Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) { ...@@ -216,6 +251,17 @@ Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) {
return nullptr; return nullptr;
} }
OpInfoImpl *IrContext::GetRegisteredOpInfo(const std::string &name) {
OpInfoImpl *rtn = impl().GetOpInfo(name);
return rtn ? rtn : nullptr;
}
void IrContext::RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) {
if (impl().GetOpInfo(name) == nullptr) {
impl().RegisterOpInfo(name, opinfo);
}
}
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) { const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
auto &impl = ctx->impl(); auto &impl = ctx->impl();
AbstractType *abstract_type = impl.GetAbstractType(type_id); AbstractType *abstract_type = impl.GetAbstractType(type_id);
......
...@@ -26,6 +26,7 @@ class AbstractType; ...@@ -26,6 +26,7 @@ class AbstractType;
class AbstractAttribute; class AbstractAttribute;
class TypeId; class TypeId;
class Dialect; class Dialect;
class OpInfoImpl;
/// ///
/// \brief IrContext is a global parameterless class used to store and manage /// \brief IrContext is a global parameterless class used to store and manage
...@@ -47,7 +48,7 @@ class IrContext { ...@@ -47,7 +48,7 @@ class IrContext {
IrContextImpl &impl() { return *impl_; } IrContextImpl &impl() { return *impl_; }
/// ///
/// \brief Register an AbstractType to IrContext /// \brief Register an AbstractType to IrContext.
/// ///
/// \param type_id The type id of the AbstractType. /// \param type_id The type id of the AbstractType.
/// \param abstract_type AbstractType* provided by user. /// \param abstract_type AbstractType* provided by user.
...@@ -64,13 +65,9 @@ class IrContext { ...@@ -64,13 +65,9 @@ class IrContext {
StorageManager &type_storage_manager(); StorageManager &type_storage_manager();
/// ///
/// \brief Returns the storage uniquer used for constructing TypeStorage /// \brief Get registered AbstractType from IrContext.
/// instances.
///
/// \return The storage uniquer used for constructing TypeStorage
/// instances.
/// ///
std::unordered_map<TypeId, AbstractType *> &registed_abstracted_type(); AbstractType *GetRegisteredAbstractType(TypeId id);
/// ///
/// \brief Register an AbstractAttribute to IrContext /// \brief Register an AbstractAttribute to IrContext
...@@ -91,14 +88,16 @@ class IrContext { ...@@ -91,14 +88,16 @@ class IrContext {
StorageManager &attribute_storage_manager(); StorageManager &attribute_storage_manager();
/// ///
/// \brief Returns the storage uniquer used for constructing AttributeStorage /// \brief Get registered AbstractAttribute from IrContext.
/// instances.
/// ///
/// \return The storage uniquer used for constructing AttributeStorage AbstractAttribute *GetRegisteredAbstractAttribute(TypeId id);
/// instances.
///
/// \brief Get or register operaiton.
/// ///
std::unordered_map<TypeId, AbstractAttribute *> void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo);
&registed_abstracted_attribute();
OpInfoImpl *GetRegisteredOpInfo(const std::string &name);
/// ///
/// \brief Get the dialect of the DialectT class in the context, ff not found, /// \brief Get the dialect of the DialectT class in the context, ff not found,
......
...@@ -15,23 +15,59 @@ ...@@ -15,23 +15,59 @@
#pragma once #pragma once
#include "paddle/ir/operation.h" #include "paddle/ir/operation.h"
#include "paddle/ir/utils.h"
namespace ir { namespace ir {
class OpBase { class OpBase {
public: public:
Operation *operation() { return operation_; } explicit OpBase(const Operation *operation) : operation_(operation) {}
explicit operator bool() { return operation() != nullptr; } const Operation *operation() const { return operation_; }
operator Operation *() const { return operation_; } explicit operator bool() const { return operation() != nullptr; }
Operation *operator->() const { return operation_; } operator const Operation *() const { return operation_; }
protected: const Operation *operator->() const { return operation_; }
explicit OpBase(Operation *operation) : operation_(operation) {}
private: private:
Operation *operation_; const Operation *operation_; // Not owned
};
///
/// \brief OpTrait
///
template <class ConcreteTrait>
class OpTraitBase : public OpBase {
public:
explicit OpTraitBase(const Operation *op) : OpBase(op) {}
static TypeId GetTraitId() { return TypeId::get<ConcreteTrait>(); }
};
///
/// \brief OpInterface
///
template <typename ConcreteInterface>
class OpInterfaceBase : public OpBase {
public:
// explicit OpInterfaceBase(Operation *op) : OpBase(op) {}
explicit OpInterfaceBase(const Operation *op) : OpBase(op) {}
static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }
};
template <typename ConcreteOp, class... TraitOrInterface>
class Op : public OpBase {
public:
using OpBase::OpBase;
using TraitList =
typename Filter<OpTraitBase, std::tuple<TraitOrInterface...>>::Type;
using InterfaceList =
typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
}; };
} // namespace ir } // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include "paddle/ir/op_info_impl.h"
namespace ir {
class OpInfo {
public:
constexpr OpInfo() = default;
OpInfo(const OpInfoImpl *impl) : impl_(impl) {} // NOLINT
OpInfo(const OpInfo &other) = default;
OpInfo &operator=(const OpInfo &other) = default;
bool operator==(OpInfo other) const { return impl_ == other.impl_; }
bool operator!=(OpInfo other) const { return impl_ != other.impl_; }
explicit operator bool() const { return impl_; }
bool operator!() const { return impl_ == nullptr; }
const OpInfoImpl *impl() const { return impl_; }
template <typename Trait>
bool HasTrait() const {
return impl_->HasTrait<Trait>();
}
template <typename Interface>
bool HasInterface() const {
return impl_->HasInterface<Interface>();
}
friend struct std::hash<OpInfo>;
private:
const OpInfoImpl *impl_{nullptr}; // not owned
};
} // namespace ir
namespace std {
template <>
struct hash<ir::OpInfo> {
std::size_t operator()(const ir::OpInfo &obj) const {
return std::hash<const ir::OpInfoImpl *>()(obj.impl_);
}
};
} // namespace std
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cstring>
#include <initializer_list>
#include <utility>
#include "paddle/ir/builtin_attribute.h"
// #include "paddle/ir/ir_context.h"
#include "paddle/ir/type.h"
namespace ir {
class Dialect;
///
/// \brief Tool template class for construct interfaces or Traits.
///
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits {
public:
/// Construct method for interfaces.
static std::pair<TypeId, void *> *interface(
std::pair<TypeId, void *> *p_interface) {
(void)std::initializer_list<int>{
0, (PlacementConstrctInterface<Args>(p_interface), 0)...};
return p_interface;
}
/// Construct method for traits.
static TypeId *trait(TypeId *p_trait) {
(void)std::initializer_list<int>{
0, (PlacementConstrctTrait<Args>(p_trait), 0)...};
return p_trait;
}
private:
/// Placement new interface.
template <typename T>
static void PlacementConstrctInterface(
std::pair<TypeId, void *> *&p_interface) { // NOLINT
new (&(p_interface->first)) TypeId(ir::TypeId::get<T>());
p_interface->second =
malloc(sizeof(typename T::template Model<ConcreteOp>));
new (p_interface->second) typename T::template Model<ConcreteOp>();
VLOG(4) << "New a interface: id[" << p_interface->first.storage()
<< "], interface[" << p_interface->second << "].";
++p_interface;
}
/// Placement new trait.
template <typename T>
static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT
new (p_trait) TypeId(ir::TypeId::get<T>());
VLOG(4) << "New a trait: id[" << (*p_trait).storage() << "].";
++p_trait;
}
};
/// Specialized for tuple type.
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits<ConcreteOp, std::tuple<Args...>> {
public:
/// Construct method for interfaces.
static std::pair<TypeId, void *> *interface(
std::pair<TypeId, void *> *p_interface) {
return ConstructInterfacesOrTraits<ConcreteOp, Args...>::interface(
p_interface);
}
/// Construct method for traits.
static TypeId *trait(TypeId *p_trait) {
return ConstructInterfacesOrTraits<ConcreteOp, Args...>::trait(p_trait);
}
};
///
/// \brief OpInfoImpl class.
///
class OpInfoImpl {
public:
///
/// \brief Construct and Deconstruct OpInfoImpl. The memory layout of
/// OpInfoImpl is: std::pair<TypeId, void *>... | TypeId... | OpInfoImpl
///
template <typename ConcreteOp>
static OpInfoImpl *create(ir::Dialect *dialect) {
// (1) Malloc memory for interfaces, traits, opinfo_impl.
size_t interfaces_num =
std::tuple_size<typename ConcreteOp::InterfaceList>::value;
size_t traits_num = std::tuple_size<typename ConcreteOp::TraitList>::value;
size_t attributes_num = ConcreteOp::attributes_num();
VLOG(4) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, "
<< traits_num << " traits, " << attributes_num << " attributes.";
size_t base_size = sizeof(std::pair<ir::TypeId, void *>) * interfaces_num +
sizeof(ir::TypeId) * traits_num + sizeof(OpInfoImpl);
void *base_ptr = malloc(base_size);
VLOG(4) << "Malloc " << base_size << " Bytes at " << base_ptr;
// (2) Construct interfaces and sort by TypeId.
std::pair<ir::TypeId, void *> *p_first_interface = nullptr;
if (interfaces_num > 0) {
p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(base_ptr);
VLOG(4) << "Construct interfaces at " << p_first_interface << " ......";
ConstructInterfacesOrTraits<
ConcreteOp,
typename ConcreteOp::InterfaceList>::interface(p_first_interface);
std::sort(p_first_interface, p_first_interface + interfaces_num);
base_ptr = reinterpret_cast<void *>(p_first_interface + interfaces_num);
}
// (3) Construct traits and sort by TypeId.
ir::TypeId *p_first_trait = nullptr;
if (traits_num > 0) {
p_first_trait = reinterpret_cast<ir::TypeId *>(base_ptr);
VLOG(4) << "Construct traits at " << p_first_trait << " ......";
ConstructInterfacesOrTraits<ConcreteOp, typename ConcreteOp::TraitList>::
trait(p_first_trait);
std::sort(p_first_trait, p_first_trait + traits_num);
base_ptr = reinterpret_cast<void *>(p_first_trait + traits_num);
}
// (4) Construct opinfo_impl.
OpInfoImpl *p_opinfo_impl = reinterpret_cast<OpInfoImpl *>(base_ptr);
VLOG(4) << "Construct op_info_impl at " << p_opinfo_impl << " ......";
OpInfoImpl *op_info =
new (p_opinfo_impl) OpInfoImpl(interfaces_num,
traits_num,
ConcreteOp::attributes_name_,
attributes_num,
ir::TypeId::get<ConcreteOp>(),
ConcreteOp::name(),
dialect);
return op_info;
}
void destroy() {
VLOG(4) << "Destroy op_info impl at " << this;
// (1) free interfaces
void *base_ptr = reinterpret_cast<void *>(
reinterpret_cast<char *>(this) - sizeof(ir::TypeId) * num_traits_ -
sizeof(std::pair<ir::TypeId, void *>) * num_interfaces_);
if (num_interfaces_ > 0) {
std::pair<ir::TypeId, void *> *p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(base_ptr);
for (size_t i = 0; i < num_interfaces_; i++) {
free((p_first_interface + i)->second);
}
}
// (2) free memeory
VLOG(4) << "Free base_ptr " << base_ptr;
free(base_ptr);
}
///
/// \brief Search methods for Trait or Interface.
///
template <typename Trait>
bool HasTrait() const {
return HasTrait(TypeId::get<Trait>());
}
bool HasTrait(TypeId trait_id) const {
if (num_traits_ > 0) {
TypeId *p_first_trait = reinterpret_cast<TypeId *>(
reinterpret_cast<char *>(const_cast<OpInfoImpl *>(this)) -
sizeof(ir::TypeId) * num_traits_);
return std::binary_search(
p_first_trait, p_first_trait + num_traits_, trait_id);
}
return false;
}
template <typename Interface>
bool HasInterface() const {
return HasInterface(TypeId::get<Interface>());
}
bool HasInterface(TypeId interface_id) const {
if (num_interfaces_ > 0) {
std::pair<ir::TypeId, void *> *p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(
reinterpret_cast<char *>(const_cast<OpInfoImpl *>(this)) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(std::pair<ir::TypeId, void *>) * num_interfaces_);
return std::binary_search(p_first_interface,
p_first_interface + num_interfaces_,
std::make_pair(interface_id, nullptr),
CompareInterface);
}
return false;
}
template <typename Interface>
typename Interface::Concept *GetInterfaceImpl() const {
if (num_interfaces_ > 0) {
ir::TypeId interface_id = ir::TypeId::get<Interface>();
std::pair<ir::TypeId, void *> *p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(
reinterpret_cast<char *>(const_cast<OpInfoImpl *>(this)) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(std::pair<ir::TypeId, void *>) * num_interfaces_);
size_t left = 0;
size_t right = num_interfaces_;
while (left < right) {
size_t mid = left + (right - left) / 2;
if ((p_first_interface + mid)->first == interface_id) {
return reinterpret_cast<typename Interface::Concept *>(
(p_first_interface + mid)->second);
} else if ((p_first_interface + mid)->first < interface_id) {
left = mid + 1;
} else {
right = mid;
}
}
}
return nullptr;
}
ir::TypeId id() const { return op_id_; }
const char *name() const { return op_name_; }
ir::Dialect *dialect() const { return dialect_; }
private:
OpInfoImpl(uint32_t num_interfaces,
uint32_t num_traits,
const char **p_attributes,
uint32_t num_attributes,
TypeId op_id,
const char *op_name,
ir::Dialect *dialect)
: num_interfaces_(num_interfaces),
num_traits_(num_traits),
p_attributes_(p_attributes),
num_attributes_(num_attributes),
op_id_(op_id),
op_name_(op_name),
dialect_(dialect) {}
static bool CompareInterface(const std::pair<ir::TypeId, void *> &a,
const std::pair<ir::TypeId, void *> &b) {
return a.first < b.first;
}
/// Interface will be recorded by std::pair<TypeId, void*>.
uint32_t num_interfaces_ = 0;
/// Trait will be recorded by TypeId.
uint32_t num_traits_ = 0;
/// Attributes array address.
const char **p_attributes_{nullptr};
/// The number of attributes for this Op.
uint32_t num_attributes_ = 0;
/// The TypeId of this Op.
TypeId op_id_;
/// The name of this Op.
const char *op_name_;
/// The dialect of this Op belong to.
ir::Dialect *dialect_;
};
} // namespace ir
...@@ -21,7 +21,8 @@ namespace ir { ...@@ -21,7 +21,8 @@ namespace ir {
// OpInlineResult, Operation, Operand. // OpInlineResult, Operation, Operand.
Operation *Operation::create(const std::vector<ir::OpResult> &inputs, Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types, const std::vector<ir::Type> &output_types,
ir::DictionaryAttribute attribute) { ir::DictionaryAttribute attribute,
ir::OpInfo op_info) {
// 1. Calculate the required memory size for OpResults + Operation + // 1. Calculate the required memory size for OpResults + Operation +
// OpOperands. // OpOperands.
uint32_t num_results = output_types.size(); uint32_t num_results = output_types.size();
...@@ -52,7 +53,7 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs, ...@@ -52,7 +53,7 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
} }
// 3.2. Construct Operation. // 3.2. Construct Operation.
Operation *op = Operation *op =
new (base_ptr) Operation(num_results, num_operands, attribute); new (base_ptr) Operation(num_results, num_operands, attribute, op_info);
base_ptr += sizeof(Operation); base_ptr += sizeof(Operation);
// 3.3. Construct OpOperands. // 3.3. Construct OpOperands.
if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) { if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
...@@ -116,13 +117,15 @@ void Operation::destroy() { ...@@ -116,13 +117,15 @@ void Operation::destroy() {
Operation::Operation(uint32_t num_results, Operation::Operation(uint32_t num_results,
uint32_t num_operands, uint32_t num_operands,
ir::DictionaryAttribute attribute) { ir::DictionaryAttribute attribute,
ir::OpInfo op_info) {
if (!attribute) { if (!attribute) {
throw("unexpected null attribute dictionary"); throw("unexpected null attribute dictionary");
} }
num_results_ = num_results; num_results_ = num_results;
num_operands_ = num_operands; num_operands_ = num_operands;
attribute_ = attribute; attribute_ = attribute;
op_info_ = op_info;
} }
ir::OpResult Operation::GetResultByIndex(uint32_t index) { ir::OpResult Operation::GetResultByIndex(uint32_t index) {
......
...@@ -15,10 +15,15 @@ ...@@ -15,10 +15,15 @@
#pragma once #pragma once
#include "paddle/ir/builtin_attribute.h" #include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/op_info.h"
#include "paddle/ir/type.h" #include "paddle/ir/type.h"
#include "paddle/ir/value_impl.h" #include "paddle/ir/value_impl.h"
namespace ir { namespace ir {
template <class ConcreteTrait>
class OpTraitBase;
template <typename ConcreteInterface>
class OpInterfaceBase;
class alignas(8) Operation final { class alignas(8) Operation final {
public: public:
...@@ -28,7 +33,8 @@ class alignas(8) Operation final { ...@@ -28,7 +33,8 @@ class alignas(8) Operation final {
/// ///
static Operation *create(const std::vector<ir::OpResult> &inputs, static Operation *create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types, const std::vector<ir::Type> &output_types,
ir::DictionaryAttribute attribute); ir::DictionaryAttribute attribute,
ir::OpInfo op_info);
void destroy(); void destroy();
...@@ -36,19 +42,60 @@ class alignas(8) Operation final { ...@@ -36,19 +42,60 @@ class alignas(8) Operation final {
std::string print(); std::string print();
ir::DictionaryAttribute attribute() { return attribute_; } ir::DictionaryAttribute attribute() const { return attribute_; }
uint32_t num_results() { return num_results_; } ir::OpInfo op_info() const { return op_info_; }
uint32_t num_operands() { return num_operands_; } uint32_t num_results() const { return num_results_; }
uint32_t num_operands() const { return num_operands_; }
template <typename T>
T dyn_cast() const {
return CastUtil<T>::call(this);
}
template <typename Trait>
bool HasTrait() const {
return op_info_.HasTrait<Trait>();
}
template <typename Interface>
bool HasInterface() const {
return op_info_.HasInterface<Interface>();
}
private: private:
Operation(uint32_t num_results, Operation(uint32_t num_results,
uint32_t num_operands, uint32_t num_operands,
ir::DictionaryAttribute attribute); ir::DictionaryAttribute attribute,
ir::OpInfo op_info);
template <typename T, typename Enabler = void>
struct CastUtil {
static T call(const Operation *op) {
throw("Can't dyn_cast to T, T should be a Trait or Interface");
}
};
template <typename T>
struct CastUtil<T,
typename std::enable_if<
std::is_base_of<OpTraitBase<T>, T>::value>::type> {
static T call(const Operation *op) { return T(op); }
};
template <typename T>
struct CastUtil<T,
typename std::enable_if<
std::is_base_of<OpInterfaceBase<T>, T>::value>::type> {
static T call(const Operation *op) {
return T(op, op->op_info_.impl()->GetInterfaceImpl<T>());
}
};
ir::DictionaryAttribute attribute_; ir::DictionaryAttribute attribute_;
ir::OpInfo op_info_;
uint32_t num_results_ = 0; uint32_t num_results_ = 0;
uint32_t num_operands_ = 0; uint32_t num_operands_ = 0;
......
...@@ -45,6 +45,12 @@ class TypeId { ...@@ -45,6 +45,12 @@ class TypeId {
return TypeId(&instance); return TypeId(&instance);
} }
TypeId(const TypeId &other) = default;
TypeId &operator=(const TypeId &other) = default;
const Storage *storage() const { return storage_; }
/// ///
/// \brief Comparison operations. /// \brief Comparison operations.
/// ///
...@@ -54,6 +60,9 @@ class TypeId { ...@@ -54,6 +60,9 @@ class TypeId {
inline bool operator!=(const TypeId &other) const { inline bool operator!=(const TypeId &other) const {
return !(*this == other); return !(*this == other);
} }
inline bool operator<(const TypeId &other) const {
return storage_ < other.storage_;
}
/// ///
/// \brief Enable hashing TypeId instances. /// \brief Enable hashing TypeId instances.
......
...@@ -17,12 +17,107 @@ ...@@ -17,12 +17,107 @@
#include <cassert> #include <cassert>
#include <cstdint> #include <cstdint>
#include <cstdlib> #include <cstdlib>
#include <tuple>
#include <type_traits>
namespace ir { namespace ir {
///
/// \brief Equivalent to boost::hash_combine.
///
std::size_t hash_combine(std::size_t lhs, std::size_t rhs); std::size_t hash_combine(std::size_t lhs, std::size_t rhs);
///
/// \brief Aligned malloc and free functions.
///
void *aligned_malloc(size_t size, size_t alignment); void *aligned_malloc(size_t size, size_t alignment);
void aligned_free(void *mem_ptr); void aligned_free(void *mem_ptr);
///
/// \brief Some template methods for manipulating std::tuple.
///
/// (1) Pop front element from Tuple
template <typename Tuple>
struct PopFrontT;
template <typename Head, typename... Tail>
struct PopFrontT<std::tuple<Head, Tail...>> {
public:
using Type = std::tuple<Tail...>;
};
template <typename Tuple>
using PopFront = typename PopFrontT<Tuple>::Type;
/// (2) Push front element to Tuple
template <typename NewElement, typename Tuple>
struct PushFrontT;
template <typename NewElement, typename... Elements>
struct PushFrontT<NewElement, std::tuple<Elements...>> {
public:
using Type = std::tuple<NewElement, Elements...>;
};
template <typename NewElement, typename... Elements>
struct PushFrontT<std::tuple<NewElement>, std::tuple<Elements...>> {
public:
using Type = std::tuple<NewElement, Elements...>;
};
template <typename NewElement, typename Tuple>
using PushFront = typename PushFrontT<NewElement, Tuple>::Type;
/// (3) IsEmpty
template <typename Tuple>
struct IsEmpty {
static constexpr bool value = false;
};
template <>
struct IsEmpty<std::tuple<>> {
static constexpr bool value = true;
};
/// (4) IfThenElseT
template <bool COND, typename TrueT, typename FalseT>
struct IfThenElseT {
using Type = TrueT;
};
template <typename TrueT, typename FalseT>
struct IfThenElseT<false, TrueT, FalseT> {
using Type = FalseT;
};
template <bool COND, typename TrueT, typename FalseT>
using IfThenElse = typename IfThenElseT<COND, TrueT, FalseT>::Type;
/// (5) Filter out all types inherited from BaseT from the tuple.
template <template <typename> class BaseT,
typename Tuple,
bool Empty = IsEmpty<Tuple>::value>
struct Filter;
template <template <typename> class BaseT, typename Tuple>
struct Filter<BaseT, Tuple, false> {
private:
using Matched =
IfThenElse<std::is_base_of<BaseT<std::tuple_element_t<0, Tuple>>,
std::tuple_element_t<0, Tuple>>::value,
std::tuple<std::tuple_element_t<0, Tuple>>,
std::tuple<>>;
using Rest = typename Filter<BaseT, PopFront<Tuple>>::Type;
public:
using Type =
IfThenElse<IsEmpty<Matched>::value, Rest, PushFront<Matched, Rest>>;
};
// basis case:
template <template <typename> class BaseT, typename Tuple>
struct Filter<BaseT, Tuple, true> {
using Type = std::tuple<>;
};
} // namespace ir } // namespace ir
...@@ -75,6 +75,12 @@ Operation *Value::GetDefiningOp() const { ...@@ -75,6 +75,12 @@ Operation *Value::GetDefiningOp() const {
std::string Value::print_ud_chain() { return impl_->print_ud_chain(); } std::string Value::print_ud_chain() { return impl_->print_ud_chain(); }
Value::use_iterator Value::begin() const {
return ir::OpOperand(impl_->first_use());
}
Value::use_iterator Value::end() const { return Value::use_iterator(); }
// OpResult // OpResult
bool OpResult::classof(Value value) { bool OpResult::classof(Value value) {
return ir::isa<detail::OpResultImpl>(value.impl()); return ir::isa<detail::OpResultImpl>(value.impl());
......
...@@ -56,6 +56,38 @@ class OpOperand { ...@@ -56,6 +56,38 @@ class OpOperand {
detail::OpOperandImpl *impl_{nullptr}; detail::OpOperandImpl *impl_{nullptr};
}; };
///
/// \brief Value Iterator
///
template <typename OperandType>
class ValueUseIterator {
public:
ValueUseIterator(OperandType use = nullptr) : current_(use) {} // NOLINT
bool operator==(const ValueUseIterator<OperandType> &rhs) const {
return current_ == rhs.current_;
}
ir::Operation *owner() const { return current_.impl()->owner(); }
OperandType get() const { return current_; }
OperandType operator*() const { return get(); }
ValueUseIterator<OperandType> &operator++() {
current_ = current_.impl()->next_use();
return *this;
}
ValueUseIterator<OperandType> operator++(int) {
ValueUseIterator<OperandType> tmp = *this;
++*(this);
return tmp;
}
protected:
OperandType current_;
};
/// ///
/// \brief Value class represents the SSA value in the IR system. This class /// \brief Value class represents the SSA value in the IR system. This class
/// only provides interfaces, for specific implementation, see Impl class. /// only provides interfaces, for specific implementation, see Impl class.
...@@ -96,6 +128,15 @@ class Value { ...@@ -96,6 +128,15 @@ class Value {
std::string print_ud_chain(); std::string print_ud_chain();
///
/// \brief Provide iterator interface to access Value use chain.
///
using use_iterator = ValueUseIterator<OpOperand>;
use_iterator begin() const;
use_iterator end() const;
friend struct std::hash<Value>; friend struct std::hash<Value>;
protected: protected:
......
...@@ -2,4 +2,5 @@ if(WITH_NEWIR) ...@@ -2,4 +2,5 @@ if(WITH_NEWIR)
cc_test_old(type_test SRCS type_test.cc DEPS new_ir gtest) cc_test_old(type_test SRCS type_test.cc DEPS new_ir gtest)
cc_test_old(ir_attribute_test SRCS ir_attribute_test.cc DEPS new_ir gtest) cc_test_old(ir_attribute_test SRCS ir_attribute_test.cc DEPS new_ir gtest)
cc_test_old(ir_value_test SRCS ir_value_test.cc DEPS new_ir gtest) cc_test_old(ir_value_test SRCS ir_value_test.cc DEPS new_ir gtest)
cc_test_old(ir_op_test SRCS ir_op_test.cc DEPS new_ir gtest)
endif() endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/ir/builtin_type.h"
#include "paddle/ir/dialect.h"
#include "paddle/ir/ir_context.h"
#include "paddle/ir/op_base.h"
/// \brief Define built-in Trait, derived from OpTraitBase.
class ReadOnlyTrait : public ir::OpTraitBase<ReadOnlyTrait> {
public:
explicit ReadOnlyTrait(const ir::Operation *op)
: ir::OpTraitBase<ReadOnlyTrait>(op) {}
};
/// \brief Define built-in Interface, derived from OpInterfaceBase. Concepts and
/// Models need to be defined within the class. Concept defines abstract
/// interface functions, and Model is a template class that defines the specific
/// implementation of interface functions based on template parameters.
class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> {
public:
struct Concept {
explicit Concept(void (*infer_shape)(const ir::Operation *))
: infer_shape_(infer_shape) {}
void (*infer_shape_)(const ir::Operation *);
};
template <class ConcreteOp>
struct Model : public Concept {
static void InferShape(const ir::Operation *op) {
ConcreteOp concret_op = ConcreteOp(op);
if (concret_op == nullptr) throw("concret_op is nullptr");
concret_op.InferShape();
}
Model() : Concept(InferShape) {
if (sizeof(Model) != sizeof(Concept)) {
throw("sizeof(Model) != sizeof(Concept)");
}
}
};
InferShapeInterface(const ir::Operation *op, Concept *impl)
: ir::OpInterfaceBase<InferShapeInterface>(op), impl_(impl) {}
void InferShape() { impl_->infer_shape_(operation()); }
private:
Concept *impl_;
};
// Define op1.
class Operation1 : public ir::Op<Operation1> {
public:
using Op::Op;
static const char *name() { return "Operation1"; }
static const char *attributes_name_[];
static uint32_t attributes_num() { return 2; }
};
const char *Operation1::attributes_name_[] = {"op1_attr1", "op1_attr2"};
// Define op2.
class Operation2
: public ir::Op<Operation2, ReadOnlyTrait, InferShapeInterface> {
public:
using Op::Op;
static const char *name() { return "Operation2"; }
static const char *attributes_name_[];
static uint32_t attributes_num() { return 2; }
static void InferShape() {
std::cout << "This is op2's InferShape interface." << std::endl;
}
};
const char *Operation2::attributes_name_[] = {"op2_attr1", "op2_attr2"};
// Define a dialect, op1 and op2 will be registered by this dialect.
class TestDialect : public ir::Dialect {
public:
explicit TestDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<TestDialect>()) {
initialize();
}
static const char *name() { return "op_test"; }
private:
void initialize() { RegisterOps<Operation1, Operation2>(); }
};
ir::DictionaryAttribute CreateAttribute(std::string attribute_name,
std::string attribute) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::StrAttribute attr_name = ir::StrAttribute::get(ctx, attribute_name);
ir::Attribute attr_value = ir::StrAttribute::get(ctx, attribute);
std::map<ir::StrAttribute, ir::Attribute> named_attr;
named_attr.insert(
std::pair<ir::StrAttribute, ir::Attribute>(attr_name, attr_value));
return ir::DictionaryAttribute::get(ctx, named_attr);
}
TEST(op_test, op_test) {
// (1) Register Dialect, Operation1, Operation2 into IrContext.
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Dialect *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
std::cout << test_dialect << std::endl;
// (2) Get registered operations.
std::string op1_name =
test_dialect->name() + "." + std::string(Operation1::name());
ir::OpInfoImpl *op1_info = ctx->GetRegisteredOpInfo(op1_name);
EXPECT_EQ(op1_info != nullptr, true);
std::string op2_name =
test_dialect->name() + "." + std::string(Operation2::name());
ir::OpInfoImpl *op2_info = ctx->GetRegisteredOpInfo(op2_name);
EXPECT_EQ(op2_info != nullptr, true);
EXPECT_EQ(op1_info->HasTrait<ReadOnlyTrait>(), false);
EXPECT_EQ(op1_info->HasInterface<InferShapeInterface>(), false);
EXPECT_EQ(op2_info->HasTrait<ReadOnlyTrait>(), true);
EXPECT_EQ(op2_info->HasInterface<InferShapeInterface>(), true);
// (3) Test uses for op.
std::vector<ir::OpResult> op_inputs = {};
std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op =
ir::Operation::create(op_inputs,
op_output_types,
CreateAttribute("op1_name", "op1_attr"),
op2_info);
if (op->HasTrait<ReadOnlyTrait>()) {
ReadOnlyTrait trait = op->dyn_cast<ReadOnlyTrait>();
EXPECT_EQ(trait.operation(), op);
}
if (op->HasInterface<InferShapeInterface>()) {
InferShapeInterface interface = op->dyn_cast<InferShapeInterface>();
interface.InferShape();
}
op->destroy();
}
...@@ -40,21 +40,30 @@ TEST(value_test, value_test) { ...@@ -40,21 +40,30 @@ TEST(value_test, value_test) {
// 1. Construct OP1: a = OP1() // 1. Construct OP1: a = OP1()
std::vector<ir::OpResult> op1_inputs = {}; std::vector<ir::OpResult> op1_inputs = {};
std::vector<ir::Type> op1_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op1_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op1 = ir::Operation::create( ir::Operation *op1 =
op1_inputs, op1_output_types, CreateAttribute("op1_name", "op1_attr")); ir::Operation::create(op1_inputs,
op1_output_types,
CreateAttribute("op1_name", "op1_attr"),
nullptr);
std::cout << op1->print() << std::endl; std::cout << op1->print() << std::endl;
// 2. Construct OP2: b = OP2(); // 2. Construct OP2: b = OP2();
std::vector<ir::OpResult> op2_inputs = {}; std::vector<ir::OpResult> op2_inputs = {};
std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op2 = ir::Operation::create( ir::Operation *op2 =
op2_inputs, op2_output_types, CreateAttribute("op2_name", "op2_attr")); ir::Operation::create(op2_inputs,
op2_output_types,
CreateAttribute("op2_name", "op2_attr"),
nullptr);
std::cout << op2->print() << std::endl; std::cout << op2->print() << std::endl;
// 3. Construct OP3: c = OP3(a, b); // 3. Construct OP3: c = OP3(a, b);
std::vector<ir::OpResult> op3_inputs = {op1->GetResultByIndex(0), std::vector<ir::OpResult> op3_inputs = {op1->GetResultByIndex(0),
op2->GetResultByIndex(0)}; op2->GetResultByIndex(0)};
std::vector<ir::Type> op3_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op3_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op3 = ir::Operation::create( ir::Operation *op3 =
op3_inputs, op3_output_types, CreateAttribute("op3_name", "op3_attr")); ir::Operation::create(op3_inputs,
op3_output_types,
CreateAttribute("op3_name", "op3_attr"),
nullptr);
std::cout << op3->print() << std::endl; std::cout << op3->print() << std::endl;
// 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c); // 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c);
std::vector<ir::OpResult> op4_inputs = {op1->GetResultByIndex(0), std::vector<ir::OpResult> op4_inputs = {op1->GetResultByIndex(0),
...@@ -63,8 +72,11 @@ TEST(value_test, value_test) { ...@@ -63,8 +72,11 @@ TEST(value_test, value_test) {
for (size_t i = 0; i < 7; i++) { for (size_t i = 0; i < 7; i++) {
op4_output_types.push_back(ir::Float32Type::get(ctx)); op4_output_types.push_back(ir::Float32Type::get(ctx));
} }
ir::Operation *op4 = ir::Operation::create( ir::Operation *op4 =
op4_inputs, op4_output_types, CreateAttribute("op4_name", "op4_attr")); ir::Operation::create(op4_inputs,
op4_output_types,
CreateAttribute("op4_name", "op4_attr"),
nullptr);
std::cout << op4->print() << std::endl; std::cout << op4->print() << std::endl;
// Test 1: // Test 1:
...@@ -86,6 +98,12 @@ TEST(value_test, value_test) { ...@@ -86,6 +98,12 @@ TEST(value_test, value_test) {
EXPECT_EQ(op4_first_input->next_use(), op3_first_input); EXPECT_EQ(op4_first_input->next_use(), op3_first_input);
EXPECT_EQ(op3_first_input->next_use(), nullptr); EXPECT_EQ(op3_first_input->next_use(), nullptr);
// Test 3: Value iterator
ir::Value::use_iterator iter = op1->GetResultByIndex(0).begin();
EXPECT_EQ(iter.owner(), op4);
++iter;
EXPECT_EQ(iter.owner(), op3);
// destroy // destroy
std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl; std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op4->destroy(); op4->destroy();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册