未验证 提交 a8541c15 编写于 作者: R Ruibiao Chen 提交者: GitHub

Refactor ExecutorCache (#45532)

* Refactor ExecutorCache

* Update code

* Fix mkldnn UT errors

* Fix typos

* Fix CI errors
上级 9c1aa6c7
...@@ -37,6 +37,8 @@ from . import framework ...@@ -37,6 +37,8 @@ from . import framework
from .incubate.checkpoint import auto_checkpoint as acp from .incubate.checkpoint import auto_checkpoint as acp
from .compiler import _prune_feed_ops from .compiler import _prune_feed_ops
from functools import lru_cache
__all__ = ['Executor', 'global_scope', 'scope_guard'] __all__ = ['Executor', 'global_scope', 'scope_guard']
g_scope = core.Scope() g_scope = core.Scope()
...@@ -335,6 +337,84 @@ def has_fetch_operators(block, ...@@ -335,6 +337,84 @@ def has_fetch_operators(block,
return fetch_count > 0 return fetch_count > 0
def _add_feed_fetch_ops(program,
feed,
fetch_list,
feed_var_name,
fetch_var_name,
use_fetch_v2=False):
tmp_program = program.clone()
global_block = tmp_program.global_block()
if feed_var_name in global_block.vars:
feed_var = global_block.var(feed_var_name)
else:
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)
if fetch_var_name in global_block.vars:
fetch_var = global_block.var(fetch_var_name)
else:
fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)
# prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed):
if global_block.has_var(name):
out = global_block.var(name)
global_block._prepend_op(type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
else:
warnings.warn(
"The variable %s is not found in program. It is not declared or is pruned."
% name)
if use_fetch_v2:
fetch_op = 'fetch_v2'
else:
fetch_op = 'fetch'
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name,
fetch_op):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(
var, six.string_types), ("Wrong type for fetch_list[%s]: %s" %
(i, type(var)))
global_block.append_op(type=fetch_op,
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
return tmp_program
def _apply_inplace_addto_pass(program, enable_inplace, enable_addto,
skip_var_names):
use_cuda = True if core.is_compiled_with_cuda() else False
attrs = {"use_cuda": use_cuda, "mem_opt_skip_vars": skip_var_names}
attr_types = {"use_cuda": "bool", "mem_opt_skip_vars": "list[str]"}
empty_startup_program = Program()
if enable_inplace:
pass_name = "buffer_shared_inplace_pass"
_apply_pass(program, empty_startup_program, pass_name, attrs,
attr_types)
if enable_addto and use_cuda:
pass_name = "inplace_addto_op_pass"
_apply_pass(program, empty_startup_program, pass_name, attrs,
attr_types)
def _fetch_var(name, scope=None, return_numpy=True): def _fetch_var(name, scope=None, return_numpy=True):
""" """
Fetch the value of the variable with the given name from the Fetch the value of the variable with the given name from the
...@@ -613,10 +693,114 @@ class _StandaloneExecutor(object): ...@@ -613,10 +693,114 @@ class _StandaloneExecutor(object):
class _ExecutorCache(object): class _ExecutorCache(object):
def __init__(self, place): class _CachedData(object):
# {Program : _StandaloneExecutor}
self._place = place def __init__(self, program, feed, fetch_list, feed_var_name,
self._cached_executors = {} fetch_var_name, place, scope):
self.program = program
self.feed = feed
self.fetch_list = fetch_list
self.feed_var_name = feed_var_name
self.fetch_var_name = fetch_var_name
self.place = place
self.scope = scope
# NOTE(Ruibiao): Not all changeable item is considered for key at present,
# ONLY: program, feed, and fetch_list
if isinstance(self.program, compiler.CompiledProgram):
self.key = hash(
_get_strong_program_cache_key_for_new_exe(
self.program._program, feed, fetch_list))
else:
self.key = hash(
_get_strong_program_cache_key_for_new_exe(
self.program, feed, fetch_list))
def __eq__(self, other):
return isinstance(
other, _ExecutorCache._CachedData) and self.key == other.key
def __hash__(self):
return self.key
def __init__(self):
# NOTE(Ruibiao): Wrap the lru_cache in constructor so that the cache is local to
# the _ExecutorCache instance, otherwise a global cache may not be released after
# the Executor instance deleted
self._get_cached_program_and_executor = lru_cache(maxsize=8)(
self._get_program_and_executor)
def clear(self):
self._get_cached_program_and_executor.cache_clear()
def get_program_and_executor(self, program, feed, fetch_list, feed_var_name,
fetch_var_name, place, scope):
return self._get_cached_program_and_executor(
self._CachedData(program, feed, fetch_list, feed_var_name,
fetch_var_name, place, scope))
def _get_program_and_executor(self, cached_data):
program = cached_data.program
inner_program = program._program if isinstance(
program, compiler.CompiledProgram) else program
feed = cached_data.feed
fetch_list = cached_data.fetch_list
feed_var_name = cached_data.feed_var_name
fetch_var_name = cached_data.fetch_var_name
place = cached_data.place
scope = cached_data.scope
# To apply IR pass, compile the Program to IrGraph and convert it back to Program
if isinstance(program, compiler.CompiledProgram) or isinstance(
program._graph, compiler.CompiledProgram):
compiled_program = program if isinstance(
program, compiler.CompiledProgram) else program._graph
build_strategy = compiled_program._build_strategy
# print(f"Program before convert:\n {inner_program}", flush=True)
compiled_program._compile(scope, place)
ir_graph = framework.IrGraph(compiled_program._graph)
converted_program = ir_graph.to_program()
if hasattr(inner_program, 'lr_sheduler'):
converted_program.lr_sheduler = inner_program.lr_sheduler
inner_program = converted_program
# print(f"Program after convert:\n {inner_program}", flush=True)
warnings.warn(
"FLAGS_USE_STANDALONE_EXECUTOR and FLAGS_CONVERT_GRAPH_TO_PROGRAM is set to 1. Graph will be converted to Program and executed using new executor."
)
else:
build_strategy = None
from paddle.incubate.autograd import prim_enabled, prim2orig
if prim_enabled() and program == default_main_program():
prim2orig()
inner_program = program
program = _add_feed_fetch_ops(program=inner_program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name,
use_fetch_v2=True)
# If there are multiple blocks in the program, subblock will not be executed with the new executor in temporary
if program.num_blocks > 1:
warnings.warn("There are more than 1 block in program.")
# standalone executor will apply buffer_shared_inplace_pass and
# inplace_addto_op_pass to program according to build_strategy
enable_inplace = True if build_strategy is None or build_strategy.enable_inplace else False
enable_addto = True if build_strategy is not None and build_strategy.enable_addto else False
if enable_inplace or enable_addto:
# inplace should skip feed and fetch var
skip_var_names = eval(_get_program_cache_key(feed, fetch_list))
_apply_inplace_addto_pass(program, enable_inplace, enable_addto,
skip_var_names)
new_program = program.clone()
new_exe = _StandaloneExecutor(place, new_program, scope)
return new_program, new_exe
class Executor(object): class Executor(object):
...@@ -720,13 +904,19 @@ class Executor(object): ...@@ -720,13 +904,19 @@ class Executor(object):
# NOTE: Whether to use experimental executor `StandaloneExecutor`. # NOTE: Whether to use experimental executor `StandaloneExecutor`.
self._enable_interpreter_core = _is_enable_standalone_executor() self._enable_interpreter_core = _is_enable_standalone_executor()
self._executor_cache = _ExecutorCache(self.place) self._executor_cache = _ExecutorCache()
self._fleet_executor = None self._fleet_executor = None
# TODO(liyurui): This option will be removed and always true when the functionality # TODO(liyurui): This option will be removed and always true when the functionality
# of fleet executor with standalone executor is ready. # of fleet executor with standalone executor is ready.
self._fleet_executor_with_standalone = False self._fleet_executor_with_standalone = False
def __del__(self):
# NOTE(Ruibiao): The manually call of clear is required. Because in Python, executor_cache
# may not immediately destructed after Executor instance deleted (so does not the _StandaloneExecutor),
# that brings errors to mkl-dnn unit tests (see ClearMKLDNNCache in interpretercore.cc for why).
self._executor_cache.clear()
def _get_scope_cache(self, program_cache_key): def _get_scope_cache(self, program_cache_key):
return self.scope_caches.get(program_cache_key, None) return self.scope_caches.get(program_cache_key, None)
...@@ -763,67 +953,6 @@ class Executor(object): ...@@ -763,67 +953,6 @@ class Executor(object):
def _add_scope_cache(self, scope_cache_key, scope): def _add_scope_cache(self, scope_cache_key, scope):
self.scope_caches[scope_cache_key] = scope self.scope_caches[scope_cache_key] = scope
def _add_feed_fetch_ops(self,
program,
feed,
fetch_list,
feed_var_name,
fetch_var_name,
use_fetch_v2=False):
tmp_program = program.clone()
global_block = tmp_program.global_block()
if feed_var_name in global_block.vars:
feed_var = global_block.var(feed_var_name)
else:
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)
if fetch_var_name in global_block.vars:
fetch_var = global_block.var(fetch_var_name)
else:
fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)
# prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed):
if global_block.has_var(name):
out = global_block.var(name)
global_block._prepend_op(type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
else:
warnings.warn(
"The variable %s is not found in program. It is not declared or is pruned."
% name)
if use_fetch_v2:
fetch_op = 'fetch_v2'
else:
fetch_op = 'fetch'
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name,
fetch_op):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(
var,
six.string_types), ("Wrong type for fetch_list[%s]: %s" %
(i, type(var)))
global_block.append_op(type=fetch_op,
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
return tmp_program
def _feed_data(self, program, feed, feed_var_name, scope): def _feed_data(self, program, feed, feed_var_name, scope):
# feed var to framework # feed var to framework
global_block = program.global_block() global_block = program.global_block()
...@@ -1482,23 +1611,6 @@ class Executor(object): ...@@ -1482,23 +1611,6 @@ class Executor(object):
assert isinstance(program, Program) assert isinstance(program, Program)
return True return True
def _apply_inplace_addto_pass(program, enable_inplace, enable_addto,
skip_var_names):
use_cuda = True if core.is_compiled_with_cuda() else False
attrs = {"use_cuda": use_cuda, "mem_opt_skip_vars": skip_var_names}
attr_types = {"use_cuda": "bool", "mem_opt_skip_vars": "list[str]"}
empty_startup_program = Program()
if enable_inplace:
pass_name = "buffer_shared_inplace_pass"
_apply_pass(program, empty_startup_program, pass_name, attrs,
attr_types)
if enable_addto and use_cuda:
pass_name = "inplace_addto_op_pass"
_apply_pass(program, empty_startup_program, pass_name, attrs,
attr_types)
# NOTE: This is an experimental feature. If `export FLAGS_USE_STANDALONE_EXECUTOR=1 `, # NOTE: This is an experimental feature. If `export FLAGS_USE_STANDALONE_EXECUTOR=1 `,
# use StandaloneExecutor to run the program. # use StandaloneExecutor to run the program.
if return_merged and self._enable_interpreter_core and _can_use_interpreter_core( if return_merged and self._enable_interpreter_core and _can_use_interpreter_core(
...@@ -1517,68 +1629,9 @@ class Executor(object): ...@@ -1517,68 +1629,9 @@ class Executor(object):
% (type(feed))) % (type(feed)))
feed = self._update_feed(program, feed) feed = self._update_feed(program, feed)
key = _get_strong_program_cache_key_for_new_exe( program, new_exe = self._executor_cache.get_program_and_executor(
inner_program, feed, fetch_list) program, feed, fetch_list, feed_var_name, fetch_var_name,
self.place, scope)
# a little bit tricy here, use inner_program before _add_feed_fetch_ops to get key
# while use program to geet _StandaloneExecutor
if key not in self._executor_cache._cached_executors:
# To apply IR pass, compile the Program to IrGraph and convert it back to Program
if isinstance(program,
compiler.CompiledProgram) or isinstance(
program._graph, compiler.CompiledProgram):
compiled_program = program if isinstance(
program,
compiler.CompiledProgram) else program._graph
build_strategy = compiled_program._build_strategy
# print(f"Program before convert:\n {inner_program}", flush=True)
compiled_program._compile(scope, self.place)
ir_graph = framework.IrGraph(compiled_program._graph)
converted_program = ir_graph.to_program()
if hasattr(inner_program, 'lr_sheduler'):
converted_program.lr_sheduler = inner_program.lr_sheduler
inner_program = converted_program
# print(f"Program after convert:\n {inner_program}", flush=True)
warnings.warn(
"FLAGS_USE_STANDALONE_EXECUTOR and FLAGS_CONVERT_GRAPH_TO_PROGRAM is set to 1. Graph will be converted to Program and executed using new executor."
)
else:
build_strategy = None
from paddle.incubate.autograd import prim_enabled, prim2orig
if prim_enabled() and program == default_main_program():
prim2orig()
program = self._add_feed_fetch_ops(
program=inner_program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name,
use_fetch_v2=True)
# If there are multiple blocks in the program, subblock will not be
# executed with the new executor in temporary
if program.num_blocks > 1:
warnings.warn("There are more than 1 block in program.")
# standalone executor will apply buffer_shared_inplace_pass and
# inplace_addto_op_pass to program according to build_strategy
enable_inplace = True if build_strategy is None or build_strategy.enable_inplace else False
enable_addto = True if build_strategy is not None and build_strategy.enable_addto else False
if enable_inplace or enable_addto:
# inplace should skip feed and fetch var
skip_var_names = eval(
_get_program_cache_key(feed, fetch_list))
_apply_inplace_addto_pass(program, enable_inplace,
enable_addto, skip_var_names)
new_program = program.clone()
new_exe = _StandaloneExecutor(self.place, new_program,
scope)
self._executor_cache._cached_executors[key] = (new_program,
new_exe)
program, new_exe = self._executor_cache._cached_executors[key]
self._feed_data(program, feed, feed_var_name, scope) self._feed_data(program, feed, feed_var_name, scope)
if hasattr(program, 'lr_sheduler'): if hasattr(program, 'lr_sheduler'):
...@@ -1703,7 +1756,7 @@ class Executor(object): ...@@ -1703,7 +1756,7 @@ class Executor(object):
cached_ctx = self._get_ctx_cache(cache_key) cached_ctx = self._get_ctx_cache(cache_key)
cached_scope = self._get_scope_cache(cache_key) cached_scope = self._get_scope_cache(cache_key)
if cached_program is None: if cached_program is None:
cached_program = self._add_feed_fetch_ops( cached_program = _add_feed_fetch_ops(
program=program, program=program,
feed=feed, feed=feed,
fetch_list=fetch_list, fetch_list=fetch_list,
...@@ -1727,7 +1780,7 @@ class Executor(object): ...@@ -1727,7 +1780,7 @@ class Executor(object):
ctx = cached_ctx ctx = cached_ctx
scope = cached_scope scope = cached_scope
else: else:
program = self._add_feed_fetch_ops(program=program, program = _add_feed_fetch_ops(program=program,
feed=feed, feed=feed,
fetch_list=fetch_list, fetch_list=fetch_list,
feed_var_name=feed_var_name, feed_var_name=feed_var_name,
...@@ -1965,7 +2018,7 @@ class Executor(object): ...@@ -1965,7 +2018,7 @@ class Executor(object):
if fetch_var_name in real_program.global_block().vars: if fetch_var_name in real_program.global_block().vars:
real_fetch_list.append(fetch_var) real_fetch_list.append(fetch_var)
program._pipeline_opt["section_program"] = self._add_feed_fetch_ops( program._pipeline_opt["section_program"] = _add_feed_fetch_ops(
program=program._pipeline_opt["section_program"], program=program._pipeline_opt["section_program"],
feed=[], feed=[],
fetch_list=real_fetch_list, fetch_list=real_fetch_list,
...@@ -2095,7 +2148,7 @@ class Executor(object): ...@@ -2095,7 +2148,7 @@ class Executor(object):
if fetch_var_name in real_program.global_block().vars: if fetch_var_name in real_program.global_block().vars:
real_fetch_list.append(fetch_var) real_fetch_list.append(fetch_var)
real_program = self._add_feed_fetch_ops(program=real_program, real_program = _add_feed_fetch_ops(program=real_program,
feed=[], feed=[],
fetch_list=real_fetch_list, fetch_list=real_fetch_list,
feed_var_name='feed', feed_var_name='feed',
...@@ -2219,8 +2272,7 @@ class Executor(object): ...@@ -2219,8 +2272,7 @@ class Executor(object):
real_program = program real_program = program
if "section_program" in program._pipeline_opt: if "section_program" in program._pipeline_opt:
real_program = program._pipeline_opt["section_program"] real_program = program._pipeline_opt["section_program"]
cached_program = self._add_feed_fetch_ops( cached_program = _add_feed_fetch_ops(program=real_program,
program=real_program,
feed=real_feed, feed=real_feed,
fetch_list=fetch_list, fetch_list=fetch_list,
feed_var_name=feed_var_name, feed_var_name=feed_var_name,
......
...@@ -1381,6 +1381,11 @@ py_test_modules( ...@@ -1381,6 +1381,11 @@ py_test_modules(
test_eager_deletion_padding_rnn_for_interpretercore MODULES test_eager_deletion_padding_rnn_for_interpretercore MODULES
test_eager_deletion_padding_rnn ENVS FLAGS_CONVERT_GRAPH_TO_PROGRAM=true) test_eager_deletion_padding_rnn ENVS FLAGS_CONVERT_GRAPH_TO_PROGRAM=true)
set_tests_properties(
test_buffer_shared_memory_reuse_pass
test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass
PROPERTIES ENVIRONMENT FLAGS_CONVERT_GRAPH_TO_PROGRAM=true)
# ExecutionStrategy is deprecated in standalone executor # ExecutionStrategy is deprecated in standalone executor
set_tests_properties(test_parallel_executor_dry_run set_tests_properties(test_parallel_executor_dry_run
PROPERTIES ENVIRONMENT "FLAGS_USE_STANDALONE_EXECUTOR=0") PROPERTIES ENVIRONMENT "FLAGS_USE_STANDALONE_EXECUTOR=0")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册