未验证 提交 a5ba0b65 编写于 作者: J jjyaoao 提交者: GitHub

Provide opoperands(), opresults() methods for the Operation module (#55903)

* Provide opoperands(), opresults() methods for the Operation module
Signed-off-by: Njjyaoao <jjyaoao@126.com>

* Update test_ir_pybind.py

---------
Signed-off-by: Njjyaoao <jjyaoao@126.com>
上级 0a1d8c68
......@@ -110,22 +110,8 @@ void BindOperation(py::module *m) {
.def("operand", &Operation::operand)
.def("result", &Operation::result)
.def("operand_source", &Operation::operand_source)
.def("operands",
[](Operation &self) -> py::list {
py::list op_list;
for (uint32_t i = 0; i < self.num_operands(); i++) {
op_list.append(self.operand(i));
}
return op_list;
})
.def("results",
[](Operation &self) -> py::list {
py::list op_list;
for (uint32_t i = 0; i < self.num_results(); i++) {
op_list.append(self.result(i));
}
return op_list;
})
.def("operands", &Operation::operands)
.def("results", &Operation::results)
.def("attrs",
[](Operation &self) -> py::dict {
py::dict attrs_dict;
......
......@@ -264,4 +264,20 @@ void Operation::Verify() {
}
}
std::vector<OpOperand> Operation::operands() const {
std::vector<OpOperand> res;
for (uint32_t i = 0; i < num_operands(); ++i) {
res.push_back(operand(i));
}
return res;
}
std::vector<OpResult> Operation::results() const {
std::vector<OpResult> res;
for (uint32_t i = 0; i < num_results(); ++i) {
res.push_back(result(i));
}
return res;
}
} // namespace ir
......@@ -142,6 +142,10 @@ class IR_API alignas(8) Operation final {
void Verify();
std::vector<OpOperand> operands() const;
std::vector<OpResult> results() const;
private:
DISABLE_COPY_AND_ASSIGN(Operation);
Operation(const AttributeMap &attribute,
......
......@@ -164,6 +164,18 @@ class TestPybind(unittest.TestCase):
self.assertEqual(full_attr["dtype"], paddle.fluid.core.DataType.FLOAT32)
self.assertTrue(isinstance(full_attr["place"], paddle.fluid.core.Place))
def test_operands(self):
newir_program = get_ir_program()
matmul_op = newir_program.block().get_ops()[1]
operands = matmul_op.operands()
self.assertEqual(len(operands), 2)
def test_results(self):
newir_program = get_ir_program()
matmul_op = newir_program.block().get_ops()[1]
results = matmul_op.results()
self.assertEqual(len(results), 1)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册