From ef29468e683cfd2e8aaf1d3841c0cb7c4bc41274 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Wed, 2 Aug 2023 09:09:46 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90new=20ir=E3=80=91add=20ir=20pybind=20a?= =?UTF-8?q?pi=20=20(#55745)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add ir core * add test * modify name * merge * add test for __eq__ * shield test for __eq__ * --amend * Update new_ir_compiler.cc --- paddle/cinn/hlir/framework/new_ir_compiler.cc | 6 +-- .../instruction/phi_kernel_instruction.cc | 4 +- .../op_generator/op_member_func_gen.py | 2 +- .../ir/dialect/op_generator/op_verify_gen.py | 10 ++--- .../ir/phi_kernel_adaptor/phi_kernel_util.cc | 16 ++++---- .../ir/phi_kernel_adaptor/phi_kernel_util.h | 4 +- .../ir/transforms/constant_folding_pass.cc | 9 +++-- .../ir/transforms/pd_op_to_kernel_pass.cc | 8 ++-- .../transforms/transform_general_functions.cc | 2 +- paddle/fluid/pybind/ir.cc | 26 +++++++++--- paddle/ir/core/ir_printer.cc | 4 +- paddle/ir/core/op_base.h | 4 +- paddle/ir/core/operation.cc | 10 ++--- paddle/ir/core/operation.h | 4 +- .../pattern_rewrite/pattern_rewrite_driver.cc | 2 +- test/cpp/ir/core/ir_program_test.cc | 4 +- test/cpp/ir/core/ir_value_test.cc | 8 ++-- .../pattern_rewrite/pattern_rewrite_test.cc | 40 +++++++++++-------- test/ir/new_ir/test_ir_pybind.py | 23 +++++++++-- 19 files changed, 114 insertions(+), 72 deletions(-) diff --git a/paddle/cinn/hlir/framework/new_ir_compiler.cc b/paddle/cinn/hlir/framework/new_ir_compiler.cc index 9c3806b9a35..34786807c95 100644 --- a/paddle/cinn/hlir/framework/new_ir_compiler.cc +++ b/paddle/cinn/hlir/framework/new_ir_compiler.cc @@ -79,7 +79,7 @@ std::vector NewIRCompiler::GetOpFunc(const ::ir::Operation& op, VLOG(4) << "GetOpFunc for op: " << op_name; // step 1: Deal with Oprands for (int i = 0; i < op.num_operands(); ++i) { - auto in_value = op.operand(i); + auto in_value = op.operand_source(i); // TODO(Aurelius84): For now, use addr as name but it's not wise. std::string input_id = CompatibleInfo::kInputPrefix + std::to_string(std::hash<::ir::Value>()(in_value)); @@ -215,7 +215,7 @@ std::vector NewIRCompiler::OpGetInputNames( std::vector names; std::unordered_set repeat; for (int i = 0; i < op.num_operands(); ++i) { - auto value = op.operand(i); + auto value = op.operand_source(i); std::string name = CompatibleInfo::kInputPrefix + std::to_string(std::hash<::ir::Value>()(value)); if (repeat.count(name)) { @@ -264,7 +264,7 @@ std::shared_ptr BuildScope(const Target& target, for (auto it = program.block()->begin(); it != program.block()->end(); ++it) { for (auto i = 0; i < (*it)->num_operands(); ++i) { - auto in_value = (*it)->operand(i); + auto in_value = (*it)->operand_source(i); create_var(CompatibleInfo::kInputPrefix, in_value); } diff --git a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc index f1e2f894afc..4121cded089 100644 --- a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc @@ -266,7 +266,7 @@ PhiKernelInstruction::PhiKernelInstruction( auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds(); std::unordered_set<::ir::Value> no_need_buffer_values; for (size_t id = 0; id < no_need_buffer_ids.size(); id++) { - no_need_buffer_values.insert(op->operand(no_need_buffer_ids[id])); + no_need_buffer_values.insert(op->operand_source(no_need_buffer_ids[id])); } SetNoNeedBuffer(no_need_buffer_values); VLOG(6) << "finish process no need buffer"; @@ -302,7 +302,7 @@ void PhiKernelInstruction::InitInputsOutputsIds( variable_2_var_name) { std::unordered_map> inputs; for (size_t i = 0; i < op->num_operands(); i++) { - ir::Value value = op->operand(i); + ir::Value value = op->operand_source(i); if (value) { PADDLE_ENFORCE_NE( value_2_var_name.find(value), diff --git a/paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py b/paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py index 30d35a5f6e7..9bc2c75ccf8 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py @@ -14,7 +14,7 @@ # generator op member function -OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand({input_index}); }} +OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand_source({input_index}); }} """ OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return result({output_index}); }} """ diff --git a/paddle/fluid/ir/dialect/op_generator/op_verify_gen.py b/paddle/fluid/ir/dialect/op_generator/op_verify_gen.py index f5f0711534f..917728f2c8b 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_verify_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_verify_gen.py @@ -40,26 +40,26 @@ void {op_name}::Verify() {{}} """ INPUT_TYPE_CHECK_TEMPLATE = """ - PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(), + PADDLE_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(), phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));""" INPUT_VECTORTYPE_CHECK_TEMPLATE = """ - if (auto vec_type = (*this)->operand({index}).type().dyn_cast()) {{ + if (auto vec_type = (*this)->operand_source({index}).type().dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); ++i) {{ PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); }} }} else {{ - PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(), + PADDLE_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(), phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); }}""" INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """ - if (auto val = (*this)->op_operand({index})) {{ + if (auto val = (*this)->operand({index})) {{ PADDLE_ENFORCE(val.type().isa<{standard}>(), phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); }}""" INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """ - if (auto val = (*this)->op_operand({index})) {{ + if (auto val = (*this)->operand({index})) {{ if (auto vec_type = val.type().dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); i++) {{ PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index 5b621b9da30..7c2a4335762 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -140,7 +140,7 @@ void CheckInputVars( size_t input_num = op->num_operands(); if (input_num > 0) { for (size_t i = 0; i < input_num; ++i) { - auto value = op->operand(i); + auto value = op->operand_source(i); if (value) { PADDLE_ENFORCE_NE( value_2_var_name.find(value), @@ -298,7 +298,7 @@ void HandleForSpecialOp( tensor_array->clear(); size_t input_num = op->num_operands(); for (size_t i = 0; i < input_num; ++i) { - auto value = op->operand(i); + auto value = op->operand_source(i); PADDLE_ENFORCE_EQ( value_2_var_name->count(value), true, @@ -315,7 +315,7 @@ void HandleForSpecialOp( .dyn_cast() .AsString(); - auto value = op->operand(0); + auto value = op->operand_source(0); // change opreand name to param_name auto orig_name = value_2_var_name->at(value); @@ -336,7 +336,7 @@ void HandleForSpecialOp( auto var_name = op->attributes().at("name").dyn_cast().AsString(); - auto value = op->operand(0); + auto value = op->operand_source(0); // change opreand name to param_name auto orig_name = value_2_var_name->at(value); @@ -372,7 +372,7 @@ void HandleForSpecialOp( if (op_name == "builtin.slice") { VLOG(6) << "Handle for builtin.slice"; auto out_value = op->result(0); - auto in_value = op->operand(0); + auto in_value = op->operand_source(0); PADDLE_ENFORCE_EQ(value_2_var_name->count(in_value), true, phi::errors::PreconditionNotMet( @@ -426,7 +426,7 @@ void HandleForInplaceOp( if (yaml_parser.HasInplace(value_name)) { std::string inplace_name = yaml_parser.InplaceName(value_name); ir::Value inplace_value = - op->operand(yaml_parser.InputName2Id().at(inplace_name)); + op->operand_source(yaml_parser.InputName2Id().at(inplace_name)); std::string var_name = value_2_var_name->at(inplace_value); VLOG(4) << "inplace: " << value_name << " -> " << inplace_name << " (var: " << var_name << ")"; @@ -547,7 +547,7 @@ void BuildRuntimeContext( true, phi::errors::NotFound("param [%s] MUST in name2id map", name)); auto index = op_yaml_info.InputName2Id().at(name); - ir::Value ptr = op->operand(index); + ir::Value ptr = op->operand_source(index); auto in_var_name = name_map.at(ptr); VLOG(6) << "ctx->EmplaceBackInput: " << name << "\t" << in_var_name; @@ -603,7 +603,7 @@ std::shared_ptr BuildOperatorBase( true, phi::errors::NotFound("param [%s] MUST in name2id map", name)); auto index = op_yaml_info.InputName2Id().at(name); - ir::Value ptr = op->operand(index); + ir::Value ptr = op->operand_source(index); auto in_var_name = name_map.at(ptr); diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h index 91997dd341c..2b7e61ecce7 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h @@ -92,7 +92,7 @@ void BuildPhiContext(ir::Operation* op, true, phi::errors::NotFound("param [%s] MUST in name2id map", t)); auto index = op_yaml_info.InputName2Id().at(t); - ir::Value ptr = op->operand(index); + ir::Value ptr = op->operand_source(index); if (!ptr) { phi::DenseTensor* ptr = nullptr; OutType in_ptr(ptr); @@ -128,7 +128,7 @@ void BuildPhiContext(ir::Operation* op, for (auto& t : vec_kernel_fn_attr_params) { if (name2id.count(t)) { // tensor attribute, get information from input - ir::Value ptr = op->operand(name2id.at(t)); + ir::Value ptr = op->operand_source(name2id.at(t)); auto in_var_name = name_map.at(ptr); diff --git a/paddle/fluid/ir/transforms/constant_folding_pass.cc b/paddle/fluid/ir/transforms/constant_folding_pass.cc index 3fcdee6748b..3f13621b61d 100644 --- a/paddle/fluid/ir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/ir/transforms/constant_folding_pass.cc @@ -114,12 +114,13 @@ class ConstantFoldingPattern : public ir::RewritePattern { std::vector op_inputs; for (uint32_t i = 0; i < op->num_operands(); i++) { PADDLE_ENFORCE_EQ( - op->operand(i).type().isa(), + op->operand_source(i).type().isa(), true, phi::errors::InvalidArgument( "Op's input must be a dense tensor type.")); - auto [param_name, param] = ir::GetParameterFromValue(op->operand(i)); + auto [param_name, param] = + ir::GetParameterFromValue(op->operand_source(i)); program->SetParameter(param_name, std::make_unique(*param)); @@ -128,8 +129,8 @@ class ConstantFoldingPattern : public ir::RewritePattern { param_var, phi::errors::InvalidArgument("Parameter var not in scope.")); - auto get_parameter_op = - builder.Build(param_name, op->operand(i).type()); + auto get_parameter_op = builder.Build( + param_name, op->operand_source(i).type()); op_inputs.push_back(get_parameter_op->result(0)); } diff --git a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc index 1d479884d85..d2f7a58ecb3 100644 --- a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc @@ -97,7 +97,7 @@ phi::KernelKey GetKernelKey( } else if (input_map.count(slot_name)) { // parse from input int in_index = input_map.at(slot_name); - auto type = map_value_pair.at(op->operand(in_index)).type(); + auto type = map_value_pair.at(op->operand_source(in_index)).type(); if (type.isa()) { kernel_data_type = TransToPhiDataType( @@ -151,7 +151,7 @@ phi::KernelKey GetKernelKey( if (op->name() == "pd.uniform") { // try to process uniform, use shape to determin backend // TODO(phlrain): shuold support other initilize op - auto define_op = op->operand(0).GetDefiningOp(); + auto define_op = op->operand_source(0).GetDefiningOp(); if (define_op->name() == "pd.full_int_array") { auto shape = define_op->attributes() .at("value") @@ -183,7 +183,7 @@ phi::KernelKey GetKernelKey( if (op_info_parser != nullptr && op_info_parser->IsTensorAttribute(i)) { continue; } - auto input_tmp = op->operand(i); + auto input_tmp = op->operand_source(i); // NOTE: if not input_tmp, it's an optional input if (!input_tmp) { continue; @@ -341,7 +341,7 @@ std::unique_ptr PdOpLowerToKernelPass(ir::Program* prog, if ((*it)->num_operands() > 0) { for (size_t i = 0; i < (*it)->num_operands(); ++i) { - auto cur_in = (*it)->operand(i); + auto cur_in = (*it)->operand_source(i); if (!cur_in) { vec_inputs.push_back(ir::OpResult()); continue; diff --git a/paddle/fluid/ir/transforms/transform_general_functions.cc b/paddle/fluid/ir/transforms/transform_general_functions.cc index 966e4035fc3..2937e55065e 100644 --- a/paddle/fluid/ir/transforms/transform_general_functions.cc +++ b/paddle/fluid/ir/transforms/transform_general_functions.cc @@ -64,7 +64,7 @@ Operation* GetDefiningOpForInput(Operation* op, uint32_t index) { index < op->num_operands(), true, phi::errors::InvalidArgument("Intput operand's index must be valid.")); - return op->operand(index).GetDefiningOp(); + return op->operand_source(index).GetDefiningOp(); } Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index) { diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index e895275c6ae..8f805bf06ec 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -77,6 +77,8 @@ void BindProgram(py::module *m) { void BindBlock(py::module *m) { py::class_ block(*m, "Block"); block.def("front", &Block::front, return_value_policy::reference) + .def("get_parent_program", + [](Block &self) { return self.GetParentOp()->GetParentProgram(); }) .def("get_ops", [](Block &self) -> py::list { py::list op_list; @@ -94,19 +96,22 @@ void BindBlock(py::module *m) { void BindOperation(py::module *m) { py::class_ op(*m, "Operation"); op.def("name", &Operation::name) - .def("get_parent", + .def("get_parent_block", py::overload_cast<>(&Operation::GetParent), return_value_policy::reference) - .def("get_parent", + .def("get_parent_block", py::overload_cast<>(&Operation::GetParent, py::const_), return_value_policy::reference) + .def("num_operands", &Operation::num_operands) .def("num_results", &Operation::num_results) + .def("operand", &Operation::operand) .def("result", &Operation::result) + .def("operand_source", &Operation::operand_source) .def("operands", [](Operation &self) -> py::list { py::list op_list; for (uint32_t i = 0; i < self.num_operands(); i++) { - op_list.append(self.op_operand(i)); + op_list.append(self.operand(i)); } return op_list; }) @@ -118,6 +123,14 @@ void BindOperation(py::module *m) { } return op_list; }) + .def("operands_source", + [](Operation &self) -> py::list { + py::list op_list; + for (uint32_t i = 0; i < self.num_operands(); i++) { + op_list.append(self.operand_source(i)); + } + return op_list; + }) .def("get_input_names", [](Operation &self) -> py::list { py::list op_list; @@ -159,8 +172,11 @@ void BindOperation(py::module *m) { void BindValue(py::module *m) { py::class_ value(*m, "Value"); - value.def( - "get_defining_op", &Value::GetDefiningOp, return_value_policy::reference); + value + .def("get_defining_op", + &Value::GetDefiningOp, + return_value_policy::reference) + .def("__eq__", &Value::operator==); } void BindOpOperand(py::module *m) { diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index 88903645655..545cb63fa03 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -239,7 +239,7 @@ void IrPrinter::PrintOpOperands(const Operation* op) { std::vector op_operands; op_operands.reserve(num_op_operands); for (size_t idx = 0; idx < num_op_operands; idx++) { - op_operands.push_back(op->operand(idx)); + op_operands.push_back(op->operand_source(idx)); } PrintInterleave( op_operands.begin(), @@ -254,7 +254,7 @@ void IrPrinter::PrintOperandsType(const Operation* op) { std::vector op_operand_types; op_operand_types.reserve(num_op_operands); for (size_t idx = 0; idx < num_op_operands; idx++) { - auto op_operand = op->op_operand(idx); + auto op_operand = op->operand(idx); if (op_operand) { op_operand_types.push_back(op_operand.type()); } else { diff --git a/paddle/ir/core/op_base.h b/paddle/ir/core/op_base.h index f1a984a85f7..bfa7705e6ff 100644 --- a/paddle/ir/core/op_base.h +++ b/paddle/ir/core/op_base.h @@ -88,7 +88,9 @@ class IR_API OpBase { const AttributeMap &attributes() const { return operation()->attributes(); } - Value operand(uint32_t index) const { return operation()->operand(index); } + Value operand_source(uint32_t index) const { + return operation()->operand_source(index); + } OpResult result(uint32_t index) const { return operation()->result(index); } diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 7e91c4c0700..7968d0c7711 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -40,7 +40,7 @@ Operation *Operation::Create(OperationArgument &&argument) { // Allocate the required memory based on the size and number of inputs, outputs, // and operators, and construct it in the order of: OpOutlineResult, -// OpInlineResult, Operation, Operand. +// OpInlineResult, Operation, operand. Operation *Operation::Create(const std::vector &inputs, const AttributeMap &attributes, const std::vector &output_types, @@ -132,7 +132,7 @@ void Operation::Destroy() { // 4. Deconstruct OpOperand. for (size_t idx = 0; idx < num_operands_; idx++) { - op_operand(idx).impl()->~OpOperandImpl(); + operand(idx).impl()->~OpOperandImpl(); } // 5. Free memory. uint32_t max_inline_result_num = @@ -186,7 +186,7 @@ ir::OpResult Operation::result(uint32_t index) const { } } -OpOperand Operation::op_operand(uint32_t index) const { +OpOperand Operation::operand(uint32_t index) const { if (index >= num_operands_) { IR_THROW("index exceeds OP input range."); } @@ -195,8 +195,8 @@ OpOperand Operation::op_operand(uint32_t index) const { return OpOperand(reinterpret_cast(ptr)); } -Value Operation::operand(uint32_t index) const { - OpOperand val = op_operand(index); +Value Operation::operand_source(uint32_t index) const { + OpOperand val = operand(index); return val ? val.source() : Value(); } diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index d47a99486c7..dfc056f9b1c 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -55,9 +55,9 @@ class IR_API alignas(8) Operation final { OpResult result(uint32_t index) const; - OpOperand op_operand(uint32_t index) const; + OpOperand operand(uint32_t index) const; - Value operand(uint32_t index) const; + Value operand_source(uint32_t index) const; /// Returns the region held by this operation at position 'index'. Region ®ion(unsigned index); diff --git a/paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc b/paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc index 8ee6c8886f6..4185cdade1c 100644 --- a/paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc +++ b/paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc @@ -131,7 +131,7 @@ class GreedyPatternRewriteDriver : public ir::PatternRewriter { void NotifyOperationRemoved(ir::Operation* op) override { for (uint32_t i = 0; i < op->num_operands(); ++i) { - AddOperandToWorklist(op->operand(i)); + AddOperandToWorklist(op->operand_source(i)); } for (uint32_t i = 0; i < op->num_regions(); ++i) { auto& region = op->region(i); diff --git a/test/cpp/ir/core/ir_program_test.cc b/test/cpp/ir/core/ir_program_test.cc index 2af40feaeec..ee5afc66863 100644 --- a/test/cpp/ir/core/ir_program_test.cc +++ b/test/cpp/ir/core/ir_program_test.cc @@ -174,9 +174,9 @@ TEST(program_test, program) { // (8) Def SetParameterOp(c, "c") auto op4 = builder.Build(op3->result(0), "c"); - EXPECT_EQ(op4->op_operand(0).type().dialect().id(), paddle_dialect->id()); + EXPECT_EQ(op4->operand(0).type().dialect().id(), paddle_dialect->id()); Interface *c_interface = - op4->op_operand(0).type().dialect().GetRegisteredInterface(); + op4->operand(0).type().dialect().GetRegisteredInterface(); // ir::Parameter *parameter_c = // c_interface->VariableToParameter(variable_c.get()); std::unique_ptr parameter_c = diff --git a/test/cpp/ir/core/ir_value_test.cc b/test/cpp/ir/core/ir_value_test.cc index 3ad5c501464..3f90e3a4fd6 100644 --- a/test/cpp/ir/core/ir_value_test.cc +++ b/test/cpp/ir/core/ir_value_test.cc @@ -91,10 +91,10 @@ TEST(value_test, value_test) { // Test 2: op1_first_output -> op4_first_input ir::OpResult op1_first_output = op1->result(0); - ir::OpOperand op4_first_input = op4->op_operand(0); + ir::OpOperand op4_first_input = op4->operand(0); EXPECT_EQ(op1_first_output.first_use(), op4_first_input); - ir::OpOperand op3_first_input = op3->op_operand(0); + ir::OpOperand op3_first_input = op3->operand(0); EXPECT_EQ(op4_first_input.next_use(), op3_first_input); EXPECT_EQ(op3_first_input.next_use(), nullptr); @@ -110,11 +110,11 @@ TEST(value_test, value_test) { // a = OP1(); b = OP2(); c = OP3(a, b); d, e, f, g, h, i, j = OP4(a, c); // c.ReplaceUsesWithIf(b, [](ir::OpOperand) { return true; }); - EXPECT_EQ(op4->operand(1), b); + EXPECT_EQ(op4->operand_source(1), b); EXPECT_TRUE(c.use_empty()); b.ReplaceAllUsesWith(a); - EXPECT_EQ(op4->operand(1), a); + EXPECT_EQ(op4->operand_source(1), a); EXPECT_TRUE(b.use_empty()); // destroy diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc index ebb6144753e..9495a605523 100644 --- a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -386,10 +386,10 @@ class Conv2dFusionOpTest : public ir::Opoperand(0).type().isa(), - phi::errors::PreconditionNotMet( - "Type validation failed for the 0th input.")); - PADDLE_ENFORCE( - (*this)->operand(1).type().isa(), - phi::errors::PreconditionNotMet( - "Type validation failed for the 1th input.")); - PADDLE_ENFORCE( - (*this)->operand(2).type().isa(), - phi::errors::PreconditionNotMet( - "Type validation failed for the 2th input.")); - if (auto val = (*this)->op_operand(3)) { + PADDLE_ENFORCE((*this) + ->operand_source(0) + .type() + .isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th input.")); + PADDLE_ENFORCE((*this) + ->operand_source(1) + .type() + .isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 1th input.")); + PADDLE_ENFORCE((*this) + ->operand_source(2) + .type() + .isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 2th input.")); + if (auto val = (*this)->operand(3)) { PADDLE_ENFORCE(val.type().isa(), phi::errors::PreconditionNotMet( "Type validation failed for the 3th input.")); diff --git a/test/ir/new_ir/test_ir_pybind.py b/test/ir/new_ir/test_ir_pybind.py index 90bf976ac3f..4d32b39ce21 100644 --- a/test/ir/new_ir/test_ir_pybind.py +++ b/test/ir/new_ir/test_ir_pybind.py @@ -30,8 +30,8 @@ def get_ir_program(): x_s = paddle.static.data('x', [4, 4], x.dtype) x_s.stop_gradient = False y_s = paddle.matmul(x_s, x_s) - y_s = paddle.add(x_s, y_s) - y_s = paddle.tanh(y_s) + z_s = paddle.add(y_s, y_s) + k_s = paddle.tanh(z_s) newir_program = ir.translate_to_new_ir(main_program.desc) return newir_program @@ -41,6 +41,11 @@ class TestPybind(unittest.TestCase): newir_program = get_ir_program() print(newir_program) + block = newir_program.block() + program = block.get_parent_program() + + self.assertEqual(newir_program, program) + def test_block(self): newir_program = get_ir_program() block = newir_program.block() @@ -57,7 +62,7 @@ class TestPybind(unittest.TestCase): matmul_op = newir_program.block().get_ops()[1] add_op = newir_program.block().get_ops()[2] tanh_op = newir_program.block().get_ops()[3] - parent_block = tanh_op.get_parent() + parent_block = tanh_op.get_parent_block() parent_ops_num = len(parent_block.get_ops()) self.assertEqual(parent_ops_num, 4) self.assertEqual(tanh_op.num_results(), 1) @@ -79,6 +84,13 @@ class TestPybind(unittest.TestCase): matmul_op.result(0).set_stop_gradient(True) self.assertEqual(matmul_op.result(0).get_stop_gradient(), True) + result_set = set() + for opresult in matmul_op.results(): + result_set.add(opresult) + + # self.assertTrue(add_op.operands()[0].source() in result_set) + # self.assertEqual(add_op.operands_source()[0] , matmul_op.results()[0],) + self.assertEqual( tanh_op.operands()[0].source().get_defining_op().name(), "pd.add" ) @@ -87,6 +99,11 @@ class TestPybind(unittest.TestCase): self.assertEqual( tanh_op.operands()[0].source().get_defining_op().name(), "pd.matmul" ) + + self.assertEqual( + tanh_op.operands()[0].source().get_defining_op(), + tanh_op.operands_source()[0].get_defining_op(), + ) self.assertEqual(add_op.result(0).use_empty(), True) def test_type(self): -- GitLab