未验证 提交 eac99c5b 编写于 作者: W winter-wang 提交者: GitHub

[IR] polish the new ir api name. (#54562)

上级 53f24669
...@@ -70,9 +70,9 @@ op_n_attribute_declare_str = ( ...@@ -70,9 +70,9 @@ op_n_attribute_declare_str = (
"static const char *attributes_name[{attribute_num}];" "static const char *attributes_name[{attribute_num}];"
) )
OP_GET_INPUT_TEMPLATE = """ ir::OpOperand {input_name}() {{ return operation()->GetOperandByIndex({input_index}); }} OP_GET_INPUT_TEMPLATE = """ ir::OpOperand {input_name}() {{ return operation()->operand({input_index}); }}
""" """
OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return operation()->GetResultByIndex({output_index}); }} OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return operation()->result({output_index}); }}
""" """
# ===================================== # =====================================
...@@ -817,11 +817,11 @@ def GenBuildInserFullForMutableAttribute( ...@@ -817,11 +817,11 @@ def GenBuildInserFullForMutableAttribute(
build_mutable_attribute = "" build_mutable_attribute = ""
BUILD_INTARRAY_ATTRIBUTE_TEMPLATE = """ // Generate int_array mutable attribute: {attr_name} BUILD_INTARRAY_ATTRIBUTE_TEMPLATE = """ // Generate int_array mutable attribute: {attr_name}
paddle::dialect::FullIntArrayOp full_{attr_name}_op = builder.Build<paddle::dialect::FullIntArrayOp>({attr_name}, {phi_dtype}, phi::CPUPlace()); paddle::dialect::FullIntArrayOp full_{attr_name}_op = builder.Build<paddle::dialect::FullIntArrayOp>({attr_name}, {phi_dtype}, phi::CPUPlace());
ir::OpResult {attr_name}_ = full_{attr_name}_op->GetResultByIndex(0); ir::OpResult {attr_name}_ = full_{attr_name}_op->result(0);
""" """
BUILD_SCALAR_ATTRIBUTE_TEMPLATE = """ // Generate scalar mutable attribute: {attr_name} BUILD_SCALAR_ATTRIBUTE_TEMPLATE = """ // Generate scalar mutable attribute: {attr_name}
paddle::dialect::FullOp full_{attr_name}_op = builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{{1}}, {attr_name}, {phi_dtype}, phi::CPUPlace()); paddle::dialect::FullOp full_{attr_name}_op = builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{{1}}, {attr_name}, {phi_dtype}, phi::CPUPlace());
ir::OpResult {attr_name}_ = full_{attr_name}_op->GetResultByIndex(0); ir::OpResult {attr_name}_ = full_{attr_name}_op->result(0);
""" """
for idx in range(len(op_mutable_attribute_name_list)): for idx in range(len(op_mutable_attribute_name_list)):
attr_name = op_mutable_attribute_name_list[idx] attr_name = op_mutable_attribute_name_list[idx]
......
...@@ -142,7 +142,7 @@ inline ir::Operation* InsertSliceOperationForTarget( ...@@ -142,7 +142,7 @@ inline ir::Operation* InsertSliceOperationForTarget(
{src_vec_type[defining_info.idx_in_vector]}, {src_vec_type[defining_info.idx_in_vector]},
op_info); op_info);
program->block()->push_back(operation); program->block()->push_back(operation);
ir::OpResult target_op_result = operation->GetResultByIndex(0); ir::OpResult target_op_result = operation->result(0);
(*param_map)[arg_name] = VariableDefiningInfo(target_op_result); (*param_map)[arg_name] = VariableDefiningInfo(target_op_result);
return operation; return operation;
} }
...@@ -190,7 +190,7 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx, ...@@ -190,7 +190,7 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx,
data = static_cast<float>(attr.dyn_cast<ir::BoolAttribute>().data()); data = static_cast<float>(attr.dyn_cast<ir::BoolAttribute>().data());
dtype = phi::DataType::BOOL; dtype = phi::DataType::BOOL;
} }
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program->block()); ir::Builder builder(ctx, program->block());
paddle::dialect::FullOp full_op = builder.Build<paddle::dialect::FullOp>( paddle::dialect::FullOp full_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, data, dtype, phi::CPUPlace()); std::vector<int64_t>{1}, data, dtype, phi::CPUPlace());
...@@ -206,7 +206,7 @@ inline ir::Operation* InsertFullArrayOperationForAttributeInput( ...@@ -206,7 +206,7 @@ inline ir::Operation* InsertFullArrayOperationForAttributeInput(
phi::IntArray int_array = phi::IntArray int_array =
attr.dyn_cast<paddle::dialect::IntArrayAttribute>().data(); attr.dyn_cast<paddle::dialect::IntArrayAttribute>().data();
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program->block()); ir::Builder builder(ctx, program->block());
paddle::dialect::FullIntArrayOp full_int_array_op = paddle::dialect::FullIntArrayOp full_int_array_op =
builder.Build<paddle::dialect::FullIntArrayOp>( builder.Build<paddle::dialect::FullIntArrayOp>(
int_array.GetData(), phi::DataType::INT64, phi::CPUPlace()); int_array.GetData(), phi::DataType::INT64, phi::CPUPlace());
...@@ -244,7 +244,7 @@ inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx, ...@@ -244,7 +244,7 @@ inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx,
defining_op = InsertFullOperationForAttributeInput(ctx, program, new_attr); defining_op = InsertFullOperationForAttributeInput(ctx, program, new_attr);
} }
return defining_op->GetResultByIndex(0); return defining_op->result(0);
} }
inline std::vector<ir::OpResult> GenerateOperationInput( inline std::vector<ir::OpResult> GenerateOperationInput(
...@@ -340,7 +340,7 @@ inline std::vector<ir::OpResult> GenerateOperationInput( ...@@ -340,7 +340,7 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
} else { } else {
auto* combine_op = InsertCombineOperationForTarget( auto* combine_op = InsertCombineOperationForTarget(
ctx, param_map, program, legacy_input_vars); ctx, param_map, program, legacy_input_vars);
op_inputs.push_back(combine_op->GetResultByIndex(0)); op_inputs.push_back(combine_op->result(0));
} }
} }
...@@ -472,7 +472,7 @@ inline void RecordOpResultMapping(TranslationContext* param_map, ...@@ -472,7 +472,7 @@ inline void RecordOpResultMapping(TranslationContext* param_map,
VLOG(10) << "[output recording]" VLOG(10) << "[output recording]"
<< "[" << op_desc.Type() << "]" << arg_name << " " << idx; << "[" << op_desc.Type() << "]" << arg_name << " " << idx;
ir::OpResult value = operation->GetResultByIndex(idx); ir::OpResult value = operation->result(idx);
bool generated_by_vector = value.type().isa<ir::VectorType>(); bool generated_by_vector = value.type().isa<ir::VectorType>();
(*param_map)[arg_name] = VariableDefiningInfo( (*param_map)[arg_name] = VariableDefiningInfo(
value, generated_by_vector, generated_by_vector ? idx_in_vector : -1); value, generated_by_vector, generated_by_vector ? idx_in_vector : -1);
......
...@@ -122,7 +122,7 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) { ...@@ -122,7 +122,7 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) {
ir::Operation* op = ir::Operation* op =
InsertGetParamaterOp(ctx, parameter_name_mappings[var_name]); InsertGetParamaterOp(ctx, parameter_name_mappings[var_name]);
program->block()->push_back(op); program->block()->push_back(op);
param_map[var_name] = VariableDefiningInfo(op->GetResultByIndex(0)); param_map[var_name] = VariableDefiningInfo(op->result(0));
VLOG(10) << "[op translated][get parameter]" << op; VLOG(10) << "[op translated][get parameter]" << op;
program->SetParameter(var_name, nullptr); program->SetParameter(var_name, nullptr);
......
...@@ -31,14 +31,6 @@ class Builder { ...@@ -31,14 +31,6 @@ class Builder {
Builder(IrContext *context, Block *block) Builder(IrContext *context, Block *block)
: Builder(context, block, block->end()) {} : Builder(context, block, block->end()) {}
static Builder AtBlockBegin(IrContext *context, Block *block) {
return Builder(context, block, block->begin());
}
static Builder AtBlockEnd(IrContext *context, Block *block) {
return Builder(context, block, block->end());
}
IrContext *context() const { return context_; } IrContext *context() const { return context_; }
Block *block() const { return block_; } Block *block() const { return block_; }
......
...@@ -198,7 +198,7 @@ void IrPrinter::PrintOpResult(Operation* op) { ...@@ -198,7 +198,7 @@ void IrPrinter::PrintOpResult(Operation* op) {
std::vector<OpResult> op_results; std::vector<OpResult> op_results;
op_results.reserve(num_op_result); op_results.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) { for (size_t idx = 0; idx < num_op_result; idx++) {
op_results.push_back(op->GetResultByIndex(idx)); op_results.push_back(op->result(idx));
} }
PrintInterleave( PrintInterleave(
op_results.begin(), op_results.begin(),
...@@ -230,7 +230,7 @@ void IrPrinter::PrintOpOperands(Operation* op) { ...@@ -230,7 +230,7 @@ void IrPrinter::PrintOpOperands(Operation* op) {
std::vector<Value> op_operands; std::vector<Value> op_operands;
op_operands.reserve(num_op_operands); op_operands.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) { for (size_t idx = 0; idx < num_op_operands; idx++) {
op_operands.push_back(op->GetOperandByIndex(idx).source()); op_operands.push_back(op->operand(idx).source());
} }
PrintInterleave( PrintInterleave(
op_operands.begin(), op_operands.begin(),
...@@ -245,9 +245,9 @@ void IrPrinter::PrintOperandsType(Operation* op) { ...@@ -245,9 +245,9 @@ void IrPrinter::PrintOperandsType(Operation* op) {
std::vector<Type> op_operand_types; std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands); op_operand_types.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) { for (size_t idx = 0; idx < num_op_operands; idx++) {
auto op_operand = op->GetOperandByIndex(idx); auto op_operand = op->operand(idx);
if (op_operand) { if (op_operand) {
op_operand_types.push_back(op->GetOperandByIndex(idx).source().type()); op_operand_types.push_back(op->operand(idx).source().type());
} else { } else {
op_operand_types.push_back(Type(nullptr)); op_operand_types.push_back(Type(nullptr));
} }
...@@ -266,7 +266,7 @@ void IrPrinter::PrintOpReturnType(Operation* op) { ...@@ -266,7 +266,7 @@ void IrPrinter::PrintOpReturnType(Operation* op) {
std::vector<Type> op_result_types; std::vector<Type> op_result_types;
op_result_types.reserve(num_op_result); op_result_types.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) { for (size_t idx = 0; idx < num_op_result; idx++) {
auto op_result = op->GetResultByIndex(idx); auto op_result = op->result(idx);
if (op_result) { if (op_result) {
op_result_types.push_back(op_result.type()); op_result_types.push_back(op_result.type());
} else { } else {
......
...@@ -179,7 +179,7 @@ Operation::Operation(const AttributeMap &attributes, ...@@ -179,7 +179,7 @@ Operation::Operation(const AttributeMap &attributes,
num_operands_(num_operands), num_operands_(num_operands),
num_regions_(num_regions) {} num_regions_(num_regions) {}
ir::OpResult Operation::GetResultByIndex(uint32_t index) const { ir::OpResult Operation::result(uint32_t index) const {
if (index >= num_results_) { if (index >= num_results_) {
IR_THROW("index exceeds OP output range."); IR_THROW("index exceeds OP output range.");
} }
...@@ -200,7 +200,7 @@ ir::OpResult Operation::GetResultByIndex(uint32_t index) const { ...@@ -200,7 +200,7 @@ ir::OpResult Operation::GetResultByIndex(uint32_t index) const {
} }
} }
ir::OpOperand Operation::GetOperandByIndex(uint32_t index) const { ir::OpOperand Operation::operand(uint32_t index) const {
if (index >= num_operands_) { if (index >= num_operands_) {
IR_THROW("index exceeds OP input range."); IR_THROW("index exceeds OP input range.");
} }
......
...@@ -47,10 +47,12 @@ class alignas(8) Operation final { ...@@ -47,10 +47,12 @@ class alignas(8) Operation final {
void Destroy(); void Destroy();
IrContext *ir_context() const; IrContext *ir_context() const;
Dialect *dialect() const; Dialect *dialect() const;
OpResult GetResultByIndex(uint32_t index) const;
OpOperand GetOperandByIndex(uint32_t index) const; OpResult result(uint32_t index) const;
OpOperand operand(uint32_t index) const;
void Print(std::ostream &os); void Print(std::ostream &os);
......
...@@ -88,7 +88,7 @@ RewriterBase::~RewriterBase() = default; ...@@ -88,7 +88,7 @@ RewriterBase::~RewriterBase() = default;
// // assert(op->num_results() == new_values.size() && "incorrect number of // // assert(op->num_results() == new_values.size() && "incorrect number of
// values to replace operation"); NotifyRootReplaced(op, new_values); bool // values to replace operation"); NotifyRootReplaced(op, new_values); bool
// replace_all_uses = true; for (uint32_t i = 0; i < op->num_results(); ++i) { // replace_all_uses = true; for (uint32_t i = 0; i < op->num_results(); ++i) {
// // op->GetResultByIndex(0) // // op->result(0)
// } // }
// } // }
// void RewriterBase::ReplaceOpWithIf(Operation* op, // void RewriterBase::ReplaceOpWithIf(Operation* op,
...@@ -138,7 +138,7 @@ void RewriterBase::ReplaceOpWithResultsOfAnotherOp(Operation* op, ...@@ -138,7 +138,7 @@ void RewriterBase::ReplaceOpWithResultsOfAnotherOp(Operation* op,
"replacement op doesn't match results of original op"); "replacement op doesn't match results of original op");
// TODO(wilber): Op support results method. // TODO(wilber): Op support results method.
// if (op->num_results() == 1) return ReplaceOp(op, // if (op->num_results() == 1) return ReplaceOp(op,
// new_op->GetResultByIndex(0)); return ReplaceOp(op, new_op->GetResults()); // new_op->result(0)); return ReplaceOp(op, new_op->GetResults());
} }
} // namespace ir } // namespace ir
...@@ -56,7 +56,7 @@ TEST(program_test, program) { ...@@ -56,7 +56,7 @@ TEST(program_test, program) {
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>(); ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx); ir::Program program(ctx);
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program.block()); ir::Builder builder(ctx, program.block());
ir::Block* block = program.block(); ir::Block* block = program.block();
// Def: A = paddle::dialect::UniformOp(std::vector<int64_t> shape, // Def: A = paddle::dialect::UniformOp(std::vector<int64_t> shape,
...@@ -68,9 +68,7 @@ TEST(program_test, program) { ...@@ -68,9 +68,7 @@ TEST(program_test, program) {
1.0, 1.0,
2, 2,
phi::CPUPlace()); phi::CPUPlace());
EXPECT_EQ(uniform1->GetResultByIndex(0) EXPECT_EQ(uniform1->result(0).type().isa<paddle::dialect::DenseTensorType>(),
.type()
.isa<paddle::dialect::DenseTensorType>(),
true); true);
EXPECT_EQ(block->size(), 4u); EXPECT_EQ(block->size(), 4u);
...@@ -82,18 +80,15 @@ TEST(program_test, program) { ...@@ -82,18 +80,15 @@ TEST(program_test, program) {
1.0, 1.0,
2, 2,
phi::CPUPlace()); phi::CPUPlace());
EXPECT_EQ(uniform2->GetResultByIndex(0) EXPECT_EQ(uniform2->result(0).type().isa<paddle::dialect::DenseTensorType>(),
.type()
.isa<paddle::dialect::DenseTensorType>(),
true); true);
EXPECT_EQ(block->size(), 8u); EXPECT_EQ(block->size(), 8u);
// Def: C = paddle::dialect::AddOp(ir::OpResult x_, ir::OpResult y_) // Def: C = paddle::dialect::AddOp(ir::OpResult x_, ir::OpResult y_)
paddle::dialect::AddOp add = builder.Build<paddle::dialect::AddOp>( paddle::dialect::AddOp add = builder.Build<paddle::dialect::AddOp>(
uniform1->GetResultByIndex(0), uniform2->GetResultByIndex(0)); uniform1->result(0), uniform2->result(0));
EXPECT_EQ( EXPECT_EQ(add->result(0).type().isa<paddle::dialect::DenseTensorType>(),
add->GetResultByIndex(0).type().isa<paddle::dialect::DenseTensorType>(), true);
true);
EXPECT_EQ(block->size(), 9u); EXPECT_EQ(block->size(), 9u);
// Execute program // Execute program
......
...@@ -104,13 +104,10 @@ TEST(program_test, program) { ...@@ -104,13 +104,10 @@ TEST(program_test, program) {
EXPECT_EQ(&program, op1->GetParentProgram()); EXPECT_EQ(&program, op1->GetParentProgram());
EXPECT_EQ(op1->GetResultByIndex(0).type().dialect().id(), EXPECT_EQ(op1->result(0).type().dialect().id(), paddle_dialect->id());
paddle_dialect->id());
using Interface = paddle::dialect::ParameterConvertInterface; using Interface = paddle::dialect::ParameterConvertInterface;
Interface *a_interface = op1->GetResultByIndex(0) Interface *a_interface =
.type() op1->result(0).type().dialect().GetRegisteredInterface<Interface>();
.dialect()
.GetRegisteredInterface<Interface>();
std::shared_ptr<paddle::framework::Variable> a_var = std::shared_ptr<paddle::framework::Variable> a_var =
a_interface->ParameterToVariable(program.GetParameter("a")); a_interface->ParameterToVariable(program.GetParameter("a"));
const phi::DenseTensor &a_tensor = a_var->Get<phi::DenseTensor>(); const phi::DenseTensor &a_tensor = a_var->Get<phi::DenseTensor>();
...@@ -134,12 +131,9 @@ TEST(program_test, program) { ...@@ -134,12 +131,9 @@ TEST(program_test, program) {
ir::Operation::Create({}, op2_attribute, {dense_tensor_dtype}, op2_info); ir::Operation::Create({}, op2_attribute, {dense_tensor_dtype}, op2_info);
block->push_back(op2); block->push_back(op2);
EXPECT_EQ(op2->GetResultByIndex(0).type().dialect().id(), EXPECT_EQ(op2->result(0).type().dialect().id(), paddle_dialect->id());
paddle_dialect->id()); Interface *b_interface =
Interface *b_interface = op2->GetResultByIndex(0) op2->result(0).type().dialect().GetRegisteredInterface<Interface>();
.type()
.dialect()
.GetRegisteredInterface<Interface>();
std::shared_ptr<paddle::framework::Variable> b_var = std::shared_ptr<paddle::framework::Variable> b_var =
b_interface->ParameterToVariable(program.GetParameter("b")); b_interface->ParameterToVariable(program.GetParameter("b"));
const phi::DenseTensor &b_tensor = b_var->Get<phi::DenseTensor>(); const phi::DenseTensor &b_tensor = b_var->Get<phi::DenseTensor>();
...@@ -158,11 +152,10 @@ TEST(program_test, program) { ...@@ -158,11 +152,10 @@ TEST(program_test, program) {
builtin_dialect->name() + "." + std::string(AddOp::name()); builtin_dialect->name() + "." + std::string(AddOp::name());
ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name); ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name);
std::unordered_map<std::string, ir::Attribute> op3_attribute; std::unordered_map<std::string, ir::Attribute> op3_attribute;
ir::Operation *op3 = ir::Operation::Create( ir::Operation *op3 = ir::Operation::Create({op1->result(0), op2->result(0)},
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, op3_attribute,
op3_attribute, {dense_tensor_dtype},
{dense_tensor_dtype}, op3_info);
op3_info);
block->push_back(op3); block->push_back(op3);
phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>( phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>(
...@@ -186,7 +179,7 @@ TEST(program_test, program) { ...@@ -186,7 +179,7 @@ TEST(program_test, program) {
// (7) Def AbsOp(b) // (7) Def AbsOp(b)
ir::OpInfo abs_info = ctx->GetRegisteredOpInfo("pd.abs"); ir::OpInfo abs_info = ctx->GetRegisteredOpInfo("pd.abs");
std::vector<ir::OpResult> operands = {op1->GetResultByIndex(0)}; std::vector<ir::OpResult> operands = {op1->result(0)};
std::unordered_map<std::string, ir::Attribute> abs_op_attribute; std::unordered_map<std::string, ir::Attribute> abs_op_attribute;
std::vector<ir::Type> output_types = {dense_tensor_dtype}; std::vector<ir::Type> output_types = {dense_tensor_dtype};
ir::OperationArgument abs_argument(abs_info); ir::OperationArgument abs_argument(abs_info);
...@@ -205,15 +198,14 @@ TEST(program_test, program) { ...@@ -205,15 +198,14 @@ TEST(program_test, program) {
std::unordered_map<std::string, ir::Attribute> op4_attribute{ std::unordered_map<std::string, ir::Attribute> op4_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "c")}}; {"parameter_name", ir::StrAttribute::get(ctx, "c")}};
ir::OperationArgument op4_argument( ir::OperationArgument op4_argument({op3->result(0)}, {}, {}, op4_info);
{op3->GetResultByIndex(0)}, {}, {}, op4_info);
op4_argument.AddAttributes(op4_attribute.begin(), op4_attribute.end()); op4_argument.AddAttributes(op4_attribute.begin(), op4_attribute.end());
ir::Operation *op4 = ir::Operation::Create(std::move(op4_argument)); ir::Operation *op4 = ir::Operation::Create(std::move(op4_argument));
block->push_back(op4); block->push_back(op4);
EXPECT_EQ(op4->GetOperandByIndex(0).source().type().dialect().id(), EXPECT_EQ(op4->operand(0).source().type().dialect().id(),
paddle_dialect->id()); paddle_dialect->id());
Interface *c_interface = op4->GetOperandByIndex(0) Interface *c_interface = op4->operand(0)
.source() .source()
.type() .type()
.dialect() .dialect()
...@@ -274,21 +266,17 @@ TEST(program_test, slice_combine_test) { ...@@ -274,21 +266,17 @@ TEST(program_test, slice_combine_test) {
ir::Type output_type = ir::Type output_type =
ir::VectorType::get(ctx, std::vector<ir::Type>({fp32_dtype, fp32_dtype})); ir::VectorType::get(ctx, std::vector<ir::Type>({fp32_dtype, fp32_dtype}));
ir::Operation *combine_op = ir::Operation::Create( ir::Operation *combine_op = ir::Operation::Create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, {op1->result(0), op2->result(0)}, {}, {output_type}, combine_op_info);
{},
{output_type},
combine_op_info);
program.block()->push_back(combine_op); program.block()->push_back(combine_op);
// (7) Def slice_op = SliceOp(combine_op, 0) // (7) Def slice_op = SliceOp(combine_op, 0)
std::string slice_op_name = std::string(ir::SliceOp::name()); std::string slice_op_name = std::string(ir::SliceOp::name());
ir::OpInfo slice_op_info = ctx->GetRegisteredOpInfo(slice_op_name); ir::OpInfo slice_op_info = ctx->GetRegisteredOpInfo(slice_op_name);
ir::Attribute index_attr = ir::Int32Attribute::get(ctx, 0); ir::Attribute index_attr = ir::Int32Attribute::get(ctx, 0);
ir::Operation *slice_op = ir::Operation *slice_op = ir::Operation::Create({combine_op->result(0)},
ir::Operation::Create({combine_op->GetResultByIndex(0)}, {{"index", index_attr}},
{{"index", index_attr}}, {fp32_dtype},
{fp32_dtype}, slice_op_info);
slice_op_info);
program.block()->push_back(slice_op); program.block()->push_back(slice_op);
// (8) Traverse Program // (8) Traverse Program
...@@ -303,7 +291,7 @@ TEST(program_test, builder) { ...@@ -303,7 +291,7 @@ TEST(program_test, builder) {
paddle::dialect::FullOp full_op = builder.Build<paddle::dialect::FullOp>( paddle::dialect::FullOp full_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector<int64_t>{2, 2}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace());
ir::Type full_op_output = full_op->GetResultByIndex(0).type(); ir::Type full_op_output = full_op->result(0).type();
EXPECT_EQ(program.block()->size(), 1u); EXPECT_EQ(program.block()->size(), 1u);
EXPECT_EQ(program.block()->back(), full_op.operation()); EXPECT_EQ(program.block()->back(), full_op.operation());
EXPECT_EQ(full_op->num_operands(), 0u); EXPECT_EQ(full_op->num_operands(), 0u);
......
...@@ -54,8 +54,7 @@ TEST(value_test, value_test) { ...@@ -54,8 +54,7 @@ TEST(value_test, value_test) {
ir::OpInfo()); ir::OpInfo());
op2->Print(std::cout); op2->Print(std::cout);
// 3. Construct OP3: c = OP3(a, b); // 3. Construct OP3: c = OP3(a, b);
std::vector<ir::OpResult> op3_inputs = {op1->GetResultByIndex(0), std::vector<ir::OpResult> op3_inputs = {op1->result(0), op2->result(0)};
op2->GetResultByIndex(0)};
std::vector<ir::Type> op3_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op3_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op3 = ir::Operation *op3 =
ir::Operation::Create(op3_inputs, ir::Operation::Create(op3_inputs,
...@@ -64,8 +63,7 @@ TEST(value_test, value_test) { ...@@ -64,8 +63,7 @@ TEST(value_test, value_test) {
ir::OpInfo()); ir::OpInfo());
op3->Print(std::cout); op3->Print(std::cout);
// 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c); // 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c);
std::vector<ir::OpResult> op4_inputs = {op1->GetResultByIndex(0), std::vector<ir::OpResult> op4_inputs = {op1->result(0), op3->result(0)};
op3->GetResultByIndex(0)};
std::vector<ir::Type> op4_output_types; std::vector<ir::Type> op4_output_types;
for (size_t i = 0; i < 7; i++) { for (size_t i = 0; i < 7; i++) {
op4_output_types.push_back(ir::Float32Type::get(ctx)); op4_output_types.push_back(ir::Float32Type::get(ctx));
...@@ -78,34 +76,34 @@ TEST(value_test, value_test) { ...@@ -78,34 +76,34 @@ TEST(value_test, value_test) {
op4->Print(std::cout); op4->Print(std::cout);
// Test 1: // Test 1:
EXPECT_EQ(op1->GetResultByIndex(0).GetDefiningOp(), op1); EXPECT_EQ(op1->result(0).GetDefiningOp(), op1);
EXPECT_EQ(op2->GetResultByIndex(0).GetDefiningOp(), op2); EXPECT_EQ(op2->result(0).GetDefiningOp(), op2);
EXPECT_EQ(op3->GetResultByIndex(0).GetDefiningOp(), op3); EXPECT_EQ(op3->result(0).GetDefiningOp(), op3);
EXPECT_EQ(op4->GetResultByIndex(6).GetDefiningOp(), op4); EXPECT_EQ(op4->result(6).GetDefiningOp(), op4);
// Test 2: op1_first_output -> op4_first_input // Test 2: op1_first_output -> op4_first_input
ir::OpResult op1_first_output = op1->GetResultByIndex(0); ir::OpResult op1_first_output = op1->result(0);
ir::OpOperand op4_first_input = op4->GetOperandByIndex(0); ir::OpOperand op4_first_input = op4->operand(0);
EXPECT_EQ(op1_first_output.first_use(), op4_first_input); EXPECT_EQ(op1_first_output.first_use(), op4_first_input);
ir::OpOperand op3_first_input = op3->GetOperandByIndex(0); ir::OpOperand op3_first_input = op3->operand(0);
EXPECT_EQ(op4_first_input.next_use(), op3_first_input); EXPECT_EQ(op4_first_input.next_use(), op3_first_input);
EXPECT_EQ(op3_first_input.next_use(), nullptr); EXPECT_EQ(op3_first_input.next_use(), nullptr);
// Test 3: Value iterator // Test 3: Value iterator
ir::Value::use_iterator iter = op1->GetResultByIndex(0).begin(); ir::Value::use_iterator iter = op1->result(0).begin();
EXPECT_EQ(iter.owner(), op4); EXPECT_EQ(iter.owner(), op4);
++iter; ++iter;
EXPECT_EQ(iter.owner(), op3); EXPECT_EQ(iter.owner(), op3);
// destroy // destroy
VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl; VLOG(0) << op1->result(0).print_ud_chain() << std::endl;
op4->Destroy(); op4->Destroy();
VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl; VLOG(0) << op1->result(0).print_ud_chain() << std::endl;
op3->Destroy(); op3->Destroy();
VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl; VLOG(0) << op1->result(0).print_ud_chain() << std::endl;
op2->Destroy(); op2->Destroy();
VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl; VLOG(0) << op1->result(0).print_ud_chain() << std::endl;
op1->Destroy(); op1->Destroy();
} }
...@@ -54,7 +54,7 @@ void build_scope(ir::Block* block, ...@@ -54,7 +54,7 @@ void build_scope(ir::Block* block,
int input = (*it)->num_operands(); int input = (*it)->num_operands();
if (input > 0) { if (input > 0) {
for (int i = 0; i < input; ++i) { for (int i = 0; i < input; ++i) {
auto ptr = (*it)->GetOperandByIndex(i).source(); auto ptr = (*it)->operand(i).source();
std::string name; std::string name;
if (name_map->find(ptr) != name_map->end()) { if (name_map->find(ptr) != name_map->end()) {
name = name_map->at(ptr); name = name_map->at(ptr);
...@@ -72,7 +72,7 @@ void build_scope(ir::Block* block, ...@@ -72,7 +72,7 @@ void build_scope(ir::Block* block,
if (out_num > 0) { if (out_num > 0) {
for (int i = 0; i < out_num; ++i) { for (int i = 0; i < out_num; ++i) {
ir::Value ptr = (*it)->GetResultByIndex(i); ir::Value ptr = (*it)->result(i);
std::string name; std::string name;
if (name_map->find(ptr) != name_map->end()) { if (name_map->find(ptr) != name_map->end()) {
name = name_map->at(ptr); name = name_map->at(ptr);
...@@ -131,7 +131,7 @@ void build_context(ir::Operation* op, ...@@ -131,7 +131,7 @@ void build_context(ir::Operation* op,
for (auto& t : vec_param_list) { for (auto& t : vec_param_list) {
if (input_index_map.count(t)) { if (input_index_map.count(t)) {
// get information from input // get information from input
ir::Value ptr = op->GetOperandByIndex(input_index_map[t]).source(); ir::Value ptr = op->operand(input_index_map[t]).source();
auto in_var_name = name_map.at(ptr); auto in_var_name = name_map.at(ptr);
if (mutable_attr_type_map.count(t)) { if (mutable_attr_type_map.count(t)) {
...@@ -180,7 +180,7 @@ void build_context(ir::Operation* op, ...@@ -180,7 +180,7 @@ void build_context(ir::Operation* op,
} }
} }
ir::Value out_ptr = op->GetResultByIndex(0); ir::Value out_ptr = op->result(0);
auto name = name_map.at(out_ptr); auto name = name_map.at(out_ptr);
ctx->EmplaceBackOutput(scope->Var(name)->GetMutable<phi::DenseTensor>()); ctx->EmplaceBackOutput(scope->Var(name)->GetMutable<phi::DenseTensor>());
...@@ -239,7 +239,7 @@ class PhiKernelAdaptor { ...@@ -239,7 +239,7 @@ class PhiKernelAdaptor {
(*it), name_map, scope_, &kernel_ctx, false); (*it), name_map, scope_, &kernel_ctx, false);
found_it->second(&kernel_ctx); found_it->second(&kernel_ctx);
auto out_value = (*it)->GetResultByIndex(0); auto out_value = (*it)->result(0);
out_name = name_map[out_value]; out_name = name_map[out_value];
} }
} }
......
...@@ -118,13 +118,10 @@ TEST(pass_manager_test, pass_manager) { ...@@ -118,13 +118,10 @@ TEST(pass_manager_test, pass_manager) {
EXPECT_EQ(&program, op1->GetParentProgram()); EXPECT_EQ(&program, op1->GetParentProgram());
EXPECT_EQ(op1->GetResultByIndex(0).type().dialect().id(), EXPECT_EQ(op1->result(0).type().dialect().id(), paddle_dialect->id());
paddle_dialect->id());
using Interface = paddle::dialect::ParameterConvertInterface; using Interface = paddle::dialect::ParameterConvertInterface;
Interface *a_interface = op1->GetResultByIndex(0) Interface *a_interface =
.type() op1->result(0).type().dialect().GetRegisteredInterface<Interface>();
.dialect()
.GetRegisteredInterface<Interface>();
std::shared_ptr<paddle::framework::Variable> a_var = std::shared_ptr<paddle::framework::Variable> a_var =
a_interface->ParameterToVariable(program.GetParameter("a")); a_interface->ParameterToVariable(program.GetParameter("a"));
const phi::DenseTensor &a_tensor = a_var->Get<phi::DenseTensor>(); const phi::DenseTensor &a_tensor = a_var->Get<phi::DenseTensor>();
...@@ -148,12 +145,9 @@ TEST(pass_manager_test, pass_manager) { ...@@ -148,12 +145,9 @@ TEST(pass_manager_test, pass_manager) {
ir::Operation::Create({}, op2_attribute, {dense_tensor_dtype}, op2_info); ir::Operation::Create({}, op2_attribute, {dense_tensor_dtype}, op2_info);
block->push_back(op2); block->push_back(op2);
EXPECT_EQ(op2->GetResultByIndex(0).type().dialect().id(), EXPECT_EQ(op2->result(0).type().dialect().id(), paddle_dialect->id());
paddle_dialect->id()); Interface *b_interface =
Interface *b_interface = op2->GetResultByIndex(0) op2->result(0).type().dialect().GetRegisteredInterface<Interface>();
.type()
.dialect()
.GetRegisteredInterface<Interface>();
std::shared_ptr<paddle::framework::Variable> b_var = std::shared_ptr<paddle::framework::Variable> b_var =
b_interface->ParameterToVariable(program.GetParameter("b")); b_interface->ParameterToVariable(program.GetParameter("b"));
const phi::DenseTensor &b_tensor = b_var->Get<phi::DenseTensor>(); const phi::DenseTensor &b_tensor = b_var->Get<phi::DenseTensor>();
...@@ -172,11 +166,10 @@ TEST(pass_manager_test, pass_manager) { ...@@ -172,11 +166,10 @@ TEST(pass_manager_test, pass_manager) {
builtin_dialect->name() + "." + std::string(AddOp::name()); builtin_dialect->name() + "." + std::string(AddOp::name());
ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name); ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name);
std::unordered_map<std::string, ir::Attribute> op3_attribute; std::unordered_map<std::string, ir::Attribute> op3_attribute;
ir::Operation *op3 = ir::Operation::Create( ir::Operation *op3 = ir::Operation::Create({op1->result(0), op2->result(0)},
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, op3_attribute,
op3_attribute, {dense_tensor_dtype},
{dense_tensor_dtype}, op3_info);
op3_info);
block->push_back(op3); block->push_back(op3);
phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>( phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>(
...@@ -200,7 +193,7 @@ TEST(pass_manager_test, pass_manager) { ...@@ -200,7 +193,7 @@ TEST(pass_manager_test, pass_manager) {
// (7) Def AbsOp(b) // (7) Def AbsOp(b)
ir::OpInfo abs_info = ctx->GetRegisteredOpInfo("pd.abs"); ir::OpInfo abs_info = ctx->GetRegisteredOpInfo("pd.abs");
std::vector<ir::OpResult> operands = {op1->GetResultByIndex(0)}; std::vector<ir::OpResult> operands = {op1->result(0)};
std::unordered_map<std::string, ir::Attribute> abs_op_attribute; std::unordered_map<std::string, ir::Attribute> abs_op_attribute;
std::vector<ir::Type> output_types = {dense_tensor_dtype}; std::vector<ir::Type> output_types = {dense_tensor_dtype};
ir::OperationArgument abs_argument(abs_info); ir::OperationArgument abs_argument(abs_info);
...@@ -219,15 +212,14 @@ TEST(pass_manager_test, pass_manager) { ...@@ -219,15 +212,14 @@ TEST(pass_manager_test, pass_manager) {
std::unordered_map<std::string, ir::Attribute> op4_attribute{ std::unordered_map<std::string, ir::Attribute> op4_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "c")}}; {"parameter_name", ir::StrAttribute::get(ctx, "c")}};
ir::OperationArgument op4_argument( ir::OperationArgument op4_argument({op3->result(0)}, {}, {}, op4_info);
{op3->GetResultByIndex(0)}, {}, {}, op4_info);
op4_argument.AddAttributes(op4_attribute.begin(), op4_attribute.end()); op4_argument.AddAttributes(op4_attribute.begin(), op4_attribute.end());
ir::Operation *op4 = ir::Operation::Create(std::move(op4_argument)); ir::Operation *op4 = ir::Operation::Create(std::move(op4_argument));
block->push_back(op4); block->push_back(op4);
EXPECT_EQ(op4->GetOperandByIndex(0).source().type().dialect().id(), EXPECT_EQ(op4->operand(0).source().type().dialect().id(),
paddle_dialect->id()); paddle_dialect->id());
Interface *c_interface = op4->GetOperandByIndex(0) Interface *c_interface = op4->operand(0)
.source() .source()
.type() .type()
.dialect() .dialect()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册