未验证 提交 c10aa24f 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Add BaseTransformer for dy2st error message (#44054)

* Add BaseTransformer for dy2st error message

* Fix return_transformer error

* Polish dy2st error info in runtime

* Fix UT error

* Polish runtime error code
上级 7a212593
...@@ -18,9 +18,10 @@ from paddle.utils import gast ...@@ -18,9 +18,10 @@ from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class AssertTransformer(gast.NodeTransformer): class AssertTransformer(BaseTransformer):
""" """
A class transforms python assert to convert_assert. A class transforms python assert to convert_assert.
""" """
......
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
# See details in https://github.com/serge-sans-paille/gast/ # See details in https://github.com/serge-sans-paille/gast/
import os import os
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.early_return_transformer import EarlyReturnTransformer from paddle.fluid.dygraph.dygraph_to_static.early_return_transformer import EarlyReturnTransformer
from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
...@@ -58,7 +59,7 @@ def apply_optimization(transformers): ...@@ -58,7 +59,7 @@ def apply_optimization(transformers):
transformers.insert(3, BreakTransformOptimizer) transformers.insert(3, BreakTransformOptimizer)
class DygraphToStaticAst(gast.NodeTransformer): class DygraphToStaticAst(BaseTransformer):
""" """
Main class to transform Dygraph to Static Graph Main class to transform Dygraph to Static Graph
""" """
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.origin_info import ORIGI_INFO
class BaseTransformer(gast.NodeTransformer):
def visit(self, node):
if not isinstance(node, gast.AST):
msg = ('Expected "gast.AST", but got "{}".').format(type(node))
raise ValueError(msg)
origin_info = getattr(node, ORIGI_INFO, None)
result = super(BaseTransformer, self).visit(node)
iter_result = result
if iter_result is not node and iter_result is not None:
if not isinstance(iter_result, (list, tuple)):
iter_result = (iter_result, )
if origin_info is not None:
for n in iter_result:
setattr(n, ORIGI_INFO, origin_info)
return result
...@@ -17,9 +17,10 @@ from paddle.utils import gast ...@@ -17,9 +17,10 @@ from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static import utils from paddle.fluid.dygraph.dygraph_to_static import utils
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class BasicApiTransformer(gast.NodeTransformer): class BasicApiTransformer(BaseTransformer):
""" """
Class to transform basic API from dygraph to static graph. Class to transform basic API from dygraph to static graph.
""" """
...@@ -98,7 +99,7 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -98,7 +99,7 @@ class BasicApiTransformer(gast.NodeTransformer):
return False return False
class ToTensorTransformer(gast.NodeTransformer): class ToTensorTransformer(BaseTransformer):
""" """
Class to transform paddle.to_tensor and paddle.to_variable to paddle.assign Class to transform paddle.to_tensor and paddle.to_variable to paddle.assign
""" """
......
...@@ -21,6 +21,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list ...@@ -21,6 +21,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import BaseNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import BaseNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_bool_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_bool_node
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
__all__ = ['BreakContinueTransformer'] __all__ = ['BreakContinueTransformer']
...@@ -28,7 +29,7 @@ BREAK_NAME_PREFIX = '__break' ...@@ -28,7 +29,7 @@ BREAK_NAME_PREFIX = '__break'
CONTINUE_NAME_PREFIX = '__continue' CONTINUE_NAME_PREFIX = '__continue'
class ForToWhileTransformer(gast.NodeTransformer): class ForToWhileTransformer(BaseTransformer):
""" """
Transform python for loop into while loop and add condition node in the Transform python for loop into while loop and add condition node in the
loop test loop test
......
...@@ -18,11 +18,12 @@ from paddle.utils import gast ...@@ -18,11 +18,12 @@ from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
PDB_SET = "pdb.set_trace" PDB_SET = "pdb.set_trace"
class CallTransformer(gast.NodeTransformer): class CallTransformer(BaseTransformer):
""" """
This class transforms function calls into Static Graph Ast. This class transforms function calls into Static Graph Ast.
""" """
......
...@@ -17,9 +17,10 @@ from paddle.utils import gast ...@@ -17,9 +17,10 @@ from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class CastTransformer(gast.NodeTransformer): class CastTransformer(BaseTransformer):
""" """
This class transforms type casting into Static Graph Ast. This class transforms type casting into Static Graph Ast.
""" """
......
...@@ -16,9 +16,10 @@ from __future__ import print_function ...@@ -16,9 +16,10 @@ from __future__ import print_function
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class EarlyReturnTransformer(gast.NodeTransformer): class EarlyReturnTransformer(BaseTransformer):
""" """
Transform if/else return statement of Dygraph into Static Graph. Transform if/else return statement of Dygraph into Static Graph.
""" """
......
...@@ -274,19 +274,25 @@ class ErrorData(object): ...@@ -274,19 +274,25 @@ class ErrorData(object):
bottom_error_message = error_value_lines[empty_line_idx + 1:] bottom_error_message = error_value_lines[empty_line_idx + 1:]
revise_suggestion = self._create_revise_suggestion(bottom_error_message) revise_suggestion = self._create_revise_suggestion(bottom_error_message)
user_filepath = ''
error_traceback = [] error_traceback = []
user_code_traceback_index = [] user_code_traceback_index = []
pattern = 'File "(?P<filepath>.+)", line (?P<lineno>.+), in (?P<function_name>.+)' pattern = 'File "(?P<filepath>.+)", line (?P<lineno>.+), in (?P<function_name>.+)'
# Distinguish user code and framework code using static_info_map
static_info_map = {}
for k, v in self.origin_info_map.items():
origin_filepath = v.location.filepath
origin_lineno = v.location.lineno
static_info_map[(origin_filepath, origin_lineno)] = k
for i in range(0, len(error_value_lines_strip), 2): for i in range(0, len(error_value_lines_strip), 2):
if error_value_lines_strip[i].startswith("File "): if error_value_lines_strip[i].startswith("File "):
re_result = re.search(pattern, error_value_lines_strip[i]) re_result = re.search(pattern, error_value_lines_strip[i])
tmp_filepath, lineno_str, function_name = re_result.groups() tmp_filepath, lineno_str, function_name = re_result.groups()
code = error_value_lines_strip[ code = error_value_lines_strip[
i + 1] if i + 1 < len(error_value_lines_strip) else '' i + 1] if i + 1 < len(error_value_lines_strip) else ''
if i == 0:
user_filepath = tmp_filepath if static_info_map.get((tmp_filepath, int(lineno_str))):
if tmp_filepath == user_filepath:
user_code_traceback_index.append(len(error_traceback)) user_code_traceback_index.append(len(error_traceback))
error_traceback.append( error_traceback.append(
......
...@@ -19,9 +19,10 @@ import warnings ...@@ -19,9 +19,10 @@ import warnings
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static import utils from paddle.fluid.dygraph.dygraph_to_static import utils
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class GradTransformer(gast.NodeTransformer): class GradTransformer(BaseTransformer):
""" """
A class transforms dygraph paddle.grad to static graph paddle.gradients. The A class transforms dygraph paddle.grad to static graph paddle.gradients. The
transformation is applied to support double grad mode. transformation is applied to support double grad mode.
......
...@@ -33,6 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe ...@@ -33,6 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_node from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_get_args_node, create_set_args_node from paddle.fluid.dygraph.dygraph_to_static.utils import create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
TRUE_FUNC_PREFIX = 'true_fn' TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn' FALSE_FUNC_PREFIX = 'false_fn'
...@@ -41,7 +42,7 @@ SET_ARGS_FUNC_PREFIX = 'set_args' ...@@ -41,7 +42,7 @@ SET_ARGS_FUNC_PREFIX = 'set_args'
ARGS_NAME = '__args' ARGS_NAME = '__args'
class IfElseTransformer(gast.NodeTransformer): class IfElseTransformer(BaseTransformer):
""" """
Transform if/else statement of Dygraph into Static Graph. Transform if/else statement of Dygraph into Static Graph.
""" """
......
...@@ -21,11 +21,11 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe ...@@ -21,11 +21,11 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class ListTransformer(gast.NodeTransformer): class ListTransformer(BaseTransformer):
""" """
This class transforms python list used in control flow into Static Graph Ast. This class transforms python list used in control flow into Static Graph Ast.
""" """
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
cmpop_type_to_str = { cmpop_type_to_str = {
gast.Eq: "==", gast.Eq: "==",
...@@ -35,7 +36,7 @@ def cmpop_node_to_str(node): ...@@ -35,7 +36,7 @@ def cmpop_node_to_str(node):
return cmpop_type_to_str[type(node)] return cmpop_type_to_str[type(node)]
class LogicalTransformer(gast.NodeTransformer): class LogicalTransformer(BaseTransformer):
""" """
Transform python boolean op into Paddle logical op. Transform python boolean op into Paddle logical op.
......
...@@ -32,6 +32,7 @@ from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_un ...@@ -32,6 +32,7 @@ from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_un
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_node, create_get_args_node, create_set_args_node from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_node, create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
__all__ = ['LoopTransformer', 'NameVisitor'] __all__ = ['LoopTransformer', 'NameVisitor']
...@@ -566,7 +567,7 @@ class NameVisitor(gast.NodeVisitor): ...@@ -566,7 +567,7 @@ class NameVisitor(gast.NodeVisitor):
return loop_vars - removed_vars return loop_vars - removed_vars
class LoopTransformer(gast.NodeTransformer): class LoopTransformer(BaseTransformer):
""" """
This class transforms python while/for statement into Static Graph Ast This class transforms python while/for statement into Static Graph Ast
""" """
......
...@@ -17,9 +17,10 @@ from __future__ import print_function ...@@ -17,9 +17,10 @@ from __future__ import print_function
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class PrintTransformer(gast.NodeTransformer): class PrintTransformer(BaseTransformer):
""" """
This class transforms python print function to fluid.layers.Print. This class transforms python print function to fluid.layers.Print.
""" """
......
...@@ -21,6 +21,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list ...@@ -21,6 +21,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import ForToWhileTransformer from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import ForToWhileTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
__all__ = [ __all__ = [
'RETURN_NO_VALUE_MAGIC_NUM', 'RETURN_NO_VALUE_VAR_NAME', 'ReturnTransformer' 'RETURN_NO_VALUE_MAGIC_NUM', 'RETURN_NO_VALUE_VAR_NAME', 'ReturnTransformer'
...@@ -57,7 +58,7 @@ def get_return_size(return_node): ...@@ -57,7 +58,7 @@ def get_return_size(return_node):
return return_length return return_length
class ReplaceReturnNoneTransformer(gast.NodeTransformer): class ReplaceReturnNoneTransformer(BaseTransformer):
""" """
Replace 'return None' to 'return' because 'None' cannot be a valid input Replace 'return None' to 'return' because 'None' cannot be a valid input
in control flow. In ReturnTransformer single 'Return' will be appended no in control flow. In ReturnTransformer single 'Return' will be appended no
...@@ -133,7 +134,7 @@ class ReturnAnalysisVisitor(gast.NodeVisitor): ...@@ -133,7 +134,7 @@ class ReturnAnalysisVisitor(gast.NodeVisitor):
return self.max_return_length[func_node] return self.max_return_length[func_node]
class ReturnTransformer(gast.NodeTransformer): class ReturnTransformer(BaseTransformer):
""" """
Transforms return statements into equivalent python statements containing Transforms return statements into equivalent python statements containing
only one return statement at last. The basics idea is using a return value only one return statement at last. The basics idea is using a return value
......
...@@ -18,9 +18,10 @@ from paddle.utils import gast ...@@ -18,9 +18,10 @@ from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class TensorShapeTransformer(gast.NodeTransformer): class TensorShapeTransformer(BaseTransformer):
""" """
This class transforms variable.shape into Static Graph Ast. This class transforms variable.shape into Static Graph Ast.
All 'xxx.shape' will be converted int '_jst.Shape(x)'. All 'xxx.shape' will be converted int '_jst.Shape(x)'.
......
...@@ -644,7 +644,12 @@ def ast_to_source_code(ast_node): ...@@ -644,7 +644,12 @@ def ast_to_source_code(ast_node):
type(ast_node)) type(ast_node))
if isinstance(ast_node, gast.AST): if isinstance(ast_node, gast.AST):
ast_node = gast.gast_to_ast(ast_node) ast_node = gast.gast_to_ast(ast_node)
source_code = astor.to_source(ast_node)
# Do not wrap lines even if they are too long
def pretty_source(source):
return ''.join(source)
source_code = astor.to_source(ast_node, pretty_source=pretty_source)
return source_code return source_code
......
...@@ -196,14 +196,17 @@ class TestDygraphToStaticCode(unittest.TestCase): ...@@ -196,14 +196,17 @@ class TestDygraphToStaticCode(unittest.TestCase):
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else) code = program_translator.get_code(dyfunc_with_if_else)
answer = get_source_code(StaticCode1.dyfunc_with_if_else) answer = get_source_code(StaticCode1.dyfunc_with_if_else)
self.assertEqual(answer, code) self.assertEqual(
answer.replace('\n', '').replace(' ', ''),
code.replace('\n', '').replace(' ', ''))
def test_program_translator(self): def test_program_translator(self):
answer = get_source_code(StaticCode2.dyfunc_with_if_else) answer = get_source_code(StaticCode2.dyfunc_with_if_else)
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else) code = program_translator.get_code(dyfunc_with_if_else)
# print(code) self.assertEqual(
self.assertEqual(answer, code) answer.replace('\n', '').replace(' ', ''),
code.replace('\n', '').replace(' ', ''))
class TestEnableDeclarative(unittest.TestCase): class TestEnableDeclarative(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册