From cae050e822b74db29540d852bf73304fa6bb5f70 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 15 Sep 2021 10:41:05 +0800 Subject: [PATCH] Enhance Check mechanism and Support single tuple/list of fetch_list in Executor (#35726) * Enhance Check mechanism of fetch_list in Executor * support single tuple * fix typo * fix typo --- python/paddle/fluid/executor.py | 47 +++++++++---- .../test_executor_check_fetch_list.py | 67 +++++++++++++++++++ 2 files changed, 102 insertions(+), 12 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_executor_check_fetch_list.py diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 74b6ec3480e..8cd8bc39941 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -833,7 +833,7 @@ class Executor(object): "The item in fetch_list should be str, variable or optimize_op, but recieved %s.", type(item)) - for item in fetch_list: + for index, item in enumerate(fetch_list): # NOTE(zhiqiu): to support (optimizer_ops, param_and_grads) and optimizer_ops in fetch_list # we should handle tuple and list in fetch_list. # TODO(zhiqiu): find a better way to handle that. @@ -841,6 +841,10 @@ class Executor(object): for i in item: _get_targets(_optimize_ops, _fetch_list, i) elif isinstance(item, tuple): + if not isinstance(item[0], (list, tuple)): + raise TypeError( + "Requires fetch_list[{}][0] shall be one of (list, tuple) when type(fetch_list[{}]) is `tuple`, but received fetch_list[{}][0]'s type is `{}`.". + format(index, index, index, type(item[0]).__name__)) for i in item[0]: _get_targets(_optimize_ops, _fetch_list, i) else: @@ -1249,17 +1253,7 @@ class Executor(object): if program is None: program = default_main_program() - if fetch_list is not None: - if isinstance(fetch_list, Variable) or isinstance( - fetch_list, str) or isinstance(fetch_list, - six.string_types): - fetch_list = [fetch_list] - assert isinstance(fetch_list, tuple) or isinstance(fetch_list, list), \ - "Currently , The fetch_list type only should be list or tuple, \n"\ - "but the input type is {}. For more information please refer to \n"\ - "the executor.run(...).".format(type(fetch_list)) - else: - fetch_list = [] + fetch_list = self._check_fetch_list(fetch_list) if isinstance(program, Program) and program._pipeline_opt: if "startup_program" in program._pipeline_opt: @@ -1479,6 +1473,35 @@ class Executor(object): def _run_inference(self, exe, feed): return exe.run(feed) + def _check_fetch_list(self, fetch_list): + is_fetch_var = lambda var: isinstance(var, (Variable, str, six.string_types)) + is_tuple_list = lambda var: isinstance(var, (tuple, list)) + + if fetch_list is None: return [] + if is_fetch_var(fetch_list): return [fetch_list] + + assert is_tuple_list(fetch_list), \ + "Currently , The fetch_list type only should be list or tuple, \n"\ + "but the input type is {}. For more information please refer to \n"\ + "the executor.run(...).".format(type(fetch_list)) + + res = [] + for i, var in enumerate(fetch_list): + if is_fetch_var(var): + res.append(var) + # such as [x, 'mean_out', loss] + elif is_tuple_list(var): + if all(is_fetch_var(v) for v in var): + res.extend(list(var)) + else: + res.append(var) + else: + raise TypeError( + "Require fetch_list[{}] 's type shall be one of (Variable, str), but received {}.". + format(i, type(var).__name__)) + + return res + def _dump_debug_info(self, program=None, trainer=None): with open(str(id(program)) + "_train_desc.prototxt", "w") as fout: fout.write(str(trainer)) diff --git a/python/paddle/fluid/tests/unittests/test_executor_check_fetch_list.py b/python/paddle/fluid/tests/unittests/test_executor_check_fetch_list.py new file mode 100644 index 00000000000..1af2009f217 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_executor_check_fetch_list.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import numpy as np +import paddle +import unittest + + +class TestCheckFetchList(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.feed = {"x": np.array([[0], [0], [1], [0]], dtype='float32')} + self.expected = np.array([[0], [1], [0]], dtype='float32') + self.build_program() + self.exe = paddle.static.Executor(paddle.CPUPlace()) + + def build_program(self): + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data(name='x', shape=[4, 1], dtype='float32') + output = paddle.unique_consecutive( + x, return_inverse=True, return_counts=True, axis=0) + + self.main_program = main_program + self.fetch_list = output + + def test_with_tuple(self): + + res = self.exe.run( + self.main_program, + feed=self.feed, + fetch_list=[self.fetch_list], # support single list/tuple + return_numpy=True) + + self.assertTrue(np.array_equal(res[0], self.expected)) + + def test_with_error(self): + with self.assertRaises(TypeError): + fetch_list = [23] + res = self.exe.run(self.main_program, + feed=self.feed, + fetch_list=fetch_list) + + with self.assertRaises(TypeError): + fetch_list = [(self.fetch_list[0], 32)] + res = self.exe.run(self.main_program, + feed=self.feed, + fetch_list=fetch_list) + + +if __name__ == '__main__': + unittest.main() -- GitLab