未验证 提交 c8874f23 编写于 作者: K kuizhiqing 提交者: GitHub

[CINN] add fetch and prune for build cinn pass (#45531)

* add fetch and prune for build cinn pass

* add prune flag
上级 13d62e12
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.static import Executor
from .pass_base import PassType, CPPPassWrapper, register_pass
from paddle.fluid.framework import core, _apply_pass as _apply_cpp_pass
......@@ -102,6 +103,13 @@ class InplaceAddtoOpPass(CPPPassWrapper):
return PassType.CALC_OPT
def _set_cinn_op_flag(flag_name, extra_ops):
values = core.globals()[flag_name]
values = [v.strip() for v in values.split(";") if v.strip()]
values.extend(extra_ops)
core.globals()[flag_name] = ";".join(values)
@register_pass("build_cinn")
class BuildCINNPass(CPPPassWrapper):
......@@ -118,18 +126,41 @@ class BuildCINNPass(CPPPassWrapper):
return PassType.CALC_OPT
def _apply_single_impl(self, main_program, startup_program, context):
allow_ops = ";".join(self.get_attr("allow_ops"))
deny_ops = ";".join(self.get_attr("deny_ops"))
assert 'FLAGS_allow_cinn_ops' in core.globals(
), "PaddlePaddle is not compiled with CINN support"
old_allow_ops = core.globals()['FLAGS_allow_cinn_ops']
old_deny_ops = core.globals()['FLAGS_deny_cinn_ops']
try:
core.globals()['FLAGS_allow_cinn_ops'] = allow_ops
core.globals()['FLAGS_deny_cinn_ops'] = deny_ops
_apply_cpp_pass(main_program, startup_program, self.cpp_name, {},
self.cpp_attr_types)
_set_cinn_op_flag('FLAGS_allow_cinn_ops',
self.get_attr("allow_ops"))
_set_cinn_op_flag('FLAGS_deny_cinn_ops', self.get_attr("deny_ops"))
feed = self.get_attr('feed', [])
fetch_list = self.get_attr('fetch_list', [])
prune_program = self.get_attr('prune_program', True)
if prune_program:
tmp_main_program = Executor._prune_program(
main_program, feed, fetch_list, [])
tmp_main_program = Executor._add_fetch_ops(
tmp_main_program, fetch_list, 'fetch')
else:
tmp_main_program = Executor._add_fetch_ops(
main_program, fetch_list, 'fetch')
_apply_cpp_pass(tmp_main_program, startup_program, self.cpp_name,
{}, self.cpp_attr_types)
tmp_main_program = Executor._remove_fetch_ops(tmp_main_program)
tmp_main_program = core.ProgramDesc(tmp_main_program.desc)
main_program._rebuild_from_desc(tmp_main_program)
finally:
core.globals()['FLAGS_allow_cinn_ops'] = old_allow_ops
core.globals()['FLAGS_deny_cinn_ops'] = old_deny_ops
......@@ -978,7 +978,8 @@ class Executor(object):
]
return outs
def _split_optimize_ops_in_fetch_list(self, fetch_list):
@classmethod
def _split_optimize_ops_in_fetch_list(cls, fetch_list):
"""
Split optimize_ops from fetch_list, which provided to specify program prunning.
Args:
......@@ -1030,7 +1031,8 @@ class Executor(object):
return _fetch_list, _optimize_ops
def _prune_program(self,
@classmethod
def _prune_program(cls,
program,
feed=None,
fetch_list=None,
......@@ -1093,7 +1095,8 @@ class Executor(object):
return program
def _update_feed(self, program, feed):
@classmethod
def _update_feed(cls, program, feed):
"""
Update the feed dict, remove the feed item which is pruned in program.
......@@ -2379,7 +2382,8 @@ class Executor(object):
return tmp_program
def _add_fetch_ops(self,
@classmethod
def _add_fetch_ops(cls,
program,
fetch_list,
fetch_var_name,
......@@ -2416,6 +2420,17 @@ class Executor(object):
return tmp_program
@classmethod
def _remove_fetch_ops(cls, program, fetch_op_name='fetch'):
tmp_program = program.clone()
global_block = tmp_program.global_block()
op_num = len(global_block.ops)
for idx in reversed(range(op_num)):
if global_block.ops[idx].type == fetch_op_name:
global_block._remove_op(idx)
return tmp_program
def _run_pipeline(self,
program=None,
dataset=None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册