From c8874f23c436bd4af17706e1fc26ba6c2e02d9f1 Mon Sep 17 00:00:00 2001 From: kuizhiqing Date: Thu, 1 Sep 2022 14:17:18 +0800 Subject: [PATCH] [CINN] add fetch and prune for build cinn pass (#45531) * add fetch and prune for build cinn pass * add prune flag --- python/paddle/distributed/passes/cpp_pass.py | 43 +++++++++++++++++--- python/paddle/fluid/executor.py | 23 +++++++++-- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/passes/cpp_pass.py b/python/paddle/distributed/passes/cpp_pass.py index 1d99a93624..c729a919c1 100644 --- a/python/paddle/distributed/passes/cpp_pass.py +++ b/python/paddle/distributed/passes/cpp_pass.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from paddle.static import Executor from .pass_base import PassType, CPPPassWrapper, register_pass from paddle.fluid.framework import core, _apply_pass as _apply_cpp_pass @@ -102,6 +103,13 @@ class InplaceAddtoOpPass(CPPPassWrapper): return PassType.CALC_OPT +def _set_cinn_op_flag(flag_name, extra_ops): + values = core.globals()[flag_name] + values = [v.strip() for v in values.split(";") if v.strip()] + values.extend(extra_ops) + core.globals()[flag_name] = ";".join(values) + + @register_pass("build_cinn") class BuildCINNPass(CPPPassWrapper): @@ -118,18 +126,41 @@ class BuildCINNPass(CPPPassWrapper): return PassType.CALC_OPT def _apply_single_impl(self, main_program, startup_program, context): - allow_ops = ";".join(self.get_attr("allow_ops")) - deny_ops = ";".join(self.get_attr("deny_ops")) assert 'FLAGS_allow_cinn_ops' in core.globals( ), "PaddlePaddle is not compiled with CINN support" old_allow_ops = core.globals()['FLAGS_allow_cinn_ops'] old_deny_ops = core.globals()['FLAGS_deny_cinn_ops'] try: - core.globals()['FLAGS_allow_cinn_ops'] = allow_ops - core.globals()['FLAGS_deny_cinn_ops'] = deny_ops - _apply_cpp_pass(main_program, startup_program, self.cpp_name, {}, - self.cpp_attr_types) + _set_cinn_op_flag('FLAGS_allow_cinn_ops', + self.get_attr("allow_ops")) + _set_cinn_op_flag('FLAGS_deny_cinn_ops', self.get_attr("deny_ops")) + + feed = self.get_attr('feed', []) + fetch_list = self.get_attr('fetch_list', []) + prune_program = self.get_attr('prune_program', True) + + if prune_program: + tmp_main_program = Executor._prune_program( + main_program, feed, fetch_list, []) + + tmp_main_program = Executor._add_fetch_ops( + tmp_main_program, fetch_list, 'fetch') + + else: + + tmp_main_program = Executor._add_fetch_ops( + main_program, fetch_list, 'fetch') + + _apply_cpp_pass(tmp_main_program, startup_program, self.cpp_name, + {}, self.cpp_attr_types) + + tmp_main_program = Executor._remove_fetch_ops(tmp_main_program) + + tmp_main_program = core.ProgramDesc(tmp_main_program.desc) + + main_program._rebuild_from_desc(tmp_main_program) + finally: core.globals()['FLAGS_allow_cinn_ops'] = old_allow_ops core.globals()['FLAGS_deny_cinn_ops'] = old_deny_ops diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 2e608fdcd7..5b92df7838 100755 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -978,7 +978,8 @@ class Executor(object): ] return outs - def _split_optimize_ops_in_fetch_list(self, fetch_list): + @classmethod + def _split_optimize_ops_in_fetch_list(cls, fetch_list): """ Split optimize_ops from fetch_list, which provided to specify program prunning. Args: @@ -1030,7 +1031,8 @@ class Executor(object): return _fetch_list, _optimize_ops - def _prune_program(self, + @classmethod + def _prune_program(cls, program, feed=None, fetch_list=None, @@ -1093,7 +1095,8 @@ class Executor(object): return program - def _update_feed(self, program, feed): + @classmethod + def _update_feed(cls, program, feed): """ Update the feed dict, remove the feed item which is pruned in program. @@ -2379,7 +2382,8 @@ class Executor(object): return tmp_program - def _add_fetch_ops(self, + @classmethod + def _add_fetch_ops(cls, program, fetch_list, fetch_var_name, @@ -2416,6 +2420,17 @@ class Executor(object): return tmp_program + @classmethod + def _remove_fetch_ops(cls, program, fetch_op_name='fetch'): + tmp_program = program.clone() + global_block = tmp_program.global_block() + op_num = len(global_block.ops) + for idx in reversed(range(op_num)): + if global_block.ops[idx].type == fetch_op_name: + global_block._remove_op(idx) + + return tmp_program + def _run_pipeline(self, program=None, dataset=None, -- GitLab