diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 9bd8c7452cde2e7a1ec8855c104f25ada9320c2c..b531000896bf6ba5da3cb9e1af8b5e993b9de5c1 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -23,6 +23,7 @@ #include #include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/ir/dialect/pd_attribute.h" #include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/fluid/ir_adaptor/translator/attribute_translator.h" @@ -198,20 +199,18 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx, inline ir::Operation* InsertFullArrayOperationForAttributeInput( ir::IrContext* ctx, ir::Program* program, ir::Attribute attr) { - std::string constant_op_name(ir::ConstantOp::name()); - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name); - - ir::Type null_type = paddle::dialect::DenseTensorType::get( - ctx, - ir::Type(nullptr), - phi::DDim{}, - paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED, - phi::LoD{}, - 0); // TODO(lyk): to be done - ir::Operation* operation = - ir::Operation::Create({}, {{"value", attr}}, {null_type}, op_info); - program->block()->push_back(operation); - return operation; + IR_ENFORCE(attr.isa(), + "Encounter non IntArray type when trying to insert IntArray " + "mutable attribute"); + + phi::IntArray int_array = + attr.dyn_cast().data(); + + ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program->block()); + paddle::dialect::FullIntArrayOp full_int_array_op = + builder.Build( + int_array.GetData(), phi::DataType::INT64, phi::CPUPlace()); + return full_int_array_op.operation(); } inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx, diff --git a/paddle/ir/core/dialect.h b/paddle/ir/core/dialect.h index 1eabc8010d670dc71f93387b2d1f497d01c672cb..0edce0e5ab585bb3eac6d00668f5eac74decbb03 100644 --- a/paddle/ir/core/dialect.h +++ b/paddle/ir/core/dialect.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "paddle/ir/core/attribute.h" @@ -25,6 +26,10 @@ #include "paddle/ir/core/type_base.h" namespace ir { + +class Operation; +class IrPrinter; + class DialectInterface; /// /// \brief Dialect can basically be understood as a namespace. In Dialect, we @@ -136,10 +141,13 @@ class Dialect { IR_THROW("dialect has no registered type printing hook"); } - virtual void PrintAttribute(Attribute type, std::ostream &os) const { + virtual void PrintAttribute(Attribute attr, std::ostream &os) const { IR_THROW("dialect has no registered attribute printing hook"); } + virtual void PrintOperation(Operation *op, + IrPrinter &printer) const; // NOLINT + private: Dialect(const Dialect &) = delete; diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index 900a6a147753e197cc52e56c9caf46ca3af46ac6..6899ac1467c82bcf8d222aab3c3f9eb582c1eb43 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -21,6 +21,7 @@ #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/dialect.h" +#include "paddle/ir/core/ir_printer.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/program.h" #include "paddle/ir/core/utils.h" @@ -32,255 +33,274 @@ namespace { constexpr char newline[] = "\n"; } // namespace -class BasicIRPrinter { - public: - explicit BasicIRPrinter(std::ostream& os) : os(os) {} - - void PrintType(Type type) { - if (!type) { - os << "<>"; - return; - } +void BasicIrPrinter::PrintType(Type type) { + if (!type) { + os << "<>"; + return; + } - if (type.isa()) { - os << "f16"; - } else if (type.isa()) { - os << "f32"; - } else if (type.isa()) { - os << "f64"; - } else if (type.isa()) { - os << "i16"; - } else if (type.isa()) { - os << "i32"; - } else if (type.isa()) { - os << "i64"; - } else if (type.isa()) { - os << "vec["; - auto inner_types = type.dyn_cast().data(); - PrintInterleave( - inner_types.begin(), - inner_types.end(), - [this](Type v) { this->PrintType(v); }, - [this]() { this->os << ","; }); - os << "]"; - } else { - auto& dialect = type.dialect(); - dialect.PrintType(type, os); - } + if (type.isa()) { + os << "f16"; + } else if (type.isa()) { + os << "f32"; + } else if (type.isa()) { + os << "f64"; + } else if (type.isa()) { + os << "i16"; + } else if (type.isa()) { + os << "i32"; + } else if (type.isa()) { + os << "i64"; + } else if (type.isa()) { + os << "vec["; + auto inner_types = type.dyn_cast().data(); + PrintInterleave( + inner_types.begin(), + inner_types.end(), + [this](Type v) { this->PrintType(v); }, + [this]() { this->os << ","; }); + os << "]"; + } else { + auto& dialect = type.dialect(); + dialect.PrintType(type, os); } +} - void PrintAttribute(const Attribute& attr) { - if (!attr) { - os << "<#AttrNull>"; - return; - } +void BasicIrPrinter::PrintAttribute(const Attribute& attr) { + if (!attr) { + os << "<#AttrNull>"; + return; + } - if (auto s = attr.dyn_cast()) { - os << s.data(); - } else if (auto b = attr.dyn_cast()) { - os << b.data(); - } else if (auto f = attr.dyn_cast()) { - os << f.data(); - } else if (auto d = attr.dyn_cast()) { - os << d.data(); - } else if (auto i = attr.dyn_cast()) { - os << i.data(); - } else if (auto i = attr.dyn_cast()) { - os << i.data(); - } else if (auto arr = attr.dyn_cast()) { - const auto& vec = arr.data(); - os << "array["; - PrintInterleave( - vec.begin(), - vec.end(), - [this](Attribute v) { this->PrintAttribute(v); }, - [this]() { this->os << ","; }); - os << "]"; - } else { - auto& dialect = attr.dialect(); - dialect.PrintAttribute(attr, os); - } + if (auto s = attr.dyn_cast()) { + os << s.data(); + } else if (auto b = attr.dyn_cast()) { + os << b.data(); + } else if (auto f = attr.dyn_cast()) { + os << f.data(); + } else if (auto d = attr.dyn_cast()) { + os << d.data(); + } else if (auto i = attr.dyn_cast()) { + os << i.data(); + } else if (auto i = attr.dyn_cast()) { + os << i.data(); + } else if (auto p = attr.dyn_cast()) { + os << p.data(); + } else if (auto arr = attr.dyn_cast()) { + const auto& vec = arr.data(); + os << "array["; + PrintInterleave( + vec.begin(), + vec.end(), + [this](Attribute v) { this->PrintAttribute(v); }, + [this]() { this->os << ","; }); + os << "]"; + } else { + auto& dialect = attr.dialect(); + dialect.PrintAttribute(attr, os); } +} - public: - std::ostream& os; -}; - -class IRPrinter : public BasicIRPrinter { - public: - explicit IRPrinter(std::ostream& os) : BasicIRPrinter(os) {} - - /// @brief print program - /// @param program - /// @example - void PrintProgram(Program* program) { PrintOperation(program->module_op()); } - - /// @brief print operation - /// @param op - /// @example - void PrintOperation(Operation* op) { - for (size_t i = 0; i < op->num_regions(); ++i) { - auto& region = op->GetRegion(i); - for (auto it = region.begin(); it != region.end(); ++it) { - auto* block = *it; - os << "{\n"; - for (auto it = block->begin(); it != block->end(); ++it) { - auto* op = *it; - // TODO(lyk): add API to get opresults directly - PrintOpResult(op); - os << " ="; - - os << " \"" << op->name() << "\""; - - // TODO(lyk): add API to get operands directly - PrintOpOperands(op); - - PrintAttributeMap(op); - os << " :"; - - // PrintOpSingature - PrintOperandsType(op); - os << " -> "; - - // TODO(lyk): add API to get opresults directly - PrintOpReturnType(op); - - os << newline; - } - os << "}\n"; +void IrPrinter::PrintProgram(Program* program) { + auto top_level_op = program->module_op(); + for (size_t i = 0; i < top_level_op->num_regions(); ++i) { + auto& region = top_level_op->GetRegion(i); + for (auto it = region.begin(); it != region.end(); ++it) { + auto* block = *it; + os << "{\n"; + for (auto it = block->begin(); it != block->end(); ++it) { + PrintOperation(*it); + os << newline; } + os << "}\n"; } } +} - private: - void PrintValue(Value v) { - if (!v) { - os << "<>"; - return; - } - const void* key = static_cast(v.impl()); - auto ret = aliases_.find(key); - if (ret != aliases_.end()) { - os << ret->second; - return; - } +void IrPrinter::PrintOperation(Operation* op) { + if (auto* dialect = op->dialect()) { + dialect->PrintOperation(op, *this); + return; + } + + PrintGeneralOperation(op); +} + +void IrPrinter::PrintGeneralOperation(Operation* op) { + // TODO(lyk): add API to get opresults directly + PrintOpResult(op); + os << " ="; + + os << " \"" << op->name() << "\""; + + // TODO(lyk): add API to get operands directly + PrintOpOperands(op); + + PrintAttributeMap(op); + os << " :"; + + // PrintOpSingature + PrintOperandsType(op); + os << " -> "; + + // TODO(lyk): add API to get opresults directly + PrintOpReturnType(op); +} + +void IrPrinter::PrintFullOperation(Operation* op) { + PrintOperation(op); + if (op->num_regions() > 0) { + os << newline; + } + for (size_t i = 0; i < op->num_regions(); ++i) { + auto& region = op->GetRegion(i); + PrintRegion(region); + } +} - std::string new_name = "%" + std::to_string(cur_var_number_); - cur_var_number_++; - aliases_[key] = new_name; - os << new_name; +void IrPrinter::PrintRegion(const Region& region) { + for (auto it = region.begin(); it != region.end(); ++it) { + auto* block = *it; + PrintBlock(block); } +} - void PrintOpResult(Operation* op) { - os << " ("; - auto num_op_result = op->num_results(); - std::vector op_results; - op_results.reserve(num_op_result); - for (size_t idx = 0; idx < num_op_result; idx++) { - op_results.push_back(op->GetResultByIndex(idx)); - } - PrintInterleave( - op_results.begin(), - op_results.end(), - [this](Value v) { this->PrintValue(v); }, - [this]() { this->os << ", "; }); - os << ")"; +void IrPrinter::PrintBlock(Block* block) { + os << "{\n"; + for (auto it = block->begin(); it != block->end(); ++it) { + PrintOperation(*it); + os << newline; } + os << "}\n"; +} - void PrintAttributeMap(Operation* op) { - os << " {"; +void IrPrinter::PrintValue(Value v) { + if (!v) { + os << "<>"; + return; + } + const void* key = static_cast(v.impl()); + auto ret = aliases_.find(key); + if (ret != aliases_.end()) { + os << ret->second; + return; + } - PrintInterleave( - op->attributes().begin(), - op->attributes().end(), - [this](std::pair it) { - this->os << it.first; - this->os << ":"; - this->PrintAttribute(it.second); - }, - [this]() { this->os << ","; }); + std::string new_name = "%" + std::to_string(cur_var_number_); + cur_var_number_++; + aliases_[key] = new_name; + os << new_name; +} - os << "}"; +void IrPrinter::PrintOpResult(Operation* op) { + os << " ("; + auto num_op_result = op->num_results(); + std::vector op_results; + op_results.reserve(num_op_result); + for (size_t idx = 0; idx < num_op_result; idx++) { + op_results.push_back(op->GetResultByIndex(idx)); } + PrintInterleave( + op_results.begin(), + op_results.end(), + [this](Value v) { this->PrintValue(v); }, + [this]() { this->os << ", "; }); + os << ")"; +} - void PrintOpOperands(Operation* op) { - os << " ("; - auto num_op_operands = op->num_operands(); - std::vector op_operands; - op_operands.reserve(num_op_operands); - for (size_t idx = 0; idx < num_op_operands; idx++) { - op_operands.push_back(op->GetOperandByIndex(idx).source()); - } - PrintInterleave( - op_operands.begin(), - op_operands.end(), - [this](Value v) { this->PrintValue(v); }, - [this]() { this->os << ", "; }); - os << ")"; +void IrPrinter::PrintAttributeMap(Operation* op) { + os << " {"; + + PrintInterleave( + op->attributes().begin(), + op->attributes().end(), + [this](std::pair it) { + this->os << it.first; + this->os << ":"; + this->PrintAttribute(it.second); + }, + [this]() { this->os << ","; }); + + os << "}"; +} + +void IrPrinter::PrintOpOperands(Operation* op) { + os << " ("; + auto num_op_operands = op->num_operands(); + std::vector op_operands; + op_operands.reserve(num_op_operands); + for (size_t idx = 0; idx < num_op_operands; idx++) { + op_operands.push_back(op->GetOperandByIndex(idx).source()); } + PrintInterleave( + op_operands.begin(), + op_operands.end(), + [this](Value v) { this->PrintValue(v); }, + [this]() { this->os << ", "; }); + os << ")"; +} - void PrintOperandsType(Operation* op) { - auto num_op_operands = op->num_operands(); - std::vector op_operand_types; - op_operand_types.reserve(num_op_operands); - for (size_t idx = 0; idx < num_op_operands; idx++) { - auto op_operand = op->GetOperandByIndex(idx); - if (op_operand) { - op_operand_types.push_back(op->GetOperandByIndex(idx).source().type()); - } else { - op_operand_types.push_back(Type(nullptr)); - } +void IrPrinter::PrintOperandsType(Operation* op) { + auto num_op_operands = op->num_operands(); + std::vector op_operand_types; + op_operand_types.reserve(num_op_operands); + for (size_t idx = 0; idx < num_op_operands; idx++) { + auto op_operand = op->GetOperandByIndex(idx); + if (op_operand) { + op_operand_types.push_back(op->GetOperandByIndex(idx).source().type()); + } else { + op_operand_types.push_back(Type(nullptr)); } - os << " ("; - PrintInterleave( - op_operand_types.begin(), - op_operand_types.end(), - [this](Type t) { this->PrintType(t); }, - [this]() { this->os << ", "; }); - os << ")"; } + os << " ("; + PrintInterleave( + op_operand_types.begin(), + op_operand_types.end(), + [this](Type t) { this->PrintType(t); }, + [this]() { this->os << ", "; }); + os << ")"; +} - void PrintOpReturnType(Operation* op) { - auto num_op_result = op->num_results(); - std::vector op_result_types; - op_result_types.reserve(num_op_result); - for (size_t idx = 0; idx < num_op_result; idx++) { - auto op_result = op->GetResultByIndex(idx); - if (op_result) { - op_result_types.push_back(op_result.type()); - } else { - op_result_types.push_back(Type(nullptr)); - } +void IrPrinter::PrintOpReturnType(Operation* op) { + auto num_op_result = op->num_results(); + std::vector op_result_types; + op_result_types.reserve(num_op_result); + for (size_t idx = 0; idx < num_op_result; idx++) { + auto op_result = op->GetResultByIndex(idx); + if (op_result) { + op_result_types.push_back(op_result.type()); + } else { + op_result_types.push_back(Type(nullptr)); } - PrintInterleave( - op_result_types.begin(), - op_result_types.end(), - [this](Type t) { this->PrintType(t); }, - [this]() { this->os << ", "; }); } + PrintInterleave( + op_result_types.begin(), + op_result_types.end(), + [this](Type t) { this->PrintType(t); }, + [this]() { this->os << ", "; }); +} - private: - size_t cur_var_number_{0}; - std::unordered_map aliases_; -}; +void Dialect::PrintOperation(Operation* op, IrPrinter& printer) const { + printer.PrintGeneralOperation(op); +} void Program::Print(std::ostream& os) { - IRPrinter printer(os); + IrPrinter printer(os); printer.PrintProgram(this); } void Operation::Print(std::ostream& os) { - IRPrinter printer(os); - printer.PrintOperation(this); + IrPrinter printer(os); + printer.PrintFullOperation(this); } void Type::Print(std::ostream& os) const { - BasicIRPrinter printer(os); + BasicIrPrinter printer(os); printer.PrintType(*this); } void Attribute::Print(std::ostream& os) const { - BasicIRPrinter printer(os); + BasicIrPrinter printer(os); printer.PrintAttribute(*this); } diff --git a/paddle/ir/core/ir_printer.h b/paddle/ir/core/ir_printer.h new file mode 100644 index 0000000000000000000000000000000000000000..ed5565bcd812143baac51b86d400118286afa624 --- /dev/null +++ b/paddle/ir/core/ir_printer.h @@ -0,0 +1,78 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/ir/core/attribute.h" +#include "paddle/ir/core/block.h" +#include "paddle/ir/core/operation.h" +#include "paddle/ir/core/program.h" +#include "paddle/ir/core/region.h" +#include "paddle/ir/core/type.h" +#include "paddle/ir/core/value.h" + +namespace ir { + +class BasicIrPrinter { + public: + explicit BasicIrPrinter(std::ostream& os) : os(os) {} + + void PrintType(Type type); + + void PrintAttribute(const Attribute& attr); + + public: + std::ostream& os; +}; + +class IrPrinter : public BasicIrPrinter { + public: + explicit IrPrinter(std::ostream& os) : BasicIrPrinter(os) {} + + /// @brief print program + /// @param program + void PrintProgram(Program* program); + + /// @brief dispatch to custom printer function or PrintGeneralOperation + void PrintOperation(Operation* op); + /// @brief print operation itself without its regions + void PrintGeneralOperation(Operation* op); + /// @brief print operation and its regions + void PrintFullOperation(Operation* op); + + void PrintRegion(const Region& Region); + void PrintBlock(Block* block); + + void PrintValue(Value v); + + void PrintOpResult(Operation* op); + + void PrintAttributeMap(Operation* op); + + void PrintOpOperands(Operation* op); + + void PrintOperandsType(Operation* op); + + void PrintOpReturnType(Operation* op); + + private: + size_t cur_var_number_{0}; + std::unordered_map aliases_; +}; + +} // namespace ir diff --git a/paddle/ir/core/op_info.cc b/paddle/ir/core/op_info.cc index b52cdf113875d3b8ea56b97eaee6d08246a6cf8a..e2e1d877fa2b72e20b44c97a4bfabc39bace8d7a 100644 --- a/paddle/ir/core/op_info.cc +++ b/paddle/ir/core/op_info.cc @@ -29,6 +29,7 @@ bool OpInfo::HasInterface(TypeId interface_id) const { IrContext *OpInfo::ir_context() const { return impl_ ? impl_->ir_context() : nullptr; } +Dialect *OpInfo::dialect() const { return impl_ ? impl_->dialect() : nullptr; } const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; } diff --git a/paddle/ir/core/op_info.h b/paddle/ir/core/op_info.h index 1d8cf19f5c90bc727e91d0b07dc0050c3e10e1d5..345f6c984d9ffa87956777a98ca9ec6de4c95178 100644 --- a/paddle/ir/core/op_info.h +++ b/paddle/ir/core/op_info.h @@ -15,6 +15,7 @@ #pragma once #include #include + #include "paddle/ir/core/type_id.h" namespace ir { @@ -23,6 +24,7 @@ class IrContext; class OpResult; class Type; class Attribute; +class Dialect; class OpInfo { public: @@ -41,6 +43,7 @@ class OpInfo { bool operator!() const { return impl_ == nullptr; } IrContext *ir_context() const; + Dialect *dialect() const; const char *name() const; diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 26dc06e29b5f52173d9cb93732ff947688f192df..db733137ad27143fc09dcd97ac9267458a017341 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -166,6 +166,8 @@ void Operation::Destroy() { IrContext *Operation::ir_context() const { return info_.ir_context(); } +Dialect *Operation::dialect() const { return info_.dialect(); } + Operation::Operation(const AttributeMap &attributes, ir::OpInfo op_info, uint32_t num_results, diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index 87149fa562e5eb89a252b4460101633b6f05ac5d..b4506ceb659f311b19d852b468de83f3d5282dbd 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -47,7 +47,7 @@ class alignas(8) Operation final { void Destroy(); IrContext *ir_context() const; - + Dialect *dialect() const; OpResult GetResultByIndex(uint32_t index) const; OpOperand GetOperandByIndex(uint32_t index) const; diff --git a/paddle/ir/core/region.h b/paddle/ir/core/region.h index fa150e4889a88ef69fb9ef4755f33f0c0068343a..5c08c675798223da3d26c19c5774ed7de6de6686 100644 --- a/paddle/ir/core/region.h +++ b/paddle/ir/core/region.h @@ -35,6 +35,8 @@ class Region { iterator begin() { return blocks_.begin(); } iterator end() { return blocks_.end(); } + const_iterator begin() const { return blocks_.begin(); } + const_iterator end() const { return blocks_.end(); } reverse_iterator rbegin() { return blocks_.rbegin(); } reverse_iterator rend() { return blocks_.rend(); } diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index 6d9524634c71368484d053c0fde2eed948147c82..95f6b7c598e5cde7dfcd01e19a0278c91eb38c13 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include "paddle/ir/core/block.h" #include "paddle/ir/core/builder.h" @@ -22,6 +23,7 @@ #include "paddle/ir/core/dialect.h" #include "paddle/ir/core/enforce.h" #include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/ir_printer.h" #include "paddle/ir/core/op_base.h" #include "paddle/ir/core/program.h" #include "paddle/ir/core/region.h" @@ -150,6 +152,15 @@ class TestDialect : public ir::Dialect { } static const char *name() { return "test"; } + void PrintOperation(ir::Operation *op, + ir::IrPrinter &printer) const override { + printer.PrintOpResult(op); + printer.os << " ="; + + printer.os << " \"" << op->name() << "\""; + printer.PrintOpOperands(op); + } + private: void initialize() { RegisterOps(); } }; @@ -222,6 +233,11 @@ TEST(op_test, region_test) { ir::Region *region = argument.regions.back().get(); EXPECT_EQ(region->empty(), true); + // (3) Test custom operation printer + std::stringstream ss; + op1->Print(ss); + EXPECT_EQ(ss.str(), " (%0) = \"test.operation1\" ()"); + region->push_back(new ir::Block()); region->push_front(new ir::Block()); region->insert(region->begin(), new ir::Block()); @@ -236,7 +252,6 @@ TEST(op_test, module_op_death) { ir::IrContext *ctx = ir::IrContext::Instance(); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(ir::ModuleOp::name()); - // (3) Test uses for op. std::vector inputs{ir::OpResult()}; ir::AttributeMap attrs{{"program", ir::Int32Attribute::get(ctx, 1)}}; std::vector output_types = {ir::Float32Type::get(ctx)}; diff --git a/test/cpp/ir/core/program_translator_test.cc b/test/cpp/ir/core/program_translator_test.cc index bac606a5e12928d9a7a9bf6ddad0cbc52a78e167..c50c579ca51daad6d258bbebfbdf231a7bd32f62 100644 --- a/test/cpp/ir/core/program_translator_test.cc +++ b/test/cpp/ir/core/program_translator_test.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include "paddle/fluid/framework/framework.pb.h" @@ -61,7 +62,9 @@ TEST(PaddleDialectTest, MainProgram) { EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 20 + 3 + 8); - program->Print(std::cout); + std::stringstream ss; + program->Print(ss); + EXPECT_GT(ss.str().size(), 0u); } TEST(PaddleDialectTest, StartupProgram) { @@ -79,5 +82,7 @@ TEST(PaddleDialectTest, StartupProgram) { // + consant_op for guassian EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 3 + 53); - program->Print(std::cout); + std::stringstream ss; + program->Print(ss); + EXPECT_GT(ss.str().size(), 0u); }