未验证 提交 7d688871 编写于 作者: K kangguangli 提交者: GitHub

[IR] Support custom op printer (#54499)

* adapt_startup_program

* refactor program translator

* polish

* add custom op printer hook

* fix merge conflicts

* fix top level op printer

* adapt full int array op

* modify by reviews

* fix
上级 64ecdc03
......@@ -23,6 +23,7 @@
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir_adaptor/translator/attribute_translator.h"
......@@ -198,20 +199,18 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx,
inline ir::Operation* InsertFullArrayOperationForAttributeInput(
ir::IrContext* ctx, ir::Program* program, ir::Attribute attr) {
std::string constant_op_name(ir::ConstantOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name);
ir::Type null_type = paddle::dialect::DenseTensorType::get(
ctx,
ir::Type(nullptr),
phi::DDim{},
paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED,
phi::LoD{},
0); // TODO(lyk): to be done
ir::Operation* operation =
ir::Operation::Create({}, {{"value", attr}}, {null_type}, op_info);
program->block()->push_back(operation);
return operation;
IR_ENFORCE(attr.isa<paddle::dialect::IntArrayAttribute>(),
"Encounter non IntArray type when trying to insert IntArray "
"mutable attribute");
phi::IntArray int_array =
attr.dyn_cast<paddle::dialect::IntArrayAttribute>().data();
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program->block());
paddle::dialect::FullIntArrayOp full_int_array_op =
builder.Build<paddle::dialect::FullIntArrayOp>(
int_array.GetData(), phi::DataType::INT64, phi::CPUPlace());
return full_int_array_op.operation();
}
inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx,
......
......@@ -14,6 +14,7 @@
#pragma once
#include <functional>
#include <ostream>
#include "paddle/ir/core/attribute.h"
......@@ -25,6 +26,10 @@
#include "paddle/ir/core/type_base.h"
namespace ir {
class Operation;
class IrPrinter;
class DialectInterface;
///
/// \brief Dialect can basically be understood as a namespace. In Dialect, we
......@@ -136,10 +141,13 @@ class Dialect {
IR_THROW("dialect has no registered type printing hook");
}
virtual void PrintAttribute(Attribute type, std::ostream &os) const {
virtual void PrintAttribute(Attribute attr, std::ostream &os) const {
IR_THROW("dialect has no registered attribute printing hook");
}
virtual void PrintOperation(Operation *op,
IrPrinter &printer) const; // NOLINT
private:
Dialect(const Dialect &) = delete;
......
......@@ -21,6 +21,7 @@
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_printer.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/utils.h"
......@@ -32,255 +33,274 @@ namespace {
constexpr char newline[] = "\n";
} // namespace
class BasicIRPrinter {
public:
explicit BasicIRPrinter(std::ostream& os) : os(os) {}
void PrintType(Type type) {
if (!type) {
os << "<<NULL TYPE>>";
return;
}
void BasicIrPrinter::PrintType(Type type) {
if (!type) {
os << "<<NULL TYPE>>";
return;
}
if (type.isa<Float16Type>()) {
os << "f16";
} else if (type.isa<Float32Type>()) {
os << "f32";
} else if (type.isa<Float64Type>()) {
os << "f64";
} else if (type.isa<Int16Type>()) {
os << "i16";
} else if (type.isa<Int32Type>()) {
os << "i32";
} else if (type.isa<Int64Type>()) {
os << "i64";
} else if (type.isa<VectorType>()) {
os << "vec[";
auto inner_types = type.dyn_cast<VectorType>().data();
PrintInterleave(
inner_types.begin(),
inner_types.end(),
[this](Type v) { this->PrintType(v); },
[this]() { this->os << ","; });
os << "]";
} else {
auto& dialect = type.dialect();
dialect.PrintType(type, os);
}
if (type.isa<Float16Type>()) {
os << "f16";
} else if (type.isa<Float32Type>()) {
os << "f32";
} else if (type.isa<Float64Type>()) {
os << "f64";
} else if (type.isa<Int16Type>()) {
os << "i16";
} else if (type.isa<Int32Type>()) {
os << "i32";
} else if (type.isa<Int64Type>()) {
os << "i64";
} else if (type.isa<VectorType>()) {
os << "vec[";
auto inner_types = type.dyn_cast<VectorType>().data();
PrintInterleave(
inner_types.begin(),
inner_types.end(),
[this](Type v) { this->PrintType(v); },
[this]() { this->os << ","; });
os << "]";
} else {
auto& dialect = type.dialect();
dialect.PrintType(type, os);
}
}
void PrintAttribute(const Attribute& attr) {
if (!attr) {
os << "<#AttrNull>";
return;
}
void BasicIrPrinter::PrintAttribute(const Attribute& attr) {
if (!attr) {
os << "<#AttrNull>";
return;
}
if (auto s = attr.dyn_cast<StrAttribute>()) {
os << s.data();
} else if (auto b = attr.dyn_cast<BoolAttribute>()) {
os << b.data();
} else if (auto f = attr.dyn_cast<FloatAttribute>()) {
os << f.data();
} else if (auto d = attr.dyn_cast<DoubleAttribute>()) {
os << d.data();
} else if (auto i = attr.dyn_cast<Int32Attribute>()) {
os << i.data();
} else if (auto i = attr.dyn_cast<Int64Attribute>()) {
os << i.data();
} else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
const auto& vec = arr.data();
os << "array[";
PrintInterleave(
vec.begin(),
vec.end(),
[this](Attribute v) { this->PrintAttribute(v); },
[this]() { this->os << ","; });
os << "]";
} else {
auto& dialect = attr.dialect();
dialect.PrintAttribute(attr, os);
}
if (auto s = attr.dyn_cast<StrAttribute>()) {
os << s.data();
} else if (auto b = attr.dyn_cast<BoolAttribute>()) {
os << b.data();
} else if (auto f = attr.dyn_cast<FloatAttribute>()) {
os << f.data();
} else if (auto d = attr.dyn_cast<DoubleAttribute>()) {
os << d.data();
} else if (auto i = attr.dyn_cast<Int32Attribute>()) {
os << i.data();
} else if (auto i = attr.dyn_cast<Int64Attribute>()) {
os << i.data();
} else if (auto p = attr.dyn_cast<PointerAttribute>()) {
os << p.data();
} else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
const auto& vec = arr.data();
os << "array[";
PrintInterleave(
vec.begin(),
vec.end(),
[this](Attribute v) { this->PrintAttribute(v); },
[this]() { this->os << ","; });
os << "]";
} else {
auto& dialect = attr.dialect();
dialect.PrintAttribute(attr, os);
}
}
public:
std::ostream& os;
};
class IRPrinter : public BasicIRPrinter {
public:
explicit IRPrinter(std::ostream& os) : BasicIRPrinter(os) {}
/// @brief print program
/// @param program
/// @example
void PrintProgram(Program* program) { PrintOperation(program->module_op()); }
/// @brief print operation
/// @param op
/// @example
void PrintOperation(Operation* op) {
for (size_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->GetRegion(i);
for (auto it = region.begin(); it != region.end(); ++it) {
auto* block = *it;
os << "{\n";
for (auto it = block->begin(); it != block->end(); ++it) {
auto* op = *it;
// TODO(lyk): add API to get opresults directly
PrintOpResult(op);
os << " =";
os << " \"" << op->name() << "\"";
// TODO(lyk): add API to get operands directly
PrintOpOperands(op);
PrintAttributeMap(op);
os << " :";
// PrintOpSingature
PrintOperandsType(op);
os << " -> ";
// TODO(lyk): add API to get opresults directly
PrintOpReturnType(op);
os << newline;
}
os << "}\n";
void IrPrinter::PrintProgram(Program* program) {
auto top_level_op = program->module_op();
for (size_t i = 0; i < top_level_op->num_regions(); ++i) {
auto& region = top_level_op->GetRegion(i);
for (auto it = region.begin(); it != region.end(); ++it) {
auto* block = *it;
os << "{\n";
for (auto it = block->begin(); it != block->end(); ++it) {
PrintOperation(*it);
os << newline;
}
os << "}\n";
}
}
}
private:
void PrintValue(Value v) {
if (!v) {
os << "<<NULL VALUE>>";
return;
}
const void* key = static_cast<const void*>(v.impl());
auto ret = aliases_.find(key);
if (ret != aliases_.end()) {
os << ret->second;
return;
}
void IrPrinter::PrintOperation(Operation* op) {
if (auto* dialect = op->dialect()) {
dialect->PrintOperation(op, *this);
return;
}
PrintGeneralOperation(op);
}
void IrPrinter::PrintGeneralOperation(Operation* op) {
// TODO(lyk): add API to get opresults directly
PrintOpResult(op);
os << " =";
os << " \"" << op->name() << "\"";
// TODO(lyk): add API to get operands directly
PrintOpOperands(op);
PrintAttributeMap(op);
os << " :";
// PrintOpSingature
PrintOperandsType(op);
os << " -> ";
// TODO(lyk): add API to get opresults directly
PrintOpReturnType(op);
}
void IrPrinter::PrintFullOperation(Operation* op) {
PrintOperation(op);
if (op->num_regions() > 0) {
os << newline;
}
for (size_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->GetRegion(i);
PrintRegion(region);
}
}
std::string new_name = "%" + std::to_string(cur_var_number_);
cur_var_number_++;
aliases_[key] = new_name;
os << new_name;
void IrPrinter::PrintRegion(const Region& region) {
for (auto it = region.begin(); it != region.end(); ++it) {
auto* block = *it;
PrintBlock(block);
}
}
void PrintOpResult(Operation* op) {
os << " (";
auto num_op_result = op->num_results();
std::vector<OpResult> op_results;
op_results.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) {
op_results.push_back(op->GetResultByIndex(idx));
}
PrintInterleave(
op_results.begin(),
op_results.end(),
[this](Value v) { this->PrintValue(v); },
[this]() { this->os << ", "; });
os << ")";
void IrPrinter::PrintBlock(Block* block) {
os << "{\n";
for (auto it = block->begin(); it != block->end(); ++it) {
PrintOperation(*it);
os << newline;
}
os << "}\n";
}
void PrintAttributeMap(Operation* op) {
os << " {";
void IrPrinter::PrintValue(Value v) {
if (!v) {
os << "<<NULL VALUE>>";
return;
}
const void* key = static_cast<const void*>(v.impl());
auto ret = aliases_.find(key);
if (ret != aliases_.end()) {
os << ret->second;
return;
}
PrintInterleave(
op->attributes().begin(),
op->attributes().end(),
[this](std::pair<std::string, Attribute> it) {
this->os << it.first;
this->os << ":";
this->PrintAttribute(it.second);
},
[this]() { this->os << ","; });
std::string new_name = "%" + std::to_string(cur_var_number_);
cur_var_number_++;
aliases_[key] = new_name;
os << new_name;
}
os << "}";
void IrPrinter::PrintOpResult(Operation* op) {
os << " (";
auto num_op_result = op->num_results();
std::vector<OpResult> op_results;
op_results.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) {
op_results.push_back(op->GetResultByIndex(idx));
}
PrintInterleave(
op_results.begin(),
op_results.end(),
[this](Value v) { this->PrintValue(v); },
[this]() { this->os << ", "; });
os << ")";
}
void PrintOpOperands(Operation* op) {
os << " (";
auto num_op_operands = op->num_operands();
std::vector<Value> op_operands;
op_operands.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
op_operands.push_back(op->GetOperandByIndex(idx).source());
}
PrintInterleave(
op_operands.begin(),
op_operands.end(),
[this](Value v) { this->PrintValue(v); },
[this]() { this->os << ", "; });
os << ")";
void IrPrinter::PrintAttributeMap(Operation* op) {
os << " {";
PrintInterleave(
op->attributes().begin(),
op->attributes().end(),
[this](std::pair<std::string, Attribute> it) {
this->os << it.first;
this->os << ":";
this->PrintAttribute(it.second);
},
[this]() { this->os << ","; });
os << "}";
}
void IrPrinter::PrintOpOperands(Operation* op) {
os << " (";
auto num_op_operands = op->num_operands();
std::vector<Value> op_operands;
op_operands.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
op_operands.push_back(op->GetOperandByIndex(idx).source());
}
PrintInterleave(
op_operands.begin(),
op_operands.end(),
[this](Value v) { this->PrintValue(v); },
[this]() { this->os << ", "; });
os << ")";
}
void PrintOperandsType(Operation* op) {
auto num_op_operands = op->num_operands();
std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
auto op_operand = op->GetOperandByIndex(idx);
if (op_operand) {
op_operand_types.push_back(op->GetOperandByIndex(idx).source().type());
} else {
op_operand_types.push_back(Type(nullptr));
}
void IrPrinter::PrintOperandsType(Operation* op) {
auto num_op_operands = op->num_operands();
std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
auto op_operand = op->GetOperandByIndex(idx);
if (op_operand) {
op_operand_types.push_back(op->GetOperandByIndex(idx).source().type());
} else {
op_operand_types.push_back(Type(nullptr));
}
os << " (";
PrintInterleave(
op_operand_types.begin(),
op_operand_types.end(),
[this](Type t) { this->PrintType(t); },
[this]() { this->os << ", "; });
os << ")";
}
os << " (";
PrintInterleave(
op_operand_types.begin(),
op_operand_types.end(),
[this](Type t) { this->PrintType(t); },
[this]() { this->os << ", "; });
os << ")";
}
void PrintOpReturnType(Operation* op) {
auto num_op_result = op->num_results();
std::vector<Type> op_result_types;
op_result_types.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) {
auto op_result = op->GetResultByIndex(idx);
if (op_result) {
op_result_types.push_back(op_result.type());
} else {
op_result_types.push_back(Type(nullptr));
}
void IrPrinter::PrintOpReturnType(Operation* op) {
auto num_op_result = op->num_results();
std::vector<Type> op_result_types;
op_result_types.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) {
auto op_result = op->GetResultByIndex(idx);
if (op_result) {
op_result_types.push_back(op_result.type());
} else {
op_result_types.push_back(Type(nullptr));
}
PrintInterleave(
op_result_types.begin(),
op_result_types.end(),
[this](Type t) { this->PrintType(t); },
[this]() { this->os << ", "; });
}
PrintInterleave(
op_result_types.begin(),
op_result_types.end(),
[this](Type t) { this->PrintType(t); },
[this]() { this->os << ", "; });
}
private:
size_t cur_var_number_{0};
std::unordered_map<const void*, std::string> aliases_;
};
void Dialect::PrintOperation(Operation* op, IrPrinter& printer) const {
printer.PrintGeneralOperation(op);
}
void Program::Print(std::ostream& os) {
IRPrinter printer(os);
IrPrinter printer(os);
printer.PrintProgram(this);
}
void Operation::Print(std::ostream& os) {
IRPrinter printer(os);
printer.PrintOperation(this);
IrPrinter printer(os);
printer.PrintFullOperation(this);
}
void Type::Print(std::ostream& os) const {
BasicIRPrinter printer(os);
BasicIrPrinter printer(os);
printer.PrintType(*this);
}
void Attribute::Print(std::ostream& os) const {
BasicIRPrinter printer(os);
BasicIrPrinter printer(os);
printer.PrintAttribute(*this);
}
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ostream>
#include <string>
#include <unordered_map>
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/value.h"
namespace ir {
class BasicIrPrinter {
public:
explicit BasicIrPrinter(std::ostream& os) : os(os) {}
void PrintType(Type type);
void PrintAttribute(const Attribute& attr);
public:
std::ostream& os;
};
class IrPrinter : public BasicIrPrinter {
public:
explicit IrPrinter(std::ostream& os) : BasicIrPrinter(os) {}
/// @brief print program
/// @param program
void PrintProgram(Program* program);
/// @brief dispatch to custom printer function or PrintGeneralOperation
void PrintOperation(Operation* op);
/// @brief print operation itself without its regions
void PrintGeneralOperation(Operation* op);
/// @brief print operation and its regions
void PrintFullOperation(Operation* op);
void PrintRegion(const Region& Region);
void PrintBlock(Block* block);
void PrintValue(Value v);
void PrintOpResult(Operation* op);
void PrintAttributeMap(Operation* op);
void PrintOpOperands(Operation* op);
void PrintOperandsType(Operation* op);
void PrintOpReturnType(Operation* op);
private:
size_t cur_var_number_{0};
std::unordered_map<const void*, std::string> aliases_;
};
} // namespace ir
......@@ -29,6 +29,7 @@ bool OpInfo::HasInterface(TypeId interface_id) const {
IrContext *OpInfo::ir_context() const {
return impl_ ? impl_->ir_context() : nullptr;
}
Dialect *OpInfo::dialect() const { return impl_ ? impl_->dialect() : nullptr; }
const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; }
......
......@@ -15,6 +15,7 @@
#pragma once
#include <functional>
#include <unordered_map>
#include "paddle/ir/core/type_id.h"
namespace ir {
......@@ -23,6 +24,7 @@ class IrContext;
class OpResult;
class Type;
class Attribute;
class Dialect;
class OpInfo {
public:
......@@ -41,6 +43,7 @@ class OpInfo {
bool operator!() const { return impl_ == nullptr; }
IrContext *ir_context() const;
Dialect *dialect() const;
const char *name() const;
......
......@@ -166,6 +166,8 @@ void Operation::Destroy() {
IrContext *Operation::ir_context() const { return info_.ir_context(); }
Dialect *Operation::dialect() const { return info_.dialect(); }
Operation::Operation(const AttributeMap &attributes,
ir::OpInfo op_info,
uint32_t num_results,
......
......@@ -47,7 +47,7 @@ class alignas(8) Operation final {
void Destroy();
IrContext *ir_context() const;
Dialect *dialect() const;
OpResult GetResultByIndex(uint32_t index) const;
OpOperand GetOperandByIndex(uint32_t index) const;
......
......@@ -35,6 +35,8 @@ class Region {
iterator begin() { return blocks_.begin(); }
iterator end() { return blocks_.end(); }
const_iterator begin() const { return blocks_.begin(); }
const_iterator end() const { return blocks_.end(); }
reverse_iterator rbegin() { return blocks_.rbegin(); }
reverse_iterator rend() { return blocks_.rend(); }
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <sstream>
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builder.h"
......@@ -22,6 +23,7 @@
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/ir_printer.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h"
......@@ -150,6 +152,15 @@ class TestDialect : public ir::Dialect {
}
static const char *name() { return "test"; }
void PrintOperation(ir::Operation *op,
ir::IrPrinter &printer) const override {
printer.PrintOpResult(op);
printer.os << " =";
printer.os << " \"" << op->name() << "\"";
printer.PrintOpOperands(op);
}
private:
void initialize() { RegisterOps<Operation1, Operation2>(); }
};
......@@ -222,6 +233,11 @@ TEST(op_test, region_test) {
ir::Region *region = argument.regions.back().get();
EXPECT_EQ(region->empty(), true);
// (3) Test custom operation printer
std::stringstream ss;
op1->Print(ss);
EXPECT_EQ(ss.str(), " (%0) = \"test.operation1\" ()");
region->push_back(new ir::Block());
region->push_front(new ir::Block());
region->insert(region->begin(), new ir::Block());
......@@ -236,7 +252,6 @@ TEST(op_test, module_op_death) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(ir::ModuleOp::name());
// (3) Test uses for op.
std::vector<ir::OpResult> inputs{ir::OpResult()};
ir::AttributeMap attrs{{"program", ir::Int32Attribute::get(ctx, 1)}};
std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)};
......
......@@ -16,6 +16,7 @@
#include <chrono>
#include <iostream>
#include <map>
#include <sstream>
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
......@@ -61,7 +62,9 @@ TEST(PaddleDialectTest, MainProgram) {
EXPECT_EQ(op_size,
p.Block(0).OpSize() + program->parameters_num() + 20 + 3 + 8);
program->Print(std::cout);
std::stringstream ss;
program->Print(ss);
EXPECT_GT(ss.str().size(), 0u);
}
TEST(PaddleDialectTest, StartupProgram) {
......@@ -79,5 +82,7 @@ TEST(PaddleDialectTest, StartupProgram) {
// + consant_op for guassian
EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 3 + 53);
program->Print(std::cout);
std::stringstream ss;
program->Print(ss);
EXPECT_GT(ss.str().size(), 0u);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册