From 01e2679974a995534657c42924912b8d0dc6f0a4 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Mon, 13 Dec 2021 14:10:41 +0800 Subject: [PATCH] [Dy2stat]Remove all comments of users code when dy2stat (#38003) (#38039) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 动转静时,将函数中的注释行进行删除。 有函数体外的注释行,使用gast库进行func2ast转换时会导致出错,本PR之后将注释行(#开头的行)进行了删除 --- .../fluid/dygraph/dygraph_to_static/utils.py | 6 +++++- .../dygraph_to_static/test_program_translator.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 320f2ef5b33..cccf37f8fb9 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -546,7 +546,11 @@ def func_to_source_code(function, dedent=True): raise TypeError( "The type of 'function' should be a function or method, but received {}.". format(type(function).__name__)) - source_code = inspect.getsource(function) + source_code_list, _ = inspect.getsourcelines(function) + source_code_list = [ + line for line in source_code_list if not line.lstrip().startswith('#') + ] + source_code = ''.join(source_code_list) if dedent: source_code = textwrap.dedent(source_code) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index c08a8d350f8..d4c41781078 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -26,6 +26,7 @@ import paddle.fluid as fluid from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.nn import Linear +from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code from ifelse_simple_func import dyfunc_with_if_else @@ -344,5 +345,18 @@ class TestFunctionTrainEvalMode(unittest.TestCase): net.foo.train() +class TestRemoveCommentInDy2St(unittest.TestCase): + def func_with_comment(self): + # Comment1 + x = paddle.to_tensor([1, 2, 3]) + # Comment2 + # Comment3 + y = paddle.to_tensor([4, 5, 6]) + + def test_remove_comment(self): + code_string = func_to_source_code(self.func_with_comment) + self.assertEqual('#' not in code_string, True) + + if __name__ == '__main__': unittest.main() -- GitLab