未验证 提交 2d873008 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Replace paddle.jit.dy2stat with _jst (#42947)

* [Dy2Stat]Replace paddle.jit.dy2stat with _jst

* [Dy2Stat]Replace paddle.jit.dy2stat with _jst

* refine code style

* refine code style
上级 a76f2b33
...@@ -37,7 +37,7 @@ class AssertTransformer(gast.NodeTransformer): ...@@ -37,7 +37,7 @@ class AssertTransformer(gast.NodeTransformer):
def visit_Assert(self, node): def visit_Assert(self, node):
convert_assert_node = gast.parse( convert_assert_node = gast.parse(
'paddle.jit.dy2static.convert_assert({test}, {msg})'.format( '_jst.convert_assert({test}, {msg})'.format(
test=ast_to_source_code(node.test), test=ast_to_source_code(node.test),
msg=ast_to_source_code(node.msg) msg=ast_to_source_code(node.msg)
if node.msg else "")).body[0].value if node.msg else "")).body[0].value
......
...@@ -71,7 +71,7 @@ class CallTransformer(gast.NodeTransformer): ...@@ -71,7 +71,7 @@ class CallTransformer(gast.NodeTransformer):
if PDB_SET in func_str: if PDB_SET in func_str:
return node return node
new_func_str = "paddle.jit.dy2static.convert_call({})".format(func_str) new_func_str = "_jst.convert_call({})".format(func_str)
new_func_ast = gast.parse(new_func_str).body[0].value new_func_ast = gast.parse(new_func_str).body[0].value
node.func = new_func_ast node.func = new_func_ast
......
...@@ -39,8 +39,8 @@ class CastTransformer(gast.NodeTransformer): ...@@ -39,8 +39,8 @@ class CastTransformer(gast.NodeTransformer):
func_str = ast_to_source_code(node.func).strip() func_str = ast_to_source_code(node.func).strip()
if func_str in self._castable_type and len(node.args) > 0: if func_str in self._castable_type and len(node.args) > 0:
args_str = ast_to_source_code(node.args[0]).strip() args_str = ast_to_source_code(node.args[0]).strip()
new_func_str = "paddle.jit.dy2static.convert_var_dtype({}, '{}')".format( new_func_str = "_jst.convert_var_dtype({}, '{}')".format(args_str,
args_str, func_str) func_str)
new_node = gast.parse(new_func_str).body[0].value new_node = gast.parse(new_func_str).body[0].value
return new_node return new_node
......
...@@ -536,7 +536,7 @@ def create_convert_ifelse_node(return_name_ids, ...@@ -536,7 +536,7 @@ def create_convert_ifelse_node(return_name_ids,
return_vars = create_name_nodes(return_name_ids) return_vars = create_name_nodes(return_name_ids)
convert_ifelse_layer = gast.parse( convert_ifelse_layer = gast.parse(
'paddle.jit.dy2static.convert_ifelse(' '_jst.convert_ifelse('
'{pred}, {true_fn}, {false_fn}, {true_args}, {false_args}, {return_vars})'. '{pred}, {true_fn}, {false_fn}, {true_args}, {false_args}, {return_vars})'.
format( format(
pred=ast_to_source_code(pred), pred=ast_to_source_code(pred),
......
...@@ -129,7 +129,7 @@ class ListTransformer(gast.NodeTransformer): ...@@ -129,7 +129,7 @@ class ListTransformer(gast.NodeTransformer):
elif slice_is_num(target_node): elif slice_is_num(target_node):
value_code = ast_to_source_code(node.value) value_code = ast_to_source_code(node.value)
i = "paddle.cast(" \ i = "paddle.cast(" \
"x=paddle.jit.dy2static.to_static_variable({})," \ "x=_jst.to_static_variable({})," \
"dtype='int64')".format(ast_to_source_code(slice_node)) "dtype='int64')".format(ast_to_source_code(slice_node))
assign_code = "{} = paddle.tensor.array_write(x={}, i={}, array={})" \ assign_code = "{} = paddle.tensor.array_write(x={}, i={}, array={})" \
.format(target_name, value_code, i, target_name) .format(target_name, value_code, i, target_name)
...@@ -252,7 +252,7 @@ class ListTransformer(gast.NodeTransformer): ...@@ -252,7 +252,7 @@ class ListTransformer(gast.NodeTransformer):
# 2. pop stmt for a list or dict if len(args_str) == 1 # 2. pop stmt for a list or dict if len(args_str) == 1
# 3. pop stmt for a dict if len(args_str) == 2 # 3. pop stmt for a dict if len(args_str) == 2
if len(args_str) <= 2: if len(args_str) <= 2:
new_pop_str = "paddle.jit.dy2static.convert_pop({}, {})"\ new_pop_str = "_jst.convert_pop({}, {})"\
.format(target_str, ",".join(args_str)) .format(target_str, ",".join(args_str))
new_pop_node = gast.parse(new_pop_str).body[0].value new_pop_node = gast.parse(new_pop_str).body[0].value
return new_pop_node return new_pop_node
......
...@@ -57,8 +57,7 @@ class LogicalTransformer(gast.NodeTransformer): ...@@ -57,8 +57,7 @@ class LogicalTransformer(gast.NodeTransformer):
self.generic_visit(node) self.generic_visit(node)
if isinstance(node.op, gast.Not): if isinstance(node.op, gast.Not):
arg = ast_to_source_code(node.operand) arg = ast_to_source_code(node.operand)
new_node_str = "paddle.jit.dy2static.convert_logical_not({})".format( new_node_str = "_jst.convert_logical_not({})".format(arg)
arg)
# NOTE: gast.parse returns Module(body=[expr(value=...)]) # NOTE: gast.parse returns Module(body=[expr(value=...)])
new_node = gast.parse(new_node_str).body[0].value new_node = gast.parse(new_node_str).body[0].value
return new_node return new_node
...@@ -67,13 +66,12 @@ class LogicalTransformer(gast.NodeTransformer): ...@@ -67,13 +66,12 @@ class LogicalTransformer(gast.NodeTransformer):
def visit_Compare(self, node): def visit_Compare(self, node):
self.generic_visit(node) self.generic_visit(node)
left_str = ast_to_source_code(node.left).strip() left_str = ast_to_source_code(node.left).strip()
if left_str.startswith("paddle.jit.dy2static.convert_var_shape"): if left_str.startswith("_jst.convert_var_shape"):
# check left and comparators are all converted var shape # check left and comparators are all converted var shape
compare_arg_strs = left_str compare_arg_strs = left_str
for i, comparator in enumerate(node.comparators): for i, comparator in enumerate(node.comparators):
comparator_str = ast_to_source_code(comparator).strip() comparator_str = ast_to_source_code(comparator).strip()
if not comparator_str.startswith( if not comparator_str.startswith("_jst.convert_var_shape"):
"paddle.jit.dy2static.convert_var_shape"):
return node return node
op_str = cmpop_node_to_str(node.ops[i]) op_str = cmpop_node_to_str(node.ops[i])
compare_arg_strs += (", '" + op_str + "', " + comparator_str) compare_arg_strs += (", '" + op_str + "', " + comparator_str)
...@@ -81,7 +79,7 @@ class LogicalTransformer(gast.NodeTransformer): ...@@ -81,7 +79,7 @@ class LogicalTransformer(gast.NodeTransformer):
# Now all left and comparators are converted shape # Now all left and comparators are converted shape
# Replace some comparsion operation because of difference between # Replace some comparsion operation because of difference between
# Python and Paddle # Python and Paddle
new_node_str = "paddle.jit.dy2static.convert_shape_compare({})".format( new_node_str = "_jst.convert_shape_compare({})".format(
compare_arg_strs) compare_arg_strs)
new_node = gast.parse(new_node_str).body[0].value new_node = gast.parse(new_node_str).body[0].value
return new_node return new_node
...@@ -119,7 +117,7 @@ class LogicalTransformer(gast.NodeTransformer): ...@@ -119,7 +117,7 @@ class LogicalTransformer(gast.NodeTransformer):
nodes = [pre_logic_node] + [post_logic_node] nodes = [pre_logic_node] + [post_logic_node]
args = [ast_to_source_code(child) for child in nodes] args = [ast_to_source_code(child) for child in nodes]
new_node_str = "paddle.jit.dy2static.convert_logical_{}(lambda:{}, lambda:{})".format( new_node_str = "_jst.convert_logical_{}(lambda:{}, lambda:{})".format(
api_type, args[0], args[1]) api_type, args[0], args[1])
# NOTE: gast.parse return Module(body=[expr(...)]) # NOTE: gast.parse return Module(body=[expr(...)])
new_node = gast.parse(new_node_str).body[0].value new_node = gast.parse(new_node_str).body[0].value
......
...@@ -89,7 +89,7 @@ def create_while_nodes(condition_name, body_name, loop_var_names): ...@@ -89,7 +89,7 @@ def create_while_nodes(condition_name, body_name, loop_var_names):
else: else:
assign_loop_var_names.append(name) assign_loop_var_names.append(name)
while_func_name = "paddle.jit.dy2static.convert_while_loop" while_func_name = "_jst.convert_while_loop"
while_node_str = "[{}] = {}({}, {}, [{}])".format( while_node_str = "[{}] = {}({}, {}, [{}])".format(
",".join(assign_loop_var_names), while_func_name, condition_name, ",".join(assign_loop_var_names), while_func_name, condition_name,
body_name, ",".join(loop_var_names)) body_name, ",".join(loop_var_names))
......
...@@ -50,6 +50,5 @@ class PrintTransformer(gast.NodeTransformer): ...@@ -50,6 +50,5 @@ class PrintTransformer(gast.NodeTransformer):
return gast.Expr(value=convert_print_node) return gast.Expr(value=convert_print_node)
def _create_print_node(self, print_args): def _create_print_node(self, print_args):
convert_print_func = gast.parse( convert_print_func = gast.parse('_jst.convert_print').body[0].value
'paddle.jit.dy2static.convert_print').body[0].value
return gast.Call(func=convert_print_func, args=print_args, keywords=[]) return gast.Call(func=convert_print_func, args=print_args, keywords=[])
...@@ -336,7 +336,7 @@ class ReturnTransformer(gast.NodeTransformer): ...@@ -336,7 +336,7 @@ class ReturnTransformer(gast.NodeTransformer):
# Here assume that the parent node of return is gast.If # Here assume that the parent node of return is gast.If
if isinstance(parent_node_of_return, gast.If): if isinstance(parent_node_of_return, gast.If):
# Prepend control flow boolean nodes such as '__return@1 = True' # Prepend control flow boolean nodes such as '__return@1 = True'
node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, True)".format( node_str = "{} = _jst.create_bool_as_type({}, True)".format(
return_name, return_name,
ast_to_source_code(parent_node_of_return.test).strip()) ast_to_source_code(parent_node_of_return.test).strip())
...@@ -449,7 +449,7 @@ class ReturnTransformer(gast.NodeTransformer): ...@@ -449,7 +449,7 @@ class ReturnTransformer(gast.NodeTransformer):
# Here assume that the parent node of return is gast.If # Here assume that the parent node of return is gast.If
if isinstance(parent_node_of_return, gast.If): if isinstance(parent_node_of_return, gast.If):
# Prepend control flow boolean nodes such as '__return@1 = False' # Prepend control flow boolean nodes such as '__return@1 = False'
node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, False)".format( node_str = "{} = _jst.create_bool_as_type({}, False)".format(
return_name, return_name,
ast_to_source_code(parent_node_of_return.test).strip()) ast_to_source_code(parent_node_of_return.test).strip())
assign_false_node = gast.parse(node_str).body[0] assign_false_node = gast.parse(node_str).body[0]
......
...@@ -42,7 +42,7 @@ def create_convert_shape_node(var_shape_node, ...@@ -42,7 +42,7 @@ def create_convert_shape_node(var_shape_node,
if slice_node is not None and slice_is_num(slice_node): if slice_node is not None and slice_is_num(slice_node):
args.append(ast_to_source_code(slice_node.slice).strip()) args.append(ast_to_source_code(slice_node.slice).strip())
convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format( convert_var_shape_func = "_jst.convert_var_shape({}, in_control_flow={})".format(
",".join(args), in_control_flow) ",".join(args), in_control_flow)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value api_shape_node = gast.parse(convert_var_shape_func).body[0].value
...@@ -59,14 +59,14 @@ def create_convert_shape_node(var_shape_node, ...@@ -59,14 +59,14 @@ def create_convert_shape_node(var_shape_node,
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None): def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', globals())".format( eval_exist_func = "_jst.eval_if_exist_else_none('{}', globals())".format(
api_shape_name) api_shape_name)
args = [attr_shape_name, eval_exist_func] args = [attr_shape_name, eval_exist_func]
if slice_node is not None and slice_is_num(slice_node): if slice_node is not None and slice_is_num(slice_node):
args.append(ast_to_source_code(slice_node.slice).strip()) args.append(ast_to_source_code(slice_node.slice).strip())
choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format( choose_shape_func = "_jst.choose_shape_attr_or_api({})".format(",".join(
",".join(args)) args))
choose_shape_node = gast.parse(choose_shape_func).body[0].value choose_shape_node = gast.parse(choose_shape_func).body[0].value
if slice_node is not None and not slice_is_num(slice_node): if slice_node is not None and not slice_is_num(slice_node):
return gast.Subscript( return gast.Subscript(
...@@ -84,7 +84,7 @@ class ShapeAttributeTransformer(gast.NodeTransformer): ...@@ -84,7 +84,7 @@ class ShapeAttributeTransformer(gast.NodeTransformer):
def visit_Attribute(self, node): def visit_Attribute(self, node):
if node.attr == 'shape': if node.attr == 'shape':
args = ast_to_source_code(node.value).strip() args = ast_to_source_code(node.value).strip()
convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape_simple({})".format( convert_var_shape_func = "_jst.convert_var_shape_simple({})".format(
args) args)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value api_shape_node = gast.parse(convert_var_shape_func).body[0].value
return api_shape_node return api_shape_node
......
...@@ -185,6 +185,7 @@ def is_api_in_module(node, module_prefix): ...@@ -185,6 +185,7 @@ def is_api_in_module(node, module_prefix):
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph import paddle.fluid.dygraph as dygraph
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.jit.dy2static as _jst
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle import to_tensor from paddle import to_tensor
...@@ -521,8 +522,8 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): ...@@ -521,8 +522,8 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
def _inject_import_statements(): def _inject_import_statements():
import_statements = [ import_statements = [
"import paddle", "from paddle import Tensor", "import paddle", "from paddle import Tensor",
"import paddle.fluid as fluid", "from typing import *", "import paddle.fluid as fluid", "import paddle.jit.dy2static as _jst",
"import numpy as np" "from typing import *", "import numpy as np"
] ]
return '\n'.join(import_statements) + '\n' return '\n'.join(import_statements) + '\n'
...@@ -1168,7 +1169,7 @@ class ForNodeVisitor(object): ...@@ -1168,7 +1169,7 @@ class ForNodeVisitor(object):
else: else:
iter_var_name = ast_to_source_code(self.iter_node).strip() iter_var_name = ast_to_source_code(self.iter_node).strip()
convert_len_node_source_str = '{} = paddle.jit.dy2static.convert_len({})'.format( convert_len_node_source_str = '{} = _jst.convert_len({})'.format(
self.iter_var_len_name, iter_var_name) self.iter_var_len_name, iter_var_name)
convert_len_node = gast.parse(convert_len_node_source_str).body[0] convert_len_node = gast.parse(convert_len_node_source_str).body[0]
......
...@@ -77,14 +77,12 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0): ...@@ -77,14 +77,12 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
def to_static_variable_gast_node(name): def to_static_variable_gast_node(name):
func_code = "{} = paddle.jit.dy2static.to_static_variable({})".format(name, func_code = "{} = _jst.to_static_variable({})".format(name, name)
name)
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
def create_static_variable_gast_node(name): def create_static_variable_gast_node(name):
func_code = "{} = paddle.jit.dy2static\ func_code = "{} = _jst.data_layer_not_check(name='{}', shape=[-1], dtype='float32')".format(
.data_layer_not_check(name='{}', shape=[-1], dtype='float32')".format(
name, unique_name.generate(name)) name, unique_name.generate(name))
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
......
...@@ -24,6 +24,7 @@ import paddle.fluid as fluid ...@@ -24,6 +24,7 @@ import paddle.fluid as fluid
from paddle.fluid.dygraph import ProgramTranslator from paddle.fluid.dygraph import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import CONVERSION_OPTIONS from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import CONVERSION_OPTIONS
from test_program_translator import get_source_code from test_program_translator import get_source_code
import paddle.jit.dy2static as _jst
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
...@@ -255,7 +256,7 @@ class TestDynamicToStaticCode(unittest.TestCase): ...@@ -255,7 +256,7 @@ class TestDynamicToStaticCode(unittest.TestCase):
return get_source_code(self.answer_func) return get_source_code(self.answer_func)
def _get_transformed_code(self): def _get_transformed_code(self):
transformed_func = paddle.jit.dy2static.convert_call(self.func) transformed_func = _jst.convert_call(self.func)
return get_source_code(transformed_func) return get_source_code(transformed_func)
def test_code(self): def test_code(self):
...@@ -275,7 +276,7 @@ class TestDynamicToStaticCode2(TestDynamicToStaticCode): ...@@ -275,7 +276,7 @@ class TestDynamicToStaticCode2(TestDynamicToStaticCode):
def set_answer_func(self): def set_answer_func(self):
class StaticCode(): class StaticCode():
def func_convert_then_not_to_static(x): def func_convert_then_not_to_static(x):
y = paddle.jit.dy2static.convert_call(func_not_to_static)(x) y = _jst.convert_call(func_not_to_static)(x)
return y return y
self.answer_func = StaticCode.func_convert_then_not_to_static self.answer_func = StaticCode.func_convert_then_not_to_static
......
...@@ -65,7 +65,7 @@ class TestOriginInfo(unittest.TestCase): ...@@ -65,7 +65,7 @@ class TestOriginInfo(unittest.TestCase):
self.func = simple_func self.func = simple_func
def set_static_lineno(self): def set_static_lineno(self):
self.static_abs_lineno_list = [6, 7, 8] self.static_abs_lineno_list = [7, 8, 9]
def set_dygraph_info(self): def set_dygraph_info(self):
self.line_num = 3 self.line_num = 3
...@@ -149,7 +149,7 @@ class TestOriginInfoWithNestedFunc(TestOriginInfo): ...@@ -149,7 +149,7 @@ class TestOriginInfoWithNestedFunc(TestOriginInfo):
self.func = nested_func self.func = nested_func
def set_static_lineno(self): def set_static_lineno(self):
self.static_abs_lineno_list = [6, 8, 9, 10, 11] self.static_abs_lineno_list = [7, 9, 10, 11, 12]
def set_dygraph_info(self): def set_dygraph_info(self):
self.line_num = 5 self.line_num = 5
...@@ -174,7 +174,7 @@ class TestOriginInfoWithDecoratedFunc(TestOriginInfo): ...@@ -174,7 +174,7 @@ class TestOriginInfoWithDecoratedFunc(TestOriginInfo):
self.func = decorated_func self.func = decorated_func
def set_static_lineno(self): def set_static_lineno(self):
self.static_abs_lineno_list = [6, 7] self.static_abs_lineno_list = [7, 8]
def set_dygraph_info(self): def set_dygraph_info(self):
self.line_num = 2 self.line_num = 2
...@@ -208,7 +208,7 @@ class TestOriginInfoWithDecoratedFunc2(TestOriginInfo): ...@@ -208,7 +208,7 @@ class TestOriginInfoWithDecoratedFunc2(TestOriginInfo):
self.func = decorated_func2 self.func = decorated_func2
def set_static_lineno(self): def set_static_lineno(self):
self.static_abs_lineno_list = [6, 7] self.static_abs_lineno_list = [7, 8]
def set_dygraph_info(self): def set_dygraph_info(self):
self.line_num = 2 self.line_num = 2
......
...@@ -27,6 +27,7 @@ from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator ...@@ -27,6 +27,7 @@ from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.nn import Linear from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
import paddle.jit.dy2static as _jst
from ifelse_simple_func import dyfunc_with_if_else from ifelse_simple_func import dyfunc_with_if_else
...@@ -76,39 +77,37 @@ class StaticCode1(): ...@@ -76,39 +77,37 @@ class StaticCode1():
x_v = x_v + 1 x_v = x_v + 1
return x_v return x_v
x_v = paddle.jit.dy2static.convert_ifelse( x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ), fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ),
(x_v, ), (x_v, )) (x_v, ), (x_v, ))
__return_0 = paddle.jit.dy2static.create_bool_as_type(label is not None, __return_0 = _jst.create_bool_as_type(label is not None, False)
False)
def true_fn_1(__return_0, __return_value_0, label, x_v): def true_fn_1(__return_0, __return_value_0, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label) loss = fluid.layers.cross_entropy(x_v, label)
__return_0 = paddle.jit.dy2static.create_bool_as_type( __return_0 = _jst.create_bool_as_type(label is not None, True)
label is not None, True)
__return_value_0 = loss __return_value_0 = loss
return __return_0, __return_value_0 return __return_0, __return_value_0
def false_fn_1(__return_0, __return_value_0): def false_fn_1(__return_0, __return_value_0):
return __return_0, __return_value_0 return __return_0, __return_value_0
__return_0, __return_value_0 = (paddle.jit.dy2static.convert_ifelse( __return_0, __return_value_0 = _jst.convert_ifelse(
label is not None, true_fn_1, false_fn_1, label is not None, true_fn_1, false_fn_1,
(__return_0, __return_value_0, label, x_v), (__return_0, __return_value_0, label, x_v),
(__return_0, __return_value_0), (__return_0, __return_value_0))) (__return_0, __return_value_0), (__return_0, __return_value_0))
def true_fn_2(__return_0, __return_value_0, x_v): def true_fn_2(__return_0, __return_value_0, x_v):
__return_1 = paddle.jit.dy2static.create_bool_as_type( __return_1 = _jst.create_bool_as_type(
paddle.jit.dy2static.convert_logical_not(__return_0), True) _jst.convert_logical_not(__return_0), True)
__return_value_0 = x_v __return_value_0 = x_v
return __return_value_0 return __return_value_0
def false_fn_2(__return_value_0): def false_fn_2(__return_value_0):
return __return_value_0 return __return_value_0
__return_value_0 = paddle.jit.dy2static.convert_ifelse( __return_value_0 = _jst.convert_ifelse(
paddle.jit.dy2static.convert_logical_not(__return_0), true_fn_2, _jst.convert_logical_not(__return_0), true_fn_2, false_fn_2,
false_fn_2, (__return_0, __return_value_0, (__return_0, __return_value_0,
x_v), (__return_value_0, ), (__return_value_0, )) x_v), (__return_value_0, ), (__return_value_0, ))
return __return_value_0 return __return_value_0
...@@ -128,39 +127,37 @@ class StaticCode2(): ...@@ -128,39 +127,37 @@ class StaticCode2():
x_v = x_v + 1 x_v = x_v + 1
return x_v return x_v
x_v = paddle.jit.dy2static.convert_ifelse( x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_3, false_fn_3, (x_v, ), fluid.layers.mean(x_v)[0] > 5, true_fn_3, false_fn_3, (x_v, ),
(x_v, ), (x_v, )) (x_v, ), (x_v, ))
__return_2 = paddle.jit.dy2static.create_bool_as_type(label is not None, __return_2 = _jst.create_bool_as_type(label is not None, False)
False)
def true_fn_4(__return_2, __return_value_1, label, x_v): def true_fn_4(__return_2, __return_value_1, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label) loss = fluid.layers.cross_entropy(x_v, label)
__return_2 = paddle.jit.dy2static.create_bool_as_type( __return_2 = _jst.create_bool_as_type(label is not None, True)
label is not None, True)
__return_value_1 = loss __return_value_1 = loss
return __return_2, __return_value_1 return __return_2, __return_value_1
def false_fn_4(__return_2, __return_value_1): def false_fn_4(__return_2, __return_value_1):
return __return_2, __return_value_1 return __return_2, __return_value_1
__return_2, __return_value_1 = paddle.jit.dy2static.convert_ifelse( __return_2, __return_value_1 = _jst.convert_ifelse(
label is not None, true_fn_4, false_fn_4, ( label is not None, true_fn_4, false_fn_4,
__return_2, __return_value_1, label, x_v), (__return_2, __return_value_1, label, x_v),
(__return_2, __return_value_1), (__return_2, __return_value_1)) (__return_2, __return_value_1), (__return_2, __return_value_1))
def true_fn_5(__return_2, __return_value_1, x_v): def true_fn_5(__return_2, __return_value_1, x_v):
__return_3 = paddle.jit.dy2static.create_bool_as_type( __return_3 = _jst.create_bool_as_type(
paddle.jit.dy2static.convert_logical_not(__return_2), True) _jst.convert_logical_not(__return_2), True)
__return_value_1 = x_v __return_value_1 = x_v
return __return_value_1 return __return_value_1
def false_fn_5(__return_value_1): def false_fn_5(__return_value_1):
return __return_value_1 return __return_value_1
__return_value_1 = paddle.jit.dy2static.convert_ifelse( __return_value_1 = _jst.convert_ifelse(
paddle.jit.dy2static.convert_logical_not(__return_2), true_fn_5, _jst.convert_logical_not(__return_2), true_fn_5, false_fn_5,
false_fn_5, (__return_2, __return_value_1, (__return_2, __return_value_1,
x_v), (__return_value_1, ), (__return_value_1, )) x_v), (__return_value_1, ), (__return_value_1, ))
return __return_value_1 return __return_value_1
......
...@@ -597,9 +597,11 @@ class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase): ...@@ -597,9 +597,11 @@ class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase):
class TestPaddleShape(unittest.TestCase): class TestPaddleShape(unittest.TestCase):
def test_paddle_shape(self): def test_paddle_shape(self):
func = paddle.jit.to_static(dyfunc_len_paddle_shape) func = paddle.jit.to_static(dyfunc_len_paddle_shape)
self.assertEqual('paddle.shape(x)' in func.code, True) func_code = func.code.replace("\n", "").replace(" ", "")
self.assertEqual('paddle.shape(x)' in func_code, True)
func = paddle.jit.to_static(dyfunc_dict_assign_shape) func = paddle.jit.to_static(dyfunc_dict_assign_shape)
self.assertEqual("__static_convert_var_shape_suffix" in func.code, True) func_code = func.code.replace("\n", "").replace(" ", "")
self.assertEqual("__static_convert_var_shape_suffix" in func_code, True)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册