未验证 提交 9d9f0ce5 编写于 作者: 王明冬 提交者: GitHub

[IR] fine-tune the implementation of ir component. (#53894)

上级 d29c1f8e
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include "paddle/ir/operation.h"
namespace ir {
///
/// \brief Unified interface of the Attribute class. Derivation of all Attribute
/// classes only derives interfaces, not members.
///
class Builder {
public:
explicit Builder(IrContext *context) : context_(context) {}
explicit Builder(Operation *op) : Builder(op->ir_context()) {}
/// Create an operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpTy create(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::build(*this, argument, std::forward<Args>(args)...);
Operation *op = Operation::create(argument);
return dyn_cast<OpTy>(op);
}
private:
IrContext *context_;
// The current op list this builder is inserting into.
// After the design of the block data structure is completed,
// this member will be replaced by the block.
std::list<Operation *> *op_list_ = nullptr;
// The insertion point within the list that this builder is inserting before.
std::list<Operation *>::iterator insertPoint;
};
} // namespace ir
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
#include "paddle/ir/builtin_op.h" #include "paddle/ir/builtin_op.h"
namespace ir { namespace ir {
const char *GetParameterOp::attributes_name_[] = {"parameter_name"}; const char *GetParameterOp::attributes_name[attributes_num] = {
"parameter_name"};
const char *SetParameterOp::attributes_name_[] = {"parameter_name"}; const char *SetParameterOp::attributes_name[attributes_num] = {
"parameter_name"};
} // namespace ir } // namespace ir
...@@ -32,11 +32,11 @@ class GetParameterOp : public ir::Op<GetParameterOp> { ...@@ -32,11 +32,11 @@ class GetParameterOp : public ir::Op<GetParameterOp> {
public: public:
using Op::Op; using Op::Op;
static const char* name() { return "GetParameterOp"; } static const char* name() { return "builtin.get_parameter"; }
static uint32_t attributes_num() { return 1; } static constexpr uint32_t attributes_num = 1;
static const char* attributes_name_[]; static const char* attributes_name[attributes_num];
}; };
/// ///
...@@ -47,11 +47,11 @@ class SetParameterOp : public ir::Op<SetParameterOp> { ...@@ -47,11 +47,11 @@ class SetParameterOp : public ir::Op<SetParameterOp> {
public: public:
using Op::Op; using Op::Op;
static const char* name() { return "SetParameterOp"; } static const char* name() { return "builtin.set_parameter"; }
static uint32_t attributes_num() { return 1; } static constexpr uint32_t attributes_num = 1;
static const char* attributes_name_[]; static const char* attributes_name[attributes_num];
}; };
} // namespace ir } // namespace ir
...@@ -20,24 +20,6 @@ Dialect::Dialect(std::string name, ir::IrContext *context, ir::TypeId id) ...@@ -20,24 +20,6 @@ Dialect::Dialect(std::string name, ir::IrContext *context, ir::TypeId id)
Dialect::~Dialect() = default; Dialect::~Dialect() = default;
void Dialect::RegisterType(ir::AbstractType &&abstract_type) {
ir::AbstractType *new_abstract_type =
new ir::AbstractType(std::move(abstract_type));
this->ir_context()->RegisterAbstractType(new_abstract_type->type_id(),
new_abstract_type);
}
void Dialect::RegisterAttribute(ir::AbstractAttribute &&abstract_attribute) {
ir::AbstractAttribute *new_abstract_attribute =
new ir::AbstractAttribute(std::move(abstract_attribute));
this->ir_context()->RegisterAbstractAttribute(
new_abstract_attribute->type_id(), new_abstract_attribute);
}
void Dialect::RegisterOp(const std::string &name, OpInfoImpl *op_info) {
this->ir_context()->RegisterOpInfo(name, op_info);
}
void Dialect::RegisterInterface(std::unique_ptr<DialectInterface> interface) { void Dialect::RegisterInterface(std::unique_ptr<DialectInterface> interface) {
VLOG(4) << "Register interface into dialect" << std::endl; VLOG(4) << "Register interface into dialect" << std::endl;
auto it = registered_interfaces_.emplace(interface->interface_id(), auto it = registered_interfaces_.emplace(interface->interface_id(),
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "paddle/ir/attribute_base.h" #include "paddle/ir/attribute_base.h"
#include "paddle/ir/dialect_interface.h" #include "paddle/ir/dialect_interface.h"
#include "paddle/ir/ir_context.h" #include "paddle/ir/ir_context.h"
#include "paddle/ir/op_info_impl.h" #include "paddle/ir/op_base.h"
#include "paddle/ir/type_base.h" #include "paddle/ir/type_base.h"
namespace ir { namespace ir {
...@@ -52,26 +52,10 @@ class Dialect { ...@@ -52,26 +52,10 @@ class Dialect {
template <typename T> template <typename T>
void RegisterType() { void RegisterType() {
VLOG(4) << "Type registered into Dialect. --->"; ir_context()->RegisterAbstractType(TypeId::get<T>(),
if (this->ir_context()->GetRegisteredAbstractType(ir::TypeId::get<T>()) == AbstractType::get<T>(*this));
nullptr) { TypeManager::RegisterType<T>(ir_context());
ir::AbstractType *abstract_type =
new ir::AbstractType(std::move(ir::AbstractType::get<T>(*this)));
this->ir_context()->RegisterAbstractType(ir::TypeId::get<T>(),
abstract_type);
ir::TypeManager::RegisterType<T>(this->ir_context());
} }
VLOG(4) << "----------------------------------";
}
///
/// \brief Register abstract_type into context.
/// NOTE: It's not recommended to use this interface directly. This interface
/// only registers abstract_type. To register TypeStorage into context, you
/// need to call ir::TypeManager::RegisterType<T>() additionally,
/// RegisterType<T>() is recommended to use.
///
void RegisterType(ir::AbstractType &&abstract_type);
/// ///
/// \brief Register all attributes contained in the template parameter Args. /// \brief Register all attributes contained in the template parameter Args.
...@@ -85,37 +69,28 @@ class Dialect { ...@@ -85,37 +69,28 @@ class Dialect {
template <typename T> template <typename T>
void RegisterAttribute() { void RegisterAttribute() {
VLOG(4) << "Attribute registered into Dialect. --->"; ir_context()->RegisterAbstractAttribute(TypeId::get<T>(),
if (this->ir_context()->GetRegisteredAbstractAttribute( AbstractAttribute::get<T>(*this));
ir::TypeId::get<T>()) == nullptr) { AttributeManager::RegisterAttribute<T>(ir_context());
ir::AbstractAttribute *abstract_attribute = new ir::AbstractAttribute(
std::move(ir::AbstractAttribute::get<T>(*this)));
this->ir_context()->RegisterAbstractAttribute(ir::TypeId::get<T>(),
abstract_attribute);
ir::AttributeManager::RegisterAttribute<T>(this->ir_context());
} }
VLOG(4) << "----------------------------------";
}
void RegisterAttribute(ir::AbstractAttribute &&abstract_attribute);
/// ///
/// \brief Register Operation methods. /// \brief Register Ops.
/// ///
template <typename... Args> template <typename... Args>
void RegisterOps() { void RegisterOps() {
(void)std::initializer_list<int>{0, (RegisterOp<Args>(), 0)...}; (void)std::initializer_list<int>{0, (RegisterOp<Args>(), 0)...};
} }
template <typename ConcertOp> template <typename ConcreteOp>
void RegisterOp() { void RegisterOp() {
std::string name = this->name() + "." + std::string(ConcertOp::name()); ir_context()->RegisterOpInfo(this,
VLOG(4) << "Op " << name << " registered into Dialect. --->"; TypeId::get<ConcreteOp>(),
if (this->ir_context()->GetRegisteredOpInfo(name) == nullptr) { ConcreteOp::name(),
ir::OpInfoImpl *op_info = ir::OpInfoImpl::create<ConcertOp>(this); ConcreteOp::GetInterfaceMap(),
this->ir_context()->RegisterOpInfo(name, op_info); ConcreteOp::GetTraitSet(),
} ConcreteOp::attributes_num,
VLOG(4) << "----------------------------------"; ConcreteOp::attributes_name);
} }
void RegisterOp(const std::string &name, OpInfoImpl *op_info); void RegisterOp(const std::string &name, OpInfoImpl *op_info);
......
...@@ -185,11 +185,6 @@ IrContext::IrContext() : impl_(new IrContextImpl()) { ...@@ -185,11 +185,6 @@ IrContext::IrContext() : impl_(new IrContextImpl()) {
impl_->int64_type = TypeManager::get<Int64Type>(this); impl_->int64_type = TypeManager::get<Int64Type>(this);
} }
void IrContext::RegisterAbstractType(ir::TypeId type_id,
AbstractType *abstract_type) {
impl().RegisterAbstractType(type_id, abstract_type);
}
StorageManager &IrContext::type_storage_manager() { StorageManager &IrContext::type_storage_manager() {
return impl().registed_type_storage_manager_; return impl().registed_type_storage_manager_;
} }
...@@ -203,8 +198,14 @@ AbstractType *IrContext::GetRegisteredAbstractType(TypeId id) { ...@@ -203,8 +198,14 @@ AbstractType *IrContext::GetRegisteredAbstractType(TypeId id) {
} }
void IrContext::RegisterAbstractAttribute( void IrContext::RegisterAbstractAttribute(
ir::TypeId type_id, AbstractAttribute *abstract_attribute) { ir::TypeId type_id, AbstractAttribute &&abstract_attribute) {
impl().RegisterAbstractAttribute(type_id, abstract_attribute); if (GetRegisteredAbstractAttribute(type_id) == nullptr) {
impl().RegisterAbstractAttribute(
type_id, new AbstractAttribute(std::move(abstract_attribute)));
VLOG(4) << "<--- Attribute registered into IrContext. --->";
} else {
LOG(WARNING) << " Attribute already registered.";
}
} }
StorageManager &IrContext::attribute_storage_manager() { StorageManager &IrContext::attribute_storage_manager() {
...@@ -251,17 +252,44 @@ Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) { ...@@ -251,17 +252,44 @@ Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) {
return nullptr; return nullptr;
} }
OpInfoImpl *IrContext::GetRegisteredOpInfo(const std::string &name) { void IrContext::RegisterAbstractType(ir::TypeId type_id,
OpInfoImpl *rtn = impl().GetOpInfo(name); AbstractType &&abstract_type) {
return rtn ? rtn : nullptr; if (GetRegisteredAbstractType(type_id) == nullptr) {
impl().RegisterAbstractType(type_id,
new AbstractType(std::move(abstract_type)));
VLOG(4) << "<--- Type registered into IrContext. --->";
} else {
LOG(WARNING) << " type already registered.";
}
} }
void IrContext::RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) { void IrContext::RegisterOpInfo(Dialect *dialect,
if (impl().GetOpInfo(name) == nullptr) { TypeId op_id,
const char *name,
std::vector<InterfaceValue> &&interface_map,
const std::vector<TypeId> &trait_set,
size_t attributes_num,
const char **attributes_name) {
if (GetRegisteredOpInfo(name) == nullptr) {
OpInfoImpl *opinfo = OpInfoImpl::create(dialect,
op_id,
name,
std::move(interface_map),
trait_set,
attributes_num,
attributes_name);
impl().RegisterOpInfo(name, opinfo); impl().RegisterOpInfo(name, opinfo);
VLOG(4) << "Op " << name << " registered into IrContext. --->";
} else {
LOG(WARNING) << name << " op already registered.";
} }
} }
OpInfo IrContext::GetRegisteredOpInfo(const std::string &name) {
OpInfoImpl *rtn = impl().GetOpInfo(name);
return rtn ? rtn : nullptr;
}
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) { const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
auto &impl = ctx->impl(); auto &impl = ctx->impl();
AbstractType *abstract_type = impl.GetAbstractType(type_id); AbstractType *abstract_type = impl.GetAbstractType(type_id);
......
...@@ -13,11 +13,9 @@ ...@@ -13,11 +13,9 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <glog/logging.h>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <unordered_map> #include <vector>
namespace ir { namespace ir {
class IrContextImpl; class IrContextImpl;
...@@ -26,8 +24,8 @@ class AbstractType; ...@@ -26,8 +24,8 @@ class AbstractType;
class AbstractAttribute; class AbstractAttribute;
class TypeId; class TypeId;
class Dialect; class Dialect;
class OpInfoImpl; class OpInfo;
class InterfaceValue;
/// ///
/// \brief IrContext is a global parameterless class used to store and manage /// \brief IrContext is a global parameterless class used to store and manage
/// Type, Attribute and other related data structures. /// Type, Attribute and other related data structures.
...@@ -53,7 +51,7 @@ class IrContext { ...@@ -53,7 +51,7 @@ class IrContext {
/// \param type_id The type id of the AbstractType. /// \param type_id The type id of the AbstractType.
/// \param abstract_type AbstractType* provided by user. /// \param abstract_type AbstractType* provided by user.
/// ///
void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type); void RegisterAbstractType(TypeId type_id, AbstractType &&abstract_type);
/// ///
/// \brief Returns the storage uniquer used for constructing TypeStorage /// \brief Returns the storage uniquer used for constructing TypeStorage
...@@ -73,10 +71,10 @@ class IrContext { ...@@ -73,10 +71,10 @@ class IrContext {
/// \brief Register an AbstractAttribute to IrContext /// \brief Register an AbstractAttribute to IrContext
/// ///
/// \param type_id The type id of the AbstractAttribute. /// \param type_id The type id of the AbstractAttribute.
/// \param abstract_attribute AbstractAttribute* provided by user. /// \param abstract_attribute AbstractAttribute provided by user.
/// ///
void RegisterAbstractAttribute(ir::TypeId type_id, void RegisterAbstractAttribute(ir::TypeId type_id,
AbstractAttribute *abstract_attribute); AbstractAttribute &&abstract_attribute);
/// ///
/// \brief Returns the storage uniquer used for constructing AttributeStorage /// \brief Returns the storage uniquer used for constructing AttributeStorage
...@@ -93,11 +91,20 @@ class IrContext { ...@@ -93,11 +91,20 @@ class IrContext {
AbstractAttribute *GetRegisteredAbstractAttribute(TypeId id); AbstractAttribute *GetRegisteredAbstractAttribute(TypeId id);
/// ///
/// \brief Get or register operaiton. /// \brief Register an op infomation to IrContext
/// ///
void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo); void RegisterOpInfo(Dialect *dialect,
TypeId op_id,
const char *name,
std::vector<InterfaceValue> &&interface_map,
const std::vector<TypeId> &trait_set,
size_t attributes_num,
const char **attributes_name);
OpInfoImpl *GetRegisteredOpInfo(const std::string &name); ///
/// \brief Get registered operaiton infomation.
///
OpInfo GetRegisteredOpInfo(const std::string &name);
/// ///
/// \brief Get the dialect of the DialectT class in the context, ff not found, /// \brief Get the dialect of the DialectT class in the context, ff not found,
...@@ -162,7 +169,6 @@ class IrContext { ...@@ -162,7 +169,6 @@ class IrContext {
private: private:
IrContext(); IrContext();
const std::unique_ptr<IrContextImpl> impl_; const std::unique_ptr<IrContextImpl> impl_;
}; };
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/op_base.h"
namespace ir {
InterfaceValue::~InterfaceValue() {
if (model_) free(model_);
}
InterfaceValue::InterfaceValue(InterfaceValue&& val) {
type_id_ = val.type_id_;
model_ = val.model_;
val.model_ = nullptr;
}
InterfaceValue& InterfaceValue::operator=(InterfaceValue&& val) {
swap(std::move(val));
return *this;
}
} // namespace ir
...@@ -13,11 +13,57 @@ ...@@ -13,11 +13,57 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <type_traits>
#include "paddle/ir/operation.h" #include "paddle/ir/operation.h"
#include "paddle/ir/utils.h" #include "paddle/ir/utils.h"
namespace ir { namespace ir {
class InterfaceValue {
public:
template <typename ConcreteOp, typename T>
static InterfaceValue get() {
InterfaceValue val;
val.type_id_ = TypeId::get<T>();
val.model_ = malloc(sizeof(typename T::template Model<ConcreteOp>));
if (val.model_ == nullptr) {
throw("Alloc memory for interface failed.");
}
static_assert(std::is_trivially_destructible<
typename T::template Model<ConcreteOp>>::value,
"interface models must be trivially destructible");
new (val.model_) typename T::template Model<ConcreteOp>();
return val;
}
TypeId type_id() const { return type_id_; }
void *model() const { return model_; }
InterfaceValue() = default;
explicit InterfaceValue(TypeId type_id) : type_id_(type_id) {}
InterfaceValue(const InterfaceValue &) = delete;
InterfaceValue(InterfaceValue &&);
InterfaceValue &operator=(const InterfaceValue &) = delete;
InterfaceValue &operator=(InterfaceValue &&);
~InterfaceValue();
void swap(InterfaceValue &&val) {
using std::swap;
swap(type_id_, val.type_id_);
swap(model_, val.model_);
}
///
/// \brief Comparison operations.
///
inline bool operator<(const InterfaceValue &other) const {
return type_id_ < other.type_id_;
}
private:
TypeId type_id_;
void *model_{nullptr};
};
class OpBase { class OpBase {
public: public:
explicit OpBase(const Operation *operation) : operation_(operation) {} explicit OpBase(const Operation *operation) : operation_(operation) {}
...@@ -58,6 +104,59 @@ class OpInterfaceBase : public OpBase { ...@@ -58,6 +104,59 @@ class OpInterfaceBase : public OpBase {
static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); } static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }
}; };
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits {
public:
/// Construct method for interfaces.
static InterfaceValue *interface(InterfaceValue *p_interface) {
(void)std::initializer_list<int>{
0, (PlacementConstrctInterface<Args>(p_interface), 0)...};
return p_interface;
}
/// Construct method for traits.
static TypeId *trait(TypeId *p_trait) {
(void)std::initializer_list<int>{
0, (PlacementConstrctTrait<Args>(p_trait), 0)...};
return p_trait;
}
private:
/// Placement new interface.
template <typename T>
static void PlacementConstrctInterface(
InterfaceValue *&p_interface) { // NOLINT
p_interface->swap(InterfaceValue::get<ConcreteOp, T>());
VLOG(4) << "New a interface: id[" << (p_interface->type_id()).storage()
<< "].";
++p_interface;
}
/// Placement new trait.
template <typename T>
static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT
*p_trait = TypeId::get<T>();
VLOG(4) << "New a trait: id[" << p_trait->storage() << "].";
++p_trait;
}
};
/// Specialized for tuple type.
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits<ConcreteOp, std::tuple<Args...>> {
public:
/// Construct method for interfaces.
static InterfaceValue *interface(InterfaceValue *p_interface) {
return ConstructInterfacesOrTraits<ConcreteOp, Args...>::interface(
p_interface);
}
/// Construct method for traits.
static TypeId *trait(TypeId *p_trait) {
return ConstructInterfacesOrTraits<ConcreteOp, Args...>::trait(p_trait);
}
};
template <typename ConcreteOp, class... TraitOrInterface> template <typename ConcreteOp, class... TraitOrInterface>
class Op : public OpBase { class Op : public OpBase {
public: public:
...@@ -68,6 +167,21 @@ class Op : public OpBase { ...@@ -68,6 +167,21 @@ class Op : public OpBase {
using InterfaceList = using InterfaceList =
typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type; typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
};
static std::vector<InterfaceValue> GetInterfaceMap() {
constexpr size_t interfaces_num = std::tuple_size<InterfaceList>::value;
std::vector<InterfaceValue> interfaces_map(interfaces_num);
ConstructInterfacesOrTraits<ConcreteOp, InterfaceList>::interface(
interfaces_map.data());
return interfaces_map;
}
static std::vector<TypeId> GetTraitSet() {
constexpr size_t traits_num = std::tuple_size<TraitList>::value;
std::vector<TypeId> trait_set(traits_num);
auto p_first_trait = trait_set.data();
ConstructInterfacesOrTraits<ConcreteOp, TraitList>::trait(p_first_trait);
return trait_set;
}
};
} // namespace ir } // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/op_info.h"
#include "paddle/ir/dialect.h"
#include "paddle/ir/ir_context.h"
#include "paddle/ir/op_info_impl.h"
namespace ir {
bool OpInfo::HasTrait(TypeId trait_id) const {
return impl_ && impl_->HasTrait(trait_id);
}
bool OpInfo::HasInterface(TypeId interface_id) const {
return impl_ && impl_->HasInterface(interface_id);
}
IrContext *OpInfo::ir_context() const {
return impl_ ? impl_->ir_context() : nullptr;
}
const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; }
void *OpInfo::GetInterfaceImpl(TypeId interface_id) const {
return impl_ ? impl_->interface_impl(interface_id) : nullptr;
}
ir::IrContext *OpInfoImpl::ir_context() const {
return dialect()->ir_context();
}
void *OpInfoImpl::interface_impl(TypeId interface_id) const {
if (num_interfaces_ > 0) {
const InterfaceValue *p_first_interface =
reinterpret_cast<const InterfaceValue *>(
reinterpret_cast<const char *>(this) -
sizeof(TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_);
size_t left = 0, right = num_interfaces_;
while (left < right) {
size_t mid = (left + right) / 2;
if ((p_first_interface + mid)->type_id() == interface_id) {
return (p_first_interface + mid)->model();
} else if ((p_first_interface + mid)->type_id() < interface_id) {
left = mid + 1;
} else {
right = mid;
}
}
}
return nullptr;
}
bool OpInfoImpl::HasTrait(TypeId trait_id) const {
if (num_traits_ > 0) {
const TypeId *p_first_trait =
reinterpret_cast<const TypeId *>(reinterpret_cast<const char *>(this) -
sizeof(ir::TypeId) * num_traits_);
return std::binary_search(
p_first_trait, p_first_trait + num_traits_, trait_id);
}
return false;
}
bool OpInfoImpl::HasInterface(TypeId interface_id) const {
if (num_interfaces_ > 0) {
const InterfaceValue *p_first_interface =
reinterpret_cast<const InterfaceValue *>(
reinterpret_cast<const char *>(this) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_);
return std::binary_search(p_first_interface,
p_first_interface + num_interfaces_,
InterfaceValue(interface_id));
}
return false;
}
OpInfoImpl *OpInfoImpl::create(Dialect *dialect,
TypeId op_id,
const char *op_name,
std::vector<InterfaceValue> &&interface_map,
const std::vector<TypeId> &trait_set,
size_t attributes_num,
const char *attributes_name[]) {
// (1) Malloc memory for interfaces, traits, opinfo_impl.
size_t interfaces_num = interface_map.size();
size_t traits_num = trait_set.size();
VLOG(4) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, "
<< traits_num << " traits, " << attributes_num << " attributes.";
size_t base_size = sizeof(InterfaceValue) * interfaces_num +
sizeof(TypeId) * traits_num + sizeof(OpInfoImpl);
char *base_ptr = static_cast<char *>(::operator new(base_size));
VLOG(4) << "Malloc " << base_size << " Bytes at "
<< static_cast<void *>(base_ptr);
if (interfaces_num > 0) {
std::sort(interface_map.begin(), interface_map.end());
for (size_t index = 0; index < interfaces_num; ++index) {
new (base_ptr + index * sizeof(InterfaceValue))
InterfaceValue(std::move(interface_map[index]));
}
base_ptr += interfaces_num * sizeof(InterfaceValue);
}
if (traits_num > 0) {
auto p_first_trait = reinterpret_cast<TypeId *>(base_ptr);
memcpy(base_ptr, trait_set.data(), sizeof(TypeId) * traits_num);
std::sort(p_first_trait, p_first_trait + traits_num);
base_ptr += traits_num * sizeof(TypeId);
}
// Construct opinfo_impl.
OpInfoImpl *p_opinfo_impl = reinterpret_cast<OpInfoImpl *>(base_ptr);
VLOG(4) << "Construct op_info_impl at " << p_opinfo_impl << " ......";
OpInfoImpl *op_info = new (p_opinfo_impl) OpInfoImpl(dialect,
op_id,
op_name,
interfaces_num,
traits_num,
attributes_num,
attributes_name
);
return op_info;
}
void OpInfoImpl::destroy() {
VLOG(4) << "Destroy op_info impl at " << this;
// (1) free interfaces
char *base_ptr = reinterpret_cast<char *>(this) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_;
if (num_interfaces_ > 0) {
InterfaceValue *p_interface_val =
reinterpret_cast<InterfaceValue *>(base_ptr);
for (size_t i = 0; i < num_interfaces_; i++) {
(p_interface_val + i)->~InterfaceValue();
}
}
// (2) free memeory
VLOG(4) << "Free base_ptr " << base_ptr;
free(base_ptr);
}
} // namespace ir
...@@ -13,12 +13,13 @@ ...@@ -13,12 +13,13 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <functional> #include <functional>
#include "paddle/ir/type_id.h"
#include "paddle/ir/op_info_impl.h"
namespace ir { namespace ir {
class OpInfoImpl;
class IrContext;
class OpInfo { class OpInfo {
public: public:
constexpr OpInfo() = default; constexpr OpInfo() = default;
...@@ -37,24 +38,42 @@ class OpInfo { ...@@ -37,24 +38,42 @@ class OpInfo {
bool operator!() const { return impl_ == nullptr; } bool operator!() const { return impl_ == nullptr; }
const OpInfoImpl *impl() const { return impl_; } IrContext *ir_context() const;
const char *name() const;
template <typename Trait> template <typename Trait>
bool HasTrait() const { bool HasTrait() const {
return impl_->HasTrait<Trait>(); return HasTrait(TypeId::get<Trait>());
} }
bool HasTrait(TypeId trait_id) const;
template <typename Interface> template <typename Interface>
bool HasInterface() const { bool HasInterface() const {
return impl_->HasInterface<Interface>(); return HasInterface(TypeId::get<Interface>());
} }
bool HasInterface(TypeId interface_id) const;
template <typename Interface>
typename Interface::Concept *GetInterfaceImpl() const;
friend struct std::hash<OpInfo>; friend struct std::hash<OpInfo>;
private:
void *GetInterfaceImpl(TypeId interface_id) const;
private: private:
const OpInfoImpl *impl_{nullptr}; // not owned const OpInfoImpl *impl_{nullptr}; // not owned
}; };
template <typename Interface>
typename Interface::Concept *OpInfo::GetInterfaceImpl() const {
void *model = GetInterfaceImpl(TypeId::get<Interface>());
return reinterpret_cast<typename Interface::Concept *>(model);
}
} // namespace ir } // namespace ir
namespace std { namespace std {
......
...@@ -15,77 +15,16 @@ ...@@ -15,77 +15,16 @@
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <cstring>
#include <initializer_list> #include <initializer_list>
#include <string>
#include <utility> #include <utility>
#include "paddle/ir/builtin_attribute.h" #include "paddle/ir/builtin_attribute.h"
// #include "paddle/ir/ir_context.h" #include "paddle/ir/op_base.h"
#include "paddle/ir/type.h" #include "paddle/ir/type.h"
namespace ir { namespace ir {
class Dialect; class Dialect;
///
/// \brief Tool template class for construct interfaces or Traits.
///
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits {
public:
/// Construct method for interfaces.
static std::pair<TypeId, void *> *interface(
std::pair<TypeId, void *> *p_interface) {
(void)std::initializer_list<int>{
0, (PlacementConstrctInterface<Args>(p_interface), 0)...};
return p_interface;
}
/// Construct method for traits.
static TypeId *trait(TypeId *p_trait) {
(void)std::initializer_list<int>{
0, (PlacementConstrctTrait<Args>(p_trait), 0)...};
return p_trait;
}
private:
/// Placement new interface.
template <typename T>
static void PlacementConstrctInterface(
std::pair<TypeId, void *> *&p_interface) { // NOLINT
new (&(p_interface->first)) TypeId(ir::TypeId::get<T>());
p_interface->second =
malloc(sizeof(typename T::template Model<ConcreteOp>));
new (p_interface->second) typename T::template Model<ConcreteOp>();
VLOG(4) << "New a interface: id[" << p_interface->first.storage()
<< "], interface[" << p_interface->second << "].";
++p_interface;
}
/// Placement new trait.
template <typename T>
static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT
new (p_trait) TypeId(ir::TypeId::get<T>());
VLOG(4) << "New a trait: id[" << (*p_trait).storage() << "].";
++p_trait;
}
};
/// Specialized for tuple type.
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits<ConcreteOp, std::tuple<Args...>> {
public:
/// Construct method for interfaces.
static std::pair<TypeId, void *> *interface(
std::pair<TypeId, void *> *p_interface) {
return ConstructInterfacesOrTraits<ConcreteOp, Args...>::interface(
p_interface);
}
/// Construct method for traits.
static TypeId *trait(TypeId *p_trait) {
return ConstructInterfacesOrTraits<ConcreteOp, Args...>::trait(p_trait);
}
};
/// ///
/// \brief OpInfoImpl class. /// \brief OpInfoImpl class.
/// ///
...@@ -95,143 +34,27 @@ class OpInfoImpl { ...@@ -95,143 +34,27 @@ class OpInfoImpl {
/// \brief Construct and Deconstruct OpInfoImpl. The memory layout of /// \brief Construct and Deconstruct OpInfoImpl. The memory layout of
/// OpInfoImpl is: std::pair<TypeId, void *>... | TypeId... | OpInfoImpl /// OpInfoImpl is: std::pair<TypeId, void *>... | TypeId... | OpInfoImpl
/// ///
template <typename ConcreteOp> static OpInfoImpl *create(Dialect *dialect,
static OpInfoImpl *create(ir::Dialect *dialect) { TypeId op_id,
// (1) Malloc memory for interfaces, traits, opinfo_impl. const char *op_name,
size_t interfaces_num = std::vector<InterfaceValue> &&interface_map,
std::tuple_size<typename ConcreteOp::InterfaceList>::value; const std::vector<TypeId> &trait_set,
size_t traits_num = std::tuple_size<typename ConcreteOp::TraitList>::value; size_t attributes_num,
size_t attributes_num = ConcreteOp::attributes_num(); const char *attributes_name[]);
VLOG(4) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, "
<< traits_num << " traits, " << attributes_num << " attributes.";
size_t base_size = sizeof(std::pair<ir::TypeId, void *>) * interfaces_num +
sizeof(ir::TypeId) * traits_num + sizeof(OpInfoImpl);
void *base_ptr = malloc(base_size);
VLOG(4) << "Malloc " << base_size << " Bytes at " << base_ptr;
// (2) Construct interfaces and sort by TypeId.
std::pair<ir::TypeId, void *> *p_first_interface = nullptr;
if (interfaces_num > 0) {
p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(base_ptr);
VLOG(4) << "Construct interfaces at " << p_first_interface << " ......";
ConstructInterfacesOrTraits<
ConcreteOp,
typename ConcreteOp::InterfaceList>::interface(p_first_interface);
std::sort(p_first_interface, p_first_interface + interfaces_num);
base_ptr = reinterpret_cast<void *>(p_first_interface + interfaces_num);
}
// (3) Construct traits and sort by TypeId.
ir::TypeId *p_first_trait = nullptr;
if (traits_num > 0) {
p_first_trait = reinterpret_cast<ir::TypeId *>(base_ptr);
VLOG(4) << "Construct traits at " << p_first_trait << " ......";
ConstructInterfacesOrTraits<ConcreteOp, typename ConcreteOp::TraitList>::
trait(p_first_trait);
std::sort(p_first_trait, p_first_trait + traits_num);
base_ptr = reinterpret_cast<void *>(p_first_trait + traits_num);
}
// (4) Construct opinfo_impl. void destroy();
OpInfoImpl *p_opinfo_impl = reinterpret_cast<OpInfoImpl *>(base_ptr);
VLOG(4) << "Construct op_info_impl at " << p_opinfo_impl << " ......";
OpInfoImpl *op_info =
new (p_opinfo_impl) OpInfoImpl(interfaces_num,
traits_num,
ConcreteOp::attributes_name_,
attributes_num,
ir::TypeId::get<ConcreteOp>(),
ConcreteOp::name(),
dialect);
return op_info;
}
void destroy() { ir::IrContext *ir_context() const;
VLOG(4) << "Destroy op_info impl at " << this;
// (1) free interfaces
void *base_ptr = reinterpret_cast<void *>(
reinterpret_cast<char *>(this) - sizeof(ir::TypeId) * num_traits_ -
sizeof(std::pair<ir::TypeId, void *>) * num_interfaces_);
if (num_interfaces_ > 0) {
std::pair<ir::TypeId, void *> *p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(base_ptr);
for (size_t i = 0; i < num_interfaces_; i++) {
free((p_first_interface + i)->second);
}
}
// (2) free memeory
VLOG(4) << "Free base_ptr " << base_ptr;
free(base_ptr);
}
///
/// \brief Search methods for Trait or Interface. /// \brief Search methods for Trait or Interface.
/// bool HasTrait(TypeId trait_id) const;
template <typename Trait>
bool HasTrait() const {
return HasTrait(TypeId::get<Trait>());
}
bool HasTrait(TypeId trait_id) const { bool HasInterface(TypeId interface_id) const;
if (num_traits_ > 0) {
TypeId *p_first_trait = reinterpret_cast<TypeId *>(
reinterpret_cast<char *>(const_cast<OpInfoImpl *>(this)) -
sizeof(ir::TypeId) * num_traits_);
return std::binary_search(
p_first_trait, p_first_trait + num_traits_, trait_id);
}
return false;
}
template <typename Interface>
bool HasInterface() const {
return HasInterface(TypeId::get<Interface>());
}
bool HasInterface(TypeId interface_id) const {
if (num_interfaces_ > 0) {
std::pair<ir::TypeId, void *> *p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(
reinterpret_cast<char *>(const_cast<OpInfoImpl *>(this)) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(std::pair<ir::TypeId, void *>) * num_interfaces_);
return std::binary_search(p_first_interface,
p_first_interface + num_interfaces_,
std::make_pair(interface_id, nullptr),
CompareInterface);
}
return false;
}
template <typename Interface>
typename Interface::Concept *GetInterfaceImpl() const {
if (num_interfaces_ > 0) {
ir::TypeId interface_id = ir::TypeId::get<Interface>();
std::pair<ir::TypeId, void *> *p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(
reinterpret_cast<char *>(const_cast<OpInfoImpl *>(this)) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(std::pair<ir::TypeId, void *>) * num_interfaces_);
size_t left = 0;
size_t right = num_interfaces_;
while (left < right) {
size_t mid = left + (right - left) / 2;
if ((p_first_interface + mid)->first == interface_id) {
return reinterpret_cast<typename Interface::Concept *>(
(p_first_interface + mid)->second);
} else if ((p_first_interface + mid)->first < interface_id) {
left = mid + 1;
} else {
right = mid;
}
}
}
return nullptr;
}
ir::TypeId id() const { return op_id_; } ir::TypeId id() const { return op_id_; }
void *interface_impl(TypeId interface_id) const;
const char *name() const { return op_name_; } const char *name() const { return op_name_; }
ir::Dialect *dialect() const { return dialect_; } ir::Dialect *dialect() const { return dialect_; }
...@@ -243,25 +66,29 @@ class OpInfoImpl { ...@@ -243,25 +66,29 @@ class OpInfoImpl {
} }
private: private:
OpInfoImpl(uint32_t num_interfaces, OpInfoImpl(ir::Dialect *dialect,
uint32_t num_traits,
const char **p_attributes,
uint32_t num_attributes,
TypeId op_id, TypeId op_id,
const char *op_name, const char *op_name,
ir::Dialect *dialect) uint32_t num_interfaces,
: num_interfaces_(num_interfaces), uint32_t num_traits,
num_traits_(num_traits), uint32_t num_attributes,
p_attributes_(p_attributes), const char **p_attributes)
num_attributes_(num_attributes), : dialect_(dialect),
op_id_(op_id), op_id_(op_id),
op_name_(op_name), op_name_(op_name),
dialect_(dialect) {} num_interfaces_(num_interfaces),
num_traits_(num_traits),
num_attributes_(num_attributes),
p_attributes_(p_attributes) {}
static bool CompareInterface(const std::pair<ir::TypeId, void *> &a, /// The dialect of this Op belong to.
const std::pair<ir::TypeId, void *> &b) { ir::Dialect *dialect_;
return a.first < b.first;
} /// The TypeId of this Op.
TypeId op_id_;
/// The name of this Op.
const char *op_name_;
/// Interface will be recorded by std::pair<TypeId, void*>. /// Interface will be recorded by std::pair<TypeId, void*>.
uint32_t num_interfaces_ = 0; uint32_t num_interfaces_ = 0;
...@@ -269,20 +96,11 @@ class OpInfoImpl { ...@@ -269,20 +96,11 @@ class OpInfoImpl {
/// Trait will be recorded by TypeId. /// Trait will be recorded by TypeId.
uint32_t num_traits_ = 0; uint32_t num_traits_ = 0;
/// Attributes array address.
const char **p_attributes_{nullptr};
/// The number of attributes for this Op. /// The number of attributes for this Op.
uint32_t num_attributes_ = 0; uint32_t num_attributes_ = 0;
/// The TypeId of this Op. /// Attributes array address.
TypeId op_id_; const char **p_attributes_{nullptr};
/// The name of this Op.
const char *op_name_;
/// The dialect of this Op belong to.
ir::Dialect *dialect_;
}; };
} // namespace ir } // namespace ir
...@@ -18,6 +18,13 @@ ...@@ -18,6 +18,13 @@
#include "paddle/ir/utils.h" #include "paddle/ir/utils.h"
namespace ir { namespace ir {
Operation *Operation::create(const OperationArgument &argument) {
return create(argument.inputs_,
argument.output_types_,
argument.attribute_,
argument.info_);
}
// Allocate the required memory based on the size and number of inputs, outputs, // Allocate the required memory based on the size and number of inputs, outputs,
// and operators, and construct it in the order of: OpOutlineResult, // and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, Operand. // OpInlineResult, Operation, Operand.
...@@ -126,6 +133,8 @@ void Operation::destroy() { ...@@ -126,6 +133,8 @@ void Operation::destroy() {
aligned_free(reinterpret_cast<void *>(aligned_ptr)); aligned_free(reinterpret_cast<void *>(aligned_ptr));
} }
IrContext *Operation::ir_context() const { return op_info_.ir_context(); }
Operation::Operation(uint32_t num_results, Operation::Operation(uint32_t num_results,
uint32_t num_operands, uint32_t num_operands,
const AttributeMap &attribute, const AttributeMap &attribute,
...@@ -190,9 +199,6 @@ std::string Operation::print() { ...@@ -190,9 +199,6 @@ std::string Operation::print() {
return result.str(); return result.str();
} }
std::string Operation::op_name() const { std::string Operation::op_name() const { return op_info_.name(); }
return op_info_.impl()->dialect()->name() + "." +
std::string(op_info_.impl()->name());
}
} // namespace ir } // namespace ir
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/ir/builtin_attribute.h" #include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/op_info.h" #include "paddle/ir/op_info.h"
#include "paddle/ir/operation_utils.h"
#include "paddle/ir/type.h" #include "paddle/ir/type.h"
#include "paddle/ir/value_impl.h" #include "paddle/ir/value_impl.h"
...@@ -26,8 +27,6 @@ template <typename ConcreteInterface> ...@@ -26,8 +27,6 @@ template <typename ConcreteInterface>
class OpInterfaceBase; class OpInterfaceBase;
class Program; class Program;
using AttributeMap = std::unordered_map<std::string, Attribute>;
class alignas(8) Operation final { class alignas(8) Operation final {
public: public:
/// ///
...@@ -40,12 +39,15 @@ class alignas(8) Operation final { ...@@ -40,12 +39,15 @@ class alignas(8) Operation final {
const std::vector<ir::Type> &output_types, const std::vector<ir::Type> &output_types,
const AttributeMap &attribute, const AttributeMap &attribute,
ir::OpInfo op_info); ir::OpInfo op_info);
static Operation *create(const OperationArgument &op_argument);
/// ///
/// \brief Destroy the operation objects and free memeory by create(). /// \brief Destroy the operation objects and free memeory by create().
/// ///
void destroy(); void destroy();
IrContext *ir_context() const;
ir::OpResult GetResultByIndex(uint32_t index); ir::OpResult GetResultByIndex(uint32_t index);
ir::OpOperand GetOperandByIndex(uint32_t index); ir::OpOperand GetOperandByIndex(uint32_t index);
...@@ -99,14 +101,17 @@ class alignas(8) Operation final { ...@@ -99,14 +101,17 @@ class alignas(8) Operation final {
struct CastUtil<T, struct CastUtil<T,
typename std::enable_if< typename std::enable_if<
std::is_base_of<OpTraitBase<T>, T>::value>::type> { std::is_base_of<OpTraitBase<T>, T>::value>::type> {
static T call(const Operation *op) { return T(op); } static T call(const Operation *op) {
return T(op->HasTrait<T>() ? op : nullptr);
}
}; };
template <typename T> template <typename T>
struct CastUtil<T, struct CastUtil<T,
typename std::enable_if< typename std::enable_if<
std::is_base_of<OpInterfaceBase<T>, T>::value>::type> { std::is_base_of<OpInterfaceBase<T>, T>::value>::type> {
static T call(const Operation *op) { static T call(const Operation *op) {
return T(op, op->op_info_.impl()->GetInterfaceImpl<T>()); typename T::Concept *interface_impl = op->op_info().GetInterfaceImpl<T>();
return interface_impl ? T(op, interface_impl) : T(nullptr, nullptr);
} }
}; };
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/operation_utils.h"
namespace ir {
OperationArgument::OperationArgument(IrContext* ir_context, std::string name) {
info_ = ir_context->GetRegisteredOpInfo(name);
}
OperationArgument::OperationArgument(OpInfo info,
const std::vector<OpResult>& operands,
const std::vector<Type>& types,
const AttributeMap& named_attr)
: info_(info),
inputs_(operands),
output_types_(types),
attribute_(named_attr) {}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/op_info.h"
#include "paddle/ir/type.h"
#include "paddle/ir/value_impl.h"
namespace ir {
using AttributeMap = std::unordered_map<std::string, Attribute>;
//===----------------------------------------------------------------------===//
// OperationArgument
//===----------------------------------------------------------------------===//
// This represents an operation arguments in an combined form, suitable for use
// with the builder APIs.
struct OperationArgument {
OpInfo info_;
std::vector<OpResult> inputs_;
std::vector<Type> output_types_;
AttributeMap attribute_;
public:
OperationArgument(IrContext* ir_context, std::string name);
explicit OperationArgument(OpInfo info) : info_(info) {}
OperationArgument(OpInfo info,
const std::vector<OpResult>& operands,
const std::vector<Type>& types,
const AttributeMap& named_attr = {});
template <class InputIt>
void addOperands(InputIt first, InputIt last);
template <class InputIt>
void addTypes(InputIt first, InputIt last);
/// Add an attribute with the specified name.
void addAttribute(const std::string& name, Attribute attr) {
attribute_[name] = attr;
}
/// Add an array of named attributes.
template <class InputIt>
void addAttributes(InputIt first, InputIt last);
/// Get the context held by this operation state.
IrContext* getContext() const { return info_.ir_context(); }
};
template <class InputIt>
void OperationArgument::addOperands(InputIt first, InputIt last) {
while (first != last) {
inputs_.emplace_back(*first++);
}
}
template <class InputIt>
void OperationArgument::addTypes(InputIt first, InputIt last) {
while (first != last) {
output_types_.emplace_back(*first++);
}
}
template <class InputIt>
void OperationArgument::addAttributes(InputIt first, InputIt last) {
while (first != last) {
attribute_[first->first] = first->second;
++first;
}
}
} // namespace ir
...@@ -45,6 +45,8 @@ class TypeId { ...@@ -45,6 +45,8 @@ class TypeId {
return TypeId(&instance); return TypeId(&instance);
} }
TypeId() = default;
TypeId(const TypeId &other) = default; TypeId(const TypeId &other) = default;
TypeId &operator=(const TypeId &other) = default; TypeId &operator=(const TypeId &other) = default;
...@@ -77,9 +79,8 @@ class TypeId { ...@@ -77,9 +79,8 @@ class TypeId {
/// ///
explicit TypeId(const Storage *storage) : storage_(storage) {} explicit TypeId(const Storage *storage) : storage_(storage) {}
const Storage *storage_; const Storage *storage_{nullptr};
}; };
} // namespace ir } // namespace ir
namespace std { namespace std {
......
...@@ -47,9 +47,8 @@ class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> { ...@@ -47,9 +47,8 @@ class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> {
} }
Model() : Concept(InferShape) { Model() : Concept(InferShape) {
if (sizeof(Model) != sizeof(Concept)) { static_assert(sizeof(Model) == sizeof(Concept),
throw("sizeof(Model) != sizeof(Concept)"); "sizeof(Model) != sizeof(Concept)");
}
} }
}; };
...@@ -66,25 +65,27 @@ class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> { ...@@ -66,25 +65,27 @@ class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> {
class Operation1 : public ir::Op<Operation1> { class Operation1 : public ir::Op<Operation1> {
public: public:
using Op::Op; using Op::Op;
static const char *name() { return "Operation1"; } static const char *name() { return "test.operation1"; }
static const char *attributes_name_[]; static constexpr uint32_t attributes_num = 2;
static uint32_t attributes_num() { return 2; } static const char *attributes_name[attributes_num];
}; };
const char *Operation1::attributes_name_[] = {"op1_attr1", "op1_attr2"}; const char *Operation1::attributes_name[attributes_num] = {"op1_attr1",
"op1_attr2"};
// Define op2. // Define op2.
class Operation2 class Operation2
: public ir::Op<Operation2, ReadOnlyTrait, InferShapeInterface> { : public ir::Op<Operation2, ReadOnlyTrait, InferShapeInterface> {
public: public:
using Op::Op; using Op::Op;
static const char *name() { return "Operation2"; } static const char *name() { return "test.operation2"; }
static const char *attributes_name_[]; static constexpr uint32_t attributes_num = 2;
static uint32_t attributes_num() { return 2; } static const char *attributes_name[attributes_num];
static void InferShape() { static void InferShape() {
std::cout << "This is op2's InferShape interface." << std::endl; std::cout << "This is op2's InferShape interface." << std::endl;
} }
}; };
const char *Operation2::attributes_name_[] = {"op2_attr1", "op2_attr2"}; const char *Operation2::attributes_name[attributes_num] = {"op2_attr1",
"op2_attr2"};
// Define a dialect, op1 and op2 will be registered by this dialect. // Define a dialect, op1 and op2 will be registered by this dialect.
class TestDialect : public ir::Dialect { class TestDialect : public ir::Dialect {
...@@ -93,7 +94,7 @@ class TestDialect : public ir::Dialect { ...@@ -93,7 +94,7 @@ class TestDialect : public ir::Dialect {
: ir::Dialect(name(), context, ir::TypeId::get<TestDialect>()) { : ir::Dialect(name(), context, ir::TypeId::get<TestDialect>()) {
initialize(); initialize();
} }
static const char *name() { return "op_test"; } static const char *name() { return "test"; }
private: private:
void initialize() { RegisterOps<Operation1, Operation2>(); } void initialize() { RegisterOps<Operation1, Operation2>(); }
...@@ -116,19 +117,17 @@ TEST(op_test, op_test) { ...@@ -116,19 +117,17 @@ TEST(op_test, op_test) {
std::cout << test_dialect << std::endl; std::cout << test_dialect << std::endl;
// (2) Get registered operations. // (2) Get registered operations.
std::string op1_name = std::string op1_name = Operation1::name();
test_dialect->name() + "." + std::string(Operation1::name()); ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name);
ir::OpInfoImpl *op1_info = ctx->GetRegisteredOpInfo(op1_name);
EXPECT_EQ(op1_info != nullptr, true); EXPECT_EQ(op1_info != nullptr, true);
std::string op2_name = std::string op2_name = Operation2::name();
test_dialect->name() + "." + std::string(Operation2::name()); ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
ir::OpInfoImpl *op2_info = ctx->GetRegisteredOpInfo(op2_name);
EXPECT_EQ(op2_info != nullptr, true); EXPECT_EQ(op2_info != nullptr, true);
EXPECT_EQ(op1_info->HasTrait<ReadOnlyTrait>(), false); EXPECT_EQ(op1_info.HasTrait<ReadOnlyTrait>(), false);
EXPECT_EQ(op1_info->HasInterface<InferShapeInterface>(), false); EXPECT_EQ(op1_info.HasInterface<InferShapeInterface>(), false);
EXPECT_EQ(op2_info->HasTrait<ReadOnlyTrait>(), true); EXPECT_EQ(op2_info.HasTrait<ReadOnlyTrait>(), true);
EXPECT_EQ(op2_info->HasInterface<InferShapeInterface>(), true); EXPECT_EQ(op2_info.HasInterface<InferShapeInterface>(), true);
// (3) Test uses for op. // (3) Test uses for op.
std::vector<ir::OpResult> op_inputs = {}; std::vector<ir::OpResult> op_inputs = {};
......
...@@ -31,11 +31,10 @@ ...@@ -31,11 +31,10 @@
class AddOp : public ir::Op<AddOp> { class AddOp : public ir::Op<AddOp> {
public: public:
using Op::Op; using Op::Op;
static const char *name() { return "Add"; } static const char *name() { return "test.add"; }
static const char **attributes_name_; static constexpr const char **attributes_name = nullptr;
static uint32_t attributes_num() { return 0; } static constexpr uint32_t attributes_num = 0;
}; };
const char **AddOp::attributes_name_ = nullptr;
TEST(program_test, program) { TEST(program_test, program) {
// (1) Init environment. // (1) Init environment.
...@@ -78,9 +77,8 @@ TEST(program_test, program) { ...@@ -78,9 +77,8 @@ TEST(program_test, program) {
EXPECT_EQ(program.parameters_num() == 2, true); EXPECT_EQ(program.parameters_num() == 2, true);
// (4) Def a = GetParameterOp("a"), and create DenseTensor for a. // (4) Def a = GetParameterOp("a"), and create DenseTensor for a.
std::string op1_name = std::string op1_name = ir::GetParameterOp::name();
builtin_dialect->name() + "." + std::string(ir::GetParameterOp::name()); ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name);
ir::OpInfoImpl *op1_info = ctx->GetRegisteredOpInfo(op1_name);
std::unordered_map<std::string, ir::Attribute> op1_attribute{ std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "a")}}; {"parameter_name", ir::StrAttribute::get(ctx, "a")}};
ir::Operation *op1 = ir::Operation *op1 =
...@@ -112,7 +110,7 @@ TEST(program_test, program) { ...@@ -112,7 +110,7 @@ TEST(program_test, program) {
// (5) Def b = GetParameterOp("b"), and create DenseTensor for b. // (5) Def b = GetParameterOp("b"), and create DenseTensor for b.
std::string op2_name = std::string op2_name =
builtin_dialect->name() + "." + std::string(ir::GetParameterOp::name()); builtin_dialect->name() + "." + std::string(ir::GetParameterOp::name());
ir::OpInfoImpl *op2_info = ctx->GetRegisteredOpInfo(op2_name); ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
std::unordered_map<std::string, ir::Attribute> op2_attribute{ std::unordered_map<std::string, ir::Attribute> op2_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "b")}}; {"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 = ir::Operation *op2 =
...@@ -142,7 +140,7 @@ TEST(program_test, program) { ...@@ -142,7 +140,7 @@ TEST(program_test, program) {
// (6) Def c = AddOp(a, b), execute this op. // (6) Def c = AddOp(a, b), execute this op.
std::string op3_name = std::string op3_name =
builtin_dialect->name() + "." + std::string(AddOp::name()); builtin_dialect->name() + "." + std::string(AddOp::name());
ir::OpInfoImpl *op3_info = ctx->GetRegisteredOpInfo(op3_name); ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name);
std::unordered_map<std::string, ir::Attribute> op3_attribute; std::unordered_map<std::string, ir::Attribute> op3_attribute;
ir::Operation *op3 = ir::Operation::create( ir::Operation *op3 = ir::Operation::create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, {op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
...@@ -173,7 +171,7 @@ TEST(program_test, program) { ...@@ -173,7 +171,7 @@ TEST(program_test, program) {
// (7) Def SetParameterOp(c, "c") // (7) Def SetParameterOp(c, "c")
std::string op4_name = std::string op4_name =
builtin_dialect->name() + "." + std::string(ir::SetParameterOp::name()); builtin_dialect->name() + "." + std::string(ir::SetParameterOp::name());
ir::OpInfoImpl *op4_info = ctx->GetRegisteredOpInfo(op4_name); ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name);
std::unordered_map<std::string, ir::Attribute> op4_attribute{ std::unordered_map<std::string, ir::Attribute> op4_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "c")}}; {"parameter_name", ir::StrAttribute::get(ctx, "c")}};
ir::Operation *op4 = ir::Operation::create( ir::Operation *op4 = ir::Operation::create(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册