diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/assert_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/assert_transformer.py index e2fcf4f2c2712eddc07cf5552738f7ae3aa01e0f..4d5076108cd31ad6c6cde811b49c6042f17a1c3f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/assert_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/assert_transformer.py @@ -37,7 +37,7 @@ class AssertTransformer(gast.NodeTransformer): def visit_Assert(self, node): convert_assert_node = gast.parse( - 'paddle.jit.dy2static.convert_assert({test}, {msg})'.format( + '_jst.convert_assert({test}, {msg})'.format( test=ast_to_source_code(node.test), msg=ast_to_source_code(node.msg) if node.msg else "")).body[0].value diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py index a80dfa11402c5c434f278ab2964cf6efda41b106..c16d1ff17f70718c8450f93f9c728da512072e9d 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py @@ -71,7 +71,7 @@ class CallTransformer(gast.NodeTransformer): if PDB_SET in func_str: return node - new_func_str = "paddle.jit.dy2static.convert_call({})".format(func_str) + new_func_str = "_jst.convert_call({})".format(func_str) new_func_ast = gast.parse(new_func_str).body[0].value node.func = new_func_ast diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/cast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/cast_transformer.py index ef2d062d2d0187de9a030de2a71dff66b7b51aad..50733e4d896e4a94d8d95e55878283d08f143196 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/cast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/cast_transformer.py @@ -39,8 +39,8 @@ class CastTransformer(gast.NodeTransformer): func_str = ast_to_source_code(node.func).strip() if func_str in self._castable_type and len(node.args) > 0: args_str = ast_to_source_code(node.args[0]).strip() - new_func_str = "paddle.jit.dy2static.convert_var_dtype({}, '{}')".format( - args_str, func_str) + new_func_str = "_jst.convert_var_dtype({}, '{}')".format(args_str, + func_str) new_node = gast.parse(new_func_str).body[0].value return new_node diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index 8fc5a691d212c22924574f09b07b2bb448d97541..157822430d23427be7a727b0a533aca93e3afa91 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -536,7 +536,7 @@ def create_convert_ifelse_node(return_name_ids, return_vars = create_name_nodes(return_name_ids) convert_ifelse_layer = gast.parse( - 'paddle.jit.dy2static.convert_ifelse(' + '_jst.convert_ifelse(' '{pred}, {true_fn}, {false_fn}, {true_args}, {false_args}, {return_vars})'. format( pred=ast_to_source_code(pred), diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py index e62def897d2eb4cb6653b9fdeb16ced16757618c..0951635162e5e6afdb4526e1b5233ee01b71c897 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py @@ -129,7 +129,7 @@ class ListTransformer(gast.NodeTransformer): elif slice_is_num(target_node): value_code = ast_to_source_code(node.value) i = "paddle.cast(" \ - "x=paddle.jit.dy2static.to_static_variable({})," \ + "x=_jst.to_static_variable({})," \ "dtype='int64')".format(ast_to_source_code(slice_node)) assign_code = "{} = paddle.tensor.array_write(x={}, i={}, array={})" \ .format(target_name, value_code, i, target_name) @@ -252,7 +252,7 @@ class ListTransformer(gast.NodeTransformer): # 2. pop stmt for a list or dict if len(args_str) == 1 # 3. pop stmt for a dict if len(args_str) == 2 if len(args_str) <= 2: - new_pop_str = "paddle.jit.dy2static.convert_pop({}, {})"\ + new_pop_str = "_jst.convert_pop({}, {})"\ .format(target_str, ",".join(args_str)) new_pop_node = gast.parse(new_pop_str).body[0].value return new_pop_node diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py index e5c093f9a9255c3d20ec2db4bb91571c6cc57d6b..bd573521f1b4e5aabd91f0c9760bc1c14ba975f7 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py @@ -57,8 +57,7 @@ class LogicalTransformer(gast.NodeTransformer): self.generic_visit(node) if isinstance(node.op, gast.Not): arg = ast_to_source_code(node.operand) - new_node_str = "paddle.jit.dy2static.convert_logical_not({})".format( - arg) + new_node_str = "_jst.convert_logical_not({})".format(arg) # NOTE: gast.parse returns Module(body=[expr(value=...)]) new_node = gast.parse(new_node_str).body[0].value return new_node @@ -67,13 +66,12 @@ class LogicalTransformer(gast.NodeTransformer): def visit_Compare(self, node): self.generic_visit(node) left_str = ast_to_source_code(node.left).strip() - if left_str.startswith("paddle.jit.dy2static.convert_var_shape"): + if left_str.startswith("_jst.convert_var_shape"): # check left and comparators are all converted var shape compare_arg_strs = left_str for i, comparator in enumerate(node.comparators): comparator_str = ast_to_source_code(comparator).strip() - if not comparator_str.startswith( - "paddle.jit.dy2static.convert_var_shape"): + if not comparator_str.startswith("_jst.convert_var_shape"): return node op_str = cmpop_node_to_str(node.ops[i]) compare_arg_strs += (", '" + op_str + "', " + comparator_str) @@ -81,7 +79,7 @@ class LogicalTransformer(gast.NodeTransformer): # Now all left and comparators are converted shape # Replace some comparsion operation because of difference between # Python and Paddle - new_node_str = "paddle.jit.dy2static.convert_shape_compare({})".format( + new_node_str = "_jst.convert_shape_compare({})".format( compare_arg_strs) new_node = gast.parse(new_node_str).body[0].value return new_node @@ -119,7 +117,7 @@ class LogicalTransformer(gast.NodeTransformer): nodes = [pre_logic_node] + [post_logic_node] args = [ast_to_source_code(child) for child in nodes] - new_node_str = "paddle.jit.dy2static.convert_logical_{}(lambda:{}, lambda:{})".format( + new_node_str = "_jst.convert_logical_{}(lambda:{}, lambda:{})".format( api_type, args[0], args[1]) # NOTE: gast.parse return Module(body=[expr(...)]) new_node = gast.parse(new_node_str).body[0].value diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 4e5a3f7b7085137fbe59bc0dc362f7d21e7bc75a..8014a00bff98396888fb46759f9d2ede960a9788 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -89,7 +89,7 @@ def create_while_nodes(condition_name, body_name, loop_var_names): else: assign_loop_var_names.append(name) - while_func_name = "paddle.jit.dy2static.convert_while_loop" + while_func_name = "_jst.convert_while_loop" while_node_str = "[{}] = {}({}, {}, [{}])".format( ",".join(assign_loop_var_names), while_func_name, condition_name, body_name, ",".join(loop_var_names)) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py index 7960617369e3f24c7ba134dfb5a2de923afb538c..f045d01c99bab018afa193ec00cc22106ec7d776 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py @@ -50,6 +50,5 @@ class PrintTransformer(gast.NodeTransformer): return gast.Expr(value=convert_print_node) def _create_print_node(self, print_args): - convert_print_func = gast.parse( - 'paddle.jit.dy2static.convert_print').body[0].value + convert_print_func = gast.parse('_jst.convert_print').body[0].value return gast.Call(func=convert_print_func, args=print_args, keywords=[]) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py index 0c7a8bf421a1282bfa542c8eaf4a93db90a1ad90..8ac659dbead99173115dfc71c3c0c97b9cb435de 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py @@ -336,7 +336,7 @@ class ReturnTransformer(gast.NodeTransformer): # Here assume that the parent node of return is gast.If if isinstance(parent_node_of_return, gast.If): # Prepend control flow boolean nodes such as '__return@1 = True' - node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, True)".format( + node_str = "{} = _jst.create_bool_as_type({}, True)".format( return_name, ast_to_source_code(parent_node_of_return.test).strip()) @@ -449,7 +449,7 @@ class ReturnTransformer(gast.NodeTransformer): # Here assume that the parent node of return is gast.If if isinstance(parent_node_of_return, gast.If): # Prepend control flow boolean nodes such as '__return@1 = False' - node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, False)".format( + node_str = "{} = _jst.create_bool_as_type({}, False)".format( return_name, ast_to_source_code(parent_node_of_return.test).strip()) assign_false_node = gast.parse(node_str).body[0] diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index 7733226cc09f2d6e2f9bcb8403ed1be42aa75e0c..d5b23d2f53b1ce42451310fb66b21d911dca57d5 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -42,7 +42,7 @@ def create_convert_shape_node(var_shape_node, if slice_node is not None and slice_is_num(slice_node): args.append(ast_to_source_code(slice_node.slice).strip()) - convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format( + convert_var_shape_func = "_jst.convert_var_shape({}, in_control_flow={})".format( ",".join(args), in_control_flow) api_shape_node = gast.parse(convert_var_shape_func).body[0].value @@ -59,14 +59,14 @@ def create_convert_shape_node(var_shape_node, def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None): - eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', globals())".format( + eval_exist_func = "_jst.eval_if_exist_else_none('{}', globals())".format( api_shape_name) args = [attr_shape_name, eval_exist_func] if slice_node is not None and slice_is_num(slice_node): args.append(ast_to_source_code(slice_node.slice).strip()) - choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format( - ",".join(args)) + choose_shape_func = "_jst.choose_shape_attr_or_api({})".format(",".join( + args)) choose_shape_node = gast.parse(choose_shape_func).body[0].value if slice_node is not None and not slice_is_num(slice_node): return gast.Subscript( @@ -84,7 +84,7 @@ class ShapeAttributeTransformer(gast.NodeTransformer): def visit_Attribute(self, node): if node.attr == 'shape': args = ast_to_source_code(node.value).strip() - convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape_simple({})".format( + convert_var_shape_func = "_jst.convert_var_shape_simple({})".format( args) api_shape_node = gast.parse(convert_var_shape_func).body[0].value return api_shape_node diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index d440e387da597d20300e50d7862dab80fb161c2a..91c2c5dc65aab7bb38cd957da88d7d9695769241 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -185,6 +185,7 @@ def is_api_in_module(node, module_prefix): import paddle.fluid as fluid import paddle.fluid.dygraph as dygraph import paddle.fluid.layers as layers + import paddle.jit.dy2static as _jst from paddle.fluid.dygraph import to_variable from paddle import to_tensor @@ -521,8 +522,8 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): def _inject_import_statements(): import_statements = [ "import paddle", "from paddle import Tensor", - "import paddle.fluid as fluid", "from typing import *", - "import numpy as np" + "import paddle.fluid as fluid", "import paddle.jit.dy2static as _jst", + "from typing import *", "import numpy as np" ] return '\n'.join(import_statements) + '\n' @@ -1168,7 +1169,7 @@ class ForNodeVisitor(object): else: iter_var_name = ast_to_source_code(self.iter_node).strip() - convert_len_node_source_str = '{} = paddle.jit.dy2static.convert_len({})'.format( + convert_len_node_source_str = '{} = _jst.convert_len({})'.format( self.iter_var_len_name, iter_var_name) convert_len_node = gast.parse(convert_len_node_source_str).body[0] diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py index 2cd6c5e43f7e1261d2bb48a8cbfc8151327c7dea..7ce5aede4995dcfa8a5be92907c3dae055848e05 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py @@ -77,14 +77,12 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0): def to_static_variable_gast_node(name): - func_code = "{} = paddle.jit.dy2static.to_static_variable({})".format(name, - name) + func_code = "{} = _jst.to_static_variable({})".format(name, name) return gast.parse(func_code).body[0] def create_static_variable_gast_node(name): - func_code = "{} = paddle.jit.dy2static\ - .data_layer_not_check(name='{}', shape=[-1], dtype='float32')".format( + func_code = "{} = _jst.data_layer_not_check(name='{}', shape=[-1], dtype='float32')".format( name, unique_name.generate(name)) return gast.parse(func_code).body[0] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py index fb918f4ae00edc2d67640f0c90bae767b3188431..2e2918facf896907f9d26ffe5bf64c6447cbc2cb 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py @@ -24,6 +24,7 @@ import paddle.fluid as fluid from paddle.fluid.dygraph import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import CONVERSION_OPTIONS from test_program_translator import get_source_code +import paddle.jit.dy2static as _jst program_translator = ProgramTranslator() @@ -255,7 +256,7 @@ class TestDynamicToStaticCode(unittest.TestCase): return get_source_code(self.answer_func) def _get_transformed_code(self): - transformed_func = paddle.jit.dy2static.convert_call(self.func) + transformed_func = _jst.convert_call(self.func) return get_source_code(transformed_func) def test_code(self): @@ -275,7 +276,7 @@ class TestDynamicToStaticCode2(TestDynamicToStaticCode): def set_answer_func(self): class StaticCode(): def func_convert_then_not_to_static(x): - y = paddle.jit.dy2static.convert_call(func_not_to_static)(x) + y = _jst.convert_call(func_not_to_static)(x) return y self.answer_func = StaticCode.func_convert_then_not_to_static diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py index e3d34184a38fc94b00d73aa9466880148207d475..8dac8889935904b54220adf1c5dd01ca63fe9236 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py @@ -65,7 +65,7 @@ class TestOriginInfo(unittest.TestCase): self.func = simple_func def set_static_lineno(self): - self.static_abs_lineno_list = [6, 7, 8] + self.static_abs_lineno_list = [7, 8, 9] def set_dygraph_info(self): self.line_num = 3 @@ -149,7 +149,7 @@ class TestOriginInfoWithNestedFunc(TestOriginInfo): self.func = nested_func def set_static_lineno(self): - self.static_abs_lineno_list = [6, 8, 9, 10, 11] + self.static_abs_lineno_list = [7, 9, 10, 11, 12] def set_dygraph_info(self): self.line_num = 5 @@ -174,7 +174,7 @@ class TestOriginInfoWithDecoratedFunc(TestOriginInfo): self.func = decorated_func def set_static_lineno(self): - self.static_abs_lineno_list = [6, 7] + self.static_abs_lineno_list = [7, 8] def set_dygraph_info(self): self.line_num = 2 @@ -208,7 +208,7 @@ class TestOriginInfoWithDecoratedFunc2(TestOriginInfo): self.func = decorated_func2 def set_static_lineno(self): - self.static_abs_lineno_list = [6, 7] + self.static_abs_lineno_list = [7, 8] def set_dygraph_info(self): self.line_num = 2 diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index b0ffbac88fb42921a744d046665f56b1221537e1..4e90c73baa94404804f389a174d420f2f3a814db 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -27,6 +27,7 @@ from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.nn import Linear from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code +import paddle.jit.dy2static as _jst from ifelse_simple_func import dyfunc_with_if_else @@ -76,40 +77,38 @@ class StaticCode1(): x_v = x_v + 1 return x_v - x_v = paddle.jit.dy2static.convert_ifelse( + x_v = _jst.convert_ifelse( fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ), (x_v, ), (x_v, )) - __return_0 = paddle.jit.dy2static.create_bool_as_type(label is not None, - False) + __return_0 = _jst.create_bool_as_type(label is not None, False) def true_fn_1(__return_0, __return_value_0, label, x_v): loss = fluid.layers.cross_entropy(x_v, label) - __return_0 = paddle.jit.dy2static.create_bool_as_type( - label is not None, True) + __return_0 = _jst.create_bool_as_type(label is not None, True) __return_value_0 = loss return __return_0, __return_value_0 def false_fn_1(__return_0, __return_value_0): return __return_0, __return_value_0 - __return_0, __return_value_0 = (paddle.jit.dy2static.convert_ifelse( + __return_0, __return_value_0 = _jst.convert_ifelse( label is not None, true_fn_1, false_fn_1, (__return_0, __return_value_0, label, x_v), - (__return_0, __return_value_0), (__return_0, __return_value_0))) + (__return_0, __return_value_0), (__return_0, __return_value_0)) def true_fn_2(__return_0, __return_value_0, x_v): - __return_1 = paddle.jit.dy2static.create_bool_as_type( - paddle.jit.dy2static.convert_logical_not(__return_0), True) + __return_1 = _jst.create_bool_as_type( + _jst.convert_logical_not(__return_0), True) __return_value_0 = x_v return __return_value_0 def false_fn_2(__return_value_0): return __return_value_0 - __return_value_0 = paddle.jit.dy2static.convert_ifelse( - paddle.jit.dy2static.convert_logical_not(__return_0), true_fn_2, - false_fn_2, (__return_0, __return_value_0, - x_v), (__return_value_0, ), (__return_value_0, )) + __return_value_0 = _jst.convert_ifelse( + _jst.convert_logical_not(__return_0), true_fn_2, false_fn_2, + (__return_0, __return_value_0, + x_v), (__return_value_0, ), (__return_value_0, )) return __return_value_0 @@ -128,40 +127,38 @@ class StaticCode2(): x_v = x_v + 1 return x_v - x_v = paddle.jit.dy2static.convert_ifelse( + x_v = _jst.convert_ifelse( fluid.layers.mean(x_v)[0] > 5, true_fn_3, false_fn_3, (x_v, ), (x_v, ), (x_v, )) - __return_2 = paddle.jit.dy2static.create_bool_as_type(label is not None, - False) + __return_2 = _jst.create_bool_as_type(label is not None, False) def true_fn_4(__return_2, __return_value_1, label, x_v): loss = fluid.layers.cross_entropy(x_v, label) - __return_2 = paddle.jit.dy2static.create_bool_as_type( - label is not None, True) + __return_2 = _jst.create_bool_as_type(label is not None, True) __return_value_1 = loss return __return_2, __return_value_1 def false_fn_4(__return_2, __return_value_1): return __return_2, __return_value_1 - __return_2, __return_value_1 = paddle.jit.dy2static.convert_ifelse( - label is not None, true_fn_4, false_fn_4, ( - __return_2, __return_value_1, label, x_v), + __return_2, __return_value_1 = _jst.convert_ifelse( + label is not None, true_fn_4, false_fn_4, + (__return_2, __return_value_1, label, x_v), (__return_2, __return_value_1), (__return_2, __return_value_1)) def true_fn_5(__return_2, __return_value_1, x_v): - __return_3 = paddle.jit.dy2static.create_bool_as_type( - paddle.jit.dy2static.convert_logical_not(__return_2), True) + __return_3 = _jst.create_bool_as_type( + _jst.convert_logical_not(__return_2), True) __return_value_1 = x_v return __return_value_1 def false_fn_5(__return_value_1): return __return_value_1 - __return_value_1 = paddle.jit.dy2static.convert_ifelse( - paddle.jit.dy2static.convert_logical_not(__return_2), true_fn_5, - false_fn_5, (__return_2, __return_value_1, - x_v), (__return_value_1, ), (__return_value_1, )) + __return_value_1 = _jst.convert_ifelse( + _jst.convert_logical_not(__return_2), true_fn_5, false_fn_5, + (__return_2, __return_value_1, + x_v), (__return_value_1, ), (__return_value_1, )) return __return_value_1 diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index d05be03bbfb193ae25ee039aef1608afdef4f585..5cf9d7749c3581fd7326d767e2dc0fa24e0fad91 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -597,9 +597,11 @@ class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase): class TestPaddleShape(unittest.TestCase): def test_paddle_shape(self): func = paddle.jit.to_static(dyfunc_len_paddle_shape) - self.assertEqual('paddle.shape(x)' in func.code, True) + func_code = func.code.replace("\n", "").replace(" ", "") + self.assertEqual('paddle.shape(x)' in func_code, True) func = paddle.jit.to_static(dyfunc_dict_assign_shape) - self.assertEqual("__static_convert_var_shape_suffix" in func.code, True) + func_code = func.code.replace("\n", "").replace(" ", "") + self.assertEqual("__static_convert_var_shape_suffix" in func_code, True) if __name__ == '__main__':