未验证 提交 dff77c23 编写于 作者: Y Yuanle Liu 提交者: GitHub

add pass/analysis_manager.h ir/type_name.h pass/pass_instrumentation.h...

add pass/analysis_manager.h ir/type_name.h pass/pass_instrumentation.h pass/utils.h and adjust pass dir (#54170)
上级 f9065e15
......@@ -9,8 +9,6 @@ add_subdirectory(testing)
add_subdirectory(phi)
add_subdirectory(fluid)
add_subdirectory(pass)
# NOTE(zhiqiu): The changes of cc tests
# Before, (1) the source file of cc tests are distributed in different sub-directories,
# (2) the tests are added and configured by calling `cc_test()` in each `CMakeLists.txt`,
......
......@@ -3,3 +3,4 @@ if(NOT WITH_NEWIR)
endif()
add_subdirectory(core)
add_subdirectory(pass)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <string>
namespace ir {
template <typename DesiredTypeName>
inline std::string get_type_name() {
#if defined(__clang__) || defined(__GNUC__)
std::string name = __PRETTY_FUNCTION__;
std::string key = "DesiredTypeName = ";
name = name.substr(name.find(key));
assert(!name.empty() && "Unable to find the template parameter!");
name = name.substr(key.size());
assert(name.back() == "]" && "Name doesn't end in the substitution key!");
auto sem_pos = name.find_first_of(";");
if (sem_pos == std::string::npos)
name.pop_back();
else
name = name.substr(0, sem_pos);
return name;
#elif defined(_MSC_VER)
std::string name = __FUNCSIG__;
std::string key = "get_type_name<";
name = name.substr(name.find(key));
assert(!name.empty() && "Unable to find the function name!");
name = name.substr(key.size());
for (std::string prefix : {"class ", "struct ", "union ", "enum "}) {
if (name.find(prefix) == 0) {
name = name.substr(prefix.size());
break;
}
}
auto angle_pos = name.rfind('>');
assert(angle_pos != std::string::npos && "Unable to find the closing '>'!");
return name.substr(0, angle_pos);
#else
// No known technique for statically extracting a type name on this compiler.
// We return a string that is unlikely to look like any type in LLVM.
return "UNKNOWN_TYPE";
#endif
}
} // namespace ir
if(NOT WITH_NEWIR)
return()
endif()
file(GLOB NEW_PASS_SRCS "*.cc")
cc_library(
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/ir/core/cast_utils.h"
#include "paddle/ir/core/type_id.h"
#include "paddle/ir/core/type_name.h"
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/utils.h"
#include "paddle/utils/optional.h"
namespace ir {
class Operation;
class AnalysisManager;
class PassInstrumentor;
namespace detail {
/// A utility class to reprensent the analyses that are kwnown to be preserved.
class PreservedAnalyses {
struct AllAnalysesType {};
public:
/// Mark all analyses as preserved.
void PreserveAll() {
preserved_ids_.insert(ir::TypeId::get<AllAnalysesType>());
}
bool IsAll() const {
return preserved_ids_.count(ir::TypeId::get<AllAnalysesType>());
}
bool IsNone() const { return preserved_ids_.empty(); }
template <typename AnalysisT>
void Preserve() {
Preserve(ir::TypeId::get<AnalysisT>());
}
template <typename AnalysisT, typename AnalysisT2, typename... OtherAnalysesT>
void Preserve() {
Preserve<AnalysisT>();
Preserve<AnalysisT2, OtherAnalysesT...>();
}
void Preserve(ir::TypeId id) { preserved_ids_.insert(id); }
template <typename AnalysisT>
bool IsPreserved() const {
return IsPreserved(ir::TypeId::get<AnalysisT>());
}
bool IsPreserved(ir::TypeId id) const { return preserved_ids_.count(id); }
template <typename AnalysisT>
void Unpreserve() {
preserved_ids_.erase(ir::TypeId::get<AnalysisT>());
}
private:
template <typename>
friend struct AnalysisModel;
std::unordered_set<ir::TypeId> preserved_ids_;
};
namespace detail {
/// Trait to check if T provides a static `IsInvalidated` method.
template <typename T, typename... Args>
using has_is_invalidated = decltype(std::declval<T&>().IsInvalidated(
std::declval<const PreservedAnalyses&>()));
/// Implementation of `IsInvalidated` if the analysis provides a definition.
template <typename AnalysisT>
std::enable_if_t<is_detected<has_is_invalidated, AnalysisT>::value, bool>
IsInvalidated(AnalysisT& analysis, const PreservedAnalyses& pa) { // NOLINT
return analysis.IsInvalidated(pa);
}
/// Default implementation of `IsInvalidated`.
template <typename AnalysisT>
std::enable_if_t<!is_detected<has_is_invalidated, AnalysisT>::value, bool>
IsInvalidated(AnalysisT& analysis, const PreservedAnalyses& pa) { // NOLINT
return !pa.IsPreserved<AnalysisT>();
}
} // namespace detail
/// Abstract base class representing an analysis.
struct AnalysisConcept {
virtual ~AnalysisConcept() = default;
// A hook used to query analyses for invalidation.
virtual bool Invalidate(PreservedAnalyses& pa) = 0; // NOLINT
};
template <typename AnalysisT>
struct AnalysisModel : public AnalysisConcept {
template <typename... Args>
explicit AnalysisModel(Args&&... args)
: analysis(std::forward<Args>(args)...) {}
bool Invalidate(PreservedAnalyses& pa) final {
bool result = detail::IsInvalidated(analysis, pa);
if (result) pa.Unpreserve<AnalysisT>();
return result;
}
AnalysisT analysis;
};
/// This class represents a cache of analyses for a single operation.
/// All computation, caching and invalidation of analyses takes place here.
class AnalysisMap {
public:
explicit AnalysisMap(ir::Operation* ir) : ir_(ir) {}
template <typename AnalysisT>
AnalysisT& GetAnalysis(PassInstrumentor* pi, AnalysisManager& am) {
return GetAnalysisImpl<AnalysisT, ir::Operation*>(pi, ir_, am);
}
template <typename AnalysisT, typename OpT>
std::enable_if_t<
std::is_constructible<AnalysisT, OpT>::value ||
std::is_constructible<AnalysisT, OpT, AnalysisManager&>::value,
AnalysisT&>
GetAnalysis(PassInstrumentor* pi, AnalysisManager& am) { // NOLINT
return GetAnalysisImpl<AnalysisT, OpT>(pi, ir::cast<OpT>(ir_), am);
}
template <typename AnalysisT>
paddle::optional<std::reference_wrapper<AnalysisT>> GetCachedAnalysis()
const {
auto res = analyses_.find(ir::TypeId::get<AnalysisT>());
if (res == analyses_.end()) return paddle::none;
return {static_cast<AnalysisModel<AnalysisT>&>(*res->second).analysis};
}
ir::Operation* getOperation() const { return ir_; }
void Clear() { analyses_.clear(); }
/// Invalidate any cached analyses based upon the given set of preserved
void Invalidate(const PreservedAnalyses& pa) {
PreservedAnalyses pa_copy(pa);
// Remove any analyses that were invalidaed.
// As using MapVector, order of insertion is preserved and
// dependencies always go before users, so need only one iteration.
for (auto it = analyses_.begin(); it != analyses_.end();) {
if (it->second->Invalidate(pa_copy))
it = analyses_.erase(it);
else
++it;
}
}
private:
template <typename AnalysisT>
static std::string GetAnalysisName() {
std::string name = ir::get_type_name<AnalysisT>();
auto pos = name.rfind("::");
if (pos != std::string::npos) {
name = name.substr(pos + 2);
}
return name;
}
template <typename AnalysisT, typename OpT>
AnalysisT& GetAnalysisImpl(PassInstrumentor* pi,
OpT op,
AnalysisManager& am) { // NOLINT
ir::TypeId id = ir::TypeId::get<AnalysisT>();
auto it = analyses_.find(id);
if (it == analyses_.end()) {
if (pi) {
pi->RunBeforeAnalysis(GetAnalysisName<AnalysisT>(), id, ir_);
}
bool was_inserted;
std::tie(it, was_inserted) =
analyses_.insert({id, ConstructAnalysis<AnalysisT>(am, op)});
assert(was_inserted);
if (pi) {
pi->RunAfterAnalysis(GetAnalysisName<AnalysisT>(), id, ir_);
}
}
return static_cast<AnalysisModel<AnalysisT>&>(*it->second).analysis;
}
/// Construct analysis using two arguments constructor (OpT,
/// AnalysisManager&).
template <
typename AnalysisT,
typename OpT,
std::enable_if_t<
std::is_constructible<AnalysisT, OpT, AnalysisManager&>::value>* =
nullptr>
static auto ConstructAnalysis(AnalysisManager& am, OpT op) { // NOLINT
return std::make_unique<AnalysisModel<AnalysisT>>(op, am);
}
/// Construct analysis using single argument constructor (OpT)
template <
typename AnalysisT,
typename OpT,
std::enable_if_t<
!std::is_constructible<AnalysisT, OpT, AnalysisManager&>::value>* =
nullptr>
static auto ConstructAnalysis(AnalysisManager&, OpT op) {
return std::make_unique<AnalysisModel<AnalysisT>>(op);
}
private:
ir::Operation* ir_;
std::unordered_map<ir::TypeId, std::unique_ptr<AnalysisConcept>> analyses_;
};
} // namespace detail
/// This class is intended to be passed around by value, and can not be
/// constructed direcyly.
class AnalysisManager {
public:
using PreservedAnalyses = detail::PreservedAnalyses;
template <typename AnalysisT>
AnalysisT& GetAnalysis() {
return analyses_->GetAnalysis<AnalysisT>(GetPassInstrumentor(), *this);
}
template <typename AnalysisT, typename OpT>
AnalysisT& GetAnalysis() {
return analyses_->GetAnalysis<AnalysisT, OpT>(GetPassInstrumentor(), *this);
}
template <typename AnalysisT>
paddle::optional<std::reference_wrapper<AnalysisT>> GetCachedAnalysis()
const {
return analyses_->GetCachedAnalysis<AnalysisT>();
}
void Invalidate(const PreservedAnalyses& pa) {
if (pa.IsAll()) return;
// Invalidate the analyses for the current operation directly.
analyses_->Invalidate(pa);
}
void clear() { analyses_->Clear(); }
PassInstrumentor* GetPassInstrumentor() const { return instrumentor_; }
ir::Operation* GetOperation() { return analyses_->getOperation(); }
private:
AnalysisManager(detail::AnalysisMap* impl, PassInstrumentor* pi)
: analyses_(impl), instrumentor_(pi) {}
private:
detail::AnalysisMap* analyses_;
PassInstrumentor* instrumentor_;
// For access constructor.
friend class AnalysisManagerHolder;
};
/// A manager class for the container operation. This class hold the
/// memory for the analyses. AnalysisManager just hold the ref to the
/// analyses.
class AnalysisManagerHolder {
public:
AnalysisManagerHolder(ir::Operation* op, PassInstrumentor* pi)
: analyses_(op), pi_(pi) {}
AnalysisManagerHolder(const AnalysisManagerHolder&) = delete;
AnalysisManagerHolder& operator=(const AnalysisManagerHolder&) = delete;
/// Returns an analysis manager for the current container op.
operator AnalysisManager() { return AnalysisManager(&analyses_, pi_); }
private:
detail::AnalysisMap analyses_;
PassInstrumentor* pi_;
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/pass/pass_adaptor.h"
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_manager.h"
namespace ir {
//----------------------------------------------------------------------------------------------//
// PassAdaptor
//----------------------------------------------------------------------------------------------//
void detail::PassAdaptor::Run(ir::Operation* op,
uint8_t opt_level,
bool verify) {
RunImpl(op, opt_level, verify);
}
void detail::PassAdaptor::RunImpl(ir::Operation* op,
uint8_t opt_level,
bool verify) {
// TODO(liuyuanle): Support block, region, etc.
return;
}
bool detail::PassAdaptor::RunPipeline(const PassManager& pm,
ir::Operation* op,
AnalysisManager am,
uint8_t opt_level,
bool verify) {
auto* instrumentor = am.GetPassInstrumentor();
if (instrumentor) {
instrumentor->RunBeforePipeline(op);
}
for (auto& pass : pm.passes()) {
if (pass->CanScheduleOn(op)) {
if (!RunPass(pass.get(), op, am, opt_level, verify)) {
return false;
}
}
}
if (instrumentor) {
instrumentor->RunAfterPipeline(op);
}
// Apply pass manager on all nested ir.
if (!RunPass(pm.pass_adaptor_.get(), op, am, opt_level, verify)) {
return false;
}
return true;
}
bool detail::PassAdaptor::RunPass(Pass* pass,
ir::Operation* op,
AnalysisManager am,
uint8_t opt_level,
bool verify) {
if (opt_level < pass->pass_info().opt_level) return true;
pass->pass_state_ = PassExecutionState(op, am);
PassInstrumentor* instrumentor = am.GetPassInstrumentor();
if (auto* adaptor = dynamic_cast<PassAdaptor*>(pass)) {
adaptor->Run(op, opt_level, verify);
} else {
if (instrumentor) instrumentor->RunBeforePass(pass, op);
pass->Run(op);
if (instrumentor) instrumentor->RunAfterPass(pass, op);
}
bool pass_failed = pass->pass_state().pass_failed;
// TODO(liuyuanle): Support verification of operation
if (!pass_failed && verify) {
// bool verify_recursively = !dynamic_cast<PassAdaptor*>(pass);
// pass_failed = ir::verify(op, verify_recursively);
}
return !pass_failed;
}
//----------------------------------------------------------------------------------------------//
// PassManager
//----------------------------------------------------------------------------------------------//
PassManager::PassManager(ir::IrContext* context, uint8_t opt_level)
: context_(context), opt_level_(opt_level) {
pass_adaptor_ = std::make_unique<detail::PassAdaptor>(this);
}
// bool PassManager::Run(ir::Program* program) const {
// if (!Initialize(context_)) {
// return false;
// }
// return Run(program->operation());
// }
bool PassManager::Run(ir::Operation* op) const {
// Construct a analysis manager for the pipeline.
AnalysisManagerHolder am(op, instrumentor_.get());
return detail::PassAdaptor::RunPipeline(*this, op, am, opt_level_, verify_);
}
bool PassManager::Initialize(ir::IrContext* context) const {
for (auto& pass : passes()) {
if (!pass->Initialize(context)) return false;
}
return true;
}
void PassManager::AddInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
if (!instrumentor_) instrumentor_ = std::make_unique<PassInstrumentor>();
instrumentor_->AddInstrumentation(std::move(pi));
}
//----------------------------------------------------------------------------------------------//
// PassInstrumentor
//----------------------------------------------------------------------------------------------//
namespace detail {
struct PassInstrumentorImpl {
// TODO(wilber): Support multi-thread.
std::vector<std::unique_ptr<PassInstrumentation>> instrumentations;
};
} // namespace detail
PassInstrumentor::PassInstrumentor()
: impl_(new detail::PassInstrumentorImpl{}) {}
PassInstrumentor::~PassInstrumentor() = default;
void PassInstrumentor::RunBeforePipeline(ir::Operation* op) {
for (auto& instr : impl_->instrumentations) {
instr->RunBeforePipeline(op);
}
}
void PassInstrumentor::RunAfterPipeline(ir::Operation* op) {
for (auto it = impl_->instrumentations.rbegin();
it != impl_->instrumentations.rend();
++it) {
(*it)->RunAfterPipeline(op);
}
}
void PassInstrumentor::RunBeforePass(Pass* pass, ir::Operation* op) {
for (auto& instr : impl_->instrumentations) {
instr->RunBeforePass(pass, op);
}
}
void PassInstrumentor::RunAfterPass(Pass* pass, ir::Operation* op) {
for (auto it = impl_->instrumentations.rbegin();
it != impl_->instrumentations.rend();
++it) {
(*it)->RunAfterPass(pass, op);
}
}
void PassInstrumentor::RunBeforeAnalysis(const std::string& name,
ir::TypeId id,
ir::Operation* op) {
for (auto& instr : impl_->instrumentations) {
instr->RunBeforeAnalysis(name, id, op);
}
}
void PassInstrumentor::RunAfterAnalysis(const std::string& name,
ir::TypeId id,
ir::Operation* op) {
for (auto it = impl_->instrumentations.rbegin();
it != impl_->instrumentations.rend();
++it) {
(*it)->RunBeforeAnalysis(name, id, op);
}
}
void PassInstrumentor::AddInstrumentation(
std::unique_ptr<PassInstrumentation> pi) {
impl_->instrumentations.emplace_back(std::move(pi));
}
} // namespace ir
......@@ -17,6 +17,8 @@
#include <cstdint>
#include <vector>
#include "paddle/ir/pass/analysis_manager.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/utils/optional.h"
namespace ir {
......@@ -31,12 +33,15 @@ class PassAdaptor;
namespace detail {
struct PassExecutionState {
explicit PassExecutionState(ir::Operation* ir) : ir(ir), pass_failed(false) {}
explicit PassExecutionState(ir::Operation* ir, const AnalysisManager& am)
: ir(ir), pass_failed(false), am(am) {}
// The IR currently being processed by pass.
ir::Operation* ir;
bool pass_failed;
// TODO(liuyuanle): Add implementation of AnalysisManager and
// PreservedAnalyses.
AnalysisManager am;
PreservedAnalyses preserved_analyses;
};
struct PassInfo {
......@@ -51,7 +56,7 @@ struct PassInfo {
// opt_level=0: the basic pass which framework need.
// opt_level=1: the fusion logical pass.
// opt_level=2: constant fold, cse, memory optimize, etc.
// opt_level=3: layout.
// opt_level=3: layout, etc.
uint8_t opt_level;
// The list which pass depends on.
......@@ -67,11 +72,11 @@ class Pass {
explicit Pass(const char* name,
uint8_t opt_level,
const std::vector<const char*>& dependents = {})
: info_(name, opt_level, dependents) {}
: pass_info_(name, opt_level, dependents) {}
virtual ~Pass() = default;
const detail::PassInfo& GetPassInfo() const { return info_; }
const detail::PassInfo& pass_info() const { return pass_info_; }
protected:
virtual void Run(ir::Operation* op) = 0;
......@@ -81,9 +86,19 @@ class Pass {
virtual bool Initialize(ir::IrContext* context) { return true; }
void SignalPassFailure() { pass_state_->pass_failed = true; }
AnalysisManager analysis_manager() { return pass_state().am; }
detail::PassExecutionState& pass_state() {
PADDLE_ENFORCE_EQ(pass_state_.is_initialized(),
true,
phi::errors::Fatal("pass state was never initialized"));
return *pass_state_;
}
void SignalPassFailure() { pass_state().pass_failed = true; }
detail::PassInfo info_;
private:
detail::PassInfo pass_info_;
paddle::optional<detail::PassExecutionState> pass_state_;
......
......@@ -14,7 +14,7 @@
#pragma once
#include "paddle/pass/pass.h"
#include "paddle/ir/pass/pass.h"
namespace ir {
......@@ -30,16 +30,22 @@ class PassAdaptor final : public Pass {
void Run(ir::Operation*) override {}
void Run(ir::Operation*, uint8_t opt_level);
void Run(ir::Operation*, uint8_t opt_level, bool verify);
private:
void RunImpl(ir::Operation* op, uint8_t opt_level);
void RunImpl(ir::Operation* op, uint8_t opt_level, bool verify);
static bool RunPass(Pass* pass, ir::Operation* op, uint8_t opt_level);
static bool RunPass(Pass* pass,
ir::Operation* op,
AnalysisManager am,
uint8_t opt_level,
bool verify);
static bool RunPipeline(const PassManager& pm,
ir::Operation* op,
uint8_t opt_level);
AnalysisManager am,
uint8_t opt_level,
bool verify);
// Use for RunImpl later.
PassManager* pm_;
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/ir/core/type_id.h"
namespace ir {
class Operation;
class Pass;
namespace detail {
struct PassInstrumentorImpl;
} // namespace detail
class PassInstrumentation {
public:
PassInstrumentation() = default;
virtual ~PassInstrumentation() = default;
/// A callback to run before a pass pipeline is executed.
virtual void RunBeforePipeline(ir::Operation* op) {}
virtual void RunAfterPipeline(ir::Operation* op) {}
virtual void RunBeforePass(Pass* pass, ir::Operation* op) {}
virtual void RunAfterPass(Pass* pass, ir::Operation* op) {}
virtual void RunBeforeAnalysis(const std::string& name,
ir::TypeId id,
ir::Operation* op) {}
virtual void RunAfterAnalysis(const std::string& name,
ir::TypeId id,
ir::Operation* op) {}
};
/// This class holds a collection of PassInstrumentation obejcts, and invokes
/// their respective callbacks.
class PassInstrumentor {
public:
PassInstrumentor();
~PassInstrumentor();
PassInstrumentor(PassInstrumentor&&) = delete;
PassInstrumentor(const PassInstrumentor&) = delete;
void AddInstrumentation(std::unique_ptr<PassInstrumentation> pi);
void RunBeforePipeline(ir::Operation* op);
void RunAfterPipeline(ir::Operation* op);
void RunBeforePass(Pass* pass, ir::Operation* op);
void RunAfterPass(Pass* pass, ir::Operation* op);
void RunBeforeAnalysis(const std::string& name,
ir::TypeId id /* */,
ir::Operation* op);
void RunAfterAnalysis(const std::string& name,
ir::TypeId id,
ir::Operation* op);
// TODO(wilber): Add other hooks.
private:
std::unique_ptr<detail::PassInstrumentorImpl> impl_;
};
} // namespace ir
......@@ -14,14 +14,20 @@
#pragma once
#include <cstdint>
#include <memory>
#include <vector>
#include "paddle/ir/core/program.h"
namespace ir {
class IrContext;
class Operation;
class Program;
class Pass;
class PassInstrumentation;
class PassInstrumentor;
namespace detail {
class PassAdaptor;
......@@ -33,34 +39,37 @@ class PassManager {
~PassManager() = default;
const std::vector<std::unique_ptr<Pass>> &GetPasses() const {
return passes_;
}
const std::vector<std::unique_ptr<Pass>> &passes() const { return passes_; }
bool Empty() const { return passes_.empty(); }
bool empty() const { return passes_.empty(); }
ir::IrContext *GetContext() const { return context_; }
ir::IrContext *context() const { return context_; }
bool Run(ir::Operation *op);
// bool Run(ir::Program *program) const;
bool Run(ir::Operation *op) const;
void AddPass(std::unique_ptr<Pass> pass) {
passes_.emplace_back(std::move(pass));
}
private:
bool RunPasses(ir::Operation *op);
void AddInstrumentation(std::unique_ptr<PassInstrumentation> pi);
bool Initialize(ir::IrContext *context);
private:
bool Initialize(ir::IrContext *context) const;
private:
ir::IrContext *context_;
uint8_t opt_level_;
bool verify_{true};
std::vector<std::unique_ptr<Pass>> passes_;
std::unique_ptr<Pass> pass_adaptor_;
std::unique_ptr<PassInstrumentor> instrumentor_;
friend class detail::PassAdaptor;
};
......
// paddle/pass/utils.h
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
namespace ir {
namespace detail {
template <typename... Ts>
struct make_void {
typedef void type;
};
template <typename... Ts>
using void_t = typename make_void<Ts...>::type;
template <class, template <class...> class Op, class... Args>
struct detector {
using value_t = std::false_type;
};
template <template <class...> class Op, class... Args>
struct detector<void_t<Op<Args...>>, Op, Args...> {
using value_t = std::true_type;
};
template <template <class...> class Op, class... Args>
using is_detected = typename detector<void, Op, Args...>::value_t;
} // namespace detail
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pass/pass.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/pass/pass_adaptor.h"
#include "paddle/pass/pass_manager.h"
namespace ir {
void detail::PassAdaptor::Run(ir::Operation* op, uint8_t opt_level) {
RunImpl(op, opt_level);
}
void detail::PassAdaptor::RunImpl(ir::Operation* op, uint8_t opt_level) {
// TODO(liuyuanle): Support block, region, etc.
return;
}
bool detail::PassAdaptor::RunPipeline(const PassManager& pm,
ir::Operation* op,
uint8_t opt_level) {
for (auto& pass : pm.GetPasses()) {
if (pass->CanScheduleOn(op)) {
if (!RunPass(pass.get(), op, opt_level)) {
return false;
}
}
}
// Apply pass manager on all nested ir.
if (!RunPass(pm.pass_adaptor_.get(), op, opt_level)) {
return false;
}
return true;
}
bool detail::PassAdaptor::RunPass(Pass* pass,
ir::Operation* op,
uint8_t opt_level) {
if (opt_level < pass->info_.opt_level) return true;
pass->pass_state_ = detail::PassExecutionState(op);
if (auto* adaptor = dynamic_cast<detail::PassAdaptor*>(pass)) {
adaptor->Run(op, opt_level);
} else {
pass->Run(op);
}
bool pass_failed = pass->pass_state_->pass_failed;
return !pass_failed;
}
PassManager::PassManager(ir::IrContext* context, uint8_t opt_level)
: context_(context), opt_level_(opt_level) {
pass_adaptor_ = std::make_unique<detail::PassAdaptor>(this);
}
bool PassManager::Run(ir::Operation* op) {
if (!Initialize(context_)) {
return false;
}
return RunPasses(op);
}
bool PassManager::RunPasses(ir::Operation* op) {
return detail::PassAdaptor::RunPipeline(*this, op, opt_level_);
}
bool PassManager::Initialize(ir::IrContext* context) {
for (auto& pass : GetPasses()) {
if (!pass->Initialize(context)) return false;
}
return true;
}
} // namespace ir
......@@ -4,7 +4,6 @@ add_subdirectory(new_executor)
add_subdirectory(prim)
add_subdirectory(imperative)
add_subdirectory(ir)
add_subdirectory(pass)
add_subdirectory(inference)
add_subdirectory(eager)
add_subdirectory(fluid)
......@@ -3,3 +3,4 @@ if(NOT WITH_NEWIR)
endif()
add_subdirectory(core)
add_subdirectory(pass)
......@@ -21,6 +21,7 @@
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/type_base.h"
#include "paddle/ir/core/type_name.h"
#include "paddle/ir/core/utils.h"
TEST(type_test, type_id) {
......@@ -212,3 +213,12 @@ TEST(type_test, custom_type_dialect) {
ir::Dialect *dialect_integer2 = ctx->GetRegisteredDialect<IntegerDialect>();
EXPECT_EQ(dialect_integer1, dialect_integer2);
}
namespace TestNamespace {
class TestClass {};
} // namespace TestNamespace
TEST(type_test, get_type_name) {
auto name = ir::get_type_name<TestNamespace::TestClass>();
EXPECT_EQ(name, "TestNamespace::TestClass");
}
cc_test_old(pass_manager_test SRCS pass_manager_test.cc DEPS new_pass gtest)
......@@ -23,8 +23,8 @@
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/operation.h"
#include "paddle/pass/pass.h"
#include "paddle/pass/pass_manager.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_manager.h"
ir::AttributeMap CreateAttributeMap(ir::IrContext *ctx,
std::string attribute_name,
......@@ -72,7 +72,7 @@ class TestPass : public ir::Pass {
auto test_op = op->dyn_cast<TestOp>();
CHECK_EQ(test_op.operation(), op);
CHECK_EQ(test_op.name(), test_op->op_info().name());
LOG(INFO) << "In " << info_.name << ": " << test_op->op_info().name();
LOG(INFO) << "In " << pass_info().name << ": " << test_op->op_info().name();
}
bool CanScheduleOn(ir::Operation *op) const override {
......@@ -80,7 +80,7 @@ class TestPass : public ir::Pass {
}
};
TEST(pass_manager_test, pass_manager_test) {
TEST(pass_manager_test, pass_manager) {
// (1) Register Dialect, Operation into IrContext.
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Dialect *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
......@@ -100,9 +100,13 @@ TEST(pass_manager_test, pass_manager_test) {
op_output_types,
op_info);
CHECK_EQ(op != nullptr, true);
// (4) Test pass manager for op.
ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>());
CHECK_EQ(pm.Run(op), true);
op->destroy();
......
if(WITH_NEWIR)
cc_test_old(pass_manager_test SRCS pass_manager_test.cc DEPS new_pass gtest)
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册