From 023eb3f9af73a09d4445557d37d260c1d942f56a Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 20 Oct 2021 14:08:07 +0800 Subject: [PATCH] catch the generatorfunction and intercept it. (#35369) (#36536) * catch the generatorfunction and intercept it. * add test generator * add test case * refine the testcase --- .../dygraph_to_static/convert_call_func.py | 11 +++++ .../test_convert_call_generator.py | 49 +++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call_generator.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index b62c16989fb..300586969ff 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -167,6 +167,17 @@ def convert_call(func): if is_builtin(func) or is_unsupported(func): return func + if inspect.isgeneratorfunction(func): + # NOTE(xiongkun03): inspect.isfunction() will return True even though func is a generator function. + # If we don't deal generatorfunction here, we will regard it as normal function and get errors in some + # occasion. + number_of_stars = 30 + translator_logger.warn( + "\n\n" + "*" * number_of_stars + + "\nYour function:`{}` doesn't support to transform to static function because it is a generator function, it will be run as-is." + .format(func.__name__) + "\n" + "*" * number_of_stars + "\n\n") + return func + if inspect.isfunction(func): # TODO(liym27): If func is a lambda function, special conversion is needed. if func.__name__ == '': diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call_generator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call_generator.py new file mode 100644 index 00000000000..cfe9e191ed4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call_generator.py @@ -0,0 +1,49 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import logging +import numpy as np + +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph import ProgramTranslator +from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import CONVERSION_OPTIONS +from test_program_translator import get_source_code +from paddle.jit import to_static + + +def dyfunc_generator(): + for i in range(100): + yield paddle.to_tensor([i] * 10) + + +def main_func(): + """ Error will raise, but we only report a warning not intercept + """ + for i in dyfunc_generator(): + print(i) + + +class TestConvertGenerator(unittest.TestCase): + def test_raise_error(self): + with self.assertRaises(Exception): + to_static(main_func)() + + +if __name__ == '__main__': + unittest.main() -- GitLab