From 452be8950a30318549e7a0c0b1d9dc709280d53a Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Thu, 30 Jul 2020 10:18:10 +0800 Subject: [PATCH] [Dy2Stat-ErrorMessage] Add interface:create_origin_info_map and attach_origin_info for AST node (#25627) * Add interface:create_origin_info_map and attach_origin_info for AST node. test=develop * Fix code according to comments from reviewers. test=develop --- .../dygraph/dygraph_to_static/origin_info.py | 236 ++++++++++++++++++ .../dygraph_to_static/test_origin_info.py | 215 ++++++++++++++++ 2 files changed, 451 insertions(+) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py b/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py new file mode 100644 index 0000000000..429fa27f61 --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py @@ -0,0 +1,236 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import collections +import inspect + +import gast + +# NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node. +ORIGI_INFO = "Original information of source code for ast node." + + +class Location(object): + """ + Location information of source code. + """ + __slots__ = ( + "filepath", + "lineno", + "col_offset", ) + + def __init__(self, filepath, lineno, col_offset=None): + self.filepath = filepath + self.lineno = lineno + self.col_offset = col_offset + + def __str__(self): + return "location: {}:{}:{}".format(self.filepath, self.lineno, + self.col_offset) + + @property + def line_location(self): + return (self.filepath, self.lineno) + + +class OriginInfo(object): + """ + Original information of source code. + """ + __slots__ = ( + "location", + "function_name", + "source_code", ) + + def __init__(self, location, function_name, source_code): + self.location = location + self.function_name = function_name + self.source_code = source_code + + def __str__(self): + return "{} \nsource_code: {} in function {}\n ".format( + self.location, self.source_code, self.function_name) + + +class OriginInfoAttacher(gast.NodeTransformer): + """ + Attach original source information to AST node according corresponding function. + """ + + def __init__(self, root, func): + self.root = root + self.func = unwrap(func) + self.filepath = inspect.getsourcefile(self.func) + self.source_code = inspect.getsource(self.func) + self.current_func = [] + + def transform(self): + source_lines, begin_lineno = inspect.getsourcelines(self.func) + begin_line = source_lines[0] + self.col_offset = len(begin_line) - len(begin_line.lstrip()) + self.source_lines = [line.strip("\n") for line in source_lines] + self.lineno_offset = begin_lineno - 1 + self.visit(self.root) + + def visit(self, node): + if isinstance(node, gast.FunctionDef): + self.current_func.append(node) + if hasattr(node, "lineno"): + self._attach_origin_info(node) + self.generic_visit(node) + + if isinstance(node, gast.FunctionDef): + self.current_func.pop() + return node + + def _attach_origin_info(self, node): + assert isinstance(node, gast.AST) + assert hasattr(node, "lineno") + + lineno = self._abs_lineno(node) + col_offset = self._abs_col_offset(node) + loc = Location(self.filepath, lineno, col_offset) + func_name = self.current_func[-1].name + code_line = self.source_lines[node.lineno - 1] + + origin_info = OriginInfo(loc, func_name, code_line) + setattr(node, ORIGI_INFO, origin_info) + + def _abs_lineno(self, node): + # NOTE(liym27): + # If the first gast.FunctionDef has decorator, its lineno is 1, which + # equals to the lineno of the first decorator node. + return self.lineno_offset + node.lineno + + def _abs_col_offset(self, node): + return self.col_offset + node.col_offset + + +def create_origin_info_map(transformed_node, static_func): + """ + Creates a original information map between transformed static function and original dygraph function. + + Args: + transformed_node(gast.AST): The AST node of transformed dygraph function with attached source information of original dygraph function. + static_func(Callable): The static function transformed by dygraph function corresponding to transformed_node. + + Returns: + The original information map. + """ + + origin_info_map = {} + static_source = inspect.getsource(static_func) + static_node = gast.parse(static_source) + static_node = attach_origin_info(static_node, static_func) + + for t_node, s_node in ast_walk(transformed_node, static_node): + assert type(t_node) == type(s_node), \ + "The node types should be the same, but received type(t_node) is {}, and type(s_node) is {}." \ + .format(type(t_node), type(s_node)) + dygraph_info = getattr(t_node, ORIGI_INFO, None) + static_info = getattr(s_node, ORIGI_INFO, None) + + if dygraph_info is None or static_info is None: + continue + static_loc = static_info.location.line_location + exist_origin_info = origin_info_map.get(static_loc) + + if exist_origin_info is not None: + if exist_origin_info.location.lineno >= dygraph_info.location.lineno: + continue + if exist_origin_info.location.col_offset <= dygraph_info.location.col_offset: + continue + + origin_info_map[static_loc] = dygraph_info + + return origin_info_map + + +def attach_origin_info(ast_node, func): + """ + Attach original source information to AST node according corresponding function. + + Args: + ast_node(gast.AST): The AST node to attach original source information. + func(Callable): The corresponding function of ast_node. Parse the original information from this function. + + Returns: + An AST node attached original source information. + """ + resolver = OriginInfoAttacher(ast_node, func) + resolver.transform() + return ast_node + + +# NOTE: inspect.unwrap() exits in PY3 but not in PY2. +def unwrap(func): + def _is_wrapped(f): + return hasattr(f, '__wrapped__') + + unwrapped_f = func + while (_is_wrapped(unwrapped_f)): + unwrapped_f = unwrapped_f.__wrapped__ + + return unwrapped_f + + +def ast_walk(transformed_node, static_node): + """ + Recursively yield all descendant nodes in the trees starting at transformed_node and static_node (including itself) in parallel. + + NOTE(liym27): + Function ast.walk is not used because it yield all descendant nodes in no specified order. + """ + + def _as_list(x): + if x is None: + return [] + return list(x) if isinstance(x, collections.Sequence) else [x] + + transformed_node_list = _as_list(transformed_node) + static_node_list = _as_list(static_node) + + while transformed_node_list: + assert len(transformed_node_list) == len(static_node_list) + t_node = transformed_node_list.pop() + s_node = static_node_list.pop() + if type(t_node) != type(s_node): + # NOTE(liym27): + # Node types should be strictly required, but there is no strict distinction between gast.Load and gast.Param + # in the ast transformation process. + if isinstance(t_node, (gast.Load, gast.Param)) or isinstance( + s_node, (gast.Load, gast.Param)): + continue + + assert type(t_node) == type(s_node), \ + "The node types should be the same, but received type(t_node) is {}, and type(s_node) is {}."\ + .format(type(t_node), type(s_node)) + + yield t_node, s_node + + for field in t_node._fields: + t_node_child = getattr(t_node, field) + s_node_child = getattr(s_node, field) + + if isinstance(t_node_child, gast.AST): + transformed_node_list.append(t_node_child) + static_node_list.append(s_node_child) + elif isinstance(t_node_child, (list, tuple)): + assert len(t_node_child) == len(s_node_child) + for d_item, s_item in zip(t_node_child, s_node_child): + if isinstance(d_item, gast.AST): + transformed_node_list.append(d_item) + static_node_list.append(s_item) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py new file mode 100644 index 0000000000..631655ec74 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py @@ -0,0 +1,215 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst +from paddle.fluid.dygraph.dygraph_to_static.origin_info import * +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func +from paddle.fluid.dygraph.jit import declarative + + +def simple_func(x): + y = x + 1 + return y + + +def nested_func(x): + def f1(a): + return a + + result = f1(x) + return result + + +@declarative +def decorated_func(x): + return x + + +@declarative +@declarative +def decorated_func2(x): + return x + + +class TestOriginInfo(unittest.TestCase): + def setUp(self): + self.set_test_func() + self.dygraph_func = unwrap(self.func) + self.dygraph_filepath = inspect.getfile(self.dygraph_func) + self.source_code = inspect.getsource(self.dygraph_func) + lines, self.start_lineno = inspect.getsourcelines(self.dygraph_func) + lines = [line.strip("\n") for line in lines] + self.lines = [line for line in lines + if line != ""] # Delete empty lines + + self.set_static_lineno() + self.set_dygraph_info() + + def set_test_func(self): + self.func = simple_func + + def set_static_lineno(self): + self.static_abs_lineno_list = [2, 3, 4] + + def set_dygraph_info(self): + self.line_num = 3 + self.line_index_list = [0, 1, 2] + self.dy_rel_lineno_list = [0, 1, 2] + self.dy_abs_col_offset = [0, 4, 4] + self.dy_func_name = [self.dygraph_func.__name__] * 3 + + def set_origin_info_list(self, dygraph_ast): + assert isinstance(dygraph_ast, gast.Module) + self.transformed_node_list = [ + dygraph_ast.body[0], dygraph_ast.body[0].body[0], + dygraph_ast.body[0].body[1] + ] + + def _get_OriginInfo_map(self): + # step1 + dygraph_ast = gast.parse(self.source_code) + dygraph_ast = attach_origin_info(dygraph_ast, self.dygraph_func) + + # step2 + transformed_ast = DygraphToStaticAst().get_static_ast(dygraph_ast).node + + # step3 + self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func) + info_map = create_origin_info_map(dygraph_ast, self.static_func) + + return info_map + + def test_origin_info_map(self): + self.set_static_lineno() + origin_info_map = self._get_OriginInfo_map() + static_filepath = inspect.getfile(self.static_func) + start_lineno = self.start_lineno + dygraph_abs_lineno_list = [ + start_lineno + lineno for lineno in self.dy_rel_lineno_list + ] + + for i in range(self.line_num): + static_lineno = self.static_abs_lineno_list[i] + staic_loc = Location(static_filepath, static_lineno) + self.assertIn(staic_loc.line_location, origin_info_map) + + dy_lineno = dygraph_abs_lineno_list[i] + dy_col_offset = self.dy_abs_col_offset[i] + line_idx = self.line_index_list[i] + code = self.lines[line_idx] + origin_info = OriginInfo( + Location(self.dygraph_filepath, dy_lineno, dy_col_offset), + self.dy_func_name[i], code) + self.assertEqual( + str(origin_info_map[staic_loc.line_location]), str(origin_info)) + + def test_attach_origin_info(self): + dygraph_ast = gast.parse(self.source_code) + dygraph_ast = attach_origin_info(dygraph_ast, self.dygraph_func) + self.set_origin_info_list(dygraph_ast) + start_lineno = self.start_lineno + + filepath = inspect.getfile(self.dygraph_func) + + for i in range(self.line_num): + node = self.transformed_node_list[i] + origin_info = getattr(node, ORIGI_INFO) + dy_rel_lineno = self.dy_rel_lineno_list[i] + dy_abs_lineno = start_lineno + dy_rel_lineno + dy_col_offset = self.dy_abs_col_offset[i] + func_name = self.dy_func_name[i] + line_idx = self.line_index_list[i] + code = self.lines[line_idx] + self.assertEqual(origin_info.location.filepath, filepath) + self.assertEqual(origin_info.location.lineno, dy_abs_lineno) + self.assertEqual(origin_info.location.col_offset, dy_col_offset) + self.assertEqual(origin_info.function_name, func_name) + self.assertEqual(origin_info.source_code, code) + + +class TestOriginInfoWithNestedFunc(TestOriginInfo): + def set_test_func(self): + self.func = nested_func + + def set_static_lineno(self): + self.static_abs_lineno_list = [2, 4, 5, 6, 7] + + def set_dygraph_info(self): + self.line_num = 5 + self.line_index_list = [0, 1, 2, 3, 4] + self.dy_rel_lineno_list = [0, 1, 2, 4, 5] + self.dy_abs_col_offset = [0, 4, 8, 4, 4] + self.dy_func_name = [self.dygraph_func.__name__] + \ + ["f1"] * 2 + \ + [self.dygraph_func.__name__] * 2 + + def set_origin_info_list(self, dygraph_ast): + assert isinstance(dygraph_ast, gast.Module) + self.transformed_node_list = [ + dygraph_ast.body[0], dygraph_ast.body[0].body[0], + dygraph_ast.body[0].body[0].body[0], dygraph_ast.body[0].body[1], + dygraph_ast.body[0].body[2] + ] + + +class TestOriginInfoWithDecoratedFunc(TestOriginInfo): + def set_test_func(self): + self.func = decorated_func + + def set_static_lineno(self): + self.static_abs_lineno_list = [2, 3] + + def set_dygraph_info(self): + self.line_num = 2 + self.line_index_list = [0, 2] + self.dy_rel_lineno_list = [0, 2] + self.dy_abs_col_offset = [0, 4] + self.dy_func_name = [self.dygraph_func.__name__] * self.line_num + + def set_origin_info_list(self, dygraph_ast): + assert isinstance(dygraph_ast, gast.Module) + self.transformed_node_list = [ + dygraph_ast.body[0], + dygraph_ast.body[0].body[0], + ] + + +class TestOriginInfoWithDecoratedFunc2(TestOriginInfo): + def set_test_func(self): + self.func = decorated_func2 + + def set_static_lineno(self): + self.static_abs_lineno_list = [2, 3] + + def set_dygraph_info(self): + self.line_num = 2 + self.line_index_list = [0, 3] + self.dy_rel_lineno_list = [0, 3] + self.dy_abs_col_offset = [0, 4] + self.dy_func_name = [self.dygraph_func.__name__] * self.line_num + + def set_origin_info_list(self, dygraph_ast): + assert isinstance(dygraph_ast, gast.Module) + self.transformed_node_list = [ + dygraph_ast.body[0], + dygraph_ast.body[0].body[0], + ] + + +if __name__ == '__main__': + unittest.main() -- GitLab