未验证 提交 d82d5b8c 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Refactor convert_shape transformer logic (#43846)

* [Dy2Stat]Refactor convert_shape transformer logic

* clean usless unittest
上级 a5dc0a79
......@@ -338,88 +338,30 @@ def convert_zip(*args):
return zip(*args)
def convert_var_shape(x, idx=None, in_control_flow=False):
def convert_shape(x):
"""
A function representation of the shape of variable.
"""
def has_negative(list_shape, idx=None):
if idx is not None:
return list_shape[idx] < 0
num_negative = sum([1 if i < 0 else 0 for i in list_shape])
return num_negative > 0
# When `x` is Variable, call nn.shape(x) in following cases:
# (1) The shape of `x` is used in control flow condition.
# ```
# if x.shape[0] == 1:
# y = XX
# ```
# (2) The dim to be used is negative
# ```
# # Assume x.shape=[3, -1] in static mode
# y = paddle.reshape(x, shape=[1, x.shape[1]])
# ```
if isinstance(x, Variable) and has_negative(x.shape, idx):
return nn.shape(x) if idx is None else nn.shape(x)[idx]
else:
return list(x.shape) if idx is None else x.shape[idx]
def has_negative(list_shape):
return any([x < 0 for x in list_shape])
# When `x` is Variable:
# (1) if x.shape contains -1, such as [2, -1, 64], returns [2, var, 64],
# where var = paddle.shape(x)[1]
# (2) if x.shape does not contains -1, return lsit(x.shape) directly
def convert_var_shape_simple(x):
"""
A function representation of the shape of variable.
"""
if isinstance(x, Variable):
return nn.shape(x)
values = list(x.shape)
if has_negative(values):
shape_tensor = nn.shape(x)
for i, v in enumerate(values):
if v is None or v < 0:
values[i] = shape_tensor[i]
return values
else:
# Use list() to make returned type consistant with dygraph
return list(x.shape)
def eval_if_exist_else_none(name, global_symbol_table):
"""
Args:
name([str]): Expression passed into `eval`.
local_symbol_table(dict): Specified from `globals()`. DO NOT use `locals()`,
because all STATIC_CONVERT_VAR_SHAPE_SUFFIX vars is
declared with keyword `global`.
Returns:
Return the variable if found in global_symbol_table else None.
"""
try:
return eval(name, global_symbol_table)
except:
return None
def choose_shape_attr_or_api(attr_shape, api_shape, idx=None):
"""
Input can be attribute `x.shape` or api `shape(x)`, this function
chooses which one to return to use in dy2stat.
Note: sometimes users write `x.shape[3]`, so attr_shape can be an integer.
"""
if api_shape is None:
return attr_shape if idx is None else attr_shape[idx]
if not isinstance(attr_shape, (list, tuple)):
# some variables like x.shape[0] is no longer a list or tuple
if isinstance(attr_shape, int) and attr_shape < 0:
return api_shape if idx is None else api_shape[idx]
return attr_shape if idx is None else attr_shape[idx]
def has_negative(list_shape, idx=None):
if idx is not None:
return list_shape[idx] < 0
num_negative = sum([1 if i < 0 else 0 for i in list_shape])
return num_negative > 0
if has_negative(attr_shape, idx):
return api_shape if idx is None else api_shape[idx]
return attr_shape if idx is None else attr_shape[idx]
return x.shape
def convert_shape_compare(left, *args):
......
......@@ -63,28 +63,6 @@ class LogicalTransformer(gast.NodeTransformer):
return new_node
return node
def visit_Compare(self, node):
self.generic_visit(node)
left_str = ast_to_source_code(node.left).strip()
if left_str.startswith("_jst.convert_var_shape"):
# check left and comparators are all converted var shape
compare_arg_strs = left_str
for i, comparator in enumerate(node.comparators):
comparator_str = ast_to_source_code(comparator).strip()
if not comparator_str.startswith("_jst.convert_var_shape"):
return node
op_str = cmpop_node_to_str(node.ops[i])
compare_arg_strs += (", '" + op_str + "', " + comparator_str)
# Now all left and comparators are converted shape
# Replace some comparsion operation because of difference between
# Python and Paddle
new_node_str = "_jst.convert_shape_compare({})".format(
compare_arg_strs)
new_node = gast.parse(new_node_str).body[0].value
return new_node
return node
def visit_BoolOp(self, node):
self.generic_visit(node)
if isinstance(node.op, gast.And):
......
......@@ -25,77 +25,11 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
STATIC_CONVERT_VAR_SHAPE_SUFFIX = '__static_convert_var_shape_suffix'
def create_convert_shape_node(var_shape_node,
slice_node=None,
in_control_flow=False):
assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript))
if isinstance(var_shape_node, gast.Attribute):
args = [ast_to_source_code(var_shape_node.value).strip()]
# (1) A slice can be a simple number such as 1, -2, i.e. gast.Index or gast.Constant
# (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index or gast.Constant
# In (1) case, we pass the number as 'idx' argument in convert_var_shape
# In (2) case, we have to make it like `convert_var_shape(x)[slice]`
if slice_node is not None and slice_is_num(slice_node):
args.append(ast_to_source_code(slice_node.slice).strip())
convert_var_shape_func = "_jst.convert_var_shape({}, in_control_flow={})".format(
",".join(args), in_control_flow)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value
if slice_node is not None and not slice_is_num(slice_node):
return gast.Subscript(value=api_shape_node,
slice=slice_node.slice,
ctx=gast.Load())
return api_shape_node
if isinstance(var_shape_node, gast.Subscript):
result_node = copy.deepcopy(var_shape_node)
result_node = create_convert_shape_node(result_node.value, result_node,
in_control_flow)
return result_node
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
eval_exist_func = "_jst.eval_if_exist_else_none('{}', globals())".format(
api_shape_name)
args = [attr_shape_name, eval_exist_func]
if slice_node is not None and slice_is_num(slice_node):
args.append(ast_to_source_code(slice_node.slice).strip())
choose_shape_func = "_jst.choose_shape_attr_or_api({})".format(
",".join(args))
choose_shape_node = gast.parse(choose_shape_func).body[0].value
if slice_node is not None and not slice_is_num(slice_node):
return gast.Subscript(value=choose_shape_node,
slice=slice_node.slice,
ctx=gast.Load())
return choose_shape_node
class ShapeAttributeTransformer(gast.NodeTransformer):
"""
Input a node like `x.shape` or `x[4].shape[0]` (self._is_var_shape(node) is True),
return a new node changes input to static shape API like `convert_var_shape(x)`,
`convert_var_shape(x[4])[0]`.
"""
def visit_Attribute(self, node):
if node.attr == 'shape':
args = ast_to_source_code(node.value).strip()
convert_var_shape_func = "_jst.convert_var_shape_simple({})".format(
args)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value
return api_shape_node
return node
class TensorShapeTransformer(gast.NodeTransformer):
"""
This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast.
This class transforms variable.shape into Static Graph Ast.
All 'xxx.shape' will be converted int '_jst.convert_shape(x)'.
"""
def __init__(self, wrapper_root):
......@@ -104,295 +38,17 @@ class TensorShapeTransformer(gast.NodeTransformer):
), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
# stores origin var string name (like "x" in `x = t.shape`) to
# static shape var string name (like "x_SUFFIX" in `x_SUFFIX = shape(t)`)
self.name_to_var_shape = {}
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
var_env = self.static_analysis_visitor.get_var_env()
var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
self.scope_var_type_dict = var_env.get_scope_var_type()
def transform(self):
SplitAssignTransformer(self.root).transform()
self.visit(self.root)
def visit_Assign(self, node):
update_static_shape_var_node = self._update_name_to_var_shape(node)
if update_static_shape_var_node is not None:
ret = [node]
ret.extend(update_static_shape_var_node)
return ret
self.generic_visit(node)
return node
def visit_Subscript(self, node):
value_node = node.value
slice_node = node.slice
if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_var_shape and self._used_by_paddle_api(
value_node):
return create_choose_shape_node(
value_node.id, self.name_to_var_shape[value_node.id], node)
elif isinstance(value_node, gast.Attribute):
if self._used_by_paddle_api(value_node):
value_name = ast_to_source_code(value_node).strip()
if value_name in self.name_to_var_shape:
return create_choose_shape_node(
value_name, self.name_to_var_shape[value_name], node)
if self._is_var_shape(value_node):
return create_convert_shape_node(value_node, node)
return node
def visit_Attribute(self, node):
if self._used_by_paddle_api(node):
name = ast_to_source_code(node).strip()
if name in self.name_to_var_shape:
return create_choose_shape_node(name,
self.name_to_var_shape[name])
if self._is_var_shape(node):
return create_convert_shape_node(node)
return node
def visit_Name(self, node):
if node.id in self.name_to_var_shape:
if self._used_by_paddle_api(node):
return create_choose_shape_node(node.id,
self.name_to_var_shape[node.id])
return node
def visit_Call(self, node):
if is_paddle_api(node):
# Visit gast.Attribute and gast.Name to replace var.shape if necessary.
self.generic_visit(node)
# Don't have to visit other APIs
return node
def visit_If(self, node):
# Call generic_visit first to transform var.shape that is used in Paddle Api.
self.generic_visit(node)
cond = node.test
self._transform_var_shape_if_necessary(cond)
return node
def visit_While(self, node):
self.generic_visit(node)
cond = node.test
self._transform_var_shape_if_necessary(cond)
return node
def visit_For(self, node):
self.generic_visit(node)
iter = node.iter
self._transform_var_shape_if_necessary(iter)
# If var.shape is a gast.Name and it is used in range function, transform it
self._transform_var_shape_in_range(node)
if node.attr == 'shape':
args = ast_to_source_code(node.value).strip()
# NOTE(dev): we can deal with paddle.shape in this case, but it's
# not pretty to modify into 'convert_shape(paddle)(x)[0]'.
if args != 'paddle':
convert_shape_func = "_jst.convert_shape({})".format(args)
shape_node = gast.parse(convert_shape_func).body[0].value
return shape_node
return node
def _transform_var_shape_in_range(self, node):
assert isinstance(node, gast.For)
if not isinstance(node.iter, gast.Call):
return False
if not isinstance(node.iter.func, gast.Name):
return False
if node.iter.func.id != "range":
return False
args = node.iter.args
for idx, arg in enumerate(args):
if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape:
args[idx] = create_choose_shape_node(
arg.id, self.name_to_var_shape[arg.id])
return True
def _transform_var_shape_if_necessary(self, cond):
need_transformed = False
for child_node in gast.walk(cond):
var_shape_node = None
if isinstance(child_node,
(gast.Name, gast.Attribute, gast.Subscript)):
child_name = ast_to_source_code(child_node).strip()
if child_name in self.name_to_var_shape:
var_shape_node = create_choose_shape_node(
child_name, self.name_to_var_shape[child_name])
elif self._is_var_shape(child_node):
var_shape_node = child_node
if var_shape_node:
need_transformed = True
wrapper_node = self.node_to_wrapper_map.get(child_node)
parent_node = wrapper_node.parent.node
for field, value in gast.iter_fields(parent_node):
if child_node is value:
if var_shape_node is child_node:
setattr(
parent_node, field,
create_convert_shape_node(
var_shape_node, None, True))
else:
setattr(parent_node, field, var_shape_node)
break
# Some child_node may be in a list such as gast.Compare
if isinstance(value, list):
has_converted_shape = False
for i, v in enumerate(value):
if child_node is v:
if var_shape_node is child_node:
value[i] = create_convert_shape_node(
var_shape_node, None, True)
else:
value[i] = var_shape_node
has_converted_shape = True
break
if has_converted_shape:
break
return need_transformed
def _used_by_paddle_api(self, node):
"""
Whether node is used in paddle api as arguments.
For example:
1) Return True in `paddle.relu(x)` where node is `x` (gast.Name)
2) Return True in `paddle.add(self.x)` where node is `self.x` (gast.Attribute)
3) Return False in `paddle.add(self.x)` where node is `paddle.add` (gast.Attribute),
because the role of node is not arguments but `gast.Call.func`.
"""
assert isinstance(node, (gast.Attribute, gast.Name))
wrapper_node = self.node_to_wrapper_map.get(node)
if not wrapper_node:
# Transformed node is not in node_to_wrapper_map
return False
while wrapper_node.parent:
parent_node = wrapper_node.parent.node
if isinstance(parent_node, gast.Call):
# Note(Aurelius84): Filter the case when the role of node is `gast.Call.func`.
if is_paddle_api(parent_node) and parent_node.func != node:
return True
else:
return False
wrapper_node = wrapper_node.parent
return False
def _is_var_shape(self, node):
"""
Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise.
"""
if not isinstance(node, (gast.Attribute, gast.Subscript)):
return False
if isinstance(node, gast.Attribute):
# If node is `paddle.shape`, return False
if (node.attr == 'shape' and isinstance(node.value, gast.Name)
and node.value.id == 'paddle'):
return False
if node.attr != 'shape':
return False
return True
if isinstance(node, gast.Subscript):
value_node = node.value
return self._is_var_shape(value_node)
return False
def _update_name_to_var_shape(self, node):
assert isinstance(node, gast.Assign)
target_node = node.targets[0]
value_node = node.value
update_static_shape_var_node = None
if isinstance(target_node, gast.Tuple):
update_static_shape_var_node = []
for idx, element in enumerate(target_node.elts):
target_id = ast_to_source_code(element).strip()
if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_var_shape:
# TODO(zhhsplendid): is context a problem for the result node of gast.parse?
static_shape_var_name = unique_name.generate(
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
static_shape_value_name = self.name_to_var_shape[
value_node.id]
sub_node_str = "{}[{}]".format(static_shape_value_name,
idx)
sub_node = gast.parse(sub_node_str).body[0].value
update_static_shape_var_node.append(
gast.Assign(targets=[static_shape_var_node],
value=sub_node))
self.name_to_var_shape[
target_id] = static_shape_var_name
if isinstance(value_node, gast.Attribute):
if self._is_var_shape(value_node): # eg: x.shape
static_shape_var_name = unique_name.generate(
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
static_shape_value_node = copy.deepcopy(value_node)
# x.shape becomes convert_var_shape_simple(x)
static_shape_value_node = ShapeAttributeTransformer(
).visit(static_shape_value_node)
sub_node_str = "{}[{}]".format(
ast_to_source_code(static_shape_value_node).strip(),
idx)
sub_node = gast.parse(sub_node_str).body[0].value
# Note(Aurelius84): Becuase static_shape_var_name is used in
# eval_if_exist_else_none() as plain string, so it will not
# be pasred as argument in convert_loop/ifelse. We delcare it
# as global var because it has unique name.
update_static_shape_var_node.append(
gast.Global(names=[static_shape_var_name]))
update_static_shape_var_node.append(
gast.Assign(targets=[static_shape_var_node],
value=sub_node))
self.name_to_var_shape[
target_id] = static_shape_var_name
return update_static_shape_var_node
else:
target_id = ast_to_source_code(target_node).strip()
if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_var_shape:
static_shape_var_name = unique_name.generate(
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
static_shape_value_name = self.name_to_var_shape[
value_node.id]
static_shape_value_node = gast.parse(
static_shape_value_name).body[0].value
update_static_shape_var_node = [
gast.Assign(targets=[static_shape_var_node],
value=static_shape_value_node)
]
self.name_to_var_shape[target_id] = static_shape_var_name
elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0]
static_shape_var_name = unique_name.generate(
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
static_shape_value_node = copy.deepcopy(value_node)
# x.shape becomes convert_var_shape_simple(x)
static_shape_value_node = ShapeAttributeTransformer().visit(
static_shape_value_node)
# Declare static_shape_var_name as global var
update_static_shape_var_node = [
gast.Global(names=[static_shape_var_name])
]
update_static_shape_var_node.append(
gast.Assign(targets=[static_shape_var_node],
value=static_shape_value_node))
self.name_to_var_shape[target_id] = static_shape_var_name
return update_static_shape_var_node
......@@ -15,7 +15,6 @@
import numpy as np
import paddle
import unittest
from paddle.jit.dy2static.convert_operators import eval_if_exist_else_none
class CallNotExist(paddle.nn.Layer):
......@@ -143,108 +142,6 @@ class TestConvertShapeCompare(unittest.TestCase):
paddle.disable_static()
class TestChooseShapeAttrOrApi(unittest.TestCase):
def test_api_shape_is_none(self):
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api([1, 2], None), [1, 2])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api([1], None), [1])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api([2, 3, 7], None, 0),
2)
def test_attr_shape_is_int(self):
x = paddle.zeros([1, 3, 5, 7])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(
x.shape[0],
paddle.shape(x)[0]), 1)
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(
x.shape[1],
paddle.shape(x)[1]), 3)
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(
-1,
paddle.shape(x)[0]),
paddle.shape(x)[0])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(
-1, paddle.shape(x), 0),
paddle.shape(x)[0])
def test_positive_attr_shape(self):
x = paddle.zeros([1, 3, 5, 7])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(
x.shape, paddle.shape(x)), x.shape)
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api(
x.shape, paddle.shape(x), 3), x.shape[3])
def test_negative_attr_shape(self):
x = paddle.zeros([7])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api([-1], paddle.shape(x),
0),
paddle.shape(x)[0])
self.assertEqual(
paddle.jit.dy2static.choose_shape_attr_or_api([-1],
paddle.shape(x)),
paddle.shape(x))
class TestEvaIfExistElseNone(unittest.TestCase):
def test_globals(self):
global x_shape
x_shape = [1, 2, 3]
self.assertEqual(eval_if_exist_else_none('x_shape', locals()), None)
self.assertEqual(eval_if_exist_else_none('x_shape', globals()), x_shape)
del x_shape
def test_enclosing_scope(self):
global x_shape
x_shape = [1, 2, 3]
def foo():
y_shape = [2, 3, 4]
self.assertEqual(eval_if_exist_else_none('x_shape', globals()),
[1, 2, 3])
self.assertEqual(eval_if_exist_else_none('y_shape', locals()),
[2, 3, 4])
foo()
del x_shape
def test_global_in_func(self):
x_shape = [1, 2, 3]
def foo():
global y_shape
y_shape = [2, 3, 4]
self.assertEqual(eval_if_exist_else_none('y_shape', globals()),
[2, 3, 4])
self.assertEqual(eval_if_exist_else_none('x_shape', locals()), None)
self.assertEqual(eval_if_exist_else_none('x_shape', globals()),
None)
del y_shape
foo()
def test_none(self):
def foo():
x_shape = [2, 3, 4]
return x_shape
self.assertEqual(eval_if_exist_else_none('x_shape', locals()), None)
class ShapeLayer(paddle.nn.Layer):
def __init__(self):
......
......@@ -275,6 +275,7 @@ class TestTensorShapeBasic(unittest.TestCase):
self.expected_slice_op_num = 0
def _compute_op_num(self, program):
print(program)
self.op_num = sum([len(block.ops) for block in program.blocks])
self.shape_op_num = 0
self.slice_op_num = 0
......@@ -300,8 +301,8 @@ class TestTensorShapeBasic2(TestTensorShapeBasic):
self.dygraph_func = dyfunc_tensor_shape_2
def _set_expected_op_num(self):
self.expected_op_num = 3
self.expected_shape_op_num = 1
self.expected_op_num = 2
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
......@@ -323,9 +324,9 @@ class TestTensorShapeBasic5(TestTensorShapeBasic):
self.dygraph_func = dyfunc_tensor_shape_5
def _set_expected_op_num(self):
self.expected_op_num = 4
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
self.expected_op_num = 2
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
class TestTensorShapeBasic6(TestTensorShapeBasic):
......@@ -334,21 +335,23 @@ class TestTensorShapeBasic6(TestTensorShapeBasic):
self.dygraph_func = dyfunc_tensor_shape_6
def _set_expected_op_num(self):
self.expected_op_num = 4
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
self.expected_op_num = 2
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
class TestTupleShape1(TestTensorShapeBasic):
def init_test_func(self):
self.input = numpy.ones((5, 7)).astype("int32")
self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")]
self.input_spec = [
paddle.static.InputSpec(shape=[-1, -1], dtype="int32")
]
self.dygraph_func = dyfunc_tuple_shape_1
def _set_expected_op_num(self):
self.expected_op_num = 6
self.expected_shape_op_num = 2
self.expected_op_num = 5
self.expected_shape_op_num = 1
self.expected_slice_op_num = 2
......@@ -356,13 +359,15 @@ class TestTupleShape2(TestTensorShapeBasic):
def init_test_func(self):
self.input = numpy.ones((5, 7)).astype("int32")
self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")]
self.input_spec = [
paddle.static.InputSpec(shape=[-1, 7], dtype="int32")
]
self.dygraph_func = dyfunc_tuple_shape_2
def _set_expected_op_num(self):
self.expected_op_num = 5
self.expected_shape_op_num = 1
self.expected_slice_op_num = 2
self.expected_slice_op_num = 1
class TestTupleShape3(TestTensorShapeBasic):
......@@ -398,9 +403,9 @@ class TestTensorShapeInIf1(TestTensorShapeBasic):
self.dygraph_func = dyfunc_with_if_1
def _set_expected_op_num(self):
self.expected_op_num = 4
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
self.expected_op_num = 2
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
class TestTensorShapeInIf2(TestTensorShapeBasic):
......@@ -432,9 +437,9 @@ class TestTensorShapeInFor2(TestTensorShapeInFor1):
self.dygraph_func = dyfunc_with_for_2
def _set_expected_op_num(self):
self.expected_op_num = 9
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
self.expected_op_num = 7
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
class TestTensorShapeInFor3(TestTensorShapeInFor1):
......@@ -466,9 +471,9 @@ class TestTensorShapeInWhile2(TestTensorShapeInFor1):
self.dygraph_func = dyfunc_with_while_2
def _set_expected_op_num(self):
self.expected_op_num = 6
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
self.expected_op_num = 4
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
class TestTensorShapeInWhile3(TestTensorShapeBasic):
......@@ -477,8 +482,8 @@ class TestTensorShapeInWhile3(TestTensorShapeBasic):
self.dygraph_func = dyfunc_with_while_3
def _set_expected_op_num(self):
self.expected_op_num = 3
self.expected_shape_op_num = 1
self.expected_op_num = 2
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
......@@ -510,9 +515,9 @@ class TestOpNumBasicWithTensorShape(unittest.TestCase):
self.dygraph_func = dyfunc_tensor_shape_1
def _set_expected_op_num(self):
self.expected_op_num = 3
self.expected_op_num = 5
self.expected_shape_op_num = 1
self.expected_slice_op_num = 0
self.expected_slice_op_num = 1
def _compute_op_num(self, program):
self.op_num = sum([len(block.ops) for block in program.blocks])
......@@ -541,9 +546,9 @@ class TestOpNumBasicWithTensorShape4(TestOpNumBasicWithTensorShape):
self.dygraph_func = dyfunc_tensor_shape_4
def _set_expected_op_num(self):
self.expected_op_num = 6
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
self.expected_op_num = 8
self.expected_shape_op_num = 2
self.expected_slice_op_num = 2
class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape):
......@@ -552,9 +557,9 @@ class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape):
self.dygraph_func = dyfunc_tuple_shape_1
def _set_expected_op_num(self):
self.expected_op_num = 7
self.expected_shape_op_num = 2
self.expected_slice_op_num = 2
self.expected_op_num = 5
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape):
......@@ -563,9 +568,9 @@ class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape):
self.dygraph_func = dyfunc_with_if_1
def _set_expected_op_num(self):
self.expected_op_num = 28
self.expected_op_num = 32
self.expected_shape_op_num = 4
self.expected_slice_op_num = 2
self.expected_slice_op_num = 4
class TestOpNumWithTensorShapeInFor1(TestOpNumBasicWithTensorShape):
......@@ -594,13 +599,15 @@ class TestChangeShapeAfterAssign(TestTensorShapeBasic):
def init_test_func(self):
self.input = numpy.ones((2, 3)).astype("int32")
self.input_spec = [paddle.static.InputSpec(shape=[2, 3], dtype="int32")]
self.input_spec = [
paddle.static.InputSpec(shape=[-1, 3], dtype="int32")
]
self.dygraph_func = dyfunc_change_shape_after_assign
def _set_expected_op_num(self):
self.expected_op_num = 7
self.expected_shape_op_num = 2
self.expected_slice_op_num = 2
self.expected_op_num = 6
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
def dyfunc_with_static_convert_var_shape(x):
......@@ -627,16 +634,5 @@ class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase):
func.concrete_program
class TestPaddleShape(unittest.TestCase):
def test_paddle_shape(self):
func = paddle.jit.to_static(dyfunc_len_paddle_shape)
func_code = func.code.replace("\n", "").replace(" ", "")
self.assertEqual('paddle.shape(x)' in func_code, True)
func = paddle.jit.to_static(dyfunc_dict_assign_shape)
func_code = func.code.replace("\n", "").replace(" ", "")
self.assertEqual("__static_convert_var_shape_suffix" in func_code, True)
if __name__ == '__main__':
unittest.main()
......@@ -26,10 +26,7 @@ from .convert_operators import convert_pop # noqa: F401
from .convert_operators import convert_print # noqa: F401
from .convert_operators import convert_shape_compare # noqa: F401
from .convert_operators import convert_var_dtype # noqa: F401
from .convert_operators import convert_var_shape # noqa: F401
from .convert_operators import convert_var_shape_simple # noqa: F401
from .convert_operators import eval_if_exist_else_none # noqa: F401
from .convert_operators import choose_shape_attr_or_api # noqa: F401
from .convert_operators import convert_shape # noqa: F401
from .convert_operators import convert_while_loop # noqa: F401
from .variable_trans_func import create_bool_as_type # noqa: F401
from .variable_trans_func import create_fill_constant_node # noqa: F401
......
......@@ -24,10 +24,7 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_pop #
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_compare # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape_simple # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import eval_if_exist_else_none # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import choose_shape_attr_or_api # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop # noqa: F401
__all__ = []
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册