未验证 提交 4aa415ef 编写于 作者: Y Yuanle Liu 提交者: GitHub

[IR&PASS] part 2-2: add IRPrinting beforce and after pass, refine IRPrinter...

[IR&PASS] part 2-2: add IRPrinting beforce and after pass, refine IRPrinter use for Program/Operation... printing (#54308)

* add IRPrinting before/after pass, update pass_manager ut to run ir::Program, IRPrinter support region/block

* rename printer.cc to ir_printer.cc

* fix comment

* remove newline
上级 37930a69
...@@ -46,13 +46,13 @@ void PrintInterleave(ForwardIterator begin, ...@@ -46,13 +46,13 @@ void PrintInterleave(ForwardIterator begin,
} // namespace } // namespace
class Printer { class BasicIRPrinter {
public: public:
explicit Printer(std::ostream& os) : os(os) {} explicit BasicIRPrinter(std::ostream& os) : os(os) {}
void PrintType(ir::Type type) { void PrintType(ir::Type type) {
if (!type) { if (!type) {
os << "<!TypeNull>"; os << "<<NULL TYPE>>";
return; return;
} }
...@@ -75,7 +75,7 @@ class Printer { ...@@ -75,7 +75,7 @@ class Printer {
inner_types.begin(), inner_types.begin(),
inner_types.end(), inner_types.end(),
[this](ir::Type v) { this->PrintType(v); }, [this](ir::Type v) { this->PrintType(v); },
[this]() { this->os << ","; }); [this]() { this->os << ", "; });
os << ">"; os << ">";
} else { } else {
auto& dialect = type.dialect(); auto& dialect = type.dialect();
...@@ -83,65 +83,79 @@ class Printer { ...@@ -83,65 +83,79 @@ class Printer {
} }
} }
public: void PrintAttribute(ir::Operation* op) { os << " { ATTRIBUTE }"; }
protected:
std::ostream& os; std::ostream& os;
}; };
void Type::print(std::ostream& os) const { class IRPrinter : public BasicIRPrinter {
Printer p(os);
p.PrintType(*this);
}
class ProgramPrinter : public Printer {
public: public:
explicit ProgramPrinter(std::ostream& os) : Printer(os), cur_var_number(0) {} explicit IRPrinter(std::ostream& os) : BasicIRPrinter(os) {}
/// @brief print program
/// @param program
/// @example
void PrintProgram(ir::Program* program) {
PrintOperation(program->module_op());
}
/// @brief print operation
/// @param op
/// @example
void PrintOperation(ir::Operation* op) {
for (size_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->GetRegion(i);
for (auto it = region.begin(); it != region.end(); ++it) {
auto* block = *it;
os << "{\n";
for (auto it = block->begin(); it != block->end(); ++it) {
auto* op = *it;
// TODO(lyk): add API to get opresults directly
PrintOpResult(op);
os << " =";
os << " \"" << op->name() << "\"";
// TODO(lyk): add API to get operands directly
PrintOpOperands(op);
PrintAttribute(op);
os << " :";
// PrintOpSingature
PrintOperandsType(op);
os << " -> ";
// TODO(lyk): add API to get opresults directly
PrintOpReturnType(op);
void Print(ir::Program& program) {
auto iterator = program.block()->begin();
while (iterator != program.block()->end()) {
PrintOperation(*iterator);
os << newline; os << newline;
iterator++; }
os << "}\n";
}
} }
} }
private:
void PrintValue(ir::Value v) { void PrintValue(ir::Value v) {
if (!v) { if (!v) {
os << "<<NULL VALUE>>"; os << "<<NULL VALUE>>";
return; return;
} }
const void* key = static_cast<const void*>(v.impl()); const void* key = static_cast<const void*>(v.impl());
auto ret = aliases.find(key); auto ret = aliases_.find(key);
if (ret != aliases.end()) { if (ret != aliases_.end()) {
os << ret->second; os << ret->second;
return; return;
} }
std::string new_name = "%" + std::to_string(cur_var_number); std::string new_name = "%" + std::to_string(cur_var_number_);
cur_var_number++; cur_var_number_++;
aliases[key] = new_name; aliases_[key] = new_name;
os << new_name; os << new_name;
} }
/// @brief print operation
/// @param op
/// @example
void PrintOperation(ir::Operation* op) {
PrintOpResult(op); // TODO(lyk): add API to get opresults directly
os << " = ";
os << "\"" << op->op_name() << "\"";
PrintOpOperands(op); // TODO(lyk): add API to get operands directly
PrintAttribute(op);
os << " : ";
// PrintOpSingature
PrintOperandsType(op);
os << " -> ";
PrintOpReturnType(op); // TODO(lyk): add API to get opresults directly
}
void PrintOpResult(ir::Operation* op) { void PrintOpResult(ir::Operation* op) {
os << " ("; os << " (";
auto num_op_result = op->num_results(); auto num_op_result = op->num_results();
...@@ -154,12 +168,10 @@ class ProgramPrinter : public Printer { ...@@ -154,12 +168,10 @@ class ProgramPrinter : public Printer {
op_results.begin(), op_results.begin(),
op_results.end(), op_results.end(),
[this](ir::Value v) { this->PrintValue(v); }, [this](ir::Value v) { this->PrintValue(v); },
[this]() { this->os << ","; }); [this]() { this->os << ", "; });
os << ") "; os << ")";
} }
void PrintAttribute(ir::Operation* op) { os << " { ATTRIBUTE } "; }
void PrintOpOperands(ir::Operation* op) { void PrintOpOperands(ir::Operation* op) {
os << " ("; os << " (";
auto num_op_operands = op->num_operands(); auto num_op_operands = op->num_operands();
...@@ -172,8 +184,8 @@ class ProgramPrinter : public Printer { ...@@ -172,8 +184,8 @@ class ProgramPrinter : public Printer {
op_operands.begin(), op_operands.begin(),
op_operands.end(), op_operands.end(),
[this](ir::Value v) { this->PrintValue(v); }, [this](ir::Value v) { this->PrintValue(v); },
[this]() { this->os << ","; }); [this]() { this->os << ", "; });
os << ") "; os << ")";
} }
void PrintOperandsType(ir::Operation* op) { void PrintOperandsType(ir::Operation* op) {
...@@ -188,11 +200,13 @@ class ProgramPrinter : public Printer { ...@@ -188,11 +200,13 @@ class ProgramPrinter : public Printer {
op_operand_types.push_back(ir::Type(nullptr)); op_operand_types.push_back(ir::Type(nullptr));
} }
} }
os << " (";
PrintInterleave( PrintInterleave(
op_operand_types.begin(), op_operand_types.begin(),
op_operand_types.end(), op_operand_types.end(),
[this](ir::Type t) { this->PrintType(t); }, [this](ir::Type t) { this->PrintType(t); },
[this]() { this->os << ","; }); [this]() { this->os << ", "; });
os << ")";
} }
void PrintOpReturnType(ir::Operation* op) { void PrintOpReturnType(ir::Operation* op) {
...@@ -211,18 +225,27 @@ class ProgramPrinter : public Printer { ...@@ -211,18 +225,27 @@ class ProgramPrinter : public Printer {
op_result_types.begin(), op_result_types.begin(),
op_result_types.end(), op_result_types.end(),
[this](ir::Type t) { this->PrintType(t); }, [this](ir::Type t) { this->PrintType(t); },
[this]() { this->os << ","; }); [this]() { this->os << ", "; });
} }
private: private:
size_t cur_var_number; size_t cur_var_number_{0};
std::unordered_map<const void*, std::string> aliases; std::unordered_map<const void*, std::string> aliases_;
}; };
std::ostream& operator<<(std::ostream& os, Program& program) { void Program::print(std::ostream& os) {
ProgramPrinter printer(os); IRPrinter printer(os);
printer.Print(program); printer.PrintProgram(this);
return os; }
void Operation::Print(std::ostream& os) {
IRPrinter printer(os);
printer.PrintOperation(this);
}
void Type::print(std::ostream& os) const {
BasicIRPrinter printer(os);
printer.PrintType(*this);
} }
} // namespace ir } // namespace ir
...@@ -113,7 +113,7 @@ class OpInterfaceBase : public OpBase { ...@@ -113,7 +113,7 @@ class OpInterfaceBase : public OpBase {
static ConcreteInterface dyn_cast(Operation *op) { static ConcreteInterface dyn_cast(Operation *op) {
if (op && op->HasInterface<ConcreteInterface>()) { if (op && op->HasInterface<ConcreteInterface>()) {
return ConcreteInterface( return ConcreteInterface(
op, op->op_info().GetInterfaceImpl<ConcreteInterface>()); op, op->info().GetInterfaceImpl<ConcreteInterface>());
} }
return ConcreteInterface(nullptr, nullptr); return ConcreteInterface(nullptr, nullptr);
} }
...@@ -184,7 +184,7 @@ class Op : public OpBase { ...@@ -184,7 +184,7 @@ class Op : public OpBase {
typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type; typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
static ConcreteOp dyn_cast(Operation *op) { static ConcreteOp dyn_cast(Operation *op) {
if (op && op->op_info().id() == TypeId::get<ConcreteOp>()) { if (op && op->info().id() == TypeId::get<ConcreteOp>()) {
return ConcreteOp(op); return ConcreteOp(op);
} }
return ConcreteOp(nullptr); return ConcreteOp(nullptr);
......
...@@ -12,9 +12,12 @@ ...@@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/ir/core/operation.h" #include <ostream>
#include "paddle/ir/core/block.h" #include "paddle/ir/core/block.h"
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h" #include "paddle/ir/core/region.h"
#include "paddle/ir/core/utils.h" #include "paddle/ir/core/utils.h"
...@@ -160,7 +163,7 @@ void Operation::destroy() { ...@@ -160,7 +163,7 @@ void Operation::destroy() {
aligned_free(reinterpret_cast<void *>(aligned_ptr)); aligned_free(reinterpret_cast<void *>(aligned_ptr));
} }
IrContext *Operation::ir_context() const { return op_info_.ir_context(); } IrContext *Operation::ir_context() const { return info_.ir_context(); }
Operation::Operation(const AttributeMap &attributes, Operation::Operation(const AttributeMap &attributes,
ir::OpInfo op_info, ir::OpInfo op_info,
...@@ -168,7 +171,7 @@ Operation::Operation(const AttributeMap &attributes, ...@@ -168,7 +171,7 @@ Operation::Operation(const AttributeMap &attributes,
uint32_t num_operands, uint32_t num_operands,
uint32_t num_regions) uint32_t num_regions)
: attributes_(attributes), : attributes_(attributes),
op_info_(op_info), info_(op_info),
num_results_(num_results), num_results_(num_results),
num_operands_(num_operands), num_operands_(num_operands),
num_regions_(num_regions) {} num_regions_(num_regions) {}
...@@ -203,28 +206,11 @@ ir::OpOperand Operation::GetOperandByIndex(uint32_t index) const { ...@@ -203,28 +206,11 @@ ir::OpOperand Operation::GetOperandByIndex(uint32_t index) const {
return ir::OpOperand(reinterpret_cast<const detail::OpOperandImpl *>(ptr)); return ir::OpOperand(reinterpret_cast<const detail::OpOperandImpl *>(ptr));
} }
std::string Operation::print() { std::string Operation::name() const {
std::stringstream result; auto p_name = info_.name();
result << "{ " << num_results_ << " outputs, " << num_operands_ return p_name ? p_name : "";
<< " inputs } : ";
result << "[ ";
for (size_t idx = num_results_; idx > 0; idx--) {
result << GetResultByIndex(idx - 1).impl_ << ", ";
}
result << "] = ";
result << this << "( ";
for (size_t idx = 0; idx < num_operands_; idx++) {
result << reinterpret_cast<void *>(reinterpret_cast<char *>(this) +
sizeof(Operation) +
idx * sizeof(detail::OpOperandImpl))
<< ", ";
}
result << ")";
return result.str();
} }
std::string Operation::op_name() const { return op_info_.name(); }
Region *Operation::GetParentRegion() const { Region *Operation::GetParentRegion() const {
return parent_ ? parent_->GetParentRegion() : nullptr; return parent_ ? parent_->GetParentRegion() : nullptr;
} }
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include <iostream> #include <ostream>
#include "paddle/ir/core/op_info.h" #include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation_utils.h" #include "paddle/ir/core/operation_utils.h"
#include "paddle/ir/core/type.h" #include "paddle/ir/core/type.h"
...@@ -52,7 +52,7 @@ class alignas(8) Operation final { ...@@ -52,7 +52,7 @@ class alignas(8) Operation final {
OpOperand GetOperandByIndex(uint32_t index) const; OpOperand GetOperandByIndex(uint32_t index) const;
std::string print(); void Print(std::ostream &os);
const AttributeMap &attributes() const { return attributes_; } const AttributeMap &attributes() const { return attributes_; }
...@@ -60,7 +60,7 @@ class alignas(8) Operation final { ...@@ -60,7 +60,7 @@ class alignas(8) Operation final {
attributes_[key] = value; attributes_[key] = value;
} }
ir::OpInfo op_info() const { return op_info_; } ir::OpInfo info() const { return info_; }
uint32_t num_results() const { return num_results_; } uint32_t num_results() const { return num_results_; }
...@@ -68,7 +68,7 @@ class alignas(8) Operation final { ...@@ -68,7 +68,7 @@ class alignas(8) Operation final {
uint32_t num_regions() const { return num_regions_; } uint32_t num_regions() const { return num_regions_; }
std::string op_name() const; std::string name() const;
template <typename T> template <typename T>
T dyn_cast() { T dyn_cast() {
...@@ -77,12 +77,12 @@ class alignas(8) Operation final { ...@@ -77,12 +77,12 @@ class alignas(8) Operation final {
template <typename Trait> template <typename Trait>
bool HasTrait() const { bool HasTrait() const {
return op_info_.HasTrait<Trait>(); return info_.HasTrait<Trait>();
} }
template <typename Interface> template <typename Interface>
bool HasInterface() const { bool HasInterface() const {
return op_info_.HasInterface<Interface>(); return info_.HasInterface<Interface>();
} }
Block *GetParentBlock() const { return parent_; } Block *GetParentBlock() const { return parent_; }
...@@ -122,7 +122,7 @@ class alignas(8) Operation final { ...@@ -122,7 +122,7 @@ class alignas(8) Operation final {
AttributeMap attributes_; AttributeMap attributes_;
OpInfo op_info_; OpInfo info_;
const uint32_t num_results_ = 0; const uint32_t num_results_ = 0;
const uint32_t num_operands_ = 0; const uint32_t num_operands_ = 0;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <list> #include <list>
#include <ostream>
#include <unordered_map> #include <unordered_map>
#include "paddle/ir/core/attribute.h" #include "paddle/ir/core/attribute.h"
...@@ -49,6 +50,8 @@ class Program { ...@@ -49,6 +50,8 @@ class Program {
ModuleOp module_op() { return module_; } ModuleOp module_op() { return module_; }
void print(std::ostream& os);
Block* block() { return module_.block(); } Block* block() { return module_.block(); }
Parameter* GetParameter(std::string name) const; Parameter* GetParameter(std::string name) const;
...@@ -66,6 +69,4 @@ class Program { ...@@ -66,6 +69,4 @@ class Program {
ParameterMap parameters_; ParameterMap parameters_;
}; };
std::ostream& operator<<(std::ostream& os, Program& program);
} // namespace ir } // namespace ir
...@@ -41,19 +41,17 @@ class PreservedAnalyses { ...@@ -41,19 +41,17 @@ class PreservedAnalyses {
public: public:
/// Mark all analyses as preserved. /// Mark all analyses as preserved.
void PreserveAll() { void PreserveAll() { preserved_ids_.insert(TypeId::get<AllAnalysesType>()); }
preserved_ids_.insert(ir::TypeId::get<AllAnalysesType>());
}
bool IsAll() const { bool IsAll() const {
return preserved_ids_.count(ir::TypeId::get<AllAnalysesType>()); return preserved_ids_.count(TypeId::get<AllAnalysesType>());
} }
bool IsNone() const { return preserved_ids_.empty(); } bool IsNone() const { return preserved_ids_.empty(); }
template <typename AnalysisT> template <typename AnalysisT>
void Preserve() { void Preserve() {
Preserve(ir::TypeId::get<AnalysisT>()); Preserve(TypeId::get<AnalysisT>());
} }
template <typename AnalysisT, typename AnalysisT2, typename... OtherAnalysesT> template <typename AnalysisT, typename AnalysisT2, typename... OtherAnalysesT>
...@@ -62,25 +60,25 @@ class PreservedAnalyses { ...@@ -62,25 +60,25 @@ class PreservedAnalyses {
Preserve<AnalysisT2, OtherAnalysesT...>(); Preserve<AnalysisT2, OtherAnalysesT...>();
} }
void Preserve(ir::TypeId id) { preserved_ids_.insert(id); } void Preserve(TypeId id) { preserved_ids_.insert(id); }
template <typename AnalysisT> template <typename AnalysisT>
bool IsPreserved() const { bool IsPreserved() const {
return IsPreserved(ir::TypeId::get<AnalysisT>()); return IsPreserved(TypeId::get<AnalysisT>());
} }
bool IsPreserved(ir::TypeId id) const { return preserved_ids_.count(id); } bool IsPreserved(TypeId id) const { return preserved_ids_.count(id); }
template <typename AnalysisT> template <typename AnalysisT>
void Unpreserve() { void Unpreserve() {
preserved_ids_.erase(ir::TypeId::get<AnalysisT>()); preserved_ids_.erase(TypeId::get<AnalysisT>());
} }
private: private:
template <typename> template <typename>
friend struct AnalysisModel; friend struct AnalysisModel;
std::unordered_set<ir::TypeId> preserved_ids_; std::unordered_set<TypeId> preserved_ids_;
}; };
namespace detail { namespace detail {
...@@ -132,11 +130,11 @@ struct AnalysisModel : public AnalysisConcept { ...@@ -132,11 +130,11 @@ struct AnalysisModel : public AnalysisConcept {
/// All computation, caching and invalidation of analyses takes place here. /// All computation, caching and invalidation of analyses takes place here.
class AnalysisMap { class AnalysisMap {
public: public:
explicit AnalysisMap(ir::Operation* ir) : ir_(ir) {} explicit AnalysisMap(Operation* ir) : ir_(ir) {}
template <typename AnalysisT> template <typename AnalysisT>
AnalysisT& GetAnalysis(PassInstrumentor* pi, AnalysisManager& am) { AnalysisT& GetAnalysis(PassInstrumentor* pi, AnalysisManager& am) {
return GetAnalysisImpl<AnalysisT, ir::Operation*>(pi, ir_, am); return GetAnalysisImpl<AnalysisT, Operation*>(pi, ir_, am);
} }
template <typename AnalysisT, typename OpT> template <typename AnalysisT, typename OpT>
...@@ -151,12 +149,12 @@ class AnalysisMap { ...@@ -151,12 +149,12 @@ class AnalysisMap {
template <typename AnalysisT> template <typename AnalysisT>
paddle::optional<std::reference_wrapper<AnalysisT>> GetCachedAnalysis() paddle::optional<std::reference_wrapper<AnalysisT>> GetCachedAnalysis()
const { const {
auto res = analyses_.find(ir::TypeId::get<AnalysisT>()); auto res = analyses_.find(TypeId::get<AnalysisT>());
if (res == analyses_.end()) return paddle::none; if (res == analyses_.end()) return paddle::none;
return {static_cast<AnalysisModel<AnalysisT>&>(*res->second).analysis}; return {static_cast<AnalysisModel<AnalysisT>&>(*res->second).analysis};
} }
ir::Operation* getOperation() const { return ir_; } Operation* getOperation() const { return ir_; }
void Clear() { analyses_.clear(); } void Clear() { analyses_.clear(); }
...@@ -190,7 +188,7 @@ class AnalysisMap { ...@@ -190,7 +188,7 @@ class AnalysisMap {
AnalysisT& GetAnalysisImpl(PassInstrumentor* pi, AnalysisT& GetAnalysisImpl(PassInstrumentor* pi,
OpT op, OpT op,
AnalysisManager& am) { // NOLINT AnalysisManager& am) { // NOLINT
ir::TypeId id = ir::TypeId::get<AnalysisT>(); TypeId id = TypeId::get<AnalysisT>();
auto it = analyses_.find(id); auto it = analyses_.find(id);
if (it == analyses_.end()) { if (it == analyses_.end()) {
if (pi) { if (pi) {
...@@ -234,8 +232,8 @@ class AnalysisMap { ...@@ -234,8 +232,8 @@ class AnalysisMap {
} }
private: private:
ir::Operation* ir_; Operation* ir_;
std::unordered_map<ir::TypeId, std::unique_ptr<AnalysisConcept>> analyses_; std::unordered_map<TypeId, std::unique_ptr<AnalysisConcept>> analyses_;
}; };
} // namespace detail } // namespace detail
...@@ -273,7 +271,7 @@ class AnalysisManager { ...@@ -273,7 +271,7 @@ class AnalysisManager {
PassInstrumentor* GetPassInstrumentor() const { return instrumentor_; } PassInstrumentor* GetPassInstrumentor() const { return instrumentor_; }
ir::Operation* GetOperation() { return analyses_->getOperation(); } Operation* GetOperation() { return analyses_->getOperation(); }
private: private:
AnalysisManager(detail::AnalysisMap* impl, PassInstrumentor* pi) AnalysisManager(detail::AnalysisMap* impl, PassInstrumentor* pi)
...@@ -292,7 +290,7 @@ class AnalysisManager { ...@@ -292,7 +290,7 @@ class AnalysisManager {
/// analyses. /// analyses.
class AnalysisManagerHolder { class AnalysisManagerHolder {
public: public:
AnalysisManagerHolder(ir::Operation* op, PassInstrumentor* pi) AnalysisManagerHolder(Operation* op, PassInstrumentor* pi)
: analyses_(op), pi_(pi) {} : analyses_(op), pi_(pi) {}
AnalysisManagerHolder(const AnalysisManagerHolder&) = delete; AnalysisManagerHolder(const AnalysisManagerHolder&) = delete;
AnalysisManagerHolder& operator=(const AnalysisManagerHolder&) = delete; AnalysisManagerHolder& operator=(const AnalysisManagerHolder&) = delete;
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <ostream>
#include <unordered_map>
#include "paddle/ir/core/operation.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_manager.h"
namespace ir {
namespace {
void PrintIR(Operation *op, bool print_module, std::ostream &os) {
// Otherwise, check to see if we are not printing at module scope.
if (print_module) {
op->Print(os << "\n");
return;
}
// Otherwise, we are printing at module scope.
os << " ('" << op->name() << "' operation)\n";
// Find the top-level operation.
auto *top_op = op;
while (auto *parent_op = top_op->GetParentOp()) {
top_op = parent_op;
}
top_op->Print(os);
}
} // namespace
class IRPrinting : public PassInstrumentation {
public:
explicit IRPrinting(std::unique_ptr<PassManager::IRPrinterOption> option)
: option_(std::move(option)) {}
~IRPrinting() = default;
void RunBeforePass(Pass *pass, Operation *op) override {
if (option_->EnablePrintOnChange()) {
// TODO(liuyuanle): support print on change
}
option_->PrintBeforeIfEnabled(pass, op, [&](std::ostream &os) {
os << "// *** IR Dump Before " << pass->pass_info().name << " ***";
PrintIR(op, option_->EnablePrintModule(), os);
os << "\n\n";
});
}
void RunAfterPass(Pass *pass, Operation *op) override {
if (option_->EnablePrintOnChange()) {
// TODO(liuyuanle): support print on change
}
option_->PrintBeforeIfEnabled(pass, op, [&](std::ostream &os) {
os << "// *** IR Dump After " << pass->pass_info().name << " ***";
PrintIR(op, option_->EnablePrintModule(), os);
os << "\n\n";
});
}
private:
std::unique_ptr<PassManager::IRPrinterOption> option_;
// TODO(liuyuanle): Add IRFingerPrint to support print on change.
};
void PassManager::EnableIRPrinting(std::unique_ptr<IRPrinterOption> option) {
AddInstrumentation(std::make_unique<IRPrinting>(std::move(option)));
}
} // namespace ir
...@@ -16,12 +16,20 @@ ...@@ -16,12 +16,20 @@
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/pass/pass_adaptor.h" #include "paddle/ir/pass/pass_adaptor.h"
#include "paddle/ir/pass/pass_instrumentation.h" #include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_manager.h" #include "paddle/ir/pass/pass_manager.h"
namespace ir { namespace ir {
//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
Pass::~Pass() = default;
bool Pass::CanApplyOn(Operation* op) const { return op->num_regions() > 0; }
//----------------------------------------------------------------------------------------------// //----------------------------------------------------------------------------------------------//
// PassAdaptor // PassAdaptor
//----------------------------------------------------------------------------------------------// //----------------------------------------------------------------------------------------------//
...@@ -34,7 +42,20 @@ void detail::PassAdaptor::Run(ir::Operation* op, ...@@ -34,7 +42,20 @@ void detail::PassAdaptor::Run(ir::Operation* op,
void detail::PassAdaptor::RunImpl(ir::Operation* op, void detail::PassAdaptor::RunImpl(ir::Operation* op,
uint8_t opt_level, uint8_t opt_level,
bool verify) { bool verify) {
// TODO(liuyuanle): Support block, region, etc. auto last_am = analysis_manager();
for (size_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->GetRegion(i);
for (auto it = region.begin(); it != region.end(); ++it) {
auto* block = *it;
for (auto it = block->begin(); it != block->end(); ++it) {
auto* op = *it;
AnalysisManagerHolder am(op, last_am.GetPassInstrumentor());
if (!RunPipeline(*pm_, op, am, opt_level, verify))
return SignalPassFailure();
}
}
}
return; return;
} }
...@@ -49,7 +70,7 @@ bool detail::PassAdaptor::RunPipeline(const PassManager& pm, ...@@ -49,7 +70,7 @@ bool detail::PassAdaptor::RunPipeline(const PassManager& pm,
} }
for (auto& pass : pm.passes()) { for (auto& pass : pm.passes()) {
if (pass->CanScheduleOn(op)) { if (pass->CanApplyOn(op)) {
if (!RunPass(pass.get(), op, am, opt_level, verify)) { if (!RunPass(pass.get(), op, am, opt_level, verify)) {
return false; return false;
} }
...@@ -106,21 +127,21 @@ PassManager::PassManager(ir::IrContext* context, uint8_t opt_level) ...@@ -106,21 +127,21 @@ PassManager::PassManager(ir::IrContext* context, uint8_t opt_level)
pass_adaptor_ = std::make_unique<detail::PassAdaptor>(this); pass_adaptor_ = std::make_unique<detail::PassAdaptor>(this);
} }
// bool PassManager::Run(ir::Program* program) const { bool PassManager::Run(ir::Program* program) {
// if (!Initialize(context_)) { if (!Initialize(context_)) {
// return false; return false;
// } }
// return Run(program->operation()); return Run(program->module_op());
// } }
bool PassManager::Run(ir::Operation* op) const { bool PassManager::Run(ir::Operation* op) {
// Construct a analysis manager for the pipeline. // Construct a analysis manager for the pipeline.
AnalysisManagerHolder am(op, instrumentor_.get()); AnalysisManagerHolder am(op, instrumentor_.get());
return detail::PassAdaptor::RunPipeline(*this, op, am, opt_level_, verify_); return detail::PassAdaptor::RunPipeline(*this, op, am, opt_level_, verify_);
} }
bool PassManager::Initialize(ir::IrContext* context) const { bool PassManager::Initialize(ir::IrContext* context) {
for (auto& pass : passes()) { for (auto& pass : passes()) {
if (!pass->Initialize(context)) return false; if (!pass->Initialize(context)) return false;
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <string>
#include <vector> #include <vector>
#include "paddle/ir/pass/analysis_manager.h" #include "paddle/ir/pass/analysis_manager.h"
...@@ -33,11 +34,11 @@ class PassAdaptor; ...@@ -33,11 +34,11 @@ class PassAdaptor;
namespace detail { namespace detail {
struct PassExecutionState { struct PassExecutionState {
explicit PassExecutionState(ir::Operation* ir, const AnalysisManager& am) explicit PassExecutionState(Operation* ir, const AnalysisManager& am)
: ir(ir), pass_failed(false), am(am) {} : ir(ir), pass_failed(false), am(am) {}
// The IR currently being processed by pass. // The IR currently being processed by pass.
ir::Operation* ir; Operation* ir;
bool pass_failed; bool pass_failed;
AnalysisManager am; AnalysisManager am;
...@@ -45,13 +46,13 @@ struct PassExecutionState { ...@@ -45,13 +46,13 @@ struct PassExecutionState {
}; };
struct PassInfo { struct PassInfo {
PassInfo(const char* name, PassInfo(const std::string& name,
uint8_t opt_level, uint8_t opt_level,
const std::vector<const char* /* pass name */>& dependents = {}) const std::vector<std::string /* pass name */>& dependents = {})
: name(name), opt_level(opt_level), dependents(dependents) {} : name(name), opt_level(opt_level), dependents(dependents) {}
// Pass name. // Pass name.
const char* name; std::string name;
// opt_level=0: the basic pass which framework need. // opt_level=0: the basic pass which framework need.
// opt_level=1: the fusion logical pass. // opt_level=1: the fusion logical pass.
...@@ -61,7 +62,7 @@ struct PassInfo { ...@@ -61,7 +62,7 @@ struct PassInfo {
// The list which pass depends on. // The list which pass depends on.
// PassManager will check the constraint(TODO). // PassManager will check the constraint(TODO).
std::vector<const char*> dependents; std::vector<std::string> dependents;
}; };
} // namespace detail } // namespace detail
...@@ -69,22 +70,23 @@ struct PassInfo { ...@@ -69,22 +70,23 @@ struct PassInfo {
/// We can access pass only from PassManager. /// We can access pass only from PassManager.
class Pass { class Pass {
public: public:
explicit Pass(const char* name, explicit Pass(const std::string& name,
uint8_t opt_level, uint8_t opt_level,
const std::vector<const char*>& dependents = {}) const std::vector<std::string>& dependents = {})
: pass_info_(name, opt_level, dependents) {} : pass_info_(name, opt_level, dependents) {}
virtual ~Pass() = default; virtual ~Pass();
std::string name() { return pass_info().name; }
const detail::PassInfo& pass_info() const { return pass_info_; } const detail::PassInfo& pass_info() const { return pass_info_; }
protected: protected:
virtual void Run(ir::Operation* op) = 0; virtual void Run(Operation* op) = 0;
// TODO(liuyuanle): Add block/region judgement. virtual inline bool CanApplyOn(Operation* op) const;
virtual inline bool CanScheduleOn(ir::Operation* op) const { return true; }
virtual bool Initialize(ir::IrContext* context) { return true; } virtual bool Initialize(IrContext* context) { return true; }
AnalysisManager analysis_manager() { return pass_state().am; } AnalysisManager analysis_manager() { return pass_state().am; }
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
namespace ir { namespace ir {
class Operation; class Operation;
class PassManager; class PassManager;
namespace detail { namespace detail {
...@@ -28,21 +27,21 @@ class PassAdaptor final : public Pass { ...@@ -28,21 +27,21 @@ class PassAdaptor final : public Pass {
public: public:
explicit PassAdaptor(PassManager* pm) : Pass("pass_adaptor", 0), pm_(pm) {} explicit PassAdaptor(PassManager* pm) : Pass("pass_adaptor", 0), pm_(pm) {}
void Run(ir::Operation*) override {} void Run(Operation*) override {}
void Run(ir::Operation*, uint8_t opt_level, bool verify); void Run(Operation*, uint8_t opt_level, bool verify);
private: private:
void RunImpl(ir::Operation* op, uint8_t opt_level, bool verify); void RunImpl(Operation* op, uint8_t opt_level, bool verify);
static bool RunPass(Pass* pass, static bool RunPass(Pass* pass,
ir::Operation* op, Operation* op,
AnalysisManager am, AnalysisManager am,
uint8_t opt_level, uint8_t opt_level,
bool verify); bool verify);
static bool RunPipeline(const PassManager& pm, static bool RunPipeline(const PassManager& pm,
ir::Operation* op, Operation* op,
AnalysisManager am, AnalysisManager am,
uint8_t opt_level, uint8_t opt_level,
bool verify); bool verify);
......
...@@ -33,21 +33,21 @@ class PassInstrumentation { ...@@ -33,21 +33,21 @@ class PassInstrumentation {
virtual ~PassInstrumentation() = default; virtual ~PassInstrumentation() = default;
/// A callback to run before a pass pipeline is executed. /// A callback to run before a pass pipeline is executed.
virtual void RunBeforePipeline(ir::Operation* op) {} virtual void RunBeforePipeline(Operation* op) {}
virtual void RunAfterPipeline(ir::Operation* op) {} virtual void RunAfterPipeline(Operation* op) {}
virtual void RunBeforePass(Pass* pass, ir::Operation* op) {} virtual void RunBeforePass(Pass* pass, Operation* op) {}
virtual void RunAfterPass(Pass* pass, ir::Operation* op) {} virtual void RunAfterPass(Pass* pass, Operation* op) {}
virtual void RunBeforeAnalysis(const std::string& name, virtual void RunBeforeAnalysis(const std::string& name,
ir::TypeId id, TypeId id,
ir::Operation* op) {} Operation* op) {}
virtual void RunAfterAnalysis(const std::string& name, virtual void RunAfterAnalysis(const std::string& name,
ir::TypeId id, TypeId id,
ir::Operation* op) {} Operation* op) {}
}; };
/// This class holds a collection of PassInstrumentation obejcts, and invokes /// This class holds a collection of PassInstrumentation obejcts, and invokes
...@@ -61,21 +61,17 @@ class PassInstrumentor { ...@@ -61,21 +61,17 @@ class PassInstrumentor {
void AddInstrumentation(std::unique_ptr<PassInstrumentation> pi); void AddInstrumentation(std::unique_ptr<PassInstrumentation> pi);
void RunBeforePipeline(ir::Operation* op); void RunBeforePipeline(Operation* op);
void RunAfterPipeline(ir::Operation* op); void RunAfterPipeline(Operation* op);
void RunBeforePass(Pass* pass, ir::Operation* op); void RunBeforePass(Pass* pass, Operation* op);
void RunAfterPass(Pass* pass, ir::Operation* op); void RunAfterPass(Pass* pass, Operation* op);
void RunBeforeAnalysis(const std::string& name, void RunBeforeAnalysis(const std::string& name, TypeId id, Operation* op);
ir::TypeId id /* */,
ir::Operation* op);
void RunAfterAnalysis(const std::string& name, void RunAfterAnalysis(const std::string& name, TypeId id, Operation* op);
ir::TypeId id,
ir::Operation* op);
// TODO(wilber): Add other hooks. // TODO(wilber): Add other hooks.
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <iostream>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -35,7 +36,7 @@ class PassAdaptor; ...@@ -35,7 +36,7 @@ class PassAdaptor;
class PassManager { class PassManager {
public: public:
explicit PassManager(ir::IrContext *context, uint8_t opt_level = 2); explicit PassManager(IrContext *context, uint8_t opt_level = 2);
~PassManager() = default; ~PassManager() = default;
...@@ -43,22 +44,82 @@ class PassManager { ...@@ -43,22 +44,82 @@ class PassManager {
bool empty() const { return passes_.empty(); } bool empty() const { return passes_.empty(); }
ir::IrContext *context() const { return context_; } IrContext *context() const { return context_; }
// bool Run(ir::Program *program) const; bool Run(Program *program);
bool Run(ir::Operation *op) const;
void AddPass(std::unique_ptr<Pass> pass) { void AddPass(std::unique_ptr<Pass> pass) {
passes_.emplace_back(std::move(pass)); passes_.emplace_back(std::move(pass));
} }
class IRPrinterOption {
public:
using PrintCallBack = std::function<void(std::ostream &)>;
explicit IRPrinterOption(
const std::function<bool(Pass *, Operation *)> &enable_print_before =
[](Pass *, Operation *) { return true; },
const std::function<bool(Pass *, Operation *)> &enable_print_after =
[](Pass *, Operation *) { return true; },
bool print_module = true,
bool print_on_change = true,
std::ostream &os = std::cout)
: enable_print_before_(enable_print_before),
enable_print_after_(enable_print_after),
print_module_(print_module),
print_on_change_(print_on_change),
os(os) {
assert((enable_print_before_ || enable_print_after_) &&
"expected at least one valid filter function");
}
~IRPrinterOption() = default;
void PrintBeforeIfEnabled(Pass *pass,
Operation *op,
const PrintCallBack &print_callback) {
if (enable_print_before_ && enable_print_before_(pass, op)) {
print_callback(os);
}
}
void PrintAfterIfEnabled(Pass *pass,
Operation *op,
const PrintCallBack &print_callback) {
if (enable_print_after_ && enable_print_after_(pass, op)) {
print_callback(os);
}
}
bool EnablePrintModule() const { return print_module_; }
bool EnablePrintOnChange() const { return print_on_change_; }
private:
// The enable_print_before_ and enable_print_after_ can be used to specify
// the pass to be printed. The default is to print all passes.
std::function<bool(Pass *, Operation *)> enable_print_before_;
std::function<bool(Pass *, Operation *)> enable_print_after_;
bool print_module_;
bool print_on_change_;
std::ostream &os;
// TODO(liuyuanle): Add flags to control printing behavior.
};
void EnableIRPrinting(std::unique_ptr<IRPrinterOption> config);
void AddInstrumentation(std::unique_ptr<PassInstrumentation> pi); void AddInstrumentation(std::unique_ptr<PassInstrumentation> pi);
private: private:
bool Initialize(ir::IrContext *context) const; bool Initialize(IrContext *context);
bool Run(Operation *op);
private: private:
ir::IrContext *context_; IrContext *context_;
uint8_t opt_level_; uint8_t opt_level_;
...@@ -70,6 +131,7 @@ class PassManager { ...@@ -70,6 +131,7 @@ class PassManager {
std::unique_ptr<PassInstrumentor> instrumentor_; std::unique_ptr<PassInstrumentor> instrumentor_;
// For access member of pass_adaptor_.
friend class detail::PassAdaptor; friend class detail::PassAdaptor;
}; };
......
...@@ -230,6 +230,8 @@ TEST(program_test, program) { ...@@ -230,6 +230,8 @@ TEST(program_test, program) {
// (8) Traverse Program // (8) Traverse Program
EXPECT_EQ(program.block()->size() == 4, true); EXPECT_EQ(program.block()->size() == 4, true);
EXPECT_EQ(program.parameters_num() == 3, true); EXPECT_EQ(program.parameters_num() == 3, true);
program.print(std::cout);
} }
TEST(program_test, slice_combine_test) { TEST(program_test, slice_combine_test) {
......
...@@ -43,7 +43,7 @@ TEST(value_test, value_test) { ...@@ -43,7 +43,7 @@ TEST(value_test, value_test) {
CreateAttributeMap("op1_name", "op1_attr"), CreateAttributeMap("op1_name", "op1_attr"),
op1_output_types, op1_output_types,
nullptr); nullptr);
VLOG(0) << op1->print(); op1->Print(std::cout);
// 2. Construct OP2: b = OP2(); // 2. Construct OP2: b = OP2();
std::vector<ir::OpResult> op2_inputs = {}; std::vector<ir::OpResult> op2_inputs = {};
std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)};
...@@ -52,7 +52,7 @@ TEST(value_test, value_test) { ...@@ -52,7 +52,7 @@ TEST(value_test, value_test) {
CreateAttributeMap("op2_name", "op2_attr"), CreateAttributeMap("op2_name", "op2_attr"),
op2_output_types, op2_output_types,
nullptr); nullptr);
VLOG(0) << op2->print() << std::endl; op2->Print(std::cout);
// 3. Construct OP3: c = OP3(a, b); // 3. Construct OP3: c = OP3(a, b);
std::vector<ir::OpResult> op3_inputs = {op1->GetResultByIndex(0), std::vector<ir::OpResult> op3_inputs = {op1->GetResultByIndex(0),
op2->GetResultByIndex(0)}; op2->GetResultByIndex(0)};
...@@ -62,7 +62,7 @@ TEST(value_test, value_test) { ...@@ -62,7 +62,7 @@ TEST(value_test, value_test) {
CreateAttributeMap("op3_name", "op3_attr"), CreateAttributeMap("op3_name", "op3_attr"),
op3_output_types, op3_output_types,
nullptr); nullptr);
VLOG(0) << op3->print() << std::endl; op3->Print(std::cout);
// 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c); // 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c);
std::vector<ir::OpResult> op4_inputs = {op1->GetResultByIndex(0), std::vector<ir::OpResult> op4_inputs = {op1->GetResultByIndex(0),
op3->GetResultByIndex(0)}; op3->GetResultByIndex(0)};
...@@ -75,7 +75,7 @@ TEST(value_test, value_test) { ...@@ -75,7 +75,7 @@ TEST(value_test, value_test) {
CreateAttributeMap("op4_name", "op4_attr"), CreateAttributeMap("op4_name", "op4_attr"),
op4_output_types, op4_output_types,
nullptr); nullptr);
VLOG(0) << op4->print() << std::endl; op4->Print(std::cout);
// Test 1: // Test 1:
EXPECT_EQ(op1->GetResultByIndex(0).GetDefiningOp(), op1); EXPECT_EQ(op1->GetResultByIndex(0).GetDefiningOp(), op1);
......
cc_test_old(pass_manager_test SRCS pass_manager_test.cc DEPS new_pass gtest) cc_test_old(
pass_manager_test
SRCS
pass_manager_test.cc
DEPS
new_pass
pd_dialect
phi
gtest)
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cstring> #include "paddle/fluid/dialect/pd_dialect.h"
#include "paddle/fluid/dialect/pd_interface.h"
#include "glog/logging.h" #include "paddle/fluid/dialect/pd_type.h"
#include "paddle/fluid/dialect/utils.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
...@@ -25,89 +27,240 @@ ...@@ -25,89 +27,240 @@
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/pass/pass.h" #include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_manager.h" #include "paddle/ir/pass/pass_manager.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
ir::AttributeMap CreateAttributeMap(ir::IrContext *ctx, class AddOp : public ir::Op<AddOp> {
std::string attribute_name,
std::string attribute) {
ir::Attribute attr_value = ir::StrAttribute::get(ctx, attribute);
ir::AttributeMap attr_map;
attr_map.insert(
std::pair<std::string, ir::Attribute>(attribute_name, attr_value));
return attr_map;
}
class TestOp : public ir::Op<TestOp> {
public: public:
using Op::Op; using Op::Op;
static const char *name() { return "TestDialect.TestOp"; } static const char *name() { return "test.add"; }
static constexpr uint32_t attributes_num = 1; static constexpr const char **attributes_name = nullptr;
static const char *attributes_name[attributes_num]; static constexpr uint32_t attributes_num = 0;
static void verify(const std::vector<ir::OpResult> &inputs, static void verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
if (attributes.count("op1_attr1") == 0 || if (inputs.size() != 2) {
!attributes.at("op1_attr1").isa<ir::StrAttribute>()) { throw("The size of inputs must be equal to 2.");
throw("Type of attribute: parameter_name is not right.");
} }
if (outputs.size() != 1) {
throw("The size of outputs must be equal to 1.");
} }
};
const char *TestOp::attributes_name[attributes_num] = {"op1_attr1"};
class TestDialect : public ir::Dialect {
public:
explicit TestDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<TestDialect>()) {
initialize();
} }
static const char *name() { return "TestDialect"; }
private:
void initialize() { RegisterOps<TestOp>(); }
}; };
class TestPass : public ir::Pass { class TestPass : public ir::Pass {
public: public:
TestPass() : ir::Pass("TestPass", 1) {} TestPass() : ir::Pass("TestPass", 1) {}
void Run(ir::Operation *op) override { void Run(ir::Operation *op) override {
auto test_op = op->dyn_cast<TestOp>(); auto module_op = op->dyn_cast<ir::ModuleOp>();
CHECK_EQ(test_op.operation(), op); CHECK_EQ(module_op.operation(), op);
CHECK_EQ(test_op.name(), test_op->op_info().name()); CHECK_EQ(module_op.name(), module_op->name());
LOG(INFO) << "In " << pass_info().name << ": " << test_op->op_info().name(); LOG(INFO) << "In " << pass_info().name << ": " << module_op->name()
<< std::endl;
} }
bool CanScheduleOn(ir::Operation *op) const override { bool CanApplyOn(ir::Operation *op) const override {
return std::strcmp(op->op_info().name(), "TestDialect.TestOp") == 0; return op->name() == "builtin.module" && op->num_regions() > 0;
} }
}; };
TEST(pass_manager_test, pass_manager) { TEST(pass_manager_test, pass_manager) {
// (1) Register Dialect, Operation into IrContext. // (1) Init environment.
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ir::Dialect *test_dialect = ctx->GetOrRegisterDialect<TestDialect>(); ir::Dialect *builtin_dialect =
CHECK_EQ(test_dialect != nullptr, true); ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
builtin_dialect->RegisterOp<AddOp>();
// (2) Get registered operations. ir::Dialect *paddle_dialect =
std::string op_name = std::string(TestOp::name()); ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
auto op_info = ctx->GetRegisteredOpInfo(op_name);
CHECK_EQ(op_info != nullptr, true); // (2) Create an empty program object
ir::Program program(ctx);
// (3) Test uses for op.
std::vector<ir::OpResult> op_inputs = {}; // (3) Create a float32 DenseTensor Parameter and save into Program
std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)}; ir::Type fp32_dtype = ir::Float32Type::get(ctx);
ir::Operation *op = paddle::dialect::DenseTensorTypeStorage::Dim dims = {2, 2};
ir::Operation::create(op_inputs, paddle::dialect::DenseTensorTypeStorage::DataLayout data_layout =
CreateAttributeMap(ctx, "op1_attr1", "op1_attr1"), paddle::dialect::DenseTensorTypeStorage::DataLayout::NCHW;
op_output_types, paddle::dialect::DenseTensorTypeStorage::LoD lod = {{0, 1, 2}};
op_info); size_t offset = 0;
ir::Type dense_tensor_dtype = paddle::dialect::DenseTensorType::get(
CHECK_EQ(op != nullptr, true); ctx, fp32_dtype, dims, data_layout, lod, offset);
// (4) Test pass manager for op. std::vector<float> data_a = {1, 2, 3, 4};
std::unique_ptr<ir::Parameter> parameter_a =
std::make_unique<ir::Parameter>(reinterpret_cast<void *>(data_a.data()),
4 * sizeof(float),
dense_tensor_dtype);
program.SetParameter("a", std::move(parameter_a));
EXPECT_EQ(program.parameters_num() == 1, true);
std::vector<float> data_b = {5, 6, 7, 8};
std::unique_ptr<ir::Parameter> parameter_b =
std::make_unique<ir::Parameter>(reinterpret_cast<void *>(data_b.data()),
4 * sizeof(float),
dense_tensor_dtype);
program.SetParameter("b", std::move(parameter_b));
EXPECT_EQ(program.parameters_num() == 2, true);
// (4) Def a = GetParameterOp("a"), and create DenseTensor for a.
std::string op1_name = ir::GetParameterOp::name();
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name);
std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "a")}};
ir::Operation *op1 =
ir::Operation::create({}, op1_attribute, {dense_tensor_dtype}, op1_info);
ir::Block *block = program.block();
block->push_back(op1);
EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParentRegion());
EXPECT_EQ(program.module_op(), block->GetParentOp());
EXPECT_EQ(&program, op1->GetParentProgram());
EXPECT_EQ(op1->GetResultByIndex(0).type().dialect().id(),
paddle_dialect->id());
using Interface = paddle::dialect::ParameterConvertInterface;
Interface *a_interface = op1->GetResultByIndex(0)
.type()
.dialect()
.GetRegisteredInterface<Interface>();
std::shared_ptr<paddle::framework::Variable> a_var =
a_interface->ParameterToVariable(program.GetParameter("a"));
const phi::DenseTensor &a_tensor = a_var->Get<phi::DenseTensor>();
EXPECT_EQ(a_tensor.numel(), 4);
EXPECT_EQ(a_tensor.dims(), dims);
EXPECT_EQ(a_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype));
EXPECT_EQ(a_tensor.layout(), data_layout);
EXPECT_EQ(a_tensor.lod(), lod);
EXPECT_EQ(a_tensor.offset(), offset);
for (int64_t i = 0; i < a_tensor.numel(); i++) {
EXPECT_EQ(*(a_tensor.data<float>() + i), data_a[i]);
}
// (5) Def b = GetParameterOp("b"), and create DenseTensor for b.
std::string op2_name =
builtin_dialect->name() + "." + std::string(ir::GetParameterOp::name());
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
std::unordered_map<std::string, ir::Attribute> op2_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 =
ir::Operation::create({}, op2_attribute, {dense_tensor_dtype}, op2_info);
block->push_back(op2);
EXPECT_EQ(op2->GetResultByIndex(0).type().dialect().id(),
paddle_dialect->id());
Interface *b_interface = op2->GetResultByIndex(0)
.type()
.dialect()
.GetRegisteredInterface<Interface>();
std::shared_ptr<paddle::framework::Variable> b_var =
b_interface->ParameterToVariable(program.GetParameter("b"));
const phi::DenseTensor &b_tensor = b_var->Get<phi::DenseTensor>();
EXPECT_EQ(b_tensor.numel(), 4);
EXPECT_EQ(b_tensor.dims(), dims);
EXPECT_EQ(b_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype));
EXPECT_EQ(b_tensor.layout(), data_layout);
EXPECT_EQ(b_tensor.lod(), lod);
EXPECT_EQ(b_tensor.offset(), offset);
for (int64_t i = 0; i < b_tensor.numel(); i++) {
EXPECT_EQ(*(b_tensor.data<float>() + i), data_b[i]);
}
// (6) Def c = AddOp(a, b), execute this op.
std::string op3_name =
builtin_dialect->name() + "." + std::string(AddOp::name());
ir::OpInfo op3_info = ctx->GetRegisteredOpInfo(op3_name);
std::unordered_map<std::string, ir::Attribute> op3_attribute;
ir::Operation *op3 = ir::Operation::create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
op3_attribute,
{dense_tensor_dtype},
op3_info);
block->push_back(op3);
phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(
paddle::platform::CPUPlace()));
phi::DenseTensor c_tensor =
phi::Add<float, phi::CPUContext>(*dev_ctx, a_tensor, b_tensor);
std::shared_ptr<paddle::framework::Variable> variable_c =
std::make_shared<paddle::framework::Variable>();
auto *dst_tensor = variable_c->GetMutable<phi::DenseTensor>();
*dst_tensor = c_tensor;
EXPECT_EQ(dst_tensor->numel(), b_tensor.numel());
EXPECT_EQ(dst_tensor->dims(), b_tensor.dims());
EXPECT_EQ(dst_tensor->dtype(), b_tensor.dtype());
EXPECT_EQ(dst_tensor->layout(), b_tensor.layout());
EXPECT_EQ(dst_tensor->lod(), b_tensor.lod());
EXPECT_EQ(dst_tensor->offset(), b_tensor.offset());
for (int64_t i = 0; i < dst_tensor->numel(); i++) {
EXPECT_EQ(*(dst_tensor->data<float>() + i), data_a[i] + data_b[i]);
}
// (7) Def AbsOp(b)
ir::OpInfo abs_info = ctx->GetRegisteredOpInfo("pd.abs");
std::vector<ir::OpResult> operands = {op1->GetResultByIndex(0)};
std::unordered_map<std::string, ir::Attribute> abs_op_attribute;
std::vector<ir::Type> output_types = {dense_tensor_dtype};
ir::OperationArgument abs_argument(abs_info);
abs_argument.AddOperands(operands.begin(), operands.end());
abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end());
abs_argument.AddTypes(output_types.begin(), output_types.end());
ir::Operation *abs_op = ir::Operation::create(std::move(abs_argument));
paddle::dialect::GetOpInfoInterface interface =
abs_op->dyn_cast<paddle::dialect::GetOpInfoInterface>();
EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true);
// (8) Def SetParameterOp(c, "c")
std::string op4_name =
builtin_dialect->name() + "." + std::string(ir::SetParameterOp::name());
ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name);
std::unordered_map<std::string, ir::Attribute> op4_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "c")}};
ir::OperationArgument op4_argument(
{op3->GetResultByIndex(0)}, {}, {}, op4_info);
op4_argument.AddAttributes(op4_attribute.begin(), op4_attribute.end());
ir::Operation *op4 = ir::Operation::create(std::move(op4_argument));
block->push_back(op4);
EXPECT_EQ(op4->GetOperandByIndex(0).source().type().dialect().id(),
paddle_dialect->id());
Interface *c_interface = op4->GetOperandByIndex(0)
.source()
.type()
.dialect()
.GetRegisteredInterface<Interface>();
// ir::Parameter *parameter_c =
// c_interface->VariableToParameter(variable_c.get());
std::unique_ptr<ir::Parameter> parameter_c =
c_interface->VariableToParameter(variable_c.get());
EXPECT_EQ(parameter_c->type(), dense_tensor_dtype);
for (int64_t i = 0; i < dst_tensor->numel(); i++) {
EXPECT_EQ(*(dst_tensor->data<float>() + i),
*(static_cast<float *>(parameter_c->data()) + i));
}
program.SetParameter("c", std::move(parameter_c));
// (8) Traverse Program
EXPECT_EQ(program.block()->size() == 4, true);
EXPECT_EQ(program.parameters_num() == 3, true);
// (9) Test pass manager for program.
ir::PassManager pm(ctx); ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>()); pm.AddPass(std::make_unique<TestPass>());
CHECK_EQ(pm.Run(op), true); pm.EnableIRPrinting(std::make_unique<ir::PassManager::IRPrinterOption>(
[](ir::Pass *pass, ir::Operation *op) {
return pass->pass_info().name == "TestPass";
},
[](ir::Pass *pass, ir::Operation *op) {
return pass->pass_info().name == "TestPass";
},
true,
false));
op->destroy(); CHECK_EQ(pm.Run(&program), true);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册