未验证 提交 c10aa24f 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Add BaseTransformer for dy2st error message (#44054)

* Add BaseTransformer for dy2st error message

* Fix return_transformer error

* Polish dy2st error info in runtime

* Fix UT error

* Polish runtime error code
上级 7a212593
......@@ -18,9 +18,10 @@ from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class AssertTransformer(gast.NodeTransformer):
class AssertTransformer(BaseTransformer):
"""
A class transforms python assert to convert_assert.
"""
......
......@@ -20,6 +20,7 @@ from __future__ import print_function
# See details in https://github.com/serge-sans-paille/gast/
import os
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.early_return_transformer import EarlyReturnTransformer
from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
......@@ -58,7 +59,7 @@ def apply_optimization(transformers):
transformers.insert(3, BreakTransformOptimizer)
class DygraphToStaticAst(gast.NodeTransformer):
class DygraphToStaticAst(BaseTransformer):
"""
Main class to transform Dygraph to Static Graph
"""
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.origin_info import ORIGI_INFO
class BaseTransformer(gast.NodeTransformer):
def visit(self, node):
if not isinstance(node, gast.AST):
msg = ('Expected "gast.AST", but got "{}".').format(type(node))
raise ValueError(msg)
origin_info = getattr(node, ORIGI_INFO, None)
result = super(BaseTransformer, self).visit(node)
iter_result = result
if iter_result is not node and iter_result is not None:
if not isinstance(iter_result, (list, tuple)):
iter_result = (iter_result, )
if origin_info is not None:
for n in iter_result:
setattr(n, ORIGI_INFO, origin_info)
return result
......@@ -17,9 +17,10 @@ from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static import utils
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class BasicApiTransformer(gast.NodeTransformer):
class BasicApiTransformer(BaseTransformer):
"""
Class to transform basic API from dygraph to static graph.
"""
......@@ -98,7 +99,7 @@ class BasicApiTransformer(gast.NodeTransformer):
return False
class ToTensorTransformer(gast.NodeTransformer):
class ToTensorTransformer(BaseTransformer):
"""
Class to transform paddle.to_tensor and paddle.to_variable to paddle.assign
"""
......
......@@ -21,6 +21,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import BaseNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_bool_node
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
__all__ = ['BreakContinueTransformer']
......@@ -28,7 +29,7 @@ BREAK_NAME_PREFIX = '__break'
CONTINUE_NAME_PREFIX = '__continue'
class ForToWhileTransformer(gast.NodeTransformer):
class ForToWhileTransformer(BaseTransformer):
"""
Transform python for loop into while loop and add condition node in the
loop test
......
......@@ -18,11 +18,12 @@ from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
PDB_SET = "pdb.set_trace"
class CallTransformer(gast.NodeTransformer):
class CallTransformer(BaseTransformer):
"""
This class transforms function calls into Static Graph Ast.
"""
......
......@@ -17,9 +17,10 @@ from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class CastTransformer(gast.NodeTransformer):
class CastTransformer(BaseTransformer):
"""
This class transforms type casting into Static Graph Ast.
"""
......
......@@ -16,9 +16,10 @@ from __future__ import print_function
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class EarlyReturnTransformer(gast.NodeTransformer):
class EarlyReturnTransformer(BaseTransformer):
"""
Transform if/else return statement of Dygraph into Static Graph.
"""
......
......@@ -274,19 +274,25 @@ class ErrorData(object):
bottom_error_message = error_value_lines[empty_line_idx + 1:]
revise_suggestion = self._create_revise_suggestion(bottom_error_message)
user_filepath = ''
error_traceback = []
user_code_traceback_index = []
pattern = 'File "(?P<filepath>.+)", line (?P<lineno>.+), in (?P<function_name>.+)'
# Distinguish user code and framework code using static_info_map
static_info_map = {}
for k, v in self.origin_info_map.items():
origin_filepath = v.location.filepath
origin_lineno = v.location.lineno
static_info_map[(origin_filepath, origin_lineno)] = k
for i in range(0, len(error_value_lines_strip), 2):
if error_value_lines_strip[i].startswith("File "):
re_result = re.search(pattern, error_value_lines_strip[i])
tmp_filepath, lineno_str, function_name = re_result.groups()
code = error_value_lines_strip[
i + 1] if i + 1 < len(error_value_lines_strip) else ''
if i == 0:
user_filepath = tmp_filepath
if tmp_filepath == user_filepath:
if static_info_map.get((tmp_filepath, int(lineno_str))):
user_code_traceback_index.append(len(error_traceback))
error_traceback.append(
......
......@@ -19,9 +19,10 @@ import warnings
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static import utils
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class GradTransformer(gast.NodeTransformer):
class GradTransformer(BaseTransformer):
"""
A class transforms dygraph paddle.grad to static graph paddle.gradients. The
transformation is applied to support double grad mode.
......
......@@ -33,6 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
......@@ -41,7 +42,7 @@ SET_ARGS_FUNC_PREFIX = 'set_args'
ARGS_NAME = '__args'
class IfElseTransformer(gast.NodeTransformer):
class IfElseTransformer(BaseTransformer):
"""
Transform if/else statement of Dygraph into Static Graph.
"""
......
......@@ -21,11 +21,11 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class ListTransformer(gast.NodeTransformer):
class ListTransformer(BaseTransformer):
"""
This class transforms python list used in control flow into Static Graph Ast.
"""
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
cmpop_type_to_str = {
gast.Eq: "==",
......@@ -35,7 +36,7 @@ def cmpop_node_to_str(node):
return cmpop_type_to_str[type(node)]
class LogicalTransformer(gast.NodeTransformer):
class LogicalTransformer(BaseTransformer):
"""
Transform python boolean op into Paddle logical op.
......
......@@ -32,6 +32,7 @@ from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_un
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_node, create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
__all__ = ['LoopTransformer', 'NameVisitor']
......@@ -566,7 +567,7 @@ class NameVisitor(gast.NodeVisitor):
return loop_vars - removed_vars
class LoopTransformer(gast.NodeTransformer):
class LoopTransformer(BaseTransformer):
"""
This class transforms python while/for statement into Static Graph Ast
"""
......
......@@ -17,9 +17,10 @@ from __future__ import print_function
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class PrintTransformer(gast.NodeTransformer):
class PrintTransformer(BaseTransformer):
"""
This class transforms python print function to fluid.layers.Print.
"""
......
......@@ -21,6 +21,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import ForToWhileTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
__all__ = [
'RETURN_NO_VALUE_MAGIC_NUM', 'RETURN_NO_VALUE_VAR_NAME', 'ReturnTransformer'
......@@ -57,7 +58,7 @@ def get_return_size(return_node):
return return_length
class ReplaceReturnNoneTransformer(gast.NodeTransformer):
class ReplaceReturnNoneTransformer(BaseTransformer):
"""
Replace 'return None' to 'return' because 'None' cannot be a valid input
in control flow. In ReturnTransformer single 'Return' will be appended no
......@@ -133,7 +134,7 @@ class ReturnAnalysisVisitor(gast.NodeVisitor):
return self.max_return_length[func_node]
class ReturnTransformer(gast.NodeTransformer):
class ReturnTransformer(BaseTransformer):
"""
Transforms return statements into equivalent python statements containing
only one return statement at last. The basics idea is using a return value
......
......@@ -18,9 +18,10 @@ from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class TensorShapeTransformer(gast.NodeTransformer):
class TensorShapeTransformer(BaseTransformer):
"""
This class transforms variable.shape into Static Graph Ast.
All 'xxx.shape' will be converted int '_jst.Shape(x)'.
......
......@@ -644,7 +644,12 @@ def ast_to_source_code(ast_node):
type(ast_node))
if isinstance(ast_node, gast.AST):
ast_node = gast.gast_to_ast(ast_node)
source_code = astor.to_source(ast_node)
# Do not wrap lines even if they are too long
def pretty_source(source):
return ''.join(source)
source_code = astor.to_source(ast_node, pretty_source=pretty_source)
return source_code
......
......@@ -196,14 +196,17 @@ class TestDygraphToStaticCode(unittest.TestCase):
program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else)
answer = get_source_code(StaticCode1.dyfunc_with_if_else)
self.assertEqual(answer, code)
self.assertEqual(
answer.replace('\n', '').replace(' ', ''),
code.replace('\n', '').replace(' ', ''))
def test_program_translator(self):
answer = get_source_code(StaticCode2.dyfunc_with_if_else)
program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else)
# print(code)
self.assertEqual(answer, code)
self.assertEqual(
answer.replace('\n', '').replace(' ', ''),
code.replace('\n', '').replace(' ', ''))
class TestEnableDeclarative(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册