From e64829e25b3bd107e4fd6864121bd4f3b4922647 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Thu, 25 Nov 2021 15:11:09 +0800 Subject: [PATCH] [new-exec] fix program cache key (#37500) * fix program cache key * bug fix * fix cache problem * remove unused code --- python/paddle/fluid/executor.py | 37 ++++++++++++--------------------- 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 3f2aebfa4c..1ed52e52fa 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -580,26 +580,6 @@ class _ExecutorCache(object): self._place = place self._cached_executors = {} - def run(self, program, scope, feed, fetch_list, return_numpy=True): - new_exe = self._get_exe_from_cache(program, scope) - return new_exe.run(feed, fetch_list, return_numpy) - - def _get_exe_from_cache(self, program, scope): - """ - Return cached _StandaloneExecutor instance. If not found, create associated - _StandaloneExecutor instance with given program and cache it. - """ - assert isinstance( - program, Program), "Required type(Program), but received {}".format( - type(program).__name__) - - if str(program) not in self._cached_executors: - new_program = program.clone() - new_exe = _StandaloneExecutor(self._place, new_program, scope) - self._cached_executors[str(program)] = new_exe - - return self._cached_executors[str(program)] - class Executor(object): """ @@ -1361,6 +1341,10 @@ class Executor(object): "feed requires dict as its Parameter. But you passed in %s" % (type(feed))) feed = self._update_feed(program, feed) + + key = _get_strong_program_cache_key(inner_program, feed, + fetch_list) + program = self._add_feed_fetch_ops( program=inner_program, feed=feed, @@ -1369,10 +1353,15 @@ class Executor(object): fetch_var_name=fetch_var_name, use_fetch_v2=True) - # NPTE(zhiqiu): Construct standalone_executor first, so - # the scope is binded with the variable_scope of standalone_executor - new_exe = self._executor_cache._get_exe_from_cache(program, - scope) + # a little bit tricy here, use inner_program before _add_feed_fetch_ops to get key + # while use program to geet _StandaloneExecutor + if key not in self._executor_cache._cached_executors: + new_program = program.clone() + new_exe = _StandaloneExecutor(self.place, new_program, + scope) + self._executor_cache._cached_executors[key] = new_exe + + new_exe = self._executor_cache._cached_executors[key] self._feed_data(program, feed, feed_var_name, scope) if hasattr(program, 'lr_sheduler'): -- GitLab