未验证 提交 7d688871 编写于 作者: K kangguangli 提交者: GitHub

[IR] Support custom op printer (#54499)

* adapt_startup_program

* refactor program translator

* polish

* add custom op printer hook

* fix merge conflicts

* fix top level op printer

* adapt full int array op

* modify by reviews

* fix
上级 64ecdc03
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir_adaptor/translator/attribute_translator.h" #include "paddle/fluid/ir_adaptor/translator/attribute_translator.h"
...@@ -198,20 +199,18 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx, ...@@ -198,20 +199,18 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx,
inline ir::Operation* InsertFullArrayOperationForAttributeInput( inline ir::Operation* InsertFullArrayOperationForAttributeInput(
ir::IrContext* ctx, ir::Program* program, ir::Attribute attr) { ir::IrContext* ctx, ir::Program* program, ir::Attribute attr) {
std::string constant_op_name(ir::ConstantOp::name()); IR_ENFORCE(attr.isa<paddle::dialect::IntArrayAttribute>(),
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name); "Encounter non IntArray type when trying to insert IntArray "
"mutable attribute");
ir::Type null_type = paddle::dialect::DenseTensorType::get(
ctx, phi::IntArray int_array =
ir::Type(nullptr), attr.dyn_cast<paddle::dialect::IntArrayAttribute>().data();
phi::DDim{},
paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED, ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program->block());
phi::LoD{}, paddle::dialect::FullIntArrayOp full_int_array_op =
0); // TODO(lyk): to be done builder.Build<paddle::dialect::FullIntArrayOp>(
ir::Operation* operation = int_array.GetData(), phi::DataType::INT64, phi::CPUPlace());
ir::Operation::Create({}, {{"value", attr}}, {null_type}, op_info); return full_int_array_op.operation();
program->block()->push_back(operation);
return operation;
} }
inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx, inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <functional>
#include <ostream> #include <ostream>
#include "paddle/ir/core/attribute.h" #include "paddle/ir/core/attribute.h"
...@@ -25,6 +26,10 @@ ...@@ -25,6 +26,10 @@
#include "paddle/ir/core/type_base.h" #include "paddle/ir/core/type_base.h"
namespace ir { namespace ir {
class Operation;
class IrPrinter;
class DialectInterface; class DialectInterface;
/// ///
/// \brief Dialect can basically be understood as a namespace. In Dialect, we /// \brief Dialect can basically be understood as a namespace. In Dialect, we
...@@ -136,10 +141,13 @@ class Dialect { ...@@ -136,10 +141,13 @@ class Dialect {
IR_THROW("dialect has no registered type printing hook"); IR_THROW("dialect has no registered type printing hook");
} }
virtual void PrintAttribute(Attribute type, std::ostream &os) const { virtual void PrintAttribute(Attribute attr, std::ostream &os) const {
IR_THROW("dialect has no registered attribute printing hook"); IR_THROW("dialect has no registered attribute printing hook");
} }
virtual void PrintOperation(Operation *op,
IrPrinter &printer) const; // NOLINT
private: private:
Dialect(const Dialect &) = delete; Dialect(const Dialect &) = delete;
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_printer.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
#include "paddle/ir/core/utils.h" #include "paddle/ir/core/utils.h"
...@@ -32,11 +33,7 @@ namespace { ...@@ -32,11 +33,7 @@ namespace {
constexpr char newline[] = "\n"; constexpr char newline[] = "\n";
} // namespace } // namespace
class BasicIRPrinter { void BasicIrPrinter::PrintType(Type type) {
public:
explicit BasicIRPrinter(std::ostream& os) : os(os) {}
void PrintType(Type type) {
if (!type) { if (!type) {
os << "<<NULL TYPE>>"; os << "<<NULL TYPE>>";
return; return;
...@@ -67,9 +64,9 @@ class BasicIRPrinter { ...@@ -67,9 +64,9 @@ class BasicIRPrinter {
auto& dialect = type.dialect(); auto& dialect = type.dialect();
dialect.PrintType(type, os); dialect.PrintType(type, os);
} }
} }
void PrintAttribute(const Attribute& attr) { void BasicIrPrinter::PrintAttribute(const Attribute& attr) {
if (!attr) { if (!attr) {
os << "<#AttrNull>"; os << "<#AttrNull>";
return; return;
...@@ -87,6 +84,8 @@ class BasicIRPrinter { ...@@ -87,6 +84,8 @@ class BasicIRPrinter {
os << i.data(); os << i.data();
} else if (auto i = attr.dyn_cast<Int64Attribute>()) { } else if (auto i = attr.dyn_cast<Int64Attribute>()) {
os << i.data(); os << i.data();
} else if (auto p = attr.dyn_cast<PointerAttribute>()) {
os << p.data();
} else if (auto arr = attr.dyn_cast<ArrayAttribute>()) { } else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
const auto& vec = arr.data(); const auto& vec = arr.data();
os << "array["; os << "array[";
...@@ -100,32 +99,34 @@ class BasicIRPrinter { ...@@ -100,32 +99,34 @@ class BasicIRPrinter {
auto& dialect = attr.dialect(); auto& dialect = attr.dialect();
dialect.PrintAttribute(attr, os); dialect.PrintAttribute(attr, os);
} }
} }
public:
std::ostream& os;
};
class IRPrinter : public BasicIRPrinter {
public:
explicit IRPrinter(std::ostream& os) : BasicIRPrinter(os) {}
/// @brief print program
/// @param program
/// @example
void PrintProgram(Program* program) { PrintOperation(program->module_op()); }
/// @brief print operation void IrPrinter::PrintProgram(Program* program) {
/// @param op auto top_level_op = program->module_op();
/// @example for (size_t i = 0; i < top_level_op->num_regions(); ++i) {
void PrintOperation(Operation* op) { auto& region = top_level_op->GetRegion(i);
for (size_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->GetRegion(i);
for (auto it = region.begin(); it != region.end(); ++it) { for (auto it = region.begin(); it != region.end(); ++it) {
auto* block = *it; auto* block = *it;
os << "{\n"; os << "{\n";
for (auto it = block->begin(); it != block->end(); ++it) { for (auto it = block->begin(); it != block->end(); ++it) {
auto* op = *it; PrintOperation(*it);
os << newline;
}
os << "}\n";
}
}
}
void IrPrinter::PrintOperation(Operation* op) {
if (auto* dialect = op->dialect()) {
dialect->PrintOperation(op, *this);
return;
}
PrintGeneralOperation(op);
}
void IrPrinter::PrintGeneralOperation(Operation* op) {
// TODO(lyk): add API to get opresults directly // TODO(lyk): add API to get opresults directly
PrintOpResult(op); PrintOpResult(op);
os << " ="; os << " =";
...@@ -144,16 +145,36 @@ class IRPrinter : public BasicIRPrinter { ...@@ -144,16 +145,36 @@ class IRPrinter : public BasicIRPrinter {
// TODO(lyk): add API to get opresults directly // TODO(lyk): add API to get opresults directly
PrintOpReturnType(op); PrintOpReturnType(op);
}
void IrPrinter::PrintFullOperation(Operation* op) {
PrintOperation(op);
if (op->num_regions() > 0) {
os << newline; os << newline;
} }
os << "}\n"; for (size_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->GetRegion(i);
PrintRegion(region);
} }
}
void IrPrinter::PrintRegion(const Region& region) {
for (auto it = region.begin(); it != region.end(); ++it) {
auto* block = *it;
PrintBlock(block);
} }
}
void IrPrinter::PrintBlock(Block* block) {
os << "{\n";
for (auto it = block->begin(); it != block->end(); ++it) {
PrintOperation(*it);
os << newline;
} }
os << "}\n";
}
private: void IrPrinter::PrintValue(Value v) {
void PrintValue(Value v) {
if (!v) { if (!v) {
os << "<<NULL VALUE>>"; os << "<<NULL VALUE>>";
return; return;
...@@ -169,9 +190,9 @@ class IRPrinter : public BasicIRPrinter { ...@@ -169,9 +190,9 @@ class IRPrinter : public BasicIRPrinter {
cur_var_number_++; cur_var_number_++;
aliases_[key] = new_name; aliases_[key] = new_name;
os << new_name; os << new_name;
} }
void PrintOpResult(Operation* op) { void IrPrinter::PrintOpResult(Operation* op) {
os << " ("; os << " (";
auto num_op_result = op->num_results(); auto num_op_result = op->num_results();
std::vector<OpResult> op_results; std::vector<OpResult> op_results;
...@@ -185,9 +206,9 @@ class IRPrinter : public BasicIRPrinter { ...@@ -185,9 +206,9 @@ class IRPrinter : public BasicIRPrinter {
[this](Value v) { this->PrintValue(v); }, [this](Value v) { this->PrintValue(v); },
[this]() { this->os << ", "; }); [this]() { this->os << ", "; });
os << ")"; os << ")";
} }
void PrintAttributeMap(Operation* op) { void IrPrinter::PrintAttributeMap(Operation* op) {
os << " {"; os << " {";
PrintInterleave( PrintInterleave(
...@@ -201,9 +222,9 @@ class IRPrinter : public BasicIRPrinter { ...@@ -201,9 +222,9 @@ class IRPrinter : public BasicIRPrinter {
[this]() { this->os << ","; }); [this]() { this->os << ","; });
os << "}"; os << "}";
} }
void PrintOpOperands(Operation* op) { void IrPrinter::PrintOpOperands(Operation* op) {
os << " ("; os << " (";
auto num_op_operands = op->num_operands(); auto num_op_operands = op->num_operands();
std::vector<Value> op_operands; std::vector<Value> op_operands;
...@@ -217,9 +238,9 @@ class IRPrinter : public BasicIRPrinter { ...@@ -217,9 +238,9 @@ class IRPrinter : public BasicIRPrinter {
[this](Value v) { this->PrintValue(v); }, [this](Value v) { this->PrintValue(v); },
[this]() { this->os << ", "; }); [this]() { this->os << ", "; });
os << ")"; os << ")";
} }
void PrintOperandsType(Operation* op) { void IrPrinter::PrintOperandsType(Operation* op) {
auto num_op_operands = op->num_operands(); auto num_op_operands = op->num_operands();
std::vector<Type> op_operand_types; std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands); op_operand_types.reserve(num_op_operands);
...@@ -238,9 +259,9 @@ class IRPrinter : public BasicIRPrinter { ...@@ -238,9 +259,9 @@ class IRPrinter : public BasicIRPrinter {
[this](Type t) { this->PrintType(t); }, [this](Type t) { this->PrintType(t); },
[this]() { this->os << ", "; }); [this]() { this->os << ", "; });
os << ")"; os << ")";
} }
void PrintOpReturnType(Operation* op) { void IrPrinter::PrintOpReturnType(Operation* op) {
auto num_op_result = op->num_results(); auto num_op_result = op->num_results();
std::vector<Type> op_result_types; std::vector<Type> op_result_types;
op_result_types.reserve(num_op_result); op_result_types.reserve(num_op_result);
...@@ -257,30 +278,29 @@ class IRPrinter : public BasicIRPrinter { ...@@ -257,30 +278,29 @@ class IRPrinter : public BasicIRPrinter {
op_result_types.end(), op_result_types.end(),
[this](Type t) { this->PrintType(t); }, [this](Type t) { this->PrintType(t); },
[this]() { this->os << ", "; }); [this]() { this->os << ", "; });
} }
private: void Dialect::PrintOperation(Operation* op, IrPrinter& printer) const {
size_t cur_var_number_{0}; printer.PrintGeneralOperation(op);
std::unordered_map<const void*, std::string> aliases_; }
};
void Program::Print(std::ostream& os) { void Program::Print(std::ostream& os) {
IRPrinter printer(os); IrPrinter printer(os);
printer.PrintProgram(this); printer.PrintProgram(this);
} }
void Operation::Print(std::ostream& os) { void Operation::Print(std::ostream& os) {
IRPrinter printer(os); IrPrinter printer(os);
printer.PrintOperation(this); printer.PrintFullOperation(this);
} }
void Type::Print(std::ostream& os) const { void Type::Print(std::ostream& os) const {
BasicIRPrinter printer(os); BasicIrPrinter printer(os);
printer.PrintType(*this); printer.PrintType(*this);
} }
void Attribute::Print(std::ostream& os) const { void Attribute::Print(std::ostream& os) const {
BasicIRPrinter printer(os); BasicIrPrinter printer(os);
printer.PrintAttribute(*this); printer.PrintAttribute(*this);
} }
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ostream>
#include <string>
#include <unordered_map>
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/value.h"
namespace ir {
class BasicIrPrinter {
public:
explicit BasicIrPrinter(std::ostream& os) : os(os) {}
void PrintType(Type type);
void PrintAttribute(const Attribute& attr);
public:
std::ostream& os;
};
class IrPrinter : public BasicIrPrinter {
public:
explicit IrPrinter(std::ostream& os) : BasicIrPrinter(os) {}
/// @brief print program
/// @param program
void PrintProgram(Program* program);
/// @brief dispatch to custom printer function or PrintGeneralOperation
void PrintOperation(Operation* op);
/// @brief print operation itself without its regions
void PrintGeneralOperation(Operation* op);
/// @brief print operation and its regions
void PrintFullOperation(Operation* op);
void PrintRegion(const Region& Region);
void PrintBlock(Block* block);
void PrintValue(Value v);
void PrintOpResult(Operation* op);
void PrintAttributeMap(Operation* op);
void PrintOpOperands(Operation* op);
void PrintOperandsType(Operation* op);
void PrintOpReturnType(Operation* op);
private:
size_t cur_var_number_{0};
std::unordered_map<const void*, std::string> aliases_;
};
} // namespace ir
...@@ -29,6 +29,7 @@ bool OpInfo::HasInterface(TypeId interface_id) const { ...@@ -29,6 +29,7 @@ bool OpInfo::HasInterface(TypeId interface_id) const {
IrContext *OpInfo::ir_context() const { IrContext *OpInfo::ir_context() const {
return impl_ ? impl_->ir_context() : nullptr; return impl_ ? impl_->ir_context() : nullptr;
} }
Dialect *OpInfo::dialect() const { return impl_ ? impl_->dialect() : nullptr; }
const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; } const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <functional> #include <functional>
#include <unordered_map> #include <unordered_map>
#include "paddle/ir/core/type_id.h" #include "paddle/ir/core/type_id.h"
namespace ir { namespace ir {
...@@ -23,6 +24,7 @@ class IrContext; ...@@ -23,6 +24,7 @@ class IrContext;
class OpResult; class OpResult;
class Type; class Type;
class Attribute; class Attribute;
class Dialect;
class OpInfo { class OpInfo {
public: public:
...@@ -41,6 +43,7 @@ class OpInfo { ...@@ -41,6 +43,7 @@ class OpInfo {
bool operator!() const { return impl_ == nullptr; } bool operator!() const { return impl_ == nullptr; }
IrContext *ir_context() const; IrContext *ir_context() const;
Dialect *dialect() const;
const char *name() const; const char *name() const;
......
...@@ -166,6 +166,8 @@ void Operation::Destroy() { ...@@ -166,6 +166,8 @@ void Operation::Destroy() {
IrContext *Operation::ir_context() const { return info_.ir_context(); } IrContext *Operation::ir_context() const { return info_.ir_context(); }
Dialect *Operation::dialect() const { return info_.dialect(); }
Operation::Operation(const AttributeMap &attributes, Operation::Operation(const AttributeMap &attributes,
ir::OpInfo op_info, ir::OpInfo op_info,
uint32_t num_results, uint32_t num_results,
......
...@@ -47,7 +47,7 @@ class alignas(8) Operation final { ...@@ -47,7 +47,7 @@ class alignas(8) Operation final {
void Destroy(); void Destroy();
IrContext *ir_context() const; IrContext *ir_context() const;
Dialect *dialect() const;
OpResult GetResultByIndex(uint32_t index) const; OpResult GetResultByIndex(uint32_t index) const;
OpOperand GetOperandByIndex(uint32_t index) const; OpOperand GetOperandByIndex(uint32_t index) const;
......
...@@ -35,6 +35,8 @@ class Region { ...@@ -35,6 +35,8 @@ class Region {
iterator begin() { return blocks_.begin(); } iterator begin() { return blocks_.begin(); }
iterator end() { return blocks_.end(); } iterator end() { return blocks_.end(); }
const_iterator begin() const { return blocks_.begin(); }
const_iterator end() const { return blocks_.end(); }
reverse_iterator rbegin() { return blocks_.rbegin(); } reverse_iterator rbegin() { return blocks_.rbegin(); }
reverse_iterator rend() { return blocks_.rend(); } reverse_iterator rend() { return blocks_.rend(); }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <sstream>
#include "paddle/ir/core/block.h" #include "paddle/ir/core/block.h"
#include "paddle/ir/core/builder.h" #include "paddle/ir/core/builder.h"
...@@ -22,6 +23,7 @@ ...@@ -22,6 +23,7 @@
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/ir_printer.h"
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h" #include "paddle/ir/core/region.h"
...@@ -150,6 +152,15 @@ class TestDialect : public ir::Dialect { ...@@ -150,6 +152,15 @@ class TestDialect : public ir::Dialect {
} }
static const char *name() { return "test"; } static const char *name() { return "test"; }
void PrintOperation(ir::Operation *op,
ir::IrPrinter &printer) const override {
printer.PrintOpResult(op);
printer.os << " =";
printer.os << " \"" << op->name() << "\"";
printer.PrintOpOperands(op);
}
private: private:
void initialize() { RegisterOps<Operation1, Operation2>(); } void initialize() { RegisterOps<Operation1, Operation2>(); }
}; };
...@@ -222,6 +233,11 @@ TEST(op_test, region_test) { ...@@ -222,6 +233,11 @@ TEST(op_test, region_test) {
ir::Region *region = argument.regions.back().get(); ir::Region *region = argument.regions.back().get();
EXPECT_EQ(region->empty(), true); EXPECT_EQ(region->empty(), true);
// (3) Test custom operation printer
std::stringstream ss;
op1->Print(ss);
EXPECT_EQ(ss.str(), " (%0) = \"test.operation1\" ()");
region->push_back(new ir::Block()); region->push_back(new ir::Block());
region->push_front(new ir::Block()); region->push_front(new ir::Block());
region->insert(region->begin(), new ir::Block()); region->insert(region->begin(), new ir::Block());
...@@ -236,7 +252,6 @@ TEST(op_test, module_op_death) { ...@@ -236,7 +252,6 @@ TEST(op_test, module_op_death) {
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(ir::ModuleOp::name()); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(ir::ModuleOp::name());
// (3) Test uses for op.
std::vector<ir::OpResult> inputs{ir::OpResult()}; std::vector<ir::OpResult> inputs{ir::OpResult()};
ir::AttributeMap attrs{{"program", ir::Int32Attribute::get(ctx, 1)}}; ir::AttributeMap attrs{{"program", ir::Int32Attribute::get(ctx, 1)}};
std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)};
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <chrono> #include <chrono>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <sstream>
#include <string> #include <string>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
...@@ -61,7 +62,9 @@ TEST(PaddleDialectTest, MainProgram) { ...@@ -61,7 +62,9 @@ TEST(PaddleDialectTest, MainProgram) {
EXPECT_EQ(op_size, EXPECT_EQ(op_size,
p.Block(0).OpSize() + program->parameters_num() + 20 + 3 + 8); p.Block(0).OpSize() + program->parameters_num() + 20 + 3 + 8);
program->Print(std::cout); std::stringstream ss;
program->Print(ss);
EXPECT_GT(ss.str().size(), 0u);
} }
TEST(PaddleDialectTest, StartupProgram) { TEST(PaddleDialectTest, StartupProgram) {
...@@ -79,5 +82,7 @@ TEST(PaddleDialectTest, StartupProgram) { ...@@ -79,5 +82,7 @@ TEST(PaddleDialectTest, StartupProgram) {
// + consant_op for guassian // + consant_op for guassian
EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 3 + 53); EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 3 + 53);
program->Print(std::cout); std::stringstream ss;
program->Print(ss);
EXPECT_GT(ss.str().size(), 0u);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册