未验证 提交 897cec81 编写于 作者: C Chen Weihang 提交者: GitHub

[Dy2static] fix some print transformer problems (#24516)

* fix some print transformer problems, test=develop

* simplify writing & avoid bud, test=develop

* polish detail, test=develop
上级 45ef6ff3
......@@ -15,8 +15,14 @@
from __future__ import print_function
import gast
import logging
from paddle.fluid import log_helper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
_logger = log_helper.get_logger(
__name__, logging.WARNING, fmt='%(asctime)s-%(levelname)s: %(message)s')
class PrintTransformer(gast.NodeTransformer):
......@@ -40,19 +46,60 @@ class PrintTransformer(gast.NodeTransformer):
# NOTE: deal with print in PY3
def visit_Call(self, node):
assert isinstance(node, gast.Call)
if isinstance(node.func, gast.Name) and node.func.id == 'print':
var = self._get_print_var(node)
return self._construct_print_node(var)
parent_node = self.node_to_wrapper_map[node].parent.node
if isinstance(parent_node, gast.Expr):
# NOTE: why need transform to gast.Assign node
# only fluid.layers.Print(x) will be pruned when exe.run(use_prune=True)
print_assign_node = self._create_assign_node(node)
if print_assign_node is not None:
return print_assign_node
else:
return self._transform_call_node(node)
return node
# NOTE: deal with print in PY2
def visit_Print(self, node):
var = self._get_print_var(node)
print_call_node = self._construct_print_node(var)
return gast.Expr(value=print_call_node)
print_assign_node = self._create_assign_node(node)
if print_assign_node is not None:
return print_assign_node
return node
def _transform_call_node(self, node):
assert isinstance(node, gast.Call), "visit Node is not gast.Call node."
var_node = self._get_print_var_node(node)
if var_node is None:
return node
if self._need_transform(var_node, node):
return self._build_print_call_node(var_node)
return node
def _get_print_var(self, node):
def _create_assign_node(self, node):
var_node = self._get_print_var_node(node)
if var_node is None:
return None
if self._need_transform(var_node, node):
return gast.Assign(
targets=[var_node], value=self._build_print_call_node(var_node))
return None
def _build_print_call_node(self, node):
return gast.Call(
func=gast.parse('fluid.layers.Print').body[0].value,
args=[node],
keywords=[
gast.keyword(
arg='summarize',
value=gast.UnaryOp(
op=gast.USub(),
operand=gast.Constant(
value=1, kind=None))), gast.keyword(
arg='print_phase',
value=gast.Constant(
value='forward', kind=None))
])
def _get_print_var_node(self, node):
if isinstance(node, gast.Call):
var_list = node.args
elif isinstance(node, gast.Print):
......@@ -60,24 +107,27 @@ class PrintTransformer(gast.NodeTransformer):
if isinstance(var_list[0], gast.Tuple):
var_list = var_list[0].elts
# TODO: support print multiple Var
assert len(var_list) == 1, "Now only support print one Variable."
if len(var_list) == 1:
return var_list[0]
else:
_logger.warning(
"ProgramTranslator could not transform printing multiple values like < %s > now and will run it as-is."
% ast_to_source_code(node).strip())
return None
def _construct_print_node(self, node):
if isinstance(node, gast.Name):
if self._is_tensor_node(node):
print_node = gast.Call(
func=gast.parse('fluid.layers.Print').body[0].value,
args=[node],
keywords=[])
return print_node
def _need_transform(self, var_node, print_node):
if isinstance(var_node, gast.Name):
if self._is_tensor_node(var_node):
return True
else:
raise TypeError(
"print object type error, only support print Variable now.")
_logger.warning(
"ProgramTranslator could not transform printing value that are not Tensor like < %s > now and will run it as-is."
% ast_to_source_code(print_node).strip())
else:
# TODO: may not only print with format
raise NotImplementedError(
"cannot transform print with format temporarily.")
_logger.warning(
"ProgramTranslator could not transform < %s > now and will run it as-is."
% ast_to_source_code(print_node).strip())
return False
def _is_tensor_node(self, node):
tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES}
......
......@@ -185,35 +185,20 @@ class TestPrintVariable(TestPrintBase):
self.get_static_output()
class TestPrintNdArray(TestPrintBase):
class TestPrintNdArray(TestPrintVariable):
def set_test_func(self):
self.dygraph_func = dyfunc_print_ndarray
def test_transform_static_error(self):
with self.assertRaises(TypeError):
self.get_dygraph_output()
self.get_static_output()
class TestPrintWithFormat(TestPrintBase):
class TestPrintWithFormat(TestPrintVariable):
def set_test_func(self):
self.dygraph_func = dyfunc_print_with_format
def test_transform_static_error(self):
with self.assertRaises(NotImplementedError):
self.get_dygraph_output()
self.get_static_output()
class TestPrintWithFormat2(TestPrintBase):
class TestPrintWithFormat2(TestPrintVariable):
def set_test_func(self):
self.dygraph_func = dyfunc_print_with_format2
def test_transform_static_error(self):
with self.assertRaises(NotImplementedError):
self.get_dygraph_output()
self.get_static_output()
class TestPrintWithIfElse(TestPrintVariable):
def set_test_func(self):
......@@ -225,15 +210,10 @@ class TestPrintMultipleVar(TestPrintVariable):
self.dygraph_func = dyfunc_print_multi_vars
class TestPrintContinueVar(TestPrintBase):
class TestPrintContinueVar(TestPrintVariable):
def set_test_func(self):
self.dygraph_func = dyfunc_print_continue_vars
def test_transform_static_error(self):
with self.assertRaises(AssertionError):
self.get_dygraph_output()
self.get_static_output()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册