未验证 提交 cae050e8 编写于 作者: A Aurelius84 提交者: GitHub

Enhance Check mechanism and Support single tuple/list of fetch_list in Executor (#35726)

* Enhance Check mechanism of fetch_list in Executor

* support single tuple

* fix typo

* fix typo
上级 9f4d201a
...@@ -833,7 +833,7 @@ class Executor(object): ...@@ -833,7 +833,7 @@ class Executor(object):
"The item in fetch_list should be str, variable or optimize_op, but recieved %s.", "The item in fetch_list should be str, variable or optimize_op, but recieved %s.",
type(item)) type(item))
for item in fetch_list: for index, item in enumerate(fetch_list):
# NOTE(zhiqiu): to support (optimizer_ops, param_and_grads) and optimizer_ops in fetch_list # NOTE(zhiqiu): to support (optimizer_ops, param_and_grads) and optimizer_ops in fetch_list
# we should handle tuple and list in fetch_list. # we should handle tuple and list in fetch_list.
# TODO(zhiqiu): find a better way to handle that. # TODO(zhiqiu): find a better way to handle that.
...@@ -841,6 +841,10 @@ class Executor(object): ...@@ -841,6 +841,10 @@ class Executor(object):
for i in item: for i in item:
_get_targets(_optimize_ops, _fetch_list, i) _get_targets(_optimize_ops, _fetch_list, i)
elif isinstance(item, tuple): elif isinstance(item, tuple):
if not isinstance(item[0], (list, tuple)):
raise TypeError(
"Requires fetch_list[{}][0] shall be one of (list, tuple) when type(fetch_list[{}]) is `tuple`, but received fetch_list[{}][0]'s type is `{}`.".
format(index, index, index, type(item[0]).__name__))
for i in item[0]: for i in item[0]:
_get_targets(_optimize_ops, _fetch_list, i) _get_targets(_optimize_ops, _fetch_list, i)
else: else:
...@@ -1249,17 +1253,7 @@ class Executor(object): ...@@ -1249,17 +1253,7 @@ class Executor(object):
if program is None: if program is None:
program = default_main_program() program = default_main_program()
if fetch_list is not None: fetch_list = self._check_fetch_list(fetch_list)
if isinstance(fetch_list, Variable) or isinstance(
fetch_list, str) or isinstance(fetch_list,
six.string_types):
fetch_list = [fetch_list]
assert isinstance(fetch_list, tuple) or isinstance(fetch_list, list), \
"Currently , The fetch_list type only should be list or tuple, \n"\
"but the input type is {}. For more information please refer to \n"\
"the executor.run(...).".format(type(fetch_list))
else:
fetch_list = []
if isinstance(program, Program) and program._pipeline_opt: if isinstance(program, Program) and program._pipeline_opt:
if "startup_program" in program._pipeline_opt: if "startup_program" in program._pipeline_opt:
...@@ -1479,6 +1473,35 @@ class Executor(object): ...@@ -1479,6 +1473,35 @@ class Executor(object):
def _run_inference(self, exe, feed): def _run_inference(self, exe, feed):
return exe.run(feed) return exe.run(feed)
def _check_fetch_list(self, fetch_list):
is_fetch_var = lambda var: isinstance(var, (Variable, str, six.string_types))
is_tuple_list = lambda var: isinstance(var, (tuple, list))
if fetch_list is None: return []
if is_fetch_var(fetch_list): return [fetch_list]
assert is_tuple_list(fetch_list), \
"Currently , The fetch_list type only should be list or tuple, \n"\
"but the input type is {}. For more information please refer to \n"\
"the executor.run(...).".format(type(fetch_list))
res = []
for i, var in enumerate(fetch_list):
if is_fetch_var(var):
res.append(var)
# such as [x, 'mean_out', loss]
elif is_tuple_list(var):
if all(is_fetch_var(v) for v in var):
res.extend(list(var))
else:
res.append(var)
else:
raise TypeError(
"Require fetch_list[{}] 's type shall be one of (Variable, str), but received {}.".
format(i, type(var).__name__))
return res
def _dump_debug_info(self, program=None, trainer=None): def _dump_debug_info(self, program=None, trainer=None):
with open(str(id(program)) + "_train_desc.prototxt", "w") as fout: with open(str(id(program)) + "_train_desc.prototxt", "w") as fout:
fout.write(str(trainer)) fout.write(str(trainer))
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import unittest
class TestCheckFetchList(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.feed = {"x": np.array([[0], [0], [1], [0]], dtype='float32')}
self.expected = np.array([[0], [1], [0]], dtype='float32')
self.build_program()
self.exe = paddle.static.Executor(paddle.CPUPlace())
def build_program(self):
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
x = paddle.static.data(name='x', shape=[4, 1], dtype='float32')
output = paddle.unique_consecutive(
x, return_inverse=True, return_counts=True, axis=0)
self.main_program = main_program
self.fetch_list = output
def test_with_tuple(self):
res = self.exe.run(
self.main_program,
feed=self.feed,
fetch_list=[self.fetch_list], # support single list/tuple
return_numpy=True)
self.assertTrue(np.array_equal(res[0], self.expected))
def test_with_error(self):
with self.assertRaises(TypeError):
fetch_list = [23]
res = self.exe.run(self.main_program,
feed=self.feed,
fetch_list=fetch_list)
with self.assertRaises(TypeError):
fetch_list = [(self.fetch_list[0], 32)]
res = self.exe.run(self.main_program,
feed=self.feed,
fetch_list=fetch_list)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册