未验证 提交 ad55f609 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat] Don't convert to paddle.shape if var_x.shape is not negetive (#29965)

1. When x is Variable, call nn.shape(x) only in following cases:
 1)The shape of x is used in control flow condition.
 2)The dim to be used is negetive
2. When x is Variable, but x.shape or x.shape[idx] doesn't contain negetive value, don't convert to paddle.shape()
上级 1f97d61c
...@@ -262,14 +262,34 @@ def convert_len(var): ...@@ -262,14 +262,34 @@ def convert_len(var):
return len(var) return len(var)
def convert_var_shape(x): def convert_var_shape(x, idx=None, in_control_flow=False):
""" """
A function representation of the shape of variable. A function representation of the shape of variable.
""" """
if isinstance(x, Variable):
return nn.shape(x) def has_negetive(list_shape, idx=None):
if idx is not None:
return list_shape[idx] < 0
num_negetive = sum([1 if i < 0 else 0 for i in list_shape])
return num_negetive > 0
# When `x` is Variable, call nn.shape(x) in following cases:
# (1) The shape of `x` is used in control flow condition.
# ```
# if x.shape[0] == 1:
# y = XX
# ```
# (2) The dim to be used is negetive
# ```
# # Assume x.shape=[3, -1] in static mode
# y = paddle.reshape(x, shape=[1, x.shape[1]])
# ```
if isinstance(x, Variable) and (in_control_flow or has_negetive(x.shape,
idx)):
return nn.shape(x) if idx is None else nn.shape(x)[idx]
else: else:
return x.shape return x.shape if idx is None else x.shape[idx]
def convert_shape_compare(left, *args): def convert_shape_compare(left, *args):
......
...@@ -24,21 +24,26 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe ...@@ -24,21 +24,26 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
def create_convert_shape_node(var_shape_node): def create_convert_shape_node(var_shape_node,
slice_node=None,
in_control_flow=False):
assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript)) assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript))
convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape"
if isinstance(var_shape_node, gast.Attribute): if isinstance(var_shape_node, gast.Attribute):
api_shape_node = gast.Call( args = [ast_to_source_code(var_shape_node.value).strip()]
func=gast.parse(convert_var_shape_func).body[0].value, if slice_node:
args=[var_shape_node.value], args.append(ast_to_source_code(slice_node).strip())
keywords=[])
convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format(
",".join(args), in_control_flow)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value
return api_shape_node return api_shape_node
if isinstance(var_shape_node, gast.Subscript): if isinstance(var_shape_node, gast.Subscript):
result_node = copy.deepcopy(var_shape_node) result_node = copy.deepcopy(var_shape_node)
result_node.value = create_convert_shape_node(result_node.value) result_node = create_convert_shape_node(
result_node.value, result_node.slice, in_control_flow)
return result_node return result_node
...@@ -72,14 +77,30 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -72,14 +77,30 @@ class TensorShapeTransformer(gast.NodeTransformer):
self.generic_visit(node) self.generic_visit(node)
return node return node
def visit_Subscript(self, node):
value_node = node.value
slice_node = node.slice
if isinstance(value_node, gast.Name):
if self._is_var_shape(value_node) and self._used_by_paddle_api(
value_node):
var_shape_node = self.name_to_var_shape[value_node.id]
return create_convert_shape_node(var_shape_node, slice_node)
if isinstance(value_node, gast.Attribute):
if self._used_by_paddle_api(value_node) and self._is_var_shape(
value_node):
return create_convert_shape_node(value_node, slice_node)
return node
def visit_Attribute(self, node): def visit_Attribute(self, node):
if self._used_by_paddle_api(node): if self._used_by_paddle_api(node):
if self.is_var_shape(node): if self._is_var_shape(node):
return create_convert_shape_node(node) return create_convert_shape_node(node)
return node return node
def visit_Name(self, node): def visit_Name(self, node):
if node.id in self.name_to_var_shape: if self._is_var_shape(node):
if self._used_by_paddle_api(node): if self._used_by_paddle_api(node):
var_shape_node = self.name_to_var_shape[node.id] var_shape_node = self.name_to_var_shape[node.id]
return create_convert_shape_node(var_shape_node) return create_convert_shape_node(var_shape_node)
...@@ -126,7 +147,7 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -126,7 +147,7 @@ class TensorShapeTransformer(gast.NodeTransformer):
return False return False
args = node.iter.args args = node.iter.args
for idx, arg in enumerate(args): for idx, arg in enumerate(args):
if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape: if isinstance(arg, gast.Name) and self._is_var_shape(arg):
args[idx] = create_convert_shape_node(self.name_to_var_shape[ args[idx] = create_convert_shape_node(self.name_to_var_shape[
arg.id]) arg.id])
...@@ -136,11 +157,11 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -136,11 +157,11 @@ class TensorShapeTransformer(gast.NodeTransformer):
need_transformed = False need_transformed = False
for child_node in gast.walk(cond): for child_node in gast.walk(cond):
var_shape_node = None var_shape_node = None
if isinstance(child_node, (gast.Attribute)): if isinstance(child_node, (gast.Attribute, gast.Subscript)):
if self.is_var_shape(child_node): if self._is_var_shape(child_node):
var_shape_node = child_node var_shape_node = child_node
elif isinstance(child_node, (gast.Name)): elif isinstance(child_node, (gast.Name)):
if child_node.id in self.name_to_var_shape: if self._is_var_shape(child_node):
var_shape_node = self.name_to_var_shape[child_node.id] var_shape_node = self.name_to_var_shape[child_node.id]
if var_shape_node: if var_shape_node:
...@@ -150,7 +171,8 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -150,7 +171,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
for field, value in gast.iter_fields(parent_node): for field, value in gast.iter_fields(parent_node):
if child_node is value: if child_node is value:
setattr(parent_node, field, setattr(parent_node, field,
create_convert_shape_node(var_shape_node)) create_convert_shape_node(var_shape_node, None,
True))
break break
# Some child_node may be in a list such as gast.Compare # Some child_node may be in a list such as gast.Compare
if isinstance(value, list): if isinstance(value, list):
...@@ -158,7 +180,7 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -158,7 +180,7 @@ class TensorShapeTransformer(gast.NodeTransformer):
for i, v in enumerate(value): for i, v in enumerate(value):
if child_node is v: if child_node is v:
value[i] = create_convert_shape_node( value[i] = create_convert_shape_node(
var_shape_node) var_shape_node, None, True)
has_converted_shape = True has_converted_shape = True
break break
if has_converted_shape: if has_converted_shape:
...@@ -182,24 +204,30 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -182,24 +204,30 @@ class TensorShapeTransformer(gast.NodeTransformer):
return False return False
def is_var_shape(self, node): def _is_var_shape(self, node):
""" """
Return True if node is like `x.shape`, return False otherwise. Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise.
""" """
assert isinstance(node, gast.Attribute) if not isinstance(node, (gast.Name, gast.Attribute, gast.Subscript)):
if node.attr != 'shape':
return False return False
try: if isinstance(node, gast.Name) and node.id in self.name_to_var_shape:
value_id = node.value.id return True
except AttributeError:
return False if isinstance(node, gast.Attribute):
if node.attr != 'shape':
return False
if not isinstance(node.value, gast.Name):
return False
if value_id in self.name_to_var_shape:
return True return True
return True if isinstance(node, gast.Subscript):
value_node = node.value
return self._is_var_shape(value_node)
return False
def _update_name_to_var_shape(self, node): def _update_name_to_var_shape(self, node):
assert isinstance(node, gast.Assign) assert isinstance(node, gast.Assign)
...@@ -223,7 +251,7 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -223,7 +251,7 @@ class TensorShapeTransformer(gast.NodeTransformer):
self.name_to_var_shape[target_id] = sub_node self.name_to_var_shape[target_id] = sub_node
has_updated = True has_updated = True
if isinstance(value_node, gast.Attribute): if isinstance(value_node, gast.Attribute):
if self.is_var_shape(value_node): # eg: x.shape if self._is_var_shape(value_node): # eg: x.shape
index_value_node = gast.Constant(value=idx, kind=None) index_value_node = gast.Constant(value=idx, kind=None)
slice_index_node = gast.Index(value=index_value_node) slice_index_node = gast.Index(value=index_value_node)
sub_node = gast.Subscript( sub_node = gast.Subscript(
...@@ -238,17 +266,17 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -238,17 +266,17 @@ class TensorShapeTransformer(gast.NodeTransformer):
target_id = ast_to_source_code(target_node).strip() target_id = ast_to_source_code(target_node).strip()
if isinstance(value_node, gast.Name): if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_var_shape: if self._is_var_shape(value_node):
self.name_to_var_shape[target_id] = self.name_to_var_shape[ self.name_to_var_shape[target_id] = self.name_to_var_shape[
value_node.id] value_node.id]
return True return True
if isinstance(value_node, gast.Attribute): if isinstance(value_node, gast.Attribute):
if self.is_var_shape(value_node): # eg: x.shape if self._is_var_shape(value_node): # eg: x.shape
self.name_to_var_shape[target_id] = value_node self.name_to_var_shape[target_id] = value_node
return True return True
if isinstance(value_node, gast.Subscript): if isinstance(value_node, gast.Subscript):
if isinstance(value_node.value, gast.Attribute): if isinstance(value_node.value, gast.Attribute):
if self.is_var_shape(value_node.value): # eg: x.shape[0] if self._is_var_shape(value_node.value): # eg: x.shape[0]
self.name_to_var_shape[target_id] = value_node self.name_to_var_shape[target_id] = value_node
return True return True
return False return False
...@@ -192,11 +192,16 @@ class TestTensorShapeBasic(unittest.TestCase): ...@@ -192,11 +192,16 @@ class TestTensorShapeBasic(unittest.TestCase):
self.input = numpy.ones(5).astype("int32") self.input = numpy.ones(5).astype("int32")
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace() ) else fluid.CPUPlace()
self._set_input_spec()
self._set_expected_op_num()
self.init_test_func() self.init_test_func()
def init_test_func(self): def init_test_func(self):
self.dygraph_func = dyfunc_tensor_shape_1 self.dygraph_func = dyfunc_tensor_shape_1
def _set_input_spec(self):
self.input_spec = [paddle.static.InputSpec(shape=[5], dtype="int32")]
def _run(self, to_static): def _run(self, to_static):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
if to_static: if to_static:
...@@ -219,6 +224,30 @@ class TestTensorShapeBasic(unittest.TestCase): ...@@ -219,6 +224,30 @@ class TestTensorShapeBasic(unittest.TestCase):
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res, msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res)) static_res))
def _set_expected_op_num(self):
self.expected_op_num = 2
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
def _compute_op_num(self, program):
self.op_num = sum([len(block.ops) for block in program.blocks])
self.shape_op_num = 0
self.slice_op_num = 0
for block in program.blocks:
self.shape_op_num += len(
[op for op in block.ops if op.type == "shape"])
self.slice_op_num += len(
[op for op in block.ops if op.type == "slice"])
def test_op_num(self):
static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec)
program = static_layer.main_program
self._compute_op_num(program)
self.assertEqual(self.op_num, self.expected_op_num)
self.assertEqual(self.shape_op_num, self.expected_shape_op_num)
self.assertEqual(self.slice_op_num, self.expected_slice_op_num)
class TestTensorShapeBasic2(TestTensorShapeBasic): class TestTensorShapeBasic2(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
...@@ -243,12 +272,14 @@ class TestTensorShapeBasic5(TestTensorShapeBasic): ...@@ -243,12 +272,14 @@ class TestTensorShapeBasic5(TestTensorShapeBasic):
class TestTupleShape1(TestTensorShapeBasic): class TestTupleShape1(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
self.input = numpy.ones((5, 7)).astype("int32") self.input = numpy.ones((5, 7)).astype("int32")
self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")]
self.dygraph_func = dyfunc_tuple_shape_1 self.dygraph_func = dyfunc_tuple_shape_1
class TestTupleShape2(TestTensorShapeBasic): class TestTupleShape2(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
self.input = numpy.ones((5, 7)).astype("int32") self.input = numpy.ones((5, 7)).astype("int32")
self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")]
self.dygraph_func = dyfunc_tuple_shape_2 self.dygraph_func = dyfunc_tuple_shape_2
...@@ -257,30 +288,45 @@ class TestTensorShapeInIf1(TestTensorShapeBasic): ...@@ -257,30 +288,45 @@ class TestTensorShapeInIf1(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
self.dygraph_func = dyfunc_with_if_1 self.dygraph_func = dyfunc_with_if_1
def _set_expected_op_num(self):
self.expected_op_num = 26
self.expected_shape_op_num = 2
self.expected_slice_op_num = 2
class TestTensorShapeInIf2(TestTensorShapeBasic): class TestTensorShapeInIf2(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
self.dygraph_func = dyfunc_with_if_2 self.dygraph_func = dyfunc_with_if_2
def _set_expected_op_num(self):
self.expected_op_num = 14
self.expected_shape_op_num = 2
self.expected_slice_op_num = 1
# 3. Tests with control flow for loop # 3. Tests with control flow for loop
class TestTensorShapeInFor1(TestTensorShapeBasic): class TestTensorShapeInFor1(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
self.dygraph_func = dyfunc_with_for_1 self.dygraph_func = dyfunc_with_for_1
def _set_expected_op_num(self):
self.expected_op_num = 22
self.expected_shape_op_num = 3
self.expected_slice_op_num = 3
class TestTensorShapeInFor2(TestTensorShapeBasic): class TestTensorShapeInFor2(TestTensorShapeInFor1):
def init_test_func(self): def init_test_func(self):
self.dygraph_func = dyfunc_with_for_2 self.dygraph_func = dyfunc_with_for_2
# 4. Tests with control flow while loop # 4. Tests with control flow while loop
class TestTensorShapeInWhile1(TestTensorShapeBasic): class TestTensorShapeInWhile1(TestTensorShapeInFor1):
def init_test_func(self): def init_test_func(self):
self.dygraph_func = dyfunc_with_while_1 self.dygraph_func = dyfunc_with_while_1
class TestTensorShapeInWhile2(TestTensorShapeBasic): class TestTensorShapeInWhile2(TestTensorShapeInFor1):
def init_test_func(self): def init_test_func(self):
self.dygraph_func = dyfunc_with_while_2 self.dygraph_func = dyfunc_with_while_2
...@@ -289,11 +335,113 @@ class TestTensorShapeInWhile3(TestTensorShapeBasic): ...@@ -289,11 +335,113 @@ class TestTensorShapeInWhile3(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
self.dygraph_func = dyfunc_with_while_3 self.dygraph_func = dyfunc_with_while_3
def _set_expected_op_num(self):
self.expected_op_num = 25
self.expected_shape_op_num = 6
self.expected_slice_op_num = 3
class TestTensorShapeInWhile4(TestTensorShapeBasic): class TestTensorShapeInWhile4(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
self.dygraph_func = dyfunc_with_while_4 self.dygraph_func = dyfunc_with_while_4
def _set_expected_op_num(self):
self.expected_op_num = 5
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
# 5. Test op num for negetive dim
class TestOpNumBasicWithTensorShape(unittest.TestCase):
def setUp(self):
self._set_input_spec()
self._set_test_func()
self._set_expected_op_num()
def _set_input_spec(self):
self.input_spec = [
paddle.static.InputSpec(
shape=[-1, 5], dtype="int32")
]
def _set_test_func(self):
self.dygraph_func = dyfunc_tensor_shape_1
def _set_expected_op_num(self):
self.expected_op_num = 3
self.expected_shape_op_num = 1
self.expected_slice_op_num = 0
def _compute_op_num(self, program):
self.op_num = sum([len(block.ops) for block in program.blocks])
self.shape_op_num = 0
self.slice_op_num = 0
for block in program.blocks:
self.shape_op_num += len(
[op for op in block.ops if op.type == "shape"])
self.slice_op_num += len(
[op for op in block.ops if op.type == "slice"])
def test_op_num(self):
static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec)
program = static_layer.main_program
self._compute_op_num(program)
self.assertEqual(self.op_num, self.expected_op_num)
self.assertEqual(self.shape_op_num, self.expected_shape_op_num)
self.assertEqual(self.slice_op_num, self.expected_slice_op_num)
class TestOpNumBasicWithTensorShape4(TestOpNumBasicWithTensorShape):
def _set_test_func(self):
self.dygraph_func = dyfunc_tensor_shape_4
def _set_expected_op_num(self):
self.expected_op_num = 6
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape):
def _set_test_func(self):
self.dygraph_func = dyfunc_tuple_shape_1
def _set_expected_op_num(self):
self.expected_op_num = 5
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape):
def _set_test_func(self):
self.dygraph_func = dyfunc_with_if_1
def _set_expected_op_num(self):
self.expected_op_num = 28
self.expected_shape_op_num = 4
self.expected_slice_op_num = 2
class TestOpNumWithTensorShapeInFor1(TestOpNumBasicWithTensorShape):
def _set_test_func(self):
self.dygraph_func = dyfunc_with_for_1
def _set_expected_op_num(self):
self.expected_op_num = 22
self.expected_shape_op_num = 3
self.expected_slice_op_num = 3
class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape):
def _set_test_func(self):
self.dygraph_func = dyfunc_with_while_1
def _set_expected_op_num(self):
self.expected_op_num = 22
self.expected_shape_op_num = 3
self.expected_slice_op_num = 3
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册