From d82d5b8ce8044320813734f6a0bc8bca0b2eff71 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 27 Jun 2022 16:31:40 +0800 Subject: [PATCH] [Dy2Stat]Refactor convert_shape transformer logic (#43846) * [Dy2Stat]Refactor convert_shape transformer logic * clean usless unittest --- .../dygraph_to_static/convert_operators.py | 90 +---- .../dygraph_to_static/logical_transformer.py | 22 -- .../tensor_shape_transformer.py | 364 +----------------- .../test_convert_operators.py | 103 ----- .../dygraph_to_static/test_tensor_shape.py | 94 +++-- python/paddle/jit/dy2static/__init__.py | 5 +- .../paddle/jit/dy2static/convert_operators.py | 5 +- 7 files changed, 73 insertions(+), 610 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 0346e4f1ef..bf97362ab7 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -338,88 +338,30 @@ def convert_zip(*args): return zip(*args) -def convert_var_shape(x, idx=None, in_control_flow=False): +def convert_shape(x): """ A function representation of the shape of variable. """ - def has_negative(list_shape, idx=None): - if idx is not None: - return list_shape[idx] < 0 - - num_negative = sum([1 if i < 0 else 0 for i in list_shape]) - return num_negative > 0 - - # When `x` is Variable, call nn.shape(x) in following cases: - # (1) The shape of `x` is used in control flow condition. - # ``` - # if x.shape[0] == 1: - # y = XX - # ``` - # (2) The dim to be used is negative - # ``` - # # Assume x.shape=[3, -1] in static mode - # y = paddle.reshape(x, shape=[1, x.shape[1]]) - # ``` - if isinstance(x, Variable) and has_negative(x.shape, idx): - return nn.shape(x) if idx is None else nn.shape(x)[idx] - else: - return list(x.shape) if idx is None else x.shape[idx] + def has_negative(list_shape): + return any([x < 0 for x in list_shape]) + # When `x` is Variable: + # (1) if x.shape contains -1, such as [2, -1, 64], returns [2, var, 64], + # where var = paddle.shape(x)[1] + + # (2) if x.shape does not contains -1, return lsit(x.shape) directly -def convert_var_shape_simple(x): - """ - A function representation of the shape of variable. - """ if isinstance(x, Variable): - return nn.shape(x) + values = list(x.shape) + if has_negative(values): + shape_tensor = nn.shape(x) + for i, v in enumerate(values): + if v is None or v < 0: + values[i] = shape_tensor[i] + return values else: - # Use list() to make returned type consistant with dygraph - return list(x.shape) - - -def eval_if_exist_else_none(name, global_symbol_table): - """ - Args: - name([str]): Expression passed into `eval`. - local_symbol_table(dict): Specified from `globals()`. DO NOT use `locals()`, - because all STATIC_CONVERT_VAR_SHAPE_SUFFIX vars is - declared with keyword `global`. - - Returns: - Return the variable if found in global_symbol_table else None. - """ - try: - return eval(name, global_symbol_table) - except: - return None - - -def choose_shape_attr_or_api(attr_shape, api_shape, idx=None): - """ - Input can be attribute `x.shape` or api `shape(x)`, this function - chooses which one to return to use in dy2stat. - - Note: sometimes users write `x.shape[3]`, so attr_shape can be an integer. - """ - if api_shape is None: - return attr_shape if idx is None else attr_shape[idx] - if not isinstance(attr_shape, (list, tuple)): - # some variables like x.shape[0] is no longer a list or tuple - if isinstance(attr_shape, int) and attr_shape < 0: - return api_shape if idx is None else api_shape[idx] - return attr_shape if idx is None else attr_shape[idx] - - def has_negative(list_shape, idx=None): - if idx is not None: - return list_shape[idx] < 0 - - num_negative = sum([1 if i < 0 else 0 for i in list_shape]) - return num_negative > 0 - - if has_negative(attr_shape, idx): - return api_shape if idx is None else api_shape[idx] - return attr_shape if idx is None else attr_shape[idx] + return x.shape def convert_shape_compare(left, *args): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py index bd573521f1..5cf8c61001 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py @@ -63,28 +63,6 @@ class LogicalTransformer(gast.NodeTransformer): return new_node return node - def visit_Compare(self, node): - self.generic_visit(node) - left_str = ast_to_source_code(node.left).strip() - if left_str.startswith("_jst.convert_var_shape"): - # check left and comparators are all converted var shape - compare_arg_strs = left_str - for i, comparator in enumerate(node.comparators): - comparator_str = ast_to_source_code(comparator).strip() - if not comparator_str.startswith("_jst.convert_var_shape"): - return node - op_str = cmpop_node_to_str(node.ops[i]) - compare_arg_strs += (", '" + op_str + "', " + comparator_str) - - # Now all left and comparators are converted shape - # Replace some comparsion operation because of difference between - # Python and Paddle - new_node_str = "_jst.convert_shape_compare({})".format( - compare_arg_strs) - new_node = gast.parse(new_node_str).body[0].value - return new_node - return node - def visit_BoolOp(self, node): self.generic_visit(node) if isinstance(node.op, gast.And): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index a04171dfc3..9c19b9fc25 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -25,77 +25,11 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor -STATIC_CONVERT_VAR_SHAPE_SUFFIX = '__static_convert_var_shape_suffix' - - -def create_convert_shape_node(var_shape_node, - slice_node=None, - in_control_flow=False): - assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript)) - - if isinstance(var_shape_node, gast.Attribute): - args = [ast_to_source_code(var_shape_node.value).strip()] - # (1) A slice can be a simple number such as 1, -2, i.e. gast.Index or gast.Constant - # (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index or gast.Constant - # In (1) case, we pass the number as 'idx' argument in convert_var_shape - # In (2) case, we have to make it like `convert_var_shape(x)[slice]` - if slice_node is not None and slice_is_num(slice_node): - args.append(ast_to_source_code(slice_node.slice).strip()) - - convert_var_shape_func = "_jst.convert_var_shape({}, in_control_flow={})".format( - ",".join(args), in_control_flow) - api_shape_node = gast.parse(convert_var_shape_func).body[0].value - - if slice_node is not None and not slice_is_num(slice_node): - return gast.Subscript(value=api_shape_node, - slice=slice_node.slice, - ctx=gast.Load()) - return api_shape_node - - if isinstance(var_shape_node, gast.Subscript): - result_node = copy.deepcopy(var_shape_node) - result_node = create_convert_shape_node(result_node.value, result_node, - in_control_flow) - return result_node - - -def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None): - eval_exist_func = "_jst.eval_if_exist_else_none('{}', globals())".format( - api_shape_name) - args = [attr_shape_name, eval_exist_func] - - if slice_node is not None and slice_is_num(slice_node): - args.append(ast_to_source_code(slice_node.slice).strip()) - choose_shape_func = "_jst.choose_shape_attr_or_api({})".format( - ",".join(args)) - choose_shape_node = gast.parse(choose_shape_func).body[0].value - if slice_node is not None and not slice_is_num(slice_node): - return gast.Subscript(value=choose_shape_node, - slice=slice_node.slice, - ctx=gast.Load()) - return choose_shape_node - - -class ShapeAttributeTransformer(gast.NodeTransformer): - """ - Input a node like `x.shape` or `x[4].shape[0]` (self._is_var_shape(node) is True), - return a new node changes input to static shape API like `convert_var_shape(x)`, - `convert_var_shape(x[4])[0]`. - """ - - def visit_Attribute(self, node): - if node.attr == 'shape': - args = ast_to_source_code(node.value).strip() - convert_var_shape_func = "_jst.convert_var_shape_simple({})".format( - args) - api_shape_node = gast.parse(convert_var_shape_func).body[0].value - return api_shape_node - return node - class TensorShapeTransformer(gast.NodeTransformer): """ - This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast. + This class transforms variable.shape into Static Graph Ast. + All 'xxx.shape' will be converted int '_jst.convert_shape(x)'. """ def __init__(self, wrapper_root): @@ -104,295 +38,17 @@ class TensorShapeTransformer(gast.NodeTransformer): ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node - # stores origin var string name (like "x" in `x = t.shape`) to - # static shape var string name (like "x_SUFFIX" in `x_SUFFIX = shape(t)`) - self.name_to_var_shape = {} - - self.static_analysis_visitor = StaticAnalysisVisitor(self.root) - self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( - ) - var_env = self.static_analysis_visitor.get_var_env() - var_env.cur_scope = var_env.cur_scope.sub_scopes[0] - self.scope_var_type_dict = var_env.get_scope_var_type() def transform(self): - SplitAssignTransformer(self.root).transform() self.visit(self.root) - def visit_Assign(self, node): - update_static_shape_var_node = self._update_name_to_var_shape(node) - if update_static_shape_var_node is not None: - ret = [node] - ret.extend(update_static_shape_var_node) - return ret - self.generic_visit(node) - return node - - def visit_Subscript(self, node): - value_node = node.value - slice_node = node.slice - if isinstance(value_node, gast.Name): - if value_node.id in self.name_to_var_shape and self._used_by_paddle_api( - value_node): - return create_choose_shape_node( - value_node.id, self.name_to_var_shape[value_node.id], node) - elif isinstance(value_node, gast.Attribute): - if self._used_by_paddle_api(value_node): - value_name = ast_to_source_code(value_node).strip() - if value_name in self.name_to_var_shape: - return create_choose_shape_node( - value_name, self.name_to_var_shape[value_name], node) - if self._is_var_shape(value_node): - return create_convert_shape_node(value_node, node) - return node - def visit_Attribute(self, node): - if self._used_by_paddle_api(node): - name = ast_to_source_code(node).strip() - if name in self.name_to_var_shape: - return create_choose_shape_node(name, - self.name_to_var_shape[name]) - if self._is_var_shape(node): - return create_convert_shape_node(node) - return node - - def visit_Name(self, node): - if node.id in self.name_to_var_shape: - if self._used_by_paddle_api(node): - return create_choose_shape_node(node.id, - self.name_to_var_shape[node.id]) - return node - - def visit_Call(self, node): - if is_paddle_api(node): - # Visit gast.Attribute and gast.Name to replace var.shape if necessary. - self.generic_visit(node) - # Don't have to visit other APIs - return node - - def visit_If(self, node): - # Call generic_visit first to transform var.shape that is used in Paddle Api. - self.generic_visit(node) - cond = node.test - self._transform_var_shape_if_necessary(cond) - - return node - - def visit_While(self, node): - self.generic_visit(node) - cond = node.test - self._transform_var_shape_if_necessary(cond) - return node - - def visit_For(self, node): - self.generic_visit(node) - iter = node.iter - self._transform_var_shape_if_necessary(iter) - - # If var.shape is a gast.Name and it is used in range function, transform it - self._transform_var_shape_in_range(node) + if node.attr == 'shape': + args = ast_to_source_code(node.value).strip() + # NOTE(dev): we can deal with paddle.shape in this case, but it's + # not pretty to modify into 'convert_shape(paddle)(x)[0]'. + if args != 'paddle': + convert_shape_func = "_jst.convert_shape({})".format(args) + shape_node = gast.parse(convert_shape_func).body[0].value + return shape_node return node - - def _transform_var_shape_in_range(self, node): - assert isinstance(node, gast.For) - if not isinstance(node.iter, gast.Call): - return False - if not isinstance(node.iter.func, gast.Name): - return False - if node.iter.func.id != "range": - return False - args = node.iter.args - for idx, arg in enumerate(args): - if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape: - args[idx] = create_choose_shape_node( - arg.id, self.name_to_var_shape[arg.id]) - return True - - def _transform_var_shape_if_necessary(self, cond): - need_transformed = False - for child_node in gast.walk(cond): - var_shape_node = None - if isinstance(child_node, - (gast.Name, gast.Attribute, gast.Subscript)): - child_name = ast_to_source_code(child_node).strip() - if child_name in self.name_to_var_shape: - var_shape_node = create_choose_shape_node( - child_name, self.name_to_var_shape[child_name]) - elif self._is_var_shape(child_node): - var_shape_node = child_node - - if var_shape_node: - need_transformed = True - wrapper_node = self.node_to_wrapper_map.get(child_node) - parent_node = wrapper_node.parent.node - for field, value in gast.iter_fields(parent_node): - if child_node is value: - if var_shape_node is child_node: - setattr( - parent_node, field, - create_convert_shape_node( - var_shape_node, None, True)) - else: - setattr(parent_node, field, var_shape_node) - break - # Some child_node may be in a list such as gast.Compare - if isinstance(value, list): - has_converted_shape = False - for i, v in enumerate(value): - if child_node is v: - if var_shape_node is child_node: - value[i] = create_convert_shape_node( - var_shape_node, None, True) - else: - value[i] = var_shape_node - has_converted_shape = True - break - if has_converted_shape: - break - return need_transformed - - def _used_by_paddle_api(self, node): - """ - Whether node is used in paddle api as arguments. - For example: - 1) Return True in `paddle.relu(x)` where node is `x` (gast.Name) - 2) Return True in `paddle.add(self.x)` where node is `self.x` (gast.Attribute) - 3) Return False in `paddle.add(self.x)` where node is `paddle.add` (gast.Attribute), - because the role of node is not arguments but `gast.Call.func`. - """ - assert isinstance(node, (gast.Attribute, gast.Name)) - wrapper_node = self.node_to_wrapper_map.get(node) - if not wrapper_node: - # Transformed node is not in node_to_wrapper_map - return False - while wrapper_node.parent: - parent_node = wrapper_node.parent.node - if isinstance(parent_node, gast.Call): - # Note(Aurelius84): Filter the case when the role of node is `gast.Call.func`. - if is_paddle_api(parent_node) and parent_node.func != node: - return True - else: - return False - wrapper_node = wrapper_node.parent - - return False - - def _is_var_shape(self, node): - """ - Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise. - """ - if not isinstance(node, (gast.Attribute, gast.Subscript)): - return False - - if isinstance(node, gast.Attribute): - # If node is `paddle.shape`, return False - if (node.attr == 'shape' and isinstance(node.value, gast.Name) - and node.value.id == 'paddle'): - return False - if node.attr != 'shape': - return False - return True - - if isinstance(node, gast.Subscript): - value_node = node.value - return self._is_var_shape(value_node) - - return False - - def _update_name_to_var_shape(self, node): - assert isinstance(node, gast.Assign) - target_node = node.targets[0] - value_node = node.value - - update_static_shape_var_node = None - if isinstance(target_node, gast.Tuple): - update_static_shape_var_node = [] - for idx, element in enumerate(target_node.elts): - target_id = ast_to_source_code(element).strip() - - if isinstance(value_node, gast.Name): - if value_node.id in self.name_to_var_shape: - # TODO(zhhsplendid): is context a problem for the result node of gast.parse? - static_shape_var_name = unique_name.generate( - STATIC_CONVERT_VAR_SHAPE_SUFFIX) - static_shape_var_node = gast.parse( - static_shape_var_name).body[0].value - - static_shape_value_name = self.name_to_var_shape[ - value_node.id] - - sub_node_str = "{}[{}]".format(static_shape_value_name, - idx) - sub_node = gast.parse(sub_node_str).body[0].value - - update_static_shape_var_node.append( - gast.Assign(targets=[static_shape_var_node], - value=sub_node)) - - self.name_to_var_shape[ - target_id] = static_shape_var_name - if isinstance(value_node, gast.Attribute): - if self._is_var_shape(value_node): # eg: x.shape - static_shape_var_name = unique_name.generate( - STATIC_CONVERT_VAR_SHAPE_SUFFIX) - static_shape_var_node = gast.parse( - static_shape_var_name).body[0].value - - static_shape_value_node = copy.deepcopy(value_node) - # x.shape becomes convert_var_shape_simple(x) - static_shape_value_node = ShapeAttributeTransformer( - ).visit(static_shape_value_node) - - sub_node_str = "{}[{}]".format( - ast_to_source_code(static_shape_value_node).strip(), - idx) - sub_node = gast.parse(sub_node_str).body[0].value - # Note(Aurelius84): Becuase static_shape_var_name is used in - # eval_if_exist_else_none() as plain string, so it will not - # be pasred as argument in convert_loop/ifelse. We delcare it - # as global var because it has unique name. - update_static_shape_var_node.append( - gast.Global(names=[static_shape_var_name])) - - update_static_shape_var_node.append( - gast.Assign(targets=[static_shape_var_node], - value=sub_node)) - self.name_to_var_shape[ - target_id] = static_shape_var_name - return update_static_shape_var_node - else: - target_id = ast_to_source_code(target_node).strip() - if isinstance(value_node, gast.Name): - if value_node.id in self.name_to_var_shape: - static_shape_var_name = unique_name.generate( - STATIC_CONVERT_VAR_SHAPE_SUFFIX) - static_shape_var_node = gast.parse( - static_shape_var_name).body[0].value - static_shape_value_name = self.name_to_var_shape[ - value_node.id] - static_shape_value_node = gast.parse( - static_shape_value_name).body[0].value - - update_static_shape_var_node = [ - gast.Assign(targets=[static_shape_var_node], - value=static_shape_value_node) - ] - self.name_to_var_shape[target_id] = static_shape_var_name - elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0] - static_shape_var_name = unique_name.generate( - STATIC_CONVERT_VAR_SHAPE_SUFFIX) - static_shape_var_node = gast.parse( - static_shape_var_name).body[0].value - static_shape_value_node = copy.deepcopy(value_node) - # x.shape becomes convert_var_shape_simple(x) - static_shape_value_node = ShapeAttributeTransformer().visit( - static_shape_value_node) - # Declare static_shape_var_name as global var - update_static_shape_var_node = [ - gast.Global(names=[static_shape_var_name]) - ] - update_static_shape_var_node.append( - gast.Assign(targets=[static_shape_var_node], - value=static_shape_value_node)) - self.name_to_var_shape[target_id] = static_shape_var_name - return update_static_shape_var_node diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py index 375873aa14..b5ccf735ce 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py @@ -15,7 +15,6 @@ import numpy as np import paddle import unittest -from paddle.jit.dy2static.convert_operators import eval_if_exist_else_none class CallNotExist(paddle.nn.Layer): @@ -143,108 +142,6 @@ class TestConvertShapeCompare(unittest.TestCase): paddle.disable_static() -class TestChooseShapeAttrOrApi(unittest.TestCase): - - def test_api_shape_is_none(self): - self.assertEqual( - paddle.jit.dy2static.choose_shape_attr_or_api([1, 2], None), [1, 2]) - self.assertEqual( - paddle.jit.dy2static.choose_shape_attr_or_api([1], None), [1]) - self.assertEqual( - paddle.jit.dy2static.choose_shape_attr_or_api([2, 3, 7], None, 0), - 2) - - def test_attr_shape_is_int(self): - x = paddle.zeros([1, 3, 5, 7]) - self.assertEqual( - paddle.jit.dy2static.choose_shape_attr_or_api( - x.shape[0], - paddle.shape(x)[0]), 1) - self.assertEqual( - paddle.jit.dy2static.choose_shape_attr_or_api( - x.shape[1], - paddle.shape(x)[1]), 3) - self.assertEqual( - paddle.jit.dy2static.choose_shape_attr_or_api( - -1, - paddle.shape(x)[0]), - paddle.shape(x)[0]) - self.assertEqual( - paddle.jit.dy2static.choose_shape_attr_or_api( - -1, paddle.shape(x), 0), - paddle.shape(x)[0]) - - def test_positive_attr_shape(self): - x = paddle.zeros([1, 3, 5, 7]) - self.assertEqual( - paddle.jit.dy2static.choose_shape_attr_or_api( - x.shape, paddle.shape(x)), x.shape) - self.assertEqual( - paddle.jit.dy2static.choose_shape_attr_or_api( - x.shape, paddle.shape(x), 3), x.shape[3]) - - def test_negative_attr_shape(self): - x = paddle.zeros([7]) - self.assertEqual( - paddle.jit.dy2static.choose_shape_attr_or_api([-1], paddle.shape(x), - 0), - paddle.shape(x)[0]) - self.assertEqual( - paddle.jit.dy2static.choose_shape_attr_or_api([-1], - paddle.shape(x)), - paddle.shape(x)) - - -class TestEvaIfExistElseNone(unittest.TestCase): - - def test_globals(self): - global x_shape - x_shape = [1, 2, 3] - self.assertEqual(eval_if_exist_else_none('x_shape', locals()), None) - self.assertEqual(eval_if_exist_else_none('x_shape', globals()), x_shape) - - del x_shape - - def test_enclosing_scope(self): - global x_shape - x_shape = [1, 2, 3] - - def foo(): - y_shape = [2, 3, 4] - self.assertEqual(eval_if_exist_else_none('x_shape', globals()), - [1, 2, 3]) - self.assertEqual(eval_if_exist_else_none('y_shape', locals()), - [2, 3, 4]) - - foo() - del x_shape - - def test_global_in_func(self): - x_shape = [1, 2, 3] - - def foo(): - global y_shape - y_shape = [2, 3, 4] - - self.assertEqual(eval_if_exist_else_none('y_shape', globals()), - [2, 3, 4]) - self.assertEqual(eval_if_exist_else_none('x_shape', locals()), None) - self.assertEqual(eval_if_exist_else_none('x_shape', globals()), - None) - - del y_shape - - foo() - - def test_none(self): - - def foo(): - x_shape = [2, 3, 4] - return x_shape - - self.assertEqual(eval_if_exist_else_none('x_shape', locals()), None) - - class ShapeLayer(paddle.nn.Layer): def __init__(self): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index 70ff91eff5..54c7866503 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -275,6 +275,7 @@ class TestTensorShapeBasic(unittest.TestCase): self.expected_slice_op_num = 0 def _compute_op_num(self, program): + print(program) self.op_num = sum([len(block.ops) for block in program.blocks]) self.shape_op_num = 0 self.slice_op_num = 0 @@ -300,8 +301,8 @@ class TestTensorShapeBasic2(TestTensorShapeBasic): self.dygraph_func = dyfunc_tensor_shape_2 def _set_expected_op_num(self): - self.expected_op_num = 3 - self.expected_shape_op_num = 1 + self.expected_op_num = 2 + self.expected_shape_op_num = 0 self.expected_slice_op_num = 0 @@ -323,9 +324,9 @@ class TestTensorShapeBasic5(TestTensorShapeBasic): self.dygraph_func = dyfunc_tensor_shape_5 def _set_expected_op_num(self): - self.expected_op_num = 4 - self.expected_shape_op_num = 1 - self.expected_slice_op_num = 1 + self.expected_op_num = 2 + self.expected_shape_op_num = 0 + self.expected_slice_op_num = 0 class TestTensorShapeBasic6(TestTensorShapeBasic): @@ -334,21 +335,23 @@ class TestTensorShapeBasic6(TestTensorShapeBasic): self.dygraph_func = dyfunc_tensor_shape_6 def _set_expected_op_num(self): - self.expected_op_num = 4 - self.expected_shape_op_num = 1 - self.expected_slice_op_num = 1 + self.expected_op_num = 2 + self.expected_shape_op_num = 0 + self.expected_slice_op_num = 0 class TestTupleShape1(TestTensorShapeBasic): def init_test_func(self): self.input = numpy.ones((5, 7)).astype("int32") - self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")] + self.input_spec = [ + paddle.static.InputSpec(shape=[-1, -1], dtype="int32") + ] self.dygraph_func = dyfunc_tuple_shape_1 def _set_expected_op_num(self): - self.expected_op_num = 6 - self.expected_shape_op_num = 2 + self.expected_op_num = 5 + self.expected_shape_op_num = 1 self.expected_slice_op_num = 2 @@ -356,13 +359,15 @@ class TestTupleShape2(TestTensorShapeBasic): def init_test_func(self): self.input = numpy.ones((5, 7)).astype("int32") - self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")] + self.input_spec = [ + paddle.static.InputSpec(shape=[-1, 7], dtype="int32") + ] self.dygraph_func = dyfunc_tuple_shape_2 def _set_expected_op_num(self): self.expected_op_num = 5 self.expected_shape_op_num = 1 - self.expected_slice_op_num = 2 + self.expected_slice_op_num = 1 class TestTupleShape3(TestTensorShapeBasic): @@ -398,9 +403,9 @@ class TestTensorShapeInIf1(TestTensorShapeBasic): self.dygraph_func = dyfunc_with_if_1 def _set_expected_op_num(self): - self.expected_op_num = 4 - self.expected_shape_op_num = 1 - self.expected_slice_op_num = 1 + self.expected_op_num = 2 + self.expected_shape_op_num = 0 + self.expected_slice_op_num = 0 class TestTensorShapeInIf2(TestTensorShapeBasic): @@ -432,9 +437,9 @@ class TestTensorShapeInFor2(TestTensorShapeInFor1): self.dygraph_func = dyfunc_with_for_2 def _set_expected_op_num(self): - self.expected_op_num = 9 - self.expected_shape_op_num = 1 - self.expected_slice_op_num = 1 + self.expected_op_num = 7 + self.expected_shape_op_num = 0 + self.expected_slice_op_num = 0 class TestTensorShapeInFor3(TestTensorShapeInFor1): @@ -466,9 +471,9 @@ class TestTensorShapeInWhile2(TestTensorShapeInFor1): self.dygraph_func = dyfunc_with_while_2 def _set_expected_op_num(self): - self.expected_op_num = 6 - self.expected_shape_op_num = 1 - self.expected_slice_op_num = 1 + self.expected_op_num = 4 + self.expected_shape_op_num = 0 + self.expected_slice_op_num = 0 class TestTensorShapeInWhile3(TestTensorShapeBasic): @@ -477,8 +482,8 @@ class TestTensorShapeInWhile3(TestTensorShapeBasic): self.dygraph_func = dyfunc_with_while_3 def _set_expected_op_num(self): - self.expected_op_num = 3 - self.expected_shape_op_num = 1 + self.expected_op_num = 2 + self.expected_shape_op_num = 0 self.expected_slice_op_num = 0 @@ -510,9 +515,9 @@ class TestOpNumBasicWithTensorShape(unittest.TestCase): self.dygraph_func = dyfunc_tensor_shape_1 def _set_expected_op_num(self): - self.expected_op_num = 3 + self.expected_op_num = 5 self.expected_shape_op_num = 1 - self.expected_slice_op_num = 0 + self.expected_slice_op_num = 1 def _compute_op_num(self, program): self.op_num = sum([len(block.ops) for block in program.blocks]) @@ -541,9 +546,9 @@ class TestOpNumBasicWithTensorShape4(TestOpNumBasicWithTensorShape): self.dygraph_func = dyfunc_tensor_shape_4 def _set_expected_op_num(self): - self.expected_op_num = 6 - self.expected_shape_op_num = 1 - self.expected_slice_op_num = 1 + self.expected_op_num = 8 + self.expected_shape_op_num = 2 + self.expected_slice_op_num = 2 class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape): @@ -552,9 +557,9 @@ class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape): self.dygraph_func = dyfunc_tuple_shape_1 def _set_expected_op_num(self): - self.expected_op_num = 7 - self.expected_shape_op_num = 2 - self.expected_slice_op_num = 2 + self.expected_op_num = 5 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 1 class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape): @@ -563,9 +568,9 @@ class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape): self.dygraph_func = dyfunc_with_if_1 def _set_expected_op_num(self): - self.expected_op_num = 28 + self.expected_op_num = 32 self.expected_shape_op_num = 4 - self.expected_slice_op_num = 2 + self.expected_slice_op_num = 4 class TestOpNumWithTensorShapeInFor1(TestOpNumBasicWithTensorShape): @@ -594,13 +599,15 @@ class TestChangeShapeAfterAssign(TestTensorShapeBasic): def init_test_func(self): self.input = numpy.ones((2, 3)).astype("int32") - self.input_spec = [paddle.static.InputSpec(shape=[2, 3], dtype="int32")] + self.input_spec = [ + paddle.static.InputSpec(shape=[-1, 3], dtype="int32") + ] self.dygraph_func = dyfunc_change_shape_after_assign def _set_expected_op_num(self): - self.expected_op_num = 7 - self.expected_shape_op_num = 2 - self.expected_slice_op_num = 2 + self.expected_op_num = 6 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 1 def dyfunc_with_static_convert_var_shape(x): @@ -627,16 +634,5 @@ class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase): func.concrete_program -class TestPaddleShape(unittest.TestCase): - - def test_paddle_shape(self): - func = paddle.jit.to_static(dyfunc_len_paddle_shape) - func_code = func.code.replace("\n", "").replace(" ", "") - self.assertEqual('paddle.shape(x)' in func_code, True) - func = paddle.jit.to_static(dyfunc_dict_assign_shape) - func_code = func.code.replace("\n", "").replace(" ", "") - self.assertEqual("__static_convert_var_shape_suffix" in func_code, True) - - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/jit/dy2static/__init__.py b/python/paddle/jit/dy2static/__init__.py index ebe3ba716f..7f20c00024 100644 --- a/python/paddle/jit/dy2static/__init__.py +++ b/python/paddle/jit/dy2static/__init__.py @@ -26,10 +26,7 @@ from .convert_operators import convert_pop # noqa: F401 from .convert_operators import convert_print # noqa: F401 from .convert_operators import convert_shape_compare # noqa: F401 from .convert_operators import convert_var_dtype # noqa: F401 -from .convert_operators import convert_var_shape # noqa: F401 -from .convert_operators import convert_var_shape_simple # noqa: F401 -from .convert_operators import eval_if_exist_else_none # noqa: F401 -from .convert_operators import choose_shape_attr_or_api # noqa: F401 +from .convert_operators import convert_shape # noqa: F401 from .convert_operators import convert_while_loop # noqa: F401 from .variable_trans_func import create_bool_as_type # noqa: F401 from .variable_trans_func import create_fill_constant_node # noqa: F401 diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index 8d67e06d9b..59ffedef0a 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -24,10 +24,7 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_pop # from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print # noqa: F401 from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_compare # noqa: F401 from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype # noqa: F401 -from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape # noqa: F401 -from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape_simple # noqa: F401 -from ...fluid.dygraph.dygraph_to_static.convert_operators import eval_if_exist_else_none # noqa: F401 -from ...fluid.dygraph.dygraph_to_static.convert_operators import choose_shape_attr_or_api # noqa: F401 +from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape # noqa: F401 from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop # noqa: F401 __all__ = [] -- GitLab