未验证 提交 e64829e2 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] fix program cache key (#37500)

* fix program cache key

* bug fix

* fix cache problem

* remove unused code
上级 50f75fb5
......@@ -580,26 +580,6 @@ class _ExecutorCache(object):
self._place = place
self._cached_executors = {}
def run(self, program, scope, feed, fetch_list, return_numpy=True):
new_exe = self._get_exe_from_cache(program, scope)
return new_exe.run(feed, fetch_list, return_numpy)
def _get_exe_from_cache(self, program, scope):
"""
Return cached _StandaloneExecutor instance. If not found, create associated
_StandaloneExecutor instance with given program and cache it.
"""
assert isinstance(
program, Program), "Required type(Program), but received {}".format(
type(program).__name__)
if str(program) not in self._cached_executors:
new_program = program.clone()
new_exe = _StandaloneExecutor(self._place, new_program, scope)
self._cached_executors[str(program)] = new_exe
return self._cached_executors[str(program)]
class Executor(object):
"""
......@@ -1361,6 +1341,10 @@ class Executor(object):
"feed requires dict as its Parameter. But you passed in %s"
% (type(feed)))
feed = self._update_feed(program, feed)
key = _get_strong_program_cache_key(inner_program, feed,
fetch_list)
program = self._add_feed_fetch_ops(
program=inner_program,
feed=feed,
......@@ -1369,10 +1353,15 @@ class Executor(object):
fetch_var_name=fetch_var_name,
use_fetch_v2=True)
# NPTE(zhiqiu): Construct standalone_executor first, so
# the scope is binded with the variable_scope of standalone_executor
new_exe = self._executor_cache._get_exe_from_cache(program,
scope)
# a little bit tricy here, use inner_program before _add_feed_fetch_ops to get key
# while use program to geet _StandaloneExecutor
if key not in self._executor_cache._cached_executors:
new_program = program.clone()
new_exe = _StandaloneExecutor(self.place, new_program,
scope)
self._executor_cache._cached_executors[key] = new_exe
new_exe = self._executor_cache._cached_executors[key]
self._feed_data(program, feed, feed_var_name, scope)
if hasattr(program, 'lr_sheduler'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册