diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 64e7b5c818abf2820827b2d079ffc1d6a6f002cc..b55906b67ed7ff1ac5748976e2ff7b194eb5be7e 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -110,22 +110,8 @@ void BindOperation(py::module *m) { .def("operand", &Operation::operand) .def("result", &Operation::result) .def("operand_source", &Operation::operand_source) - .def("operands", - [](Operation &self) -> py::list { - py::list op_list; - for (uint32_t i = 0; i < self.num_operands(); i++) { - op_list.append(self.operand(i)); - } - return op_list; - }) - .def("results", - [](Operation &self) -> py::list { - py::list op_list; - for (uint32_t i = 0; i < self.num_results(); i++) { - op_list.append(self.result(i)); - } - return op_list; - }) + .def("operands", &Operation::operands) + .def("results", &Operation::results) .def("attrs", [](Operation &self) -> py::dict { py::dict attrs_dict; diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 7968d0c7711078ceec76dd7c18ef3d9250666530..5cdc154a0a5da2d771374eceae8fc35179fd6267 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -264,4 +264,20 @@ void Operation::Verify() { } } +std::vector Operation::operands() const { + std::vector res; + for (uint32_t i = 0; i < num_operands(); ++i) { + res.push_back(operand(i)); + } + return res; +} + +std::vector Operation::results() const { + std::vector res; + for (uint32_t i = 0; i < num_results(); ++i) { + res.push_back(result(i)); + } + return res; +} + } // namespace ir diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index dfc056f9b1cfac9745f5eb4e5e57f98866d7560d..a223c57abdd08f1c91b84fe9ccc307927f2f6022 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -142,6 +142,10 @@ class IR_API alignas(8) Operation final { void Verify(); + std::vector operands() const; + + std::vector results() const; + private: DISABLE_COPY_AND_ASSIGN(Operation); Operation(const AttributeMap &attribute, diff --git a/test/ir/new_ir/test_ir_pybind.py b/test/ir/new_ir/test_ir_pybind.py index 62a4de6289f1dc32b14764a63765a2e539c8503d..feb89cbc88d4220fc168d98f9c47635b1e34815b 100644 --- a/test/ir/new_ir/test_ir_pybind.py +++ b/test/ir/new_ir/test_ir_pybind.py @@ -164,6 +164,18 @@ class TestPybind(unittest.TestCase): self.assertEqual(full_attr["dtype"], paddle.fluid.core.DataType.FLOAT32) self.assertTrue(isinstance(full_attr["place"], paddle.fluid.core.Place)) + def test_operands(self): + newir_program = get_ir_program() + matmul_op = newir_program.block().get_ops()[1] + operands = matmul_op.operands() + self.assertEqual(len(operands), 2) + + def test_results(self): + newir_program = get_ir_program() + matmul_op = newir_program.block().get_ops()[1] + results = matmul_op.results() + self.assertEqual(len(results), 1) + if __name__ == "__main__": unittest.main()