未验证 提交 4fa3e149 编写于 作者: A Aurelius84 提交者: GitHub

[NewIR]Refine IrPrinter and basic Concept Interface for const Object (#55209)

* [NewIR]Refine IrPrinter and basic Concept Interface for const Object
上级 b20d22df
......@@ -37,7 +37,12 @@ namespace pybind {
void BindProgram(py::module *m) {
py::class_<Program> program(*m, "Program");
program.def("parameters_num", &Program::parameters_num)
.def("block", &Program::block, return_value_policy::reference)
.def("block",
py::overload_cast<>(&Program::block),
return_value_policy::reference)
.def("block",
py::overload_cast<>(&Program::block, py::const_),
return_value_policy::reference)
.def("print", [](Program &self) {
std::ostringstream print_stream;
self.Print(print_stream);
......
......@@ -52,7 +52,7 @@ void ModuleOp::Destroy() {
}
}
void ModuleOp::Verify() {
void ModuleOp::Verify() const {
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
// Verify inputs:
IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
......@@ -79,7 +79,7 @@ void GetParameterOp::Build(Builder &builder,
argument.output_types.emplace_back(type);
}
void GetParameterOp::Verify() {
void GetParameterOp::Verify() const {
VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
// Verify inputs:
IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
......@@ -105,7 +105,7 @@ void SetParameterOp::Build(Builder &builder, // NOLINT
argument.AddAttribute(attributes_name[0],
ir::StrAttribute::get(builder.ir_context(), name));
}
void SetParameterOp::Verify() {
void SetParameterOp::Verify() const {
VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
// Verify inputs:
IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1.");
......@@ -132,7 +132,7 @@ void CombineOp::Build(Builder &builder,
ir::VectorType::get(builder.ir_context(), inputs_type));
}
void CombineOp::Verify() {
void CombineOp::Verify() const {
// outputs.size() == 1
IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1.");
......@@ -162,7 +162,7 @@ void CombineOp::Verify() {
}
const char *SliceOp::attributes_name[attributes_num] = {"index"};
void SliceOp::Verify() {
void SliceOp::Verify() const {
// inputs.size() == 1
auto input_size = num_operands();
IR_ENFORCE(
......@@ -217,13 +217,13 @@ void ConstantOp::Build(Builder &builder,
argument.output_types.push_back(output_type);
}
void ConstantOp::Verify() {
void ConstantOp::Verify() const {
IR_ENFORCE(num_operands() == 0, "The size of inputs must be equal to 0.");
IR_ENFORCE(num_results() == 1, "The size of outputs must be equal to 1.");
IR_ENFORCE(attributes().count("value") > 0, "must has value attribute");
}
Attribute ConstantOp::value() { return attributes().at("value"); }
Attribute ConstantOp::value() const { return attributes().at("value"); }
} // namespace ir
......
......@@ -30,7 +30,7 @@ class IR_API ModuleOp : public ir::Op<ModuleOp> {
static const char *name() { return "builtin.module"; }
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
void Verify();
void Verify() const;
Program *program();
Block *block();
......@@ -55,7 +55,7 @@ class IR_API GetParameterOp : public ir::Op<GetParameterOp> {
OperationArgument &argument, // NOLINT
const std::string &name,
Type type);
void Verify();
void Verify() const;
};
///
......@@ -72,7 +72,7 @@ class IR_API SetParameterOp : public ir::Op<SetParameterOp> {
OperationArgument &argument, // NOLINT
OpResult parameter,
const std::string &name);
void Verify();
void Verify() const;
};
///
......@@ -92,7 +92,7 @@ class IR_API CombineOp : public ir::Op<CombineOp> {
OperationArgument &argument, // NOLINT
const std::vector<ir::OpResult> &inputs);
void Verify();
void Verify() const;
};
///
......@@ -107,7 +107,7 @@ class IR_API SliceOp : public ir::Op<SliceOp> {
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
void Verify();
void Verify() const;
};
class IR_API ConstantLikeTrait : public OpTraitBase<ConstantLikeTrait> {
......@@ -132,9 +132,9 @@ class IR_API ConstantOp : public Op<ConstantOp, ConstantLikeTrait> {
Attribute value,
Type output_type);
void Verify();
void Verify() const;
Attribute value();
Attribute value() const;
};
} // namespace ir
......
......@@ -145,7 +145,7 @@ class IR_API Dialect {
IR_THROW("dialect has no registered attribute printing hook");
}
virtual void PrintOperation(Operation *op,
virtual void PrintOperation(const Operation *op,
IrPrinter &printer) const; // NOLINT
private:
......
......@@ -115,7 +115,7 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) {
}
}
void IrPrinter::PrintProgram(Program* program) {
void IrPrinter::PrintProgram(const Program* program) {
auto top_level_op = program->module_op();
for (size_t i = 0; i < top_level_op->num_regions(); ++i) {
auto& region = top_level_op->region(i);
......@@ -123,7 +123,7 @@ void IrPrinter::PrintProgram(Program* program) {
}
}
void IrPrinter::PrintOperation(Operation* op) {
void IrPrinter::PrintOperation(const Operation* op) {
if (auto* dialect = op->dialect()) {
dialect->PrintOperation(op, *this);
return;
......@@ -132,7 +132,7 @@ void IrPrinter::PrintOperation(Operation* op) {
PrintGeneralOperation(op);
}
void IrPrinter::PrintGeneralOperation(Operation* op) {
void IrPrinter::PrintGeneralOperation(const Operation* op) {
// TODO(lyk): add API to get opresults directly
PrintOpResult(op);
os << " =";
......@@ -153,7 +153,7 @@ void IrPrinter::PrintGeneralOperation(Operation* op) {
PrintOpReturnType(op);
}
void IrPrinter::PrintFullOperation(Operation* op) {
void IrPrinter::PrintFullOperation(const Operation* op) {
PrintOperation(op);
if (op->num_regions() > 0) {
os << newline;
......@@ -171,7 +171,7 @@ void IrPrinter::PrintRegion(const Region& region) {
}
}
void IrPrinter::PrintBlock(Block* block) {
void IrPrinter::PrintBlock(const Block* block) {
os << "{\n";
for (auto it = block->begin(); it != block->end(); ++it) {
PrintOperation(*it);
......@@ -180,7 +180,7 @@ void IrPrinter::PrintBlock(Block* block) {
os << "}\n";
}
void IrPrinter::PrintValue(Value v) {
void IrPrinter::PrintValue(const Value& v) {
if (!v) {
os << "<<NULL VALUE>>";
return;
......@@ -198,7 +198,7 @@ void IrPrinter::PrintValue(Value v) {
os << new_name;
}
void IrPrinter::PrintOpResult(Operation* op) {
void IrPrinter::PrintOpResult(const Operation* op) {
os << " (";
auto num_op_result = op->num_results();
std::vector<OpResult> op_results;
......@@ -214,7 +214,7 @@ void IrPrinter::PrintOpResult(Operation* op) {
os << ")";
}
void IrPrinter::PrintAttributeMap(Operation* op) {
void IrPrinter::PrintAttributeMap(const Operation* op) {
os << " {";
PrintInterleave(
......@@ -230,7 +230,7 @@ void IrPrinter::PrintAttributeMap(Operation* op) {
os << "}";
}
void IrPrinter::PrintOpOperands(Operation* op) {
void IrPrinter::PrintOpOperands(const Operation* op) {
os << " (";
auto num_op_operands = op->num_operands();
std::vector<Value> op_operands;
......@@ -246,7 +246,7 @@ void IrPrinter::PrintOpOperands(Operation* op) {
os << ")";
}
void IrPrinter::PrintOperandsType(Operation* op) {
void IrPrinter::PrintOperandsType(const Operation* op) {
auto num_op_operands = op->num_operands();
std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands);
......@@ -267,7 +267,7 @@ void IrPrinter::PrintOperandsType(Operation* op) {
os << ")";
}
void IrPrinter::PrintOpReturnType(Operation* op) {
void IrPrinter::PrintOpReturnType(const Operation* op) {
auto num_op_result = op->num_results();
std::vector<Type> op_result_types;
op_result_types.reserve(num_op_result);
......@@ -286,16 +286,16 @@ void IrPrinter::PrintOpReturnType(Operation* op) {
[this]() { this->os << ", "; });
}
void Dialect::PrintOperation(Operation* op, IrPrinter& printer) const {
void Dialect::PrintOperation(const Operation* op, IrPrinter& printer) const {
printer.PrintGeneralOperation(op);
}
void Program::Print(std::ostream& os) {
void Program::Print(std::ostream& os) const {
IrPrinter printer(os);
printer.PrintProgram(this);
}
void Operation::Print(std::ostream& os) {
void Operation::Print(std::ostream& os) const {
IrPrinter printer(os);
printer.PrintFullOperation(this);
}
......
......@@ -46,29 +46,29 @@ class IR_API IrPrinter : public BasicIrPrinter {
/// @brief print program
/// @param program
void PrintProgram(Program* program);
void PrintProgram(const Program* program);
/// @brief dispatch to custom printer function or PrintGeneralOperation
void PrintOperation(Operation* op);
void PrintOperation(const Operation* op);
/// @brief print operation itself without its regions
void PrintGeneralOperation(Operation* op);
void PrintGeneralOperation(const Operation* op);
/// @brief print operation and its regions
void PrintFullOperation(Operation* op);
void PrintFullOperation(const Operation* op);
void PrintRegion(const Region& Region);
void PrintBlock(Block* block);
void PrintBlock(const Block* block);
void PrintValue(Value v);
void PrintValue(const Value& v);
void PrintOpResult(Operation* op);
void PrintOpResult(const Operation* op);
void PrintAttributeMap(Operation* op);
void PrintAttributeMap(const Operation* op);
void PrintOpOperands(Operation* op);
void PrintOpOperands(const Operation* op);
void PrintOperandsType(Operation* op);
void PrintOperandsType(const Operation* op);
void PrintOpReturnType(Operation* op);
void PrintOpReturnType(const Operation* op);
private:
size_t cur_var_number_{0};
......
......@@ -232,6 +232,11 @@ Region &Operation::region(unsigned index) {
return regions_[index];
}
const Region &Operation::region(unsigned index) const {
assert(index < num_regions_ && "invalid region index");
return regions_[index];
}
void Operation::SetParent(Block *parent, const Block::iterator &position) {
parent_ = parent;
position_ = position;
......
......@@ -59,8 +59,9 @@ class IR_API alignas(8) Operation final {
/// Returns the region held by this operation at position 'index'.
Region &region(unsigned index);
const Region &region(unsigned index) const;
void Print(std::ostream &os);
void Print(std::ostream &os) const;
const AttributeMap &attributes() const { return attributes_; }
......
......@@ -48,11 +48,12 @@ class IR_API Program {
~Program();
size_t parameters_num() const { return parameters_.size(); }
ModuleOp module_op() { return module_; }
ModuleOp module_op() const { return module_; }
void Print(std::ostream& os);
void Print(std::ostream& os) const;
Block* block() { return module_.block(); }
const Block* block() const { return module_op().block(); }
Parameter* GetParameter(const std::string& name) const;
void SetParameter(const std::string& name,
......
......@@ -155,7 +155,7 @@ class TestDialect : public ir::Dialect {
}
static const char *name() { return "test"; }
void PrintOperation(ir::Operation *op,
void PrintOperation(const ir::Operation *op,
ir::IrPrinter &printer) const override {
printer.PrintOpResult(op);
printer.os << " =";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册