未验证 提交 e707ee53 编写于 作者: N Nyakku Shigure 提交者: GitHub

[Dy2St] optimize `print` function convertor to display Tensor at compile time (#48672)

* [Dy2St] refactor convert_print to display Tensor in compile time
上级 595338c6
...@@ -16,149 +16,82 @@ import unittest ...@@ -16,149 +16,82 @@ import unittest
import numpy import numpy
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.jit import ProgramTranslator from paddle.jit import ProgramTranslator, to_static
from paddle.jit.api import declarative
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
# 1. print VarBase # 1. print Tensor
@declarative @to_static
def dyfunc_print_variable(x): def dyfunc_print_variable(x):
"""
PY2:
Print(dest=None, values=[Name(id='x_v', annotation=None, type_comment=None)], nl=True)],
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[Name(id='x_v', annotation=None, type_comment=None)],
keywords=[]))
"""
# NOTE: transform to static code, var name will be changed # NOTE: transform to static code, var name will be changed
x_v = fluid.dygraph.to_variable(x) x_t = paddle.to_tensor(x)
print(x_v) print(x_t)
# 2. print ndarray # 2. print ndarray
@declarative @to_static
def dyfunc_print_ndarray(x): def dyfunc_print_ndarray(x):
"""
PY2:
Print(dest=None, values=[Name(id='x', annotation=None, type_comment=None)
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[Name(id='x', annotation=None, type_comment=None)],
keywords=[]))
"""
print(x) print(x)
# 3. print VarBase with format # 3. print Tensor with format
@declarative @to_static
def dyfunc_print_with_format(x): def dyfunc_print_with_format(x):
""" x_t = paddle.to_tensor(x)
PY2: print("PrintTensor: {}".format(x_t))
Print(dest=None,
values=[
Call( # 4. print Tensor with format 2
func=Attribute(value=Constant(value='PrintVariable: {}', kind=None), attr='format'), @to_static
args=[Name(id='x_v', annotation=None, type_comment=None)],
keywords=[])],
nl=True)
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[
Call(
func=Attribute(value=Constant(value='PrintVariable: {}', kind=None), attr='format'),
args=[Name(id='x_v', annotation=None, type_comment=None)],
keywords=[])],
keywords=[]))
"""
x_v = fluid.dygraph.to_variable(x)
print("PrintVariable: {}".format(x_v))
# 4. print VarBase with format 2
@declarative
def dyfunc_print_with_format2(x): def dyfunc_print_with_format2(x):
""" x_t = paddle.to_tensor(x)
PY2: print("PrintTensor: %s" % (x_t))
Print(dest=None,
values=[
BinOp(left=Constant(value='PrintVariable: %s', kind=None), # 5. print Tensor in control flow1
op=Mod, @to_static
right=Name(id='x_v', annotation=None, type_comment=None))],
nl=True)
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[
BinOp(left=Constant(value='PrintVariable: %s', kind=None),
op=Mod,
right=Name(id='x_v', annotation=None, type_comment=None))],
keywords=[]))
"""
x_v = fluid.dygraph.to_variable(x)
print("PrintVariable: %s" % (x_v))
# 5. print VarBase in control flow1
@declarative
def dyfunc_print_with_ifelse(x): def dyfunc_print_with_ifelse(x):
x_v = fluid.dygraph.to_variable(x) x_t = paddle.to_tensor(x)
if len(x_v.shape) > 1: if len(x_t.shape) > 1:
print(x_v) print(x_t)
else: else:
print(x_v) print(x_t)
# 6. print mutiple VarBases # 6. print multiple Tensor
@declarative @to_static
def dyfunc_print_multi_vars(x): def dyfunc_print_multi_tensor(x):
""" x_t = paddle.to_tensor(x)
# NOTE: y_v type is error before cur PR in this case y_t = x_t * 2
Assign(targets=[Name(id='y_v', annotation=None, type_comment=None)], print(x_t)
value=BinOp(left=Name(id='x_v', annotation=None, type_comment=None), op=Mult, right=Constant(value=2, kind=None))) print(y_t)
"""
x_v = fluid.dygraph.to_variable(x)
y_v = x_v * 2
print(x_v)
print(y_v)
# 7. print continue VarBase # 7. print continue Tensor
@declarative @to_static
def dyfunc_print_continue_vars(x): def dyfunc_print_continue_vars(x):
""" x_t = paddle.to_tensor(x)
PY3: y_t = x_t * 2
Expr( print(x_t, y_t)
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[Name(id='x_v', annotation=None, type_comment=None),
Name(id='y_v', annotation=None, type_comment=None)], # 8. print with kwargs
keywords=[])) @to_static
PY2: def dyfunc_print_with_kwargs(x):
Print(dest=None, x_t = paddle.to_tensor(x)
values=[ print("Tensor", x_t, end='\n\n', sep=': ')
Tuple(
elts=[Name(id='x_v', annotation=None, type_comment=None),
Name(id='y_v', annotation=None, type_comment=None)])],
nl=True)
"""
x_v = fluid.dygraph.to_variable(x)
y_v = x_v * 2
print(x_v, y_v)
class TestPrintBase(unittest.TestCase): class TestPrintBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.input = numpy.ones(5).astype("int32") self.input = numpy.ones(5).astype("int32")
self.place = ( self.place = (
fluid.CUDAPlace(0) paddle.CUDAPlace(0)
if fluid.is_compiled_with_cuda() if paddle.is_compiled_with_cuda()
else fluid.CPUPlace() else paddle.CPUPlace()
) )
self.set_test_func() self.set_test_func()
...@@ -207,9 +140,9 @@ class TestPrintWithIfElse(TestPrintVariable): ...@@ -207,9 +140,9 @@ class TestPrintWithIfElse(TestPrintVariable):
self.dygraph_func = dyfunc_print_with_ifelse self.dygraph_func = dyfunc_print_with_ifelse
class TestPrintMultipleVar(TestPrintVariable): class TestPrintMultipleTensor(TestPrintVariable):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dyfunc_print_multi_vars self.dygraph_func = dyfunc_print_multi_tensor
class TestPrintContinueVar(TestPrintVariable): class TestPrintContinueVar(TestPrintVariable):
...@@ -217,5 +150,10 @@ class TestPrintContinueVar(TestPrintVariable): ...@@ -217,5 +150,10 @@ class TestPrintContinueVar(TestPrintVariable):
self.dygraph_func = dyfunc_print_continue_vars self.dygraph_func = dyfunc_print_continue_vars
class TestPrintWithKwargs(TestPrintVariable):
def set_test_func(self):
self.dygraph_func = dyfunc_print_with_kwargs
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -25,7 +25,6 @@ from .convert_operators import convert_len as Len # noqa: F401 ...@@ -25,7 +25,6 @@ from .convert_operators import convert_len as Len # noqa: F401
from .convert_operators import convert_logical_not as Not # noqa: F401 from .convert_operators import convert_logical_not as Not # noqa: F401
from .convert_operators import convert_logical_or as Or # noqa: F401 from .convert_operators import convert_logical_or as Or # noqa: F401
from .convert_operators import convert_pop as Pop # noqa: F401 from .convert_operators import convert_pop as Pop # noqa: F401
from .convert_operators import convert_print as Print # noqa: F401
from .convert_operators import convert_shape as Shape # noqa: F401 from .convert_operators import convert_shape as Shape # noqa: F401
from .convert_operators import convert_while_loop as While # noqa: F401 from .convert_operators import convert_while_loop as While # noqa: F401
from .convert_operators import unpack_by_structure as Unpack # noqa: F401 from .convert_operators import unpack_by_structure as Unpack # noqa: F401
......
...@@ -52,9 +52,6 @@ from .logical_transformer import ( ...@@ -52,9 +52,6 @@ from .logical_transformer import (
from .loop_transformer import ( from .loop_transformer import (
LoopTransformer, LoopTransformer,
) )
from .print_transformer import (
PrintTransformer,
)
from .return_transformer import ( from .return_transformer import (
ReturnTransformer, ReturnTransformer,
) )
...@@ -135,7 +132,6 @@ class DygraphToStaticAst(BaseTransformer): ...@@ -135,7 +132,6 @@ class DygraphToStaticAst(BaseTransformer):
LoopTransformer, # for/while -> while_op LoopTransformer, # for/while -> while_op
IfElseTransformer, # if/else -> cond_op IfElseTransformer, # if/else -> cond_op
AssertTransformer, # assert statement AssertTransformer, # assert statement
PrintTransformer, # print statement
CallTransformer, # transform call recursively CallTransformer, # transform call recursively
CastTransformer, # type casting statement CastTransformer, # type casting statement
DecoratorTransformer, # transform decorators to function call DecoratorTransformer, # transform decorators to function call
......
...@@ -60,6 +60,7 @@ class CallTransformer(BaseTransformer): ...@@ -60,6 +60,7 @@ class CallTransformer(BaseTransformer):
'zip', 'zip',
'range', 'range',
'enumerate', 'enumerate',
'print',
} }
is_builtin = eval("is_builtin({})".format(func_str)) is_builtin = eval("is_builtin({})".format(func_str))
need_convert = func_str in need_convert_builtin_func_list need_convert = func_str in need_convert_builtin_func_list
......
...@@ -30,6 +30,7 @@ from .convert_operators import ( ...@@ -30,6 +30,7 @@ from .convert_operators import (
convert_zip, convert_zip,
convert_range, convert_range,
convert_enumerate, convert_enumerate,
convert_print,
) )
from paddle.jit.dy2static.logging_utils import ( from paddle.jit.dy2static.logging_utils import (
...@@ -215,6 +216,9 @@ def convert_call(func): ...@@ -215,6 +216,9 @@ def convert_call(func):
if is_builtin(func, "enumerate"): if is_builtin(func, "enumerate"):
return convert_enumerate return convert_enumerate
if is_builtin(func, "print"):
return convert_print
if is_builtin(func) or is_unsupported(func): if is_builtin(func) or is_unsupported(func):
return func return func
......
...@@ -736,17 +736,15 @@ def convert_assert(cond, message=""): ...@@ -736,17 +736,15 @@ def convert_assert(cond, message=""):
assert cond, message assert cond, message
def convert_print(*args): def convert_print(*objects, sep=' ', end='\n', file=None, flush=False):
""" """
A function representing Python ``print`` statement. Note: this is a basic A function representing Python ``print`` function. It will print all arguments
python function so we haven't handle sep, end, file and flush parameters of at compile time and only print the Tensor values at runtime.
python function.
""" """
for var in args: for obj in objects:
if isinstance(var, Variable): if isinstance(obj, Variable):
var = Print(var) Print(obj)
else: print(*objects, sep=sep, end=end, file=file, flush=flush)
print(var)
def convert_pop(target, *args): def convert_pop(target, *args):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.utils import gast
from paddle.jit.dy2static.static_analysis import (
AstNodeWrapper,
StaticAnalysisVisitor,
)
from .base_transformer import (
BaseTransformer,
)
class PrintTransformer(BaseTransformer):
"""
This class transforms python print function to fluid.layers.Print.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of PrintTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
self.node_to_wrapper_map = (
self.static_analysis_visitor.get_node_to_wrapper_map()
)
def transform(self):
self.visit(self.root)
# NOTE: deal with print in PY3
def visit_Call(self, node):
if isinstance(node.func, gast.Name) and node.func.id == 'print':
node = self._create_print_node(node.args)
return node
# NOTE: deal with print in PY2
def visit_Print(self, node):
convert_print_node = self._create_print_node(node.values)
return gast.Expr(value=convert_print_node)
def _create_print_node(self, print_args):
convert_print_func = gast.parse('_jst.Print').body[0].value
return gast.Call(func=convert_print_func, args=print_args, keywords=[])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册