未验证 提交 8dfcf03e 编写于 作者: W winter-wang 提交者: GitHub

[IR] add positon member in new ir operation. (#54483)

上级 b3232936
......@@ -13,22 +13,23 @@
// limitations under the License.
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/region.h"
namespace ir {
Block::~Block() { clear(); }
void Block::push_back(Operation *op) {
op->set_parent(this);
ops_.push_back(op);
}
void Block::push_back(Operation *op) { insert(ops_.end(), op); }
void Block::push_front(Operation *op) { insert(ops_.begin(), op); }
void Block::push_front(Operation *op) {
op->set_parent(this);
ops_.push_front(op);
Operation *Block::GetParentOp() const {
return parent_ ? parent_->GetParent() : nullptr;
}
Block::iterator Block::insert(const_iterator iterator, Operation *op) {
op->set_parent(this);
return ops_.insert(iterator, op);
Block::iterator iter = ops_.insert(iterator, op);
op->SetParent(this, iter);
return iter;
}
void Block::clear() {
......
......@@ -16,10 +16,10 @@
#include <cstddef>
#include <list>
#include "paddle/ir/core/operation.h"
namespace ir {
class Region;
class Operation;
class Block {
public:
......@@ -30,7 +30,9 @@ class Block {
Block() = default;
~Block();
Region *parent() const { return parent_; }
Region *GetParent() const { return parent_; }
Operation *GetParentOp() const;
bool empty() const { return ops_.empty(); }
size_t size() const { return ops_.size(); }
......@@ -46,18 +48,12 @@ class Block {
iterator insert(const_iterator iterator, Operation *op);
void clear();
Region *GetParentRegion() const { return parent_; }
Operation *GetParentOp() const {
return parent_ ? parent_->GetParentOp() : nullptr;
}
private:
Block(Block &) = delete;
Block &operator=(const Block &) = delete;
friend class Region;
void set_parent(Region *parent) { parent_ = parent; }
void SetParent(Region *parent) { parent_ = parent; }
private:
Region *parent_; // not owned
......
......@@ -26,10 +26,10 @@ namespace ir {
///
class Builder {
public:
explicit Builder(IrContext *context,
Block *block,
Block::iterator insert_point)
Builder(IrContext *context, Block *block, Block::iterator insert_point)
: context_(context), block_(block), insert_point_(insert_point) {}
Builder(IrContext *context, Block *block)
: Builder(context, block, block->end()) {}
static Builder AtBlockBegin(IrContext *context, Block *block) {
return Builder(context, block, block->begin());
......
......@@ -49,7 +49,7 @@ class IrContextImpl {
registed_dialect_.clear();
for (auto &op_map : registed_op_infos_) {
op_map.second->destroy();
OpInfoImpl::Destroy(op_map.second);
}
registed_op_infos_.clear();
}
......@@ -103,24 +103,25 @@ class IrContextImpl {
return registed_op_infos_.find(name) != registed_op_infos_.end();
}
void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) {
void RegisterOpInfo(const std::string &name, OpInfo info) {
std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
VLOG(4) << "Register an operation of: [Name=" << name
<< ", OpInfoImpl ptr=" << opinfo << "].";
registed_op_infos_.emplace(name, opinfo);
<< ", OpInfo ptr=" << info.AsOpaquePointer() << "].";
registed_op_infos_.emplace(name, info);
}
OpInfoImpl *GetOpInfo(const std::string &name) {
OpInfo GetOpInfo(const std::string &name) {
std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
auto iter = registed_op_infos_.find(name);
if (iter != registed_op_infos_.end()) {
VLOG(4) << "Found a cached operation of: [name=" << name
<< ", OpInfoImpl ptr=" << iter->second << "].";
VLOG(4) << "Found a cached OpInfo of: [name=" << name
<< ", OpInfo: ptr=" << iter->second.AsOpaquePointer() << "].";
return iter->second;
}
LOG(WARNING) << "No cache found operation of: [Name=" << name << "].";
return nullptr;
return OpInfo();
}
const OpInfoMap &registered_op_info_map() { return registed_op_infos_; }
void RegisterDialect(std::string name, Dialect *dialect) {
std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
......@@ -170,7 +171,7 @@ class IrContextImpl {
ir::SpinLock registed_dialect_lock_;
// The Op registered in the context.
std::unordered_map<std::string, OpInfoImpl *> registed_op_infos_;
OpInfoMap registed_op_infos_;
ir::SpinLock registed_op_infos_lock_;
ir::SpinLock destructor_lock_;
......@@ -282,43 +283,39 @@ void IrContext::RegisterOpInfo(Dialect *dialect,
if (impl().IsOpInfoRegistered(name)) {
LOG(WARNING) << name << " op already registered.";
} else {
OpInfoImpl *opinfo = OpInfoImpl::create(dialect,
op_id,
name,
std::move(interface_map),
trait_set,
attributes_num,
attributes_name,
verify);
impl().RegisterOpInfo(name, opinfo);
OpInfo info = OpInfoImpl::Create(dialect,
op_id,
name,
std::move(interface_map),
trait_set,
attributes_num,
attributes_name,
verify);
impl().RegisterOpInfo(name, info);
VLOG(4) << name << " op registered into IrContext. --->";
}
}
OpInfo IrContext::GetRegisteredOpInfo(const std::string &name) {
OpInfoImpl *rtn = impl().GetOpInfo(name);
return rtn ? rtn : nullptr;
return impl().GetOpInfo(name);
}
const OpInfoMap &IrContext::registered_op_info_map() {
return impl().registered_op_info_map();
}
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
auto &impl = ctx->impl();
AbstractType *abstract_type = impl.GetAbstractType(type_id);
if (abstract_type) {
return *abstract_type;
} else {
throw("Abstract type not found in IrContext.");
}
AbstractType *abstract_type = ctx->impl().GetAbstractType(type_id);
IR_ENFORCE(abstract_type, "Abstract type not found in IrContext.");
return *abstract_type;
}
const AbstractAttribute &AbstractAttribute::lookup(TypeId type_id,
IrContext *ctx) {
auto &impl = ctx->impl();
AbstractAttribute *abstract_attribute = impl.GetAbstractAttribute(type_id);
if (abstract_attribute) {
return *abstract_attribute;
} else {
throw("Abstract attribute not found in IrContext.");
}
AbstractAttribute *abstract_attribute =
ctx->impl().GetAbstractAttribute(type_id);
IR_ENFORCE(abstract_attribute, "Abstract attribute not found in IrContext.");
return *abstract_attribute;
}
BFloat16Type BFloat16Type::get(IrContext *ctx) {
......
......@@ -31,6 +31,8 @@ class Type;
class OpResult;
class Attribute;
using OpInfoMap = std::unordered_map<std::string, OpInfo>;
///
/// \brief IrContext is a global parameterless class used to store and manage
/// Type, Attribute and other related data structures.
......@@ -116,6 +118,11 @@ class IrContext {
///
OpInfo GetRegisteredOpInfo(const std::string &name);
///
/// \brief Get registered operaiton infomation map.
///
const OpInfoMap &registered_op_info_map();
///
/// \brief Get the dialect of the DialectT class in the context, ff not found,
/// create and register to context.
......
......@@ -41,123 +41,6 @@ void OpInfo::Verify(const std::vector<OpResult> &inputs,
}
void *OpInfo::GetInterfaceImpl(TypeId interface_id) const {
return impl_ ? impl_->interface_impl(interface_id) : nullptr;
return impl_ ? impl_->GetInterfaceImpl(interface_id) : nullptr;
}
ir::IrContext *OpInfoImpl::ir_context() const {
return dialect()->ir_context();
}
void *OpInfoImpl::interface_impl(TypeId interface_id) const {
if (num_interfaces_ > 0) {
const InterfaceValue *p_first_interface =
reinterpret_cast<const InterfaceValue *>(
reinterpret_cast<const char *>(this) -
sizeof(TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_);
size_t left = 0, right = num_interfaces_;
while (left < right) {
size_t mid = (left + right) / 2;
if ((p_first_interface + mid)->type_id() == interface_id) {
return (p_first_interface + mid)->model();
} else if ((p_first_interface + mid)->type_id() < interface_id) {
left = mid + 1;
} else {
right = mid;
}
}
}
return nullptr;
}
bool OpInfoImpl::HasTrait(TypeId trait_id) const {
if (num_traits_ > 0) {
const TypeId *p_first_trait =
reinterpret_cast<const TypeId *>(reinterpret_cast<const char *>(this) -
sizeof(ir::TypeId) * num_traits_);
return std::binary_search(
p_first_trait, p_first_trait + num_traits_, trait_id);
}
return false;
}
bool OpInfoImpl::HasInterface(TypeId interface_id) const {
if (num_interfaces_ > 0) {
const InterfaceValue *p_first_interface =
reinterpret_cast<const InterfaceValue *>(
reinterpret_cast<const char *>(this) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_);
return std::binary_search(p_first_interface,
p_first_interface + num_interfaces_,
InterfaceValue(interface_id));
}
return false;
}
OpInfoImpl *OpInfoImpl::create(Dialect *dialect,
TypeId op_id,
const char *op_name,
std::vector<InterfaceValue> &&interface_map,
const std::vector<TypeId> &trait_set,
size_t attributes_num,
const char *attributes_name[],
VerifyPtr verify) {
// (1) Malloc memory for interfaces, traits, opinfo_impl.
size_t interfaces_num = interface_map.size();
size_t traits_num = trait_set.size();
VLOG(4) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, "
<< traits_num << " traits, " << attributes_num << " attributes.";
size_t base_size = sizeof(InterfaceValue) * interfaces_num +
sizeof(TypeId) * traits_num + sizeof(OpInfoImpl);
char *base_ptr = static_cast<char *>(::operator new(base_size));
VLOG(4) << "Malloc " << base_size << " Bytes at "
<< static_cast<void *>(base_ptr);
if (interfaces_num > 0) {
std::sort(interface_map.begin(), interface_map.end());
for (size_t index = 0; index < interfaces_num; ++index) {
new (base_ptr + index * sizeof(InterfaceValue))
InterfaceValue(std::move(interface_map[index]));
}
base_ptr += interfaces_num * sizeof(InterfaceValue);
}
if (traits_num > 0) {
auto p_first_trait = reinterpret_cast<TypeId *>(base_ptr);
memcpy(base_ptr, trait_set.data(), sizeof(TypeId) * traits_num);
std::sort(p_first_trait, p_first_trait + traits_num);
base_ptr += traits_num * sizeof(TypeId);
}
// Construct opinfo_impl.
OpInfoImpl *p_opinfo_impl = reinterpret_cast<OpInfoImpl *>(base_ptr);
VLOG(4) << "Construct op_info_impl at " << p_opinfo_impl << " ......";
OpInfoImpl *op_info = new (p_opinfo_impl) OpInfoImpl(dialect,
op_id,
op_name,
interfaces_num,
traits_num,
attributes_num,
attributes_name,
verify
);
return op_info;
}
void OpInfoImpl::destroy() {
VLOG(4) << "Destroy op_info impl at " << this;
// (1) free interfaces
char *base_ptr = reinterpret_cast<char *>(this) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_;
if (num_interfaces_ > 0) {
InterfaceValue *p_interface_val =
reinterpret_cast<InterfaceValue *>(base_ptr);
for (size_t i = 0; i < num_interfaces_; i++) {
(p_interface_val + i)->~InterfaceValue();
}
}
// (2) free memeory
VLOG(4) << "Free base_ptr " << base_ptr;
free(base_ptr);
}
} // namespace ir
......@@ -28,8 +28,6 @@ class OpInfo {
public:
constexpr OpInfo() = default;
OpInfo(const OpInfoImpl *impl) : impl_(impl) {} // NOLINT
OpInfo(const OpInfo &other) = default;
OpInfo &operator=(const OpInfo &other) = default;
......@@ -52,8 +50,6 @@ class OpInfo {
const std::vector<Type> &outputs,
const std::unordered_map<std::string, Attribute> &attributes);
const OpInfoImpl *impl() const;
template <typename Trait>
bool HasTrait() const {
return HasTrait(TypeId::get<Trait>());
......@@ -71,13 +67,20 @@ class OpInfo {
template <typename Interface>
typename Interface::Concept *GetInterfaceImpl() const;
void *AsOpaquePointer() const { return impl_; }
static OpInfo RecoverFromOpaquePointer(void *impl) {
return static_cast<OpInfoImpl *>(impl);
}
friend class OpInfoImpl;
friend struct std::hash<OpInfo>;
private:
OpInfo(OpInfoImpl *impl) : impl_(impl) {} // NOLINT
void *GetInterfaceImpl(TypeId interface_id) const;
private:
const OpInfoImpl *impl_{nullptr}; // not owned
OpInfoImpl *impl_{nullptr}; // not owned
};
template <typename Interface>
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/core/op_info_impl.h"
#include "paddle/ir/core/dialect.h"
namespace ir {
OpInfo OpInfoImpl::Create(Dialect *dialect,
TypeId op_id,
const char *op_name,
std::vector<InterfaceValue> &&interface_map,
const std::vector<TypeId> &trait_set,
size_t attributes_num,
const char *attributes_name[],
VerifyPtr verify) {
// (1) Malloc memory for interfaces, traits, opinfo_impl.
size_t interfaces_num = interface_map.size();
size_t traits_num = trait_set.size();
VLOG(4) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, "
<< traits_num << " traits, " << attributes_num << " attributes.";
size_t base_size = sizeof(InterfaceValue) * interfaces_num +
sizeof(TypeId) * traits_num + sizeof(OpInfoImpl);
char *base_ptr = static_cast<char *>(::operator new(base_size));
VLOG(4) << "Malloc " << base_size << " Bytes at "
<< static_cast<void *>(base_ptr);
if (interfaces_num > 0) {
std::sort(interface_map.begin(), interface_map.end());
for (size_t index = 0; index < interfaces_num; ++index) {
new (base_ptr + index * sizeof(InterfaceValue))
InterfaceValue(std::move(interface_map[index]));
}
base_ptr += interfaces_num * sizeof(InterfaceValue);
}
if (traits_num > 0) {
auto p_first_trait = reinterpret_cast<TypeId *>(base_ptr);
memcpy(base_ptr, trait_set.data(), sizeof(TypeId) * traits_num);
std::sort(p_first_trait, p_first_trait + traits_num);
base_ptr += traits_num * sizeof(TypeId);
}
// Construct OpInfoImpl.
VLOG(4) << "Construct OpInfoImpl at " << base_ptr << " ......";
OpInfo op_info = new (base_ptr) OpInfoImpl(dialect,
op_id,
op_name,
interfaces_num,
traits_num,
attributes_num,
attributes_name,
verify);
return op_info;
}
void OpInfoImpl::Destroy(OpInfo info) {
if (info.impl_) {
info.impl_->Destroy();
} else {
LOG(WARNING) << "A nullptr OpInfo is destoryed.";
}
}
ir::IrContext *OpInfoImpl::ir_context() const {
return dialect_ ? dialect_->ir_context() : nullptr;
}
bool OpInfoImpl::HasTrait(TypeId trait_id) const {
if (num_traits_ > 0) {
const TypeId *p_first_trait =
reinterpret_cast<const TypeId *>(reinterpret_cast<const char *>(this) -
sizeof(ir::TypeId) * num_traits_);
return std::binary_search(
p_first_trait, p_first_trait + num_traits_, trait_id);
}
return false;
}
bool OpInfoImpl::HasInterface(TypeId interface_id) const {
if (num_interfaces_ > 0) {
const InterfaceValue *p_first_interface =
reinterpret_cast<const InterfaceValue *>(
reinterpret_cast<const char *>(this) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_);
return std::binary_search(p_first_interface,
p_first_interface + num_interfaces_,
InterfaceValue(interface_id));
}
return false;
}
void *OpInfoImpl::GetInterfaceImpl(TypeId interface_id) const {
if (num_interfaces_ > 0) {
const InterfaceValue *p_first_interface =
reinterpret_cast<const InterfaceValue *>(
reinterpret_cast<const char *>(this) -
sizeof(TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_);
size_t left = 0, right = num_interfaces_;
while (left < right) {
size_t mid = (left + right) / 2;
if ((p_first_interface + mid)->type_id() == interface_id) {
return (p_first_interface + mid)->model();
} else if ((p_first_interface + mid)->type_id() < interface_id) {
left = mid + 1;
} else {
right = mid;
}
}
}
return nullptr;
}
void OpInfoImpl::Destroy() {
VLOG(4) << "Destroy op_info impl at " << this;
// (1) free interfaces
char *base_ptr = reinterpret_cast<char *>(this) -
sizeof(ir::TypeId) * num_traits_ -
sizeof(InterfaceValue) * num_interfaces_;
if (num_interfaces_ > 0) {
InterfaceValue *p_interface_val =
reinterpret_cast<InterfaceValue *>(base_ptr);
for (size_t i = 0; i < num_interfaces_; i++) {
(p_interface_val + i)->~InterfaceValue();
}
}
// (2) free memeory
VLOG(4) << "Free base_ptr " << base_ptr;
free(base_ptr);
}
} // namespace ir
......@@ -38,40 +38,39 @@ class OpInfoImpl {
/// \brief Construct and Deconstruct OpInfoImpl. The memory layout of
/// OpInfoImpl is: std::pair<TypeId, void *>... | TypeId... | OpInfoImpl
///
static OpInfoImpl *create(Dialect *dialect,
TypeId op_id,
const char *op_name,
std::vector<InterfaceValue> &&interface_map,
const std::vector<TypeId> &trait_set,
size_t attributes_num,
const char *attributes_name[],
VerifyPtr verify);
static OpInfo Create(Dialect *dialect,
TypeId op_id,
const char *op_name,
std::vector<InterfaceValue> &&interface_map,
const std::vector<TypeId> &trait_set,
size_t attributes_num,
const char *attributes_name[],
VerifyPtr verify);
static void Destroy(OpInfo info);
void destroy();
TypeId id() const { return op_id_; }
ir::IrContext *ir_context() const;
Dialect *dialect() const { return dialect_; }
VerifyPtr verify() const { return verify_; }
IrContext *ir_context() const;
/// \brief Search methods for Trait or Interface.
bool HasTrait(TypeId trait_id) const;
bool HasInterface(TypeId interface_id) const;
ir::TypeId id() const { return op_id_; }
void *interface_impl(TypeId interface_id) const;
void *GetInterfaceImpl(TypeId interface_id) const;
const char *name() const { return op_name_; }
ir::Dialect *dialect() const { return dialect_; }
uint32_t AttributeNum() const { return num_attributes_; }
const char *GetAttributeByIndex(size_t idx) const {
return idx < num_attributes_ ? p_attributes_[idx] : nullptr;
}
VerifyPtr verify() const { return verify_; }
private:
OpInfoImpl(ir::Dialect *dialect,
TypeId op_id,
......@@ -89,9 +88,10 @@ class OpInfoImpl {
num_attributes_(num_attributes),
p_attributes_(p_attributes),
verify_(verify) {}
void Destroy();
/// The dialect of this Op belong to.
ir::Dialect *dialect_;
Dialect *dialect_;
/// The TypeId of this Op.
TypeId op_id_;
......
......@@ -213,7 +213,7 @@ std::string Operation::name() const {
}
Region *Operation::GetParentRegion() const {
return parent_ ? parent_->GetParentRegion() : nullptr;
return parent_ ? parent_->GetParent() : nullptr;
}
Operation *Operation::GetParentOp() const {
......
......@@ -15,6 +15,7 @@
#pragma once
#include <ostream>
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation_utils.h"
#include "paddle/ir/core/type.h"
......@@ -22,7 +23,6 @@
namespace ir {
class OpBase;
class Program;
class Block;
class OpOperand;
class OpResult;
......@@ -85,7 +85,7 @@ class alignas(8) Operation final {
return info_.HasInterface<Interface>();
}
Block *GetParentBlock() const { return parent_; }
Block *GetParent() const { return parent_; }
Region *GetParentRegion() const;
......@@ -96,6 +96,8 @@ class alignas(8) Operation final {
/// Returns the region held by this operation at position 'index'.
Region &GetRegion(unsigned index);
operator Block::iterator() { return position_; }
private:
Operation(const AttributeMap &attribute,
ir::OpInfo op_info,
......@@ -111,7 +113,10 @@ class alignas(8) Operation final {
};
friend class Block;
void set_parent(Block *parent) { parent_ = parent; }
void SetParent(Block *parent, const Block::iterator &position) {
parent_ = parent;
position_ = position;
}
template <typename T>
struct CastUtil<
......@@ -130,6 +135,7 @@ class alignas(8) Operation final {
Region *regions_{nullptr};
Block *parent_{nullptr};
Block::iterator position_;
};
} // namespace ir
......@@ -19,26 +19,26 @@ namespace ir {
Region::~Region() { clear(); }
void Region::push_back(Block *block) {
block->set_parent(this);
block->SetParent(this);
blocks_.push_back(block);
}
void Region::emplace_back() { push_back(new Block); }
void Region::push_front(Block *block) {
block->set_parent(this);
block->SetParent(this);
blocks_.push_front(block);
}
Region::iterator Region::insert(const_iterator position, Block *block) {
block->set_parent(this);
block->SetParent(this);
return blocks_.insert(position, block);
}
void Region::TakeBody(Region &&other) {
clear();
blocks_.swap(other.blocks_);
for (auto &block : blocks_) {
block->set_parent(this);
block->SetParent(this);
}
}
......
......@@ -48,7 +48,7 @@ class Region {
void TakeBody(Region &&other);
Operation *GetParentOp() const { return parent_; }
Operation *GetParent() const { return parent_; }
private:
Region(Region &) = delete;
......
......@@ -71,3 +71,5 @@ cc_test_old(
gtest
new_ir
pd_dialect)
cc_test_old(ir_op_info_test SRCS op_info_test.cc DEPS gtest new_ir)
......@@ -163,10 +163,10 @@ TEST(op_test, op_test) {
// (2) Get registered operations.
std::string op1_name = Operation1::name();
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name);
EXPECT_EQ(op1_info != nullptr, true);
EXPECT_TRUE(op1_info);
std::string op2_name = Operation2::name();
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
EXPECT_EQ(op2_info != nullptr, true);
EXPECT_TRUE(op2_info);
EXPECT_EQ(op1_info.HasTrait<ReadOnlyTrait>(), false);
EXPECT_EQ(op1_info.HasInterface<InferShapeInterface>(), false);
EXPECT_EQ(op2_info.HasTrait<ReadOnlyTrait>(), true);
......
......@@ -98,7 +98,7 @@ TEST(program_test, program) {
ir::Block *block = program.block();
block->push_back(op1);
EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParentRegion());
EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParent());
EXPECT_EQ(program.module_op(), block->GetParentOp());
......@@ -299,7 +299,7 @@ TEST(program_test, builder) {
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx);
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program.block());
ir::Builder builder = ir::Builder(ctx, program.block());
paddle::dialect::FullOp full_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace());
......
......@@ -42,7 +42,7 @@ TEST(value_test, value_test) {
ir::Operation::Create(op1_inputs,
CreateAttributeMap("op1_name", "op1_attr"),
op1_output_types,
nullptr);
ir::OpInfo());
op1->Print(std::cout);
// 2. Construct OP2: b = OP2();
std::vector<ir::OpResult> op2_inputs = {};
......@@ -51,7 +51,7 @@ TEST(value_test, value_test) {
ir::Operation::Create(op2_inputs,
CreateAttributeMap("op2_name", "op2_attr"),
op2_output_types,
nullptr);
ir::OpInfo());
op2->Print(std::cout);
// 3. Construct OP3: c = OP3(a, b);
std::vector<ir::OpResult> op3_inputs = {op1->GetResultByIndex(0),
......@@ -61,7 +61,7 @@ TEST(value_test, value_test) {
ir::Operation::Create(op3_inputs,
CreateAttributeMap("op3_name", "op3_attr"),
op3_output_types,
nullptr);
ir::OpInfo());
op3->Print(std::cout);
// 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c);
std::vector<ir::OpResult> op4_inputs = {op1->GetResultByIndex(0),
......@@ -74,7 +74,7 @@ TEST(value_test, value_test) {
ir::Operation::Create(op4_inputs,
CreateAttributeMap("op4_name", "op4_attr"),
op4_output_types,
nullptr);
ir::OpInfo());
op4->Print(std::cout);
// Test 1:
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
TEST(ir_op_info_test, op_op_info_test) {
ir::IrContext* context = ir::IrContext::Instance();
ir::Program program(context);
ir::Block* block = program.block();
ir::Builder builder(context, block);
builder.Build<ir::ConstantOp>(ir::Int32_tAttribute::get(context, 5),
ir::Int32Type::get(context));
ir::Operation* op = block->back();
EXPECT_EQ(block->end(), ++ir::Block::iterator(*op));
auto& info_map = context->registered_op_info_map();
EXPECT_FALSE(info_map.empty());
void* info_1 = op->info().AsOpaquePointer();
auto info_2 = ir::OpInfo::RecoverFromOpaquePointer(info_1);
EXPECT_EQ(op->info(), info_2);
}
......@@ -112,7 +112,7 @@ TEST(pass_manager_test, pass_manager) {
ir::Block *block = program.block();
block->push_back(op1);
EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParentRegion());
EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParent());
EXPECT_EQ(program.module_op(), block->GetParentOp());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册