未验证 提交 9999e849 编写于 作者: L Leo Chen 提交者: GitHub

[New-IR] add pass registry (#56729)

* add pass registry

* add pass registry macro
上级 fc1e1b77
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
#include "paddle/ir/core/value.h" #include "paddle/ir/core/value.h"
#include "paddle/ir/pass/pass.h" #include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_manager.h" #include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pass/pass_registry.h"
#include "paddle/ir/transforms/dead_code_elimination_pass.h" #include "paddle/ir/transforms/dead_code_elimination_pass.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
...@@ -57,6 +58,8 @@ using paddle::dialect::APIBuilder; ...@@ -57,6 +58,8 @@ using paddle::dialect::APIBuilder;
using paddle::dialect::DenseTensorType; using paddle::dialect::DenseTensorType;
using pybind11::return_value_policy; using pybind11::return_value_policy;
USE_PASS(dead_code_elimination);
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -488,15 +491,6 @@ void BindIrPass(pybind11::module *m) { ...@@ -488,15 +491,6 @@ void BindIrPass(pybind11::module *m) {
[](const Pass &self) { return self.pass_info().dependents; }); [](const Pass &self) { return self.pass_info().dependents; });
} }
// TODO(zhiqiu): refine pass registry
std::unique_ptr<Pass> CreatePassByName(std::string name) {
if (name == "DeadCodeEliminationPass") {
return ir::CreateDeadCodeEliminationPass();
} else {
IR_THROW("The %s pass is not registed", name);
}
}
void BindPassManager(pybind11::module *m) { void BindPassManager(pybind11::module *m) {
py::class_<PassManager, std::shared_ptr<PassManager>> pass_manager( py::class_<PassManager, std::shared_ptr<PassManager>> pass_manager(
*m, *m,
...@@ -514,7 +508,8 @@ void BindPassManager(pybind11::module *m) { ...@@ -514,7 +508,8 @@ void BindPassManager(pybind11::module *m) {
py::arg("opt_level") = 2) py::arg("opt_level") = 2)
.def("add_pass", .def("add_pass",
[](PassManager &self, std::string pass_name) { [](PassManager &self, std::string pass_name) {
self.AddPass(std::move(CreatePassByName(pass_name))); self.AddPass(
std::move(ir::PassRegistry::Instance().Get(pass_name)));
}) })
.def("passes", .def("passes",
[](PassManager &self) { [](PassManager &self) {
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/enforce.h"
#include "paddle/ir/pass/analysis_manager.h" #include "paddle/ir/pass/analysis_manager.h"
#include "paddle/ir/pass/pass_registry.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
namespace ir { namespace ir {
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/pass/pass_registry.h"
namespace ir {
PassRegistry &PassRegistry::Instance() {
static PassRegistry g_pass_info_map;
return g_pass_info_map;
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <unordered_map>
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/macros.h"
#include "paddle/ir/pass/pass.h"
namespace ir {
class Pass;
using PassCreator = std::function<std::unique_ptr<Pass>()>;
class PassRegistry {
public:
static PassRegistry &Instance();
bool Has(const std::string &pass_type) const {
return pass_map_.find(pass_type) != pass_map_.end();
}
void Insert(const std::string &pass_type, const PassCreator &pass_creator) {
IR_ENFORCE(
Has(pass_type) != true, "Pass %s has been registered.", pass_type);
pass_map_.insert({pass_type, pass_creator});
}
std::unique_ptr<Pass> Get(const std::string &pass_type) const {
IR_ENFORCE(
Has(pass_type) == true, "Pass %s has not been registered.", pass_type);
return pass_map_.at(pass_type)();
}
private:
PassRegistry() = default;
std::unordered_map<std::string, PassCreator> pass_map_;
DISABLE_COPY_AND_ASSIGN(PassRegistry);
};
template <typename PassType>
class PassRegistrar {
public:
// In our design, various kinds of passes,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which
// are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_PASS macros to
// call this method. So, as long as the callee code calls USE_PASS, the global
// registrar variable won't be removed by the linker.
void Touch() {}
explicit PassRegistrar(const char *pass_type) {
PassRegistry::Instance().Insert(
pass_type, []() { return std::make_unique<PassType>(); });
}
};
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
// Register a new pass that can be applied on the IR.
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::ir::PassRegistrar<pass_class> __pass_registrar_##pass_type##__( \
#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
} \
static ::ir::PassRegistrar<pass_class> &__pass_tmp_registrar_##pass_type##__ \
UNUSED = __pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ UNUSED = \
TouchPassRegistrar_##pass_type()
} // namespace ir
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
#include "paddle/ir/pass/pass.h" #include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_registry.h"
namespace { namespace {
...@@ -75,3 +76,5 @@ std::unique_ptr<Pass> CreateDeadCodeEliminationPass() { ...@@ -75,3 +76,5 @@ std::unique_ptr<Pass> CreateDeadCodeEliminationPass() {
} }
} // namespace ir } // namespace ir
REGISTER_PASS(dead_code_elimination, DeadCodeEliminationPass);
...@@ -51,11 +51,12 @@ class TestShadowOutputSlice(unittest.TestCase): ...@@ -51,11 +51,12 @@ class TestShadowOutputSlice(unittest.TestCase):
self.assertTrue('pd.uniform' in op_names) self.assertTrue('pd.uniform' in op_names)
pm = ir.PassManager() pm = ir.PassManager()
pm.add_pass( pm.add_pass(
'DeadCodeEliminationPass' 'dead_code_elimination'
) # apply pass to elimitate dead code ) # apply pass to elimitate dead code
pm.run(new_program) pm.run(new_program)
op_names = [op.name() for op in new_program.block().ops] op_names = [op.name() for op in new_program.block().ops]
# print(op_names) # print(op_names)
# TODO(zhiqiu): unify the name of pass
self.assertEqual(pm.passes(), ['DeadCodeEliminationPass']) self.assertEqual(pm.passes(), ['DeadCodeEliminationPass'])
self.assertFalse(pm.empty()) self.assertFalse(pm.empty())
self.assertTrue( self.assertTrue(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册