From 24c4e37f324a984978fda8263a4b542cc05d0435 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Tue, 6 Jun 2023 15:27:14 +0800 Subject: [PATCH] [IR&PASS] part 2-3: add PassTiming (#54348) --- paddle/ir/pass/ir_printing.cc | 18 ++-- paddle/ir/pass/pass.cc | 40 +++++---- paddle/ir/pass/pass.h | 2 +- paddle/ir/pass/pass_adaptor.h | 2 +- paddle/ir/pass/pass_instrumentation.h | 7 +- paddle/ir/pass/pass_manager.h | 4 +- paddle/ir/pass/pass_timing.cc | 123 ++++++++++++++++++++++++++ paddle/ir/pass/utils.cc | 28 ++++++ paddle/ir/pass/utils.h | 10 ++- test/cpp/ir/pass/pass_manager_test.cc | 8 +- 10 files changed, 207 insertions(+), 35 deletions(-) create mode 100644 paddle/ir/pass/pass_timing.cc create mode 100644 paddle/ir/pass/utils.cc diff --git a/paddle/ir/pass/ir_printing.cc b/paddle/ir/pass/ir_printing.cc index 370307a7759..1d6ab50d794 100644 --- a/paddle/ir/pass/ir_printing.cc +++ b/paddle/ir/pass/ir_printing.cc @@ -13,26 +13,24 @@ // limitations under the License. #include +#include #include #include "paddle/ir/core/operation.h" #include "paddle/ir/pass/pass.h" #include "paddle/ir/pass/pass_instrumentation.h" #include "paddle/ir/pass/pass_manager.h" +#include "paddle/ir/pass/utils.h" namespace ir { namespace { void PrintIR(Operation *op, bool print_module, std::ostream &os) { - // Otherwise, check to see if we are not printing at module scope. - if (print_module) { + if (!print_module) { op->Print(os << "\n"); return; } - // Otherwise, we are printing at module scope. - os << " ('" << op->name() << "' operation)\n"; - // Find the top-level operation. auto *top_op = op; while (auto *parent_op = top_op->GetParentOp()) { @@ -55,7 +53,9 @@ class IRPrinting : public PassInstrumentation { } option_->PrintBeforeIfEnabled(pass, op, [&](std::ostream &os) { - os << "// *** IR Dump Before " << pass->pass_info().name << " ***"; + std::string header = + "IRPrinting on " + op->name() + " before " + pass->name() + " pass"; + detail::PrintHeader(header, os); PrintIR(op, option_->EnablePrintModule(), os); os << "\n\n"; }); @@ -66,8 +66,10 @@ class IRPrinting : public PassInstrumentation { // TODO(liuyuanle): support print on change } - option_->PrintBeforeIfEnabled(pass, op, [&](std::ostream &os) { - os << "// *** IR Dump After " << pass->pass_info().name << " ***"; + option_->PrintAfterIfEnabled(pass, op, [&](std::ostream &os) { + std::string header = + "IRPrinting on " + op->name() + " after " + pass->name() + " pass"; + detail::PrintHeader(header, os); PrintIR(op, option_->EnablePrintModule(), os); os << "\n\n"; }); diff --git a/paddle/ir/pass/pass.cc b/paddle/ir/pass/pass.cc index 83c925b6468..5ede04a97c8 100644 --- a/paddle/ir/pass/pass.cc +++ b/paddle/ir/pass/pass.cc @@ -33,13 +33,11 @@ bool Pass::CanApplyOn(Operation* op) const { return op->num_regions() > 0; } //----------------------------------------------------------------------------------------------// // PassAdaptor //----------------------------------------------------------------------------------------------// -void detail::PassAdaptor::Run(ir::Operation* op, - uint8_t opt_level, - bool verify) { +void detail::PassAdaptor::Run(Operation* op, uint8_t opt_level, bool verify) { RunImpl(op, opt_level, verify); } -void detail::PassAdaptor::RunImpl(ir::Operation* op, +void detail::PassAdaptor::RunImpl(Operation* op, uint8_t opt_level, bool verify) { auto last_am = analysis_manager(); @@ -60,7 +58,7 @@ void detail::PassAdaptor::RunImpl(ir::Operation* op, } bool detail::PassAdaptor::RunPipeline(const PassManager& pm, - ir::Operation* op, + Operation* op, AnalysisManager am, uint8_t opt_level, bool verify) { @@ -90,7 +88,7 @@ bool detail::PassAdaptor::RunPipeline(const PassManager& pm, } bool detail::PassAdaptor::RunPass(Pass* pass, - ir::Operation* op, + Operation* op, AnalysisManager am, uint8_t opt_level, bool verify) { @@ -122,26 +120,26 @@ bool detail::PassAdaptor::RunPass(Pass* pass, //----------------------------------------------------------------------------------------------// // PassManager //----------------------------------------------------------------------------------------------// -PassManager::PassManager(ir::IrContext* context, uint8_t opt_level) +PassManager::PassManager(IrContext* context, uint8_t opt_level) : context_(context), opt_level_(opt_level) { pass_adaptor_ = std::make_unique(this); } -bool PassManager::Run(ir::Program* program) { +bool PassManager::Run(Program* program) { if (!Initialize(context_)) { return false; } return Run(program->module_op()); } -bool PassManager::Run(ir::Operation* op) { +bool PassManager::Run(Operation* op) { // Construct a analysis manager for the pipeline. AnalysisManagerHolder am(op, instrumentor_.get()); return detail::PassAdaptor::RunPipeline(*this, op, am, opt_level_, verify_); } -bool PassManager::Initialize(ir::IrContext* context) { +bool PassManager::Initialize(IrContext* context) { for (auto& pass : passes()) { if (!pass->Initialize(context)) return false; } @@ -170,13 +168,15 @@ PassInstrumentor::PassInstrumentor() PassInstrumentor::~PassInstrumentor() = default; -void PassInstrumentor::RunBeforePipeline(ir::Operation* op) { +void PassInstrumentor::RunBeforePipeline(Operation* op) { + if (op->num_regions() == 0) return; for (auto& instr : impl_->instrumentations) { instr->RunBeforePipeline(op); } } -void PassInstrumentor::RunAfterPipeline(ir::Operation* op) { +void PassInstrumentor::RunAfterPipeline(Operation* op) { + if (op->num_regions() == 0) return; for (auto it = impl_->instrumentations.rbegin(); it != impl_->instrumentations.rend(); ++it) { @@ -184,13 +184,15 @@ void PassInstrumentor::RunAfterPipeline(ir::Operation* op) { } } -void PassInstrumentor::RunBeforePass(Pass* pass, ir::Operation* op) { +void PassInstrumentor::RunBeforePass(Pass* pass, Operation* op) { + if (op->num_regions() == 0) return; for (auto& instr : impl_->instrumentations) { instr->RunBeforePass(pass, op); } } -void PassInstrumentor::RunAfterPass(Pass* pass, ir::Operation* op) { +void PassInstrumentor::RunAfterPass(Pass* pass, Operation* op) { + if (op->num_regions() == 0) return; for (auto it = impl_->instrumentations.rbegin(); it != impl_->instrumentations.rend(); ++it) { @@ -199,16 +201,18 @@ void PassInstrumentor::RunAfterPass(Pass* pass, ir::Operation* op) { } void PassInstrumentor::RunBeforeAnalysis(const std::string& name, - ir::TypeId id, - ir::Operation* op) { + TypeId id, + Operation* op) { + if (op->num_regions() == 0) return; for (auto& instr : impl_->instrumentations) { instr->RunBeforeAnalysis(name, id, op); } } void PassInstrumentor::RunAfterAnalysis(const std::string& name, - ir::TypeId id, - ir::Operation* op) { + TypeId id, + Operation* op) { + if (op->num_regions() == 0) return; for (auto it = impl_->instrumentations.rbegin(); it != impl_->instrumentations.rend(); ++it) { diff --git a/paddle/ir/pass/pass.h b/paddle/ir/pass/pass.h index 2c399d8c326..31b865e14f6 100644 --- a/paddle/ir/pass/pass.h +++ b/paddle/ir/pass/pass.h @@ -77,7 +77,7 @@ class Pass { virtual ~Pass(); - std::string name() { return pass_info().name; } + std::string name() const { return pass_info().name; } const detail::PassInfo& pass_info() const { return pass_info_; } diff --git a/paddle/ir/pass/pass_adaptor.h b/paddle/ir/pass/pass_adaptor.h index 520c2a4c3bc..4b81e8362a7 100644 --- a/paddle/ir/pass/pass_adaptor.h +++ b/paddle/ir/pass/pass_adaptor.h @@ -46,7 +46,7 @@ class PassAdaptor final : public Pass { uint8_t opt_level, bool verify); - // Use for RunImpl later. + private: PassManager* pm_; // For accessing RunPipeline. diff --git a/paddle/ir/pass/pass_instrumentation.h b/paddle/ir/pass/pass_instrumentation.h index 1e15adc9dd8..2105d4b261b 100644 --- a/paddle/ir/pass/pass_instrumentation.h +++ b/paddle/ir/pass/pass_instrumentation.h @@ -32,19 +32,24 @@ class PassInstrumentation { PassInstrumentation() = default; virtual ~PassInstrumentation() = default; - /// A callback to run before a pass pipeline is executed. + // A callback to run before a pass pipeline is executed. virtual void RunBeforePipeline(Operation* op) {} + // A callback to run after a pass pipeline is executed. virtual void RunAfterPipeline(Operation* op) {} + // A callback to run before a pass is executed. virtual void RunBeforePass(Pass* pass, Operation* op) {} + // A callback to run after a pass is executed. virtual void RunAfterPass(Pass* pass, Operation* op) {} + // A callback to run before a analysis is executed. virtual void RunBeforeAnalysis(const std::string& name, TypeId id, Operation* op) {} + // A callback to run after a analysis is executed. virtual void RunAfterAnalysis(const std::string& name, TypeId id, Operation* op) {} diff --git a/paddle/ir/pass/pass_manager.h b/paddle/ir/pass/pass_manager.h index 281e0a7a340..d5cefb477d2 100644 --- a/paddle/ir/pass/pass_manager.h +++ b/paddle/ir/pass/pass_manager.h @@ -42,7 +42,7 @@ class PassManager { const std::vector> &passes() const { return passes_; } - bool empty() const { return passes_.empty(); } + bool Empty() const { return passes_.empty(); } IrContext *context() const { return context_; } @@ -111,6 +111,8 @@ class PassManager { void EnableIRPrinting(std::unique_ptr config); + void EnablePassTiming(bool print_module = true); + void AddInstrumentation(std::unique_ptr pi); private: diff --git a/paddle/ir/pass/pass_timing.cc b/paddle/ir/pass/pass_timing.cc new file mode 100644 index 00000000000..c9f74ac4a1a --- /dev/null +++ b/paddle/ir/pass/pass_timing.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "paddle/ir/core/operation.h" +#include "paddle/ir/pass/pass.h" +#include "paddle/ir/pass/pass_instrumentation.h" +#include "paddle/ir/pass/pass_manager.h" +#include "paddle/ir/pass/utils.h" +namespace ir { +namespace { +class Timer { + public: + Timer() = default; + + ~Timer() = default; + + void Start() { start_time_ = std::chrono::steady_clock::now(); } + + void Stop() { walk_time += std::chrono::steady_clock::now() - start_time_; } + + double GetTimePerSecond() const { + return std::chrono::duration_cast>(walk_time) + .count(); + } + + private: + std::chrono::time_point start_time_; + + std::chrono::nanoseconds walk_time = std::chrono::nanoseconds(0); +}; +} // namespace + +class PassTimer : public PassInstrumentation { + public: + explicit PassTimer(bool print_module) : print_module_(print_module) {} + ~PassTimer() = default; + + void RunBeforePipeline(ir::Operation* op) override { + pipeline_timers_[op] = Timer(); + pipeline_timers_[op].Start(); + } + + void RunAfterPipeline(Operation* op) override { + pipeline_timers_[op].Stop(); + PrintTime(op, std::cout); + } + + void RunBeforePass(Pass* pass, Operation* op) override { + if (!pass_timers_.count(op)) { + pass_timers_[op] = {}; + } + pass_timers_[op][pass->name()] = Timer(); + pass_timers_[op][pass->name()].Start(); + } + + void RunAfterPass(Pass* pass, Operation* op) override { + pass_timers_[op][pass->name()].Stop(); + } + + private: + void PrintTime(Operation* op, std::ostream& os) { + if (print_module_ && op->name() != "builtin.module") return; + + std::string header = "PassTiming on " + op->name(); + detail::PrintHeader(header, os); + + os << " Total Execution Time: " << std::fixed << std::setprecision(3) + << pipeline_timers_[op].GetTimePerSecond() << " seconds\n\n"; + os << " ----Walk Time---- ----Name----\n"; + + auto& map = pass_timers_[op]; + std::vector> pairs(map.begin(), map.end()); + std::sort(pairs.begin(), + pairs.end(), + [](const std::pair& lhs, + const std::pair& rhs) { + return lhs.second.GetTimePerSecond() > + rhs.second.GetTimePerSecond(); + }); + + for (auto& v : pairs) { + os << " " << std::fixed << std::setw(8) << std::setprecision(3) + << v.second.GetTimePerSecond() << " (" << std::setw(5) + << std::setprecision(1) + << 100 * v.second.GetTimePerSecond() / + pipeline_timers_[op].GetTimePerSecond() + << "%)" + << " " << v.first << "\n"; + } + } + + private: + bool print_module_; + + std::unordered_map pipeline_timers_; + + std::unordered_map> + pass_timers_; +}; + +void PassManager::EnablePassTiming(bool print_module) { + AddInstrumentation(std::make_unique(print_module)); +} + +} // namespace ir diff --git a/paddle/ir/pass/utils.cc b/paddle/ir/pass/utils.cc new file mode 100644 index 00000000000..8c890943420 --- /dev/null +++ b/paddle/ir/pass/utils.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/pass/utils.h" + +namespace ir { +namespace detail { + +void PrintHeader(const std::string &header, std::ostream &os) { + unsigned padding = (80 - header.size()) / 2; + os << "===" << std::string(73, '-') << "===\n"; + os << std::string(padding, ' ') << header << "\n"; + os << "===" << std::string(73, '-') << "===\n"; +} + +} // namespace detail +} // namespace ir diff --git a/paddle/ir/pass/utils.h b/paddle/ir/pass/utils.h index b3724431d11..61ee43037e8 100644 --- a/paddle/ir/pass/utils.h +++ b/paddle/ir/pass/utils.h @@ -1,5 +1,3 @@ -// paddle/pass/utils.h - // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +14,8 @@ #pragma once +#include +#include #include namespace ir { @@ -41,5 +41,11 @@ struct detector>, Op, Args...> { template