未验证 提交 63b70740 编写于 作者: L Leo Chen 提交者: GitHub

bind python interface for pass manager (#56638)

* bind python interface for pass manager

* add ut

* revert unused change
上级 9dbc8f02
...@@ -37,6 +37,9 @@ ...@@ -37,6 +37,9 @@
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
#include "paddle/ir/core/type.h" #include "paddle/ir/core/type.h"
#include "paddle/ir/core/value.h" #include "paddle/ir/core/value.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/transforms/dead_code_elimination_pass.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
...@@ -45,6 +48,8 @@ using ir::Block; ...@@ -45,6 +48,8 @@ using ir::Block;
using ir::Operation; using ir::Operation;
using ir::OpOperand; using ir::OpOperand;
using ir::OpResult; using ir::OpResult;
using ir::Pass;
using ir::PassManager;
using ir::Program; using ir::Program;
using ir::Type; using ir::Type;
using ir::Value; using ir::Value;
...@@ -465,6 +470,60 @@ void BindUtils(pybind11::module *m) { ...@@ -465,6 +470,60 @@ void BindUtils(pybind11::module *m) {
)DOC"); )DOC");
} }
void BindIrPass(pybind11::module *m) {
py::class_<Pass, std::shared_ptr<Pass>> pass(*m,
"Pass",
R"DOC(
Pass class.
)DOC");
pass.def("name", &Pass::name)
.def("opt_level",
[](const Pass &self) { return self.pass_info().opt_level; })
.def("dependents",
[](const Pass &self) { return self.pass_info().dependents; });
}
// TODO(zhiqiu): refine pass registry
std::unique_ptr<Pass> CreatePassByName(std::string name) {
if (name == "DeadCodeEliminationPass") {
return ir::CreateDeadCodeEliminationPass();
} else {
IR_THROW("The %s pass is not registed", name);
}
}
void BindPassManager(pybind11::module *m) {
py::class_<PassManager, std::shared_ptr<PassManager>> pass_manager(
*m,
"PassManager",
R"DOC(
A class that manages all passes.
)DOC");
pass_manager
.def(
"__init__",
[](PassManager &self, uint8_t opt_level) {
new (&self) PassManager(ir::IrContext::Instance(), opt_level);
},
py::arg("opt_level") = 2)
.def("add_pass",
[](PassManager &self, std::string pass_name) {
self.AddPass(std::move(CreatePassByName(pass_name)));
})
.def("passes",
[](PassManager &self) {
std::vector<std::string> pass_names;
for (const auto &pass : self.passes()) {
pass_names.emplace_back(pass->name());
}
return pass_names;
})
.def("run", [](PassManager &self, Program *p) { self.Run(p); })
.def("empty", &PassManager::Empty);
}
void BindNewIR(pybind11::module *module) { void BindNewIR(pybind11::module *module) {
auto ir_module = module->def_submodule("ir"); auto ir_module = module->def_submodule("ir");
BindProgram(&ir_module); BindProgram(&ir_module);
...@@ -475,6 +534,8 @@ void BindNewIR(pybind11::module *module) { ...@@ -475,6 +534,8 @@ void BindNewIR(pybind11::module *module) {
BindOpResult(&ir_module); BindOpResult(&ir_module);
BindType(&ir_module); BindType(&ir_module);
BindUtils(&ir_module); BindUtils(&ir_module);
BindIrPass(&ir_module);
BindPassManager(&ir_module);
auto ops_modules = ir_module.def_submodule("ops"); auto ops_modules = ir_module.def_submodule("ops");
BindOpsAPI(&ops_modules); BindOpsAPI(&ops_modules);
} }
......
...@@ -28,6 +28,7 @@ from paddle.fluid.libpaddle.ir import ( ...@@ -28,6 +28,7 @@ from paddle.fluid.libpaddle.ir import (
reset_insertion_point_to_start, reset_insertion_point_to_start,
reset_insertion_point_to_end, reset_insertion_point_to_end,
check_unregistered_ops, check_unregistered_ops,
PassManager,
) # noqa: F401 ) # noqa: F401
from . import core from . import core
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle import ir
from paddle.fluid import core
from paddle.framework import LayerHelper
paddle.enable_static()
class TestShadowOutputSlice(unittest.TestCase):
def test_op(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x = paddle.ones([3, 9, 5], dtype='float32')
y = paddle.static.data(
name="y", shape=[3, 9, 5], dtype="float32"
)
paddle.rand([10]) # will be eliminated
_, out, _ = paddle.split(x, num_or_sections=3, axis=1)
helper = LayerHelper('shadow_output')
helper.append_op(
type="shadow_output",
inputs={"x": [out.name]},
outputs={"out": [y.name]},
attrs={"name": out.name},
)
new_program = ir.translate_to_new_ir(main_program.desc)
op_names = [op.name() for op in new_program.block().ops]
# print(op_names)
self.assertTrue('pd.uniform' in op_names)
pm = ir.PassManager()
pm.add_pass(
'DeadCodeEliminationPass'
) # apply pass to elimitate dead code
pm.run(new_program)
op_names = [op.name() for op in new_program.block().ops]
# print(op_names)
self.assertEqual(pm.passes(), ['DeadCodeEliminationPass'])
self.assertFalse(pm.empty())
self.assertTrue(
'pd.uniform' not in op_names
) # uniform is elimited because its output is not used
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册