未验证 提交 01c7daad 编写于 作者: C chengduo 提交者: GitHub

Add checking for the fetch_list of Executor.run (#18957)

* update exe.run
上级 e53f517a
......@@ -38,7 +38,7 @@ paddle.fluid.DistributeTranspilerConfig.__init__
paddle.fluid.ParallelExecutor ('paddle.fluid.parallel_executor.ParallelExecutor', ('document', '2b4d2e859f2e0c6161f4fed995f7956d'))
paddle.fluid.ParallelExecutor.__init__ (ArgSpec(args=['self', 'use_cuda', 'loss_name', 'main_program', 'share_vars_from', 'exec_strategy', 'build_strategy', 'num_trainers', 'trainer_id', 'scope'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 1, 0, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.ParallelExecutor.drop_local_exe_scopes (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '77c739744ea5708b80fb1b37cc89db40'))
paddle.fluid.ParallelExecutor.run (ArgSpec(args=['self', 'fetch_list', 'feed', 'feed_dict', 'return_numpy'], varargs=None, keywords=None, defaults=(None, None, True)), ('document', '33ce6ec50f8eeb05d340e6b114b026fd'))
paddle.fluid.ParallelExecutor.run (ArgSpec(args=['self', 'fetch_list', 'feed', 'feed_dict', 'return_numpy'], varargs=None, keywords=None, defaults=(None, None, True)), ('document', '0af092676e5b1320bb4232396154ce4b'))
paddle.fluid.create_lod_tensor (ArgSpec(args=['data', 'recursive_seq_lens', 'place'], varargs=None, keywords=None, defaults=None), ('document', 'b82ea20e2dc5ff2372e0643169ca47ff'))
paddle.fluid.create_random_int_lodtensor (ArgSpec(args=['recursive_seq_lens', 'base_shape', 'place', 'low', 'high'], varargs=None, keywords=None, defaults=None), ('document', '74dc6d23185d90a7a50fbac19f5b65fb'))
paddle.fluid.DataFeedDesc ('paddle.fluid.data_feed_desc.DataFeedDesc', ('document', '43877a0d9357db94d3dbc7359cbe8c73'))
......
......@@ -44,12 +44,12 @@ class SlimGraphExecutor(object):
feed = None
if data is not None:
feeder = DataFeeder(
feed_list=graph.in_nodes.values(),
feed_list=list(graph.in_nodes.values()),
place=self.place,
program=graph.program)
feed = feeder.feed(data)
fetch_list = graph.out_nodes.values()
fetch_list = list(graph.out_nodes.values())
program = graph.compiled_graph if graph.compiled_graph else graph.program
results = self.exe.run(program,
scope=scope,
......
......@@ -18,6 +18,7 @@ import logging
import os
import multiprocessing
import sys
import warnings
import numpy as np
from .wrapped_decorator import signature_safe_contextmanager
import six
......@@ -611,17 +612,30 @@ class Executor(object):
except Exception as e:
if not isinstance(e, core.EOFException):
print("An exception was thrown!\n {}".format(str(e)))
raise e
six.reraise(*sys.exc_info())
def _run_impl(self, program, feed, fetch_list, feed_var_name,
fetch_var_name, scope, return_numpy, use_program_cache):
if self._closed:
raise RuntimeError("Attempted to use a closed Executor")
if program is None:
program = default_main_program()
if isinstance(program,Program) and \
len(program.global_block().ops) == 0:
warnings.warn("The current program is empty.")
if scope is None:
scope = global_scope()
if fetch_list is None:
if fetch_list is not None:
if isinstance(fetch_list, Variable) or isinstance(fetch_list, str):
fetch_list = [fetch_list]
assert isinstance(fetch_list, tuple) or isinstance(fetch_list, list), \
"Currently , The fetch_list type only should be list or tuple, \n"\
"but the input type is {}. For more information please refer to \n"\
"the executor.run(...).".format(type(fetch_list))
else:
fetch_list = []
compiled = isinstance(program, compiler.CompiledProgram)
......@@ -679,9 +693,8 @@ class Executor(object):
raise TypeError(
"feed requires dict as its Parameter. But you passed in %s" %
(type(feed)))
if program is None:
program = default_main_program()
assert program is not None, "The program should not be Empty"
if not isinstance(program, Program):
raise TypeError(
"Executor requires Program as its Parameter. But you passed in %s"
......
......@@ -180,7 +180,7 @@ class ParallelExecutor(object):
The feed parameter can be a dict or a list. If feed is a dict, the
feed data will be split into multiple devices. If feed is a list, we
assume the data has been splitted into multiple devices, the each
assume the data has been split into multiple devices, the each
element in the list will be copied to each device directly.
Examples:
......@@ -212,7 +212,6 @@ class ParallelExecutor(object):
loss = fluid.layers.mean(hidden)
fluid.optimizer.SGD(learning_rate=0.01).minimize(loss)
startup_program.random_seed=1
exe.run(startup_program)
train_exe = fluid.ParallelExecutor(use_cuda=use_cuda,
......@@ -239,7 +238,7 @@ class ParallelExecutor(object):
Args:
fetch_list(list): The fetched variable names
feed(list|dict|None): The feed variables. If the feed is a dict,
tensors in that dict will be splitted into each devices. If
tensors in that dict will be split into each devices. If
the feed is a list, each element of the list will be copied
to each device. Default None.
feed_dict: Alias for feed parameter, for backward compatibility.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册