From ad55f609d5b4d8a2bc056d8f9175aa90cbec1dcd Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Fri, 8 Jan 2021 12:10:16 +0800 Subject: [PATCH] [Dy2Stat] Don't convert to paddle.shape if var_x.shape is not negetive (#29965) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. When x is Variable, call nn.shape(x) only in following cases: 1)The shape of x is used in control flow condition. 2)The dim to be used is negetive 2. When x is Variable, but x.shape or x.shape[idx] doesn't contain negetive value, don't convert to paddle.shape() --- .../dygraph_to_static/convert_operators.py | 28 +++- .../tensor_shape_transformer.py | 90 ++++++---- .../dygraph_to_static/test_tensor_shape.py | 154 +++++++++++++++++- 3 files changed, 234 insertions(+), 38 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 383ee9deb1..13574832bd 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -262,14 +262,34 @@ def convert_len(var): return len(var) -def convert_var_shape(x): +def convert_var_shape(x, idx=None, in_control_flow=False): """ A function representation of the shape of variable. """ - if isinstance(x, Variable): - return nn.shape(x) + + def has_negetive(list_shape, idx=None): + if idx is not None: + return list_shape[idx] < 0 + + num_negetive = sum([1 if i < 0 else 0 for i in list_shape]) + return num_negetive > 0 + + # When `x` is Variable, call nn.shape(x) in following cases: + # (1) The shape of `x` is used in control flow condition. + # ``` + # if x.shape[0] == 1: + # y = XX + # ``` + # (2) The dim to be used is negetive + # ``` + # # Assume x.shape=[3, -1] in static mode + # y = paddle.reshape(x, shape=[1, x.shape[1]]) + # ``` + if isinstance(x, Variable) and (in_control_flow or has_negetive(x.shape, + idx)): + return nn.shape(x) if idx is None else nn.shape(x)[idx] else: - return x.shape + return x.shape if idx is None else x.shape[idx] def convert_shape_compare(left, *args): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index 1fd4e5b6c7..7c45c10a48 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -24,21 +24,26 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor -def create_convert_shape_node(var_shape_node): +def create_convert_shape_node(var_shape_node, + slice_node=None, + in_control_flow=False): assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript)) - convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape" - if isinstance(var_shape_node, gast.Attribute): - api_shape_node = gast.Call( - func=gast.parse(convert_var_shape_func).body[0].value, - args=[var_shape_node.value], - keywords=[]) + args = [ast_to_source_code(var_shape_node.value).strip()] + if slice_node: + args.append(ast_to_source_code(slice_node).strip()) + + convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format( + ",".join(args), in_control_flow) + + api_shape_node = gast.parse(convert_var_shape_func).body[0].value return api_shape_node if isinstance(var_shape_node, gast.Subscript): result_node = copy.deepcopy(var_shape_node) - result_node.value = create_convert_shape_node(result_node.value) + result_node = create_convert_shape_node( + result_node.value, result_node.slice, in_control_flow) return result_node @@ -72,14 +77,30 @@ class TensorShapeTransformer(gast.NodeTransformer): self.generic_visit(node) return node + def visit_Subscript(self, node): + value_node = node.value + slice_node = node.slice + if isinstance(value_node, gast.Name): + if self._is_var_shape(value_node) and self._used_by_paddle_api( + value_node): + var_shape_node = self.name_to_var_shape[value_node.id] + return create_convert_shape_node(var_shape_node, slice_node) + + if isinstance(value_node, gast.Attribute): + if self._used_by_paddle_api(value_node) and self._is_var_shape( + value_node): + return create_convert_shape_node(value_node, slice_node) + + return node + def visit_Attribute(self, node): if self._used_by_paddle_api(node): - if self.is_var_shape(node): + if self._is_var_shape(node): return create_convert_shape_node(node) return node def visit_Name(self, node): - if node.id in self.name_to_var_shape: + if self._is_var_shape(node): if self._used_by_paddle_api(node): var_shape_node = self.name_to_var_shape[node.id] return create_convert_shape_node(var_shape_node) @@ -126,7 +147,7 @@ class TensorShapeTransformer(gast.NodeTransformer): return False args = node.iter.args for idx, arg in enumerate(args): - if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape: + if isinstance(arg, gast.Name) and self._is_var_shape(arg): args[idx] = create_convert_shape_node(self.name_to_var_shape[ arg.id]) @@ -136,11 +157,11 @@ class TensorShapeTransformer(gast.NodeTransformer): need_transformed = False for child_node in gast.walk(cond): var_shape_node = None - if isinstance(child_node, (gast.Attribute)): - if self.is_var_shape(child_node): + if isinstance(child_node, (gast.Attribute, gast.Subscript)): + if self._is_var_shape(child_node): var_shape_node = child_node elif isinstance(child_node, (gast.Name)): - if child_node.id in self.name_to_var_shape: + if self._is_var_shape(child_node): var_shape_node = self.name_to_var_shape[child_node.id] if var_shape_node: @@ -150,7 +171,8 @@ class TensorShapeTransformer(gast.NodeTransformer): for field, value in gast.iter_fields(parent_node): if child_node is value: setattr(parent_node, field, - create_convert_shape_node(var_shape_node)) + create_convert_shape_node(var_shape_node, None, + True)) break # Some child_node may be in a list such as gast.Compare if isinstance(value, list): @@ -158,7 +180,7 @@ class TensorShapeTransformer(gast.NodeTransformer): for i, v in enumerate(value): if child_node is v: value[i] = create_convert_shape_node( - var_shape_node) + var_shape_node, None, True) has_converted_shape = True break if has_converted_shape: @@ -182,24 +204,30 @@ class TensorShapeTransformer(gast.NodeTransformer): return False - def is_var_shape(self, node): + def _is_var_shape(self, node): """ - Return True if node is like `x.shape`, return False otherwise. + Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise. """ - assert isinstance(node, gast.Attribute) - - if node.attr != 'shape': + if not isinstance(node, (gast.Name, gast.Attribute, gast.Subscript)): return False - try: - value_id = node.value.id - except AttributeError: - return False + if isinstance(node, gast.Name) and node.id in self.name_to_var_shape: + return True + + if isinstance(node, gast.Attribute): + if node.attr != 'shape': + return False + + if not isinstance(node.value, gast.Name): + return False - if value_id in self.name_to_var_shape: return True - return True + if isinstance(node, gast.Subscript): + value_node = node.value + return self._is_var_shape(value_node) + + return False def _update_name_to_var_shape(self, node): assert isinstance(node, gast.Assign) @@ -223,7 +251,7 @@ class TensorShapeTransformer(gast.NodeTransformer): self.name_to_var_shape[target_id] = sub_node has_updated = True if isinstance(value_node, gast.Attribute): - if self.is_var_shape(value_node): # eg: x.shape + if self._is_var_shape(value_node): # eg: x.shape index_value_node = gast.Constant(value=idx, kind=None) slice_index_node = gast.Index(value=index_value_node) sub_node = gast.Subscript( @@ -238,17 +266,17 @@ class TensorShapeTransformer(gast.NodeTransformer): target_id = ast_to_source_code(target_node).strip() if isinstance(value_node, gast.Name): - if value_node.id in self.name_to_var_shape: + if self._is_var_shape(value_node): self.name_to_var_shape[target_id] = self.name_to_var_shape[ value_node.id] return True if isinstance(value_node, gast.Attribute): - if self.is_var_shape(value_node): # eg: x.shape + if self._is_var_shape(value_node): # eg: x.shape self.name_to_var_shape[target_id] = value_node return True if isinstance(value_node, gast.Subscript): if isinstance(value_node.value, gast.Attribute): - if self.is_var_shape(value_node.value): # eg: x.shape[0] + if self._is_var_shape(value_node.value): # eg: x.shape[0] self.name_to_var_shape[target_id] = value_node return True return False diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index 53dbb07c97..dfc8d2429f 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -192,11 +192,16 @@ class TestTensorShapeBasic(unittest.TestCase): self.input = numpy.ones(5).astype("int32") self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( ) else fluid.CPUPlace() + self._set_input_spec() + self._set_expected_op_num() self.init_test_func() def init_test_func(self): self.dygraph_func = dyfunc_tensor_shape_1 + def _set_input_spec(self): + self.input_spec = [paddle.static.InputSpec(shape=[5], dtype="int32")] + def _run(self, to_static): with fluid.dygraph.guard(): if to_static: @@ -219,6 +224,30 @@ class TestTensorShapeBasic(unittest.TestCase): msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res, static_res)) + def _set_expected_op_num(self): + self.expected_op_num = 2 + self.expected_shape_op_num = 0 + self.expected_slice_op_num = 0 + + def _compute_op_num(self, program): + self.op_num = sum([len(block.ops) for block in program.blocks]) + self.shape_op_num = 0 + self.slice_op_num = 0 + + for block in program.blocks: + self.shape_op_num += len( + [op for op in block.ops if op.type == "shape"]) + self.slice_op_num += len( + [op for op in block.ops if op.type == "slice"]) + + def test_op_num(self): + static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec) + program = static_layer.main_program + self._compute_op_num(program) + self.assertEqual(self.op_num, self.expected_op_num) + self.assertEqual(self.shape_op_num, self.expected_shape_op_num) + self.assertEqual(self.slice_op_num, self.expected_slice_op_num) + class TestTensorShapeBasic2(TestTensorShapeBasic): def init_test_func(self): @@ -243,12 +272,14 @@ class TestTensorShapeBasic5(TestTensorShapeBasic): class TestTupleShape1(TestTensorShapeBasic): def init_test_func(self): self.input = numpy.ones((5, 7)).astype("int32") + self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")] self.dygraph_func = dyfunc_tuple_shape_1 class TestTupleShape2(TestTensorShapeBasic): def init_test_func(self): self.input = numpy.ones((5, 7)).astype("int32") + self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")] self.dygraph_func = dyfunc_tuple_shape_2 @@ -257,30 +288,45 @@ class TestTensorShapeInIf1(TestTensorShapeBasic): def init_test_func(self): self.dygraph_func = dyfunc_with_if_1 + def _set_expected_op_num(self): + self.expected_op_num = 26 + self.expected_shape_op_num = 2 + self.expected_slice_op_num = 2 + class TestTensorShapeInIf2(TestTensorShapeBasic): def init_test_func(self): self.dygraph_func = dyfunc_with_if_2 + def _set_expected_op_num(self): + self.expected_op_num = 14 + self.expected_shape_op_num = 2 + self.expected_slice_op_num = 1 + # 3. Tests with control flow for loop class TestTensorShapeInFor1(TestTensorShapeBasic): def init_test_func(self): self.dygraph_func = dyfunc_with_for_1 + def _set_expected_op_num(self): + self.expected_op_num = 22 + self.expected_shape_op_num = 3 + self.expected_slice_op_num = 3 + -class TestTensorShapeInFor2(TestTensorShapeBasic): +class TestTensorShapeInFor2(TestTensorShapeInFor1): def init_test_func(self): self.dygraph_func = dyfunc_with_for_2 # 4. Tests with control flow while loop -class TestTensorShapeInWhile1(TestTensorShapeBasic): +class TestTensorShapeInWhile1(TestTensorShapeInFor1): def init_test_func(self): self.dygraph_func = dyfunc_with_while_1 -class TestTensorShapeInWhile2(TestTensorShapeBasic): +class TestTensorShapeInWhile2(TestTensorShapeInFor1): def init_test_func(self): self.dygraph_func = dyfunc_with_while_2 @@ -289,11 +335,113 @@ class TestTensorShapeInWhile3(TestTensorShapeBasic): def init_test_func(self): self.dygraph_func = dyfunc_with_while_3 + def _set_expected_op_num(self): + self.expected_op_num = 25 + self.expected_shape_op_num = 6 + self.expected_slice_op_num = 3 + class TestTensorShapeInWhile4(TestTensorShapeBasic): def init_test_func(self): self.dygraph_func = dyfunc_with_while_4 + def _set_expected_op_num(self): + self.expected_op_num = 5 + self.expected_shape_op_num = 0 + self.expected_slice_op_num = 0 + + +# 5. Test op num for negetive dim +class TestOpNumBasicWithTensorShape(unittest.TestCase): + def setUp(self): + self._set_input_spec() + self._set_test_func() + self._set_expected_op_num() + + def _set_input_spec(self): + self.input_spec = [ + paddle.static.InputSpec( + shape=[-1, 5], dtype="int32") + ] + + def _set_test_func(self): + self.dygraph_func = dyfunc_tensor_shape_1 + + def _set_expected_op_num(self): + self.expected_op_num = 3 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 0 + + def _compute_op_num(self, program): + self.op_num = sum([len(block.ops) for block in program.blocks]) + self.shape_op_num = 0 + self.slice_op_num = 0 + + for block in program.blocks: + self.shape_op_num += len( + [op for op in block.ops if op.type == "shape"]) + self.slice_op_num += len( + [op for op in block.ops if op.type == "slice"]) + + def test_op_num(self): + static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec) + program = static_layer.main_program + + self._compute_op_num(program) + self.assertEqual(self.op_num, self.expected_op_num) + self.assertEqual(self.shape_op_num, self.expected_shape_op_num) + self.assertEqual(self.slice_op_num, self.expected_slice_op_num) + + +class TestOpNumBasicWithTensorShape4(TestOpNumBasicWithTensorShape): + def _set_test_func(self): + self.dygraph_func = dyfunc_tensor_shape_4 + + def _set_expected_op_num(self): + self.expected_op_num = 6 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 1 + + +class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape): + def _set_test_func(self): + self.dygraph_func = dyfunc_tuple_shape_1 + + def _set_expected_op_num(self): + self.expected_op_num = 5 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 1 + + +class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape): + def _set_test_func(self): + self.dygraph_func = dyfunc_with_if_1 + + def _set_expected_op_num(self): + self.expected_op_num = 28 + self.expected_shape_op_num = 4 + self.expected_slice_op_num = 2 + + +class TestOpNumWithTensorShapeInFor1(TestOpNumBasicWithTensorShape): + def _set_test_func(self): + self.dygraph_func = dyfunc_with_for_1 + + def _set_expected_op_num(self): + self.expected_op_num = 22 + self.expected_shape_op_num = 3 + self.expected_slice_op_num = 3 + + +class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape): + def _set_test_func(self): + self.dygraph_func = dyfunc_with_while_1 + + def _set_expected_op_num(self): + self.expected_op_num = 22 + self.expected_shape_op_num = 3 + self.expected_slice_op_num = 3 + if __name__ == '__main__': unittest.main() -- GitLab