diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index ead5e6889d9632f826218a99c664850f5120e663..e895275c6aef711d1c95000dea1594bdc89a0df1 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -94,7 +94,12 @@ void BindBlock(py::module *m) { void BindOperation(py::module *m) { py::class_ op(*m, "Operation"); op.def("name", &Operation::name) - .def("get_parent", &Operation::GetParent, return_value_policy::reference) + .def("get_parent", + py::overload_cast<>(&Operation::GetParent), + return_value_policy::reference) + .def("get_parent", + py::overload_cast<>(&Operation::GetParent, py::const_), + return_value_policy::reference) .def("num_results", &Operation::num_results) .def("result", &Operation::result) .def("operands", diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index ffecb9922ad055369d430b593c641519b21e9dc2..7e91c4c0700d441688a71d8b3aa2d31dbcc8afb9 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -210,7 +210,7 @@ Attribute Operation::attribute(const std::string &key) const { return attributes_.at(key); } -Region *Operation::GetParentRegion() const { +Region *Operation::GetParentRegion() { return parent_ ? parent_->GetParent() : nullptr; } @@ -218,8 +218,8 @@ Operation *Operation::GetParentOp() const { return parent_ ? parent_->GetParentOp() : nullptr; } -Program *Operation::GetParentProgram() { - Operation *op = this; +const Program *Operation::GetParentProgram() const { + Operation *op = const_cast(this); while (Operation *parent_op = op->GetParentOp()) { op = parent_op; } diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index 8518853160c8fc094ceae93b981579e7ab8181dc..d47a99486c7e07bd964816b19481ca39c259dff5 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -109,13 +109,23 @@ class IR_API alignas(8) Operation final { return info_.HasInterface(); } - Block *GetParent() const { return parent_; } + const Block *GetParent() const { return parent_; } - Region *GetParentRegion() const; + Block *GetParent() { + return const_cast( + const_cast(this)->GetParent()); + } + + Region *GetParentRegion(); Operation *GetParentOp() const; - Program *GetParentProgram(); + const Program *GetParentProgram() const; + + Program *GetParentProgram() { + return const_cast( + const_cast(this)->GetParentProgram()); + } operator Block::iterator() { return position_; }