未验证 提交 897cec81 编写于 作者: C Chen Weihang 提交者: GitHub

[Dy2static] fix some print transformer problems (#24516)

* fix some print transformer problems, test=develop

* simplify writing & avoid bud, test=develop

* polish detail, test=develop
上级 45ef6ff3
...@@ -15,8 +15,14 @@ ...@@ -15,8 +15,14 @@
from __future__ import print_function from __future__ import print_function
import gast import gast
import logging
from paddle.fluid import log_helper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
_logger = log_helper.get_logger(
__name__, logging.WARNING, fmt='%(asctime)s-%(levelname)s: %(message)s')
class PrintTransformer(gast.NodeTransformer): class PrintTransformer(gast.NodeTransformer):
...@@ -40,19 +46,60 @@ class PrintTransformer(gast.NodeTransformer): ...@@ -40,19 +46,60 @@ class PrintTransformer(gast.NodeTransformer):
# NOTE: deal with print in PY3 # NOTE: deal with print in PY3
def visit_Call(self, node): def visit_Call(self, node):
assert isinstance(node, gast.Call)
if isinstance(node.func, gast.Name) and node.func.id == 'print': if isinstance(node.func, gast.Name) and node.func.id == 'print':
var = self._get_print_var(node) parent_node = self.node_to_wrapper_map[node].parent.node
return self._construct_print_node(var) if isinstance(parent_node, gast.Expr):
# NOTE: why need transform to gast.Assign node
# only fluid.layers.Print(x) will be pruned when exe.run(use_prune=True)
print_assign_node = self._create_assign_node(node)
if print_assign_node is not None:
return print_assign_node
else:
return self._transform_call_node(node)
return node return node
# NOTE: deal with print in PY2 # NOTE: deal with print in PY2
def visit_Print(self, node): def visit_Print(self, node):
var = self._get_print_var(node) print_assign_node = self._create_assign_node(node)
print_call_node = self._construct_print_node(var) if print_assign_node is not None:
return gast.Expr(value=print_call_node) return print_assign_node
return node
def _transform_call_node(self, node):
assert isinstance(node, gast.Call), "visit Node is not gast.Call node."
var_node = self._get_print_var_node(node)
if var_node is None:
return node
if self._need_transform(var_node, node):
return self._build_print_call_node(var_node)
return node
def _get_print_var(self, node): def _create_assign_node(self, node):
var_node = self._get_print_var_node(node)
if var_node is None:
return None
if self._need_transform(var_node, node):
return gast.Assign(
targets=[var_node], value=self._build_print_call_node(var_node))
return None
def _build_print_call_node(self, node):
return gast.Call(
func=gast.parse('fluid.layers.Print').body[0].value,
args=[node],
keywords=[
gast.keyword(
arg='summarize',
value=gast.UnaryOp(
op=gast.USub(),
operand=gast.Constant(
value=1, kind=None))), gast.keyword(
arg='print_phase',
value=gast.Constant(
value='forward', kind=None))
])
def _get_print_var_node(self, node):
if isinstance(node, gast.Call): if isinstance(node, gast.Call):
var_list = node.args var_list = node.args
elif isinstance(node, gast.Print): elif isinstance(node, gast.Print):
...@@ -60,24 +107,27 @@ class PrintTransformer(gast.NodeTransformer): ...@@ -60,24 +107,27 @@ class PrintTransformer(gast.NodeTransformer):
if isinstance(var_list[0], gast.Tuple): if isinstance(var_list[0], gast.Tuple):
var_list = var_list[0].elts var_list = var_list[0].elts
# TODO: support print multiple Var # TODO: support print multiple Var
assert len(var_list) == 1, "Now only support print one Variable." if len(var_list) == 1:
return var_list[0] return var_list[0]
else:
_logger.warning(
"ProgramTranslator could not transform printing multiple values like < %s > now and will run it as-is."
% ast_to_source_code(node).strip())
return None
def _construct_print_node(self, node): def _need_transform(self, var_node, print_node):
if isinstance(node, gast.Name): if isinstance(var_node, gast.Name):
if self._is_tensor_node(node): if self._is_tensor_node(var_node):
print_node = gast.Call( return True
func=gast.parse('fluid.layers.Print').body[0].value,
args=[node],
keywords=[])
return print_node
else: else:
raise TypeError( _logger.warning(
"print object type error, only support print Variable now.") "ProgramTranslator could not transform printing value that are not Tensor like < %s > now and will run it as-is."
% ast_to_source_code(print_node).strip())
else: else:
# TODO: may not only print with format _logger.warning(
raise NotImplementedError( "ProgramTranslator could not transform < %s > now and will run it as-is."
"cannot transform print with format temporarily.") % ast_to_source_code(print_node).strip())
return False
def _is_tensor_node(self, node): def _is_tensor_node(self, node):
tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES} tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES}
......
...@@ -185,35 +185,20 @@ class TestPrintVariable(TestPrintBase): ...@@ -185,35 +185,20 @@ class TestPrintVariable(TestPrintBase):
self.get_static_output() self.get_static_output()
class TestPrintNdArray(TestPrintBase): class TestPrintNdArray(TestPrintVariable):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dyfunc_print_ndarray self.dygraph_func = dyfunc_print_ndarray
def test_transform_static_error(self):
with self.assertRaises(TypeError):
self.get_dygraph_output()
self.get_static_output()
class TestPrintWithFormat(TestPrintBase): class TestPrintWithFormat(TestPrintVariable):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dyfunc_print_with_format self.dygraph_func = dyfunc_print_with_format
def test_transform_static_error(self):
with self.assertRaises(NotImplementedError):
self.get_dygraph_output()
self.get_static_output()
class TestPrintWithFormat2(TestPrintBase): class TestPrintWithFormat2(TestPrintVariable):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dyfunc_print_with_format2 self.dygraph_func = dyfunc_print_with_format2
def test_transform_static_error(self):
with self.assertRaises(NotImplementedError):
self.get_dygraph_output()
self.get_static_output()
class TestPrintWithIfElse(TestPrintVariable): class TestPrintWithIfElse(TestPrintVariable):
def set_test_func(self): def set_test_func(self):
...@@ -225,15 +210,10 @@ class TestPrintMultipleVar(TestPrintVariable): ...@@ -225,15 +210,10 @@ class TestPrintMultipleVar(TestPrintVariable):
self.dygraph_func = dyfunc_print_multi_vars self.dygraph_func = dyfunc_print_multi_vars
class TestPrintContinueVar(TestPrintBase): class TestPrintContinueVar(TestPrintVariable):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dyfunc_print_continue_vars self.dygraph_func = dyfunc_print_continue_vars
def test_transform_static_error(self):
with self.assertRaises(AssertionError):
self.get_dygraph_output()
self.get_static_output()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册