未验证 提交 e707ee53 编写于 作者: N Nyakku Shigure 提交者: GitHub

[Dy2St] optimize `print` function convertor to display Tensor at compile time (#48672)

* [Dy2St] refactor convert_print to display Tensor in compile time
上级 595338c6
......@@ -16,149 +16,82 @@ import unittest
import numpy
import paddle
import paddle.fluid as fluid
from paddle.jit import ProgramTranslator
from paddle.jit.api import declarative
from paddle.jit import ProgramTranslator, to_static
program_translator = ProgramTranslator()
# 1. print VarBase
@declarative
# 1. print Tensor
@to_static
def dyfunc_print_variable(x):
"""
PY2:
Print(dest=None, values=[Name(id='x_v', annotation=None, type_comment=None)], nl=True)],
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[Name(id='x_v', annotation=None, type_comment=None)],
keywords=[]))
"""
# NOTE: transform to static code, var name will be changed
x_v = fluid.dygraph.to_variable(x)
print(x_v)
x_t = paddle.to_tensor(x)
print(x_t)
# 2. print ndarray
@declarative
@to_static
def dyfunc_print_ndarray(x):
"""
PY2:
Print(dest=None, values=[Name(id='x', annotation=None, type_comment=None)
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[Name(id='x', annotation=None, type_comment=None)],
keywords=[]))
"""
print(x)
# 3. print VarBase with format
@declarative
# 3. print Tensor with format
@to_static
def dyfunc_print_with_format(x):
"""
PY2:
Print(dest=None,
values=[
Call(
func=Attribute(value=Constant(value='PrintVariable: {}', kind=None), attr='format'),
args=[Name(id='x_v', annotation=None, type_comment=None)],
keywords=[])],
nl=True)
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[
Call(
func=Attribute(value=Constant(value='PrintVariable: {}', kind=None), attr='format'),
args=[Name(id='x_v', annotation=None, type_comment=None)],
keywords=[])],
keywords=[]))
"""
x_v = fluid.dygraph.to_variable(x)
print("PrintVariable: {}".format(x_v))
# 4. print VarBase with format 2
@declarative
x_t = paddle.to_tensor(x)
print("PrintTensor: {}".format(x_t))
# 4. print Tensor with format 2
@to_static
def dyfunc_print_with_format2(x):
"""
PY2:
Print(dest=None,
values=[
BinOp(left=Constant(value='PrintVariable: %s', kind=None),
op=Mod,
right=Name(id='x_v', annotation=None, type_comment=None))],
nl=True)
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[
BinOp(left=Constant(value='PrintVariable: %s', kind=None),
op=Mod,
right=Name(id='x_v', annotation=None, type_comment=None))],
keywords=[]))
"""
x_v = fluid.dygraph.to_variable(x)
print("PrintVariable: %s" % (x_v))
# 5. print VarBase in control flow1
@declarative
x_t = paddle.to_tensor(x)
print("PrintTensor: %s" % (x_t))
# 5. print Tensor in control flow1
@to_static
def dyfunc_print_with_ifelse(x):
x_v = fluid.dygraph.to_variable(x)
if len(x_v.shape) > 1:
print(x_v)
x_t = paddle.to_tensor(x)
if len(x_t.shape) > 1:
print(x_t)
else:
print(x_v)
print(x_t)
# 6. print mutiple VarBases
@declarative
def dyfunc_print_multi_vars(x):
"""
# NOTE: y_v type is error before cur PR in this case
Assign(targets=[Name(id='y_v', annotation=None, type_comment=None)],
value=BinOp(left=Name(id='x_v', annotation=None, type_comment=None), op=Mult, right=Constant(value=2, kind=None)))
"""
x_v = fluid.dygraph.to_variable(x)
y_v = x_v * 2
print(x_v)
print(y_v)
# 6. print multiple Tensor
@to_static
def dyfunc_print_multi_tensor(x):
x_t = paddle.to_tensor(x)
y_t = x_t * 2
print(x_t)
print(y_t)
# 7. print continue VarBase
@declarative
# 7. print continue Tensor
@to_static
def dyfunc_print_continue_vars(x):
"""
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[Name(id='x_v', annotation=None, type_comment=None),
Name(id='y_v', annotation=None, type_comment=None)],
keywords=[]))
PY2:
Print(dest=None,
values=[
Tuple(
elts=[Name(id='x_v', annotation=None, type_comment=None),
Name(id='y_v', annotation=None, type_comment=None)])],
nl=True)
"""
x_v = fluid.dygraph.to_variable(x)
y_v = x_v * 2
print(x_v, y_v)
x_t = paddle.to_tensor(x)
y_t = x_t * 2
print(x_t, y_t)
# 8. print with kwargs
@to_static
def dyfunc_print_with_kwargs(x):
x_t = paddle.to_tensor(x)
print("Tensor", x_t, end='\n\n', sep=': ')
class TestPrintBase(unittest.TestCase):
def setUp(self):
self.input = numpy.ones(5).astype("int32")
self.place = (
fluid.CUDAPlace(0)
if fluid.is_compiled_with_cuda()
else fluid.CPUPlace()
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
self.set_test_func()
......@@ -207,9 +140,9 @@ class TestPrintWithIfElse(TestPrintVariable):
self.dygraph_func = dyfunc_print_with_ifelse
class TestPrintMultipleVar(TestPrintVariable):
class TestPrintMultipleTensor(TestPrintVariable):
def set_test_func(self):
self.dygraph_func = dyfunc_print_multi_vars
self.dygraph_func = dyfunc_print_multi_tensor
class TestPrintContinueVar(TestPrintVariable):
......@@ -217,5 +150,10 @@ class TestPrintContinueVar(TestPrintVariable):
self.dygraph_func = dyfunc_print_continue_vars
class TestPrintWithKwargs(TestPrintVariable):
def set_test_func(self):
self.dygraph_func = dyfunc_print_with_kwargs
if __name__ == '__main__':
unittest.main()
......@@ -25,7 +25,6 @@ from .convert_operators import convert_len as Len # noqa: F401
from .convert_operators import convert_logical_not as Not # noqa: F401
from .convert_operators import convert_logical_or as Or # noqa: F401
from .convert_operators import convert_pop as Pop # noqa: F401
from .convert_operators import convert_print as Print # noqa: F401
from .convert_operators import convert_shape as Shape # noqa: F401
from .convert_operators import convert_while_loop as While # noqa: F401
from .convert_operators import unpack_by_structure as Unpack # noqa: F401
......
......@@ -52,9 +52,6 @@ from .logical_transformer import (
from .loop_transformer import (
LoopTransformer,
)
from .print_transformer import (
PrintTransformer,
)
from .return_transformer import (
ReturnTransformer,
)
......@@ -135,7 +132,6 @@ class DygraphToStaticAst(BaseTransformer):
LoopTransformer, # for/while -> while_op
IfElseTransformer, # if/else -> cond_op
AssertTransformer, # assert statement
PrintTransformer, # print statement
CallTransformer, # transform call recursively
CastTransformer, # type casting statement
DecoratorTransformer, # transform decorators to function call
......
......@@ -60,6 +60,7 @@ class CallTransformer(BaseTransformer):
'zip',
'range',
'enumerate',
'print',
}
is_builtin = eval("is_builtin({})".format(func_str))
need_convert = func_str in need_convert_builtin_func_list
......
......@@ -30,6 +30,7 @@ from .convert_operators import (
convert_zip,
convert_range,
convert_enumerate,
convert_print,
)
from paddle.jit.dy2static.logging_utils import (
......@@ -215,6 +216,9 @@ def convert_call(func):
if is_builtin(func, "enumerate"):
return convert_enumerate
if is_builtin(func, "print"):
return convert_print
if is_builtin(func) or is_unsupported(func):
return func
......
......@@ -736,17 +736,15 @@ def convert_assert(cond, message=""):
assert cond, message
def convert_print(*args):
def convert_print(*objects, sep=' ', end='\n', file=None, flush=False):
"""
A function representing Python ``print`` statement. Note: this is a basic
python function so we haven't handle sep, end, file and flush parameters of
python function.
A function representing Python ``print`` function. It will print all arguments
at compile time and only print the Tensor values at runtime.
"""
for var in args:
if isinstance(var, Variable):
var = Print(var)
else:
print(var)
for obj in objects:
if isinstance(obj, Variable):
Print(obj)
print(*objects, sep=sep, end=end, file=file, flush=flush)
def convert_pop(target, *args):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.utils import gast
from paddle.jit.dy2static.static_analysis import (
AstNodeWrapper,
StaticAnalysisVisitor,
)
from .base_transformer import (
BaseTransformer,
)
class PrintTransformer(BaseTransformer):
"""
This class transforms python print function to fluid.layers.Print.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of PrintTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
self.node_to_wrapper_map = (
self.static_analysis_visitor.get_node_to_wrapper_map()
)
def transform(self):
self.visit(self.root)
# NOTE: deal with print in PY3
def visit_Call(self, node):
if isinstance(node.func, gast.Name) and node.func.id == 'print':
node = self._create_print_node(node.args)
return node
# NOTE: deal with print in PY2
def visit_Print(self, node):
convert_print_node = self._create_print_node(node.values)
return gast.Expr(value=convert_print_node)
def _create_print_node(self, print_args):
convert_print_func = gast.parse('_jst.Print').body[0].value
return gast.Call(func=convert_print_func, args=print_args, keywords=[])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册