From 897cec811a1125983931f08a2acbc2dc02e73d8b Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 15 May 2020 11:15:32 +0800 Subject: [PATCH] [Dy2static] fix some print transformer problems (#24516) * fix some print transformer problems, test=develop * simplify writing & avoid bud, test=develop * polish detail, test=develop --- .../dygraph_to_static/print_transformer.py | 96 ++++++++++++++----- .../unittests/dygraph_to_static/test_print.py | 28 +----- 2 files changed, 77 insertions(+), 47 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py index 69035133335..5c45cc8a600 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py @@ -15,8 +15,14 @@ from __future__ import print_function import gast +import logging +from paddle.fluid import log_helper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code + +_logger = log_helper.get_logger( + __name__, logging.WARNING, fmt='%(asctime)s-%(levelname)s: %(message)s') class PrintTransformer(gast.NodeTransformer): @@ -40,19 +46,60 @@ class PrintTransformer(gast.NodeTransformer): # NOTE: deal with print in PY3 def visit_Call(self, node): - assert isinstance(node, gast.Call) if isinstance(node.func, gast.Name) and node.func.id == 'print': - var = self._get_print_var(node) - return self._construct_print_node(var) + parent_node = self.node_to_wrapper_map[node].parent.node + if isinstance(parent_node, gast.Expr): + # NOTE: why need transform to gast.Assign node + # only fluid.layers.Print(x) will be pruned when exe.run(use_prune=True) + print_assign_node = self._create_assign_node(node) + if print_assign_node is not None: + return print_assign_node + else: + return self._transform_call_node(node) return node # NOTE: deal with print in PY2 def visit_Print(self, node): - var = self._get_print_var(node) - print_call_node = self._construct_print_node(var) - return gast.Expr(value=print_call_node) + print_assign_node = self._create_assign_node(node) + if print_assign_node is not None: + return print_assign_node + return node + + def _transform_call_node(self, node): + assert isinstance(node, gast.Call), "visit Node is not gast.Call node." + var_node = self._get_print_var_node(node) + if var_node is None: + return node + if self._need_transform(var_node, node): + return self._build_print_call_node(var_node) + return node + + def _create_assign_node(self, node): + var_node = self._get_print_var_node(node) + if var_node is None: + return None + if self._need_transform(var_node, node): + return gast.Assign( + targets=[var_node], value=self._build_print_call_node(var_node)) + return None - def _get_print_var(self, node): + def _build_print_call_node(self, node): + return gast.Call( + func=gast.parse('fluid.layers.Print').body[0].value, + args=[node], + keywords=[ + gast.keyword( + arg='summarize', + value=gast.UnaryOp( + op=gast.USub(), + operand=gast.Constant( + value=1, kind=None))), gast.keyword( + arg='print_phase', + value=gast.Constant( + value='forward', kind=None)) + ]) + + def _get_print_var_node(self, node): if isinstance(node, gast.Call): var_list = node.args elif isinstance(node, gast.Print): @@ -60,24 +107,27 @@ class PrintTransformer(gast.NodeTransformer): if isinstance(var_list[0], gast.Tuple): var_list = var_list[0].elts # TODO: support print multiple Var - assert len(var_list) == 1, "Now only support print one Variable." - return var_list[0] - - def _construct_print_node(self, node): - if isinstance(node, gast.Name): - if self._is_tensor_node(node): - print_node = gast.Call( - func=gast.parse('fluid.layers.Print').body[0].value, - args=[node], - keywords=[]) - return print_node + if len(var_list) == 1: + return var_list[0] + else: + _logger.warning( + "ProgramTranslator could not transform printing multiple values like < %s > now and will run it as-is." + % ast_to_source_code(node).strip()) + return None + + def _need_transform(self, var_node, print_node): + if isinstance(var_node, gast.Name): + if self._is_tensor_node(var_node): + return True else: - raise TypeError( - "print object type error, only support print Variable now.") + _logger.warning( + "ProgramTranslator could not transform printing value that are not Tensor like < %s > now and will run it as-is." + % ast_to_source_code(print_node).strip()) else: - # TODO: may not only print with format - raise NotImplementedError( - "cannot transform print with format temporarily.") + _logger.warning( + "ProgramTranslator could not transform < %s > now and will run it as-is." + % ast_to_source_code(print_node).strip()) + return False def _is_tensor_node(self, node): tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES} diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_print.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_print.py index 8f2e75aa4c8..aabfd3b2c48 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_print.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_print.py @@ -185,35 +185,20 @@ class TestPrintVariable(TestPrintBase): self.get_static_output() -class TestPrintNdArray(TestPrintBase): +class TestPrintNdArray(TestPrintVariable): def set_test_func(self): self.dygraph_func = dyfunc_print_ndarray - def test_transform_static_error(self): - with self.assertRaises(TypeError): - self.get_dygraph_output() - self.get_static_output() - -class TestPrintWithFormat(TestPrintBase): +class TestPrintWithFormat(TestPrintVariable): def set_test_func(self): self.dygraph_func = dyfunc_print_with_format - def test_transform_static_error(self): - with self.assertRaises(NotImplementedError): - self.get_dygraph_output() - self.get_static_output() - -class TestPrintWithFormat2(TestPrintBase): +class TestPrintWithFormat2(TestPrintVariable): def set_test_func(self): self.dygraph_func = dyfunc_print_with_format2 - def test_transform_static_error(self): - with self.assertRaises(NotImplementedError): - self.get_dygraph_output() - self.get_static_output() - class TestPrintWithIfElse(TestPrintVariable): def set_test_func(self): @@ -225,15 +210,10 @@ class TestPrintMultipleVar(TestPrintVariable): self.dygraph_func = dyfunc_print_multi_vars -class TestPrintContinueVar(TestPrintBase): +class TestPrintContinueVar(TestPrintVariable): def set_test_func(self): self.dygraph_func = dyfunc_print_continue_vars - def test_transform_static_error(self): - with self.assertRaises(AssertionError): - self.get_dygraph_output() - self.get_static_output() - if __name__ == '__main__': unittest.main() -- GitLab