未验证 提交 7d688871 编写于 作者: K kangguangli 提交者: GitHub

[IR] Support custom op printer (#54499)

* adapt_startup_program

* refactor program translator

* polish

* add custom op printer hook

* fix merge conflicts

* fix top level op printer

* adapt full int array op

* modify by reviews

* fix
上级 64ecdc03
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir_adaptor/translator/attribute_translator.h" #include "paddle/fluid/ir_adaptor/translator/attribute_translator.h"
...@@ -198,20 +199,18 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx, ...@@ -198,20 +199,18 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx,
inline ir::Operation* InsertFullArrayOperationForAttributeInput( inline ir::Operation* InsertFullArrayOperationForAttributeInput(
ir::IrContext* ctx, ir::Program* program, ir::Attribute attr) { ir::IrContext* ctx, ir::Program* program, ir::Attribute attr) {
std::string constant_op_name(ir::ConstantOp::name()); IR_ENFORCE(attr.isa<paddle::dialect::IntArrayAttribute>(),
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name); "Encounter non IntArray type when trying to insert IntArray "
"mutable attribute");
ir::Type null_type = paddle::dialect::DenseTensorType::get(
ctx, phi::IntArray int_array =
ir::Type(nullptr), attr.dyn_cast<paddle::dialect::IntArrayAttribute>().data();
phi::DDim{},
paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED, ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program->block());
phi::LoD{}, paddle::dialect::FullIntArrayOp full_int_array_op =
0); // TODO(lyk): to be done builder.Build<paddle::dialect::FullIntArrayOp>(
ir::Operation* operation = int_array.GetData(), phi::DataType::INT64, phi::CPUPlace());
ir::Operation::Create({}, {{"value", attr}}, {null_type}, op_info); return full_int_array_op.operation();
program->block()->push_back(operation);
return operation;
} }
inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx, inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <functional>
#include <ostream> #include <ostream>
#include "paddle/ir/core/attribute.h" #include "paddle/ir/core/attribute.h"
...@@ -25,6 +26,10 @@ ...@@ -25,6 +26,10 @@
#include "paddle/ir/core/type_base.h" #include "paddle/ir/core/type_base.h"
namespace ir { namespace ir {
class Operation;
class IrPrinter;
class DialectInterface; class DialectInterface;
/// ///
/// \brief Dialect can basically be understood as a namespace. In Dialect, we /// \brief Dialect can basically be understood as a namespace. In Dialect, we
...@@ -136,10 +141,13 @@ class Dialect { ...@@ -136,10 +141,13 @@ class Dialect {
IR_THROW("dialect has no registered type printing hook"); IR_THROW("dialect has no registered type printing hook");
} }
virtual void PrintAttribute(Attribute type, std::ostream &os) const { virtual void PrintAttribute(Attribute attr, std::ostream &os) const {
IR_THROW("dialect has no registered attribute printing hook"); IR_THROW("dialect has no registered attribute printing hook");
} }
virtual void PrintOperation(Operation *op,
IrPrinter &printer) const; // NOLINT
private: private:
Dialect(const Dialect &) = delete; Dialect(const Dialect &) = delete;
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_printer.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
#include "paddle/ir/core/utils.h" #include "paddle/ir/core/utils.h"
...@@ -32,255 +33,274 @@ namespace { ...@@ -32,255 +33,274 @@ namespace {
constexpr char newline[] = "\n"; constexpr char newline[] = "\n";
} // namespace } // namespace
class BasicIRPrinter { void BasicIrPrinter::PrintType(Type type) {
public: if (!type) {
explicit BasicIRPrinter(std::ostream& os) : os(os) {} os << "<<NULL TYPE>>";
return;
void PrintType(Type type) { }
if (!type) {
os << "<<NULL TYPE>>";
return;
}
if (type.isa<Float16Type>()) { if (type.isa<Float16Type>()) {
os << "f16"; os << "f16";
} else if (type.isa<Float32Type>()) { } else if (type.isa<Float32Type>()) {
os << "f32"; os << "f32";
} else if (type.isa<Float64Type>()) { } else if (type.isa<Float64Type>()) {
os << "f64"; os << "f64";
} else if (type.isa<Int16Type>()) { } else if (type.isa<Int16Type>()) {
os << "i16"; os << "i16";
} else if (type.isa<Int32Type>()) { } else if (type.isa<Int32Type>()) {
os << "i32"; os << "i32";
} else if (type.isa<Int64Type>()) { } else if (type.isa<Int64Type>()) {
os << "i64"; os << "i64";
} else if (type.isa<VectorType>()) { } else if (type.isa<VectorType>()) {
os << "vec["; os << "vec[";
auto inner_types = type.dyn_cast<VectorType>().data(); auto inner_types = type.dyn_cast<VectorType>().data();
PrintInterleave( PrintInterleave(
inner_types.begin(), inner_types.begin(),
inner_types.end(), inner_types.end(),
[this](Type v) { this->PrintType(v); }, [this](Type v) { this->PrintType(v); },
[this]() { this->os << ","; }); [this]() { this->os << ","; });
os << "]"; os << "]";
} else { } else {
auto& dialect = type.dialect(); auto& dialect = type.dialect();
dialect.PrintType(type, os); dialect.PrintType(type, os);
}
} }
}
void PrintAttribute(const Attribute& attr) { void BasicIrPrinter::PrintAttribute(const Attribute& attr) {
if (!attr) { if (!attr) {
os << "<#AttrNull>"; os << "<#AttrNull>";
return; return;
} }
if (auto s = attr.dyn_cast<StrAttribute>()) { if (auto s = attr.dyn_cast<StrAttribute>()) {
os << s.data(); os << s.data();
} else if (auto b = attr.dyn_cast<BoolAttribute>()) { } else if (auto b = attr.dyn_cast<BoolAttribute>()) {
os << b.data(); os << b.data();
} else if (auto f = attr.dyn_cast<FloatAttribute>()) { } else if (auto f = attr.dyn_cast<FloatAttribute>()) {
os << f.data(); os << f.data();
} else if (auto d = attr.dyn_cast<DoubleAttribute>()) { } else if (auto d = attr.dyn_cast<DoubleAttribute>()) {
os << d.data(); os << d.data();
} else if (auto i = attr.dyn_cast<Int32Attribute>()) { } else if (auto i = attr.dyn_cast<Int32Attribute>()) {
os << i.data(); os << i.data();
} else if (auto i = attr.dyn_cast<Int64Attribute>()) { } else if (auto i = attr.dyn_cast<Int64Attribute>()) {
os << i.data(); os << i.data();
} else if (auto arr = attr.dyn_cast<ArrayAttribute>()) { } else if (auto p = attr.dyn_cast<PointerAttribute>()) {
const auto& vec = arr.data(); os << p.data();
os << "array["; } else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
PrintInterleave( const auto& vec = arr.data();
vec.begin(), os << "array[";
vec.end(), PrintInterleave(
[this](Attribute v) { this->PrintAttribute(v); }, vec.begin(),
[this]() { this->os << ","; }); vec.end(),
os << "]"; [this](Attribute v) { this->PrintAttribute(v); },
} else { [this]() { this->os << ","; });
auto& dialect = attr.dialect(); os << "]";
dialect.PrintAttribute(attr, os); } else {
} auto& dialect = attr.dialect();
dialect.PrintAttribute(attr, os);
} }
}
public: void IrPrinter::PrintProgram(Program* program) {
std::ostream& os; auto top_level_op = program->module_op();
}; for (size_t i = 0; i < top_level_op->num_regions(); ++i) {
auto& region = top_level_op->GetRegion(i);
class IRPrinter : public BasicIRPrinter { for (auto it = region.begin(); it != region.end(); ++it) {
public: auto* block = *it;
explicit IRPrinter(std::ostream& os) : BasicIRPrinter(os) {} os << "{\n";
for (auto it = block->begin(); it != block->end(); ++it) {
/// @brief print program PrintOperation(*it);
/// @param program os << newline;
/// @example
void PrintProgram(Program* program) { PrintOperation(program->module_op()); }
/// @brief print operation
/// @param op
/// @example
void PrintOperation(Operation* op) {
for (size_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->GetRegion(i);
for (auto it = region.begin(); it != region.end(); ++it) {
auto* block = *it;
os << "{\n";
for (auto it = block->begin(); it != block->end(); ++it) {
auto* op = *it;
// TODO(lyk): add API to get opresults directly
PrintOpResult(op);
os << " =";
os << " \"" << op->name() << "\"";
// TODO(lyk): add API to get operands directly
PrintOpOperands(op);
PrintAttributeMap(op);
os << " :";
// PrintOpSingature
PrintOperandsType(op);
os << " -> ";
// TODO(lyk): add API to get opresults directly
PrintOpReturnType(op);
os << newline;
}
os << "}\n";
} }
os << "}\n";
} }
} }
}
private: void IrPrinter::PrintOperation(Operation* op) {
void PrintValue(Value v) { if (auto* dialect = op->dialect()) {
if (!v) { dialect->PrintOperation(op, *this);
os << "<<NULL VALUE>>"; return;
return; }
}
const void* key = static_cast<const void*>(v.impl()); PrintGeneralOperation(op);
auto ret = aliases_.find(key); }
if (ret != aliases_.end()) {
os << ret->second; void IrPrinter::PrintGeneralOperation(Operation* op) {
return; // TODO(lyk): add API to get opresults directly
} PrintOpResult(op);
os << " =";
os << " \"" << op->name() << "\"";
// TODO(lyk): add API to get operands directly
PrintOpOperands(op);
PrintAttributeMap(op);
os << " :";
// PrintOpSingature
PrintOperandsType(op);
os << " -> ";
// TODO(lyk): add API to get opresults directly
PrintOpReturnType(op);
}
void IrPrinter::PrintFullOperation(Operation* op) {
PrintOperation(op);
if (op->num_regions() > 0) {
os << newline;
}
for (size_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->GetRegion(i);
PrintRegion(region);
}
}
std::string new_name = "%" + std::to_string(cur_var_number_); void IrPrinter::PrintRegion(const Region& region) {
cur_var_number_++; for (auto it = region.begin(); it != region.end(); ++it) {
aliases_[key] = new_name; auto* block = *it;
os << new_name; PrintBlock(block);
} }
}
void PrintOpResult(Operation* op) { void IrPrinter::PrintBlock(Block* block) {
os << " ("; os << "{\n";
auto num_op_result = op->num_results(); for (auto it = block->begin(); it != block->end(); ++it) {
std::vector<OpResult> op_results; PrintOperation(*it);
op_results.reserve(num_op_result); os << newline;
for (size_t idx = 0; idx < num_op_result; idx++) {
op_results.push_back(op->GetResultByIndex(idx));
}
PrintInterleave(
op_results.begin(),
op_results.end(),
[this](Value v) { this->PrintValue(v); },
[this]() { this->os << ", "; });
os << ")";
} }
os << "}\n";
}
void PrintAttributeMap(Operation* op) { void IrPrinter::PrintValue(Value v) {
os << " {"; if (!v) {
os << "<<NULL VALUE>>";
return;
}
const void* key = static_cast<const void*>(v.impl());
auto ret = aliases_.find(key);
if (ret != aliases_.end()) {
os << ret->second;
return;
}
PrintInterleave( std::string new_name = "%" + std::to_string(cur_var_number_);
op->attributes().begin(), cur_var_number_++;
op->attributes().end(), aliases_[key] = new_name;
[this](std::pair<std::string, Attribute> it) { os << new_name;
this->os << it.first; }
this->os << ":";
this->PrintAttribute(it.second);
},
[this]() { this->os << ","; });
os << "}"; void IrPrinter::PrintOpResult(Operation* op) {
os << " (";
auto num_op_result = op->num_results();
std::vector<OpResult> op_results;
op_results.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) {
op_results.push_back(op->GetResultByIndex(idx));
} }
PrintInterleave(
op_results.begin(),
op_results.end(),
[this](Value v) { this->PrintValue(v); },
[this]() { this->os << ", "; });
os << ")";
}
void PrintOpOperands(Operation* op) { void IrPrinter::PrintAttributeMap(Operation* op) {
os << " ("; os << " {";
auto num_op_operands = op->num_operands();
std::vector<Value> op_operands; PrintInterleave(
op_operands.reserve(num_op_operands); op->attributes().begin(),
for (size_t idx = 0; idx < num_op_operands; idx++) { op->attributes().end(),
op_operands.push_back(op->GetOperandByIndex(idx).source()); [this](std::pair<std::string, Attribute> it) {
} this->os << it.first;
PrintInterleave( this->os << ":";
op_operands.begin(), this->PrintAttribute(it.second);
op_operands.end(), },
[this](Value v) { this->PrintValue(v); }, [this]() { this->os << ","; });
[this]() { this->os << ", "; });
os << ")"; os << "}";
}
void IrPrinter::PrintOpOperands(Operation* op) {
os << " (";
auto num_op_operands = op->num_operands();
std::vector<Value> op_operands;
op_operands.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
op_operands.push_back(op->GetOperandByIndex(idx).source());
} }
PrintInterleave(
op_operands.begin(),
op_operands.end(),
[this](Value v) { this->PrintValue(v); },
[this]() { this->os << ", "; });
os << ")";
}
void PrintOperandsType(Operation* op) { void IrPrinter::PrintOperandsType(Operation* op) {
auto num_op_operands = op->num_operands(); auto num_op_operands = op->num_operands();
std::vector<Type> op_operand_types; std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands); op_operand_types.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) { for (size_t idx = 0; idx < num_op_operands; idx++) {
auto op_operand = op->GetOperandByIndex(idx); auto op_operand = op->GetOperandByIndex(idx);
if (op_operand) { if (op_operand) {
op_operand_types.push_back(op->GetOperandByIndex(idx).source().type()); op_operand_types.push_back(op->GetOperandByIndex(idx).source().type());
} else { } else {
op_operand_types.push_back(Type(nullptr)); op_operand_types.push_back(Type(nullptr));
}
} }
os << " (";
PrintInterleave(
op_operand_types.begin(),
op_operand_types.end(),
[this](Type t) { this->PrintType(t); },
[this]() { this->os << ", "; });
os << ")";
} }
os << " (";
PrintInterleave(
op_operand_types.begin(),
op_operand_types.end(),
[this](Type t) { this->PrintType(t); },
[this]() { this->os << ", "; });
os << ")";
}
void PrintOpReturnType(Operation* op) { void IrPrinter::PrintOpReturnType(Operation* op) {
auto num_op_result = op->num_results(); auto num_op_result = op->num_results();
std::vector<Type> op_result_types; std::vector<Type> op_result_types;
op_result_types.reserve(num_op_result); op_result_types.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) { for (size_t idx = 0; idx < num_op_result; idx++) {
auto op_result = op->GetResultByIndex(idx); auto op_result = op->GetResultByIndex(idx);
if (op_result) { if (op_result) {
op_result_types.push_back(op_result.type()); op_result_types.push_back(op_result.type());
} else { } else {
op_result_types.push_back(Type(nullptr)); op_result_types.push_back(Type(nullptr));
}
} }
PrintInterleave(
op_result_types.begin(),
op_result_types.end(),
[this](Type t) { this->PrintType(t); },
[this]() { this->os << ", "; });
} }
PrintInterleave(
op_result_types.begin(),
op_result_types.end(),
[this](Type t) { this->PrintType(t); },
[this]() { this->os << ", "; });
}
private: void Dialect::PrintOperation(Operation* op, IrPrinter& printer) const {
size_t cur_var_number_{0}; printer.PrintGeneralOperation(op);
std::unordered_map<const void*, std::string> aliases_; }
};
void Program::Print(std::ostream& os) { void Program::Print(std::ostream& os) {
IRPrinter printer(os); IrPrinter printer(os);
printer.PrintProgram(this); printer.PrintProgram(this);
} }
void Operation::Print(std::ostream& os) { void Operation::Print(std::ostream& os) {
IRPrinter printer(os); IrPrinter printer(os);
printer.PrintOperation(this); printer.PrintFullOperation(this);
} }
void Type::Print(std::ostream& os) const { void Type::Print(std::ostream& os) const {
BasicIRPrinter printer(os); BasicIrPrinter printer(os);
printer.PrintType(*this); printer.PrintType(*this);
} }
void Attribute::Print(std::ostream& os) const { void Attribute::Print(std::ostream& os) const {
BasicIRPrinter printer(os); BasicIrPrinter printer(os);
printer.PrintAttribute(*this); printer.PrintAttribute(*this);
} }
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ostream>
#include <string>
#include <unordered_map>
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/value.h"
namespace ir {
class BasicIrPrinter {
public:
explicit BasicIrPrinter(std::ostream& os) : os(os) {}
void PrintType(Type type);
void PrintAttribute(const Attribute& attr);
public:
std::ostream& os;
};
class IrPrinter : public BasicIrPrinter {
public:
explicit IrPrinter(std::ostream& os) : BasicIrPrinter(os) {}
/// @brief print program
/// @param program
void PrintProgram(Program* program);
/// @brief dispatch to custom printer function or PrintGeneralOperation
void PrintOperation(Operation* op);
/// @brief print operation itself without its regions
void PrintGeneralOperation(Operation* op);
/// @brief print operation and its regions
void PrintFullOperation(Operation* op);
void PrintRegion(const Region& Region);
void PrintBlock(Block* block);
void PrintValue(Value v);
void PrintOpResult(Operation* op);
void PrintAttributeMap(Operation* op);
void PrintOpOperands(Operation* op);
void PrintOperandsType(Operation* op);
void PrintOpReturnType(Operation* op);
private:
size_t cur_var_number_{0};
std::unordered_map<const void*, std::string> aliases_;
};
} // namespace ir
...@@ -29,6 +29,7 @@ bool OpInfo::HasInterface(TypeId interface_id) const { ...@@ -29,6 +29,7 @@ bool OpInfo::HasInterface(TypeId interface_id) const {
IrContext *OpInfo::ir_context() const { IrContext *OpInfo::ir_context() const {
return impl_ ? impl_->ir_context() : nullptr; return impl_ ? impl_->ir_context() : nullptr;
} }
Dialect *OpInfo::dialect() const { return impl_ ? impl_->dialect() : nullptr; }
const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; } const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <functional> #include <functional>
#include <unordered_map> #include <unordered_map>
#include "paddle/ir/core/type_id.h" #include "paddle/ir/core/type_id.h"
namespace ir { namespace ir {
...@@ -23,6 +24,7 @@ class IrContext; ...@@ -23,6 +24,7 @@ class IrContext;
class OpResult; class OpResult;
class Type; class Type;
class Attribute; class Attribute;
class Dialect;
class OpInfo { class OpInfo {
public: public:
...@@ -41,6 +43,7 @@ class OpInfo { ...@@ -41,6 +43,7 @@ class OpInfo {
bool operator!() const { return impl_ == nullptr; } bool operator!() const { return impl_ == nullptr; }
IrContext *ir_context() const; IrContext *ir_context() const;
Dialect *dialect() const;
const char *name() const; const char *name() const;
......
...@@ -166,6 +166,8 @@ void Operation::Destroy() { ...@@ -166,6 +166,8 @@ void Operation::Destroy() {
IrContext *Operation::ir_context() const { return info_.ir_context(); } IrContext *Operation::ir_context() const { return info_.ir_context(); }
Dialect *Operation::dialect() const { return info_.dialect(); }
Operation::Operation(const AttributeMap &attributes, Operation::Operation(const AttributeMap &attributes,
ir::OpInfo op_info, ir::OpInfo op_info,
uint32_t num_results, uint32_t num_results,
......
...@@ -47,7 +47,7 @@ class alignas(8) Operation final { ...@@ -47,7 +47,7 @@ class alignas(8) Operation final {
void Destroy(); void Destroy();
IrContext *ir_context() const; IrContext *ir_context() const;
Dialect *dialect() const;
OpResult GetResultByIndex(uint32_t index) const; OpResult GetResultByIndex(uint32_t index) const;
OpOperand GetOperandByIndex(uint32_t index) const; OpOperand GetOperandByIndex(uint32_t index) const;
......
...@@ -35,6 +35,8 @@ class Region { ...@@ -35,6 +35,8 @@ class Region {
iterator begin() { return blocks_.begin(); } iterator begin() { return blocks_.begin(); }
iterator end() { return blocks_.end(); } iterator end() { return blocks_.end(); }
const_iterator begin() const { return blocks_.begin(); }
const_iterator end() const { return blocks_.end(); }
reverse_iterator rbegin() { return blocks_.rbegin(); } reverse_iterator rbegin() { return blocks_.rbegin(); }
reverse_iterator rend() { return blocks_.rend(); } reverse_iterator rend() { return blocks_.rend(); }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <sstream>
#include "paddle/ir/core/block.h" #include "paddle/ir/core/block.h"
#include "paddle/ir/core/builder.h" #include "paddle/ir/core/builder.h"
...@@ -22,6 +23,7 @@ ...@@ -22,6 +23,7 @@
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/ir_printer.h"
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h" #include "paddle/ir/core/region.h"
...@@ -150,6 +152,15 @@ class TestDialect : public ir::Dialect { ...@@ -150,6 +152,15 @@ class TestDialect : public ir::Dialect {
} }
static const char *name() { return "test"; } static const char *name() { return "test"; }
void PrintOperation(ir::Operation *op,
ir::IrPrinter &printer) const override {
printer.PrintOpResult(op);
printer.os << " =";
printer.os << " \"" << op->name() << "\"";
printer.PrintOpOperands(op);
}
private: private:
void initialize() { RegisterOps<Operation1, Operation2>(); } void initialize() { RegisterOps<Operation1, Operation2>(); }
}; };
...@@ -222,6 +233,11 @@ TEST(op_test, region_test) { ...@@ -222,6 +233,11 @@ TEST(op_test, region_test) {
ir::Region *region = argument.regions.back().get(); ir::Region *region = argument.regions.back().get();
EXPECT_EQ(region->empty(), true); EXPECT_EQ(region->empty(), true);
// (3) Test custom operation printer
std::stringstream ss;
op1->Print(ss);
EXPECT_EQ(ss.str(), " (%0) = \"test.operation1\" ()");
region->push_back(new ir::Block()); region->push_back(new ir::Block());
region->push_front(new ir::Block()); region->push_front(new ir::Block());
region->insert(region->begin(), new ir::Block()); region->insert(region->begin(), new ir::Block());
...@@ -236,7 +252,6 @@ TEST(op_test, module_op_death) { ...@@ -236,7 +252,6 @@ TEST(op_test, module_op_death) {
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(ir::ModuleOp::name()); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(ir::ModuleOp::name());
// (3) Test uses for op.
std::vector<ir::OpResult> inputs{ir::OpResult()}; std::vector<ir::OpResult> inputs{ir::OpResult()};
ir::AttributeMap attrs{{"program", ir::Int32Attribute::get(ctx, 1)}}; ir::AttributeMap attrs{{"program", ir::Int32Attribute::get(ctx, 1)}};
std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)};
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <chrono> #include <chrono>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <sstream>
#include <string> #include <string>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
...@@ -61,7 +62,9 @@ TEST(PaddleDialectTest, MainProgram) { ...@@ -61,7 +62,9 @@ TEST(PaddleDialectTest, MainProgram) {
EXPECT_EQ(op_size, EXPECT_EQ(op_size,
p.Block(0).OpSize() + program->parameters_num() + 20 + 3 + 8); p.Block(0).OpSize() + program->parameters_num() + 20 + 3 + 8);
program->Print(std::cout); std::stringstream ss;
program->Print(ss);
EXPECT_GT(ss.str().size(), 0u);
} }
TEST(PaddleDialectTest, StartupProgram) { TEST(PaddleDialectTest, StartupProgram) {
...@@ -79,5 +82,7 @@ TEST(PaddleDialectTest, StartupProgram) { ...@@ -79,5 +82,7 @@ TEST(PaddleDialectTest, StartupProgram) {
// + consant_op for guassian // + consant_op for guassian
EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 3 + 53); EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 3 + 53);
program->Print(std::cout); std::stringstream ss;
program->Print(ss);
EXPECT_GT(ss.str().size(), 0u);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册