From 1888337a49584f356d47561389bde180a07e9586 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 21 Jan 2019 20:51:18 +0800 Subject: [PATCH] tweak the executor implementation to better match origin behavior test=develop --- python/paddle/fluid/executor.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 20aa6054fe4..f6bee559eaf 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -305,7 +305,9 @@ class Executor(object): def __init__(self, place): self.place = place self.program_caches = dict() - self.executor = None + p = core.Place() + p.set_place(self.place) + self._default_executor = core.Executor(p) self._closed = False def _get_program_cache(self, program_cache_key): @@ -397,12 +399,13 @@ class Executor(object): >>> ... >>> exe.close() """ - if not self._closed and self.executor: - self.executor.close() + if not self._closed: + self._default_executor.close() self._closed = True def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name, return_numpy): + exe = program._executor if isinstance(feed, dict): feed_tensor_dict = dict() for feed_name in feed: @@ -414,8 +417,7 @@ class Executor(object): feed_tensor.set(feed[feed_name], core.CPUPlace()) feed_tensor_dict[feed_name] = feed_tensor - self.executor.feed_and_split_tensor_into_local_scopes( - feed_tensor_dict) + exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict) elif isinstance(feed, list) or isinstance(feed, tuple): if len(feed) != len(program._places): raise ValueError( @@ -436,10 +438,10 @@ class Executor(object): tensor = tmp res_dict[feed_name] = tensor res.append(res_dict) - self.executor.feed_tensors_into_local_scopes(res) + exe.feed_tensors_into_local_scopes(res) fetch_var_names = list(map(_to_name_str, fetch_list)) - self.executor.run(fetch_var_names, fetch_var_name) + exe.run(fetch_var_names, fetch_var_name) arr = scope.find_var(fetch_var_name).get_lod_tensor_array() if return_numpy: @@ -511,12 +513,9 @@ class Executor(object): compiled = isinstance(program, compiler.CompiledProgram) # For backward compatibility, run directly. if not compiled: - if not self.executor: - p = core.Place() - p.set_place(self.place) - self.executor = core.Executor(p) return self._run( program, + self._default_executor, feed=feed, fetch_list=fetch_list, feed_var_name=feed_var_name, @@ -526,7 +525,6 @@ class Executor(object): use_program_cache=use_program_cache) program._compile(scope, self.place) - self.executor = program._executor if program._is_data_parallel: return self._run_parallel( program, @@ -542,6 +540,7 @@ class Executor(object): # performance. return self._run( program._program, + self._default_executor, feed=feed, fetch_list=fetch_list, feed_var_name=feed_var_name, @@ -550,8 +549,8 @@ class Executor(object): return_numpy=return_numpy, use_program_cache=use_program_cache) - def _run(self, program, feed, fetch_list, feed_var_name, fetch_var_name, - scope, return_numpy, use_program_cache): + def _run(self, program, exe, feed, fetch_list, feed_var_name, + fetch_var_name, scope, return_numpy, use_program_cache): if feed is None: feed = {} @@ -589,7 +588,7 @@ class Executor(object): fetch_var_name=fetch_var_name) self._feed_data(program, feed, feed_var_name, scope) - self.executor.run(program.desc, scope, 0, True, True) + exe.run(program.desc, scope, 0, True, True) outs = self._fetch_data(fetch_list, fetch_var_name, scope) if return_numpy: outs = as_numpy(outs) -- GitLab