未验证 提交 452be895 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat-ErrorMessage] Add interface:create_origin_info_map and...

[Dy2Stat-ErrorMessage]  Add interface:create_origin_info_map and attach_origin_info for AST node (#25627)

* Add interface:create_origin_info_map and attach_origin_info for AST node. test=develop

* Fix code according to comments from reviewers. test=develop
上级 c2a21ca9
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import collections
import inspect
import gast
# NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node.
ORIGI_INFO = "Original information of source code for ast node."
class Location(object):
"""
Location information of source code.
"""
__slots__ = (
"filepath",
"lineno",
"col_offset", )
def __init__(self, filepath, lineno, col_offset=None):
self.filepath = filepath
self.lineno = lineno
self.col_offset = col_offset
def __str__(self):
return "location: {}:{}:{}".format(self.filepath, self.lineno,
self.col_offset)
@property
def line_location(self):
return (self.filepath, self.lineno)
class OriginInfo(object):
"""
Original information of source code.
"""
__slots__ = (
"location",
"function_name",
"source_code", )
def __init__(self, location, function_name, source_code):
self.location = location
self.function_name = function_name
self.source_code = source_code
def __str__(self):
return "{} \nsource_code: {} in function {}\n ".format(
self.location, self.source_code, self.function_name)
class OriginInfoAttacher(gast.NodeTransformer):
"""
Attach original source information to AST node according corresponding function.
"""
def __init__(self, root, func):
self.root = root
self.func = unwrap(func)
self.filepath = inspect.getsourcefile(self.func)
self.source_code = inspect.getsource(self.func)
self.current_func = []
def transform(self):
source_lines, begin_lineno = inspect.getsourcelines(self.func)
begin_line = source_lines[0]
self.col_offset = len(begin_line) - len(begin_line.lstrip())
self.source_lines = [line.strip("\n") for line in source_lines]
self.lineno_offset = begin_lineno - 1
self.visit(self.root)
def visit(self, node):
if isinstance(node, gast.FunctionDef):
self.current_func.append(node)
if hasattr(node, "lineno"):
self._attach_origin_info(node)
self.generic_visit(node)
if isinstance(node, gast.FunctionDef):
self.current_func.pop()
return node
def _attach_origin_info(self, node):
assert isinstance(node, gast.AST)
assert hasattr(node, "lineno")
lineno = self._abs_lineno(node)
col_offset = self._abs_col_offset(node)
loc = Location(self.filepath, lineno, col_offset)
func_name = self.current_func[-1].name
code_line = self.source_lines[node.lineno - 1]
origin_info = OriginInfo(loc, func_name, code_line)
setattr(node, ORIGI_INFO, origin_info)
def _abs_lineno(self, node):
# NOTE(liym27):
# If the first gast.FunctionDef has decorator, its lineno is 1, which
# equals to the lineno of the first decorator node.
return self.lineno_offset + node.lineno
def _abs_col_offset(self, node):
return self.col_offset + node.col_offset
def create_origin_info_map(transformed_node, static_func):
"""
Creates a original information map between transformed static function and original dygraph function.
Args:
transformed_node(gast.AST): The AST node of transformed dygraph function with attached source information of original dygraph function.
static_func(Callable): The static function transformed by dygraph function corresponding to transformed_node.
Returns:
The original information map.
"""
origin_info_map = {}
static_source = inspect.getsource(static_func)
static_node = gast.parse(static_source)
static_node = attach_origin_info(static_node, static_func)
for t_node, s_node in ast_walk(transformed_node, static_node):
assert type(t_node) == type(s_node), \
"The node types should be the same, but received type(t_node) is {}, and type(s_node) is {}." \
.format(type(t_node), type(s_node))
dygraph_info = getattr(t_node, ORIGI_INFO, None)
static_info = getattr(s_node, ORIGI_INFO, None)
if dygraph_info is None or static_info is None:
continue
static_loc = static_info.location.line_location
exist_origin_info = origin_info_map.get(static_loc)
if exist_origin_info is not None:
if exist_origin_info.location.lineno >= dygraph_info.location.lineno:
continue
if exist_origin_info.location.col_offset <= dygraph_info.location.col_offset:
continue
origin_info_map[static_loc] = dygraph_info
return origin_info_map
def attach_origin_info(ast_node, func):
"""
Attach original source information to AST node according corresponding function.
Args:
ast_node(gast.AST): The AST node to attach original source information.
func(Callable): The corresponding function of ast_node. Parse the original information from this function.
Returns:
An AST node attached original source information.
"""
resolver = OriginInfoAttacher(ast_node, func)
resolver.transform()
return ast_node
# NOTE: inspect.unwrap() exits in PY3 but not in PY2.
def unwrap(func):
def _is_wrapped(f):
return hasattr(f, '__wrapped__')
unwrapped_f = func
while (_is_wrapped(unwrapped_f)):
unwrapped_f = unwrapped_f.__wrapped__
return unwrapped_f
def ast_walk(transformed_node, static_node):
"""
Recursively yield all descendant nodes in the trees starting at transformed_node and static_node (including itself) in parallel.
NOTE(liym27):
Function ast.walk is not used because it yield all descendant nodes in no specified order.
"""
def _as_list(x):
if x is None:
return []
return list(x) if isinstance(x, collections.Sequence) else [x]
transformed_node_list = _as_list(transformed_node)
static_node_list = _as_list(static_node)
while transformed_node_list:
assert len(transformed_node_list) == len(static_node_list)
t_node = transformed_node_list.pop()
s_node = static_node_list.pop()
if type(t_node) != type(s_node):
# NOTE(liym27):
# Node types should be strictly required, but there is no strict distinction between gast.Load and gast.Param
# in the ast transformation process.
if isinstance(t_node, (gast.Load, gast.Param)) or isinstance(
s_node, (gast.Load, gast.Param)):
continue
assert type(t_node) == type(s_node), \
"The node types should be the same, but received type(t_node) is {}, and type(s_node) is {}."\
.format(type(t_node), type(s_node))
yield t_node, s_node
for field in t_node._fields:
t_node_child = getattr(t_node, field)
s_node_child = getattr(s_node, field)
if isinstance(t_node_child, gast.AST):
transformed_node_list.append(t_node_child)
static_node_list.append(s_node_child)
elif isinstance(t_node_child, (list, tuple)):
assert len(t_node_child) == len(s_node_child)
for d_item, s_item in zip(t_node_child, s_node_child):
if isinstance(d_item, gast.AST):
transformed_node_list.append(d_item)
static_node_list.append(s_item)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.origin_info import *
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.jit import declarative
def simple_func(x):
y = x + 1
return y
def nested_func(x):
def f1(a):
return a
result = f1(x)
return result
@declarative
def decorated_func(x):
return x
@declarative
@declarative
def decorated_func2(x):
return x
class TestOriginInfo(unittest.TestCase):
def setUp(self):
self.set_test_func()
self.dygraph_func = unwrap(self.func)
self.dygraph_filepath = inspect.getfile(self.dygraph_func)
self.source_code = inspect.getsource(self.dygraph_func)
lines, self.start_lineno = inspect.getsourcelines(self.dygraph_func)
lines = [line.strip("\n") for line in lines]
self.lines = [line for line in lines
if line != ""] # Delete empty lines
self.set_static_lineno()
self.set_dygraph_info()
def set_test_func(self):
self.func = simple_func
def set_static_lineno(self):
self.static_abs_lineno_list = [2, 3, 4]
def set_dygraph_info(self):
self.line_num = 3
self.line_index_list = [0, 1, 2]
self.dy_rel_lineno_list = [0, 1, 2]
self.dy_abs_col_offset = [0, 4, 4]
self.dy_func_name = [self.dygraph_func.__name__] * 3
def set_origin_info_list(self, dygraph_ast):
assert isinstance(dygraph_ast, gast.Module)
self.transformed_node_list = [
dygraph_ast.body[0], dygraph_ast.body[0].body[0],
dygraph_ast.body[0].body[1]
]
def _get_OriginInfo_map(self):
# step1
dygraph_ast = gast.parse(self.source_code)
dygraph_ast = attach_origin_info(dygraph_ast, self.dygraph_func)
# step2
transformed_ast = DygraphToStaticAst().get_static_ast(dygraph_ast).node
# step3
self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func)
info_map = create_origin_info_map(dygraph_ast, self.static_func)
return info_map
def test_origin_info_map(self):
self.set_static_lineno()
origin_info_map = self._get_OriginInfo_map()
static_filepath = inspect.getfile(self.static_func)
start_lineno = self.start_lineno
dygraph_abs_lineno_list = [
start_lineno + lineno for lineno in self.dy_rel_lineno_list
]
for i in range(self.line_num):
static_lineno = self.static_abs_lineno_list[i]
staic_loc = Location(static_filepath, static_lineno)
self.assertIn(staic_loc.line_location, origin_info_map)
dy_lineno = dygraph_abs_lineno_list[i]
dy_col_offset = self.dy_abs_col_offset[i]
line_idx = self.line_index_list[i]
code = self.lines[line_idx]
origin_info = OriginInfo(
Location(self.dygraph_filepath, dy_lineno, dy_col_offset),
self.dy_func_name[i], code)
self.assertEqual(
str(origin_info_map[staic_loc.line_location]), str(origin_info))
def test_attach_origin_info(self):
dygraph_ast = gast.parse(self.source_code)
dygraph_ast = attach_origin_info(dygraph_ast, self.dygraph_func)
self.set_origin_info_list(dygraph_ast)
start_lineno = self.start_lineno
filepath = inspect.getfile(self.dygraph_func)
for i in range(self.line_num):
node = self.transformed_node_list[i]
origin_info = getattr(node, ORIGI_INFO)
dy_rel_lineno = self.dy_rel_lineno_list[i]
dy_abs_lineno = start_lineno + dy_rel_lineno
dy_col_offset = self.dy_abs_col_offset[i]
func_name = self.dy_func_name[i]
line_idx = self.line_index_list[i]
code = self.lines[line_idx]
self.assertEqual(origin_info.location.filepath, filepath)
self.assertEqual(origin_info.location.lineno, dy_abs_lineno)
self.assertEqual(origin_info.location.col_offset, dy_col_offset)
self.assertEqual(origin_info.function_name, func_name)
self.assertEqual(origin_info.source_code, code)
class TestOriginInfoWithNestedFunc(TestOriginInfo):
def set_test_func(self):
self.func = nested_func
def set_static_lineno(self):
self.static_abs_lineno_list = [2, 4, 5, 6, 7]
def set_dygraph_info(self):
self.line_num = 5
self.line_index_list = [0, 1, 2, 3, 4]
self.dy_rel_lineno_list = [0, 1, 2, 4, 5]
self.dy_abs_col_offset = [0, 4, 8, 4, 4]
self.dy_func_name = [self.dygraph_func.__name__] + \
["f1"] * 2 + \
[self.dygraph_func.__name__] * 2
def set_origin_info_list(self, dygraph_ast):
assert isinstance(dygraph_ast, gast.Module)
self.transformed_node_list = [
dygraph_ast.body[0], dygraph_ast.body[0].body[0],
dygraph_ast.body[0].body[0].body[0], dygraph_ast.body[0].body[1],
dygraph_ast.body[0].body[2]
]
class TestOriginInfoWithDecoratedFunc(TestOriginInfo):
def set_test_func(self):
self.func = decorated_func
def set_static_lineno(self):
self.static_abs_lineno_list = [2, 3]
def set_dygraph_info(self):
self.line_num = 2
self.line_index_list = [0, 2]
self.dy_rel_lineno_list = [0, 2]
self.dy_abs_col_offset = [0, 4]
self.dy_func_name = [self.dygraph_func.__name__] * self.line_num
def set_origin_info_list(self, dygraph_ast):
assert isinstance(dygraph_ast, gast.Module)
self.transformed_node_list = [
dygraph_ast.body[0],
dygraph_ast.body[0].body[0],
]
class TestOriginInfoWithDecoratedFunc2(TestOriginInfo):
def set_test_func(self):
self.func = decorated_func2
def set_static_lineno(self):
self.static_abs_lineno_list = [2, 3]
def set_dygraph_info(self):
self.line_num = 2
self.line_index_list = [0, 3]
self.dy_rel_lineno_list = [0, 3]
self.dy_abs_col_offset = [0, 4]
self.dy_func_name = [self.dygraph_func.__name__] * self.line_num
def set_origin_info_list(self, dygraph_ast):
assert isinstance(dygraph_ast, gast.Module)
self.transformed_node_list = [
dygraph_ast.body[0],
dygraph_ast.body[0].body[0],
]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册