未验证 提交 f16e2778 编写于 作者: L liym27 提交者: GitHub

[Dy2Static]Convert var.shape stmt and Convert the return variables of...

[Dy2Static]Convert var.shape stmt and Convert the return variables of Tensor-dependent 'if' staments to Tensor if it not  (#24911)

* Support int and long: int or long -> six.integer_types. 

* Modify test_tensor_shape: fix bug and modify comment. 

* Support convert_var_shape to convert var.shape stmt

* Modify code in ifelse_simple_func.py because don't support return non-Tensor in Tensor-dependent 'if' stament currently. 

* Convert the return variables of Tensor-dependent 'if' staments to Tensor if it not. test=develop
上级 25a4dac4
...@@ -43,6 +43,7 @@ def convert_while_loop(cond, body, loop_vars): ...@@ -43,6 +43,7 @@ def convert_while_loop(cond, body, loop_vars):
def _run_paddle_while_loop(cond, body, loop_vars): def _run_paddle_while_loop(cond, body, loop_vars):
# NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Variable.
loop_vars = [to_static_variable(var) for var in loop_vars] loop_vars = [to_static_variable(var) for var in loop_vars]
loop_vars = control_flow.while_loop(cond, body, loop_vars) loop_vars = control_flow.while_loop(cond, body, loop_vars)
return loop_vars return loop_vars
...@@ -146,7 +147,7 @@ def _run_py_logical_not(x): ...@@ -146,7 +147,7 @@ def _run_py_logical_not(x):
return not x return not x
def convert_ifelse(pred, true_fn, false_fn): def convert_ifelse(pred, true_fn, false_fn, true_args, false_args, return_vars):
""" """
A function representation of a Python ``if/else`` statement. A function representation of a Python ``if/else`` statement.
...@@ -154,25 +155,45 @@ def convert_ifelse(pred, true_fn, false_fn): ...@@ -154,25 +155,45 @@ def convert_ifelse(pred, true_fn, false_fn):
pred(bool|Variable): A boolean variable which determines whether to return the result of ``true_fn`` or ``false_fn`` . pred(bool|Variable): A boolean variable which determines whether to return the result of ``true_fn`` or ``false_fn`` .
true_fn(callable): A callable to be performed if ``pred`` is true. true_fn(callable): A callable to be performed if ``pred`` is true.
false_fn(callable): A callable to be performed if ``pred`` is false. false_fn(callable): A callable to be performed if ``pred`` is false.
true_args(tuple): Parameters of ``true_fn``.
false_args(tuple): Parameters of ``false_fn``.
return_vars(tuple): Return variables of ``true_fn`` and ``false_fn``.
Returns: Returns:
``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` . ``true_fn(true_args)`` if the predicate ``pred`` is true else ``false_fn(false_args)`` .
""" """
if isinstance(pred, Variable): if isinstance(pred, Variable):
return _run_paddle_cond(pred, true_fn, false_fn) return _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
return_vars)
else: else:
return _run_py_ifelse(pred, true_fn, false_fn) return _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args)
def _run_paddle_cond(pred, true_fn, false_fn): def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
return_vars):
return_var_ids = [id(var) for var in return_vars]
# NOTE 1: return vars of Paddle op `control_flow.cond` must be Paddle Variable
# NOTE 2: Here uses id(var) not var, because `if var in return_var` use operator `==`,
# which will call `fluid.layers.equal` and causes error when var in return_vars is not initialized.
true_args = [
to_static_variable(var) if id(var) in return_var_ids else var
for var in true_args
]
false_args = [
to_static_variable(var) if id(var) in return_var_ids else var
for var in false_args
]
pred = cast_bool_if_necessary(pred) pred = cast_bool_if_necessary(pred)
return control_flow.cond(pred, true_fn, false_fn) return control_flow.cond(pred, lambda: true_fn(*true_args),
lambda: false_fn(*false_args))
def _run_py_ifelse(pred, true_fn, false_fn): def _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args):
return true_fn() if pred else false_fn() return true_fn(*true_args) if pred else false_fn(*false_args)
def convert_len(var): def convert_len(var):
...@@ -202,6 +223,16 @@ def convert_len(var): ...@@ -202,6 +223,16 @@ def convert_len(var):
return len(var) return len(var)
def convert_var_shape(x):
"""
A function representation of the shape of variable.
"""
if isinstance(x, Variable):
return nn.shape(x)
else:
return x.shape
def cast_bool_if_necessary(var): def cast_bool_if_necessary(var):
assert isinstance(var, Variable) assert isinstance(var, Variable)
if convert_dtype(var.dtype) not in ['bool']: if convert_dtype(var.dtype) not in ['bool']:
......
...@@ -24,23 +24,14 @@ from collections import defaultdict ...@@ -24,23 +24,14 @@ from collections import defaultdict
import gast import gast
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import compare_with_none from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import is_candidate_node
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node
from paddle.fluid.dygraph.dygraph_to_static.utils import IsControlFlowVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
TRUE_FUNC_PREFIX = 'true_fn' TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn' FALSE_FUNC_PREFIX = 'false_fn'
LOGIC_AND_PREFIX = 'logic_and'
LOGIC_OR_PREFIX = 'logic_or'
LOGIC_NOT_PREFIX = 'logic_not'
PLAIN_TENSOR_PREFIX = 'bool_tensor'
class IfElseTransformer(gast.NodeTransformer): class IfElseTransformer(gast.NodeTransformer):
...@@ -66,8 +57,9 @@ class IfElseTransformer(gast.NodeTransformer): ...@@ -66,8 +57,9 @@ class IfElseTransformer(gast.NodeTransformer):
self.generic_visit(node) self.generic_visit(node)
new_vars_stmts, true_func_node, false_func_node, return_name_ids = transform_if_else( new_vars_stmts, true_func_node, false_func_node, return_name_ids = transform_if_else(
node, self.root) node, self.root)
new_node = create_cond_node(return_name_ids, node.test, true_func_node,
false_func_node) new_node = create_convert_ifelse_node(return_name_ids, node.test,
true_func_node, false_func_node)
return new_vars_stmts + [true_func_node, false_func_node] + [new_node] return new_vars_stmts + [true_func_node, false_func_node] + [new_node]
...@@ -86,8 +78,8 @@ class IfElseTransformer(gast.NodeTransformer): ...@@ -86,8 +78,8 @@ class IfElseTransformer(gast.NodeTransformer):
""" """
self.generic_visit(node) self.generic_visit(node)
new_node = create_cond_node(None, node.test, node.body, node.orelse, new_node = create_convert_ifelse_node(None, node.test, node.body,
True) node.orelse, True)
# Note: A blank line will be added separately if transform gast.Expr # Note: A blank line will be added separately if transform gast.Expr
# into source code. Using gast.Expr.value instead to avoid syntax error # into source code. Using gast.Expr.value instead to avoid syntax error
# in python. # in python.
...@@ -108,6 +100,7 @@ class NameVisitor(gast.NodeVisitor): ...@@ -108,6 +100,7 @@ class NameVisitor(gast.NodeVisitor):
# Available only when end_node is set. # Available only when end_node is set.
self._is_finished = False self._is_finished = False
self._candidate_ctxs = (gast.Store, gast.Load, gast.Param) self._candidate_ctxs = (gast.Store, gast.Load, gast.Param)
self._def_func_names = set()
def visit(self, node): def visit(self, node):
"""Visit a node.""" """Visit a node."""
...@@ -173,6 +166,8 @@ class NameVisitor(gast.NodeVisitor): ...@@ -173,6 +166,8 @@ class NameVisitor(gast.NodeVisitor):
def visit_Name(self, node): def visit_Name(self, node):
blacklist = {'True', 'False', 'None'} blacklist = {'True', 'False', 'None'}
if node.id in blacklist: return if node.id in blacklist: return
if node.id in self._def_func_names:
return
if not self._is_call_func_name_node(node): if not self._is_call_func_name_node(node):
if isinstance(node.ctx, self._candidate_ctxs): if isinstance(node.ctx, self._candidate_ctxs):
self.name_ids[node.id].append(node.ctx) self.name_ids[node.id].append(node.ctx)
...@@ -183,6 +178,7 @@ class NameVisitor(gast.NodeVisitor): ...@@ -183,6 +178,7 @@ class NameVisitor(gast.NodeVisitor):
self.generic_visit(node) self.generic_visit(node)
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
self._def_func_names.add(node.name)
if not self.end_node: if not self.end_node:
self.generic_visit(node) self.generic_visit(node)
else: else:
...@@ -274,6 +270,7 @@ def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load): ...@@ -274,6 +270,7 @@ def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load):
kw_defaults=None, kw_defaults=None,
kwarg=None, kwarg=None,
defaults=[]) defaults=[])
return arguments return arguments
...@@ -453,56 +450,59 @@ def transform_if_else(node, root): ...@@ -453,56 +450,59 @@ def transform_if_else(node, root):
name=unique_name.generate(FALSE_FUNC_PREFIX), name=unique_name.generate(FALSE_FUNC_PREFIX),
input_args=parse_cond_args(orelse_name_ids, modified_name_ids), input_args=parse_cond_args(orelse_name_ids, modified_name_ids),
return_name_ids=return_name_ids) return_name_ids=return_name_ids)
return create_new_vars_in_parent_stmts, true_func_node, false_func_node, return_name_ids return create_new_vars_in_parent_stmts, true_func_node, false_func_node, return_name_ids
def create_cond_node(return_name_ids, def create_convert_ifelse_node(return_name_ids,
pred, pred,
true_func, true_func,
false_func, false_func,
is_if_expr=False): is_if_expr=False):
""" """
Create `fluid.layers.cond(pred, true_fn, false_fn)` to replace Create `fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
original `python if/else` statement. pred, true_fn, false_fn, true_args, false_args, return_vars)`
to replace original `python if/else` statement.
""" """
def create_lambda_node(func_or_expr_node, is_if_expr=False): def create_name_nodes(name_ids):
body = func_or_expr_node if not name_ids:
if not is_if_expr: return gast.Tuple(elts=[], ctx=gast.Load())
body = gast.Call(
func=gast.Name( gast_names = [
id=func_or_expr_node.name, gast.Name(
ctx=gast.Load(), id=name_id, ctx=gast.Load(), annotation=None, type_comment=None)
annotation=None, for name_id in name_ids
type_comment=None), ]
args=[func_or_expr_node.args], name_node = gast.Tuple(elts=gast_names, ctx=gast.Load())
keywords=[]) return name_node
lambda_node = gast.Lambda( if is_if_expr:
args=gast.arguments( true_args = gast.Tuple(elts=[], ctx=gast.Load())
args=[], false_args = gast.Tuple(elts=[], ctx=gast.Load())
posonlyargs=[], true_func_source = "lambda : {}".format(ast_to_source_code(true_func))
vararg=None, false_func_source = "lambda : {}".format(ast_to_source_code(false_func))
kwonlyargs=[], else:
kw_defaults=None, true_args = gast.Tuple(elts=true_func.args.args, ctx=gast.Load())
kwarg=None, false_args = gast.Tuple(elts=false_func.args.args, ctx=gast.Load())
defaults=[]), true_func_source = true_func.name
body=body) false_func_source = false_func.name
return lambda_node
return_vars = create_name_nodes(return_name_ids)
cond_api = gast.parse(
'fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse' convert_ifelse_layer = gast.parse(
).body[0].value 'fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse('
true_func_lambda = create_lambda_node(true_func, is_if_expr) '{pred}, {true_fn}, {false_fn}, {true_args}, {false_args}, {return_vars})'.
false_func_lambda = create_lambda_node(false_func, is_if_expr) format(
cond_layer = gast.Call( pred=ast_to_source_code(pred),
func=cond_api, true_fn=true_func_source,
args=[pred, true_func_lambda, false_func_lambda], false_fn=false_func_source,
keywords=[]) true_args=ast_to_source_code(true_args),
false_args=ast_to_source_code(false_args),
return_vars=ast_to_source_code(return_vars))).body[0].value
if return_name_ids: if return_name_ids:
_, cond_node = create_assign_node(return_name_ids, cond_layer) _, cond_node = create_assign_node(return_name_ids, convert_ifelse_layer)
else: # No variables can be returned if no assign statement in if.body. else: # No variables can be returned if no assign statement in if.body.
cond_node = gast.Expr(value=cond_layer) cond_node = gast.Expr(value=convert_ifelse_layer)
return cond_node return cond_node
...@@ -14,16 +14,36 @@ ...@@ -14,16 +14,36 @@
from __future__ import print_function from __future__ import print_function
import copy
import gast import gast
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.utils import create_api_shape_node from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
def create_convert_shape_node(var_shape_node):
assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript))
convert_var_shape_func = "fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape"
if isinstance(var_shape_node, gast.Attribute):
api_shape_node = gast.Call(
func=gast.parse(convert_var_shape_func).body[0].value,
args=[var_shape_node.value],
keywords=[])
return api_shape_node
if isinstance(var_shape_node, gast.Subscript):
result_node = copy.deepcopy(var_shape_node)
result_node.value = create_convert_shape_node(result_node.value)
return result_node
class TensorShapeTransformer(gast.NodeTransformer): class TensorShapeTransformer(gast.NodeTransformer):
""" """
This class transforms Tensor.shape used in Paddle Apis and control flow conditions into Static Graph Ast. This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast.
""" """
def __init__(self, wrapper_root): def __init__(self, wrapper_root):
...@@ -32,7 +52,7 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -32,7 +52,7 @@ class TensorShapeTransformer(gast.NodeTransformer):
), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
self.wrapper_root = wrapper_root self.wrapper_root = wrapper_root
self.root = wrapper_root.node self.root = wrapper_root.node
self.name_to_tensor_shape = {} self.name_to_var_shape = {}
self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
...@@ -42,58 +62,60 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -42,58 +62,60 @@ class TensorShapeTransformer(gast.NodeTransformer):
self.scope_var_type_dict = var_env.get_scope_var_type() self.scope_var_type_dict = var_env.get_scope_var_type()
def transform(self): def transform(self):
SplitAssignTransformer(self.root).transform()
self.visit(self.root) self.visit(self.root)
def visit_Assign(self, node): def visit_Assign(self, node):
if self._update_name_to_tensor_shape(node): if self._update_name_to_var_shape(node):
return node return node
self.generic_visit(node) self.generic_visit(node)
return node return node
def visit_Attribute(self, node): def visit_Attribute(self, node):
if self._used_by_paddle_api(node): if self._used_by_paddle_api(node):
if self.is_tensor_shape(node): if self.is_var_shape(node):
return create_api_shape_node(node) return create_convert_shape_node(node)
return node return node
def visit_Name(self, node): def visit_Name(self, node):
if node.id in self.name_to_tensor_shape: if node.id in self.name_to_var_shape:
if self._used_by_paddle_api(node): if self._used_by_paddle_api(node):
tensor_shape_node = self.name_to_tensor_shape[node.id] var_shape_node = self.name_to_var_shape[node.id]
return create_api_shape_node(tensor_shape_node) return create_convert_shape_node(var_shape_node)
return node return node
def visit_Call(self, node): def visit_Call(self, node):
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
if is_paddle_api(node): if is_paddle_api(node):
# Visit gast.Attribute and gast.Name to replace tensor.shape if necessary. # Visit gast.Attribute and gast.Name to replace var.shape if necessary.
self.generic_visit(node) self.generic_visit(node)
return node return node
def visit_If(self, node): def visit_If(self, node):
# Call generic_visit first to transform Tensor.shape that is used in Paddle Api. # Call generic_visit first to transform var.shape that is used in Paddle Api.
self.generic_visit(node) self.generic_visit(node)
cond = node.test cond = node.test
self._transform_tensor_shape_if_necessary(cond) self._transform_var_shape_if_necessary(cond)
return node return node
def visit_While(self, node): def visit_While(self, node):
self.generic_visit(node) self.generic_visit(node)
cond = node.test cond = node.test
self._transform_tensor_shape_if_necessary(cond) self._transform_var_shape_if_necessary(cond)
return node return node
def visit_For(self, node): def visit_For(self, node):
self.generic_visit(node) self.generic_visit(node)
iter = node.iter iter = node.iter
self._transform_tensor_shape_if_necessary(iter) self._transform_var_shape_if_necessary(iter)
# If tensor.shape is a gast.Name and it is used in range function, transform it # If var.shape is a gast.Name and it is used in range function, transform it
self._transform_tensor_shape_in_range(node) self._transform_var_shape_in_range(node)
return node return node
def _transform_tensor_shape_in_range(self, node): def _transform_var_shape_in_range(self, node):
assert isinstance(node, gast.For) assert isinstance(node, gast.For)
if not isinstance(node.iter, gast.Call): if not isinstance(node.iter, gast.Call):
return False return False
...@@ -103,31 +125,33 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -103,31 +125,33 @@ class TensorShapeTransformer(gast.NodeTransformer):
return False return False
args = node.iter.args args = node.iter.args
for idx, arg in enumerate(args): for idx, arg in enumerate(args):
if isinstance(arg, if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape:
gast.Name) and arg.id in self.name_to_tensor_shape: args[idx] = create_convert_shape_node(self.name_to_var_shape[
args[idx] = create_api_shape_node(self.name_to_tensor_shape[
arg.id]) arg.id])
return True return True
def _transform_tensor_shape_if_necessary(self, cond): def _transform_var_shape_if_necessary(self, cond):
need_transformed = False
for child_node in gast.walk(cond): for child_node in gast.walk(cond):
tensor_shape_node = None var_shape_node = None
if isinstance(child_node, (gast.Attribute)): if isinstance(child_node, (gast.Attribute)):
if self.is_tensor_shape(child_node): if self.is_var_shape(child_node):
tensor_shape_node = child_node var_shape_node = child_node
elif isinstance(child_node, (gast.Name)): elif isinstance(child_node, (gast.Name)):
if child_node.id in self.name_to_tensor_shape: if child_node.id in self.name_to_var_shape:
tensor_shape_node = self.name_to_tensor_shape[child_node.id] var_shape_node = self.name_to_var_shape[child_node.id]
if tensor_shape_node: if var_shape_node:
need_transformed = True
wrapper_node = self.node_to_wrapper_map.get(child_node) wrapper_node = self.node_to_wrapper_map.get(child_node)
parent_node = wrapper_node.parent.node parent_node = wrapper_node.parent.node
for field, value in gast.iter_fields(parent_node): for field, value in gast.iter_fields(parent_node):
if child_node is value: if child_node is value:
setattr(parent_node, field, setattr(parent_node, field,
create_api_shape_node(tensor_shape_node)) create_convert_shape_node(var_shape_node))
break break
return need_transformed
def _used_by_paddle_api(self, node): def _used_by_paddle_api(self, node):
assert isinstance(node, (gast.Attribute, gast.Name)) assert isinstance(node, (gast.Attribute, gast.Name))
...@@ -146,11 +170,12 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -146,11 +170,12 @@ class TensorShapeTransformer(gast.NodeTransformer):
return False return False
def is_tensor_shape(self, node): def is_var_shape(self, node):
""" """
Return True if node is like `x.shape` and x is Tensor, return False otherwise. Return True if node is like `x.shape`, return False otherwise.
""" """
assert isinstance(node, gast.Attribute) assert isinstance(node, gast.Attribute)
if node.attr != 'shape': if node.attr != 'shape':
return False return False
...@@ -159,26 +184,13 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -159,26 +184,13 @@ class TensorShapeTransformer(gast.NodeTransformer):
except AttributeError: except AttributeError:
return False return False
if value_id in self.name_to_tensor_shape: if value_id in self.name_to_var_shape:
return True return True
# TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function
# Need a better way to confirm whether `value_id` is a Tensor.
try:
var_type_set = self.scope_var_type_dict[value_id]
except KeyError:
return False
if NodeVarType.NUMPY_NDARRAY in var_type_set:
return False
if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
return False
return True return True
def _update_name_to_tensor_shape(self, node): def _update_name_to_var_shape(self, node):
assert isinstance(node, gast.Assign) assert isinstance(node, gast.Assign)
# TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1]
target_node = node.targets[0] target_node = node.targets[0]
try: try:
target_id = target_node.id target_id = target_node.id
...@@ -187,17 +199,17 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -187,17 +199,17 @@ class TensorShapeTransformer(gast.NodeTransformer):
value_node = node.value value_node = node.value
if isinstance(value_node, gast.Name): if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_tensor_shape: if value_node.id in self.name_to_var_shape:
self.name_to_tensor_shape[ self.name_to_var_shape[target_id] = self.name_to_var_shape[
target_id] = self.name_to_tensor_shape[value_node.id] value_node.id]
return True return True
if isinstance(value_node, gast.Attribute): if isinstance(value_node, gast.Attribute):
if self.is_tensor_shape(value_node): # eg: x.shape if self.is_var_shape(value_node): # eg: x.shape
self.name_to_tensor_shape[target_id] = value_node self.name_to_var_shape[target_id] = value_node
return True return True
if isinstance(value_node, gast.Subscript): if isinstance(value_node, gast.Subscript):
if isinstance(value_node.value, gast.Attribute): if isinstance(value_node.value, gast.Attribute):
if self.is_tensor_shape(value_node.value): # eg: x.shape[0] if self.is_var_shape(value_node.value): # eg: x.shape[0]
self.name_to_tensor_shape[target_id] = value_node self.name_to_var_shape[target_id] = value_node
return True return True
return False return False
...@@ -117,11 +117,7 @@ def to_static_variable(x): ...@@ -117,11 +117,7 @@ def to_static_variable(x):
if isinstance(x, float): if isinstance(x, float):
return fill_constant(shape=[1], dtype='float64', value=x) return fill_constant(shape=[1], dtype='float64', value=x)
if six.PY2: if isinstance(x, six.integer_types):
if isinstance(x, (int, long)):
return fill_constant(shape=[1], dtype='int64', value=x)
else:
if isinstance(x, int):
return fill_constant(shape=[1], dtype='int64', value=x) return fill_constant(shape=[1], dtype='int64', value=x)
return x return x
...@@ -341,7 +341,7 @@ def _convert_to_tensor_list(old_list, dtype="int32"): ...@@ -341,7 +341,7 @@ def _convert_to_tensor_list(old_list, dtype="int32"):
ele.stop_gradient = True ele.stop_gradient = True
new_list_tensor.append(ele) new_list_tensor.append(ele)
else: else:
assert (isinstance(ele, int)) assert isinstance(ele, six.integer_types)
temp_out = fill_constant([1], dtype, ele, force_cpu=True) temp_out = fill_constant([1], dtype, ele, force_cpu=True)
new_list_tensor.append(temp_out) new_list_tensor.append(temp_out)
return new_list_tensor return new_list_tensor
...@@ -42,7 +42,10 @@ def dyfunc_with_if_else(x_v, label=None): ...@@ -42,7 +42,10 @@ def dyfunc_with_if_else(x_v, label=None):
def dyfunc_with_if_else2(x, col=100): def dyfunc_with_if_else2(x, col=100):
row = 0 row = 0
if abs(col) > x.shape[-1]: if abs(col) > x.shape[-1]:
col = -1 # TODO: Don't support return non-Tensor in Tensor-dependent `if` stament currently.
# `x` is Tensor, `col` is not Tensor, and `col` is the return value of `true_fn` after transformed.
# col = -1
col = fluid.layers.fill_constant(shape=[1], value=-1, dtype="int64")
if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[row][col]: if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[row][col]:
y = fluid.layers.relu(x) y = fluid.layers.relu(x)
else: else:
...@@ -101,7 +104,12 @@ def nested_if_else(x_v): ...@@ -101,7 +104,12 @@ def nested_if_else(x_v):
feat_size = x_v.shape[-1] feat_size = x_v.shape[-1]
bias = fluid.layers.fill_constant([feat_size], dtype='float32', value=1) bias = fluid.layers.fill_constant([feat_size], dtype='float32', value=1)
if x_v.shape[0] != batch_size: if x_v.shape[0] != batch_size:
batch_size = x_v.shape[0] # TODO: Don't support return non-Tensor in Tensor-dependent `if` stament currently.
# `x_v.shape[0]` is not Tensor, and `batch_size` is the return value of `true_fn` after transformed.
# col = -1
# batch_size = x_v.shape[0]
batch_size = fluid.layers.shape(x_v)[0]
# if tensor.shape is [1], now support to compare with numpy. # if tensor.shape is [1], now support to compare with numpy.
if fluid.layers.mean(x_v).numpy() < 0: if fluid.layers.mean(x_v).numpy() < 0:
y = x_v + bias y = x_v + bias
......
...@@ -72,10 +72,8 @@ class StaticCode1(): ...@@ -72,10 +72,8 @@ class StaticCode1():
return x_v return x_v
x_v = fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( x_v = fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ),
lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_0)(x_v), (x_v, ), (x_v, ))
lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_0)(x_v)
)
def true_fn_1(label, x_v): def true_fn_1(label, x_v):
loss = fluid.layers.cross_entropy(x_v, label) loss = fluid.layers.cross_entropy(x_v, label)
...@@ -86,9 +84,7 @@ class StaticCode1(): ...@@ -86,9 +84,7 @@ class StaticCode1():
return return
fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
label is not None, label is not None, true_fn_1, false_fn_1, (label, x_v), (), ())
lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_1)(label, x_v),
lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_1)())
return x_v return x_v
...@@ -104,10 +100,8 @@ class StaticCode2(): ...@@ -104,10 +100,8 @@ class StaticCode2():
return x_v return x_v
x_v = fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( x_v = fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, fluid.layers.mean(x_v)[0] > 5, true_fn_2, false_fn_2, (x_v, ),
lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_2)(x_v), (x_v, ), (x_v, ))
lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_2)(x_v)
)
def true_fn_3(label, x_v): def true_fn_3(label, x_v):
loss = fluid.layers.cross_entropy(x_v, label) loss = fluid.layers.cross_entropy(x_v, label)
...@@ -118,9 +112,7 @@ class StaticCode2(): ...@@ -118,9 +112,7 @@ class StaticCode2():
return return
fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
label is not None, label is not None, true_fn_3, false_fn_3, (label, x_v), (), ())
lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_3)(label, x_v),
lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_3)())
return x_v return x_v
...@@ -138,7 +130,6 @@ class TestDygraphToStaticCode(unittest.TestCase): ...@@ -138,7 +130,6 @@ class TestDygraphToStaticCode(unittest.TestCase):
self.maxDiff = None self.maxDiff = None
def test_decorator(self): def test_decorator(self):
x_v = None
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else) code = program_translator.get_code(dyfunc_with_if_else)
answer = get_source_code(StaticCode1.dyfunc_with_if_else) answer = get_source_code(StaticCode1.dyfunc_with_if_else)
......
...@@ -36,7 +36,7 @@ def dyfunc_tensor_shape_2(x): ...@@ -36,7 +36,7 @@ def dyfunc_tensor_shape_2(x):
def dyfunc_tensor_shape_3(x): def dyfunc_tensor_shape_3(x):
# Don't transform y.shape because y is numpy.ndarray # Transform y.shape but run y.shape actually because y is not Tensor
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
y = numpy.ones(5) y = numpy.ones(5)
res = fluid.layers.reshape(x, shape=y.shape) res = fluid.layers.reshape(x, shape=y.shape)
...@@ -51,7 +51,8 @@ def dyfunc_tensor_shape_4(x): ...@@ -51,7 +51,8 @@ def dyfunc_tensor_shape_4(x):
def dyfunc_tensor_shape_5(x): def dyfunc_tensor_shape_5(x):
# `res = fluid.layers.reshape(x, shape=(-1, s))` to # `res = fluid.layers.reshape(x, shape=(-1, s))` to
# `res = fluid.layers.reshape(x, shape=(-1, fluid.layers.shape(x)[0]))` # `res = fluid.layers.reshape(x, shape=(-1,
# fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]))`
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
s = x.shape[0] s = x.shape[0]
res = fluid.layers.reshape(x, shape=(-1, s)) res = fluid.layers.reshape(x, shape=(-1, s))
...@@ -63,7 +64,8 @@ def dyfunc_with_if_1(x): ...@@ -63,7 +64,8 @@ def dyfunc_with_if_1(x):
res = fluid.layers.reshape(x, [-1, 1]) res = fluid.layers.reshape(x, [-1, 1])
x_shape_0 = x.shape[0] x_shape_0 = x.shape[0]
if x_shape_0 < 1: if x_shape_0 < 1:
# `res.shape[0] > 1` is transformed into `if fluid.layers.shape(res)[0] > 1` # `res.shape[0]` is transformed into
# `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(res)[0]`
if res.shape[0] > 1: if res.shape[0] > 1:
res = fluid.layers.fill_constant( res = fluid.layers.fill_constant(
value=2, shape=x.shape, dtype="int32") value=2, shape=x.shape, dtype="int32")
...@@ -75,7 +77,7 @@ def dyfunc_with_if_1(x): ...@@ -75,7 +77,7 @@ def dyfunc_with_if_1(x):
def dyfunc_with_if_2(x): def dyfunc_with_if_2(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
# `len(x.shape)` will not be transformed. # `len(x.shape)` will not be transformed because x.shape is not used by Paddle api.
if len(x.shape) < 1: if len(x.shape) < 1:
res = x res = x
else: else:
...@@ -87,7 +89,7 @@ def dyfunc_with_if_2(x): ...@@ -87,7 +89,7 @@ def dyfunc_with_if_2(x):
def dyfunc_with_for_1(x): def dyfunc_with_for_1(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
# `x.shape[0]` is transformed into `fluid.layers.shape(x)[0]` # `x.shape[0]` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]`
for i in range(x.shape[0]): for i in range(x.shape[0]):
res += 1 res += 1
return res return res
...@@ -98,7 +100,7 @@ def dyfunc_with_for_2(x): ...@@ -98,7 +100,7 @@ def dyfunc_with_for_2(x):
x_shape_0 = x.shape[0] x_shape_0 = x.shape[0]
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
# `x_shape_0` is transformed into `fluid.layers.shape(x)[0]` # `x_shape_0` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]`
for i in range(x_shape_0): for i in range(x_shape_0):
res += 1 res += 1
return res return res
...@@ -122,7 +124,7 @@ def dyfunc_with_for_3(x): ...@@ -122,7 +124,7 @@ def dyfunc_with_for_3(x):
def dyfunc_with_while_1(x): def dyfunc_with_while_1(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
# `x.shape[0]` is transformed into `fluid.layers.shape(x)[0]` # `x.shape[0]` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]`
i = 1 i = 1
while i < x.shape[0]: while i < x.shape[0]:
res += 1 res += 1
...@@ -135,19 +137,14 @@ def dyfunc_with_while_2(x): ...@@ -135,19 +137,14 @@ def dyfunc_with_while_2(x):
x_shape_0 = x.shape[0] x_shape_0 = x.shape[0]
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
i = 1 i = 1
# `x_shape_0` is transformed into `fluid.layers.shape(x)[0]` # `x_shape_0` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]`
# TODO(liym27): If `x_shape_0` is at right like `while i < x_shape_0`, it will not be transformed. while i < x_shape_0:
# Fix this bug next PR.
while x_shape_0 > i:
res += 1 res += 1
i = i + 2 i = i + 2
return res return res
def dyfunc_with_while_3(x): def dyfunc_with_while_3(x):
# TODO(liym27):
# It will fail to run because the same problem as `dyfunc_with_for_3`.
# After the AST tranformation of for loop is improved, add TestTensorShapeInWhile3.
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
x_shape = x.shape x_shape = x.shape
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
...@@ -160,6 +157,19 @@ def dyfunc_with_while_3(x): ...@@ -160,6 +157,19 @@ def dyfunc_with_while_3(x):
return res return res
def dyfunc_with_while_4(x):
x = fluid.dygraph.to_variable(x)
y = numpy.ones(5)
y_shape_0 = y.shape[0]
i = 1
# Transform y_shape_0 but run y.shape[0] actually because y is not Tensor
while y_shape_0 > i:
x += 1
i += 1
return x
# 1. Basic tests without control flow # 1. Basic tests without control flow
class TestTensorShapeBasic(unittest.TestCase): class TestTensorShapeBasic(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -183,7 +193,7 @@ class TestTensorShapeBasic(unittest.TestCase): ...@@ -183,7 +193,7 @@ class TestTensorShapeBasic(unittest.TestCase):
return self._run(to_static=False) return self._run(to_static=False)
def get_static_output(self): def get_static_output(self):
return self._run(to_static=False) return self._run(to_static=True)
def test_transformed_static_result(self): def test_transformed_static_result(self):
static_res = self.get_static_output() static_res = self.get_static_output()
...@@ -247,5 +257,15 @@ class TestTensorShapeInWhile2(TestTensorShapeBasic): ...@@ -247,5 +257,15 @@ class TestTensorShapeInWhile2(TestTensorShapeBasic):
self.dygraph_func = dyfunc_with_while_2 self.dygraph_func = dyfunc_with_while_2
class TestTensorShapeInWhile3(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_with_while_3
class TestTensorShapeInWhile4(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_with_while_4
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册