未验证 提交 f16e2778 编写于 作者: L liym27 提交者: GitHub

[Dy2Static]Convert var.shape stmt and Convert the return variables of...

[Dy2Static]Convert var.shape stmt and Convert the return variables of Tensor-dependent 'if' staments to Tensor if it not  (#24911)

* Support int and long: int or long -> six.integer_types. 

* Modify test_tensor_shape: fix bug and modify comment. 

* Support convert_var_shape to convert var.shape stmt

* Modify code in ifelse_simple_func.py because don't support return non-Tensor in Tensor-dependent 'if' stament currently. 

* Convert the return variables of Tensor-dependent 'if' staments to Tensor if it not. test=develop
上级 25a4dac4
......@@ -43,6 +43,7 @@ def convert_while_loop(cond, body, loop_vars):
def _run_paddle_while_loop(cond, body, loop_vars):
# NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Variable.
loop_vars = [to_static_variable(var) for var in loop_vars]
loop_vars = control_flow.while_loop(cond, body, loop_vars)
return loop_vars
......@@ -146,7 +147,7 @@ def _run_py_logical_not(x):
return not x
def convert_ifelse(pred, true_fn, false_fn):
def convert_ifelse(pred, true_fn, false_fn, true_args, false_args, return_vars):
"""
A function representation of a Python ``if/else`` statement.
......@@ -154,25 +155,45 @@ def convert_ifelse(pred, true_fn, false_fn):
pred(bool|Variable): A boolean variable which determines whether to return the result of ``true_fn`` or ``false_fn`` .
true_fn(callable): A callable to be performed if ``pred`` is true.
false_fn(callable): A callable to be performed if ``pred`` is false.
true_args(tuple): Parameters of ``true_fn``.
false_args(tuple): Parameters of ``false_fn``.
return_vars(tuple): Return variables of ``true_fn`` and ``false_fn``.
Returns:
``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` .
``true_fn(true_args)`` if the predicate ``pred`` is true else ``false_fn(false_args)`` .
"""
if isinstance(pred, Variable):
return _run_paddle_cond(pred, true_fn, false_fn)
return _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
return_vars)
else:
return _run_py_ifelse(pred, true_fn, false_fn)
return _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args)
def _run_paddle_cond(pred, true_fn, false_fn):
def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
return_vars):
return_var_ids = [id(var) for var in return_vars]
# NOTE 1: return vars of Paddle op `control_flow.cond` must be Paddle Variable
# NOTE 2: Here uses id(var) not var, because `if var in return_var` use operator `==`,
# which will call `fluid.layers.equal` and causes error when var in return_vars is not initialized.
true_args = [
to_static_variable(var) if id(var) in return_var_ids else var
for var in true_args
]
false_args = [
to_static_variable(var) if id(var) in return_var_ids else var
for var in false_args
]
pred = cast_bool_if_necessary(pred)
return control_flow.cond(pred, true_fn, false_fn)
return control_flow.cond(pred, lambda: true_fn(*true_args),
lambda: false_fn(*false_args))
def _run_py_ifelse(pred, true_fn, false_fn):
def _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args):
return true_fn() if pred else false_fn()
return true_fn(*true_args) if pred else false_fn(*false_args)
def convert_len(var):
......@@ -202,6 +223,16 @@ def convert_len(var):
return len(var)
def convert_var_shape(x):
"""
A function representation of the shape of variable.
"""
if isinstance(x, Variable):
return nn.shape(x)
else:
return x.shape
def cast_bool_if_necessary(var):
assert isinstance(var, Variable)
if convert_dtype(var.dtype) not in ['bool']:
......
......@@ -24,23 +24,14 @@ from collections import defaultdict
import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import compare_with_none
from paddle.fluid.dygraph.dygraph_to_static.utils import is_candidate_node
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node
from paddle.fluid.dygraph.dygraph_to_static.utils import IsControlFlowVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
LOGIC_AND_PREFIX = 'logic_and'
LOGIC_OR_PREFIX = 'logic_or'
LOGIC_NOT_PREFIX = 'logic_not'
PLAIN_TENSOR_PREFIX = 'bool_tensor'
class IfElseTransformer(gast.NodeTransformer):
......@@ -66,8 +57,9 @@ class IfElseTransformer(gast.NodeTransformer):
self.generic_visit(node)
new_vars_stmts, true_func_node, false_func_node, return_name_ids = transform_if_else(
node, self.root)
new_node = create_cond_node(return_name_ids, node.test, true_func_node,
false_func_node)
new_node = create_convert_ifelse_node(return_name_ids, node.test,
true_func_node, false_func_node)
return new_vars_stmts + [true_func_node, false_func_node] + [new_node]
......@@ -86,8 +78,8 @@ class IfElseTransformer(gast.NodeTransformer):
"""
self.generic_visit(node)
new_node = create_cond_node(None, node.test, node.body, node.orelse,
True)
new_node = create_convert_ifelse_node(None, node.test, node.body,
node.orelse, True)
# Note: A blank line will be added separately if transform gast.Expr
# into source code. Using gast.Expr.value instead to avoid syntax error
# in python.
......@@ -108,6 +100,7 @@ class NameVisitor(gast.NodeVisitor):
# Available only when end_node is set.
self._is_finished = False
self._candidate_ctxs = (gast.Store, gast.Load, gast.Param)
self._def_func_names = set()
def visit(self, node):
"""Visit a node."""
......@@ -173,6 +166,8 @@ class NameVisitor(gast.NodeVisitor):
def visit_Name(self, node):
blacklist = {'True', 'False', 'None'}
if node.id in blacklist: return
if node.id in self._def_func_names:
return
if not self._is_call_func_name_node(node):
if isinstance(node.ctx, self._candidate_ctxs):
self.name_ids[node.id].append(node.ctx)
......@@ -183,6 +178,7 @@ class NameVisitor(gast.NodeVisitor):
self.generic_visit(node)
def visit_FunctionDef(self, node):
self._def_func_names.add(node.name)
if not self.end_node:
self.generic_visit(node)
else:
......@@ -274,6 +270,7 @@ def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load):
kw_defaults=None,
kwarg=None,
defaults=[])
return arguments
......@@ -453,56 +450,59 @@ def transform_if_else(node, root):
name=unique_name.generate(FALSE_FUNC_PREFIX),
input_args=parse_cond_args(orelse_name_ids, modified_name_ids),
return_name_ids=return_name_ids)
return create_new_vars_in_parent_stmts, true_func_node, false_func_node, return_name_ids
def create_cond_node(return_name_ids,
pred,
true_func,
false_func,
is_if_expr=False):
def create_convert_ifelse_node(return_name_ids,
pred,
true_func,
false_func,
is_if_expr=False):
"""
Create `fluid.layers.cond(pred, true_fn, false_fn)` to replace
original `python if/else` statement.
Create `fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
pred, true_fn, false_fn, true_args, false_args, return_vars)`
to replace original `python if/else` statement.
"""
def create_lambda_node(func_or_expr_node, is_if_expr=False):
body = func_or_expr_node
if not is_if_expr:
body = gast.Call(
func=gast.Name(
id=func_or_expr_node.name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
args=[func_or_expr_node.args],
keywords=[])
lambda_node = gast.Lambda(
args=gast.arguments(
args=[],
posonlyargs=[],
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
defaults=[]),
body=body)
return lambda_node
cond_api = gast.parse(
'fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse'
).body[0].value
true_func_lambda = create_lambda_node(true_func, is_if_expr)
false_func_lambda = create_lambda_node(false_func, is_if_expr)
cond_layer = gast.Call(
func=cond_api,
args=[pred, true_func_lambda, false_func_lambda],
keywords=[])
def create_name_nodes(name_ids):
if not name_ids:
return gast.Tuple(elts=[], ctx=gast.Load())
gast_names = [
gast.Name(
id=name_id, ctx=gast.Load(), annotation=None, type_comment=None)
for name_id in name_ids
]
name_node = gast.Tuple(elts=gast_names, ctx=gast.Load())
return name_node
if is_if_expr:
true_args = gast.Tuple(elts=[], ctx=gast.Load())
false_args = gast.Tuple(elts=[], ctx=gast.Load())
true_func_source = "lambda : {}".format(ast_to_source_code(true_func))
false_func_source = "lambda : {}".format(ast_to_source_code(false_func))
else:
true_args = gast.Tuple(elts=true_func.args.args, ctx=gast.Load())
false_args = gast.Tuple(elts=false_func.args.args, ctx=gast.Load())
true_func_source = true_func.name
false_func_source = false_func.name
return_vars = create_name_nodes(return_name_ids)
convert_ifelse_layer = gast.parse(
'fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse('
'{pred}, {true_fn}, {false_fn}, {true_args}, {false_args}, {return_vars})'.
format(
pred=ast_to_source_code(pred),
true_fn=true_func_source,
false_fn=false_func_source,
true_args=ast_to_source_code(true_args),
false_args=ast_to_source_code(false_args),
return_vars=ast_to_source_code(return_vars))).body[0].value
if return_name_ids:
_, cond_node = create_assign_node(return_name_ids, cond_layer)
_, cond_node = create_assign_node(return_name_ids, convert_ifelse_layer)
else: # No variables can be returned if no assign statement in if.body.
cond_node = gast.Expr(value=cond_layer)
cond_node = gast.Expr(value=convert_ifelse_layer)
return cond_node
......@@ -14,16 +14,36 @@
from __future__ import print_function
import copy
import gast
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.utils import create_api_shape_node
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
def create_convert_shape_node(var_shape_node):
assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript))
convert_var_shape_func = "fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape"
if isinstance(var_shape_node, gast.Attribute):
api_shape_node = gast.Call(
func=gast.parse(convert_var_shape_func).body[0].value,
args=[var_shape_node.value],
keywords=[])
return api_shape_node
if isinstance(var_shape_node, gast.Subscript):
result_node = copy.deepcopy(var_shape_node)
result_node.value = create_convert_shape_node(result_node.value)
return result_node
class TensorShapeTransformer(gast.NodeTransformer):
"""
This class transforms Tensor.shape used in Paddle Apis and control flow conditions into Static Graph Ast.
This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast.
"""
def __init__(self, wrapper_root):
......@@ -32,7 +52,7 @@ class TensorShapeTransformer(gast.NodeTransformer):
), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.name_to_tensor_shape = {}
self.name_to_var_shape = {}
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
......@@ -42,58 +62,60 @@ class TensorShapeTransformer(gast.NodeTransformer):
self.scope_var_type_dict = var_env.get_scope_var_type()
def transform(self):
SplitAssignTransformer(self.root).transform()
self.visit(self.root)
def visit_Assign(self, node):
if self._update_name_to_tensor_shape(node):
if self._update_name_to_var_shape(node):
return node
self.generic_visit(node)
return node
def visit_Attribute(self, node):
if self._used_by_paddle_api(node):
if self.is_tensor_shape(node):
return create_api_shape_node(node)
if self.is_var_shape(node):
return create_convert_shape_node(node)
return node
def visit_Name(self, node):
if node.id in self.name_to_tensor_shape:
if node.id in self.name_to_var_shape:
if self._used_by_paddle_api(node):
tensor_shape_node = self.name_to_tensor_shape[node.id]
return create_api_shape_node(tensor_shape_node)
var_shape_node = self.name_to_var_shape[node.id]
return create_convert_shape_node(var_shape_node)
return node
def visit_Call(self, node):
assert isinstance(node, gast.Call)
if is_paddle_api(node):
# Visit gast.Attribute and gast.Name to replace tensor.shape if necessary.
# Visit gast.Attribute and gast.Name to replace var.shape if necessary.
self.generic_visit(node)
return node
def visit_If(self, node):
# Call generic_visit first to transform Tensor.shape that is used in Paddle Api.
# Call generic_visit first to transform var.shape that is used in Paddle Api.
self.generic_visit(node)
cond = node.test
self._transform_tensor_shape_if_necessary(cond)
self._transform_var_shape_if_necessary(cond)
return node
def visit_While(self, node):
self.generic_visit(node)
cond = node.test
self._transform_tensor_shape_if_necessary(cond)
self._transform_var_shape_if_necessary(cond)
return node
def visit_For(self, node):
self.generic_visit(node)
iter = node.iter
self._transform_tensor_shape_if_necessary(iter)
self._transform_var_shape_if_necessary(iter)
# If tensor.shape is a gast.Name and it is used in range function, transform it
self._transform_tensor_shape_in_range(node)
# If var.shape is a gast.Name and it is used in range function, transform it
self._transform_var_shape_in_range(node)
return node
def _transform_tensor_shape_in_range(self, node):
def _transform_var_shape_in_range(self, node):
assert isinstance(node, gast.For)
if not isinstance(node.iter, gast.Call):
return False
......@@ -103,31 +125,33 @@ class TensorShapeTransformer(gast.NodeTransformer):
return False
args = node.iter.args
for idx, arg in enumerate(args):
if isinstance(arg,
gast.Name) and arg.id in self.name_to_tensor_shape:
args[idx] = create_api_shape_node(self.name_to_tensor_shape[
if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape:
args[idx] = create_convert_shape_node(self.name_to_var_shape[
arg.id])
return True
def _transform_tensor_shape_if_necessary(self, cond):
def _transform_var_shape_if_necessary(self, cond):
need_transformed = False
for child_node in gast.walk(cond):
tensor_shape_node = None
var_shape_node = None
if isinstance(child_node, (gast.Attribute)):
if self.is_tensor_shape(child_node):
tensor_shape_node = child_node
if self.is_var_shape(child_node):
var_shape_node = child_node
elif isinstance(child_node, (gast.Name)):
if child_node.id in self.name_to_tensor_shape:
tensor_shape_node = self.name_to_tensor_shape[child_node.id]
if child_node.id in self.name_to_var_shape:
var_shape_node = self.name_to_var_shape[child_node.id]
if tensor_shape_node:
if var_shape_node:
need_transformed = True
wrapper_node = self.node_to_wrapper_map.get(child_node)
parent_node = wrapper_node.parent.node
for field, value in gast.iter_fields(parent_node):
if child_node is value:
setattr(parent_node, field,
create_api_shape_node(tensor_shape_node))
create_convert_shape_node(var_shape_node))
break
return need_transformed
def _used_by_paddle_api(self, node):
assert isinstance(node, (gast.Attribute, gast.Name))
......@@ -146,11 +170,12 @@ class TensorShapeTransformer(gast.NodeTransformer):
return False
def is_tensor_shape(self, node):
def is_var_shape(self, node):
"""
Return True if node is like `x.shape` and x is Tensor, return False otherwise.
Return True if node is like `x.shape`, return False otherwise.
"""
assert isinstance(node, gast.Attribute)
if node.attr != 'shape':
return False
......@@ -159,26 +184,13 @@ class TensorShapeTransformer(gast.NodeTransformer):
except AttributeError:
return False
if value_id in self.name_to_tensor_shape:
if value_id in self.name_to_var_shape:
return True
# TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function
# Need a better way to confirm whether `value_id` is a Tensor.
try:
var_type_set = self.scope_var_type_dict[value_id]
except KeyError:
return False
if NodeVarType.NUMPY_NDARRAY in var_type_set:
return False
if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
return False
return True
def _update_name_to_tensor_shape(self, node):
def _update_name_to_var_shape(self, node):
assert isinstance(node, gast.Assign)
# TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1]
target_node = node.targets[0]
try:
target_id = target_node.id
......@@ -187,17 +199,17 @@ class TensorShapeTransformer(gast.NodeTransformer):
value_node = node.value
if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_tensor_shape:
self.name_to_tensor_shape[
target_id] = self.name_to_tensor_shape[value_node.id]
if value_node.id in self.name_to_var_shape:
self.name_to_var_shape[target_id] = self.name_to_var_shape[
value_node.id]
return True
if isinstance(value_node, gast.Attribute):
if self.is_tensor_shape(value_node): # eg: x.shape
self.name_to_tensor_shape[target_id] = value_node
if self.is_var_shape(value_node): # eg: x.shape
self.name_to_var_shape[target_id] = value_node
return True
if isinstance(value_node, gast.Subscript):
if isinstance(value_node.value, gast.Attribute):
if self.is_tensor_shape(value_node.value): # eg: x.shape[0]
self.name_to_tensor_shape[target_id] = value_node
if self.is_var_shape(value_node.value): # eg: x.shape[0]
self.name_to_var_shape[target_id] = value_node
return True
return False
......@@ -117,11 +117,7 @@ def to_static_variable(x):
if isinstance(x, float):
return fill_constant(shape=[1], dtype='float64', value=x)
if six.PY2:
if isinstance(x, (int, long)):
return fill_constant(shape=[1], dtype='int64', value=x)
else:
if isinstance(x, int):
return fill_constant(shape=[1], dtype='int64', value=x)
if isinstance(x, six.integer_types):
return fill_constant(shape=[1], dtype='int64', value=x)
return x
......@@ -112,9 +112,9 @@ def _yield_flat_nest(nest):
def flatten(nest):
"""
:alias_main: paddle.flatten
:alias: paddle.flatten,paddle.tensor.flatten,paddle.tensor.manipulation.flatten
:old_api: paddle.fluid.layers.flatten
:alias_main: paddle.flatten
:alias: paddle.flatten,paddle.tensor.flatten,paddle.tensor.manipulation.flatten
:old_api: paddle.fluid.layers.flatten
Traverse all entries in the nested structure and put them into an list.
"""
......@@ -341,7 +341,7 @@ def _convert_to_tensor_list(old_list, dtype="int32"):
ele.stop_gradient = True
new_list_tensor.append(ele)
else:
assert (isinstance(ele, int))
assert isinstance(ele, six.integer_types)
temp_out = fill_constant([1], dtype, ele, force_cpu=True)
new_list_tensor.append(temp_out)
return new_list_tensor
......@@ -42,7 +42,10 @@ def dyfunc_with_if_else(x_v, label=None):
def dyfunc_with_if_else2(x, col=100):
row = 0
if abs(col) > x.shape[-1]:
col = -1
# TODO: Don't support return non-Tensor in Tensor-dependent `if` stament currently.
# `x` is Tensor, `col` is not Tensor, and `col` is the return value of `true_fn` after transformed.
# col = -1
col = fluid.layers.fill_constant(shape=[1], value=-1, dtype="int64")
if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[row][col]:
y = fluid.layers.relu(x)
else:
......@@ -101,7 +104,12 @@ def nested_if_else(x_v):
feat_size = x_v.shape[-1]
bias = fluid.layers.fill_constant([feat_size], dtype='float32', value=1)
if x_v.shape[0] != batch_size:
batch_size = x_v.shape[0]
# TODO: Don't support return non-Tensor in Tensor-dependent `if` stament currently.
# `x_v.shape[0]` is not Tensor, and `batch_size` is the return value of `true_fn` after transformed.
# col = -1
# batch_size = x_v.shape[0]
batch_size = fluid.layers.shape(x_v)[0]
# if tensor.shape is [1], now support to compare with numpy.
if fluid.layers.mean(x_v).numpy() < 0:
y = x_v + bias
......
......@@ -72,10 +72,8 @@ class StaticCode1():
return x_v
x_v = fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5,
lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_0)(x_v),
lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_0)(x_v)
)
fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ),
(x_v, ), (x_v, ))
def true_fn_1(label, x_v):
loss = fluid.layers.cross_entropy(x_v, label)
......@@ -86,9 +84,7 @@ class StaticCode1():
return
fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
label is not None,
lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_1)(label, x_v),
lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_1)())
label is not None, true_fn_1, false_fn_1, (label, x_v), (), ())
return x_v
......@@ -104,10 +100,8 @@ class StaticCode2():
return x_v
x_v = fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5,
lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_2)(x_v),
lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_2)(x_v)
)
fluid.layers.mean(x_v)[0] > 5, true_fn_2, false_fn_2, (x_v, ),
(x_v, ), (x_v, ))
def true_fn_3(label, x_v):
loss = fluid.layers.cross_entropy(x_v, label)
......@@ -118,9 +112,7 @@ class StaticCode2():
return
fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
label is not None,
lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_3)(label, x_v),
lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_3)())
label is not None, true_fn_3, false_fn_3, (label, x_v), (), ())
return x_v
......@@ -138,7 +130,6 @@ class TestDygraphToStaticCode(unittest.TestCase):
self.maxDiff = None
def test_decorator(self):
x_v = None
program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else)
answer = get_source_code(StaticCode1.dyfunc_with_if_else)
......
......@@ -36,7 +36,7 @@ def dyfunc_tensor_shape_2(x):
def dyfunc_tensor_shape_3(x):
# Don't transform y.shape because y is numpy.ndarray
# Transform y.shape but run y.shape actually because y is not Tensor
x = fluid.dygraph.to_variable(x)
y = numpy.ones(5)
res = fluid.layers.reshape(x, shape=y.shape)
......@@ -51,7 +51,8 @@ def dyfunc_tensor_shape_4(x):
def dyfunc_tensor_shape_5(x):
# `res = fluid.layers.reshape(x, shape=(-1, s))` to
# `res = fluid.layers.reshape(x, shape=(-1, fluid.layers.shape(x)[0]))`
# `res = fluid.layers.reshape(x, shape=(-1,
# fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]))`
x = fluid.dygraph.to_variable(x)
s = x.shape[0]
res = fluid.layers.reshape(x, shape=(-1, s))
......@@ -63,7 +64,8 @@ def dyfunc_with_if_1(x):
res = fluid.layers.reshape(x, [-1, 1])
x_shape_0 = x.shape[0]
if x_shape_0 < 1:
# `res.shape[0] > 1` is transformed into `if fluid.layers.shape(res)[0] > 1`
# `res.shape[0]` is transformed into
# `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(res)[0]`
if res.shape[0] > 1:
res = fluid.layers.fill_constant(
value=2, shape=x.shape, dtype="int32")
......@@ -75,7 +77,7 @@ def dyfunc_with_if_1(x):
def dyfunc_with_if_2(x):
x = fluid.dygraph.to_variable(x)
# `len(x.shape)` will not be transformed.
# `len(x.shape)` will not be transformed because x.shape is not used by Paddle api.
if len(x.shape) < 1:
res = x
else:
......@@ -87,7 +89,7 @@ def dyfunc_with_if_2(x):
def dyfunc_with_for_1(x):
x = fluid.dygraph.to_variable(x)
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
# `x.shape[0]` is transformed into `fluid.layers.shape(x)[0]`
# `x.shape[0]` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]`
for i in range(x.shape[0]):
res += 1
return res
......@@ -98,7 +100,7 @@ def dyfunc_with_for_2(x):
x_shape_0 = x.shape[0]
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
# `x_shape_0` is transformed into `fluid.layers.shape(x)[0]`
# `x_shape_0` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]`
for i in range(x_shape_0):
res += 1
return res
......@@ -122,7 +124,7 @@ def dyfunc_with_for_3(x):
def dyfunc_with_while_1(x):
x = fluid.dygraph.to_variable(x)
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
# `x.shape[0]` is transformed into `fluid.layers.shape(x)[0]`
# `x.shape[0]` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]`
i = 1
while i < x.shape[0]:
res += 1
......@@ -135,19 +137,14 @@ def dyfunc_with_while_2(x):
x_shape_0 = x.shape[0]
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
i = 1
# `x_shape_0` is transformed into `fluid.layers.shape(x)[0]`
# TODO(liym27): If `x_shape_0` is at right like `while i < x_shape_0`, it will not be transformed.
# Fix this bug next PR.
while x_shape_0 > i:
# `x_shape_0` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]`
while i < x_shape_0:
res += 1
i = i + 2
return res
def dyfunc_with_while_3(x):
# TODO(liym27):
# It will fail to run because the same problem as `dyfunc_with_for_3`.
# After the AST tranformation of for loop is improved, add TestTensorShapeInWhile3.
x = fluid.dygraph.to_variable(x)
x_shape = x.shape
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
......@@ -160,6 +157,19 @@ def dyfunc_with_while_3(x):
return res
def dyfunc_with_while_4(x):
x = fluid.dygraph.to_variable(x)
y = numpy.ones(5)
y_shape_0 = y.shape[0]
i = 1
# Transform y_shape_0 but run y.shape[0] actually because y is not Tensor
while y_shape_0 > i:
x += 1
i += 1
return x
# 1. Basic tests without control flow
class TestTensorShapeBasic(unittest.TestCase):
def setUp(self):
......@@ -183,7 +193,7 @@ class TestTensorShapeBasic(unittest.TestCase):
return self._run(to_static=False)
def get_static_output(self):
return self._run(to_static=False)
return self._run(to_static=True)
def test_transformed_static_result(self):
static_res = self.get_static_output()
......@@ -247,5 +257,15 @@ class TestTensorShapeInWhile2(TestTensorShapeBasic):
self.dygraph_func = dyfunc_with_while_2
class TestTensorShapeInWhile3(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_with_while_3
class TestTensorShapeInWhile4(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_with_while_4
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册