未验证 提交 9d9f0ce5 编写于 作者: 王明冬 提交者: GitHub

[IR] fine-tune the implementation of ir component. (#53894)

上级 d29c1f8e
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include "paddle/ir/operation.h"
namespace ir {
///
/// \brief Unified interface of the Attribute class. Derivation of all Attribute
/// classes only derives interfaces, not members.
///
class Builder {
public:
explicit Builder(IrContext *context) : context_(context) {}
explicit Builder(Operation *op) : Builder(op->ir_context()) {}
/// Create an operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpTy create(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::build(*this, argument, std::forward<Args>(args)...);
Operation *op = Operation::create(argument);
return dyn_cast<OpTy>(op);
}
private:
IrContext *context_;
// The current op list this builder is inserting into.
// After the design of the block data structure is completed,
// this member will be replaced by the block.
std::list<Operation *> *op_list_ = nullptr;
// The insertion point within the list that this builder is inserting before.
std::list<Operation *>::iterator insertPoint;
};
} // namespace ir
......@@ -15,8 +15,10 @@
#include "paddle/ir/builtin_op.h"
namespace ir {
const char *GetParameterOp::attributes_name_[] = {"parameter_name"};
const char *GetParameterOp::attributes_name[attributes_num] = {
"parameter_name"};
const char *SetParameterOp::attributes_name_[] = {"parameter_name"};
const char *SetParameterOp::attributes_name[attributes_num] = {
"parameter_name"};
} // namespace ir
......@@ -32,11 +32,11 @@ class GetParameterOp : public ir::Op<GetParameterOp> {
public:
using Op::Op;
static const char* name() { return "GetParameterOp"; }
static const char* name() { return "builtin.get_parameter"; }
static uint32_t attributes_num() { return 1; }
static constexpr uint32_t attributes_num = 1;
static const char* attributes_name_[];
static const char* attributes_name[attributes_num];
};
///
......@@ -47,11 +47,11 @@ class SetParameterOp : public ir::Op<SetParameterOp> {
public:
using Op::Op;
static const char* name() { return "SetParameterOp"; }
static const char* name() { return "builtin.set_parameter"; }
static uint32_t attributes_num() { return 1; }
static constexpr uint32_t attributes_num = 1;
static const char* attributes_name_[];
static const char* attributes_name[attributes_num];
};
} // namespace ir
......@@ -20,24 +20,6 @@ Dialect::Dialect(std::string name, ir::IrContext *context, ir::TypeId id)
Dialect::~Dialect() = default;
void Dialect::RegisterType(ir::AbstractType &&abstract_type) {
ir::AbstractType *new_abstract_type =
new ir::AbstractType(std::move(abstract_type));
this->ir_context()->RegisterAbstractType(new_abstract_type->type_id(),
new_abstract_type);
}
void Dialect::RegisterAttribute(ir::AbstractAttribute &&abstract_attribute) {
ir::AbstractAttribute *new_abstract_attribute =
new ir::AbstractAttribute(std::move(abstract_attribute));
this->ir_context()->RegisterAbstractAttribute(
new_abstract_attribute->type_id(), new_abstract_attribute);
}
void Dialect::RegisterOp(const std::string &name, OpInfoImpl *op_info) {
this->ir_context()->RegisterOpInfo(name, op_info);
}
void Dialect::RegisterInterface(std::unique_ptr<DialectInterface> interface) {
VLOG(4) << "Register interface into dialect" << std::endl;
auto it = registered_interfaces_.emplace(interface->interface_id(),
......
......@@ -17,7 +17,7 @@
#include "paddle/ir/attribute_base.h"
#include "paddle/ir/dialect_interface.h"
#include "paddle/ir/ir_context.h"
#include "paddle/ir/op_info_impl.h"
#include "paddle/ir/op_base.h"
#include "paddle/ir/type_base.h"
namespace ir {
......@@ -52,26 +52,10 @@ class Dialect {
template <typename T>
void RegisterType() {
VLOG(4) << "Type registered into Dialect. --->";
if (this->ir_context()->GetRegisteredAbstractType(ir::TypeId::get<T>()) ==
nullptr) {
ir::AbstractType *abstract_type =
new ir::AbstractType(std::move(ir::AbstractType::get<T>(*this)));
this->ir_context()->RegisterAbstractType(ir::TypeId::get<T>(),
abstract_type);
ir::TypeManager::RegisterType<T>(this->ir_context());
ir_context()->RegisterAbstractType(TypeId::get<T>(),
AbstractType::get<T>(*this));
TypeManager::RegisterType<T>(ir_context());
}
VLOG(4) << "----------------------------------";
}
///
/// \brief Register abstract_type into context.
/// NOTE: It's not recommended to use this interface directly. This interface
/// only registers abstract_type. To register TypeStorage into context, you
/// need to call ir::TypeManager::RegisterType<T>() additionally,
/// RegisterType<T>() is recommended to use.
///
void RegisterType(ir::AbstractType &&abstract_type);
///
/// \brief Register all attributes contained in the template parameter Args.
......@@ -85,37 +69,28 @@ class Dialect {
template <typename T>
void RegisterAttribute() {
VLOG(4) << "Attribute registered into Dialect. --->";
if (this->ir_context()->GetRegisteredAbstractAttribute(
ir::TypeId::get<T>()) == nullptr) {
ir::AbstractAttribute *abstract_attribute = new ir::AbstractAttribute(
std::move(ir::AbstractAttribute::get<T>(*this)));
this->ir_context()->RegisterAbstractAttribute(ir::TypeId::get<T>(),
abstract_attribute);
ir::AttributeManager::RegisterAttribute<T>(this->ir_context());
ir_context()->RegisterAbstractAttribute(TypeId::get<T>(),
AbstractAttribute::get<T>(*this));
AttributeManager::RegisterAttribute<T>(ir_context());
}
VLOG(4) << "----------------------------------";
}
void RegisterAttribute(ir::AbstractAttribute &&abstract_attribute);
///
/// \brief Register Operation methods.
/// \brief Register Ops.
///
template <typename... Args>
void RegisterOps() {
(void)std::initializer_list<int>{0, (RegisterOp<Args>(), 0)...};
}
template <typename ConcertOp>
template <typename ConcreteOp>
void RegisterOp() {
std::string name = this->name() + "." + std::string(ConcertOp::name());
VLOG(4) << "Op " << name << " registered into Dialect. --->";
if (this->ir_context()->GetRegisteredOpInfo(name) == nullptr) {
ir::OpInfoImpl *op_info = ir::OpInfoImpl::create<ConcertOp>(this);
this->ir_context()->RegisterOpInfo(name, op_info);
}
VLOG(4) << "----------------------------------";
ir_context()->RegisterOpInfo(this,
TypeId::get<ConcreteOp>(),
ConcreteOp::name(),
ConcreteOp::GetInterfaceMap(),
ConcreteOp::GetTraitSet(),
ConcreteOp::attributes_num,
ConcreteOp::attributes_name);
}
void RegisterOp(const std::string &name, OpInfoImpl *op_info);
......
......@@ -185,11 +185,6 @@ IrContext::IrContext() : impl_(new IrContextImpl()) {
impl_->int64_type = TypeManager::get<Int64Type>(this);
}
void IrContext::RegisterAbstractType(ir::TypeId type_id,
AbstractType *abstract_type) {
impl().RegisterAbstractType(type_id, abstract_type);
}
StorageManager &IrContext::type_storage_manager() {
return impl().registed_type_storage_manager_;
}
......@@ -203,8 +198,14 @@ AbstractType *IrContext::GetRegisteredAbstractType(TypeId id) {
}
void IrContext::RegisterAbstractAttribute(
ir::TypeId type_id, AbstractAttribute *abstract_attribute) {
impl().RegisterAbstractAttribute(type_id, abstract_attribute);
ir::TypeId type_id, AbstractAttribute &&abstract_attribute) {
if (GetRegisteredAbstractAttribute(type_id) == nullptr) {
impl().RegisterAbstractAttribute(
type_id, new AbstractAttribute(std::move(abstract_attribute)));
VLOG(4) << "<--- Attribute registered into IrContext. --->";
} else {
LOG(WARNING) << " Attribute already registered.";
}
}
StorageManager &IrContext::attribute_storage_manager() {
......@@ -251,17 +252,44 @@ Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) {
return nullptr;
}
OpInfoImpl *IrContext::GetRegisteredOpInfo(const std::string &name) {
OpInfoImpl *rtn = impl().GetOpInfo(name);
return rtn ? rtn : nullptr;
void IrContext::RegisterAbstractType(ir::TypeId type_id,
AbstractType &&abstract_type) {
if (GetRegisteredAbstractType(type_id) == nullptr) {
impl().RegisterAbstractType(type_id,
new AbstractType(std::move(abstract_type)));
VLOG(4) << "<--- Type registered into IrContext. --->";
} else {
LOG(WARNING) << " type already registered.";
}
}
void IrContext::RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) {
if (impl().GetOpInfo(name) == nullptr) {
void IrContext::RegisterOpInfo(Dialect *dialect,
TypeId op_id,
const char *name,
std::vector<InterfaceValue> &&interface_map,
const std::vector<TypeId> &trait_set,
size_t attributes_num,
const char **attributes_name) {
if (GetRegisteredOpInfo(name) == nullptr) {
OpInfoImpl *opinfo = OpInfoImpl::create(dialect,
op_id,
name,
std::move(interface_map),
trait_set,
attributes_num,
attributes_name);
impl().RegisterOpInfo(name, opinfo);
VLOG(4) << "Op " << name << " registered into IrContext. --->";
} else {
LOG(WARNING) << name << " op already registered.";
}
}
OpInfo IrContext::GetRegisteredOpInfo(const std::string &name) {
OpInfoImpl *rtn = impl().GetOpInfo(name);
return rtn ? rtn : nullptr;
}
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
auto &impl = ctx->impl();
AbstractType *abstract_type = impl.GetAbstractType(type_id);
......
......@@ -13,11 +13,9 @@
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <functional>
#include <memory>
#include <unordered_map>
#include <vector>
namespace ir {
class IrContextImpl;
......@@ -26,8 +24,8 @@ class AbstractType;
class AbstractAttribute;
class TypeId;
class Dialect;
class OpInfoImpl;
class OpInfo;
class InterfaceValue;
///
/// \brief IrContext is a global parameterless class used to store and manage
/// Type, Attribute and other related data structures.
......@@ -53,7 +51,7 @@ class IrContext {
/// \param type_id The type id of the AbstractType.
/// \param abstract_type AbstractType* provided by user.
///
void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type);
void RegisterAbstractType(TypeId type_id, AbstractType &&abstract_type);
///
/// \brief Returns the storage uniquer used for constructing TypeStorage
......@@ -73,10 +71,10 @@ class IrContext {
/// \brief Register an AbstractAttribute to IrContext
///
/// \param type_id The type id of the AbstractAttribute.
/// \param abstract_attribute AbstractAttribute* provided by user.
/// \param abstract_attribute AbstractAttribute provided by user.
///
void RegisterAbstractAttribute(ir::TypeId type_id,
AbstractAttribute *abstract_attribute);
AbstractAttribute &&abstract_attribute);
///
/// \brief Returns the storage uniquer used for constructing AttributeStorage
......@@ -93,11 +91,20 @@ class IrContext {
AbstractAttribute *GetRegisteredAbstractAttribute(TypeId id);
///
/// \brief Get or register operaiton.
/// \brief Register an op infomation to IrContext
///
void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo);
void RegisterOpInfo(Dialect *dialect,
TypeId op_id,
const char *name,
std::vector<InterfaceValue> &&interface_map,
const std::vector<TypeId> &trait_set,
size_t attributes_num,
const char **attributes_name);
OpInfoImpl *GetRegisteredOpInfo(const std::string &name);
///
/// \brief Get registered operaiton infomation.
///
OpInfo GetRegisteredOpInfo(const std::string &name);
///
/// \brief Get the dialect of the DialectT class in the context, ff not found,
......@@ -162,7 +169,6 @@ class IrContext {
private:
IrContext();
const std::unique_ptr<IrContextImpl> impl_;
};
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/op_base.h"
namespace ir {
InterfaceValue::~InterfaceValue() {
if (model_) free(model_);
}
InterfaceValue::InterfaceValue(InterfaceValue&& val) {
type_id_ = val.type_id_;
model_ = val.model_;
val.model_ = nullptr;
}
InterfaceValue& InterfaceValue::operator=(InterfaceValue&& val) {
swap(std::move(val));
return *this;
}
} // namespace ir
......@@ -13,11 +13,57 @@
// limitations under the License.
#pragma once
#include <type_traits>
#include "paddle/ir/operation.h"
#include "paddle/ir/utils.h"
namespace ir {
class InterfaceValue {
public:
template <typename ConcreteOp, typename T>
static InterfaceValue get() {
InterfaceValue val;
val.type_id_ = TypeId::get<T>();
val.model_ = malloc(sizeof(typename T::template Model<ConcreteOp>));
if (val.model_ == nullptr) {
throw("Alloc memory for interface failed.");
}
static_assert(std::is_trivially_destructible<
typename T::template Model<ConcreteOp>>::value,
"interface models must be trivially destructible");
new (val.model_) typename T::template Model<ConcreteOp>();
return val;
}
TypeId type_id() const { return type_id_; }
void *model() const { return model_; }
InterfaceValue() = default;
explicit InterfaceValue(TypeId type_id) : type_id_(type_id) {}
InterfaceValue(const InterfaceValue &) = delete;
InterfaceValue(InterfaceValue &&);
InterfaceValue &operator=(const InterfaceValue &) = delete;
InterfaceValue &operator=(InterfaceValue &&);
~InterfaceValue();
void swap(InterfaceValue &&val) {
using std::swap;
swap(type_id_, val.type_id_);
swap(model_, val.model_);
}
///
/// \brief Comparison operations.
///
inline bool operator<(const InterfaceValue &other) const {
return type_id_ < other.type_id_;
}
private:
TypeId type_id_;
void *model_{nullptr};
};
class OpBase {
public:
explicit OpBase(const Operation *operation) : operation_(operation) {}
......@@ -58,6 +104,59 @@ class OpInterfaceBase : public OpBase {
static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }
};
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits {
public:
/// Construct method for interfaces.
static InterfaceValue *interface(InterfaceValue *p_interface) {
(void)std::initializer_list<int>{
0, (PlacementConstrctInterface<Args>(p_interface), 0)...};
return p_interface;
}
/// Construct method for traits.
static TypeId *trait(TypeId *p_trait) {
(void)std::initializer_list<int>{
0, (PlacementConstrctTrait<Args>(p_trait), 0)...};
return p_trait;
}
private:
/// Placement new interface.
template <typename T>
static void PlacementConstrctInterface(
InterfaceValue *&p_interface) { // NOLINT
p_interface->swap(InterfaceValue::get<ConcreteOp, T>());
VLOG(4) << "New a interface: id[" << (p_interface->type_id()).storage()
<< "].";
++p_interface;
}
/// Placement new trait.
template <typename T>
static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT
*p_trait = TypeId::get<T>();
VLOG(4) << "New a trait: id[" << p_trait->storage() << "].";
++p_trait;
}
};
/// Specialized for tuple type.
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits<ConcreteOp, std::tuple<Args...>> {
public:
/// Construct method for interfaces.
static InterfaceValue *interface(InterfaceValue *p_interface) {
return ConstructInterfacesOrTraits<ConcreteOp, Args...>::interface(
p_interface);
}
/// Construct method for traits.
static TypeId *trait(TypeId *p_trait) {
return ConstructInterfacesOrTraits<ConcreteOp, Args...>::trait(p_trait);
}
};
template <typename ConcreteOp, class... TraitOrInterface>
class Op : public OpBase {
public:
......@@ -68,6 +167,21 @@ class Op : public OpBase {
using InterfaceList =
typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
};
static std::vector<InterfaceValue> GetInterfaceMap() {
constexpr size_t interfaces_num = std::tuple_size<InterfaceList>::value;
std::vector<InterfaceValue> interfaces_map(interfaces_num);
ConstructInterfacesOrTraits<ConcreteOp, InterfaceList>::interface(
interfaces_map.data());
return interfaces_map;
}
static std::vector<TypeId> GetTraitSet() {
constexpr size_t traits_num = std::tuple_size<TraitList>::value;
std::vector<TypeId> trait_set(traits_num);
auto p_first_trait = trait_set.data();
ConstructInterfacesOrTraits<ConcreteOp, TraitList>::trait(p_first_trait);
return trait_set;
}
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/op_info.h"
#include "paddle/ir/dialect.h"
#include "paddle/ir/ir_context.h"
#include "paddle/ir/op_info_impl.h"
namespace ir {
bool OpInfo::HasTrait(TypeId trait_id) const {
return impl_ && impl_->HasTrait(trait_id);
}
bool OpInfo::HasInterface(TypeId interface_id) const {
return impl_ && impl_->HasInterface(interface_id);
}
IrContext *OpInfo::ir_context() const {
return impl_ ? impl_->ir_context() : nullptr;
}
const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; }
void *OpInfo::GetInterfaceImpl(TypeId interface_id) const {
return impl_ ? impl_->interface_impl(interface_id) : nullptr;
}
ir::IrContext *OpInfoImpl::ir_context() const {
return dialect()->ir_context();
}
void *OpInfoImpl::interface_impl(TypeId interface_id) const {
if (num_interfaces_ > 0) {
const InterfaceValue *p_first_interface =
reinterpret_cast<const InterfaceValue *>(
reinterpret_cast<const char *>(this) -
sizeof(TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_);
size_t left = 0, right = num_interfaces_;
while (left < right) {
size_t mid = (left + right) / 2;
if ((p_first_interface + mid)->type_id() == interface_id) {
return (p_first_interface + mid)->model();
} else if ((p_first_interface + mid)->type_id() < interface_id) {
left = mid + 1;
} else {
right = mid;
}
}
}
return nullptr;
}
bool OpInfoImpl::HasTrait(TypeId trait_id) const {
if (num_traits_ > 0) {
const TypeId *p_first_trait =
reinterpret_cast<const TypeId *>(reinterpret_cast<const char *>(this) -
sizeof(ir::TypeId) * num_traits_);
return std::binary_search(
p_first_trait, p_first_trait + num_traits_, trait_id);
}
return false;
}
bool OpInfoImpl::HasInterface(TypeId interface_id) const {
if (num_interfaces_ > 0) {
const InterfaceValue *p_first_interface =
reinterpret_cast<const InterfaceValue *>(
reinterpret_cast<const char *>(this) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_);
return std::binary_search(p_first_interface,
p_first_interface + num_interfaces_,
InterfaceValue(interface_id));
}
return false;
}
OpInfoImpl *OpInfoImpl::create(Dialect *dialect,
TypeId op_id,
const char *op_name,
std::vector<InterfaceValue> &&interface_map,
const std::vector<TypeId> &trait_set,
size_t attributes_num,
const char *attributes_name[]) {
// (1) Malloc memory for interfaces, traits, opinfo_impl.
size_t interfaces_num = interface_map.size();
size_t traits_num = trait_set.size();
VLOG(4) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, "
<< traits_num << " traits, " << attributes_num << " attributes.";
size_t base_size = sizeof(InterfaceValue) * interfaces_num +
sizeof(TypeId) * traits_num + sizeof(OpInfoImpl);
char *base_ptr = static_cast<char *>(::operator new(base_size));
VLOG(4) << "Malloc " << base_size << " Bytes at "
<< static_cast<void *>(base_ptr);
if (interfaces_num > 0) {
std::sort(interface_map.begin(), interface_map.end());
for (size_t index = 0; index < interfaces_num; ++index) {
new (base_ptr + index * sizeof(InterfaceValue))
InterfaceValue(std::move(interface_map[index]));
}
base_ptr += interfaces_num * sizeof(InterfaceValue);
}
if (traits_num > 0) {
auto p_first_trait = reinterpret_cast<TypeId *>(base_ptr);
memcpy(base_ptr, trait_set.data(), sizeof(TypeId) * traits_num);
std::sort(p_first_trait, p_first_trait + traits_num);
base_ptr += traits_num * sizeof(TypeId);
}
// Construct opinfo_impl.
OpInfoImpl *p_opinfo_impl = reinterpret_cast<OpInfoImpl *>(base_ptr);
VLOG(4) << "Construct op_info_impl at " << p_opinfo_impl << " ......";
OpInfoImpl *op_info = new (p_opinfo_impl) OpInfoImpl(dialect,
op_id,
op_name,
interfaces_num,
traits_num,
attributes_num,
attributes_name
);
return op_info;
}
void OpInfoImpl::destroy() {
VLOG(4) << "Destroy op_info impl at " << this;
// (1) free interfaces
char *base_ptr = reinterpret_cast<char *>(this) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_;
if (num_interfaces_ > 0) {
InterfaceValue *p_interface_val =
reinterpret_cast<InterfaceValue *>(base_ptr);
for (size_t i = 0; i < num_interfaces_; i++) {
(p_interface_val + i)->~InterfaceValue();
}
}
// (2) free memeory
VLOG(4) << "Free base_ptr " << base_ptr;
free(base_ptr);
}
} // namespace ir
......@@ -13,12 +13,13 @@
// limitations under the License.
#pragma once
#include <functional>
#include "paddle/ir/op_info_impl.h"
#include "paddle/ir/type_id.h"
namespace ir {
class OpInfoImpl;
class IrContext;
class OpInfo {
public:
constexpr OpInfo() = default;
......@@ -37,24 +38,42 @@ class OpInfo {
bool operator!() const { return impl_ == nullptr; }
const OpInfoImpl *impl() const { return impl_; }
IrContext *ir_context() const;
const char *name() const;
template <typename Trait>
bool HasTrait() const {
return impl_->HasTrait<Trait>();
return HasTrait(TypeId::get<Trait>());
}
bool HasTrait(TypeId trait_id) const;
template <typename Interface>
bool HasInterface() const {
return impl_->HasInterface<Interface>();
return HasInterface(TypeId::get<Interface>());
}
bool HasInterface(TypeId interface_id) const;
template <typename Interface>
typename Interface::Concept *GetInterfaceImpl() const;
friend struct std::hash<OpInfo>;
private:
void *GetInterfaceImpl(TypeId interface_id) const;
private:
const OpInfoImpl *impl_{nullptr}; // not owned
};
template <typename Interface>
typename Interface::Concept *OpInfo::GetInterfaceImpl() const {
void *model = GetInterfaceImpl(TypeId::get<Interface>());
return reinterpret_cast<typename Interface::Concept *>(model);
}
} // namespace ir
namespace std {
......
......@@ -15,77 +15,16 @@
#pragma once
#include <algorithm>
#include <cstring>
#include <initializer_list>
#include <string>
#include <utility>
#include "paddle/ir/builtin_attribute.h"
// #include "paddle/ir/ir_context.h"
#include "paddle/ir/op_base.h"
#include "paddle/ir/type.h"
namespace ir {
class Dialect;
///
/// \brief Tool template class for construct interfaces or Traits.
///
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits {
public:
/// Construct method for interfaces.
static std::pair<TypeId, void *> *interface(
std::pair<TypeId, void *> *p_interface) {
(void)std::initializer_list<int>{
0, (PlacementConstrctInterface<Args>(p_interface), 0)...};
return p_interface;
}
/// Construct method for traits.
static TypeId *trait(TypeId *p_trait) {
(void)std::initializer_list<int>{
0, (PlacementConstrctTrait<Args>(p_trait), 0)...};
return p_trait;
}
private:
/// Placement new interface.
template <typename T>
static void PlacementConstrctInterface(
std::pair<TypeId, void *> *&p_interface) { // NOLINT
new (&(p_interface->first)) TypeId(ir::TypeId::get<T>());
p_interface->second =
malloc(sizeof(typename T::template Model<ConcreteOp>));
new (p_interface->second) typename T::template Model<ConcreteOp>();
VLOG(4) << "New a interface: id[" << p_interface->first.storage()
<< "], interface[" << p_interface->second << "].";
++p_interface;
}
/// Placement new trait.
template <typename T>
static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT
new (p_trait) TypeId(ir::TypeId::get<T>());
VLOG(4) << "New a trait: id[" << (*p_trait).storage() << "].";
++p_trait;
}
};
/// Specialized for tuple type.
template <typename ConcreteOp, typename... Args>
class ConstructInterfacesOrTraits<ConcreteOp, std::tuple<Args...>> {
public:
/// Construct method for interfaces.
static std::pair<TypeId, void *> *interface(
std::pair<TypeId, void *> *p_interface) {
return ConstructInterfacesOrTraits<ConcreteOp, Args...>::interface(
p_interface);
}
/// Construct method for traits.
static TypeId *trait(TypeId *p_trait) {
return ConstructInterfacesOrTraits<ConcreteOp, Args...>::trait(p_trait);
}
};
///
/// \brief OpInfoImpl class.
///
......@@ -95,143 +34,27 @@ class OpInfoImpl {
/// \brief Construct and Deconstruct OpInfoImpl. The memory layout of
/// OpInfoImpl is: std::pair<TypeId, void *>... | TypeId... | OpInfoImpl
///
template <typename ConcreteOp>
static OpInfoImpl *create(ir::Dialect *dialect) {
// (1) Malloc memory for interfaces, traits, opinfo_impl.
size_t interfaces_num =
std::tuple_size<typename ConcreteOp::InterfaceList>::value;
size_t traits_num = std::tuple_size<typename ConcreteOp::TraitList>::value;
size_t attributes_num = ConcreteOp::attributes_num();
VLOG(4) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, "
<< traits_num << " traits, " << attributes_num << " attributes.";
size_t base_size = sizeof(std::pair<ir::TypeId, void *>) * interfaces_num +
sizeof(ir::TypeId) * traits_num + sizeof(OpInfoImpl);
void *base_ptr = malloc(base_size);
VLOG(4) << "Malloc " << base_size << " Bytes at " << base_ptr;
// (2) Construct interfaces and sort by TypeId.
std::pair<ir::TypeId, void *> *p_first_interface = nullptr;
if (interfaces_num > 0) {
p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(base_ptr);
VLOG(4) << "Construct interfaces at " << p_first_interface << " ......";
ConstructInterfacesOrTraits<
ConcreteOp,
typename ConcreteOp::InterfaceList>::interface(p_first_interface);
std::sort(p_first_interface, p_first_interface + interfaces_num);
base_ptr = reinterpret_cast<void *>(p_first_interface + interfaces_num);
}
// (3) Construct traits and sort by TypeId.
ir::TypeId *p_first_trait = nullptr;
if (traits_num > 0) {
p_first_trait = reinterpret_cast<ir::TypeId *>(base_ptr);
VLOG(4) << "Construct traits at " << p_first_trait << " ......";
ConstructInterfacesOrTraits<ConcreteOp, typename ConcreteOp::TraitList>::
trait(p_first_trait);
std::sort(p_first_trait, p_first_trait + traits_num);
base_ptr = reinterpret_cast<void *>(p_first_trait + traits_num);
}
static OpInfoImpl *create(Dialect *dialect,
TypeId op_id,
const char *op_name,
std::vector<InterfaceValue> &&interface_map,
const std::vector<TypeId> &trait_set,
size_t attributes_num,
const char *attributes_name[]);
// (4) Construct opinfo_impl.
OpInfoImpl *p_opinfo_impl = reinterpret_cast<OpInfoImpl *>(base_ptr);
VLOG(4) << "Construct op_info_impl at " << p_opinfo_impl << " ......";
OpInfoImpl *op_info =
new (p_opinfo_impl) OpInfoImpl(interfaces_num,
traits_num,
ConcreteOp::attributes_name_,
attributes_num,
ir::TypeId::get<ConcreteOp>(),
ConcreteOp::name(),
dialect);
return op_info;
}
void destroy();
void destroy() {
VLOG(4) << "Destroy op_info impl at " << this;
// (1) free interfaces
void *base_ptr = reinterpret_cast<void *>(
reinterpret_cast<char *>(this) - sizeof(ir::TypeId) * num_traits_ -
sizeof(std::pair<ir::TypeId, void *>) * num_interfaces_);
if (num_interfaces_ > 0) {
std::pair<ir::TypeId, void *> *p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(base_ptr);
for (size_t i = 0; i < num_interfaces_; i++) {
free((p_first_interface + i)->second);
}
}
// (2) free memeory
VLOG(4) << "Free base_ptr " << base_ptr;
free(base_ptr);
}
ir::IrContext *ir_context() const;
///
/// \brief Search methods for Trait or Interface.
///
template <typename Trait>
bool HasTrait() const {
return HasTrait(TypeId::get<Trait>());
}
bool HasTrait(TypeId trait_id) const;
bool HasTrait(TypeId trait_id) const {
if (num_traits_ > 0) {
TypeId *p_first_trait = reinterpret_cast<TypeId *>(
reinterpret_cast<char *>(const_cast<OpInfoImpl *>(this)) -
sizeof(ir::TypeId) * num_traits_);
return std::binary_search(
p_first_trait, p_first_trait + num_traits_, trait_id);
}
return false;
}
template <typename Interface>
bool HasInterface() const {
return HasInterface(TypeId::get<Interface>());
}
bool HasInterface(TypeId interface_id) const {
if (num_interfaces_ > 0) {
std::pair<ir::TypeId, void *> *p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(
reinterpret_cast<char *>(const_cast<OpInfoImpl *>(this)) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(std::pair<ir::TypeId, void *>) * num_interfaces_);
return std::binary_search(p_first_interface,
p_first_interface + num_interfaces_,
std::make_pair(interface_id, nullptr),
CompareInterface);
}
return false;
}
template <typename Interface>
typename Interface::Concept *GetInterfaceImpl() const {
if (num_interfaces_ > 0) {
ir::TypeId interface_id = ir::TypeId::get<Interface>();
std::pair<ir::TypeId, void *> *p_first_interface =
reinterpret_cast<std::pair<ir::TypeId, void *> *>(
reinterpret_cast<char *>(const_cast<OpInfoImpl *>(this)) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(std::pair<ir::TypeId, void *>) * num_interfaces_);
size_t left = 0;
size_t right = num_interfaces_;
while (left < right) {
size_t mid = left + (right - left) / 2;
if ((p_first_interface + mid)->first == interface_id) {
return reinterpret_cast<typename Interface::Concept *>(
(p_first_interface + mid)->second);
} else if ((p_first_interface + mid)->first < interface_id) {
left = mid + 1;
} else {
right = mid;
}
}
}
return nullptr;
}
bool HasInterface(TypeId interface_id) const;
ir::TypeId id() const { return op_id_; }
void *interface_impl(TypeId interface_id) const;
const char *name() const { return op_name_; }
ir::Dialect *dialect() const { return dialect_; }
......@@ -243,25 +66,29 @@ class OpInfoImpl {
}
private:
OpInfoImpl(uint32_t num_interfaces,
uint32_t num_traits,
const char **p_attributes,
uint32_t num_attributes,
OpInfoImpl(ir::Dialect *dialect,
TypeId op_id,
const char *op_name,
ir::Dialect *dialect)
: num_interfaces_(num_interfaces),
num_traits_(num_traits),
p_attributes_(p_attributes),
num_attributes_(num_attributes),
uint32_t num_interfaces,
uint32_t num_traits,
uint32_t num_attributes,
const char **p_attributes)
: dialect_(dialect),
op_id_(op_id),
op_name_(op_name),
dialect_(dialect) {}
num_interfaces_(num_interfaces),
num_traits_(num_traits),
num_attributes_(num_attributes),
p_attributes_(p_attributes) {}
static bool CompareInterface(const std::pair<ir::TypeId, void *> &a,
const std::pair<ir::TypeId, void *> &b) {
return a.first < b.first;
}
/// The dialect of this Op belong to.
ir::Dialect *dialect_;
/// The TypeId of this Op.
TypeId op_id_;
/// The name of this Op.
const char *op_name_;
/// Interface will be recorded by std::pair<TypeId, void*>.
uint32_t num_interfaces_ = 0;
......@@ -269,20 +96,11 @@ class OpInfoImpl {
/// Trait will be recorded by TypeId.
uint32_t num_traits_ = 0;
/// Attributes array address.
const char **p_attributes_{nullptr};
/// The number of attributes for this Op.
uint32_t num_attributes_ = 0;
/// The TypeId of this Op.
TypeId op_id_;
/// The name of this Op.
const char *op_name_;
/// The dialect of this Op belong to.
ir::Dialect *dialect_;
/// Attributes array address.
const char **p_attributes_{nullptr};
};
} // namespace ir
......@@ -18,6 +18,13 @@
#include "paddle/ir/utils.h"
namespace ir {
Operation *Operation::create(const OperationArgument &argument) {
return create(argument.inputs_,
argument.output_types_,
argument.attribute_,
argument.info_);
}
// Allocate the required memory based on the size and number of inputs, outputs,
// and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, Operand.
......@@ -126,6 +133,8 @@ void Operation::destroy() {
aligned_free(reinterpret_cast<void *>(aligned_ptr));
}
IrContext *Operation::ir_context() const { return op_info_.ir_context(); }
Operation::Operation(uint32_t num_results,
uint32_t num_operands,
const AttributeMap &attribute,
......@@ -190,9 +199,6 @@ std::string Operation::print() {
return result.str();
}
std::string Operation::op_name() const {
return op_info_.impl()->dialect()->name() + "." +
std::string(op_info_.impl()->name());
}
std::string Operation::op_name() const { return op_info_.name(); }
} // namespace ir
......@@ -16,6 +16,7 @@
#include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/op_info.h"
#include "paddle/ir/operation_utils.h"
#include "paddle/ir/type.h"
#include "paddle/ir/value_impl.h"
......@@ -26,8 +27,6 @@ template <typename ConcreteInterface>
class OpInterfaceBase;
class Program;
using AttributeMap = std::unordered_map<std::string, Attribute>;
class alignas(8) Operation final {
public:
///
......@@ -40,12 +39,15 @@ class alignas(8) Operation final {
const std::vector<ir::Type> &output_types,
const AttributeMap &attribute,
ir::OpInfo op_info);
static Operation *create(const OperationArgument &op_argument);
///
/// \brief Destroy the operation objects and free memeory by create().
///
void destroy();
IrContext *ir_context() const;
ir::OpResult GetResultByIndex(uint32_t index);
ir::OpOperand GetOperandByIndex(uint32_t index);
......@@ -99,14 +101,17 @@ class alignas(8) Operation final {
struct CastUtil<T,
typename std::enable_if<
std::is_base_of<OpTraitBase<T>, T>::value>::type> {
static T call(const Operation *op) { return T(op); }
static T call(const Operation *op) {
return T(op->HasTrait<T>() ? op : nullptr);
}
};
template <typename T>
struct CastUtil<T,
typename std::enable_if<
std::is_base_of<OpInterfaceBase<T>, T>::value>::type> {
static T call(const Operation *op) {
return T(op, op->op_info_.impl()->GetInterfaceImpl<T>());
typename T::Concept *interface_impl = op->op_info().GetInterfaceImpl<T>();
return interface_impl ? T(op, interface_impl) : T(nullptr, nullptr);
}
};
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/operation_utils.h"
namespace ir {
OperationArgument::OperationArgument(IrContext* ir_context, std::string name) {
info_ = ir_context->GetRegisteredOpInfo(name);
}
OperationArgument::OperationArgument(OpInfo info,
const std::vector<OpResult>& operands,
const std::vector<Type>& types,
const AttributeMap& named_attr)
: info_(info),
inputs_(operands),
output_types_(types),
attribute_(named_attr) {}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/op_info.h"
#include "paddle/ir/type.h"
#include "paddle/ir/value_impl.h"
namespace ir {
using AttributeMap = std::unordered_map<std::string, Attribute>;
//===----------------------------------------------------------------------===//
// OperationArgument
//===----------------------------------------------------------------------===//
// This represents an operation arguments in an combined form, suitable for use
// with the builder APIs.
struct OperationArgument {
OpInfo info_;
std::vector<OpResult> inputs_;
std::vector<Type> output_types_;
AttributeMap attribute_;
public:
OperationArgument(IrContext* ir_context, std::string name);
explicit OperationArgument(OpInfo info) : info_(info) {}
OperationArgument(OpInfo info,
const std::vector<OpResult>& operands,
const std::vector<Type>& types,
const AttributeMap& named_attr = {});
template <class InputIt>
void addOperands(InputIt first, InputIt last);
template <class InputIt>
void addTypes(InputIt first, InputIt last);
/// Add an attribute with the specified name.
void addAttribute(const std::string& name, Attribute attr) {
attribute_[name] = attr;
}
/// Add an array of named attributes.
template <class InputIt>
void addAttributes(InputIt first, InputIt last);
/// Get the context held by this operation state.
IrContext* getContext() const { return info_.ir_context(); }
};
template <class InputIt>
void OperationArgument::addOperands(InputIt first, InputIt last) {
while (first != last) {
inputs_.emplace_back(*first++);
}
}
template <class InputIt>
void OperationArgument::addTypes(InputIt first, InputIt last) {
while (first != last) {
output_types_.emplace_back(*first++);
}
}
template <class InputIt>
void OperationArgument::addAttributes(InputIt first, InputIt last) {
while (first != last) {
attribute_[first->first] = first->second;
++first;
}
}
} // namespace ir
......@@ -45,6 +45,8 @@ class TypeId {
return TypeId(&instance);
}
TypeId() = default;
TypeId(const TypeId &other) = default;
TypeId &operator=(const TypeId &other) = default;
......@@ -77,9 +79,8 @@ class TypeId {
///
explicit TypeId(const Storage *storage) : storage_(storage) {}
const Storage *storage_;
const Storage *storage_{nullptr};
};
} // namespace ir
namespace std {
......
......@@ -47,9 +47,8 @@ class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> {
}
Model() : Concept(InferShape) {
if (sizeof(Model) != sizeof(Concept)) {
throw("sizeof(Model) != sizeof(Concept)");
}
static_assert(sizeof(Model) == sizeof(Concept),
"sizeof(Model) != sizeof(Concept)");
}
};
......@@ -66,25 +65,27 @@ class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> {
class Operation1 : public ir::Op<Operation1> {
public:
using Op::Op;
static const char *name() { return "Operation1"; }
static const char *attributes_name_[];
static uint32_t attributes_num() { return 2; }
static const char *name() { return "test.operation1"; }
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];
};
const char *Operation1::attributes_name_[] = {"op1_attr1", "op1_attr2"};
const char *Operation1::attributes_name[attributes_num] = {"op1_attr1",
"op1_attr2"};
// Define op2.
class Operation2
: public ir::Op<Operation2, ReadOnlyTrait, InferShapeInterface> {
public:
using Op::Op;
static const char *name() { return "Operation2"; }
static const char *attributes_name_[];
static uint32_t attributes_num() { return 2; }
static const char *name() { return "test.operation2"; }
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];
static void InferShape() {
std::cout << "This is op2's InferShape interface." << std::endl;
}
};
const char *Operation2::attributes_name_[] = {"op2_attr1", "op2_attr2"};
const char *Operation2::attributes_name[attributes_num] = {"op2_attr1",
"op2_attr2"};
// Define a dialect, op1 and op2 will be registered by this dialect.
class TestDialect : public ir::Dialect {
......@@ -93,7 +94,7 @@ class TestDialect : public ir::Dialect {
: ir::Dialect(name(), context, ir::TypeId::get<TestDialect>()) {
initialize();
}
static const char *name() { return "op_test"; }
static const char *name() { return "test"; }
private:
void initialize() { RegisterOps<Operation1, Operation2>(); }
......@@ -116,19 +117,17 @@ TEST(op_test, op_test) {
std::cout << test_dialect << std::endl;
// (2) Get registered operations.
std::string op1_name =
test_dialect->name() + "." + std::string(Operation1::name());
ir::OpInfoImpl *op1_info = ctx->GetRegisteredOpInfo(op1_name);
std::string op1_name = Operation1::name();
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name);
EXPECT_EQ(op1_info != nullptr, true);
std::string op2_name =
test_dialect->name() + "." + std::string(Operation2::name());
ir::OpInfoImpl *op2_info = ctx->GetRegisteredOpInfo(op2_name);
std::string op2_name = Operation2::name();
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
EXPECT_EQ(op2_info != nullptr, true);
EXPECT_EQ(op1_info->HasTrait<ReadOnlyTrait>(), false);
EXPECT_EQ(op1_info->HasInterface<InferShapeInterface>(), false);
EXPECT_EQ(op2_info->HasTrait<ReadOnlyTrait>(), true);
EXPECT_EQ(op2_info->HasInterface<InferShapeInterface>(), true);
EXPECT_EQ(op1_info.HasTrait<ReadOnlyTrait>(), false);
EXPECT_EQ(op1_info.HasInterface<InferShapeInterface>(), false);
EXPECT_EQ(op2_info.HasTrait<ReadOnlyTrait>(), true);
EXPECT_EQ(op2_info.HasInterface<InferShapeInterface>(), true);
// (3) Test uses for op.
std::vector<ir::OpResult> op_inputs = {};
......
......@@ -31,11 +31,10 @@
class AddOp : public ir::Op<AddOp> {
public:
using Op::Op;
static const char *name() { return "Add"; }
static const char **attributes_name_;
static uint32_t attributes_num() { return 0; }
static const char *name() { return "test.add"; }
static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0;
};
const char **AddOp::attributes_name_ = nullptr;
TEST(program_test, program) {
// (1) Init environment.
......@@ -78,9 +77,8 @@ TEST(program_test, program) {
EXPECT_EQ(program.parameters_num() == 2, true);
// (4) Def a = GetParameterOp("a"), and create DenseTensor for a.
std::string op1_name =
builtin_dialect->name() + "." + std::string(ir::GetParameterOp::name());
ir::OpInfoImpl *op1_info = ctx->GetRegisteredOpInfo(op1_name);
std::string op1_name = ir::GetParameterOp::name();
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name);
std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "a")}};
ir::Operation *op1 =
......@@ -112,7 +110,7 @@ TEST(program_test, program) {
// (5) Def b = GetParameterOp("b"), and create DenseTensor for b.
std::string op2_name =
builtin_dialect->name() + "." + std::string(ir::GetParameterOp::name());
ir::OpInfoImpl *op2_info = ctx->GetRegisteredOpInfo(op2_name);
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
std::unordered_map<std::string, ir::Attribute> op2_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 =
......@@ -142,7 +140,7 @@ TEST(program_test, program) {
// (6) Def c = AddOp(a, b), execute this op.
std::string op3_name =
builtin_dialect->name() + "." + std::string(AddOp::name());
ir::OpInfoImpl *op3_info = ctx->GetRegisteredOpInfo(op3_name);
ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name);
std::unordered_map<std::string, ir::Attribute> op3_attribute;
ir::Operation *op3 = ir::Operation::create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
......@@ -173,7 +171,7 @@ TEST(program_test, program) {
// (7) Def SetParameterOp(c, "c")
std::string op4_name =
builtin_dialect->name() + "." + std::string(ir::SetParameterOp::name());
ir::OpInfoImpl *op4_info = ctx->GetRegisteredOpInfo(op4_name);
ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name);
std::unordered_map<std::string, ir::Attribute> op4_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "c")}};
ir::Operation *op4 = ir::Operation::create(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册