未验证 提交 60da8854 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Modify print for dynamic type (#25612)

Modify the print in Dy2stat for dynamic type. Unit test is covered in old test_print.py
上级 dfe4e67e
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.framework import Variable, core from paddle.fluid.framework import core, Variable
from paddle.fluid.layers import Assert, cast, control_flow, logical_and, logical_not, logical_or, nn from paddle.fluid.layers import Assert, Print
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
def convert_while_loop(cond, body, loop_vars): def convert_while_loop(cond, body, loop_vars):
...@@ -271,3 +272,16 @@ def convert_assert(cond, message=""): ...@@ -271,3 +272,16 @@ def convert_assert(cond, message=""):
return Assert(cond) return Assert(cond)
else: else:
assert cond, message assert cond, message
def convert_print(*args):
"""
A function representing Python ``print`` statement. Note: this is a basic
python function so we haven't handle sep, end, file and flush parameters of
python function.
"""
for var in args:
if isinstance(var, Variable):
var = Print(var)
else:
print(var)
...@@ -47,84 +47,17 @@ class PrintTransformer(gast.NodeTransformer): ...@@ -47,84 +47,17 @@ class PrintTransformer(gast.NodeTransformer):
# NOTE: deal with print in PY3 # NOTE: deal with print in PY3
def visit_Call(self, node): def visit_Call(self, node):
if isinstance(node.func, gast.Name) and node.func.id == 'print': if isinstance(node.func, gast.Name) and node.func.id == 'print':
parent_node = self.node_to_wrapper_map[node].parent.node convert_print_node = self._create_print_node(node.args)
if isinstance(parent_node, gast.Expr): return gast.Expr(value=convert_print_node)
# NOTE: why need transform to gast.Assign node
# only fluid.layers.Print(x) will be pruned when exe.run(use_prune=True)
print_assign_node = self._create_assign_node(node)
if print_assign_node is not None:
return print_assign_node
else:
return self._transform_call_node(node)
return node return node
# NOTE: deal with print in PY2 # NOTE: deal with print in PY2
def visit_Print(self, node): def visit_Print(self, node):
print_assign_node = self._create_assign_node(node) convert_print_node = self._create_print_node(node.values)
if print_assign_node is not None: return gast.Expr(value=convert_print_node)
return print_assign_node
return node def _create_print_node(self, print_args):
convert_print_func = gast.parse(
def _transform_call_node(self, node): 'fluid.dygraph.dygraph_to_static.convert_operators.convert_print'
assert isinstance(node, gast.Call), "visit Node is not gast.Call node." ).body[0].value
var_node = self._get_print_var_node(node) return gast.Call(func=convert_print_func, args=print_args, keywords=[])
if var_node is None:
return node
if self._need_transform(var_node, node):
return self._build_print_call_node(var_node)
return node
def _create_assign_node(self, node):
var_node = self._get_print_var_node(node)
if var_node is None:
return None
if self._need_transform(var_node, node):
return gast.Assign(
targets=[var_node], value=self._build_print_call_node(var_node))
return None
def _build_print_call_node(self, node):
return gast.Call(
func=gast.parse('fluid.layers.Print').body[0].value,
args=[node],
keywords=[
gast.keyword(
arg='summarize',
value=gast.UnaryOp(
op=gast.USub(),
operand=gast.Constant(
value=1, kind=None))), gast.keyword(
arg='print_phase',
value=gast.Constant(
value='forward', kind=None))
])
def _get_print_var_node(self, node):
if isinstance(node, gast.Call):
var_list = node.args
elif isinstance(node, gast.Print):
var_list = node.values
if isinstance(var_list[0], gast.Tuple):
var_list = var_list[0].elts
# TODO: support print multiple Var
if len(var_list) == 1:
return var_list[0]
else:
_logger.warning(
"ProgramTranslator could not transform printing multiple values like < %s > now and will run it as-is."
% ast_to_source_code(node).strip())
return None
def _need_transform(self, var_node, print_node):
if isinstance(var_node, gast.Name):
if self.static_analysis_visitor.is_tensor_node(var_node):
return True
else:
_logger.warning(
"ProgramTranslator could not transform printing value that are not Tensor like < %s > now and will run it as-is."
% ast_to_source_code(print_node).strip())
else:
_logger.warning(
"ProgramTranslator could not transform < %s > now and will run it as-is."
% ast_to_source_code(print_node).strip())
return False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册