未验证 提交 01c7daad 编写于 作者: C chengduo 提交者: GitHub

Add checking for the fetch_list of Executor.run (#18957)

* update exe.run
上级 e53f517a
...@@ -38,7 +38,7 @@ paddle.fluid.DistributeTranspilerConfig.__init__ ...@@ -38,7 +38,7 @@ paddle.fluid.DistributeTranspilerConfig.__init__
paddle.fluid.ParallelExecutor ('paddle.fluid.parallel_executor.ParallelExecutor', ('document', '2b4d2e859f2e0c6161f4fed995f7956d')) paddle.fluid.ParallelExecutor ('paddle.fluid.parallel_executor.ParallelExecutor', ('document', '2b4d2e859f2e0c6161f4fed995f7956d'))
paddle.fluid.ParallelExecutor.__init__ (ArgSpec(args=['self', 'use_cuda', 'loss_name', 'main_program', 'share_vars_from', 'exec_strategy', 'build_strategy', 'num_trainers', 'trainer_id', 'scope'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 1, 0, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.ParallelExecutor.__init__ (ArgSpec(args=['self', 'use_cuda', 'loss_name', 'main_program', 'share_vars_from', 'exec_strategy', 'build_strategy', 'num_trainers', 'trainer_id', 'scope'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 1, 0, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.ParallelExecutor.drop_local_exe_scopes (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '77c739744ea5708b80fb1b37cc89db40')) paddle.fluid.ParallelExecutor.drop_local_exe_scopes (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '77c739744ea5708b80fb1b37cc89db40'))
paddle.fluid.ParallelExecutor.run (ArgSpec(args=['self', 'fetch_list', 'feed', 'feed_dict', 'return_numpy'], varargs=None, keywords=None, defaults=(None, None, True)), ('document', '33ce6ec50f8eeb05d340e6b114b026fd')) paddle.fluid.ParallelExecutor.run (ArgSpec(args=['self', 'fetch_list', 'feed', 'feed_dict', 'return_numpy'], varargs=None, keywords=None, defaults=(None, None, True)), ('document', '0af092676e5b1320bb4232396154ce4b'))
paddle.fluid.create_lod_tensor (ArgSpec(args=['data', 'recursive_seq_lens', 'place'], varargs=None, keywords=None, defaults=None), ('document', 'b82ea20e2dc5ff2372e0643169ca47ff')) paddle.fluid.create_lod_tensor (ArgSpec(args=['data', 'recursive_seq_lens', 'place'], varargs=None, keywords=None, defaults=None), ('document', 'b82ea20e2dc5ff2372e0643169ca47ff'))
paddle.fluid.create_random_int_lodtensor (ArgSpec(args=['recursive_seq_lens', 'base_shape', 'place', 'low', 'high'], varargs=None, keywords=None, defaults=None), ('document', '74dc6d23185d90a7a50fbac19f5b65fb')) paddle.fluid.create_random_int_lodtensor (ArgSpec(args=['recursive_seq_lens', 'base_shape', 'place', 'low', 'high'], varargs=None, keywords=None, defaults=None), ('document', '74dc6d23185d90a7a50fbac19f5b65fb'))
paddle.fluid.DataFeedDesc ('paddle.fluid.data_feed_desc.DataFeedDesc', ('document', '43877a0d9357db94d3dbc7359cbe8c73')) paddle.fluid.DataFeedDesc ('paddle.fluid.data_feed_desc.DataFeedDesc', ('document', '43877a0d9357db94d3dbc7359cbe8c73'))
......
...@@ -44,12 +44,12 @@ class SlimGraphExecutor(object): ...@@ -44,12 +44,12 @@ class SlimGraphExecutor(object):
feed = None feed = None
if data is not None: if data is not None:
feeder = DataFeeder( feeder = DataFeeder(
feed_list=graph.in_nodes.values(), feed_list=list(graph.in_nodes.values()),
place=self.place, place=self.place,
program=graph.program) program=graph.program)
feed = feeder.feed(data) feed = feeder.feed(data)
fetch_list = graph.out_nodes.values() fetch_list = list(graph.out_nodes.values())
program = graph.compiled_graph if graph.compiled_graph else graph.program program = graph.compiled_graph if graph.compiled_graph else graph.program
results = self.exe.run(program, results = self.exe.run(program,
scope=scope, scope=scope,
......
...@@ -18,6 +18,7 @@ import logging ...@@ -18,6 +18,7 @@ import logging
import os import os
import multiprocessing import multiprocessing
import sys import sys
import warnings
import numpy as np import numpy as np
from .wrapped_decorator import signature_safe_contextmanager from .wrapped_decorator import signature_safe_contextmanager
import six import six
...@@ -611,17 +612,30 @@ class Executor(object): ...@@ -611,17 +612,30 @@ class Executor(object):
except Exception as e: except Exception as e:
if not isinstance(e, core.EOFException): if not isinstance(e, core.EOFException):
print("An exception was thrown!\n {}".format(str(e))) print("An exception was thrown!\n {}".format(str(e)))
raise e six.reraise(*sys.exc_info())
def _run_impl(self, program, feed, fetch_list, feed_var_name, def _run_impl(self, program, feed, fetch_list, feed_var_name,
fetch_var_name, scope, return_numpy, use_program_cache): fetch_var_name, scope, return_numpy, use_program_cache):
if self._closed: if self._closed:
raise RuntimeError("Attempted to use a closed Executor") raise RuntimeError("Attempted to use a closed Executor")
if program is None:
program = default_main_program()
if isinstance(program,Program) and \
len(program.global_block().ops) == 0:
warnings.warn("The current program is empty.")
if scope is None: if scope is None:
scope = global_scope() scope = global_scope()
if fetch_list is None:
if fetch_list is not None:
if isinstance(fetch_list, Variable) or isinstance(fetch_list, str):
fetch_list = [fetch_list]
assert isinstance(fetch_list, tuple) or isinstance(fetch_list, list), \
"Currently , The fetch_list type only should be list or tuple, \n"\
"but the input type is {}. For more information please refer to \n"\
"the executor.run(...).".format(type(fetch_list))
else:
fetch_list = [] fetch_list = []
compiled = isinstance(program, compiler.CompiledProgram) compiled = isinstance(program, compiler.CompiledProgram)
...@@ -679,9 +693,8 @@ class Executor(object): ...@@ -679,9 +693,8 @@ class Executor(object):
raise TypeError( raise TypeError(
"feed requires dict as its Parameter. But you passed in %s" % "feed requires dict as its Parameter. But you passed in %s" %
(type(feed))) (type(feed)))
if program is None:
program = default_main_program()
assert program is not None, "The program should not be Empty"
if not isinstance(program, Program): if not isinstance(program, Program):
raise TypeError( raise TypeError(
"Executor requires Program as its Parameter. But you passed in %s" "Executor requires Program as its Parameter. But you passed in %s"
......
...@@ -180,7 +180,7 @@ class ParallelExecutor(object): ...@@ -180,7 +180,7 @@ class ParallelExecutor(object):
The feed parameter can be a dict or a list. If feed is a dict, the The feed parameter can be a dict or a list. If feed is a dict, the
feed data will be split into multiple devices. If feed is a list, we feed data will be split into multiple devices. If feed is a list, we
assume the data has been splitted into multiple devices, the each assume the data has been split into multiple devices, the each
element in the list will be copied to each device directly. element in the list will be copied to each device directly.
Examples: Examples:
...@@ -212,7 +212,6 @@ class ParallelExecutor(object): ...@@ -212,7 +212,6 @@ class ParallelExecutor(object):
loss = fluid.layers.mean(hidden) loss = fluid.layers.mean(hidden)
fluid.optimizer.SGD(learning_rate=0.01).minimize(loss) fluid.optimizer.SGD(learning_rate=0.01).minimize(loss)
startup_program.random_seed=1
exe.run(startup_program) exe.run(startup_program)
train_exe = fluid.ParallelExecutor(use_cuda=use_cuda, train_exe = fluid.ParallelExecutor(use_cuda=use_cuda,
...@@ -239,7 +238,7 @@ class ParallelExecutor(object): ...@@ -239,7 +238,7 @@ class ParallelExecutor(object):
Args: Args:
fetch_list(list): The fetched variable names fetch_list(list): The fetched variable names
feed(list|dict|None): The feed variables. If the feed is a dict, feed(list|dict|None): The feed variables. If the feed is a dict,
tensors in that dict will be splitted into each devices. If tensors in that dict will be split into each devices. If
the feed is a list, each element of the list will be copied the feed is a list, each element of the list will be copied
to each device. Default None. to each device. Default None.
feed_dict: Alias for feed parameter, for backward compatibility. feed_dict: Alias for feed parameter, for backward compatibility.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册