diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index c4df01c4c7654a7874a4279869daf813efc543f1..51e85901e7d558c697382c5d433c2cc7661c3213 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -547,7 +547,11 @@ def func_to_source_code(function, dedent=True): raise TypeError( "The type of 'function' should be a function or method, but received {}.". format(type(function).__name__)) - source_code = inspect.getsource(function) + source_code_list, _ = inspect.getsourcelines(function) + source_code_list = [ + line for line in source_code_list if not line.lstrip().startswith('#') + ] + source_code = ''.join(source_code_list) if dedent: source_code = textwrap.dedent(source_code) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index c08a8d350f8aa83eb2c7e2eae8726917c02bba4f..d4c41781078b760feadbb5e2c928968414d17063 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -26,6 +26,7 @@ import paddle.fluid as fluid from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.nn import Linear +from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code from ifelse_simple_func import dyfunc_with_if_else @@ -344,5 +345,18 @@ class TestFunctionTrainEvalMode(unittest.TestCase): net.foo.train() +class TestRemoveCommentInDy2St(unittest.TestCase): + def func_with_comment(self): + # Comment1 + x = paddle.to_tensor([1, 2, 3]) + # Comment2 + # Comment3 + y = paddle.to_tensor([4, 5, 6]) + + def test_remove_comment(self): + code_string = func_to_source_code(self.func_with_comment) + self.assertEqual('#' not in code_string, True) + + if __name__ == '__main__': unittest.main()