未验证 提交 2d873008 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Replace paddle.jit.dy2stat with _jst (#42947)

* [Dy2Stat]Replace paddle.jit.dy2stat with _jst

* [Dy2Stat]Replace paddle.jit.dy2stat with _jst

* refine code style

* refine code style
上级 a76f2b33
......@@ -37,7 +37,7 @@ class AssertTransformer(gast.NodeTransformer):
def visit_Assert(self, node):
convert_assert_node = gast.parse(
'paddle.jit.dy2static.convert_assert({test}, {msg})'.format(
'_jst.convert_assert({test}, {msg})'.format(
test=ast_to_source_code(node.test),
msg=ast_to_source_code(node.msg)
if node.msg else "")).body[0].value
......
......@@ -71,7 +71,7 @@ class CallTransformer(gast.NodeTransformer):
if PDB_SET in func_str:
return node
new_func_str = "paddle.jit.dy2static.convert_call({})".format(func_str)
new_func_str = "_jst.convert_call({})".format(func_str)
new_func_ast = gast.parse(new_func_str).body[0].value
node.func = new_func_ast
......
......@@ -39,8 +39,8 @@ class CastTransformer(gast.NodeTransformer):
func_str = ast_to_source_code(node.func).strip()
if func_str in self._castable_type and len(node.args) > 0:
args_str = ast_to_source_code(node.args[0]).strip()
new_func_str = "paddle.jit.dy2static.convert_var_dtype({}, '{}')".format(
args_str, func_str)
new_func_str = "_jst.convert_var_dtype({}, '{}')".format(args_str,
func_str)
new_node = gast.parse(new_func_str).body[0].value
return new_node
......
......@@ -536,7 +536,7 @@ def create_convert_ifelse_node(return_name_ids,
return_vars = create_name_nodes(return_name_ids)
convert_ifelse_layer = gast.parse(
'paddle.jit.dy2static.convert_ifelse('
'_jst.convert_ifelse('
'{pred}, {true_fn}, {false_fn}, {true_args}, {false_args}, {return_vars})'.
format(
pred=ast_to_source_code(pred),
......
......@@ -129,7 +129,7 @@ class ListTransformer(gast.NodeTransformer):
elif slice_is_num(target_node):
value_code = ast_to_source_code(node.value)
i = "paddle.cast(" \
"x=paddle.jit.dy2static.to_static_variable({})," \
"x=_jst.to_static_variable({})," \
"dtype='int64')".format(ast_to_source_code(slice_node))
assign_code = "{} = paddle.tensor.array_write(x={}, i={}, array={})" \
.format(target_name, value_code, i, target_name)
......@@ -252,7 +252,7 @@ class ListTransformer(gast.NodeTransformer):
# 2. pop stmt for a list or dict if len(args_str) == 1
# 3. pop stmt for a dict if len(args_str) == 2
if len(args_str) <= 2:
new_pop_str = "paddle.jit.dy2static.convert_pop({}, {})"\
new_pop_str = "_jst.convert_pop({}, {})"\
.format(target_str, ",".join(args_str))
new_pop_node = gast.parse(new_pop_str).body[0].value
return new_pop_node
......
......@@ -57,8 +57,7 @@ class LogicalTransformer(gast.NodeTransformer):
self.generic_visit(node)
if isinstance(node.op, gast.Not):
arg = ast_to_source_code(node.operand)
new_node_str = "paddle.jit.dy2static.convert_logical_not({})".format(
arg)
new_node_str = "_jst.convert_logical_not({})".format(arg)
# NOTE: gast.parse returns Module(body=[expr(value=...)])
new_node = gast.parse(new_node_str).body[0].value
return new_node
......@@ -67,13 +66,12 @@ class LogicalTransformer(gast.NodeTransformer):
def visit_Compare(self, node):
self.generic_visit(node)
left_str = ast_to_source_code(node.left).strip()
if left_str.startswith("paddle.jit.dy2static.convert_var_shape"):
if left_str.startswith("_jst.convert_var_shape"):
# check left and comparators are all converted var shape
compare_arg_strs = left_str
for i, comparator in enumerate(node.comparators):
comparator_str = ast_to_source_code(comparator).strip()
if not comparator_str.startswith(
"paddle.jit.dy2static.convert_var_shape"):
if not comparator_str.startswith("_jst.convert_var_shape"):
return node
op_str = cmpop_node_to_str(node.ops[i])
compare_arg_strs += (", '" + op_str + "', " + comparator_str)
......@@ -81,7 +79,7 @@ class LogicalTransformer(gast.NodeTransformer):
# Now all left and comparators are converted shape
# Replace some comparsion operation because of difference between
# Python and Paddle
new_node_str = "paddle.jit.dy2static.convert_shape_compare({})".format(
new_node_str = "_jst.convert_shape_compare({})".format(
compare_arg_strs)
new_node = gast.parse(new_node_str).body[0].value
return new_node
......@@ -119,7 +117,7 @@ class LogicalTransformer(gast.NodeTransformer):
nodes = [pre_logic_node] + [post_logic_node]
args = [ast_to_source_code(child) for child in nodes]
new_node_str = "paddle.jit.dy2static.convert_logical_{}(lambda:{}, lambda:{})".format(
new_node_str = "_jst.convert_logical_{}(lambda:{}, lambda:{})".format(
api_type, args[0], args[1])
# NOTE: gast.parse return Module(body=[expr(...)])
new_node = gast.parse(new_node_str).body[0].value
......
......@@ -89,7 +89,7 @@ def create_while_nodes(condition_name, body_name, loop_var_names):
else:
assign_loop_var_names.append(name)
while_func_name = "paddle.jit.dy2static.convert_while_loop"
while_func_name = "_jst.convert_while_loop"
while_node_str = "[{}] = {}({}, {}, [{}])".format(
",".join(assign_loop_var_names), while_func_name, condition_name,
body_name, ",".join(loop_var_names))
......
......@@ -50,6 +50,5 @@ class PrintTransformer(gast.NodeTransformer):
return gast.Expr(value=convert_print_node)
def _create_print_node(self, print_args):
convert_print_func = gast.parse(
'paddle.jit.dy2static.convert_print').body[0].value
convert_print_func = gast.parse('_jst.convert_print').body[0].value
return gast.Call(func=convert_print_func, args=print_args, keywords=[])
......@@ -336,7 +336,7 @@ class ReturnTransformer(gast.NodeTransformer):
# Here assume that the parent node of return is gast.If
if isinstance(parent_node_of_return, gast.If):
# Prepend control flow boolean nodes such as '__return@1 = True'
node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, True)".format(
node_str = "{} = _jst.create_bool_as_type({}, True)".format(
return_name,
ast_to_source_code(parent_node_of_return.test).strip())
......@@ -449,7 +449,7 @@ class ReturnTransformer(gast.NodeTransformer):
# Here assume that the parent node of return is gast.If
if isinstance(parent_node_of_return, gast.If):
# Prepend control flow boolean nodes such as '__return@1 = False'
node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, False)".format(
node_str = "{} = _jst.create_bool_as_type({}, False)".format(
return_name,
ast_to_source_code(parent_node_of_return.test).strip())
assign_false_node = gast.parse(node_str).body[0]
......
......@@ -42,7 +42,7 @@ def create_convert_shape_node(var_shape_node,
if slice_node is not None and slice_is_num(slice_node):
args.append(ast_to_source_code(slice_node.slice).strip())
convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format(
convert_var_shape_func = "_jst.convert_var_shape({}, in_control_flow={})".format(
",".join(args), in_control_flow)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value
......@@ -59,14 +59,14 @@ def create_convert_shape_node(var_shape_node,
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', globals())".format(
eval_exist_func = "_jst.eval_if_exist_else_none('{}', globals())".format(
api_shape_name)
args = [attr_shape_name, eval_exist_func]
if slice_node is not None and slice_is_num(slice_node):
args.append(ast_to_source_code(slice_node.slice).strip())
choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format(
",".join(args))
choose_shape_func = "_jst.choose_shape_attr_or_api({})".format(",".join(
args))
choose_shape_node = gast.parse(choose_shape_func).body[0].value
if slice_node is not None and not slice_is_num(slice_node):
return gast.Subscript(
......@@ -84,7 +84,7 @@ class ShapeAttributeTransformer(gast.NodeTransformer):
def visit_Attribute(self, node):
if node.attr == 'shape':
args = ast_to_source_code(node.value).strip()
convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape_simple({})".format(
convert_var_shape_func = "_jst.convert_var_shape_simple({})".format(
args)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value
return api_shape_node
......
......@@ -185,6 +185,7 @@ def is_api_in_module(node, module_prefix):
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
import paddle.fluid.layers as layers
import paddle.jit.dy2static as _jst
from paddle.fluid.dygraph import to_variable
from paddle import to_tensor
......@@ -521,8 +522,8 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
def _inject_import_statements():
import_statements = [
"import paddle", "from paddle import Tensor",
"import paddle.fluid as fluid", "from typing import *",
"import numpy as np"
"import paddle.fluid as fluid", "import paddle.jit.dy2static as _jst",
"from typing import *", "import numpy as np"
]
return '\n'.join(import_statements) + '\n'
......@@ -1168,7 +1169,7 @@ class ForNodeVisitor(object):
else:
iter_var_name = ast_to_source_code(self.iter_node).strip()
convert_len_node_source_str = '{} = paddle.jit.dy2static.convert_len({})'.format(
convert_len_node_source_str = '{} = _jst.convert_len({})'.format(
self.iter_var_len_name, iter_var_name)
convert_len_node = gast.parse(convert_len_node_source_str).body[0]
......
......@@ -77,14 +77,12 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
def to_static_variable_gast_node(name):
func_code = "{} = paddle.jit.dy2static.to_static_variable({})".format(name,
name)
func_code = "{} = _jst.to_static_variable({})".format(name, name)
return gast.parse(func_code).body[0]
def create_static_variable_gast_node(name):
func_code = "{} = paddle.jit.dy2static\
.data_layer_not_check(name='{}', shape=[-1], dtype='float32')".format(
func_code = "{} = _jst.data_layer_not_check(name='{}', shape=[-1], dtype='float32')".format(
name, unique_name.generate(name))
return gast.parse(func_code).body[0]
......
......@@ -24,6 +24,7 @@ import paddle.fluid as fluid
from paddle.fluid.dygraph import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import CONVERSION_OPTIONS
from test_program_translator import get_source_code
import paddle.jit.dy2static as _jst
program_translator = ProgramTranslator()
......@@ -255,7 +256,7 @@ class TestDynamicToStaticCode(unittest.TestCase):
return get_source_code(self.answer_func)
def _get_transformed_code(self):
transformed_func = paddle.jit.dy2static.convert_call(self.func)
transformed_func = _jst.convert_call(self.func)
return get_source_code(transformed_func)
def test_code(self):
......@@ -275,7 +276,7 @@ class TestDynamicToStaticCode2(TestDynamicToStaticCode):
def set_answer_func(self):
class StaticCode():
def func_convert_then_not_to_static(x):
y = paddle.jit.dy2static.convert_call(func_not_to_static)(x)
y = _jst.convert_call(func_not_to_static)(x)
return y
self.answer_func = StaticCode.func_convert_then_not_to_static
......
......@@ -65,7 +65,7 @@ class TestOriginInfo(unittest.TestCase):
self.func = simple_func
def set_static_lineno(self):
self.static_abs_lineno_list = [6, 7, 8]
self.static_abs_lineno_list = [7, 8, 9]
def set_dygraph_info(self):
self.line_num = 3
......@@ -149,7 +149,7 @@ class TestOriginInfoWithNestedFunc(TestOriginInfo):
self.func = nested_func
def set_static_lineno(self):
self.static_abs_lineno_list = [6, 8, 9, 10, 11]
self.static_abs_lineno_list = [7, 9, 10, 11, 12]
def set_dygraph_info(self):
self.line_num = 5
......@@ -174,7 +174,7 @@ class TestOriginInfoWithDecoratedFunc(TestOriginInfo):
self.func = decorated_func
def set_static_lineno(self):
self.static_abs_lineno_list = [6, 7]
self.static_abs_lineno_list = [7, 8]
def set_dygraph_info(self):
self.line_num = 2
......@@ -208,7 +208,7 @@ class TestOriginInfoWithDecoratedFunc2(TestOriginInfo):
self.func = decorated_func2
def set_static_lineno(self):
self.static_abs_lineno_list = [6, 7]
self.static_abs_lineno_list = [7, 8]
def set_dygraph_info(self):
self.line_num = 2
......
......@@ -27,6 +27,7 @@ from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
import paddle.jit.dy2static as _jst
from ifelse_simple_func import dyfunc_with_if_else
......@@ -76,40 +77,38 @@ class StaticCode1():
x_v = x_v + 1
return x_v
x_v = paddle.jit.dy2static.convert_ifelse(
x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ),
(x_v, ), (x_v, ))
__return_0 = paddle.jit.dy2static.create_bool_as_type(label is not None,
False)
__return_0 = _jst.create_bool_as_type(label is not None, False)
def true_fn_1(__return_0, __return_value_0, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label)
__return_0 = paddle.jit.dy2static.create_bool_as_type(
label is not None, True)
__return_0 = _jst.create_bool_as_type(label is not None, True)
__return_value_0 = loss
return __return_0, __return_value_0
def false_fn_1(__return_0, __return_value_0):
return __return_0, __return_value_0
__return_0, __return_value_0 = (paddle.jit.dy2static.convert_ifelse(
__return_0, __return_value_0 = _jst.convert_ifelse(
label is not None, true_fn_1, false_fn_1,
(__return_0, __return_value_0, label, x_v),
(__return_0, __return_value_0), (__return_0, __return_value_0)))
(__return_0, __return_value_0), (__return_0, __return_value_0))
def true_fn_2(__return_0, __return_value_0, x_v):
__return_1 = paddle.jit.dy2static.create_bool_as_type(
paddle.jit.dy2static.convert_logical_not(__return_0), True)
__return_1 = _jst.create_bool_as_type(
_jst.convert_logical_not(__return_0), True)
__return_value_0 = x_v
return __return_value_0
def false_fn_2(__return_value_0):
return __return_value_0
__return_value_0 = paddle.jit.dy2static.convert_ifelse(
paddle.jit.dy2static.convert_logical_not(__return_0), true_fn_2,
false_fn_2, (__return_0, __return_value_0,
x_v), (__return_value_0, ), (__return_value_0, ))
__return_value_0 = _jst.convert_ifelse(
_jst.convert_logical_not(__return_0), true_fn_2, false_fn_2,
(__return_0, __return_value_0,
x_v), (__return_value_0, ), (__return_value_0, ))
return __return_value_0
......@@ -128,40 +127,38 @@ class StaticCode2():
x_v = x_v + 1
return x_v
x_v = paddle.jit.dy2static.convert_ifelse(
x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_3, false_fn_3, (x_v, ),
(x_v, ), (x_v, ))
__return_2 = paddle.jit.dy2static.create_bool_as_type(label is not None,
False)
__return_2 = _jst.create_bool_as_type(label is not None, False)
def true_fn_4(__return_2, __return_value_1, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label)
__return_2 = paddle.jit.dy2static.create_bool_as_type(
label is not None, True)
__return_2 = _jst.create_bool_as_type(label is not None, True)
__return_value_1 = loss
return __return_2, __return_value_1
def false_fn_4(__return_2, __return_value_1):
return __return_2, __return_value_1
__return_2, __return_value_1 = paddle.jit.dy2static.convert_ifelse(
label is not None, true_fn_4, false_fn_4, (
__return_2, __return_value_1, label, x_v),
__return_2, __return_value_1 = _jst.convert_ifelse(
label is not None, true_fn_4, false_fn_4,
(__return_2, __return_value_1, label, x_v),
(__return_2, __return_value_1), (__return_2, __return_value_1))
def true_fn_5(__return_2, __return_value_1, x_v):
__return_3 = paddle.jit.dy2static.create_bool_as_type(
paddle.jit.dy2static.convert_logical_not(__return_2), True)
__return_3 = _jst.create_bool_as_type(
_jst.convert_logical_not(__return_2), True)
__return_value_1 = x_v
return __return_value_1
def false_fn_5(__return_value_1):
return __return_value_1
__return_value_1 = paddle.jit.dy2static.convert_ifelse(
paddle.jit.dy2static.convert_logical_not(__return_2), true_fn_5,
false_fn_5, (__return_2, __return_value_1,
x_v), (__return_value_1, ), (__return_value_1, ))
__return_value_1 = _jst.convert_ifelse(
_jst.convert_logical_not(__return_2), true_fn_5, false_fn_5,
(__return_2, __return_value_1,
x_v), (__return_value_1, ), (__return_value_1, ))
return __return_value_1
......
......@@ -597,9 +597,11 @@ class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase):
class TestPaddleShape(unittest.TestCase):
def test_paddle_shape(self):
func = paddle.jit.to_static(dyfunc_len_paddle_shape)
self.assertEqual('paddle.shape(x)' in func.code, True)
func_code = func.code.replace("\n", "").replace(" ", "")
self.assertEqual('paddle.shape(x)' in func_code, True)
func = paddle.jit.to_static(dyfunc_dict_assign_shape)
self.assertEqual("__static_convert_var_shape_suffix" in func.code, True)
func_code = func.code.replace("\n", "").replace(" ", "")
self.assertEqual("__static_convert_var_shape_suffix" in func_code, True)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册