From 63b707409f120c51af1ca815a45dd849ad504de5 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Fri, 25 Aug 2023 10:34:29 +0800 Subject: [PATCH] bind python interface for pass manager (#56638) * bind python interface for pass manager * add ut * revert unused change --- paddle/fluid/pybind/ir.cc | 61 ++++++++++++++++++++++++++ python/paddle/ir/__init__.py | 1 + test/ir/new_ir/test_pass_manager.py | 67 +++++++++++++++++++++++++++++ 3 files changed, 129 insertions(+) create mode 100644 test/ir/new_ir/test_pass_manager.py diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 98f8b35c3e8..1f42052ca7a 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -37,6 +37,9 @@ #include "paddle/ir/core/program.h" #include "paddle/ir/core/type.h" #include "paddle/ir/core/value.h" +#include "paddle/ir/pass/pass.h" +#include "paddle/ir/pass/pass_manager.h" +#include "paddle/ir/transforms/dead_code_elimination_pass.h" #include "paddle/phi/core/enforce.h" #include "pybind11/stl.h" @@ -45,6 +48,8 @@ using ir::Block; using ir::Operation; using ir::OpOperand; using ir::OpResult; +using ir::Pass; +using ir::PassManager; using ir::Program; using ir::Type; using ir::Value; @@ -465,6 +470,60 @@ void BindUtils(pybind11::module *m) { )DOC"); } +void BindIrPass(pybind11::module *m) { + py::class_> pass(*m, + "Pass", + R"DOC( + Pass class. + + )DOC"); + pass.def("name", &Pass::name) + .def("opt_level", + [](const Pass &self) { return self.pass_info().opt_level; }) + .def("dependents", + [](const Pass &self) { return self.pass_info().dependents; }); +} + +// TODO(zhiqiu): refine pass registry +std::unique_ptr CreatePassByName(std::string name) { + if (name == "DeadCodeEliminationPass") { + return ir::CreateDeadCodeEliminationPass(); + } else { + IR_THROW("The %s pass is not registed", name); + } +} + +void BindPassManager(pybind11::module *m) { + py::class_> pass_manager( + *m, + "PassManager", + R"DOC( + A class that manages all passes. + + )DOC"); + pass_manager + .def( + "__init__", + [](PassManager &self, uint8_t opt_level) { + new (&self) PassManager(ir::IrContext::Instance(), opt_level); + }, + py::arg("opt_level") = 2) + .def("add_pass", + [](PassManager &self, std::string pass_name) { + self.AddPass(std::move(CreatePassByName(pass_name))); + }) + .def("passes", + [](PassManager &self) { + std::vector pass_names; + for (const auto &pass : self.passes()) { + pass_names.emplace_back(pass->name()); + } + return pass_names; + }) + .def("run", [](PassManager &self, Program *p) { self.Run(p); }) + .def("empty", &PassManager::Empty); +} + void BindNewIR(pybind11::module *module) { auto ir_module = module->def_submodule("ir"); BindProgram(&ir_module); @@ -475,6 +534,8 @@ void BindNewIR(pybind11::module *module) { BindOpResult(&ir_module); BindType(&ir_module); BindUtils(&ir_module); + BindIrPass(&ir_module); + BindPassManager(&ir_module); auto ops_modules = ir_module.def_submodule("ops"); BindOpsAPI(&ops_modules); } diff --git a/python/paddle/ir/__init__.py b/python/paddle/ir/__init__.py index 17f61a883fd..be8ddeba229 100644 --- a/python/paddle/ir/__init__.py +++ b/python/paddle/ir/__init__.py @@ -28,6 +28,7 @@ from paddle.fluid.libpaddle.ir import ( reset_insertion_point_to_start, reset_insertion_point_to_end, check_unregistered_ops, + PassManager, ) # noqa: F401 from . import core diff --git a/test/ir/new_ir/test_pass_manager.py b/test/ir/new_ir/test_pass_manager.py new file mode 100644 index 00000000000..580dea77677 --- /dev/null +++ b/test/ir/new_ir/test_pass_manager.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle import ir +from paddle.fluid import core +from paddle.framework import LayerHelper + +paddle.enable_static() + + +class TestShadowOutputSlice(unittest.TestCase): + def test_op(self): + place = core.Place() + place.set_place(paddle.CPUPlace()) + new_scope = paddle.static.Scope() + main_program = paddle.static.Program() + with paddle.static.scope_guard(new_scope): + with paddle.static.program_guard(main_program): + x = paddle.ones([3, 9, 5], dtype='float32') + y = paddle.static.data( + name="y", shape=[3, 9, 5], dtype="float32" + ) + paddle.rand([10]) # will be eliminated + + _, out, _ = paddle.split(x, num_or_sections=3, axis=1) + helper = LayerHelper('shadow_output') + helper.append_op( + type="shadow_output", + inputs={"x": [out.name]}, + outputs={"out": [y.name]}, + attrs={"name": out.name}, + ) + + new_program = ir.translate_to_new_ir(main_program.desc) + op_names = [op.name() for op in new_program.block().ops] + # print(op_names) + self.assertTrue('pd.uniform' in op_names) + pm = ir.PassManager() + pm.add_pass( + 'DeadCodeEliminationPass' + ) # apply pass to elimitate dead code + pm.run(new_program) + op_names = [op.name() for op in new_program.block().ops] + # print(op_names) + self.assertEqual(pm.passes(), ['DeadCodeEliminationPass']) + self.assertFalse(pm.empty()) + self.assertTrue( + 'pd.uniform' not in op_names + ) # uniform is elimited because its output is not used + + +if __name__ == "__main__": + unittest.main() -- GitLab