未验证 提交 ea89e968 编写于 作者: W winter-wang 提交者: GitHub

[IR] format the use of new ir api (#54356)

上级 2e0c8678
...@@ -647,10 +647,10 @@ def GenBuildInputArgsStr( ...@@ -647,10 +647,10 @@ def GenBuildInputArgsStr(
for_func_define=True, for_func_define=True,
): ):
''' '''
Example: ir::Builder &builder, ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={} Example: ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={}
''' '''
build_args_str = "ir::Builder &builder, ir::OperationArgument &argument"
# add inputs # add inputs
build_args_str = "ir::OperationArgument &argument"
if len(op_input_name_list) > 0: if len(op_input_name_list) > 0:
for input_name in op_input_name_list: for input_name in op_input_name_list:
build_args_str += ", ir::OpResult " + input_name + "_" build_args_str += ", ir::OpResult " + input_name + "_"
......
...@@ -126,7 +126,7 @@ inline ir::Operation* InsertSliceOperationForTarget( ...@@ -126,7 +126,7 @@ inline ir::Operation* InsertSliceOperationForTarget(
ir::VectorType src_vec_type = ir::VectorType src_vec_type =
defining_info.value.type().dyn_cast<ir::VectorType>(); defining_info.value.type().dyn_cast<ir::VectorType>();
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create({defining_info.value}, ir::Operation::Create({defining_info.value},
op_attribute_map, op_attribute_map,
{src_vec_type[defining_info.idx_in_vector]}, {src_vec_type[defining_info.idx_in_vector]},
op_info); op_info);
...@@ -153,7 +153,7 @@ inline ir::Operation* InsertCombineOperationForTarget( ...@@ -153,7 +153,7 @@ inline ir::Operation* InsertCombineOperationForTarget(
} }
ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec); ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec);
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create(src_values, {}, {target_vec_type}, op_info); ir::Operation::Create(src_values, {}, {target_vec_type}, op_info);
program->block()->push_back(operation); program->block()->push_back(operation);
return operation; return operation;
} }
...@@ -165,7 +165,7 @@ inline ir::Operation* InsertConstantOperationForOptionalArg( ...@@ -165,7 +165,7 @@ inline ir::Operation* InsertConstantOperationForOptionalArg(
ir::Type null_type = ir::Type(nullptr); ir::Type null_type = ir::Type(nullptr);
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create({}, {}, {null_type}, op_info); ir::Operation::Create({}, {}, {null_type}, op_info);
program->block()->push_back(operation); program->block()->push_back(operation);
return operation; return operation;
} }
...@@ -401,7 +401,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, ...@@ -401,7 +401,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end."; VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end.";
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create(op_inputs, attribute_map, op_output_types, op_info); ir::Operation::Create(op_inputs, attribute_map, op_output_types, op_info);
VLOG(4) << "[general op][" << op_desc.Type() << "] opearation creation end."; VLOG(4) << "[general op][" << op_desc.Type() << "] opearation creation end.";
program->block()->push_back(operation); program->block()->push_back(operation);
...@@ -436,7 +436,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, ...@@ -436,7 +436,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx,
}; };
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create(op_inputs, attribute_map, op_output_types, op_info); ir::Operation::Create(op_inputs, attribute_map, op_output_types, op_info);
program->block()->push_back(operation); program->block()->push_back(operation);
RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);
...@@ -466,7 +466,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, ...@@ -466,7 +466,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
}; };
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create(op_inputs, attribute_map, op_output_types, op_info); ir::Operation::Create(op_inputs, attribute_map, op_output_types, op_info);
program->block()->push_back(operation); program->block()->push_back(operation);
return operation; return operation;
......
...@@ -79,7 +79,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock( ...@@ -79,7 +79,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock(
{"parameter_name", ir::StrAttribute::get(ctx, var->Name())}, {"parameter_name", ir::StrAttribute::get(ctx, var->Name())},
}; };
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
ir::Operation* operation = ir::Operation::create( ir::Operation* operation = ir::Operation::Create(
{}, op_attribute_map, {translated_var_type}, op_info); {}, op_attribute_map, {translated_var_type}, op_info);
program->block()->push_back(operation); program->block()->push_back(operation);
param_map[var->Name()] = param_map[var->Name()] =
......
...@@ -33,7 +33,7 @@ Block::iterator Block::insert(const_iterator iterator, Operation *op) { ...@@ -33,7 +33,7 @@ Block::iterator Block::insert(const_iterator iterator, Operation *op) {
void Block::clear() { void Block::clear() {
while (!empty()) { while (!empty()) {
ops_.back()->destroy(); ops_.back()->Destroy();
ops_.pop_back(); ops_.pop_back();
} }
} }
......
...@@ -17,7 +17,20 @@ ...@@ -17,7 +17,20 @@
#include "paddle/ir/core/value.h" #include "paddle/ir/core/value.h"
namespace ir { namespace ir {
Operation *Builder::insert(Operation *op) { /// Create an operation given the fields represented as an OperationState.
Operation *Builder::Build(OperationArgument &&argument) {
return Insert(Operation::Create(std::move(argument)));
}
/// Creates an operation with the given fields.
Operation *Builder::Build(const std::vector<OpResult> &inputs,
const AttributeMap &attribute,
const std::vector<Type> &output_types,
OpInfo op_info) {
return Build(OperationArgument(inputs, attribute, output_types, op_info));
}
Operation *Builder::Insert(Operation *op) {
if (block_) { if (block_) {
block_->insert(insert_point_, op); block_->insert(insert_point_, op);
} else { } else {
...@@ -26,17 +39,4 @@ Operation *Builder::insert(Operation *op) { ...@@ -26,17 +39,4 @@ Operation *Builder::insert(Operation *op) {
return op; return op;
} }
/// Create an operation given the fields represented as an OperationState.
Operation *Builder::create(OperationArgument &&argument) {
return insert(Operation::create(std::move(argument)));
}
/// Creates an operation with the given fields.
Operation *Builder::create(const std::vector<OpResult> &inputs,
const AttributeMap &attribute,
const std::vector<Type> &output_types,
OpInfo op_info) {
return create(OperationArgument(inputs, attribute, output_types, op_info));
}
} // namespace ir } // namespace ir
...@@ -43,27 +43,27 @@ class Builder { ...@@ -43,27 +43,27 @@ class Builder {
Block *block() const { return block_; } Block *block() const { return block_; }
Operation *insert(Operation *op);
/// Creates an operation given the fields represented as an OperationState. /// Creates an operation given the fields represented as an OperationState.
Operation *create(OperationArgument &&argument); Operation *Build(OperationArgument &&argument);
/// Creates an operation with the given fields. /// Creates an operation with the given fields.
Operation *create(const std::vector<ir::OpResult> &inputs, Operation *Build(const std::vector<ir::OpResult> &inputs,
const AttributeMap &attribute, const AttributeMap &attribute,
const std::vector<ir::Type> &output_types, const std::vector<ir::Type> &output_types,
ir::OpInfo op_info); ir::OpInfo op_info);
/// Create an operation of specific op type at the current insertion point. /// Create an operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args> template <typename OpTy, typename... Args>
OpTy create(Args &&...args) { OpTy Build(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::Build(*this, argument, std::forward<Args>(args)...); OpTy::Build(argument, std::forward<Args>(args)...);
Operation *op = create(std::move(argument)); Operation *op = Build(std::move(argument));
return op->dyn_cast<OpTy>(); return op->dyn_cast<OpTy>();
} }
private: private:
Operation *Insert(Operation *op);
IrContext *context_; IrContext *context_;
Block *block_ = nullptr; Block *block_ = nullptr;
// The insertion point within the list that this builder is inserting before. // The insertion point within the list that this builder is inserting before.
......
...@@ -37,17 +37,17 @@ Block *ModuleOp::block() { ...@@ -37,17 +37,17 @@ Block *ModuleOp::block() {
return operation()->GetRegion(0).front(); return operation()->GetRegion(0).front();
} }
ModuleOp ModuleOp::create(IrContext *context, Program *pointer) { ModuleOp ModuleOp::Create(IrContext *context, Program *pointer) {
ir::OpInfo info = context->GetRegisteredOpInfo(name()); ir::OpInfo info = context->GetRegisteredOpInfo(name());
OperationArgument argument(info); OperationArgument argument(info);
argument.AddRegion()->emplace_back(); argument.AddRegion()->emplace_back();
argument.AddAttribute("program", PointerAttribute::get(context, pointer)); argument.AddAttribute("program", PointerAttribute::get(context, pointer));
return ModuleOp(Operation::create(std::move(argument))); return ModuleOp(Operation::Create(std::move(argument)));
} }
void ModuleOp::destroy() { void ModuleOp::Destroy() {
if (operation()) { if (operation()) {
operation()->destroy(); operation()->Destroy();
*this = ModuleOp(nullptr); *this = ModuleOp(nullptr);
} }
} }
...@@ -216,8 +216,7 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs, ...@@ -216,8 +216,7 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
const char *ConstantOp::attributes_name[attributes_num] = {"value"}; const char *ConstantOp::attributes_name[attributes_num] = {"value"};
void ConstantOp::Build(Builder &builder, void ConstantOp::Build(OperationArgument &argument,
OperationArgument &argument,
Attribute value, Attribute value,
Type output_type) { Type output_type) {
argument.AddAttribute("value", value); argument.AddAttribute("value", value);
......
...@@ -40,8 +40,8 @@ class ModuleOp : public ir::Op<ModuleOp> { ...@@ -40,8 +40,8 @@ class ModuleOp : public ir::Op<ModuleOp> {
// //
// As the top operation, ModuleOp only support create&destroye through // As the top operation, ModuleOp only support create&destroye through
// below interface: "create"&"destroy". // below interface: "create"&"destroy".
static ModuleOp create(IrContext *context, Program *pointer); static ModuleOp Create(IrContext *context, Program *pointer);
void destroy(); void Destroy();
}; };
/// ///
...@@ -125,8 +125,7 @@ class ConstantOp : public Op<ConstantOp, ConstantLikeTrait> { ...@@ -125,8 +125,7 @@ class ConstantOp : public Op<ConstantOp, ConstantLikeTrait> {
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void Build(Builder &builder, // NOLINT static void Build(OperationArgument &argument, // NOLINT
OperationArgument &argument, // NOLINT
Attribute value, Attribute value,
Type output_type); Type output_type);
......
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
#include "paddle/ir/core/value_impl.h" #include "paddle/ir/core/value_impl.h"
namespace ir { namespace ir {
Operation *Operation::create(OperationArgument &&argument) { Operation *Operation::Create(OperationArgument &&argument) {
Operation *op = create(argument.inputs, Operation *op = Create(argument.inputs,
argument.attributes, argument.attributes,
argument.output_types, argument.output_types,
argument.info, argument.info,
...@@ -40,7 +40,7 @@ Operation *Operation::create(OperationArgument &&argument) { ...@@ -40,7 +40,7 @@ Operation *Operation::create(OperationArgument &&argument) {
// Allocate the required memory based on the size and number of inputs, outputs, // Allocate the required memory based on the size and number of inputs, outputs,
// and operators, and construct it in the order of: OpOutlineResult, // and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, Operand. // OpInlineResult, Operation, Operand.
Operation *Operation::create(const std::vector<ir::OpResult> &inputs, Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
const AttributeMap &attributes, const AttributeMap &attributes,
const std::vector<ir::Type> &output_types, const std::vector<ir::Type> &output_types,
ir::OpInfo op_info, ir::OpInfo op_info,
...@@ -104,7 +104,7 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs, ...@@ -104,7 +104,7 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
// Call destructors for OpResults, Operation, and OpOperands in sequence, and // Call destructors for OpResults, Operation, and OpOperands in sequence, and
// finally free memory. // finally free memory.
void Operation::destroy() { void Operation::Destroy() {
// Deconstruct Regions. // Deconstruct Regions.
if (num_regions_ > 0) { if (num_regions_ > 0) {
for (size_t idx = 0; idx < num_regions_; idx++) { for (size_t idx = 0; idx < num_regions_; idx++) {
......
...@@ -34,17 +34,17 @@ class alignas(8) Operation final { ...@@ -34,17 +34,17 @@ class alignas(8) Operation final {
/// NOTE: Similar to new and delete, the destroy() and the create() need to be /// NOTE: Similar to new and delete, the destroy() and the create() need to be
/// used in conjunction. /// used in conjunction.
/// ///
static Operation *create(const std::vector<ir::OpResult> &inputs, static Operation *Create(const std::vector<ir::OpResult> &inputs,
const AttributeMap &attributes, const AttributeMap &attributes,
const std::vector<ir::Type> &output_types, const std::vector<ir::Type> &output_types,
ir::OpInfo op_info, ir::OpInfo op_info,
size_t num_regions = 0); size_t num_regions = 0);
static Operation *create(OperationArgument &&op_argument); static Operation *Create(OperationArgument &&op_argument);
/// ///
/// \brief Destroy the operation objects and free memory by create(). /// \brief Destroy the operation objects and free memory by create().
/// ///
void destroy(); void Destroy();
IrContext *ir_context() const; IrContext *ir_context() const;
......
...@@ -18,12 +18,12 @@ ...@@ -18,12 +18,12 @@
namespace ir { namespace ir {
Program::Program(IrContext* context) { Program::Program(IrContext* context) {
module_ = ModuleOp::create(context, this); module_ = ModuleOp::Create(context, this);
} }
Program::~Program() { Program::~Program() {
if (module_) { if (module_) {
module_.destroy(); module_.Destroy();
} }
} }
......
...@@ -89,7 +89,7 @@ TEST(program_test, program) { ...@@ -89,7 +89,7 @@ TEST(program_test, program) {
{"seed", seed}, {"seed", seed},
{"place", uni_place}}; {"place", uni_place}};
ir::Operation* op1 = ir::Operation* op1 =
ir::Operation::create({}, op1_attribute, {dense_tensor_dtype}, op1_info); ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op1_info);
block->push_back(op1); block->push_back(op1);
...@@ -99,13 +99,13 @@ TEST(program_test, program) { ...@@ -99,13 +99,13 @@ TEST(program_test, program) {
ir::Attribute ten2 = ir::Int32_tAttribute::get(ctx, 3); ir::Attribute ten2 = ir::Int32_tAttribute::get(ctx, 3);
std::unordered_map<std::string, ir::Attribute> op2_attribute{{"shape", ten2}}; std::unordered_map<std::string, ir::Attribute> op2_attribute{{"shape", ten2}};
ir::Operation* op2 = ir::Operation* op2 =
ir::Operation::create({}, op1_attribute, {dense_tensor_dtype}, op2_info); ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op2_info);
block->push_back(op2); block->push_back(op2);
// (3) Def out = AddOp(a, b) // (3) Def out = AddOp(a, b)
std::string add_op_name = std::string(paddle::dialect::AddOp::name()); std::string add_op_name = std::string(paddle::dialect::AddOp::name());
ir::OpInfo add_op_info = ctx->GetRegisteredOpInfo(add_op_name); ir::OpInfo add_op_info = ctx->GetRegisteredOpInfo(add_op_name);
ir::Operation* add_op = ir::Operation::create( ir::Operation* add_op = ir::Operation::Create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, {op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
{}, {},
{dense_tensor_dtype}, {dense_tensor_dtype},
......
...@@ -82,7 +82,7 @@ TEST(infershape_test, infershape_test) { ...@@ -82,7 +82,7 @@ TEST(infershape_test, infershape_test) {
std::vector<ir::OpResult> op_inputs = {}; std::vector<ir::OpResult> op_inputs = {};
std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op = ir::Operation *op =
ir::Operation::create(op_inputs, {}, op_output_types, op_info); ir::Operation::Create(op_inputs, {}, op_output_types, op_info);
InferShapeInterface interface = op->dyn_cast<InferShapeInterface>(); InferShapeInterface interface = op->dyn_cast<InferShapeInterface>();
phi::InferMetaContext infer_meta_ctx; phi::InferMetaContext infer_meta_ctx;
......
...@@ -175,7 +175,7 @@ TEST(op_test, op_test) { ...@@ -175,7 +175,7 @@ TEST(op_test, op_test) {
std::vector<ir::OpResult> op_inputs = {}; std::vector<ir::OpResult> op_inputs = {};
std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op2 = ir::Operation *op2 =
ir::Operation::create(op_inputs, ir::Operation::Create(op_inputs,
CreateAttributeMap({"op2_attr1", "op2_attr2"}, CreateAttributeMap({"op2_attr1", "op2_attr2"},
{"op2_attr1", "op2_attr2"}), {"op2_attr1", "op2_attr2"}),
op_output_types, op_output_types,
...@@ -187,7 +187,7 @@ TEST(op_test, op_test) { ...@@ -187,7 +187,7 @@ TEST(op_test, op_test) {
interface.InferShape(); interface.InferShape();
Operation2 Op2 = op2->dyn_cast<Operation2>(); Operation2 Op2 = op2->dyn_cast<Operation2>();
EXPECT_EQ(Op2.operation(), op2); EXPECT_EQ(Op2.operation(), op2);
op2->destroy(); op2->Destroy();
} }
TEST(op_test, region_test) { TEST(op_test, region_test) {
...@@ -201,13 +201,13 @@ TEST(op_test, region_test) { ...@@ -201,13 +201,13 @@ TEST(op_test, region_test) {
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(Operation2::name()); ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(Operation2::name());
ir::Operation *op1 = ir::Operation *op1 =
ir::Operation::create({}, ir::Operation::Create({},
CreateAttributeMap({"op1_attr1", "op1_attr2"}, CreateAttributeMap({"op1_attr1", "op1_attr2"},
{"op1_attr1", "op1_attr2"}), {"op1_attr1", "op1_attr2"}),
{ir::Float32Type::get(ctx)}, {ir::Float32Type::get(ctx)},
op1_info); op1_info);
ir::Operation *op1_2 = ir::Operation *op1_2 =
ir::Operation::create({}, ir::Operation::Create({},
CreateAttributeMap({"op1_attr1", "op1_attr2"}, CreateAttributeMap({"op1_attr1", "op1_attr2"},
{"op1_attr1", "op1_attr2"}), {"op1_attr1", "op1_attr2"}),
{ir::Float32Type::get(ctx)}, {ir::Float32Type::get(ctx)},
...@@ -227,8 +227,8 @@ TEST(op_test, region_test) { ...@@ -227,8 +227,8 @@ TEST(op_test, region_test) {
ir::Block *block = region->front(); ir::Block *block = region->front();
block->push_front(op1); block->push_front(op1);
block->insert(block->begin(), op1_2); block->insert(block->begin(), op1_2);
ir::Operation *op2 = ir::Operation::create(std::move(argument)); ir::Operation *op2 = ir::Operation::Create(std::move(argument));
op2->destroy(); op2->Destroy();
} }
TEST(op_test, module_op_death) { TEST(op_test, module_op_death) {
...@@ -240,9 +240,9 @@ TEST(op_test, module_op_death) { ...@@ -240,9 +240,9 @@ TEST(op_test, module_op_death) {
ir::AttributeMap attrs{{"program", ir::Int32_tAttribute::get(ctx, 1)}}; ir::AttributeMap attrs{{"program", ir::Int32_tAttribute::get(ctx, 1)}};
std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)};
EXPECT_THROW(ir::Operation::create(inputs, {}, {}, op_info), const char *); EXPECT_THROW(ir::Operation::Create(inputs, {}, {}, op_info), const char *);
EXPECT_THROW(ir::Operation::create({}, attrs, {}, op_info), const char *); EXPECT_THROW(ir::Operation::Create({}, attrs, {}, op_info), const char *);
EXPECT_THROW(ir::Operation::create({}, {}, output_types, op_info), EXPECT_THROW(ir::Operation::Create({}, {}, output_types, op_info),
const char *); const char *);
ir::Program program(ctx); ir::Program program(ctx);
......
...@@ -64,10 +64,9 @@ TEST(program_test, program) { ...@@ -64,10 +64,9 @@ TEST(program_test, program) {
// (3) Create a float32 DenseTensor Parameter and save into Program // (3) Create a float32 DenseTensor Parameter and save into Program
ir::Type fp32_dtype = ir::Float32Type::get(ctx); ir::Type fp32_dtype = ir::Float32Type::get(ctx);
paddle::dialect::DenseTensorTypeStorage::Dim dims = {2, 2}; phi::DDim dims = {2, 2};
paddle::dialect::DenseTensorTypeStorage::DataLayout data_layout = phi::DataLayout data_layout = phi::DataLayout::NCHW;
paddle::dialect::DenseTensorTypeStorage::DataLayout::NCHW; phi::LoD lod = {{0, 1, 2}};
paddle::dialect::DenseTensorTypeStorage::LoD lod = {{0, 1, 2}};
size_t offset = 0; size_t offset = 0;
ir::Type dense_tensor_dtype = paddle::dialect::DenseTensorType::get( ir::Type dense_tensor_dtype = paddle::dialect::DenseTensorType::get(
ctx, fp32_dtype, dims, data_layout, lod, offset); ctx, fp32_dtype, dims, data_layout, lod, offset);
...@@ -94,7 +93,7 @@ TEST(program_test, program) { ...@@ -94,7 +93,7 @@ TEST(program_test, program) {
std::unordered_map<std::string, ir::Attribute> op1_attribute{ std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "a")}}; {"parameter_name", ir::StrAttribute::get(ctx, "a")}};
ir::Operation *op1 = ir::Operation *op1 =
ir::Operation::create({}, op1_attribute, {dense_tensor_dtype}, op1_info); ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op1_info);
ir::Block *block = program.block(); ir::Block *block = program.block();
block->push_back(op1); block->push_back(op1);
...@@ -132,7 +131,7 @@ TEST(program_test, program) { ...@@ -132,7 +131,7 @@ TEST(program_test, program) {
std::unordered_map<std::string, ir::Attribute> op2_attribute{ std::unordered_map<std::string, ir::Attribute> op2_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "b")}}; {"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 = ir::Operation *op2 =
ir::Operation::create({}, op2_attribute, {dense_tensor_dtype}, op2_info); ir::Operation::Create({}, op2_attribute, {dense_tensor_dtype}, op2_info);
block->push_back(op2); block->push_back(op2);
EXPECT_EQ(op2->GetResultByIndex(0).type().dialect().id(), EXPECT_EQ(op2->GetResultByIndex(0).type().dialect().id(),
...@@ -159,7 +158,7 @@ TEST(program_test, program) { ...@@ -159,7 +158,7 @@ TEST(program_test, program) {
builtin_dialect->name() + "." + std::string(AddOp::name()); builtin_dialect->name() + "." + std::string(AddOp::name());
ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name); ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name);
std::unordered_map<std::string, ir::Attribute> op3_attribute; std::unordered_map<std::string, ir::Attribute> op3_attribute;
ir::Operation *op3 = ir::Operation::create( ir::Operation *op3 = ir::Operation::Create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, {op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
op3_attribute, op3_attribute,
{dense_tensor_dtype}, {dense_tensor_dtype},
...@@ -194,7 +193,7 @@ TEST(program_test, program) { ...@@ -194,7 +193,7 @@ TEST(program_test, program) {
abs_argument.AddOperands(operands.begin(), operands.end()); abs_argument.AddOperands(operands.begin(), operands.end());
abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end()); abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end());
abs_argument.AddTypes(output_types.begin(), output_types.end()); abs_argument.AddTypes(output_types.begin(), output_types.end());
ir::Operation *abs_op = ir::Operation::create(std::move(abs_argument)); ir::Operation *abs_op = ir::Operation::Create(std::move(abs_argument));
paddle::dialect::OpYamlInfoInterface interface = paddle::dialect::OpYamlInfoInterface interface =
abs_op->dyn_cast<paddle::dialect::OpYamlInfoInterface>(); abs_op->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true); EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true);
...@@ -209,7 +208,7 @@ TEST(program_test, program) { ...@@ -209,7 +208,7 @@ TEST(program_test, program) {
ir::OperationArgument op4_argument( ir::OperationArgument op4_argument(
{op3->GetResultByIndex(0)}, {}, {}, op4_info); {op3->GetResultByIndex(0)}, {}, {}, op4_info);
op4_argument.AddAttributes(op4_attribute.begin(), op4_attribute.end()); op4_argument.AddAttributes(op4_attribute.begin(), op4_attribute.end());
ir::Operation *op4 = ir::Operation::create(std::move(op4_argument)); ir::Operation *op4 = ir::Operation::Create(std::move(op4_argument));
block->push_back(op4); block->push_back(op4);
EXPECT_EQ(op4->GetOperandByIndex(0).source().type().dialect().id(), EXPECT_EQ(op4->GetOperandByIndex(0).source().type().dialect().id(),
...@@ -256,7 +255,7 @@ TEST(program_test, slice_combine_test) { ...@@ -256,7 +255,7 @@ TEST(program_test, slice_combine_test) {
std::unordered_map<std::string, ir::Attribute> op1_attribute{ std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "a")}}; {"parameter_name", ir::StrAttribute::get(ctx, "a")}};
ir::Operation *op1 = ir::Operation *op1 =
ir::Operation::create({}, op1_attribute, {fp32_dtype}, op1_info); ir::Operation::Create({}, op1_attribute, {fp32_dtype}, op1_info);
program.block()->push_back(op1); program.block()->push_back(op1);
// (5) Def b = Constant("b") // (5) Def b = Constant("b")
...@@ -266,7 +265,7 @@ TEST(program_test, slice_combine_test) { ...@@ -266,7 +265,7 @@ TEST(program_test, slice_combine_test) {
attr_map.insert(std::pair<std::string, ir::Attribute>( attr_map.insert(std::pair<std::string, ir::Attribute>(
"value", ir::FloatAttribute::get(ctx, 2.0))); "value", ir::FloatAttribute::get(ctx, 2.0)));
ir::Operation *op2 = ir::Operation *op2 =
ir::Operation::create({}, attr_map, {fp32_dtype}, op2_info); ir::Operation::Create({}, attr_map, {fp32_dtype}, op2_info);
program.block()->push_back(op2); program.block()->push_back(op2);
// (6) Def combine_op = CombineOp("a", "b") // (6) Def combine_op = CombineOp("a", "b")
...@@ -274,7 +273,7 @@ TEST(program_test, slice_combine_test) { ...@@ -274,7 +273,7 @@ TEST(program_test, slice_combine_test) {
ir::OpInfo combine_op_info = ctx->GetRegisteredOpInfo(combine_op_name); ir::OpInfo combine_op_info = ctx->GetRegisteredOpInfo(combine_op_name);
ir::Type output_type = ir::Type output_type =
ir::VectorType::get(ctx, std::vector<ir::Type>({fp32_dtype, fp32_dtype})); ir::VectorType::get(ctx, std::vector<ir::Type>({fp32_dtype, fp32_dtype}));
ir::Operation *combine_op = ir::Operation::create( ir::Operation *combine_op = ir::Operation::Create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, {op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
{}, {},
{output_type}, {output_type},
...@@ -286,7 +285,7 @@ TEST(program_test, slice_combine_test) { ...@@ -286,7 +285,7 @@ TEST(program_test, slice_combine_test) {
ir::OpInfo slice_op_info = ctx->GetRegisteredOpInfo(slice_op_name); ir::OpInfo slice_op_info = ctx->GetRegisteredOpInfo(slice_op_name);
ir::Attribute index_attr = ir::Int32_tAttribute::get(ctx, 0); ir::Attribute index_attr = ir::Int32_tAttribute::get(ctx, 0);
ir::Operation *slice_op = ir::Operation *slice_op =
ir::Operation::create({combine_op->GetResultByIndex(0)}, ir::Operation::Create({combine_op->GetResultByIndex(0)},
{{"index", index_attr}}, {{"index", index_attr}},
{fp32_dtype}, {fp32_dtype},
slice_op_info); slice_op_info);
...@@ -302,14 +301,14 @@ TEST(program_test, builder) { ...@@ -302,14 +301,14 @@ TEST(program_test, builder) {
ir::Program program(ctx); ir::Program program(ctx);
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program.block()); ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program.block());
paddle::dialect::FullOp full_op = builder.create<paddle::dialect::FullOp>( paddle::dialect::FullOp full_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector<int64_t>{2, 2}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace());
ir::Type full_op_output = full_op->GetResultByIndex(0).type(); ir::Type full_op_output = full_op->GetResultByIndex(0).type();
EXPECT_EQ(program.block()->size() == 1, true); EXPECT_EQ(program.block()->size(), 1u);
EXPECT_EQ(program.block()->back(), full_op.operation()); EXPECT_EQ(program.block()->back(), full_op.operation());
EXPECT_EQ(full_op->num_operands() == 0, true); EXPECT_EQ(full_op->num_operands(), 0u);
EXPECT_EQ(full_op->num_results() == 1, true); EXPECT_EQ(full_op->num_results(), 1u);
EXPECT_EQ(full_op->attributes().size() == 4, true); EXPECT_EQ(full_op->attributes().size(), 4u);
EXPECT_EQ( EXPECT_EQ(
full_op_output.dyn_cast<paddle::dialect::DenseTensorType>().offset() == 0, full_op_output.dyn_cast<paddle::dialect::DenseTensorType>().offset() == 0,
true); true);
...@@ -319,7 +318,7 @@ TEST(program_test, builder) { ...@@ -319,7 +318,7 @@ TEST(program_test, builder) {
EXPECT_EQ(dim == 2, true); EXPECT_EQ(dim == 2, true);
} }
ir::ConstantOp constant = builder.create<ir::ConstantOp>( ir::ConstantOp constant = builder.Build<ir::ConstantOp>(
ir::Int32_tAttribute::get(ctx, 2), ir::Int32Type::get(ctx)); ir::Int32_tAttribute::get(ctx, 2), ir::Int32Type::get(ctx));
EXPECT_EQ(program.block()->size() == 2, true); EXPECT_EQ(program.block()->size() == 2, true);
EXPECT_EQ(constant.value().dyn_cast<ir::Int32_tAttribute>().data() == 2, EXPECT_EQ(constant.value().dyn_cast<ir::Int32_tAttribute>().data() == 2,
......
...@@ -39,7 +39,7 @@ TEST(value_test, value_test) { ...@@ -39,7 +39,7 @@ TEST(value_test, value_test) {
std::vector<ir::OpResult> op1_inputs = {}; std::vector<ir::OpResult> op1_inputs = {};
std::vector<ir::Type> op1_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op1_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op1 = ir::Operation *op1 =
ir::Operation::create(op1_inputs, ir::Operation::Create(op1_inputs,
CreateAttributeMap("op1_name", "op1_attr"), CreateAttributeMap("op1_name", "op1_attr"),
op1_output_types, op1_output_types,
nullptr); nullptr);
...@@ -48,7 +48,7 @@ TEST(value_test, value_test) { ...@@ -48,7 +48,7 @@ TEST(value_test, value_test) {
std::vector<ir::OpResult> op2_inputs = {}; std::vector<ir::OpResult> op2_inputs = {};
std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op2 = ir::Operation *op2 =
ir::Operation::create(op2_inputs, ir::Operation::Create(op2_inputs,
CreateAttributeMap("op2_name", "op2_attr"), CreateAttributeMap("op2_name", "op2_attr"),
op2_output_types, op2_output_types,
nullptr); nullptr);
...@@ -58,7 +58,7 @@ TEST(value_test, value_test) { ...@@ -58,7 +58,7 @@ TEST(value_test, value_test) {
op2->GetResultByIndex(0)}; op2->GetResultByIndex(0)};
std::vector<ir::Type> op3_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op3_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op3 = ir::Operation *op3 =
ir::Operation::create(op3_inputs, ir::Operation::Create(op3_inputs,
CreateAttributeMap("op3_name", "op3_attr"), CreateAttributeMap("op3_name", "op3_attr"),
op3_output_types, op3_output_types,
nullptr); nullptr);
...@@ -71,7 +71,7 @@ TEST(value_test, value_test) { ...@@ -71,7 +71,7 @@ TEST(value_test, value_test) {
op4_output_types.push_back(ir::Float32Type::get(ctx)); op4_output_types.push_back(ir::Float32Type::get(ctx));
} }
ir::Operation *op4 = ir::Operation *op4 =
ir::Operation::create(op4_inputs, ir::Operation::Create(op4_inputs,
CreateAttributeMap("op4_name", "op4_attr"), CreateAttributeMap("op4_name", "op4_attr"),
op4_output_types, op4_output_types,
nullptr); nullptr);
...@@ -101,11 +101,11 @@ TEST(value_test, value_test) { ...@@ -101,11 +101,11 @@ TEST(value_test, value_test) {
// destroy // destroy
VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl; VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op4->destroy(); op4->Destroy();
VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl; VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op3->destroy(); op3->Destroy();
VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl; VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op2->destroy(); op2->Destroy();
VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl; VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op1->destroy(); op1->Destroy();
} }
...@@ -107,7 +107,7 @@ TEST(pass_manager_test, pass_manager) { ...@@ -107,7 +107,7 @@ TEST(pass_manager_test, pass_manager) {
std::unordered_map<std::string, ir::Attribute> op1_attribute{ std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "a")}}; {"parameter_name", ir::StrAttribute::get(ctx, "a")}};
ir::Operation *op1 = ir::Operation *op1 =
ir::Operation::create({}, op1_attribute, {dense_tensor_dtype}, op1_info); ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op1_info);
ir::Block *block = program.block(); ir::Block *block = program.block();
block->push_back(op1); block->push_back(op1);
...@@ -145,7 +145,7 @@ TEST(pass_manager_test, pass_manager) { ...@@ -145,7 +145,7 @@ TEST(pass_manager_test, pass_manager) {
std::unordered_map<std::string, ir::Attribute> op2_attribute{ std::unordered_map<std::string, ir::Attribute> op2_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "b")}}; {"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 = ir::Operation *op2 =
ir::Operation::create({}, op2_attribute, {dense_tensor_dtype}, op2_info); ir::Operation::Create({}, op2_attribute, {dense_tensor_dtype}, op2_info);
block->push_back(op2); block->push_back(op2);
EXPECT_EQ(op2->GetResultByIndex(0).type().dialect().id(), EXPECT_EQ(op2->GetResultByIndex(0).type().dialect().id(),
...@@ -172,7 +172,7 @@ TEST(pass_manager_test, pass_manager) { ...@@ -172,7 +172,7 @@ TEST(pass_manager_test, pass_manager) {
builtin_dialect->name() + "." + std::string(AddOp::name()); builtin_dialect->name() + "." + std::string(AddOp::name());
ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name); ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name);
std::unordered_map<std::string, ir::Attribute> op3_attribute; std::unordered_map<std::string, ir::Attribute> op3_attribute;
ir::Operation *op3 = ir::Operation::create( ir::Operation *op3 = ir::Operation::Create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, {op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
op3_attribute, op3_attribute,
{dense_tensor_dtype}, {dense_tensor_dtype},
...@@ -207,7 +207,7 @@ TEST(pass_manager_test, pass_manager) { ...@@ -207,7 +207,7 @@ TEST(pass_manager_test, pass_manager) {
abs_argument.AddOperands(operands.begin(), operands.end()); abs_argument.AddOperands(operands.begin(), operands.end());
abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end()); abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end());
abs_argument.AddTypes(output_types.begin(), output_types.end()); abs_argument.AddTypes(output_types.begin(), output_types.end());
ir::Operation *abs_op = ir::Operation::create(std::move(abs_argument)); ir::Operation *abs_op = ir::Operation::Create(std::move(abs_argument));
paddle::dialect::OpYamlInfoInterface interface = paddle::dialect::OpYamlInfoInterface interface =
abs_op->dyn_cast<paddle::dialect::OpYamlInfoInterface>(); abs_op->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true); EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true);
...@@ -222,7 +222,7 @@ TEST(pass_manager_test, pass_manager) { ...@@ -222,7 +222,7 @@ TEST(pass_manager_test, pass_manager) {
ir::OperationArgument op4_argument( ir::OperationArgument op4_argument(
{op3->GetResultByIndex(0)}, {}, {}, op4_info); {op3->GetResultByIndex(0)}, {}, {}, op4_info);
op4_argument.AddAttributes(op4_attribute.begin(), op4_attribute.end()); op4_argument.AddAttributes(op4_attribute.begin(), op4_attribute.end());
ir::Operation *op4 = ir::Operation::create(std::move(op4_argument)); ir::Operation *op4 = ir::Operation::Create(std::move(op4_argument));
block->push_back(op4); block->push_back(op4);
EXPECT_EQ(op4->GetOperandByIndex(0).source().type().dialect().id(), EXPECT_EQ(op4->GetOperandByIndex(0).source().type().dialect().id(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册