未验证 提交 8dfcf03e 编写于 作者: W winter-wang 提交者: GitHub

[IR] add positon member in new ir operation. (#54483)

上级 b3232936
...@@ -13,22 +13,23 @@ ...@@ -13,22 +13,23 @@
// limitations under the License. // limitations under the License.
#include "paddle/ir/core/block.h" #include "paddle/ir/core/block.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/region.h"
namespace ir { namespace ir {
Block::~Block() { clear(); } Block::~Block() { clear(); }
void Block::push_back(Operation *op) { void Block::push_back(Operation *op) { insert(ops_.end(), op); }
op->set_parent(this);
ops_.push_back(op); void Block::push_front(Operation *op) { insert(ops_.begin(), op); }
}
void Block::push_front(Operation *op) { Operation *Block::GetParentOp() const {
op->set_parent(this); return parent_ ? parent_->GetParent() : nullptr;
ops_.push_front(op);
} }
Block::iterator Block::insert(const_iterator iterator, Operation *op) { Block::iterator Block::insert(const_iterator iterator, Operation *op) {
op->set_parent(this); Block::iterator iter = ops_.insert(iterator, op);
return ops_.insert(iterator, op); op->SetParent(this, iter);
return iter;
} }
void Block::clear() { void Block::clear() {
......
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
#include <cstddef> #include <cstddef>
#include <list> #include <list>
#include "paddle/ir/core/operation.h"
namespace ir { namespace ir {
class Region; class Region;
class Operation;
class Block { class Block {
public: public:
...@@ -30,7 +30,9 @@ class Block { ...@@ -30,7 +30,9 @@ class Block {
Block() = default; Block() = default;
~Block(); ~Block();
Region *parent() const { return parent_; } Region *GetParent() const { return parent_; }
Operation *GetParentOp() const;
bool empty() const { return ops_.empty(); } bool empty() const { return ops_.empty(); }
size_t size() const { return ops_.size(); } size_t size() const { return ops_.size(); }
...@@ -46,18 +48,12 @@ class Block { ...@@ -46,18 +48,12 @@ class Block {
iterator insert(const_iterator iterator, Operation *op); iterator insert(const_iterator iterator, Operation *op);
void clear(); void clear();
Region *GetParentRegion() const { return parent_; }
Operation *GetParentOp() const {
return parent_ ? parent_->GetParentOp() : nullptr;
}
private: private:
Block(Block &) = delete; Block(Block &) = delete;
Block &operator=(const Block &) = delete; Block &operator=(const Block &) = delete;
friend class Region; friend class Region;
void set_parent(Region *parent) { parent_ = parent; } void SetParent(Region *parent) { parent_ = parent; }
private: private:
Region *parent_; // not owned Region *parent_; // not owned
......
...@@ -26,10 +26,10 @@ namespace ir { ...@@ -26,10 +26,10 @@ namespace ir {
/// ///
class Builder { class Builder {
public: public:
explicit Builder(IrContext *context, Builder(IrContext *context, Block *block, Block::iterator insert_point)
Block *block,
Block::iterator insert_point)
: context_(context), block_(block), insert_point_(insert_point) {} : context_(context), block_(block), insert_point_(insert_point) {}
Builder(IrContext *context, Block *block)
: Builder(context, block, block->end()) {}
static Builder AtBlockBegin(IrContext *context, Block *block) { static Builder AtBlockBegin(IrContext *context, Block *block) {
return Builder(context, block, block->begin()); return Builder(context, block, block->begin());
......
...@@ -49,7 +49,7 @@ class IrContextImpl { ...@@ -49,7 +49,7 @@ class IrContextImpl {
registed_dialect_.clear(); registed_dialect_.clear();
for (auto &op_map : registed_op_infos_) { for (auto &op_map : registed_op_infos_) {
op_map.second->destroy(); OpInfoImpl::Destroy(op_map.second);
} }
registed_op_infos_.clear(); registed_op_infos_.clear();
} }
...@@ -103,24 +103,25 @@ class IrContextImpl { ...@@ -103,24 +103,25 @@ class IrContextImpl {
return registed_op_infos_.find(name) != registed_op_infos_.end(); return registed_op_infos_.find(name) != registed_op_infos_.end();
} }
void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) { void RegisterOpInfo(const std::string &name, OpInfo info) {
std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_); std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
VLOG(4) << "Register an operation of: [Name=" << name VLOG(4) << "Register an operation of: [Name=" << name
<< ", OpInfoImpl ptr=" << opinfo << "]."; << ", OpInfo ptr=" << info.AsOpaquePointer() << "].";
registed_op_infos_.emplace(name, opinfo); registed_op_infos_.emplace(name, info);
} }
OpInfoImpl *GetOpInfo(const std::string &name) { OpInfo GetOpInfo(const std::string &name) {
std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_); std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
auto iter = registed_op_infos_.find(name); auto iter = registed_op_infos_.find(name);
if (iter != registed_op_infos_.end()) { if (iter != registed_op_infos_.end()) {
VLOG(4) << "Found a cached operation of: [name=" << name VLOG(4) << "Found a cached OpInfo of: [name=" << name
<< ", OpInfoImpl ptr=" << iter->second << "]."; << ", OpInfo: ptr=" << iter->second.AsOpaquePointer() << "].";
return iter->second; return iter->second;
} }
LOG(WARNING) << "No cache found operation of: [Name=" << name << "]."; LOG(WARNING) << "No cache found operation of: [Name=" << name << "].";
return nullptr; return OpInfo();
} }
const OpInfoMap &registered_op_info_map() { return registed_op_infos_; }
void RegisterDialect(std::string name, Dialect *dialect) { void RegisterDialect(std::string name, Dialect *dialect) {
std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_); std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
...@@ -170,7 +171,7 @@ class IrContextImpl { ...@@ -170,7 +171,7 @@ class IrContextImpl {
ir::SpinLock registed_dialect_lock_; ir::SpinLock registed_dialect_lock_;
// The Op registered in the context. // The Op registered in the context.
std::unordered_map<std::string, OpInfoImpl *> registed_op_infos_; OpInfoMap registed_op_infos_;
ir::SpinLock registed_op_infos_lock_; ir::SpinLock registed_op_infos_lock_;
ir::SpinLock destructor_lock_; ir::SpinLock destructor_lock_;
...@@ -282,43 +283,39 @@ void IrContext::RegisterOpInfo(Dialect *dialect, ...@@ -282,43 +283,39 @@ void IrContext::RegisterOpInfo(Dialect *dialect,
if (impl().IsOpInfoRegistered(name)) { if (impl().IsOpInfoRegistered(name)) {
LOG(WARNING) << name << " op already registered."; LOG(WARNING) << name << " op already registered.";
} else { } else {
OpInfoImpl *opinfo = OpInfoImpl::create(dialect, OpInfo info = OpInfoImpl::Create(dialect,
op_id, op_id,
name, name,
std::move(interface_map), std::move(interface_map),
trait_set, trait_set,
attributes_num, attributes_num,
attributes_name, attributes_name,
verify); verify);
impl().RegisterOpInfo(name, opinfo); impl().RegisterOpInfo(name, info);
VLOG(4) << name << " op registered into IrContext. --->"; VLOG(4) << name << " op registered into IrContext. --->";
} }
} }
OpInfo IrContext::GetRegisteredOpInfo(const std::string &name) { OpInfo IrContext::GetRegisteredOpInfo(const std::string &name) {
OpInfoImpl *rtn = impl().GetOpInfo(name); return impl().GetOpInfo(name);
return rtn ? rtn : nullptr; }
const OpInfoMap &IrContext::registered_op_info_map() {
return impl().registered_op_info_map();
} }
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) { const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
auto &impl = ctx->impl(); AbstractType *abstract_type = ctx->impl().GetAbstractType(type_id);
AbstractType *abstract_type = impl.GetAbstractType(type_id); IR_ENFORCE(abstract_type, "Abstract type not found in IrContext.");
if (abstract_type) { return *abstract_type;
return *abstract_type;
} else {
throw("Abstract type not found in IrContext.");
}
} }
const AbstractAttribute &AbstractAttribute::lookup(TypeId type_id, const AbstractAttribute &AbstractAttribute::lookup(TypeId type_id,
IrContext *ctx) { IrContext *ctx) {
auto &impl = ctx->impl(); AbstractAttribute *abstract_attribute =
AbstractAttribute *abstract_attribute = impl.GetAbstractAttribute(type_id); ctx->impl().GetAbstractAttribute(type_id);
if (abstract_attribute) { IR_ENFORCE(abstract_attribute, "Abstract attribute not found in IrContext.");
return *abstract_attribute; return *abstract_attribute;
} else {
throw("Abstract attribute not found in IrContext.");
}
} }
BFloat16Type BFloat16Type::get(IrContext *ctx) { BFloat16Type BFloat16Type::get(IrContext *ctx) {
......
...@@ -31,6 +31,8 @@ class Type; ...@@ -31,6 +31,8 @@ class Type;
class OpResult; class OpResult;
class Attribute; class Attribute;
using OpInfoMap = std::unordered_map<std::string, OpInfo>;
/// ///
/// \brief IrContext is a global parameterless class used to store and manage /// \brief IrContext is a global parameterless class used to store and manage
/// Type, Attribute and other related data structures. /// Type, Attribute and other related data structures.
...@@ -116,6 +118,11 @@ class IrContext { ...@@ -116,6 +118,11 @@ class IrContext {
/// ///
OpInfo GetRegisteredOpInfo(const std::string &name); OpInfo GetRegisteredOpInfo(const std::string &name);
///
/// \brief Get registered operaiton infomation map.
///
const OpInfoMap &registered_op_info_map();
/// ///
/// \brief Get the dialect of the DialectT class in the context, ff not found, /// \brief Get the dialect of the DialectT class in the context, ff not found,
/// create and register to context. /// create and register to context.
......
...@@ -41,123 +41,6 @@ void OpInfo::Verify(const std::vector<OpResult> &inputs, ...@@ -41,123 +41,6 @@ void OpInfo::Verify(const std::vector<OpResult> &inputs,
} }
void *OpInfo::GetInterfaceImpl(TypeId interface_id) const { void *OpInfo::GetInterfaceImpl(TypeId interface_id) const {
return impl_ ? impl_->interface_impl(interface_id) : nullptr; return impl_ ? impl_->GetInterfaceImpl(interface_id) : nullptr;
} }
ir::IrContext *OpInfoImpl::ir_context() const {
return dialect()->ir_context();
}
void *OpInfoImpl::interface_impl(TypeId interface_id) const {
if (num_interfaces_ > 0) {
const InterfaceValue *p_first_interface =
reinterpret_cast<const InterfaceValue *>(
reinterpret_cast<const char *>(this) -
sizeof(TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_);
size_t left = 0, right = num_interfaces_;
while (left < right) {
size_t mid = (left + right) / 2;
if ((p_first_interface + mid)->type_id() == interface_id) {
return (p_first_interface + mid)->model();
} else if ((p_first_interface + mid)->type_id() < interface_id) {
left = mid + 1;
} else {
right = mid;
}
}
}
return nullptr;
}
bool OpInfoImpl::HasTrait(TypeId trait_id) const {
if (num_traits_ > 0) {
const TypeId *p_first_trait =
reinterpret_cast<const TypeId *>(reinterpret_cast<const char *>(this) -
sizeof(ir::TypeId) * num_traits_);
return std::binary_search(
p_first_trait, p_first_trait + num_traits_, trait_id);
}
return false;
}
bool OpInfoImpl::HasInterface(TypeId interface_id) const {
if (num_interfaces_ > 0) {
const InterfaceValue *p_first_interface =
reinterpret_cast<const InterfaceValue *>(
reinterpret_cast<const char *>(this) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_);
return std::binary_search(p_first_interface,
p_first_interface + num_interfaces_,
InterfaceValue(interface_id));
}
return false;
}
OpInfoImpl *OpInfoImpl::create(Dialect *dialect,
TypeId op_id,
const char *op_name,
std::vector<InterfaceValue> &&interface_map,
const std::vector<TypeId> &trait_set,
size_t attributes_num,
const char *attributes_name[],
VerifyPtr verify) {
// (1) Malloc memory for interfaces, traits, opinfo_impl.
size_t interfaces_num = interface_map.size();
size_t traits_num = trait_set.size();
VLOG(4) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, "
<< traits_num << " traits, " << attributes_num << " attributes.";
size_t base_size = sizeof(InterfaceValue) * interfaces_num +
sizeof(TypeId) * traits_num + sizeof(OpInfoImpl);
char *base_ptr = static_cast<char *>(::operator new(base_size));
VLOG(4) << "Malloc " << base_size << " Bytes at "
<< static_cast<void *>(base_ptr);
if (interfaces_num > 0) {
std::sort(interface_map.begin(), interface_map.end());
for (size_t index = 0; index < interfaces_num; ++index) {
new (base_ptr + index * sizeof(InterfaceValue))
InterfaceValue(std::move(interface_map[index]));
}
base_ptr += interfaces_num * sizeof(InterfaceValue);
}
if (traits_num > 0) {
auto p_first_trait = reinterpret_cast<TypeId *>(base_ptr);
memcpy(base_ptr, trait_set.data(), sizeof(TypeId) * traits_num);
std::sort(p_first_trait, p_first_trait + traits_num);
base_ptr += traits_num * sizeof(TypeId);
}
// Construct opinfo_impl.
OpInfoImpl *p_opinfo_impl = reinterpret_cast<OpInfoImpl *>(base_ptr);
VLOG(4) << "Construct op_info_impl at " << p_opinfo_impl << " ......";
OpInfoImpl *op_info = new (p_opinfo_impl) OpInfoImpl(dialect,
op_id,
op_name,
interfaces_num,
traits_num,
attributes_num,
attributes_name,
verify
);
return op_info;
}
void OpInfoImpl::destroy() {
VLOG(4) << "Destroy op_info impl at " << this;
// (1) free interfaces
char *base_ptr = reinterpret_cast<char *>(this) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_;
if (num_interfaces_ > 0) {
InterfaceValue *p_interface_val =
reinterpret_cast<InterfaceValue *>(base_ptr);
for (size_t i = 0; i < num_interfaces_; i++) {
(p_interface_val + i)->~InterfaceValue();
}
}
// (2) free memeory
VLOG(4) << "Free base_ptr " << base_ptr;
free(base_ptr);
}
} // namespace ir } // namespace ir
...@@ -28,8 +28,6 @@ class OpInfo { ...@@ -28,8 +28,6 @@ class OpInfo {
public: public:
constexpr OpInfo() = default; constexpr OpInfo() = default;
OpInfo(const OpInfoImpl *impl) : impl_(impl) {} // NOLINT
OpInfo(const OpInfo &other) = default; OpInfo(const OpInfo &other) = default;
OpInfo &operator=(const OpInfo &other) = default; OpInfo &operator=(const OpInfo &other) = default;
...@@ -52,8 +50,6 @@ class OpInfo { ...@@ -52,8 +50,6 @@ class OpInfo {
const std::vector<Type> &outputs, const std::vector<Type> &outputs,
const std::unordered_map<std::string, Attribute> &attributes); const std::unordered_map<std::string, Attribute> &attributes);
const OpInfoImpl *impl() const;
template <typename Trait> template <typename Trait>
bool HasTrait() const { bool HasTrait() const {
return HasTrait(TypeId::get<Trait>()); return HasTrait(TypeId::get<Trait>());
...@@ -71,13 +67,20 @@ class OpInfo { ...@@ -71,13 +67,20 @@ class OpInfo {
template <typename Interface> template <typename Interface>
typename Interface::Concept *GetInterfaceImpl() const; typename Interface::Concept *GetInterfaceImpl() const;
void *AsOpaquePointer() const { return impl_; }
static OpInfo RecoverFromOpaquePointer(void *impl) {
return static_cast<OpInfoImpl *>(impl);
}
friend class OpInfoImpl;
friend struct std::hash<OpInfo>; friend struct std::hash<OpInfo>;
private: private:
OpInfo(OpInfoImpl *impl) : impl_(impl) {} // NOLINT
void *GetInterfaceImpl(TypeId interface_id) const; void *GetInterfaceImpl(TypeId interface_id) const;
private: private:
const OpInfoImpl *impl_{nullptr}; // not owned OpInfoImpl *impl_{nullptr}; // not owned
}; };
template <typename Interface> template <typename Interface>
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/core/op_info_impl.h"
#include "paddle/ir/core/dialect.h"
namespace ir {
OpInfo OpInfoImpl::Create(Dialect *dialect,
TypeId op_id,
const char *op_name,
std::vector<InterfaceValue> &&interface_map,
const std::vector<TypeId> &trait_set,
size_t attributes_num,
const char *attributes_name[],
VerifyPtr verify) {
// (1) Malloc memory for interfaces, traits, opinfo_impl.
size_t interfaces_num = interface_map.size();
size_t traits_num = trait_set.size();
VLOG(4) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, "
<< traits_num << " traits, " << attributes_num << " attributes.";
size_t base_size = sizeof(InterfaceValue) * interfaces_num +
sizeof(TypeId) * traits_num + sizeof(OpInfoImpl);
char *base_ptr = static_cast<char *>(::operator new(base_size));
VLOG(4) << "Malloc " << base_size << " Bytes at "
<< static_cast<void *>(base_ptr);
if (interfaces_num > 0) {
std::sort(interface_map.begin(), interface_map.end());
for (size_t index = 0; index < interfaces_num; ++index) {
new (base_ptr + index * sizeof(InterfaceValue))
InterfaceValue(std::move(interface_map[index]));
}
base_ptr += interfaces_num * sizeof(InterfaceValue);
}
if (traits_num > 0) {
auto p_first_trait = reinterpret_cast<TypeId *>(base_ptr);
memcpy(base_ptr, trait_set.data(), sizeof(TypeId) * traits_num);
std::sort(p_first_trait, p_first_trait + traits_num);
base_ptr += traits_num * sizeof(TypeId);
}
// Construct OpInfoImpl.
VLOG(4) << "Construct OpInfoImpl at " << base_ptr << " ......";
OpInfo op_info = new (base_ptr) OpInfoImpl(dialect,
op_id,
op_name,
interfaces_num,
traits_num,
attributes_num,
attributes_name,
verify);
return op_info;
}
void OpInfoImpl::Destroy(OpInfo info) {
if (info.impl_) {
info.impl_->Destroy();
} else {
LOG(WARNING) << "A nullptr OpInfo is destoryed.";
}
}
ir::IrContext *OpInfoImpl::ir_context() const {
return dialect_ ? dialect_->ir_context() : nullptr;
}
bool OpInfoImpl::HasTrait(TypeId trait_id) const {
if (num_traits_ > 0) {
const TypeId *p_first_trait =
reinterpret_cast<const TypeId *>(reinterpret_cast<const char *>(this) -
sizeof(ir::TypeId) * num_traits_);
return std::binary_search(
p_first_trait, p_first_trait + num_traits_, trait_id);
}
return false;
}
bool OpInfoImpl::HasInterface(TypeId interface_id) const {
if (num_interfaces_ > 0) {
const InterfaceValue *p_first_interface =
reinterpret_cast<const InterfaceValue *>(
reinterpret_cast<const char *>(this) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_);
return std::binary_search(p_first_interface,
p_first_interface + num_interfaces_,
InterfaceValue(interface_id));
}
return false;
}
void *OpInfoImpl::GetInterfaceImpl(TypeId interface_id) const {
if (num_interfaces_ > 0) {
const InterfaceValue *p_first_interface =
reinterpret_cast<const InterfaceValue *>(
reinterpret_cast<const char *>(this) -
sizeof(TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_);
size_t left = 0, right = num_interfaces_;
while (left < right) {
size_t mid = (left + right) / 2;
if ((p_first_interface + mid)->type_id() == interface_id) {
return (p_first_interface + mid)->model();
} else if ((p_first_interface + mid)->type_id() < interface_id) {
left = mid + 1;
} else {
right = mid;
}
}
}
return nullptr;
}
void OpInfoImpl::Destroy() {
VLOG(4) << "Destroy op_info impl at " << this;
// (1) free interfaces
char *base_ptr = reinterpret_cast<char *>(this) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_;
if (num_interfaces_ > 0) {
InterfaceValue *p_interface_val =
reinterpret_cast<InterfaceValue *>(base_ptr);
for (size_t i = 0; i < num_interfaces_; i++) {
(p_interface_val + i)->~InterfaceValue();
}
}
// (2) free memeory
VLOG(4) << "Free base_ptr " << base_ptr;
free(base_ptr);
}
} // namespace ir
...@@ -38,40 +38,39 @@ class OpInfoImpl { ...@@ -38,40 +38,39 @@ class OpInfoImpl {
/// \brief Construct and Deconstruct OpInfoImpl. The memory layout of /// \brief Construct and Deconstruct OpInfoImpl. The memory layout of
/// OpInfoImpl is: std::pair<TypeId, void *>... | TypeId... | OpInfoImpl /// OpInfoImpl is: std::pair<TypeId, void *>... | TypeId... | OpInfoImpl
/// ///
static OpInfoImpl *create(Dialect *dialect, static OpInfo Create(Dialect *dialect,
TypeId op_id, TypeId op_id,
const char *op_name, const char *op_name,
std::vector<InterfaceValue> &&interface_map, std::vector<InterfaceValue> &&interface_map,
const std::vector<TypeId> &trait_set, const std::vector<TypeId> &trait_set,
size_t attributes_num, size_t attributes_num,
const char *attributes_name[], const char *attributes_name[],
VerifyPtr verify); VerifyPtr verify);
static void Destroy(OpInfo info);
void destroy(); TypeId id() const { return op_id_; }
ir::IrContext *ir_context() const; Dialect *dialect() const { return dialect_; }
VerifyPtr verify() const { return verify_; }
IrContext *ir_context() const;
/// \brief Search methods for Trait or Interface. /// \brief Search methods for Trait or Interface.
bool HasTrait(TypeId trait_id) const; bool HasTrait(TypeId trait_id) const;
bool HasInterface(TypeId interface_id) const; bool HasInterface(TypeId interface_id) const;
ir::TypeId id() const { return op_id_; } void *GetInterfaceImpl(TypeId interface_id) const;
void *interface_impl(TypeId interface_id) const;
const char *name() const { return op_name_; } const char *name() const { return op_name_; }
ir::Dialect *dialect() const { return dialect_; }
uint32_t AttributeNum() const { return num_attributes_; } uint32_t AttributeNum() const { return num_attributes_; }
const char *GetAttributeByIndex(size_t idx) const { const char *GetAttributeByIndex(size_t idx) const {
return idx < num_attributes_ ? p_attributes_[idx] : nullptr; return idx < num_attributes_ ? p_attributes_[idx] : nullptr;
} }
VerifyPtr verify() const { return verify_; }
private: private:
OpInfoImpl(ir::Dialect *dialect, OpInfoImpl(ir::Dialect *dialect,
TypeId op_id, TypeId op_id,
...@@ -89,9 +88,10 @@ class OpInfoImpl { ...@@ -89,9 +88,10 @@ class OpInfoImpl {
num_attributes_(num_attributes), num_attributes_(num_attributes),
p_attributes_(p_attributes), p_attributes_(p_attributes),
verify_(verify) {} verify_(verify) {}
void Destroy();
/// The dialect of this Op belong to. /// The dialect of this Op belong to.
ir::Dialect *dialect_; Dialect *dialect_;
/// The TypeId of this Op. /// The TypeId of this Op.
TypeId op_id_; TypeId op_id_;
......
...@@ -213,7 +213,7 @@ std::string Operation::name() const { ...@@ -213,7 +213,7 @@ std::string Operation::name() const {
} }
Region *Operation::GetParentRegion() const { Region *Operation::GetParentRegion() const {
return parent_ ? parent_->GetParentRegion() : nullptr; return parent_ ? parent_->GetParent() : nullptr;
} }
Operation *Operation::GetParentOp() const { Operation *Operation::GetParentOp() const {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <ostream> #include <ostream>
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/op_info.h" #include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation_utils.h" #include "paddle/ir/core/operation_utils.h"
#include "paddle/ir/core/type.h" #include "paddle/ir/core/type.h"
...@@ -22,7 +23,6 @@ ...@@ -22,7 +23,6 @@
namespace ir { namespace ir {
class OpBase; class OpBase;
class Program; class Program;
class Block;
class OpOperand; class OpOperand;
class OpResult; class OpResult;
...@@ -85,7 +85,7 @@ class alignas(8) Operation final { ...@@ -85,7 +85,7 @@ class alignas(8) Operation final {
return info_.HasInterface<Interface>(); return info_.HasInterface<Interface>();
} }
Block *GetParentBlock() const { return parent_; } Block *GetParent() const { return parent_; }
Region *GetParentRegion() const; Region *GetParentRegion() const;
...@@ -96,6 +96,8 @@ class alignas(8) Operation final { ...@@ -96,6 +96,8 @@ class alignas(8) Operation final {
/// Returns the region held by this operation at position 'index'. /// Returns the region held by this operation at position 'index'.
Region &GetRegion(unsigned index); Region &GetRegion(unsigned index);
operator Block::iterator() { return position_; }
private: private:
Operation(const AttributeMap &attribute, Operation(const AttributeMap &attribute,
ir::OpInfo op_info, ir::OpInfo op_info,
...@@ -111,7 +113,10 @@ class alignas(8) Operation final { ...@@ -111,7 +113,10 @@ class alignas(8) Operation final {
}; };
friend class Block; friend class Block;
void set_parent(Block *parent) { parent_ = parent; } void SetParent(Block *parent, const Block::iterator &position) {
parent_ = parent;
position_ = position;
}
template <typename T> template <typename T>
struct CastUtil< struct CastUtil<
...@@ -130,6 +135,7 @@ class alignas(8) Operation final { ...@@ -130,6 +135,7 @@ class alignas(8) Operation final {
Region *regions_{nullptr}; Region *regions_{nullptr};
Block *parent_{nullptr}; Block *parent_{nullptr};
Block::iterator position_;
}; };
} // namespace ir } // namespace ir
...@@ -19,26 +19,26 @@ namespace ir { ...@@ -19,26 +19,26 @@ namespace ir {
Region::~Region() { clear(); } Region::~Region() { clear(); }
void Region::push_back(Block *block) { void Region::push_back(Block *block) {
block->set_parent(this); block->SetParent(this);
blocks_.push_back(block); blocks_.push_back(block);
} }
void Region::emplace_back() { push_back(new Block); } void Region::emplace_back() { push_back(new Block); }
void Region::push_front(Block *block) { void Region::push_front(Block *block) {
block->set_parent(this); block->SetParent(this);
blocks_.push_front(block); blocks_.push_front(block);
} }
Region::iterator Region::insert(const_iterator position, Block *block) { Region::iterator Region::insert(const_iterator position, Block *block) {
block->set_parent(this); block->SetParent(this);
return blocks_.insert(position, block); return blocks_.insert(position, block);
} }
void Region::TakeBody(Region &&other) { void Region::TakeBody(Region &&other) {
clear(); clear();
blocks_.swap(other.blocks_); blocks_.swap(other.blocks_);
for (auto &block : blocks_) { for (auto &block : blocks_) {
block->set_parent(this); block->SetParent(this);
} }
} }
......
...@@ -48,7 +48,7 @@ class Region { ...@@ -48,7 +48,7 @@ class Region {
void TakeBody(Region &&other); void TakeBody(Region &&other);
Operation *GetParentOp() const { return parent_; } Operation *GetParent() const { return parent_; }
private: private:
Region(Region &) = delete; Region(Region &) = delete;
......
...@@ -71,3 +71,5 @@ cc_test_old( ...@@ -71,3 +71,5 @@ cc_test_old(
gtest gtest
new_ir new_ir
pd_dialect) pd_dialect)
cc_test_old(ir_op_info_test SRCS op_info_test.cc DEPS gtest new_ir)
...@@ -163,10 +163,10 @@ TEST(op_test, op_test) { ...@@ -163,10 +163,10 @@ TEST(op_test, op_test) {
// (2) Get registered operations. // (2) Get registered operations.
std::string op1_name = Operation1::name(); std::string op1_name = Operation1::name();
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name);
EXPECT_EQ(op1_info != nullptr, true); EXPECT_TRUE(op1_info);
std::string op2_name = Operation2::name(); std::string op2_name = Operation2::name();
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
EXPECT_EQ(op2_info != nullptr, true); EXPECT_TRUE(op2_info);
EXPECT_EQ(op1_info.HasTrait<ReadOnlyTrait>(), false); EXPECT_EQ(op1_info.HasTrait<ReadOnlyTrait>(), false);
EXPECT_EQ(op1_info.HasInterface<InferShapeInterface>(), false); EXPECT_EQ(op1_info.HasInterface<InferShapeInterface>(), false);
EXPECT_EQ(op2_info.HasTrait<ReadOnlyTrait>(), true); EXPECT_EQ(op2_info.HasTrait<ReadOnlyTrait>(), true);
......
...@@ -98,7 +98,7 @@ TEST(program_test, program) { ...@@ -98,7 +98,7 @@ TEST(program_test, program) {
ir::Block *block = program.block(); ir::Block *block = program.block();
block->push_back(op1); block->push_back(op1);
EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParentRegion()); EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParent());
EXPECT_EQ(program.module_op(), block->GetParentOp()); EXPECT_EQ(program.module_op(), block->GetParentOp());
...@@ -299,7 +299,7 @@ TEST(program_test, builder) { ...@@ -299,7 +299,7 @@ TEST(program_test, builder) {
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>(); ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx); ir::Program program(ctx);
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program.block()); ir::Builder builder = ir::Builder(ctx, program.block());
paddle::dialect::FullOp full_op = builder.Build<paddle::dialect::FullOp>( paddle::dialect::FullOp full_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector<int64_t>{2, 2}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace());
......
...@@ -42,7 +42,7 @@ TEST(value_test, value_test) { ...@@ -42,7 +42,7 @@ TEST(value_test, value_test) {
ir::Operation::Create(op1_inputs, ir::Operation::Create(op1_inputs,
CreateAttributeMap("op1_name", "op1_attr"), CreateAttributeMap("op1_name", "op1_attr"),
op1_output_types, op1_output_types,
nullptr); ir::OpInfo());
op1->Print(std::cout); op1->Print(std::cout);
// 2. Construct OP2: b = OP2(); // 2. Construct OP2: b = OP2();
std::vector<ir::OpResult> op2_inputs = {}; std::vector<ir::OpResult> op2_inputs = {};
...@@ -51,7 +51,7 @@ TEST(value_test, value_test) { ...@@ -51,7 +51,7 @@ TEST(value_test, value_test) {
ir::Operation::Create(op2_inputs, ir::Operation::Create(op2_inputs,
CreateAttributeMap("op2_name", "op2_attr"), CreateAttributeMap("op2_name", "op2_attr"),
op2_output_types, op2_output_types,
nullptr); ir::OpInfo());
op2->Print(std::cout); op2->Print(std::cout);
// 3. Construct OP3: c = OP3(a, b); // 3. Construct OP3: c = OP3(a, b);
std::vector<ir::OpResult> op3_inputs = {op1->GetResultByIndex(0), std::vector<ir::OpResult> op3_inputs = {op1->GetResultByIndex(0),
...@@ -61,7 +61,7 @@ TEST(value_test, value_test) { ...@@ -61,7 +61,7 @@ TEST(value_test, value_test) {
ir::Operation::Create(op3_inputs, ir::Operation::Create(op3_inputs,
CreateAttributeMap("op3_name", "op3_attr"), CreateAttributeMap("op3_name", "op3_attr"),
op3_output_types, op3_output_types,
nullptr); ir::OpInfo());
op3->Print(std::cout); op3->Print(std::cout);
// 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c); // 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c);
std::vector<ir::OpResult> op4_inputs = {op1->GetResultByIndex(0), std::vector<ir::OpResult> op4_inputs = {op1->GetResultByIndex(0),
...@@ -74,7 +74,7 @@ TEST(value_test, value_test) { ...@@ -74,7 +74,7 @@ TEST(value_test, value_test) {
ir::Operation::Create(op4_inputs, ir::Operation::Create(op4_inputs,
CreateAttributeMap("op4_name", "op4_attr"), CreateAttributeMap("op4_name", "op4_attr"),
op4_output_types, op4_output_types,
nullptr); ir::OpInfo());
op4->Print(std::cout); op4->Print(std::cout);
// Test 1: // Test 1:
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
TEST(ir_op_info_test, op_op_info_test) {
ir::IrContext* context = ir::IrContext::Instance();
ir::Program program(context);
ir::Block* block = program.block();
ir::Builder builder(context, block);
builder.Build<ir::ConstantOp>(ir::Int32_tAttribute::get(context, 5),
ir::Int32Type::get(context));
ir::Operation* op = block->back();
EXPECT_EQ(block->end(), ++ir::Block::iterator(*op));
auto& info_map = context->registered_op_info_map();
EXPECT_FALSE(info_map.empty());
void* info_1 = op->info().AsOpaquePointer();
auto info_2 = ir::OpInfo::RecoverFromOpaquePointer(info_1);
EXPECT_EQ(op->info(), info_2);
}
...@@ -112,7 +112,7 @@ TEST(pass_manager_test, pass_manager) { ...@@ -112,7 +112,7 @@ TEST(pass_manager_test, pass_manager) {
ir::Block *block = program.block(); ir::Block *block = program.block();
block->push_back(op1); block->push_back(op1);
EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParentRegion()); EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParent());
EXPECT_EQ(program.module_op(), block->GetParentOp()); EXPECT_EQ(program.module_op(), block->GetParentOp());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册