提交 c4b09a71 编写于 作者: X Xin Pan

polish

test=develop
上级 5f0a0286
...@@ -77,6 +77,7 @@ class PreparedOp { ...@@ -77,6 +77,7 @@ class PreparedOp {
framework::OperatorWithKernel::OpKernelFunc func; framework::OperatorWithKernel::OpKernelFunc func;
platform::DeviceContext* dev_ctx; platform::DeviceContext* dev_ctx;
}; };
class OpBase; class OpBase;
class VarBase { class VarBase {
......
...@@ -208,20 +208,20 @@ def _fetch_var(name, scope=None, return_numpy=True): ...@@ -208,20 +208,20 @@ def _fetch_var(name, scope=None, return_numpy=True):
return tensor return tensor
def _get_program_cache_key(feed, fetch_list): def _to_name_str(var):
feed_var_names = list(feed.keys()) if isinstance(var, Variable):
return var.desc.name()
elif isinstance(var, str):
return var
elif isinstance(var, six.string_types):
return str(var)
else:
raise TypeError(str(var) + " should be Variable or str")
def to_name_str(var):
if isinstance(var, Variable):
return var.desc.name()
elif isinstance(var, str):
return var
elif isinstance(var, six.string_types):
return str(var)
else:
raise TypeError(str(var) + " should be Variable or str")
fetch_var_names = list(map(to_name_str, fetch_list)) def _get_program_cache_key(feed, fetch_list):
feed_var_names = list(feed.keys())
fetch_var_names = list(map(_to_name_str, fetch_list))
return str(feed_var_names + fetch_var_names) return str(feed_var_names + fetch_var_names)
...@@ -397,11 +397,8 @@ class Executor(object): ...@@ -397,11 +397,8 @@ class Executor(object):
self.executor.close() self.executor.close()
self._closed = True self._closed = True
def _run_parallel(self, def _run_parallel(self, scope, feed, fetch_list, fetch_var_name,
scope, return_numpy):
feed=None,
fetch_list=None,
return_numpy=True):
if isinstance(feed, dict): if isinstance(feed, dict):
feed_tensor_dict = dict() feed_tensor_dict = dict()
for feed_name in feed: for feed_name in feed:
...@@ -437,8 +434,8 @@ class Executor(object): ...@@ -437,8 +434,8 @@ class Executor(object):
res.append(res_dict) res.append(res_dict)
self.executor.feed_tensors_into_local_scopes(res) self.executor.feed_tensors_into_local_scopes(res)
fetch_var_name = '@FETCHED_VAR_NAME@' fetch_var_names = list(map(_to_name_str, fetch_list))
self.executor.run(fetch_list, fetch_var_name) self.executor.run(fetch_var_names, fetch_var_name)
arr = scope.find_var(fetch_var_name).get_lod_tensor_array() arr = scope.find_var(fetch_var_name).get_lod_tensor_array()
if return_numpy: if return_numpy:
...@@ -504,6 +501,8 @@ class Executor(object): ...@@ -504,6 +501,8 @@ class Executor(object):
if scope is None: if scope is None:
scope = global_scope() scope = global_scope()
if fetch_list is None:
fetch_list = []
compiled = isinstance(program, compiler.CompiledProgram) compiled = isinstance(program, compiler.CompiledProgram)
# For backward compatibility, run directly. # For backward compatibility, run directly.
...@@ -529,6 +528,7 @@ class Executor(object): ...@@ -529,6 +528,7 @@ class Executor(object):
scope=scope, scope=scope,
feed=feed, feed=feed,
fetch_list=fetch_list, fetch_list=fetch_list,
fetch_var_name=fetch_var_name,
return_numpy=return_numpy) return_numpy=return_numpy)
else: else:
# TODO(panyx0718): Can compile program to optimize executor # TODO(panyx0718): Can compile program to optimize executor
...@@ -552,8 +552,6 @@ class Executor(object): ...@@ -552,8 +552,6 @@ class Executor(object):
raise TypeError( raise TypeError(
"feed requires dict as its Parameter. But you passed in %s" % "feed requires dict as its Parameter. But you passed in %s" %
(type(feed))) (type(feed)))
if fetch_list is None:
fetch_list = []
if program is None: if program is None:
program = default_main_program() program = default_main_program()
......
...@@ -279,7 +279,7 @@ class ParallelExecutor(object): ...@@ -279,7 +279,7 @@ class ParallelExecutor(object):
res.append(res_dict) res.append(res_dict)
self.executor.feed_tensors_into_local_scopes(res) self.executor.feed_tensors_into_local_scopes(res)
fetch_var_name = '@FETCHED_VAR_NAME@' fetch_var_name = 'fetch'
self.executor.run(fetch_list, fetch_var_name) self.executor.run(fetch_list, fetch_var_name)
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册