From 88e43625e058dc277dc7e7f02e9754f48e1a2213 Mon Sep 17 00:00:00 2001 From: winter-wang <78149749+winter-wang@users.noreply.github.com> Date: Tue, 30 May 2023 13:52:50 +0800 Subject: [PATCH] [IR] add region data structure. (#54185) --- paddle/fluid/translator/op_translator.cc | 10 +-- paddle/fluid/translator/program_translator.cc | 2 +- paddle/ir/core/block.cc | 14 ++++ paddle/ir/core/block.h | 24 ++++--- paddle/ir/core/builder.cc | 10 +-- paddle/ir/core/builder.h | 6 +- paddle/ir/core/builtin_op.h | 1 + paddle/ir/core/operation.cc | 68 +++++++++++++------ paddle/ir/core/operation.h | 34 +++++++--- paddle/ir/core/operation_utils.cc | 16 ++--- paddle/ir/core/operation_utils.h | 36 ++++++---- paddle/ir/core/region.cc | 48 +++++++++++++ paddle/ir/core/region.h | 60 ++++++++++++++++ test/cpp/ir/core/ir_op_test.cc | 45 +++++++++++- test/cpp/ir/core/ir_program_test.cc | 16 ++--- test/cpp/ir/core/ir_value_test.cc | 8 +-- test/cpp/pass/pass_manager_test.cc | 2 +- 17 files changed, 308 insertions(+), 92 deletions(-) create mode 100644 paddle/ir/core/region.cc create mode 100644 paddle/ir/core/region.h diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 222edc45e68..c8ce1ffdcab 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -110,8 +110,8 @@ inline ir::Operation* InsertSliceOperationForTarget( defining_info.value.type().dyn_cast(); ir::Operation* operation = ir::Operation::create({defining_info.value}, - {src_vec_type[defining_info.idx_in_vector]}, op_attribute_map, + {src_vec_type[defining_info.idx_in_vector]}, op_info); program->InsertOp(operation); ir::OpResult target_op_result = operation->GetResultByIndex(0); @@ -136,7 +136,7 @@ inline ir::Operation* InsertCombineOperationForTarget( } ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec); ir::Operation* operation = - ir::Operation::create(src_values, {target_vec_type}, {}, op_info); + ir::Operation::create(src_values, {}, {target_vec_type}, op_info); program->InsertOp(operation); return operation; } @@ -281,7 +281,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc); ir::Operation* operation = - ir::Operation::create(op_inputs, op_output_types, {}, op_info); + ir::Operation::create(op_inputs, {}, op_output_types, op_info); program->InsertOp(operation); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); @@ -299,7 +299,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc); ir::Operation* operation = - ir::Operation::create(op_inputs, op_output_types, {}, op_info); + ir::Operation::create(op_inputs, {}, op_output_types, op_info); program->InsertOp(operation); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); @@ -315,7 +315,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, OpOutputTypeList op_output_types = {}; auto op_info = LoopkUpOpInfo(ctx, op_desc); ir::Operation* operation = - ir::Operation::create(op_inputs, op_output_types, {}, op_info); + ir::Operation::create(op_inputs, {}, op_output_types, op_info); program->InsertOp(operation); return operation; diff --git a/paddle/fluid/translator/program_translator.cc b/paddle/fluid/translator/program_translator.cc index f3af3a3db54..85a09b2da03 100644 --- a/paddle/fluid/translator/program_translator.cc +++ b/paddle/fluid/translator/program_translator.cc @@ -79,7 +79,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock( }; ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); ir::Operation* operation = ir::Operation::create( - {}, {translated_var_type}, op_attribute_map, op_info); + {}, op_attribute_map, {translated_var_type}, op_info); program->InsertOp(operation); param_map[var->Name()] = VariableDefiningInfo(operation->GetResultByIndex(0)); diff --git a/paddle/ir/core/block.cc b/paddle/ir/core/block.cc index 62263a8f1a3..0a11fb81f5f 100644 --- a/paddle/ir/core/block.cc +++ b/paddle/ir/core/block.cc @@ -16,6 +16,20 @@ namespace ir { Block::~Block() { clear(); } +void Block::push_back(Operation *op) { + op->set_parent(this); + ops_.push_back(op); +} + +void Block::push_front(Operation *op) { + op->set_parent(this); + ops_.push_front(op); +} + +Block::iterator Block::insert(const_iterator iterator, Operation *op) { + op->set_parent(this); + return ops_.insert(iterator, op); +} void Block::clear() { while (!empty()) { diff --git a/paddle/ir/core/block.h b/paddle/ir/core/block.h index 3bc4ef82478..09b4b584b66 100644 --- a/paddle/ir/core/block.h +++ b/paddle/ir/core/block.h @@ -14,18 +14,23 @@ #pragma once +#include #include #include "paddle/ir/core/operation.h" namespace ir { +class Region; + class Block { public: using iterator = std::list::iterator; using reverse_iterator = std::list::reverse_iterator; + using const_iterator = std::list::const_iterator; Block() = default; ~Block(); + Region *parent() const { return parent_; } bool empty() const { return ops_.empty(); } size_t size() const { return ops_.size(); } @@ -34,21 +39,22 @@ class Block { reverse_iterator rbegin() { return ops_.rbegin(); } reverse_iterator rend() { return ops_.rend(); } - Operation *back() { return ops_.back(); } - Operation *front() { return ops_.front(); } - void push_back(Operation *op) { ops_.push_back(op); } - void push_front(Operation *op) { ops_.push_front(op); } - std::list::iterator insert( - std::list::const_iterator iterator, Operation *op) { - return ops_.insert(iterator, op); - } + Operation *back() const { return ops_.back(); } + Operation *front() const { return ops_.front(); } + void push_back(Operation *op); + void push_front(Operation *op); + iterator insert(const_iterator iterator, Operation *op); void clear(); private: Block(Block &) = delete; - void operator=(Block &) = delete; + Block &operator=(const Block &) = delete; + + friend class Region; + void set_parent(Region *parent) { parent_ = parent; } private: + Region *parent_; // not owned std::list ops_; // owned }; } // namespace ir diff --git a/paddle/ir/core/builder.cc b/paddle/ir/core/builder.cc index 107b8467747..13c16db6a6b 100644 --- a/paddle/ir/core/builder.cc +++ b/paddle/ir/core/builder.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/ir/core/builder.h" +#include "paddle/ir/core/region.h" namespace ir { Operation *Builder::insert(Operation *op) { @@ -25,17 +26,16 @@ Operation *Builder::insert(Operation *op) { } /// Create an operation given the fields represented as an OperationState. -Operation *Builder::create(const OperationArgument &argument) { - return insert(Operation::create(argument)); +Operation *Builder::create(OperationArgument &&argument) { + return insert(Operation::create(std::move(argument))); } /// Creates an operation with the given fields. Operation *Builder::create(const std::vector &inputs, - const std::vector &output_types, const AttributeMap &attribute, + const std::vector &output_types, ir::OpInfo op_info) { - OperationArgument argument(op_info, inputs, output_types, attribute); - return create(argument); + return create(OperationArgument(inputs, attribute, output_types, op_info)); } } // namespace ir diff --git a/paddle/ir/core/builder.h b/paddle/ir/core/builder.h index e1f3de2b726..7bd187961c3 100644 --- a/paddle/ir/core/builder.h +++ b/paddle/ir/core/builder.h @@ -47,12 +47,12 @@ class Builder { Operation *insert(Operation *op); /// Creates an operation given the fields represented as an OperationState. - Operation *create(const OperationArgument &argument); + Operation *create(OperationArgument &&argument); /// Creates an operation with the given fields. Operation *create(const std::vector &inputs, - const std::vector &output_types, const AttributeMap &attribute, + const std::vector &output_types, ir::OpInfo op_info); /// Create an operation of specific op type at the current insertion point. @@ -60,7 +60,7 @@ class Builder { OpTy create(Args &&...args) { OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); OpTy::build(*this, argument, std::forward(args)...); - Operation *op = create(argument); + Operation *op = create(std::move(argument)); return op->dyn_cast(); } diff --git a/paddle/ir/core/builtin_op.h b/paddle/ir/core/builtin_op.h index ffe357ed9e9..3a7f77b00aa 100644 --- a/paddle/ir/core/builtin_op.h +++ b/paddle/ir/core/builtin_op.h @@ -17,6 +17,7 @@ #include "paddle/ir/core/op_base.h" namespace ir { + /// /// \brief GetParameterOp: OpResult = GetParameterOp({StrAttribute, /// StrAttribute}) diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 4c62ed811cc..4f9575c03d3 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -15,23 +15,31 @@ #include "paddle/ir/core/operation.h" #include "paddle/ir/core/dialect.h" #include "paddle/ir/core/program.h" +#include "paddle/ir/core/region.h" #include "paddle/ir/core/utils.h" namespace ir { -Operation *Operation::create(const OperationArgument &argument) { - return create(argument.inputs_, - argument.output_types_, - argument.attribute_, - argument.info_); +Operation *Operation::create(OperationArgument &&argument) { + Operation *op = create(argument.inputs, + argument.attribute, + argument.output_types, + argument.info, + argument.regions.size()); + + for (size_t index = 0; index < argument.regions.size(); ++index) { + op->GetRegion(index).TakeBody(std::move(*argument.regions[index])); + } + return op; } // Allocate the required memory based on the size and number of inputs, outputs, // and operators, and construct it in the order of: OpOutlineResult, // OpInlineResult, Operation, Operand. Operation *Operation::create(const std::vector &inputs, - const std::vector &output_types, const AttributeMap &attribute, - ir::OpInfo op_info) { + const std::vector &output_types, + ir::OpInfo op_info, + size_t num_regions) { // 0. Verify if (op_info) { op_info.verify(inputs, output_types, attribute); @@ -50,7 +58,9 @@ Operation *Operation::create(const std::vector &inputs, : sizeof(detail::OpInlineResultImpl) * num_results; size_t operand_mem_size = sizeof(detail::OpOperandImpl) * num_operands; size_t op_mem_size = sizeof(Operation); - size_t base_size = result_mem_size + op_mem_size + operand_mem_size; + size_t region_mem_size = num_regions * sizeof(Region); + size_t base_size = + result_mem_size + op_mem_size + operand_mem_size + region_mem_size; // 2. Malloc memory. char *base_ptr = reinterpret_cast(aligned_malloc(base_size, 8)); // 3.1. Construct OpResults. @@ -65,8 +75,8 @@ Operation *Operation::create(const std::vector &inputs, } } // 3.2. Construct Operation. - Operation *op = - new (base_ptr) Operation(num_results, num_operands, attribute, op_info); + Operation *op = new (base_ptr) + Operation(attribute, op_info, num_results, num_operands, num_regions); base_ptr += sizeof(Operation); // 3.3. Construct OpOperands. if ((reinterpret_cast(base_ptr) & 0x7) != 0) { @@ -76,13 +86,27 @@ Operation *Operation::create(const std::vector &inputs, new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op); base_ptr += sizeof(detail::OpOperandImpl); } - + // 3.4. Construct Regions + if (num_regions > 0) { + op->regions_ = reinterpret_cast(base_ptr); + for (size_t idx = 0; idx < num_regions; idx++) { + new (base_ptr) Region(op); + base_ptr += sizeof(Region); + } + } return op; } // Call destructors for OpResults, Operation, and OpOperands in sequence, and // finally free memory. void Operation::destroy() { + // Deconstruct Regions. + if (num_regions_ > 0) { + for (size_t idx = 0; idx < num_regions_; idx++) { + regions_[idx].~Region(); + } + } + // 1. Get aligned_ptr by result_num. uint32_t max_inline_result_num = detail::OpResultImpl::GetMaxInlineResultIndex() + 1; @@ -136,15 +160,16 @@ void Operation::destroy() { IrContext *Operation::ir_context() const { return op_info_.ir_context(); } -Operation::Operation(uint32_t num_results, +Operation::Operation(const AttributeMap &attribute, + ir::OpInfo op_info, + uint32_t num_results, uint32_t num_operands, - const AttributeMap &attribute, - ir::OpInfo op_info) { - num_results_ = num_results; - num_operands_ = num_operands; - attribute_ = attribute; - op_info_ = op_info; -} + uint32_t num_regions) + : attribute_(attribute), + op_info_(op_info), + num_results_(num_results), + num_operands_(num_operands), + num_regions_(num_regions) {} ir::OpResult Operation::GetResultByIndex(uint32_t index) const { if (index >= num_results_) { @@ -198,4 +223,9 @@ std::string Operation::print() { std::string Operation::op_name() const { return op_info_.name(); } +Region &Operation::GetRegion(unsigned index) { + assert(index < num_regions_ && "invalid region index"); + return regions_[index]; +} + } // namespace ir diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index 7e5993ada74..d5804ee9e43 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -15,7 +15,6 @@ #pragma once #include -#include #include "paddle/ir/core/op_info.h" #include "paddle/ir/core/operation_utils.h" #include "paddle/ir/core/type.h" @@ -24,6 +23,7 @@ namespace ir { class OpBase; class Program; +class Block; class alignas(8) Operation final { public: @@ -34,16 +34,19 @@ class alignas(8) Operation final { /// used in conjunction. /// static Operation *create(const std::vector &inputs, - const std::vector &output_types, const AttributeMap &attribute, - ir::OpInfo op_info); - static Operation *create(const OperationArgument &op_argument); + const std::vector &output_types, + ir::OpInfo op_info, + size_t num_regions = 0); + static Operation *create(OperationArgument &&op_argument); /// /// \brief Destroy the operation objects and free memory by create(). /// void destroy(); + Block *parent() const { return parent_; } + IrContext *ir_context() const; ir::OpResult GetResultByIndex(uint32_t index) const; @@ -60,6 +63,8 @@ class alignas(8) Operation final { uint32_t num_operands() const { return num_operands_; } + uint32_t num_regions() const { return num_regions_; } + std::string op_name() const; template @@ -83,11 +88,15 @@ class alignas(8) Operation final { parent_program_ = parent_program; } + /// Returns the region held by this operation at position 'index'. + Region &GetRegion(unsigned index); + private: - Operation(uint32_t num_results, + Operation(const AttributeMap &attribute, + ir::OpInfo op_info, + uint32_t num_results, uint32_t num_operands, - const AttributeMap &attribute, - ir::OpInfo op_info); + uint32_t num_regions); template struct CastUtil { @@ -96,6 +105,9 @@ class alignas(8) Operation final { } }; + friend class Block; + void set_parent(Block *parent) { parent_ = parent; } + template struct CastUtil< T, @@ -107,11 +119,13 @@ class alignas(8) Operation final { OpInfo op_info_; - uint32_t num_results_ = 0; - - uint32_t num_operands_ = 0; + const uint32_t num_results_ = 0; + const uint32_t num_operands_ = 0; + const uint32_t num_regions_ = 0; + Region *regions_{nullptr}; Program *parent_program_{nullptr}; + Block *parent_{nullptr}; }; } // namespace ir diff --git a/paddle/ir/core/operation_utils.cc b/paddle/ir/core/operation_utils.cc index e61b368bafa..d68c037000a 100644 --- a/paddle/ir/core/operation_utils.cc +++ b/paddle/ir/core/operation_utils.cc @@ -13,19 +13,11 @@ // limitations under the License. #include "paddle/ir/core/operation_utils.h" +#include "paddle/ir/core/region.h" namespace ir { -OperationArgument::OperationArgument(IrContext* ir_context, std::string name) { - info_ = ir_context->GetRegisteredOpInfo(name); +OperationArgument::OperationArgument(IrContext* ir_context, + const std::string& name) { + info = ir_context->GetRegisteredOpInfo(name); } - -OperationArgument::OperationArgument(OpInfo info, - const std::vector& operands, - const std::vector& types, - const AttributeMap& named_attr) - : info_(info), - inputs_(operands), - output_types_(types), - attribute_(named_attr) {} - } // namespace ir diff --git a/paddle/ir/core/operation_utils.h b/paddle/ir/core/operation_utils.h index 8deeec781e8..fb43e8a1ca0 100644 --- a/paddle/ir/core/operation_utils.h +++ b/paddle/ir/core/operation_utils.h @@ -16,6 +16,7 @@ #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/op_info.h" +#include "paddle/ir/core/region.h" #include "paddle/ir/core/type.h" #include "paddle/ir/core/value_impl.h" @@ -30,18 +31,25 @@ using AttributeMap = std::unordered_map; // This represents an operation arguments in an combined form, suitable for use // with the builder APIs. struct OperationArgument { - OpInfo info_; - std::vector inputs_; - std::vector output_types_; - AttributeMap attribute_; + std::vector inputs; + AttributeMap attribute; + std::vector output_types; + OpInfo info; + std::vector> regions; public: - OperationArgument(IrContext* ir_context, std::string name); - explicit OperationArgument(OpInfo info) : info_(info) {} - OperationArgument(OpInfo info, - const std::vector& operands, + OperationArgument(IrContext* ir_context, const std::string& name); + explicit OperationArgument(OpInfo info) : info(info) {} + OperationArgument(const std::vector& operands, + const AttributeMap& named_attr, const std::vector& types, - const AttributeMap& named_attr = {}); + OpInfo info, + std::vector>&& regions = {}) + : inputs(operands), + attribute(named_attr), + output_types(types), + info(info), + regions(std::move(regions)) {} template void addOperands(InputIt first, InputIt last); @@ -51,31 +59,31 @@ struct OperationArgument { /// Add an attribute with the specified name. void addAttribute(const std::string& name, Attribute attr) { - attribute_[name] = attr; + this->attribute[name] = attr; } /// Add an array of named attributes. template void addAttributes(InputIt first, InputIt last); /// Get the context held by this operation state. - IrContext* getContext() const { return info_.ir_context(); } + IrContext* getContext() const { return info.ir_context(); } }; template void OperationArgument::addOperands(InputIt first, InputIt last) { while (first != last) { - inputs_.emplace_back(*first++); + inputs.emplace_back(*first++); } } template void OperationArgument::addTypes(InputIt first, InputIt last) { while (first != last) { - output_types_.emplace_back(*first++); + output_types.emplace_back(*first++); } } template void OperationArgument::addAttributes(InputIt first, InputIt last) { while (first != last) { - attribute_[first->first] = first->second; + attribute[first->first] = first->second; ++first; } } diff --git a/paddle/ir/core/region.cc b/paddle/ir/core/region.cc new file mode 100644 index 00000000000..905f497c0bc --- /dev/null +++ b/paddle/ir/core/region.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/core/region.h" +#include "paddle/ir/core/block.h" + +namespace ir { +Region::~Region() { clear(); } + +void Region::push_back(Block *block) { + block->set_parent(this); + blocks_.push_back(block); +} +void Region::push_front(Block *block) { + block->set_parent(this); + blocks_.push_front(block); +} + +Region::iterator Region::insert(const_iterator position, Block *block) { + block->set_parent(this); + return blocks_.insert(position, block); +} +void Region::TakeBody(Region &&other) { + clear(); + blocks_.swap(other.blocks_); + for (auto &block : blocks_) { + block->set_parent(this); + } +} + +void Region::clear() { + while (!empty()) { + delete blocks_.back(); + blocks_.pop_back(); + } +} +} // namespace ir diff --git a/paddle/ir/core/region.h b/paddle/ir/core/region.h new file mode 100644 index 00000000000..da84d970f1f --- /dev/null +++ b/paddle/ir/core/region.h @@ -0,0 +1,60 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace ir { + +class Block; +class Operation; + +class Region { + public: + using iterator = std::list::iterator; + using reverse_iterator = std::list::reverse_iterator; + using const_iterator = std::list::const_iterator; + ~Region(); + Region() = default; + + bool empty() const { return blocks_.empty(); } + size_t size() const { return blocks_.size(); } + + iterator begin() { return blocks_.begin(); } + iterator end() { return blocks_.end(); } + reverse_iterator rbegin() { return blocks_.rbegin(); } + reverse_iterator rend() { return blocks_.rend(); } + + Block *back() const { return blocks_.back(); } + Block *front() const { return blocks_.front(); } + void push_back(Block *block); + void push_front(Block *block); + iterator insert(const_iterator position, Block *block); + void clear(); + + void TakeBody(Region &&other); + + private: + Region(Region &) = delete; + Region &operator=(const Region &) = delete; + friend class Operation; + explicit Region(Operation *op) : parent_(op) {} + + private: + Operation *parent_{nullptr}; // not owned + std::list blocks_; // owned +}; +} // namespace ir diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index 831a3f39b21..af83acd5bfa 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -14,12 +14,14 @@ #include +#include "paddle/ir/core/block.h" #include "paddle/ir/core/builder.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/dialect.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/op_base.h" +#include "paddle/ir/core/region.h" /// \brief Define built-in Trait, derived from OpTraitBase. class ReadOnlyTrait : public ir::OpTraitBase { @@ -175,9 +177,9 @@ TEST(op_test, op_test) { std::vector op_output_types = {ir::Float32Type::get(ctx)}; ir::Operation *op2 = ir::Operation::create(op_inputs, - op_output_types, CreateAttributeMap({"op2_attr1", "op2_attr2"}, {"op2_attr1", "op2_attr2"}), + op_output_types, op2_info); ReadOnlyTrait trait = op2->dyn_cast(); @@ -188,3 +190,44 @@ TEST(op_test, op_test) { EXPECT_EQ(Op2.operation(), op2); op2->destroy(); } + +TEST(op_test, region_test) { + // (1) Register Dialect, Operation1, Operation2 into IrContext. + ir::IrContext *ctx = ir::IrContext::Instance(); + ir::Dialect *test_dialect = ctx->GetOrRegisterDialect(); + EXPECT_EQ(test_dialect != nullptr, true); + + // (2) Get registered operations. + ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(Operation1::name()); + ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(Operation2::name()); + + ir::Operation *op1 = + ir::Operation::create({}, + CreateAttributeMap({"op1_attr1", "op1_attr2"}, + {"op1_attr1", "op1_attr2"}), + {ir::Float32Type::get(ctx)}, + op1_info); + ir::Operation *op1_2 = + ir::Operation::create({}, + CreateAttributeMap({"op1_attr1", "op1_attr2"}, + {"op1_attr1", "op1_attr2"}), + {ir::Float32Type::get(ctx)}, + op1_info); + + ir::OperationArgument argument(op2_info); + argument.attribute = CreateAttributeMap({"op2_attr1", "op2_attr2"}, + {"op2_attr1", "op2_attr2"}); + argument.output_types = {ir::Float32Type::get(ctx)}; + argument.regions.emplace_back(std::make_unique()); + ir::Region *region = argument.regions.back().get(); + EXPECT_EQ(region->empty(), true); + + region->push_back(new ir::Block()); + region->push_front(new ir::Block()); + region->insert(region->begin(), new ir::Block()); + ir::Block *block = region->front(); + block->push_front(op1); + block->insert(block->begin(), op1_2); + ir::Operation *op2 = ir::Operation::create(std::move(argument)); + op2->destroy(); +} diff --git a/test/cpp/ir/core/ir_program_test.cc b/test/cpp/ir/core/ir_program_test.cc index 34a24a0475d..c16858ba1a2 100644 --- a/test/cpp/ir/core/ir_program_test.cc +++ b/test/cpp/ir/core/ir_program_test.cc @@ -91,7 +91,7 @@ TEST(program_test, program) { std::unordered_map op1_attribute{ {"parameter_name", ir::StrAttribute::get(ctx, "a")}}; ir::Operation *op1 = - ir::Operation::create({}, {dense_tensor_dtype}, op1_attribute, op1_info); + ir::Operation::create({}, op1_attribute, {dense_tensor_dtype}, op1_info); program.InsertOp(op1); @@ -123,7 +123,7 @@ TEST(program_test, program) { std::unordered_map op2_attribute{ {"parameter_name", ir::StrAttribute::get(ctx, "b")}}; ir::Operation *op2 = - ir::Operation::create({}, {dense_tensor_dtype}, op2_attribute, op2_info); + ir::Operation::create({}, op2_attribute, {dense_tensor_dtype}, op2_info); program.InsertOp(op2); EXPECT_EQ(op2->GetResultByIndex(0).type().dialect().id(), @@ -153,8 +153,8 @@ TEST(program_test, program) { std::unordered_map op3_attribute; ir::Operation *op3 = ir::Operation::create( {op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, - {dense_tensor_dtype}, op3_attribute, + {dense_tensor_dtype}, op3_info); program.InsertOp(op3); @@ -184,7 +184,7 @@ TEST(program_test, program) { std::unordered_map op4_attribute{ {"parameter_name", ir::StrAttribute::get(ctx, "c")}}; ir::Operation *op4 = ir::Operation::create( - {op3->GetResultByIndex(0)}, {}, op4_attribute, op4_info); + {op3->GetResultByIndex(0)}, op4_attribute, {}, op4_info); program.InsertOp(op4); EXPECT_EQ(op4->GetOperandByIndex(0).impl()->source().type().dialect().id(), @@ -230,7 +230,7 @@ TEST(program_test, slice_combine_test) { std::unordered_map op1_attribute{ {"parameter_name", ir::StrAttribute::get(ctx, "a")}}; ir::Operation *op1 = - ir::Operation::create({}, {fp32_dtype}, op1_attribute, op1_info); + ir::Operation::create({}, op1_attribute, {fp32_dtype}, op1_info); program.InsertOp(op1); // (5) Def b = GetParameterOp("b") @@ -239,7 +239,7 @@ TEST(program_test, slice_combine_test) { std::unordered_map op2_attribute{ {"parameter_name", ir::StrAttribute::get(ctx, "b")}}; ir::Operation *op2 = - ir::Operation::create({}, {fp32_dtype}, op2_attribute, op2_info); + ir::Operation::create({}, op2_attribute, {fp32_dtype}, op2_info); program.InsertOp(op2); // (6) Def combine_op = CombineOp("a", "b") @@ -249,8 +249,8 @@ TEST(program_test, slice_combine_test) { ir::VectorType::get(ctx, std::vector({fp32_dtype, fp32_dtype})); ir::Operation *combine_op = ir::Operation::create( {op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, - {output_type}, {}, + {output_type}, combine_op_info); program.InsertOp(combine_op); @@ -260,8 +260,8 @@ TEST(program_test, slice_combine_test) { ir::Attribute index_attr = ir::Int32_tAttribute::get(ctx, 0); ir::Operation *slice_op = ir::Operation::create({combine_op->GetResultByIndex(0)}, - {fp32_dtype}, {{"index", index_attr}}, + {fp32_dtype}, slice_op_info); program.InsertOp(slice_op); diff --git a/test/cpp/ir/core/ir_value_test.cc b/test/cpp/ir/core/ir_value_test.cc index 9a7fbc13810..28e340e52a5 100644 --- a/test/cpp/ir/core/ir_value_test.cc +++ b/test/cpp/ir/core/ir_value_test.cc @@ -40,8 +40,8 @@ TEST(value_test, value_test) { std::vector op1_output_types = {ir::Float32Type::get(ctx)}; ir::Operation *op1 = ir::Operation::create(op1_inputs, - op1_output_types, CreateAttributeMap("op1_name", "op1_attr"), + op1_output_types, nullptr); VLOG(0) << op1->print(); // 2. Construct OP2: b = OP2(); @@ -49,8 +49,8 @@ TEST(value_test, value_test) { std::vector op2_output_types = {ir::Float32Type::get(ctx)}; ir::Operation *op2 = ir::Operation::create(op2_inputs, - op2_output_types, CreateAttributeMap("op2_name", "op2_attr"), + op2_output_types, nullptr); VLOG(0) << op2->print() << std::endl; // 3. Construct OP3: c = OP3(a, b); @@ -59,8 +59,8 @@ TEST(value_test, value_test) { std::vector op3_output_types = {ir::Float32Type::get(ctx)}; ir::Operation *op3 = ir::Operation::create(op3_inputs, - op3_output_types, CreateAttributeMap("op3_name", "op3_attr"), + op3_output_types, nullptr); VLOG(0) << op3->print() << std::endl; // 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c); @@ -72,8 +72,8 @@ TEST(value_test, value_test) { } ir::Operation *op4 = ir::Operation::create(op4_inputs, - op4_output_types, CreateAttributeMap("op4_name", "op4_attr"), + op4_output_types, nullptr); VLOG(0) << op4->print() << std::endl; diff --git a/test/cpp/pass/pass_manager_test.cc b/test/cpp/pass/pass_manager_test.cc index da46ffefa1d..058f3600b03 100644 --- a/test/cpp/pass/pass_manager_test.cc +++ b/test/cpp/pass/pass_manager_test.cc @@ -96,8 +96,8 @@ TEST(pass_manager_test, pass_manager_test) { std::vector op_output_types = {ir::Float32Type::get(ctx)}; ir::Operation *op = ir::Operation::create(op_inputs, - op_output_types, CreateAttributeMap(ctx, "op1_attr1", "op1_attr1"), + op_output_types, op_info); // (4) Test pass manager for op. -- GitLab