// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "paddle/ir/core/block.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/dialect.h" #include "paddle/ir/core/ir_printer.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/program.h" #include "paddle/ir/core/utils.h" #include "paddle/ir/core/value.h" namespace ir { namespace { constexpr char newline[] = "\n"; } // namespace void BasicIrPrinter::PrintType(Type type) { if (!type) { os << "<>"; return; } if (type.isa()) { os << "bf16"; } else if (type.isa()) { os << "f16"; } else if (type.isa()) { os << "f32"; } else if (type.isa()) { os << "f64"; } else if (type.isa()) { os << "b"; } else if (type.isa()) { os << "i8"; } else if (type.isa()) { os << "u8"; } else if (type.isa()) { os << "i16"; } else if (type.isa()) { os << "i32"; } else if (type.isa()) { os << "i64"; } else if (type.isa()) { os << "c64"; } else if (type.isa()) { os << "c128"; } else if (type.isa()) { os << "vec["; auto inner_types = type.dyn_cast().data(); PrintInterleave( inner_types.begin(), inner_types.end(), [this](Type v) { this->PrintType(v); }, [this]() { this->os << ","; }); os << "]"; } else { auto& dialect = type.dialect(); dialect.PrintType(type, os); } } void BasicIrPrinter::PrintAttribute(const Attribute& attr) { if (!attr) { os << "<#AttrNull>"; return; } if (auto s = attr.dyn_cast()) { os << s.data(); } else if (auto b = attr.dyn_cast()) { os << b.data(); } else if (auto f = attr.dyn_cast()) { os << f.data(); } else if (auto d = attr.dyn_cast()) { os << d.data(); } else if (auto i = attr.dyn_cast()) { os << i.data(); } else if (auto i = attr.dyn_cast()) { os << i.data(); } else if (auto p = attr.dyn_cast()) { os << p.data(); } else if (auto arr = attr.dyn_cast()) { const auto& vec = arr.data(); os << "array["; PrintInterleave( vec.begin(), vec.end(), [this](Attribute v) { this->PrintAttribute(v); }, [this]() { this->os << ","; }); os << "]"; } else if (auto type = attr.dyn_cast()) { os << type.data(); } else { auto& dialect = attr.dialect(); dialect.PrintAttribute(attr, os); } } void IrPrinter::PrintProgram(Program* program) { auto top_level_op = program->module_op(); for (size_t i = 0; i < top_level_op->num_regions(); ++i) { auto& region = top_level_op->region(i); for (auto it = region.begin(); it != region.end(); ++it) { auto* block = *it; os << "{\n"; for (auto it = block->begin(); it != block->end(); ++it) { PrintOperation(*it); os << newline; } os << "}\n"; } } } void IrPrinter::PrintOperation(Operation* op) { if (auto* dialect = op->dialect()) { dialect->PrintOperation(op, *this); return; } PrintGeneralOperation(op); } void IrPrinter::PrintGeneralOperation(Operation* op) { // TODO(lyk): add API to get opresults directly PrintOpResult(op); os << " ="; os << " \"" << op->name() << "\""; // TODO(lyk): add API to get operands directly PrintOpOperands(op); PrintAttributeMap(op); os << " :"; // PrintOpSingature PrintOperandsType(op); os << " -> "; // TODO(lyk): add API to get opresults directly PrintOpReturnType(op); } void IrPrinter::PrintFullOperation(Operation* op) { PrintOperation(op); if (op->num_regions() > 0) { os << newline; } for (size_t i = 0; i < op->num_regions(); ++i) { auto& region = op->region(i); PrintRegion(region); } } void IrPrinter::PrintRegion(const Region& region) { for (auto it = region.begin(); it != region.end(); ++it) { auto* block = *it; PrintBlock(block); } } void IrPrinter::PrintBlock(Block* block) { os << "{\n"; for (auto it = block->begin(); it != block->end(); ++it) { PrintOperation(*it); os << newline; } os << "}\n"; } void IrPrinter::PrintValue(Value v) { if (!v) { os << "<>"; return; } const void* key = static_cast(v.impl()); auto ret = aliases_.find(key); if (ret != aliases_.end()) { os << ret->second; return; } std::string new_name = "%" + std::to_string(cur_var_number_); cur_var_number_++; aliases_[key] = new_name; os << new_name; } void IrPrinter::PrintOpResult(Operation* op) { os << " ("; auto num_op_result = op->num_results(); std::vector op_results; op_results.reserve(num_op_result); for (size_t idx = 0; idx < num_op_result; idx++) { op_results.push_back(op->result(idx)); } PrintInterleave( op_results.begin(), op_results.end(), [this](Value v) { this->PrintValue(v); }, [this]() { this->os << ", "; }); os << ")"; } void IrPrinter::PrintAttributeMap(Operation* op) { os << " {"; PrintInterleave( op->attributes().begin(), op->attributes().end(), [this](std::pair it) { this->os << it.first; this->os << ":"; this->PrintAttribute(it.second); }, [this]() { this->os << ","; }); os << "}"; } void IrPrinter::PrintOpOperands(Operation* op) { os << " ("; auto num_op_operands = op->num_operands(); std::vector op_operands; op_operands.reserve(num_op_operands); for (size_t idx = 0; idx < num_op_operands; idx++) { op_operands.push_back(op->operand(idx)); } PrintInterleave( op_operands.begin(), op_operands.end(), [this](Value v) { this->PrintValue(v); }, [this]() { this->os << ", "; }); os << ")"; } void IrPrinter::PrintOperandsType(Operation* op) { auto num_op_operands = op->num_operands(); std::vector op_operand_types; op_operand_types.reserve(num_op_operands); for (size_t idx = 0; idx < num_op_operands; idx++) { auto op_operand = op->op_operand(idx); if (op_operand) { op_operand_types.push_back(op_operand.type()); } else { op_operand_types.push_back(Type()); } } os << " ("; PrintInterleave( op_operand_types.begin(), op_operand_types.end(), [this](Type t) { this->PrintType(t); }, [this]() { this->os << ", "; }); os << ")"; } void IrPrinter::PrintOpReturnType(Operation* op) { auto num_op_result = op->num_results(); std::vector op_result_types; op_result_types.reserve(num_op_result); for (size_t idx = 0; idx < num_op_result; idx++) { auto op_result = op->result(idx); if (op_result) { op_result_types.push_back(op_result.type()); } else { op_result_types.push_back(Type(nullptr)); } } PrintInterleave( op_result_types.begin(), op_result_types.end(), [this](Type t) { this->PrintType(t); }, [this]() { this->os << ", "; }); } void Dialect::PrintOperation(Operation* op, IrPrinter& printer) const { printer.PrintGeneralOperation(op); } void Program::Print(std::ostream& os) { IrPrinter printer(os); printer.PrintProgram(this); } void Operation::Print(std::ostream& os) { IrPrinter printer(os); printer.PrintFullOperation(this); } void Type::Print(std::ostream& os) const { BasicIrPrinter printer(os); printer.PrintType(*this); } void Attribute::Print(std::ostream& os) const { BasicIrPrinter printer(os); printer.PrintAttribute(*this); } std::ostream& operator<<(std::ostream& os, Type type) { type.Print(os); return os; } std::ostream& operator<<(std::ostream& os, Attribute attr) { attr.Print(os); return os; } } // namespace ir