未验证 提交 3a72408f 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Cherry-pick][Dy2stat] Cherry-pick of PR31082 and PR31051 (#31101)

Cherry-pick of #31051 and #31082
上级 29467060
......@@ -267,12 +267,12 @@ def convert_var_shape(x, idx=None, in_control_flow=False):
A function representation of the shape of variable.
"""
def has_negetive(list_shape, idx=None):
def has_negative(list_shape, idx=None):
if idx is not None:
return list_shape[idx] < 0
num_negetive = sum([1 if i < 0 else 0 for i in list_shape])
return num_negetive > 0
num_negative = sum([1 if i < 0 else 0 for i in list_shape])
return num_negative > 0
# When `x` is Variable, call nn.shape(x) in following cases:
# (1) The shape of `x` is used in control flow condition.
......@@ -280,18 +280,62 @@ def convert_var_shape(x, idx=None, in_control_flow=False):
# if x.shape[0] == 1:
# y = XX
# ```
# (2) The dim to be used is negetive
# (2) The dim to be used is negative
# ```
# # Assume x.shape=[3, -1] in static mode
# y = paddle.reshape(x, shape=[1, x.shape[1]])
# ```
if isinstance(x, Variable) and (in_control_flow or has_negetive(x.shape,
if isinstance(x, Variable) and (in_control_flow or has_negative(x.shape,
idx)):
return nn.shape(x) if idx is None else nn.shape(x)[idx]
else:
return x.shape if idx is None else x.shape[idx]
def convert_var_shape_simple(x):
"""
A function representation of the shape of variable.
"""
if isinstance(x, Variable):
return nn.shape(x)
else:
return x.shape
def eval_if_exist_else_none(name):
try:
return eval(name)
except:
return None
def choose_shape_attr_or_api(attr_shape, api_shape, idx=None):
"""
Input can be attribute `x.shape` or api `shape(x)`, this function
chooses which one to return to use in dy2stat.
Note: sometimes users write `x.shape[3]`, so attr_shape can be an integer.
"""
if api_shape is None:
return attr_shape if idx is None else attr_shape[idx]
if not isinstance(attr_shape, (list, tuple)):
# some variables like x.shape[0] is no longer a list or tuple
if isinstance(attr_shape, int) and attr_shape < 0:
return api_shape if idx is None else api_shape[idx]
return attr_shape if idx is None else attr_shape[idx]
def has_negative(list_shape, idx=None):
if idx is not None:
return list_shape[idx] < 0
num_negative = sum([1 if i < 0 else 0 for i in list_shape])
return num_negative > 0
if has_negative(attr_shape, idx):
return api_shape if idx is None else api_shape[idx]
return attr_shape if idx is None else attr_shape[idx]
def convert_shape_compare(left, *args):
"""
A function handles comparison difference between Paddle and Python.
......
......@@ -17,12 +17,15 @@ from __future__ import print_function
import copy
import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
STATIC_CONVERT_VAR_SHAPE_SUFFIX = '__static_convert_var_shape_suffix'
def create_convert_shape_node(var_shape_node,
slice_node=None,
......@@ -31,13 +34,20 @@ def create_convert_shape_node(var_shape_node,
if isinstance(var_shape_node, gast.Attribute):
args = [ast_to_source_code(var_shape_node.value).strip()]
if slice_node:
# (1) A slice can be a simple number such as 1, -2, i.e. gast.Index
# (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index
# In (1) case, we pass the number as 'idx' argument in convert_var_shape
# In (2) case, we have to make it like `convert_var_shape(x)[slice]`
if slice_node is not None and isinstance(slice_node, gast.Index):
args.append(ast_to_source_code(slice_node).strip())
convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format(
",".join(args), in_control_flow)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value
if slice_node is not None and not isinstance(slice_node, gast.Index):
return gast.Subscript(
value=api_shape_node, slice=slice_node, ctx=gast.Load())
return api_shape_node
if isinstance(var_shape_node, gast.Subscript):
......@@ -47,6 +57,39 @@ def create_convert_shape_node(var_shape_node,
return result_node
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}')".format(
api_shape_name)
args = [attr_shape_name, eval_exist_func]
if slice_node is not None and isinstance(slice_node, gast.Index):
args.append(ast_to_source_code(slice_node).strip())
choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format(
",".join(args))
choose_shape_node = gast.parse(choose_shape_func).body[0].value
if slice_node is not None and not isinstance(slice_node, gast.Index):
return gast.Subscript(
value=choose_shape_node, slice=slice_node, ctx=gast.Load())
return choose_shape_node
class ShapeAttributeTransformer(gast.NodeTransformer):
"""
Input a node like `x.shape` or `x[4].shape[0]` (self._is_var_shape(node) is True),
return a new node changes input to static shape API like `convert_var_shape(x)`,
`convert_var_shape(x[4])[0]`.
"""
def visit_Attribute(self, node):
if node.attr == 'shape':
args = ast_to_source_code(node.value).strip()
convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape_simple({})".format(
args)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value
return api_shape_node
return node
class TensorShapeTransformer(gast.NodeTransformer):
"""
This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast.
......@@ -58,6 +101,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
# stores origin var string name (like "x" in `x = t.shape`) to
# static shape var string name (like "x_SUFFIX" in `x_SUFFIX = shape(t)`)
self.name_to_var_shape = {}
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
......@@ -72,8 +117,11 @@ class TensorShapeTransformer(gast.NodeTransformer):
self.visit(self.root)
def visit_Assign(self, node):
if self._update_name_to_var_shape(node):
return node
update_static_shape_var_node = self._update_name_to_var_shape(node)
if update_static_shape_var_node is not None:
ret = [node]
ret.extend(update_static_shape_var_node)
return ret
self.generic_visit(node)
return node
......@@ -81,37 +129,44 @@ class TensorShapeTransformer(gast.NodeTransformer):
value_node = node.value
slice_node = node.slice
if isinstance(value_node, gast.Name):
if self._is_var_shape(value_node) and self._used_by_paddle_api(
value_node):
var_shape_node = self.name_to_var_shape[value_node.id]
return create_convert_shape_node(var_shape_node, slice_node)
if isinstance(value_node, gast.Attribute):
if self._used_by_paddle_api(value_node) and self._is_var_shape(
if value_node.id in self.name_to_var_shape and self._used_by_paddle_api(
value_node):
return create_convert_shape_node(value_node, slice_node)
return create_choose_shape_node(
value_node.id, self.name_to_var_shape[value_node.id],
slice_node)
elif isinstance(value_node, gast.Attribute):
if self._used_by_paddle_api(value_node):
value_name = ast_to_source_code(value_node).strip()
if value_name in self.name_to_var_shape:
return create_choose_shape_node(
value_name, self.name_to_var_shape[value_name],
slice_node)
if self._is_var_shape(value_node):
return create_convert_shape_node(value_node, slice_node)
return node
def visit_Attribute(self, node):
if self._used_by_paddle_api(node):
name = ast_to_source_code(node).strip()
if name in self.name_to_var_shape:
return create_choose_shape_node(name,
self.name_to_var_shape[name])
if self._is_var_shape(node):
return create_convert_shape_node(node)
return node
def visit_Name(self, node):
if self._is_var_shape(node):
if node.id in self.name_to_var_shape:
if self._used_by_paddle_api(node):
var_shape_node = self.name_to_var_shape[node.id]
return create_convert_shape_node(var_shape_node)
return create_choose_shape_node(node.id,
self.name_to_var_shape[node.id])
return node
def visit_Call(self, node):
assert isinstance(node, gast.Call)
if is_paddle_api(node):
# Visit gast.Attribute and gast.Name to replace var.shape if necessary.
self.generic_visit(node)
# Don't have to visit other APIs
return node
def visit_If(self, node):
......@@ -147,22 +202,23 @@ class TensorShapeTransformer(gast.NodeTransformer):
return False
args = node.iter.args
for idx, arg in enumerate(args):
if isinstance(arg, gast.Name) and self._is_var_shape(arg):
args[idx] = create_convert_shape_node(self.name_to_var_shape[
arg.id])
if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape:
args[idx] = create_choose_shape_node(
arg.id, self.name_to_var_shape[arg.id])
return True
def _transform_var_shape_if_necessary(self, cond):
need_transformed = False
for child_node in gast.walk(cond):
var_shape_node = None
if isinstance(child_node, (gast.Attribute, gast.Subscript)):
if self._is_var_shape(child_node):
if isinstance(child_node,
(gast.Name, gast.Attribute, gast.Subscript)):
child_name = ast_to_source_code(child_node).strip()
if child_name in self.name_to_var_shape:
var_shape_node = create_choose_shape_node(
child_name, self.name_to_var_shape[child_name])
elif self._is_var_shape(child_node):
var_shape_node = child_node
elif isinstance(child_node, (gast.Name)):
if self._is_var_shape(child_node):
var_shape_node = self.name_to_var_shape[child_node.id]
if var_shape_node:
need_transformed = True
......@@ -170,17 +226,23 @@ class TensorShapeTransformer(gast.NodeTransformer):
parent_node = wrapper_node.parent.node
for field, value in gast.iter_fields(parent_node):
if child_node is value:
setattr(parent_node, field,
create_convert_shape_node(var_shape_node, None,
True))
if var_shape_node is child_node:
setattr(parent_node, field,
create_convert_shape_node(var_shape_node,
None, True))
else:
setattr(parent_node, field, var_shape_node)
break
# Some child_node may be in a list such as gast.Compare
if isinstance(value, list):
has_converted_shape = False
for i, v in enumerate(value):
if child_node is v:
value[i] = create_convert_shape_node(
var_shape_node, None, True)
if var_shape_node is child_node:
value[i] = create_convert_shape_node(
var_shape_node, None, True)
else:
value[i] = var_shape_node
has_converted_shape = True
break
if has_converted_shape:
......@@ -217,19 +279,12 @@ class TensorShapeTransformer(gast.NodeTransformer):
"""
Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise.
"""
if not isinstance(node, (gast.Name, gast.Attribute, gast.Subscript)):
if not isinstance(node, (gast.Attribute, gast.Subscript)):
return False
if isinstance(node, gast.Name) and node.id in self.name_to_var_shape:
return True
if isinstance(node, gast.Attribute):
if node.attr != 'shape':
return False
if not isinstance(node.value, gast.Name):
return False
return True
if isinstance(node, gast.Subscript):
......@@ -243,49 +298,94 @@ class TensorShapeTransformer(gast.NodeTransformer):
target_node = node.targets[0]
value_node = node.value
update_static_shape_var_node = None
if isinstance(target_node, gast.Tuple):
has_updated = False
update_static_shape_var_node = []
for idx, element in enumerate(target_node.elts):
target_id = ast_to_source_code(element).strip()
if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_var_shape:
# TODO(zhhsplendid): is context a problem for the result node of gast.parse?
static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
static_shape_value_name = self.name_to_var_shape[
value_node.id]
static_shape_value_node = gast.parse(
static_shape_value_name).body[0].value
index_value_node = gast.Constant(value=idx, kind=None)
slice_index_node = gast.Index(value=index_value_node)
var_shape_node = self.name_to_var_shape[value_node.id]
sub_node = gast.Subscript(
value=var_shape_node,
value=static_shape_value_node,
slice=slice_index_node,
ctx=gast.Load())
self.name_to_var_shape[target_id] = sub_node
has_updated = True
update_static_shape_var_node.append(
gast.Assign(
targets=[static_shape_var_node],
value=sub_node))
self.name_to_var_shape[
target_id] = static_shape_var_name
if isinstance(value_node, gast.Attribute):
if self._is_var_shape(value_node): # eg: x.shape
static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
static_shape_value_node = copy.deepcopy(value_node)
# x.shape becomes convert_var_shape_simple(x)
ShapeAttributeTransformer().visit(
static_shape_value_node)
index_value_node = gast.Constant(value=idx, kind=None)
slice_index_node = gast.Index(value=index_value_node)
sub_node = gast.Subscript(
value=value_node,
value=static_shape_value_node,
slice=slice_index_node,
ctx=gast.Load())
self.name_to_var_shape[target_id] = sub_node
has_updated = True
return has_updated
update_static_shape_var_node.append(
gast.Assign(
targets=[static_shape_var_node],
value=sub_node))
self.name_to_var_shape[
target_id] = static_shape_var_name
return update_static_shape_var_node
else:
target_id = ast_to_source_code(target_node).strip()
if isinstance(value_node, gast.Name):
if self._is_var_shape(value_node):
self.name_to_var_shape[target_id] = self.name_to_var_shape[
if value_node.id in self.name_to_var_shape:
static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
static_shape_value_name = self.name_to_var_shape[
value_node.id]
return True
if isinstance(value_node, gast.Attribute):
if self._is_var_shape(value_node): # eg: x.shape
self.name_to_var_shape[target_id] = value_node
return True
if isinstance(value_node, gast.Subscript):
if isinstance(value_node.value, gast.Attribute):
if self._is_var_shape(value_node.value): # eg: x.shape[0]
self.name_to_var_shape[target_id] = value_node
return True
return False
static_shape_value_node = gast.parse(
static_shape_value_name).body[0].value
update_static_shape_var_node = [
gast.Assign(
targets=[static_shape_var_node],
value=static_shape_value_node)
]
self.name_to_var_shape[target_id] = static_shape_var_name
elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0]
static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(static_shape_var_name).body[
0].value
static_shape_value_node = copy.deepcopy(value_node)
# x.shape becomes convert_var_shape_simple(x)
ShapeAttributeTransformer().visit(static_shape_value_node)
update_static_shape_var_node = [
gast.Assign(
targets=[static_shape_var_node],
value=static_shape_value_node)
]
self.name_to_var_shape[target_id] = static_shape_var_name
return update_static_shape_var_node
......@@ -136,5 +136,58 @@ class TestConvertShapeCompare(unittest.TestCase):
paddle.disable_static()
class TestChooseShapeAttrOrApi(unittest.TestCase):
def test_api_shape_is_none(self):
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api([1, 2], None),
[1, 2])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api([1], None), [1])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api([2, 3, 7], None, 0),
2)
def test_attr_shape_is_int(self):
x = paddle.zeros([1, 3, 5, 7])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(x.shape[0],
paddle.shape(x)[0]),
1)
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(x.shape[1],
paddle.shape(x)[1]),
3)
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(-1,
paddle.shape(x)[0]),
paddle.shape(x)[0])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(-1,
paddle.shape(x), 0),
paddle.shape(x)[0])
def test_positive_attr_shape(self):
x = paddle.zeros([1, 3, 5, 7])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(x.shape,
paddle.shape(x)),
x.shape)
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(x.shape,
paddle.shape(x), 3),
x.shape[3])
def test_negative_attr_shape(self):
x = paddle.zeros([7])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api([-1],
paddle.shape(x), 0),
paddle.shape(x)[0])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api([-1],
paddle.shape(x)),
paddle.shape(x))
if __name__ == '__main__':
unittest.main()
......@@ -60,6 +60,16 @@ def dyfunc_tensor_shape_5(x):
return res
def dyfunc_tensor_shape_6(x):
# `res = fluid.layers.reshape(x, shape=(-1, s))` to
# `res = fluid.layers.reshape(x, shape=(-1,
# paddle.jit.dy2static.convert_var_shape(x)[0:]))`
x = fluid.dygraph.to_variable(x)
s = x.shape[0:]
res = fluid.layers.reshape(x, shape=s)
return res
def dyfunc_tuple_shape_1(x):
x = paddle.to_tensor(x)
a, b = x.shape
......@@ -197,6 +207,14 @@ def dyfunc_with_while_4(x):
return x
def dyfunc_change_shape_after_assign(x):
x = paddle.to_tensor(x)
a, b = x.shape
x = paddle.reshape(x, shape=(-1, 1))
res = paddle.reshape(x, shape=(b, a))
return res
# 1. Basic tests without control flow
class TestTensorShapeBasic(unittest.TestCase):
def setUp(self):
......@@ -279,6 +297,21 @@ class TestTensorShapeBasic5(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_tensor_shape_5
def _set_expected_op_num(self):
self.expected_op_num = 4
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
class TestTensorShapeBasic6(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_tensor_shape_6
def _set_expected_op_num(self):
self.expected_op_num = 4
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
class TestTupleShape1(TestTensorShapeBasic):
def init_test_func(self):
......@@ -312,9 +345,9 @@ class TestTensorShapeInIf1(TestTensorShapeBasic):
self.dygraph_func = dyfunc_with_if_1
def _set_expected_op_num(self):
self.expected_op_num = 26
self.expected_shape_op_num = 2
self.expected_slice_op_num = 2
self.expected_op_num = 4
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
class TestTensorShapeInIf2(TestTensorShapeBasic):
......@@ -342,6 +375,11 @@ class TestTensorShapeInFor2(TestTensorShapeInFor1):
def init_test_func(self):
self.dygraph_func = dyfunc_with_for_2
def _set_expected_op_num(self):
self.expected_op_num = 9
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
# 4. Tests with control flow while loop
class TestTensorShapeInWhile1(TestTensorShapeInFor1):
......@@ -353,15 +391,20 @@ class TestTensorShapeInWhile2(TestTensorShapeInFor1):
def init_test_func(self):
self.dygraph_func = dyfunc_with_while_2
def _set_expected_op_num(self):
self.expected_op_num = 6
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
class TestTensorShapeInWhile3(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_with_while_3
def _set_expected_op_num(self):
self.expected_op_num = 25
self.expected_shape_op_num = 6
self.expected_slice_op_num = 3
self.expected_op_num = 2
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
class TestTensorShapeInWhile4(TestTensorShapeBasic):
......@@ -431,9 +474,9 @@ class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape):
self.dygraph_func = dyfunc_tuple_shape_1
def _set_expected_op_num(self):
self.expected_op_num = 5
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
self.expected_op_num = 2
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape):
......@@ -441,7 +484,7 @@ class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape):
self.dygraph_func = dyfunc_with_if_1
def _set_expected_op_num(self):
self.expected_op_num = 28
self.expected_op_num = 19
self.expected_shape_op_num = 4
self.expected_slice_op_num = 2
......@@ -466,5 +509,17 @@ class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape):
self.expected_slice_op_num = 3
class TestChangeShapeAfterAssign(TestTensorShapeBasic):
def init_test_func(self):
self.input = numpy.ones((2, 3)).astype("int32")
self.input_spec = [paddle.static.InputSpec(shape=[2, 3], dtype="int32")]
self.dygraph_func = dyfunc_change_shape_after_assign
def _set_expected_op_num(self):
self.expected_op_num = 3
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
if __name__ == '__main__':
unittest.main()
......@@ -25,11 +25,15 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_compare #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape_simple #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import eval_if_exist_else_none #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import choose_shape_attr_or_api #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop #DEFINE_ALIAS
__all__ = [
'cast_bool_if_necessary', 'convert_assert', 'convert_ifelse', 'convert_len',
'convert_logical_and', 'convert_logical_not', 'convert_logical_or',
'convert_pop', 'convert_print', 'convert_shape_compare',
'convert_var_dtype', 'convert_var_shape', 'convert_while_loop'
'convert_var_dtype', 'convert_var_shape', 'convert_var_shape_simple',
'eval_if_exist_else_none', 'choose_shape_attr_or_api', 'convert_while_loop'
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册