From 4fa3e149fec20b0bfd3dac6844999e9e6d1e44e2 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 11 Jul 2023 09:12:02 +0800 Subject: [PATCH] [NewIR]Refine IrPrinter and basic Concept Interface for const Object (#55209) * [NewIR]Refine IrPrinter and basic Concept Interface for const Object --- paddle/fluid/pybind/ir.cc | 7 ++++++- paddle/ir/core/builtin_op.cc | 14 +++++++------- paddle/ir/core/builtin_op.h | 14 +++++++------- paddle/ir/core/dialect.h | 2 +- paddle/ir/core/ir_printer.cc | 28 ++++++++++++++-------------- paddle/ir/core/ir_printer.h | 22 +++++++++++----------- paddle/ir/core/operation.cc | 5 +++++ paddle/ir/core/operation.h | 3 ++- paddle/ir/core/program.h | 5 +++-- test/cpp/ir/core/ir_op_test.cc | 2 +- 10 files changed, 57 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index b2634fff1e3..6f96ceaaf41 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -37,7 +37,12 @@ namespace pybind { void BindProgram(py::module *m) { py::class_ program(*m, "Program"); program.def("parameters_num", &Program::parameters_num) - .def("block", &Program::block, return_value_policy::reference) + .def("block", + py::overload_cast<>(&Program::block), + return_value_policy::reference) + .def("block", + py::overload_cast<>(&Program::block, py::const_), + return_value_policy::reference) .def("print", [](Program &self) { std::ostringstream print_stream; self.Print(print_stream); diff --git a/paddle/ir/core/builtin_op.cc b/paddle/ir/core/builtin_op.cc index 091f0fdebf2..ec394fd4115 100644 --- a/paddle/ir/core/builtin_op.cc +++ b/paddle/ir/core/builtin_op.cc @@ -52,7 +52,7 @@ void ModuleOp::Destroy() { } } -void ModuleOp::Verify() { +void ModuleOp::Verify() const { VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp."; // Verify inputs: IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0."); @@ -79,7 +79,7 @@ void GetParameterOp::Build(Builder &builder, argument.output_types.emplace_back(type); } -void GetParameterOp::Verify() { +void GetParameterOp::Verify() const { VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp."; // Verify inputs: IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0."); @@ -105,7 +105,7 @@ void SetParameterOp::Build(Builder &builder, // NOLINT argument.AddAttribute(attributes_name[0], ir::StrAttribute::get(builder.ir_context(), name)); } -void SetParameterOp::Verify() { +void SetParameterOp::Verify() const { VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp."; // Verify inputs: IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1."); @@ -132,7 +132,7 @@ void CombineOp::Build(Builder &builder, ir::VectorType::get(builder.ir_context(), inputs_type)); } -void CombineOp::Verify() { +void CombineOp::Verify() const { // outputs.size() == 1 IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1."); @@ -162,7 +162,7 @@ void CombineOp::Verify() { } const char *SliceOp::attributes_name[attributes_num] = {"index"}; -void SliceOp::Verify() { +void SliceOp::Verify() const { // inputs.size() == 1 auto input_size = num_operands(); IR_ENFORCE( @@ -217,13 +217,13 @@ void ConstantOp::Build(Builder &builder, argument.output_types.push_back(output_type); } -void ConstantOp::Verify() { +void ConstantOp::Verify() const { IR_ENFORCE(num_operands() == 0, "The size of inputs must be equal to 0."); IR_ENFORCE(num_results() == 1, "The size of outputs must be equal to 1."); IR_ENFORCE(attributes().count("value") > 0, "must has value attribute"); } -Attribute ConstantOp::value() { return attributes().at("value"); } +Attribute ConstantOp::value() const { return attributes().at("value"); } } // namespace ir diff --git a/paddle/ir/core/builtin_op.h b/paddle/ir/core/builtin_op.h index 27f264ff218..0ab058f5aac 100644 --- a/paddle/ir/core/builtin_op.h +++ b/paddle/ir/core/builtin_op.h @@ -30,7 +30,7 @@ class IR_API ModuleOp : public ir::Op { static const char *name() { return "builtin.module"; } static constexpr uint32_t attributes_num = 1; static const char *attributes_name[attributes_num]; - void Verify(); + void Verify() const; Program *program(); Block *block(); @@ -55,7 +55,7 @@ class IR_API GetParameterOp : public ir::Op { OperationArgument &argument, // NOLINT const std::string &name, Type type); - void Verify(); + void Verify() const; }; /// @@ -72,7 +72,7 @@ class IR_API SetParameterOp : public ir::Op { OperationArgument &argument, // NOLINT OpResult parameter, const std::string &name); - void Verify(); + void Verify() const; }; /// @@ -92,7 +92,7 @@ class IR_API CombineOp : public ir::Op { OperationArgument &argument, // NOLINT const std::vector &inputs); - void Verify(); + void Verify() const; }; /// @@ -107,7 +107,7 @@ class IR_API SliceOp : public ir::Op { static constexpr uint32_t attributes_num = 1; static const char *attributes_name[attributes_num]; - void Verify(); + void Verify() const; }; class IR_API ConstantLikeTrait : public OpTraitBase { @@ -132,9 +132,9 @@ class IR_API ConstantOp : public Op { Attribute value, Type output_type); - void Verify(); + void Verify() const; - Attribute value(); + Attribute value() const; }; } // namespace ir diff --git a/paddle/ir/core/dialect.h b/paddle/ir/core/dialect.h index be67898dd98..c1cc54a257b 100644 --- a/paddle/ir/core/dialect.h +++ b/paddle/ir/core/dialect.h @@ -145,7 +145,7 @@ class IR_API Dialect { IR_THROW("dialect has no registered attribute printing hook"); } - virtual void PrintOperation(Operation *op, + virtual void PrintOperation(const Operation *op, IrPrinter &printer) const; // NOLINT private: diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index b72081fd0a4..7a9642bd042 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -115,7 +115,7 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) { } } -void IrPrinter::PrintProgram(Program* program) { +void IrPrinter::PrintProgram(const Program* program) { auto top_level_op = program->module_op(); for (size_t i = 0; i < top_level_op->num_regions(); ++i) { auto& region = top_level_op->region(i); @@ -123,7 +123,7 @@ void IrPrinter::PrintProgram(Program* program) { } } -void IrPrinter::PrintOperation(Operation* op) { +void IrPrinter::PrintOperation(const Operation* op) { if (auto* dialect = op->dialect()) { dialect->PrintOperation(op, *this); return; @@ -132,7 +132,7 @@ void IrPrinter::PrintOperation(Operation* op) { PrintGeneralOperation(op); } -void IrPrinter::PrintGeneralOperation(Operation* op) { +void IrPrinter::PrintGeneralOperation(const Operation* op) { // TODO(lyk): add API to get opresults directly PrintOpResult(op); os << " ="; @@ -153,7 +153,7 @@ void IrPrinter::PrintGeneralOperation(Operation* op) { PrintOpReturnType(op); } -void IrPrinter::PrintFullOperation(Operation* op) { +void IrPrinter::PrintFullOperation(const Operation* op) { PrintOperation(op); if (op->num_regions() > 0) { os << newline; @@ -171,7 +171,7 @@ void IrPrinter::PrintRegion(const Region& region) { } } -void IrPrinter::PrintBlock(Block* block) { +void IrPrinter::PrintBlock(const Block* block) { os << "{\n"; for (auto it = block->begin(); it != block->end(); ++it) { PrintOperation(*it); @@ -180,7 +180,7 @@ void IrPrinter::PrintBlock(Block* block) { os << "}\n"; } -void IrPrinter::PrintValue(Value v) { +void IrPrinter::PrintValue(const Value& v) { if (!v) { os << "<>"; return; @@ -198,7 +198,7 @@ void IrPrinter::PrintValue(Value v) { os << new_name; } -void IrPrinter::PrintOpResult(Operation* op) { +void IrPrinter::PrintOpResult(const Operation* op) { os << " ("; auto num_op_result = op->num_results(); std::vector op_results; @@ -214,7 +214,7 @@ void IrPrinter::PrintOpResult(Operation* op) { os << ")"; } -void IrPrinter::PrintAttributeMap(Operation* op) { +void IrPrinter::PrintAttributeMap(const Operation* op) { os << " {"; PrintInterleave( @@ -230,7 +230,7 @@ void IrPrinter::PrintAttributeMap(Operation* op) { os << "}"; } -void IrPrinter::PrintOpOperands(Operation* op) { +void IrPrinter::PrintOpOperands(const Operation* op) { os << " ("; auto num_op_operands = op->num_operands(); std::vector op_operands; @@ -246,7 +246,7 @@ void IrPrinter::PrintOpOperands(Operation* op) { os << ")"; } -void IrPrinter::PrintOperandsType(Operation* op) { +void IrPrinter::PrintOperandsType(const Operation* op) { auto num_op_operands = op->num_operands(); std::vector op_operand_types; op_operand_types.reserve(num_op_operands); @@ -267,7 +267,7 @@ void IrPrinter::PrintOperandsType(Operation* op) { os << ")"; } -void IrPrinter::PrintOpReturnType(Operation* op) { +void IrPrinter::PrintOpReturnType(const Operation* op) { auto num_op_result = op->num_results(); std::vector op_result_types; op_result_types.reserve(num_op_result); @@ -286,16 +286,16 @@ void IrPrinter::PrintOpReturnType(Operation* op) { [this]() { this->os << ", "; }); } -void Dialect::PrintOperation(Operation* op, IrPrinter& printer) const { +void Dialect::PrintOperation(const Operation* op, IrPrinter& printer) const { printer.PrintGeneralOperation(op); } -void Program::Print(std::ostream& os) { +void Program::Print(std::ostream& os) const { IrPrinter printer(os); printer.PrintProgram(this); } -void Operation::Print(std::ostream& os) { +void Operation::Print(std::ostream& os) const { IrPrinter printer(os); printer.PrintFullOperation(this); } diff --git a/paddle/ir/core/ir_printer.h b/paddle/ir/core/ir_printer.h index e60e8fa5cfd..d3f868946dd 100644 --- a/paddle/ir/core/ir_printer.h +++ b/paddle/ir/core/ir_printer.h @@ -46,29 +46,29 @@ class IR_API IrPrinter : public BasicIrPrinter { /// @brief print program /// @param program - void PrintProgram(Program* program); + void PrintProgram(const Program* program); /// @brief dispatch to custom printer function or PrintGeneralOperation - void PrintOperation(Operation* op); + void PrintOperation(const Operation* op); /// @brief print operation itself without its regions - void PrintGeneralOperation(Operation* op); + void PrintGeneralOperation(const Operation* op); /// @brief print operation and its regions - void PrintFullOperation(Operation* op); + void PrintFullOperation(const Operation* op); void PrintRegion(const Region& Region); - void PrintBlock(Block* block); + void PrintBlock(const Block* block); - void PrintValue(Value v); + void PrintValue(const Value& v); - void PrintOpResult(Operation* op); + void PrintOpResult(const Operation* op); - void PrintAttributeMap(Operation* op); + void PrintAttributeMap(const Operation* op); - void PrintOpOperands(Operation* op); + void PrintOpOperands(const Operation* op); - void PrintOperandsType(Operation* op); + void PrintOperandsType(const Operation* op); - void PrintOpReturnType(Operation* op); + void PrintOpReturnType(const Operation* op); private: size_t cur_var_number_{0}; diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 3600f9a55dd..fd0f2bf9d99 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -232,6 +232,11 @@ Region &Operation::region(unsigned index) { return regions_[index]; } +const Region &Operation::region(unsigned index) const { + assert(index < num_regions_ && "invalid region index"); + return regions_[index]; +} + void Operation::SetParent(Block *parent, const Block::iterator &position) { parent_ = parent; position_ = position; diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index 711434220ef..634f01eeb9b 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -59,8 +59,9 @@ class IR_API alignas(8) Operation final { /// Returns the region held by this operation at position 'index'. Region ®ion(unsigned index); + const Region ®ion(unsigned index) const; - void Print(std::ostream &os); + void Print(std::ostream &os) const; const AttributeMap &attributes() const { return attributes_; } diff --git a/paddle/ir/core/program.h b/paddle/ir/core/program.h index a65142b2531..0e2ecb58d91 100644 --- a/paddle/ir/core/program.h +++ b/paddle/ir/core/program.h @@ -48,11 +48,12 @@ class IR_API Program { ~Program(); size_t parameters_num() const { return parameters_.size(); } - ModuleOp module_op() { return module_; } + ModuleOp module_op() const { return module_; } - void Print(std::ostream& os); + void Print(std::ostream& os) const; Block* block() { return module_.block(); } + const Block* block() const { return module_op().block(); } Parameter* GetParameter(const std::string& name) const; void SetParameter(const std::string& name, diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index c7f9c5e8af2..0f530e41b8a 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -155,7 +155,7 @@ class TestDialect : public ir::Dialect { } static const char *name() { return "test"; } - void PrintOperation(ir::Operation *op, + void PrintOperation(const ir::Operation *op, ir::IrPrinter &printer) const override { printer.PrintOpResult(op); printer.os << " ="; -- GitLab