未验证 提交 4fa3e149 编写于 作者: A Aurelius84 提交者: GitHub

[NewIR]Refine IrPrinter and basic Concept Interface for const Object (#55209)

* [NewIR]Refine IrPrinter and basic Concept Interface for const Object
上级 b20d22df
...@@ -37,7 +37,12 @@ namespace pybind { ...@@ -37,7 +37,12 @@ namespace pybind {
void BindProgram(py::module *m) { void BindProgram(py::module *m) {
py::class_<Program> program(*m, "Program"); py::class_<Program> program(*m, "Program");
program.def("parameters_num", &Program::parameters_num) program.def("parameters_num", &Program::parameters_num)
.def("block", &Program::block, return_value_policy::reference) .def("block",
py::overload_cast<>(&Program::block),
return_value_policy::reference)
.def("block",
py::overload_cast<>(&Program::block, py::const_),
return_value_policy::reference)
.def("print", [](Program &self) { .def("print", [](Program &self) {
std::ostringstream print_stream; std::ostringstream print_stream;
self.Print(print_stream); self.Print(print_stream);
......
...@@ -52,7 +52,7 @@ void ModuleOp::Destroy() { ...@@ -52,7 +52,7 @@ void ModuleOp::Destroy() {
} }
} }
void ModuleOp::Verify() { void ModuleOp::Verify() const {
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
// Verify inputs: // Verify inputs:
IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0."); IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
...@@ -79,7 +79,7 @@ void GetParameterOp::Build(Builder &builder, ...@@ -79,7 +79,7 @@ void GetParameterOp::Build(Builder &builder,
argument.output_types.emplace_back(type); argument.output_types.emplace_back(type);
} }
void GetParameterOp::Verify() { void GetParameterOp::Verify() const {
VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
// Verify inputs: // Verify inputs:
IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0."); IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
...@@ -105,7 +105,7 @@ void SetParameterOp::Build(Builder &builder, // NOLINT ...@@ -105,7 +105,7 @@ void SetParameterOp::Build(Builder &builder, // NOLINT
argument.AddAttribute(attributes_name[0], argument.AddAttribute(attributes_name[0],
ir::StrAttribute::get(builder.ir_context(), name)); ir::StrAttribute::get(builder.ir_context(), name));
} }
void SetParameterOp::Verify() { void SetParameterOp::Verify() const {
VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
// Verify inputs: // Verify inputs:
IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1."); IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1.");
...@@ -132,7 +132,7 @@ void CombineOp::Build(Builder &builder, ...@@ -132,7 +132,7 @@ void CombineOp::Build(Builder &builder,
ir::VectorType::get(builder.ir_context(), inputs_type)); ir::VectorType::get(builder.ir_context(), inputs_type));
} }
void CombineOp::Verify() { void CombineOp::Verify() const {
// outputs.size() == 1 // outputs.size() == 1
IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1."); IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1.");
...@@ -162,7 +162,7 @@ void CombineOp::Verify() { ...@@ -162,7 +162,7 @@ void CombineOp::Verify() {
} }
const char *SliceOp::attributes_name[attributes_num] = {"index"}; const char *SliceOp::attributes_name[attributes_num] = {"index"};
void SliceOp::Verify() { void SliceOp::Verify() const {
// inputs.size() == 1 // inputs.size() == 1
auto input_size = num_operands(); auto input_size = num_operands();
IR_ENFORCE( IR_ENFORCE(
...@@ -217,13 +217,13 @@ void ConstantOp::Build(Builder &builder, ...@@ -217,13 +217,13 @@ void ConstantOp::Build(Builder &builder,
argument.output_types.push_back(output_type); argument.output_types.push_back(output_type);
} }
void ConstantOp::Verify() { void ConstantOp::Verify() const {
IR_ENFORCE(num_operands() == 0, "The size of inputs must be equal to 0."); IR_ENFORCE(num_operands() == 0, "The size of inputs must be equal to 0.");
IR_ENFORCE(num_results() == 1, "The size of outputs must be equal to 1."); IR_ENFORCE(num_results() == 1, "The size of outputs must be equal to 1.");
IR_ENFORCE(attributes().count("value") > 0, "must has value attribute"); IR_ENFORCE(attributes().count("value") > 0, "must has value attribute");
} }
Attribute ConstantOp::value() { return attributes().at("value"); } Attribute ConstantOp::value() const { return attributes().at("value"); }
} // namespace ir } // namespace ir
......
...@@ -30,7 +30,7 @@ class IR_API ModuleOp : public ir::Op<ModuleOp> { ...@@ -30,7 +30,7 @@ class IR_API ModuleOp : public ir::Op<ModuleOp> {
static const char *name() { return "builtin.module"; } static const char *name() { return "builtin.module"; }
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
void Verify(); void Verify() const;
Program *program(); Program *program();
Block *block(); Block *block();
...@@ -55,7 +55,7 @@ class IR_API GetParameterOp : public ir::Op<GetParameterOp> { ...@@ -55,7 +55,7 @@ class IR_API GetParameterOp : public ir::Op<GetParameterOp> {
OperationArgument &argument, // NOLINT OperationArgument &argument, // NOLINT
const std::string &name, const std::string &name,
Type type); Type type);
void Verify(); void Verify() const;
}; };
/// ///
...@@ -72,7 +72,7 @@ class IR_API SetParameterOp : public ir::Op<SetParameterOp> { ...@@ -72,7 +72,7 @@ class IR_API SetParameterOp : public ir::Op<SetParameterOp> {
OperationArgument &argument, // NOLINT OperationArgument &argument, // NOLINT
OpResult parameter, OpResult parameter,
const std::string &name); const std::string &name);
void Verify(); void Verify() const;
}; };
/// ///
...@@ -92,7 +92,7 @@ class IR_API CombineOp : public ir::Op<CombineOp> { ...@@ -92,7 +92,7 @@ class IR_API CombineOp : public ir::Op<CombineOp> {
OperationArgument &argument, // NOLINT OperationArgument &argument, // NOLINT
const std::vector<ir::OpResult> &inputs); const std::vector<ir::OpResult> &inputs);
void Verify(); void Verify() const;
}; };
/// ///
...@@ -107,7 +107,7 @@ class IR_API SliceOp : public ir::Op<SliceOp> { ...@@ -107,7 +107,7 @@ class IR_API SliceOp : public ir::Op<SliceOp> {
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
void Verify(); void Verify() const;
}; };
class IR_API ConstantLikeTrait : public OpTraitBase<ConstantLikeTrait> { class IR_API ConstantLikeTrait : public OpTraitBase<ConstantLikeTrait> {
...@@ -132,9 +132,9 @@ class IR_API ConstantOp : public Op<ConstantOp, ConstantLikeTrait> { ...@@ -132,9 +132,9 @@ class IR_API ConstantOp : public Op<ConstantOp, ConstantLikeTrait> {
Attribute value, Attribute value,
Type output_type); Type output_type);
void Verify(); void Verify() const;
Attribute value(); Attribute value() const;
}; };
} // namespace ir } // namespace ir
......
...@@ -145,7 +145,7 @@ class IR_API Dialect { ...@@ -145,7 +145,7 @@ class IR_API Dialect {
IR_THROW("dialect has no registered attribute printing hook"); IR_THROW("dialect has no registered attribute printing hook");
} }
virtual void PrintOperation(Operation *op, virtual void PrintOperation(const Operation *op,
IrPrinter &printer) const; // NOLINT IrPrinter &printer) const; // NOLINT
private: private:
......
...@@ -115,7 +115,7 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) { ...@@ -115,7 +115,7 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) {
} }
} }
void IrPrinter::PrintProgram(Program* program) { void IrPrinter::PrintProgram(const Program* program) {
auto top_level_op = program->module_op(); auto top_level_op = program->module_op();
for (size_t i = 0; i < top_level_op->num_regions(); ++i) { for (size_t i = 0; i < top_level_op->num_regions(); ++i) {
auto& region = top_level_op->region(i); auto& region = top_level_op->region(i);
...@@ -123,7 +123,7 @@ void IrPrinter::PrintProgram(Program* program) { ...@@ -123,7 +123,7 @@ void IrPrinter::PrintProgram(Program* program) {
} }
} }
void IrPrinter::PrintOperation(Operation* op) { void IrPrinter::PrintOperation(const Operation* op) {
if (auto* dialect = op->dialect()) { if (auto* dialect = op->dialect()) {
dialect->PrintOperation(op, *this); dialect->PrintOperation(op, *this);
return; return;
...@@ -132,7 +132,7 @@ void IrPrinter::PrintOperation(Operation* op) { ...@@ -132,7 +132,7 @@ void IrPrinter::PrintOperation(Operation* op) {
PrintGeneralOperation(op); PrintGeneralOperation(op);
} }
void IrPrinter::PrintGeneralOperation(Operation* op) { void IrPrinter::PrintGeneralOperation(const Operation* op) {
// TODO(lyk): add API to get opresults directly // TODO(lyk): add API to get opresults directly
PrintOpResult(op); PrintOpResult(op);
os << " ="; os << " =";
...@@ -153,7 +153,7 @@ void IrPrinter::PrintGeneralOperation(Operation* op) { ...@@ -153,7 +153,7 @@ void IrPrinter::PrintGeneralOperation(Operation* op) {
PrintOpReturnType(op); PrintOpReturnType(op);
} }
void IrPrinter::PrintFullOperation(Operation* op) { void IrPrinter::PrintFullOperation(const Operation* op) {
PrintOperation(op); PrintOperation(op);
if (op->num_regions() > 0) { if (op->num_regions() > 0) {
os << newline; os << newline;
...@@ -171,7 +171,7 @@ void IrPrinter::PrintRegion(const Region& region) { ...@@ -171,7 +171,7 @@ void IrPrinter::PrintRegion(const Region& region) {
} }
} }
void IrPrinter::PrintBlock(Block* block) { void IrPrinter::PrintBlock(const Block* block) {
os << "{\n"; os << "{\n";
for (auto it = block->begin(); it != block->end(); ++it) { for (auto it = block->begin(); it != block->end(); ++it) {
PrintOperation(*it); PrintOperation(*it);
...@@ -180,7 +180,7 @@ void IrPrinter::PrintBlock(Block* block) { ...@@ -180,7 +180,7 @@ void IrPrinter::PrintBlock(Block* block) {
os << "}\n"; os << "}\n";
} }
void IrPrinter::PrintValue(Value v) { void IrPrinter::PrintValue(const Value& v) {
if (!v) { if (!v) {
os << "<<NULL VALUE>>"; os << "<<NULL VALUE>>";
return; return;
...@@ -198,7 +198,7 @@ void IrPrinter::PrintValue(Value v) { ...@@ -198,7 +198,7 @@ void IrPrinter::PrintValue(Value v) {
os << new_name; os << new_name;
} }
void IrPrinter::PrintOpResult(Operation* op) { void IrPrinter::PrintOpResult(const Operation* op) {
os << " ("; os << " (";
auto num_op_result = op->num_results(); auto num_op_result = op->num_results();
std::vector<OpResult> op_results; std::vector<OpResult> op_results;
...@@ -214,7 +214,7 @@ void IrPrinter::PrintOpResult(Operation* op) { ...@@ -214,7 +214,7 @@ void IrPrinter::PrintOpResult(Operation* op) {
os << ")"; os << ")";
} }
void IrPrinter::PrintAttributeMap(Operation* op) { void IrPrinter::PrintAttributeMap(const Operation* op) {
os << " {"; os << " {";
PrintInterleave( PrintInterleave(
...@@ -230,7 +230,7 @@ void IrPrinter::PrintAttributeMap(Operation* op) { ...@@ -230,7 +230,7 @@ void IrPrinter::PrintAttributeMap(Operation* op) {
os << "}"; os << "}";
} }
void IrPrinter::PrintOpOperands(Operation* op) { void IrPrinter::PrintOpOperands(const Operation* op) {
os << " ("; os << " (";
auto num_op_operands = op->num_operands(); auto num_op_operands = op->num_operands();
std::vector<Value> op_operands; std::vector<Value> op_operands;
...@@ -246,7 +246,7 @@ void IrPrinter::PrintOpOperands(Operation* op) { ...@@ -246,7 +246,7 @@ void IrPrinter::PrintOpOperands(Operation* op) {
os << ")"; os << ")";
} }
void IrPrinter::PrintOperandsType(Operation* op) { void IrPrinter::PrintOperandsType(const Operation* op) {
auto num_op_operands = op->num_operands(); auto num_op_operands = op->num_operands();
std::vector<Type> op_operand_types; std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands); op_operand_types.reserve(num_op_operands);
...@@ -267,7 +267,7 @@ void IrPrinter::PrintOperandsType(Operation* op) { ...@@ -267,7 +267,7 @@ void IrPrinter::PrintOperandsType(Operation* op) {
os << ")"; os << ")";
} }
void IrPrinter::PrintOpReturnType(Operation* op) { void IrPrinter::PrintOpReturnType(const Operation* op) {
auto num_op_result = op->num_results(); auto num_op_result = op->num_results();
std::vector<Type> op_result_types; std::vector<Type> op_result_types;
op_result_types.reserve(num_op_result); op_result_types.reserve(num_op_result);
...@@ -286,16 +286,16 @@ void IrPrinter::PrintOpReturnType(Operation* op) { ...@@ -286,16 +286,16 @@ void IrPrinter::PrintOpReturnType(Operation* op) {
[this]() { this->os << ", "; }); [this]() { this->os << ", "; });
} }
void Dialect::PrintOperation(Operation* op, IrPrinter& printer) const { void Dialect::PrintOperation(const Operation* op, IrPrinter& printer) const {
printer.PrintGeneralOperation(op); printer.PrintGeneralOperation(op);
} }
void Program::Print(std::ostream& os) { void Program::Print(std::ostream& os) const {
IrPrinter printer(os); IrPrinter printer(os);
printer.PrintProgram(this); printer.PrintProgram(this);
} }
void Operation::Print(std::ostream& os) { void Operation::Print(std::ostream& os) const {
IrPrinter printer(os); IrPrinter printer(os);
printer.PrintFullOperation(this); printer.PrintFullOperation(this);
} }
......
...@@ -46,29 +46,29 @@ class IR_API IrPrinter : public BasicIrPrinter { ...@@ -46,29 +46,29 @@ class IR_API IrPrinter : public BasicIrPrinter {
/// @brief print program /// @brief print program
/// @param program /// @param program
void PrintProgram(Program* program); void PrintProgram(const Program* program);
/// @brief dispatch to custom printer function or PrintGeneralOperation /// @brief dispatch to custom printer function or PrintGeneralOperation
void PrintOperation(Operation* op); void PrintOperation(const Operation* op);
/// @brief print operation itself without its regions /// @brief print operation itself without its regions
void PrintGeneralOperation(Operation* op); void PrintGeneralOperation(const Operation* op);
/// @brief print operation and its regions /// @brief print operation and its regions
void PrintFullOperation(Operation* op); void PrintFullOperation(const Operation* op);
void PrintRegion(const Region& Region); void PrintRegion(const Region& Region);
void PrintBlock(Block* block); void PrintBlock(const Block* block);
void PrintValue(Value v); void PrintValue(const Value& v);
void PrintOpResult(Operation* op); void PrintOpResult(const Operation* op);
void PrintAttributeMap(Operation* op); void PrintAttributeMap(const Operation* op);
void PrintOpOperands(Operation* op); void PrintOpOperands(const Operation* op);
void PrintOperandsType(Operation* op); void PrintOperandsType(const Operation* op);
void PrintOpReturnType(Operation* op); void PrintOpReturnType(const Operation* op);
private: private:
size_t cur_var_number_{0}; size_t cur_var_number_{0};
......
...@@ -232,6 +232,11 @@ Region &Operation::region(unsigned index) { ...@@ -232,6 +232,11 @@ Region &Operation::region(unsigned index) {
return regions_[index]; return regions_[index];
} }
const Region &Operation::region(unsigned index) const {
assert(index < num_regions_ && "invalid region index");
return regions_[index];
}
void Operation::SetParent(Block *parent, const Block::iterator &position) { void Operation::SetParent(Block *parent, const Block::iterator &position) {
parent_ = parent; parent_ = parent;
position_ = position; position_ = position;
......
...@@ -59,8 +59,9 @@ class IR_API alignas(8) Operation final { ...@@ -59,8 +59,9 @@ class IR_API alignas(8) Operation final {
/// Returns the region held by this operation at position 'index'. /// Returns the region held by this operation at position 'index'.
Region &region(unsigned index); Region &region(unsigned index);
const Region &region(unsigned index) const;
void Print(std::ostream &os); void Print(std::ostream &os) const;
const AttributeMap &attributes() const { return attributes_; } const AttributeMap &attributes() const { return attributes_; }
......
...@@ -48,11 +48,12 @@ class IR_API Program { ...@@ -48,11 +48,12 @@ class IR_API Program {
~Program(); ~Program();
size_t parameters_num() const { return parameters_.size(); } size_t parameters_num() const { return parameters_.size(); }
ModuleOp module_op() { return module_; } ModuleOp module_op() const { return module_; }
void Print(std::ostream& os); void Print(std::ostream& os) const;
Block* block() { return module_.block(); } Block* block() { return module_.block(); }
const Block* block() const { return module_op().block(); }
Parameter* GetParameter(const std::string& name) const; Parameter* GetParameter(const std::string& name) const;
void SetParameter(const std::string& name, void SetParameter(const std::string& name,
......
...@@ -155,7 +155,7 @@ class TestDialect : public ir::Dialect { ...@@ -155,7 +155,7 @@ class TestDialect : public ir::Dialect {
} }
static const char *name() { return "test"; } static const char *name() { return "test"; }
void PrintOperation(ir::Operation *op, void PrintOperation(const ir::Operation *op,
ir::IrPrinter &printer) const override { ir::IrPrinter &printer) const override {
printer.PrintOpResult(op); printer.PrintOpResult(op);
printer.os << " ="; printer.os << " =";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册