未验证 提交 24c4e37f 编写于 作者: Y Yuanle Liu 提交者: GitHub

[IR&PASS] part 2-3: add PassTiming (#54348)

上级 3ea7d577
......@@ -13,26 +13,24 @@
// limitations under the License.
#include <ostream>
#include <string>
#include <unordered_map>
#include "paddle/ir/core/operation.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pass/utils.h"
namespace ir {
namespace {
void PrintIR(Operation *op, bool print_module, std::ostream &os) {
// Otherwise, check to see if we are not printing at module scope.
if (print_module) {
if (!print_module) {
op->Print(os << "\n");
return;
}
// Otherwise, we are printing at module scope.
os << " ('" << op->name() << "' operation)\n";
// Find the top-level operation.
auto *top_op = op;
while (auto *parent_op = top_op->GetParentOp()) {
......@@ -55,7 +53,9 @@ class IRPrinting : public PassInstrumentation {
}
option_->PrintBeforeIfEnabled(pass, op, [&](std::ostream &os) {
os << "// *** IR Dump Before " << pass->pass_info().name << " ***";
std::string header =
"IRPrinting on " + op->name() + " before " + pass->name() + " pass";
detail::PrintHeader(header, os);
PrintIR(op, option_->EnablePrintModule(), os);
os << "\n\n";
});
......@@ -66,8 +66,10 @@ class IRPrinting : public PassInstrumentation {
// TODO(liuyuanle): support print on change
}
option_->PrintBeforeIfEnabled(pass, op, [&](std::ostream &os) {
os << "// *** IR Dump After " << pass->pass_info().name << " ***";
option_->PrintAfterIfEnabled(pass, op, [&](std::ostream &os) {
std::string header =
"IRPrinting on " + op->name() + " after " + pass->name() + " pass";
detail::PrintHeader(header, os);
PrintIR(op, option_->EnablePrintModule(), os);
os << "\n\n";
});
......
......@@ -33,13 +33,11 @@ bool Pass::CanApplyOn(Operation* op) const { return op->num_regions() > 0; }
//----------------------------------------------------------------------------------------------//
// PassAdaptor
//----------------------------------------------------------------------------------------------//
void detail::PassAdaptor::Run(ir::Operation* op,
uint8_t opt_level,
bool verify) {
void detail::PassAdaptor::Run(Operation* op, uint8_t opt_level, bool verify) {
RunImpl(op, opt_level, verify);
}
void detail::PassAdaptor::RunImpl(ir::Operation* op,
void detail::PassAdaptor::RunImpl(Operation* op,
uint8_t opt_level,
bool verify) {
auto last_am = analysis_manager();
......@@ -60,7 +58,7 @@ void detail::PassAdaptor::RunImpl(ir::Operation* op,
}
bool detail::PassAdaptor::RunPipeline(const PassManager& pm,
ir::Operation* op,
Operation* op,
AnalysisManager am,
uint8_t opt_level,
bool verify) {
......@@ -90,7 +88,7 @@ bool detail::PassAdaptor::RunPipeline(const PassManager& pm,
}
bool detail::PassAdaptor::RunPass(Pass* pass,
ir::Operation* op,
Operation* op,
AnalysisManager am,
uint8_t opt_level,
bool verify) {
......@@ -122,26 +120,26 @@ bool detail::PassAdaptor::RunPass(Pass* pass,
//----------------------------------------------------------------------------------------------//
// PassManager
//----------------------------------------------------------------------------------------------//
PassManager::PassManager(ir::IrContext* context, uint8_t opt_level)
PassManager::PassManager(IrContext* context, uint8_t opt_level)
: context_(context), opt_level_(opt_level) {
pass_adaptor_ = std::make_unique<detail::PassAdaptor>(this);
}
bool PassManager::Run(ir::Program* program) {
bool PassManager::Run(Program* program) {
if (!Initialize(context_)) {
return false;
}
return Run(program->module_op());
}
bool PassManager::Run(ir::Operation* op) {
bool PassManager::Run(Operation* op) {
// Construct a analysis manager for the pipeline.
AnalysisManagerHolder am(op, instrumentor_.get());
return detail::PassAdaptor::RunPipeline(*this, op, am, opt_level_, verify_);
}
bool PassManager::Initialize(ir::IrContext* context) {
bool PassManager::Initialize(IrContext* context) {
for (auto& pass : passes()) {
if (!pass->Initialize(context)) return false;
}
......@@ -170,13 +168,15 @@ PassInstrumentor::PassInstrumentor()
PassInstrumentor::~PassInstrumentor() = default;
void PassInstrumentor::RunBeforePipeline(ir::Operation* op) {
void PassInstrumentor::RunBeforePipeline(Operation* op) {
if (op->num_regions() == 0) return;
for (auto& instr : impl_->instrumentations) {
instr->RunBeforePipeline(op);
}
}
void PassInstrumentor::RunAfterPipeline(ir::Operation* op) {
void PassInstrumentor::RunAfterPipeline(Operation* op) {
if (op->num_regions() == 0) return;
for (auto it = impl_->instrumentations.rbegin();
it != impl_->instrumentations.rend();
++it) {
......@@ -184,13 +184,15 @@ void PassInstrumentor::RunAfterPipeline(ir::Operation* op) {
}
}
void PassInstrumentor::RunBeforePass(Pass* pass, ir::Operation* op) {
void PassInstrumentor::RunBeforePass(Pass* pass, Operation* op) {
if (op->num_regions() == 0) return;
for (auto& instr : impl_->instrumentations) {
instr->RunBeforePass(pass, op);
}
}
void PassInstrumentor::RunAfterPass(Pass* pass, ir::Operation* op) {
void PassInstrumentor::RunAfterPass(Pass* pass, Operation* op) {
if (op->num_regions() == 0) return;
for (auto it = impl_->instrumentations.rbegin();
it != impl_->instrumentations.rend();
++it) {
......@@ -199,16 +201,18 @@ void PassInstrumentor::RunAfterPass(Pass* pass, ir::Operation* op) {
}
void PassInstrumentor::RunBeforeAnalysis(const std::string& name,
ir::TypeId id,
ir::Operation* op) {
TypeId id,
Operation* op) {
if (op->num_regions() == 0) return;
for (auto& instr : impl_->instrumentations) {
instr->RunBeforeAnalysis(name, id, op);
}
}
void PassInstrumentor::RunAfterAnalysis(const std::string& name,
ir::TypeId id,
ir::Operation* op) {
TypeId id,
Operation* op) {
if (op->num_regions() == 0) return;
for (auto it = impl_->instrumentations.rbegin();
it != impl_->instrumentations.rend();
++it) {
......
......@@ -77,7 +77,7 @@ class Pass {
virtual ~Pass();
std::string name() { return pass_info().name; }
std::string name() const { return pass_info().name; }
const detail::PassInfo& pass_info() const { return pass_info_; }
......
......@@ -46,7 +46,7 @@ class PassAdaptor final : public Pass {
uint8_t opt_level,
bool verify);
// Use for RunImpl later.
private:
PassManager* pm_;
// For accessing RunPipeline.
......
......@@ -32,19 +32,24 @@ class PassInstrumentation {
PassInstrumentation() = default;
virtual ~PassInstrumentation() = default;
/// A callback to run before a pass pipeline is executed.
// A callback to run before a pass pipeline is executed.
virtual void RunBeforePipeline(Operation* op) {}
// A callback to run after a pass pipeline is executed.
virtual void RunAfterPipeline(Operation* op) {}
// A callback to run before a pass is executed.
virtual void RunBeforePass(Pass* pass, Operation* op) {}
// A callback to run after a pass is executed.
virtual void RunAfterPass(Pass* pass, Operation* op) {}
// A callback to run before a analysis is executed.
virtual void RunBeforeAnalysis(const std::string& name,
TypeId id,
Operation* op) {}
// A callback to run after a analysis is executed.
virtual void RunAfterAnalysis(const std::string& name,
TypeId id,
Operation* op) {}
......
......@@ -42,7 +42,7 @@ class PassManager {
const std::vector<std::unique_ptr<Pass>> &passes() const { return passes_; }
bool empty() const { return passes_.empty(); }
bool Empty() const { return passes_.empty(); }
IrContext *context() const { return context_; }
......@@ -111,6 +111,8 @@ class PassManager {
void EnableIRPrinting(std::unique_ptr<IRPrinterOption> config);
void EnablePassTiming(bool print_module = true);
void AddInstrumentation(std::unique_ptr<PassInstrumentation> pi);
private:
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <chrono>
#include <iomanip>
#include <ostream>
#include <string>
#include <unordered_map>
#include "paddle/ir/core/operation.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pass/utils.h"
namespace ir {
namespace {
class Timer {
public:
Timer() = default;
~Timer() = default;
void Start() { start_time_ = std::chrono::steady_clock::now(); }
void Stop() { walk_time += std::chrono::steady_clock::now() - start_time_; }
double GetTimePerSecond() const {
return std::chrono::duration_cast<std::chrono::duration<double>>(walk_time)
.count();
}
private:
std::chrono::time_point<std::chrono::steady_clock> start_time_;
std::chrono::nanoseconds walk_time = std::chrono::nanoseconds(0);
};
} // namespace
class PassTimer : public PassInstrumentation {
public:
explicit PassTimer(bool print_module) : print_module_(print_module) {}
~PassTimer() = default;
void RunBeforePipeline(ir::Operation* op) override {
pipeline_timers_[op] = Timer();
pipeline_timers_[op].Start();
}
void RunAfterPipeline(Operation* op) override {
pipeline_timers_[op].Stop();
PrintTime(op, std::cout);
}
void RunBeforePass(Pass* pass, Operation* op) override {
if (!pass_timers_.count(op)) {
pass_timers_[op] = {};
}
pass_timers_[op][pass->name()] = Timer();
pass_timers_[op][pass->name()].Start();
}
void RunAfterPass(Pass* pass, Operation* op) override {
pass_timers_[op][pass->name()].Stop();
}
private:
void PrintTime(Operation* op, std::ostream& os) {
if (print_module_ && op->name() != "builtin.module") return;
std::string header = "PassTiming on " + op->name();
detail::PrintHeader(header, os);
os << " Total Execution Time: " << std::fixed << std::setprecision(3)
<< pipeline_timers_[op].GetTimePerSecond() << " seconds\n\n";
os << " ----Walk Time---- ----Name----\n";
auto& map = pass_timers_[op];
std::vector<std::pair<std::string, Timer>> pairs(map.begin(), map.end());
std::sort(pairs.begin(),
pairs.end(),
[](const std::pair<std::string, Timer>& lhs,
const std::pair<std::string, Timer>& rhs) {
return lhs.second.GetTimePerSecond() >
rhs.second.GetTimePerSecond();
});
for (auto& v : pairs) {
os << " " << std::fixed << std::setw(8) << std::setprecision(3)
<< v.second.GetTimePerSecond() << " (" << std::setw(5)
<< std::setprecision(1)
<< 100 * v.second.GetTimePerSecond() /
pipeline_timers_[op].GetTimePerSecond()
<< "%)"
<< " " << v.first << "\n";
}
}
private:
bool print_module_;
std::unordered_map<Operation*, Timer> pipeline_timers_;
std::unordered_map<Operation*,
std::unordered_map<std::string /*pass name*/, Timer>>
pass_timers_;
};
void PassManager::EnablePassTiming(bool print_module) {
AddInstrumentation(std::make_unique<PassTimer>(print_module));
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/pass/utils.h"
namespace ir {
namespace detail {
void PrintHeader(const std::string &header, std::ostream &os) {
unsigned padding = (80 - header.size()) / 2;
os << "===" << std::string(73, '-') << "===\n";
os << std::string(padding, ' ') << header << "\n";
os << "===" << std::string(73, '-') << "===\n";
}
} // namespace detail
} // namespace ir
// paddle/pass/utils.h
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -16,6 +14,8 @@
#pragma once
#include <ostream>
#include <string>
#include <type_traits>
namespace ir {
......@@ -41,5 +41,11 @@ struct detector<void_t<Op<Args...>>, Op, Args...> {
template <template <class...> class Op, class... Args>
using is_detected = typename detector<void, Op, Args...>::value_t;
// Print content as follows.
// ===-------------------------------------------------------------------------===
// header
// ===-------------------------------------------------------------------------===
void PrintHeader(const std::string &header, std::ostream &os);
} // namespace detail
} // namespace ir
......@@ -254,13 +254,15 @@ TEST(pass_manager_test, pass_manager) {
pm.EnableIRPrinting(std::make_unique<ir::PassManager::IRPrinterOption>(
[](ir::Pass *pass, ir::Operation *op) {
return pass->pass_info().name == "TestPass";
return pass->name() == "TestPass";
},
[](ir::Pass *pass, ir::Operation *op) {
return pass->pass_info().name == "TestPass";
return pass->name() == "TestPass";
},
true,
false));
true));
pm.EnablePassTiming(true);
CHECK_EQ(pm.Run(&program), true);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册