From 9999e8497a485ea76df8fd198afd9c8db08c926f Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Tue, 29 Aug 2023 19:14:55 +0800 Subject: [PATCH] [New-IR] add pass registry (#56729) * add pass registry * add pass registry macro --- paddle/fluid/pybind/ir.cc | 15 +-- paddle/ir/pass/pass.h | 1 + paddle/ir/pass/pass_registry.cc | 23 ++++ paddle/ir/pass/pass_registry.h | 104 ++++++++++++++++++ .../transforms/dead_code_elimination_pass.cc | 3 + test/ir/new_ir/test_pass_manager.py | 3 +- 6 files changed, 138 insertions(+), 11 deletions(-) create mode 100644 paddle/ir/pass/pass_registry.cc create mode 100644 paddle/ir/pass/pass_registry.h diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 6c6957c3e00..675e6f2acd2 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -39,6 +39,7 @@ #include "paddle/ir/core/value.h" #include "paddle/ir/pass/pass.h" #include "paddle/ir/pass/pass_manager.h" +#include "paddle/ir/pass/pass_registry.h" #include "paddle/ir/transforms/dead_code_elimination_pass.h" #include "paddle/phi/core/enforce.h" #include "pybind11/stl.h" @@ -57,6 +58,8 @@ using paddle::dialect::APIBuilder; using paddle::dialect::DenseTensorType; using pybind11::return_value_policy; +USE_PASS(dead_code_elimination); + namespace paddle { namespace pybind { @@ -488,15 +491,6 @@ void BindIrPass(pybind11::module *m) { [](const Pass &self) { return self.pass_info().dependents; }); } -// TODO(zhiqiu): refine pass registry -std::unique_ptr CreatePassByName(std::string name) { - if (name == "DeadCodeEliminationPass") { - return ir::CreateDeadCodeEliminationPass(); - } else { - IR_THROW("The %s pass is not registed", name); - } -} - void BindPassManager(pybind11::module *m) { py::class_> pass_manager( *m, @@ -514,7 +508,8 @@ void BindPassManager(pybind11::module *m) { py::arg("opt_level") = 2) .def("add_pass", [](PassManager &self, std::string pass_name) { - self.AddPass(std::move(CreatePassByName(pass_name))); + self.AddPass( + std::move(ir::PassRegistry::Instance().Get(pass_name))); }) .def("passes", [](PassManager &self) { diff --git a/paddle/ir/pass/pass.h b/paddle/ir/pass/pass.h index 484651e87e2..4a4cbf629d6 100644 --- a/paddle/ir/pass/pass.h +++ b/paddle/ir/pass/pass.h @@ -20,6 +20,7 @@ #include "paddle/ir/core/enforce.h" #include "paddle/ir/pass/analysis_manager.h" +#include "paddle/ir/pass/pass_registry.h" #include "paddle/phi/core/enforce.h" namespace ir { diff --git a/paddle/ir/pass/pass_registry.cc b/paddle/ir/pass/pass_registry.cc new file mode 100644 index 00000000000..a0239219a69 --- /dev/null +++ b/paddle/ir/pass/pass_registry.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/pass/pass_registry.h" + +namespace ir { +PassRegistry &PassRegistry::Instance() { + static PassRegistry g_pass_info_map; + return g_pass_info_map; +} + +} // namespace ir diff --git a/paddle/ir/pass/pass_registry.h b/paddle/ir/pass/pass_registry.h new file mode 100644 index 00000000000..c35dc0ba90a --- /dev/null +++ b/paddle/ir/pass/pass_registry.h @@ -0,0 +1,104 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/ir/core/enforce.h" +#include "paddle/ir/core/macros.h" +#include "paddle/ir/pass/pass.h" + +namespace ir { + +class Pass; + +using PassCreator = std::function()>; + +class PassRegistry { + public: + static PassRegistry &Instance(); + + bool Has(const std::string &pass_type) const { + return pass_map_.find(pass_type) != pass_map_.end(); + } + + void Insert(const std::string &pass_type, const PassCreator &pass_creator) { + IR_ENFORCE( + Has(pass_type) != true, "Pass %s has been registered.", pass_type); + pass_map_.insert({pass_type, pass_creator}); + } + + std::unique_ptr Get(const std::string &pass_type) const { + IR_ENFORCE( + Has(pass_type) == true, "Pass %s has not been registered.", pass_type); + return pass_map_.at(pass_type)(); + } + + private: + PassRegistry() = default; + std::unordered_map pass_map_; + + DISABLE_COPY_AND_ASSIGN(PassRegistry); +}; + +template +class PassRegistrar { + public: + // In our design, various kinds of passes, + // have their corresponding registry and registrar. The action of + // registration is in the constructor of a global registrar variable, which + // are not used in the code that calls package framework, and would + // be removed from the generated binary file by the linker. To avoid such + // removal, we add Touch to all registrar classes and make USE_PASS macros to + // call this method. So, as long as the callee code calls USE_PASS, the global + // registrar variable won't be removed by the linker. + void Touch() {} + explicit PassRegistrar(const char *pass_type) { + PassRegistry::Instance().Insert( + pass_type, []() { return std::make_unique(); }); + } +}; + +#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \ + struct __test_global_namespace_##uniq_name##__ {}; \ + static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ + __test_global_namespace_##uniq_name##__>::value, \ + msg) + +// Register a new pass that can be applied on the IR. +#define REGISTER_PASS(pass_type, pass_class) \ + STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ + __reg_pass__##pass_type, \ + "REGISTER_PASS must be called in global namespace"); \ + static ::ir::PassRegistrar __pass_registrar_##pass_type##__( \ + #pass_type); \ + int TouchPassRegistrar_##pass_type() { \ + __pass_registrar_##pass_type##__.Touch(); \ + return 0; \ + } \ + static ::ir::PassRegistrar &__pass_tmp_registrar_##pass_type##__ \ + UNUSED = __pass_registrar_##pass_type##__ + +#define USE_PASS(pass_type) \ + STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ + __use_pass_itself_##pass_type, \ + "USE_PASS must be called in global namespace"); \ + extern int TouchPassRegistrar_##pass_type(); \ + static int use_pass_itself_##pass_type##_ UNUSED = \ + TouchPassRegistrar_##pass_type() + +} // namespace ir diff --git a/paddle/ir/transforms/dead_code_elimination_pass.cc b/paddle/ir/transforms/dead_code_elimination_pass.cc index f58a4485fc7..d56b83b8446 100644 --- a/paddle/ir/transforms/dead_code_elimination_pass.cc +++ b/paddle/ir/transforms/dead_code_elimination_pass.cc @@ -17,6 +17,7 @@ #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/program.h" #include "paddle/ir/pass/pass.h" +#include "paddle/ir/pass/pass_registry.h" namespace { @@ -75,3 +76,5 @@ std::unique_ptr CreateDeadCodeEliminationPass() { } } // namespace ir + +REGISTER_PASS(dead_code_elimination, DeadCodeEliminationPass); diff --git a/test/ir/new_ir/test_pass_manager.py b/test/ir/new_ir/test_pass_manager.py index 580dea77677..2f31e945f31 100644 --- a/test/ir/new_ir/test_pass_manager.py +++ b/test/ir/new_ir/test_pass_manager.py @@ -51,11 +51,12 @@ class TestShadowOutputSlice(unittest.TestCase): self.assertTrue('pd.uniform' in op_names) pm = ir.PassManager() pm.add_pass( - 'DeadCodeEliminationPass' + 'dead_code_elimination' ) # apply pass to elimitate dead code pm.run(new_program) op_names = [op.name() for op in new_program.block().ops] # print(op_names) + # TODO(zhiqiu): unify the name of pass self.assertEqual(pm.passes(), ['DeadCodeEliminationPass']) self.assertFalse(pm.empty()) self.assertTrue( -- GitLab