未验证 提交 3a72408f 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Cherry-pick][Dy2stat] Cherry-pick of PR31082 and PR31051 (#31101)

Cherry-pick of #31051 and #31082
上级 29467060
...@@ -267,12 +267,12 @@ def convert_var_shape(x, idx=None, in_control_flow=False): ...@@ -267,12 +267,12 @@ def convert_var_shape(x, idx=None, in_control_flow=False):
A function representation of the shape of variable. A function representation of the shape of variable.
""" """
def has_negetive(list_shape, idx=None): def has_negative(list_shape, idx=None):
if idx is not None: if idx is not None:
return list_shape[idx] < 0 return list_shape[idx] < 0
num_negetive = sum([1 if i < 0 else 0 for i in list_shape]) num_negative = sum([1 if i < 0 else 0 for i in list_shape])
return num_negetive > 0 return num_negative > 0
# When `x` is Variable, call nn.shape(x) in following cases: # When `x` is Variable, call nn.shape(x) in following cases:
# (1) The shape of `x` is used in control flow condition. # (1) The shape of `x` is used in control flow condition.
...@@ -280,18 +280,62 @@ def convert_var_shape(x, idx=None, in_control_flow=False): ...@@ -280,18 +280,62 @@ def convert_var_shape(x, idx=None, in_control_flow=False):
# if x.shape[0] == 1: # if x.shape[0] == 1:
# y = XX # y = XX
# ``` # ```
# (2) The dim to be used is negetive # (2) The dim to be used is negative
# ``` # ```
# # Assume x.shape=[3, -1] in static mode # # Assume x.shape=[3, -1] in static mode
# y = paddle.reshape(x, shape=[1, x.shape[1]]) # y = paddle.reshape(x, shape=[1, x.shape[1]])
# ``` # ```
if isinstance(x, Variable) and (in_control_flow or has_negetive(x.shape, if isinstance(x, Variable) and (in_control_flow or has_negative(x.shape,
idx)): idx)):
return nn.shape(x) if idx is None else nn.shape(x)[idx] return nn.shape(x) if idx is None else nn.shape(x)[idx]
else: else:
return x.shape if idx is None else x.shape[idx] return x.shape if idx is None else x.shape[idx]
def convert_var_shape_simple(x):
"""
A function representation of the shape of variable.
"""
if isinstance(x, Variable):
return nn.shape(x)
else:
return x.shape
def eval_if_exist_else_none(name):
try:
return eval(name)
except:
return None
def choose_shape_attr_or_api(attr_shape, api_shape, idx=None):
"""
Input can be attribute `x.shape` or api `shape(x)`, this function
chooses which one to return to use in dy2stat.
Note: sometimes users write `x.shape[3]`, so attr_shape can be an integer.
"""
if api_shape is None:
return attr_shape if idx is None else attr_shape[idx]
if not isinstance(attr_shape, (list, tuple)):
# some variables like x.shape[0] is no longer a list or tuple
if isinstance(attr_shape, int) and attr_shape < 0:
return api_shape if idx is None else api_shape[idx]
return attr_shape if idx is None else attr_shape[idx]
def has_negative(list_shape, idx=None):
if idx is not None:
return list_shape[idx] < 0
num_negative = sum([1 if i < 0 else 0 for i in list_shape])
return num_negative > 0
if has_negative(attr_shape, idx):
return api_shape if idx is None else api_shape[idx]
return attr_shape if idx is None else attr_shape[idx]
def convert_shape_compare(left, *args): def convert_shape_compare(left, *args):
""" """
A function handles comparison difference between Paddle and Python. A function handles comparison difference between Paddle and Python.
......
...@@ -17,12 +17,15 @@ from __future__ import print_function ...@@ -17,12 +17,15 @@ from __future__ import print_function
import copy import copy
import gast import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
STATIC_CONVERT_VAR_SHAPE_SUFFIX = '__static_convert_var_shape_suffix'
def create_convert_shape_node(var_shape_node, def create_convert_shape_node(var_shape_node,
slice_node=None, slice_node=None,
...@@ -31,13 +34,20 @@ def create_convert_shape_node(var_shape_node, ...@@ -31,13 +34,20 @@ def create_convert_shape_node(var_shape_node,
if isinstance(var_shape_node, gast.Attribute): if isinstance(var_shape_node, gast.Attribute):
args = [ast_to_source_code(var_shape_node.value).strip()] args = [ast_to_source_code(var_shape_node.value).strip()]
if slice_node: # (1) A slice can be a simple number such as 1, -2, i.e. gast.Index
# (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index
# In (1) case, we pass the number as 'idx' argument in convert_var_shape
# In (2) case, we have to make it like `convert_var_shape(x)[slice]`
if slice_node is not None and isinstance(slice_node, gast.Index):
args.append(ast_to_source_code(slice_node).strip()) args.append(ast_to_source_code(slice_node).strip())
convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format( convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format(
",".join(args), in_control_flow) ",".join(args), in_control_flow)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value api_shape_node = gast.parse(convert_var_shape_func).body[0].value
if slice_node is not None and not isinstance(slice_node, gast.Index):
return gast.Subscript(
value=api_shape_node, slice=slice_node, ctx=gast.Load())
return api_shape_node return api_shape_node
if isinstance(var_shape_node, gast.Subscript): if isinstance(var_shape_node, gast.Subscript):
...@@ -47,6 +57,39 @@ def create_convert_shape_node(var_shape_node, ...@@ -47,6 +57,39 @@ def create_convert_shape_node(var_shape_node,
return result_node return result_node
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}')".format(
api_shape_name)
args = [attr_shape_name, eval_exist_func]
if slice_node is not None and isinstance(slice_node, gast.Index):
args.append(ast_to_source_code(slice_node).strip())
choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format(
",".join(args))
choose_shape_node = gast.parse(choose_shape_func).body[0].value
if slice_node is not None and not isinstance(slice_node, gast.Index):
return gast.Subscript(
value=choose_shape_node, slice=slice_node, ctx=gast.Load())
return choose_shape_node
class ShapeAttributeTransformer(gast.NodeTransformer):
"""
Input a node like `x.shape` or `x[4].shape[0]` (self._is_var_shape(node) is True),
return a new node changes input to static shape API like `convert_var_shape(x)`,
`convert_var_shape(x[4])[0]`.
"""
def visit_Attribute(self, node):
if node.attr == 'shape':
args = ast_to_source_code(node.value).strip()
convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape_simple({})".format(
args)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value
return api_shape_node
return node
class TensorShapeTransformer(gast.NodeTransformer): class TensorShapeTransformer(gast.NodeTransformer):
""" """
This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast. This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast.
...@@ -58,6 +101,8 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -58,6 +101,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
self.wrapper_root = wrapper_root self.wrapper_root = wrapper_root
self.root = wrapper_root.node self.root = wrapper_root.node
# stores origin var string name (like "x" in `x = t.shape`) to
# static shape var string name (like "x_SUFFIX" in `x_SUFFIX = shape(t)`)
self.name_to_var_shape = {} self.name_to_var_shape = {}
self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
...@@ -72,8 +117,11 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -72,8 +117,11 @@ class TensorShapeTransformer(gast.NodeTransformer):
self.visit(self.root) self.visit(self.root)
def visit_Assign(self, node): def visit_Assign(self, node):
if self._update_name_to_var_shape(node): update_static_shape_var_node = self._update_name_to_var_shape(node)
return node if update_static_shape_var_node is not None:
ret = [node]
ret.extend(update_static_shape_var_node)
return ret
self.generic_visit(node) self.generic_visit(node)
return node return node
...@@ -81,37 +129,44 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -81,37 +129,44 @@ class TensorShapeTransformer(gast.NodeTransformer):
value_node = node.value value_node = node.value
slice_node = node.slice slice_node = node.slice
if isinstance(value_node, gast.Name): if isinstance(value_node, gast.Name):
if self._is_var_shape(value_node) and self._used_by_paddle_api( if value_node.id in self.name_to_var_shape and self._used_by_paddle_api(
value_node):
var_shape_node = self.name_to_var_shape[value_node.id]
return create_convert_shape_node(var_shape_node, slice_node)
if isinstance(value_node, gast.Attribute):
if self._used_by_paddle_api(value_node) and self._is_var_shape(
value_node): value_node):
return create_choose_shape_node(
value_node.id, self.name_to_var_shape[value_node.id],
slice_node)
elif isinstance(value_node, gast.Attribute):
if self._used_by_paddle_api(value_node):
value_name = ast_to_source_code(value_node).strip()
if value_name in self.name_to_var_shape:
return create_choose_shape_node(
value_name, self.name_to_var_shape[value_name],
slice_node)
if self._is_var_shape(value_node):
return create_convert_shape_node(value_node, slice_node) return create_convert_shape_node(value_node, slice_node)
return node return node
def visit_Attribute(self, node): def visit_Attribute(self, node):
if self._used_by_paddle_api(node): if self._used_by_paddle_api(node):
name = ast_to_source_code(node).strip()
if name in self.name_to_var_shape:
return create_choose_shape_node(name,
self.name_to_var_shape[name])
if self._is_var_shape(node): if self._is_var_shape(node):
return create_convert_shape_node(node) return create_convert_shape_node(node)
return node return node
def visit_Name(self, node): def visit_Name(self, node):
if self._is_var_shape(node): if node.id in self.name_to_var_shape:
if self._used_by_paddle_api(node): if self._used_by_paddle_api(node):
var_shape_node = self.name_to_var_shape[node.id] return create_choose_shape_node(node.id,
return create_convert_shape_node(var_shape_node) self.name_to_var_shape[node.id])
return node return node
def visit_Call(self, node): def visit_Call(self, node):
assert isinstance(node, gast.Call)
if is_paddle_api(node): if is_paddle_api(node):
# Visit gast.Attribute and gast.Name to replace var.shape if necessary. # Visit gast.Attribute and gast.Name to replace var.shape if necessary.
self.generic_visit(node) self.generic_visit(node)
# Don't have to visit other APIs
return node return node
def visit_If(self, node): def visit_If(self, node):
...@@ -147,22 +202,23 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -147,22 +202,23 @@ class TensorShapeTransformer(gast.NodeTransformer):
return False return False
args = node.iter.args args = node.iter.args
for idx, arg in enumerate(args): for idx, arg in enumerate(args):
if isinstance(arg, gast.Name) and self._is_var_shape(arg): if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape:
args[idx] = create_convert_shape_node(self.name_to_var_shape[ args[idx] = create_choose_shape_node(
arg.id]) arg.id, self.name_to_var_shape[arg.id])
return True return True
def _transform_var_shape_if_necessary(self, cond): def _transform_var_shape_if_necessary(self, cond):
need_transformed = False need_transformed = False
for child_node in gast.walk(cond): for child_node in gast.walk(cond):
var_shape_node = None var_shape_node = None
if isinstance(child_node, (gast.Attribute, gast.Subscript)): if isinstance(child_node,
if self._is_var_shape(child_node): (gast.Name, gast.Attribute, gast.Subscript)):
child_name = ast_to_source_code(child_node).strip()
if child_name in self.name_to_var_shape:
var_shape_node = create_choose_shape_node(
child_name, self.name_to_var_shape[child_name])
elif self._is_var_shape(child_node):
var_shape_node = child_node var_shape_node = child_node
elif isinstance(child_node, (gast.Name)):
if self._is_var_shape(child_node):
var_shape_node = self.name_to_var_shape[child_node.id]
if var_shape_node: if var_shape_node:
need_transformed = True need_transformed = True
...@@ -170,17 +226,23 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -170,17 +226,23 @@ class TensorShapeTransformer(gast.NodeTransformer):
parent_node = wrapper_node.parent.node parent_node = wrapper_node.parent.node
for field, value in gast.iter_fields(parent_node): for field, value in gast.iter_fields(parent_node):
if child_node is value: if child_node is value:
if var_shape_node is child_node:
setattr(parent_node, field, setattr(parent_node, field,
create_convert_shape_node(var_shape_node, None, create_convert_shape_node(var_shape_node,
True)) None, True))
else:
setattr(parent_node, field, var_shape_node)
break break
# Some child_node may be in a list such as gast.Compare # Some child_node may be in a list such as gast.Compare
if isinstance(value, list): if isinstance(value, list):
has_converted_shape = False has_converted_shape = False
for i, v in enumerate(value): for i, v in enumerate(value):
if child_node is v: if child_node is v:
if var_shape_node is child_node:
value[i] = create_convert_shape_node( value[i] = create_convert_shape_node(
var_shape_node, None, True) var_shape_node, None, True)
else:
value[i] = var_shape_node
has_converted_shape = True has_converted_shape = True
break break
if has_converted_shape: if has_converted_shape:
...@@ -217,19 +279,12 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -217,19 +279,12 @@ class TensorShapeTransformer(gast.NodeTransformer):
""" """
Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise. Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise.
""" """
if not isinstance(node, (gast.Name, gast.Attribute, gast.Subscript)): if not isinstance(node, (gast.Attribute, gast.Subscript)):
return False return False
if isinstance(node, gast.Name) and node.id in self.name_to_var_shape:
return True
if isinstance(node, gast.Attribute): if isinstance(node, gast.Attribute):
if node.attr != 'shape': if node.attr != 'shape':
return False return False
if not isinstance(node.value, gast.Name):
return False
return True return True
if isinstance(node, gast.Subscript): if isinstance(node, gast.Subscript):
...@@ -243,49 +298,94 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -243,49 +298,94 @@ class TensorShapeTransformer(gast.NodeTransformer):
target_node = node.targets[0] target_node = node.targets[0]
value_node = node.value value_node = node.value
update_static_shape_var_node = None
if isinstance(target_node, gast.Tuple): if isinstance(target_node, gast.Tuple):
has_updated = False update_static_shape_var_node = []
for idx, element in enumerate(target_node.elts): for idx, element in enumerate(target_node.elts):
target_id = ast_to_source_code(element).strip() target_id = ast_to_source_code(element).strip()
if isinstance(value_node, gast.Name): if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_var_shape: if value_node.id in self.name_to_var_shape:
# TODO(zhhsplendid): is context a problem for the result node of gast.parse?
static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
static_shape_value_name = self.name_to_var_shape[
value_node.id]
static_shape_value_node = gast.parse(
static_shape_value_name).body[0].value
index_value_node = gast.Constant(value=idx, kind=None) index_value_node = gast.Constant(value=idx, kind=None)
slice_index_node = gast.Index(value=index_value_node) slice_index_node = gast.Index(value=index_value_node)
var_shape_node = self.name_to_var_shape[value_node.id]
sub_node = gast.Subscript( sub_node = gast.Subscript(
value=var_shape_node, value=static_shape_value_node,
slice=slice_index_node, slice=slice_index_node,
ctx=gast.Load()) ctx=gast.Load())
self.name_to_var_shape[target_id] = sub_node
has_updated = True update_static_shape_var_node.append(
gast.Assign(
targets=[static_shape_var_node],
value=sub_node))
self.name_to_var_shape[
target_id] = static_shape_var_name
if isinstance(value_node, gast.Attribute): if isinstance(value_node, gast.Attribute):
if self._is_var_shape(value_node): # eg: x.shape if self._is_var_shape(value_node): # eg: x.shape
static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
static_shape_value_node = copy.deepcopy(value_node)
# x.shape becomes convert_var_shape_simple(x)
ShapeAttributeTransformer().visit(
static_shape_value_node)
index_value_node = gast.Constant(value=idx, kind=None) index_value_node = gast.Constant(value=idx, kind=None)
slice_index_node = gast.Index(value=index_value_node) slice_index_node = gast.Index(value=index_value_node)
sub_node = gast.Subscript( sub_node = gast.Subscript(
value=value_node, value=static_shape_value_node,
slice=slice_index_node, slice=slice_index_node,
ctx=gast.Load()) ctx=gast.Load())
self.name_to_var_shape[target_id] = sub_node
has_updated = True
return has_updated update_static_shape_var_node.append(
gast.Assign(
targets=[static_shape_var_node],
value=sub_node))
self.name_to_var_shape[
target_id] = static_shape_var_name
return update_static_shape_var_node
else: else:
target_id = ast_to_source_code(target_node).strip() target_id = ast_to_source_code(target_node).strip()
if isinstance(value_node, gast.Name): if isinstance(value_node, gast.Name):
if self._is_var_shape(value_node): if value_node.id in self.name_to_var_shape:
self.name_to_var_shape[target_id] = self.name_to_var_shape[ static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
static_shape_value_name = self.name_to_var_shape[
value_node.id] value_node.id]
return True static_shape_value_node = gast.parse(
if isinstance(value_node, gast.Attribute): static_shape_value_name).body[0].value
if self._is_var_shape(value_node): # eg: x.shape
self.name_to_var_shape[target_id] = value_node update_static_shape_var_node = [
return True gast.Assign(
if isinstance(value_node, gast.Subscript): targets=[static_shape_var_node],
if isinstance(value_node.value, gast.Attribute): value=static_shape_value_node)
if self._is_var_shape(value_node.value): # eg: x.shape[0] ]
self.name_to_var_shape[target_id] = value_node self.name_to_var_shape[target_id] = static_shape_var_name
return True elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0]
return False static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(static_shape_var_name).body[
0].value
static_shape_value_node = copy.deepcopy(value_node)
# x.shape becomes convert_var_shape_simple(x)
ShapeAttributeTransformer().visit(static_shape_value_node)
update_static_shape_var_node = [
gast.Assign(
targets=[static_shape_var_node],
value=static_shape_value_node)
]
self.name_to_var_shape[target_id] = static_shape_var_name
return update_static_shape_var_node
...@@ -136,5 +136,58 @@ class TestConvertShapeCompare(unittest.TestCase): ...@@ -136,5 +136,58 @@ class TestConvertShapeCompare(unittest.TestCase):
paddle.disable_static() paddle.disable_static()
class TestChooseShapeAttrOrApi(unittest.TestCase):
def test_api_shape_is_none(self):
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api([1, 2], None),
[1, 2])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api([1], None), [1])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api([2, 3, 7], None, 0),
2)
def test_attr_shape_is_int(self):
x = paddle.zeros([1, 3, 5, 7])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(x.shape[0],
paddle.shape(x)[0]),
1)
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(x.shape[1],
paddle.shape(x)[1]),
3)
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(-1,
paddle.shape(x)[0]),
paddle.shape(x)[0])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(-1,
paddle.shape(x), 0),
paddle.shape(x)[0])
def test_positive_attr_shape(self):
x = paddle.zeros([1, 3, 5, 7])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(x.shape,
paddle.shape(x)),
x.shape)
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(x.shape,
paddle.shape(x), 3),
x.shape[3])
def test_negative_attr_shape(self):
x = paddle.zeros([7])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api([-1],
paddle.shape(x), 0),
paddle.shape(x)[0])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api([-1],
paddle.shape(x)),
paddle.shape(x))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -60,6 +60,16 @@ def dyfunc_tensor_shape_5(x): ...@@ -60,6 +60,16 @@ def dyfunc_tensor_shape_5(x):
return res return res
def dyfunc_tensor_shape_6(x):
# `res = fluid.layers.reshape(x, shape=(-1, s))` to
# `res = fluid.layers.reshape(x, shape=(-1,
# paddle.jit.dy2static.convert_var_shape(x)[0:]))`
x = fluid.dygraph.to_variable(x)
s = x.shape[0:]
res = fluid.layers.reshape(x, shape=s)
return res
def dyfunc_tuple_shape_1(x): def dyfunc_tuple_shape_1(x):
x = paddle.to_tensor(x) x = paddle.to_tensor(x)
a, b = x.shape a, b = x.shape
...@@ -197,6 +207,14 @@ def dyfunc_with_while_4(x): ...@@ -197,6 +207,14 @@ def dyfunc_with_while_4(x):
return x return x
def dyfunc_change_shape_after_assign(x):
x = paddle.to_tensor(x)
a, b = x.shape
x = paddle.reshape(x, shape=(-1, 1))
res = paddle.reshape(x, shape=(b, a))
return res
# 1. Basic tests without control flow # 1. Basic tests without control flow
class TestTensorShapeBasic(unittest.TestCase): class TestTensorShapeBasic(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -279,6 +297,21 @@ class TestTensorShapeBasic5(TestTensorShapeBasic): ...@@ -279,6 +297,21 @@ class TestTensorShapeBasic5(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
self.dygraph_func = dyfunc_tensor_shape_5 self.dygraph_func = dyfunc_tensor_shape_5
def _set_expected_op_num(self):
self.expected_op_num = 4
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
class TestTensorShapeBasic6(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_tensor_shape_6
def _set_expected_op_num(self):
self.expected_op_num = 4
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
class TestTupleShape1(TestTensorShapeBasic): class TestTupleShape1(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
...@@ -312,9 +345,9 @@ class TestTensorShapeInIf1(TestTensorShapeBasic): ...@@ -312,9 +345,9 @@ class TestTensorShapeInIf1(TestTensorShapeBasic):
self.dygraph_func = dyfunc_with_if_1 self.dygraph_func = dyfunc_with_if_1
def _set_expected_op_num(self): def _set_expected_op_num(self):
self.expected_op_num = 26 self.expected_op_num = 4
self.expected_shape_op_num = 2 self.expected_shape_op_num = 1
self.expected_slice_op_num = 2 self.expected_slice_op_num = 1
class TestTensorShapeInIf2(TestTensorShapeBasic): class TestTensorShapeInIf2(TestTensorShapeBasic):
...@@ -342,6 +375,11 @@ class TestTensorShapeInFor2(TestTensorShapeInFor1): ...@@ -342,6 +375,11 @@ class TestTensorShapeInFor2(TestTensorShapeInFor1):
def init_test_func(self): def init_test_func(self):
self.dygraph_func = dyfunc_with_for_2 self.dygraph_func = dyfunc_with_for_2
def _set_expected_op_num(self):
self.expected_op_num = 9
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
# 4. Tests with control flow while loop # 4. Tests with control flow while loop
class TestTensorShapeInWhile1(TestTensorShapeInFor1): class TestTensorShapeInWhile1(TestTensorShapeInFor1):
...@@ -353,15 +391,20 @@ class TestTensorShapeInWhile2(TestTensorShapeInFor1): ...@@ -353,15 +391,20 @@ class TestTensorShapeInWhile2(TestTensorShapeInFor1):
def init_test_func(self): def init_test_func(self):
self.dygraph_func = dyfunc_with_while_2 self.dygraph_func = dyfunc_with_while_2
def _set_expected_op_num(self):
self.expected_op_num = 6
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
class TestTensorShapeInWhile3(TestTensorShapeBasic): class TestTensorShapeInWhile3(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
self.dygraph_func = dyfunc_with_while_3 self.dygraph_func = dyfunc_with_while_3
def _set_expected_op_num(self): def _set_expected_op_num(self):
self.expected_op_num = 25 self.expected_op_num = 2
self.expected_shape_op_num = 6 self.expected_shape_op_num = 0
self.expected_slice_op_num = 3 self.expected_slice_op_num = 0
class TestTensorShapeInWhile4(TestTensorShapeBasic): class TestTensorShapeInWhile4(TestTensorShapeBasic):
...@@ -431,9 +474,9 @@ class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape): ...@@ -431,9 +474,9 @@ class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape):
self.dygraph_func = dyfunc_tuple_shape_1 self.dygraph_func = dyfunc_tuple_shape_1
def _set_expected_op_num(self): def _set_expected_op_num(self):
self.expected_op_num = 5 self.expected_op_num = 2
self.expected_shape_op_num = 1 self.expected_shape_op_num = 0
self.expected_slice_op_num = 1 self.expected_slice_op_num = 0
class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape): class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape):
...@@ -441,7 +484,7 @@ class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape): ...@@ -441,7 +484,7 @@ class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape):
self.dygraph_func = dyfunc_with_if_1 self.dygraph_func = dyfunc_with_if_1
def _set_expected_op_num(self): def _set_expected_op_num(self):
self.expected_op_num = 28 self.expected_op_num = 19
self.expected_shape_op_num = 4 self.expected_shape_op_num = 4
self.expected_slice_op_num = 2 self.expected_slice_op_num = 2
...@@ -466,5 +509,17 @@ class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape): ...@@ -466,5 +509,17 @@ class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape):
self.expected_slice_op_num = 3 self.expected_slice_op_num = 3
class TestChangeShapeAfterAssign(TestTensorShapeBasic):
def init_test_func(self):
self.input = numpy.ones((2, 3)).astype("int32")
self.input_spec = [paddle.static.InputSpec(shape=[2, 3], dtype="int32")]
self.dygraph_func = dyfunc_change_shape_after_assign
def _set_expected_op_num(self):
self.expected_op_num = 3
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -25,11 +25,15 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print ...@@ -25,11 +25,15 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_compare #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_compare #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape_simple #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import eval_if_exist_else_none #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import choose_shape_attr_or_api #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop #DEFINE_ALIAS
__all__ = [ __all__ = [
'cast_bool_if_necessary', 'convert_assert', 'convert_ifelse', 'convert_len', 'cast_bool_if_necessary', 'convert_assert', 'convert_ifelse', 'convert_len',
'convert_logical_and', 'convert_logical_not', 'convert_logical_or', 'convert_logical_and', 'convert_logical_not', 'convert_logical_or',
'convert_pop', 'convert_print', 'convert_shape_compare', 'convert_pop', 'convert_print', 'convert_shape_compare',
'convert_var_dtype', 'convert_var_shape', 'convert_while_loop' 'convert_var_dtype', 'convert_var_shape', 'convert_var_shape_simple',
'eval_if_exist_else_none', 'choose_shape_attr_or_api', 'convert_while_loop'
] ]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册