未验证 提交 2961a4f0 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat] Optimize loop cond (#24049)

* Simplify code for gast.If in is_control_flow_to_transform.
* Move IsControlFlowVisitor to file utils. 
* Don't use convert_call for build-in func in CallTransformer. 
* Optimize api is_control_flow_to_transform. 
* Polish the document of IsControlFlowVisitor.
上级 aa0330f4
...@@ -32,6 +32,15 @@ class CallTransformer(gast.NodeTransformer): ...@@ -32,6 +32,15 @@ class CallTransformer(gast.NodeTransformer):
self.wrapper_root = wrapper_root self.wrapper_root = wrapper_root
self.root = wrapper_root.node self.root = wrapper_root.node
def _is_builtin_call(self, node):
assert isinstance(node, gast.Call)
func_str = ast_to_source_code(node.func).strip()
try:
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin
return eval("is_builtin({})".format(func_str))
except Exception:
return False
def transform(self): def transform(self):
self.visit(self.root) self.visit(self.root)
...@@ -39,6 +48,10 @@ class CallTransformer(gast.NodeTransformer): ...@@ -39,6 +48,10 @@ class CallTransformer(gast.NodeTransformer):
self.generic_visit(node) self.generic_visit(node)
if is_paddle_api(node): if is_paddle_api(node):
return node return node
if self._is_builtin_call(node):
return node
func_str = ast_to_source_code(node.func).strip() func_str = ast_to_source_code(node.func).strip()
new_func_str = "fluid.dygraph.dygraph_to_static.convert_call({})".format( new_func_str = "fluid.dygraph.dygraph_to_static.convert_call({})".format(
func_str) func_str)
......
...@@ -25,12 +25,15 @@ import gast ...@@ -25,12 +25,15 @@ import gast
import six import six
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import compare_with_none
from paddle.fluid.dygraph.dygraph_to_static.utils import is_candidate_node
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node
from paddle.fluid.dygraph.dygraph_to_static.utils import IsControlFlowVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
TRUE_FUNC_PREFIX = 'true_fn' TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn' FALSE_FUNC_PREFIX = 'false_fn'
...@@ -142,145 +145,6 @@ class IfElseTransformer(gast.NodeTransformer): ...@@ -142,145 +145,6 @@ class IfElseTransformer(gast.NodeTransformer):
return self.new_func_nodes return self.new_func_nodes
def is_candidate_node(node):
"""
Nodes with specified type will be dependent on tensor.
"""
is_compare_node = isinstance(node,
(gast.Compare, gast.BoolOp, gast.UnaryOp))
# TODO(Aurelius84): `.numpy()` may be an customized function,
# and should consider a more elegant way to solve this problem.
has_numpy_attr = ".numpy()" in ast_to_source_code(node)
return is_compare_node or has_numpy_attr
def compare_with_none(node):
"""
Whether the comparator of `gast.Compare` node is `None`.
"""
if isinstance(node, gast.Compare):
for child in [node.left, node.comparators]:
# node.comparators is a list.
if isinstance(child, list):
child = child[0]
if (isinstance(child, gast.Constant) and child.value is None) or (
isinstance(child, gast.Name) and child.id == 'None'):
return True
return False
class IsControlFlowVisitor(gast.NodeVisitor):
"""
Judge whether the node.test from Dygraph code dependent on paddle Tensor.
If does, it should satisfy:
1. must involve at least one var whose type is Tensor.
2. the Tensor var should call `.numpy()[]` interface or Tensor.shape is [1].
3. involve Tensor.shape[i] and the shape[i] is unknown in compile time.
The following examples should not be considered as control_flow_if:
1. `if Tensor_var` or `if Tensor_var is None`
2. if Tensor.shape[i] is determined with fixed value (not -1 or None)
Note: pred in ConditionalBlock require variable, which means all vars should be Tensor
or transformed into Tensor, like fill_constant(shape=[1], dtype='int32', value=Tensor.shape[i]).
TODO: 1. need to deal with `tensor.shape[i]` which need to eval the data of shape[i],
because reshape_op may be called before this statement.
"""
def __init__(self,
ast_node,
static_analysis_visitor=None,
node_var_type_map=None):
assert isinstance(
ast_node, gast.AST
), "Type of input node should be gast.AST, but received %s." % type(
ast_node)
self.ast_root = ast_node
if static_analysis_visitor is None:
static_analysis_visitor = StaticAnalysisVisitor(ast_node)
self.static_analysis_visitor = static_analysis_visitor
self.node_var_type_map = node_var_type_map
self.is_control_flow_num = 0
self._compare_node_tenor_set = set()
def transform(self):
node = self.ast_root
if is_candidate_node(node):
self.visit(node)
return self.is_control_flow_num > 0
def visit_BoolOp(self, node):
for i, child in enumerate(node.values):
if is_candidate_node(child):
self.visit(child)
return node
def visit_Compare(self, node):
# Ignores child node with `if x` or `if x is None`
# TODO(Aurelius84): `if tensor` will be supported in dygraph
# and should be considered as is_control_flow.
pre_control_flow_num = self.is_control_flow_num
if not compare_with_none(node):
self.generic_visit(node)
for child in gast.walk(node):
if isinstance(child, gast.Subscript):
self._visit_Subscript(child)
if self.is_control_flow_num > pre_control_flow_num:
self._compare_node_tenor_set.add(node)
return node
def _visit_Subscript(self, node):
self.generic_visit(node)
if hasattr(node, 'value') and isinstance(node.value, gast.Call):
self._visit_Call(node.value)
return node
def _visit_Call(self, node):
assert isinstance(node, gast.Call)
if isinstance(node.func, gast.Attribute):
attr_node = node.func
if attr_node.attr == 'numpy':
self.is_control_flow_num += 1
def visit_Call(self, node):
self._visit_Call(node)
if is_paddle_api(node):
self.is_control_flow_num += 1
return node
def visit_Name(self, node):
if self._is_node_with_tensor(node, node.id):
self.is_control_flow_num += 1
return node
def visit_Constant(self, node):
if self._is_node_with_tensor(node, node.value):
self.is_control_flow_num += 1
return node
def _is_node_with_tensor(self, node, name_id):
tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES}
# Look up the node_var_type_map by name_id.
if self.node_var_type_map:
if name_id and isinstance(name_id, six.string_types):
var_type = self.node_var_type_map.get(name_id, None)
if var_type and var_type & tensor_types:
return True
# if not found, look up the node_to_wrapper_map by node.
node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
wrapper_node = node_to_wrapper_map.get(node, None)
if wrapper_node is not None:
if wrapper_node.node_var_type & tensor_types:
return True
return False
def get_compare_nodes_with_tensor(self):
return self._compare_node_tenor_set
class NodeTestTransformer(gast.NodeTransformer): class NodeTestTransformer(gast.NodeTransformer):
def __init__(self, ast_node, compare_nodes_with_tensor=None): def __init__(self, ast_node, compare_nodes_with_tensor=None):
if compare_nodes_with_tensor is None: if compare_nodes_with_tensor is None:
......
...@@ -55,19 +55,22 @@ class ListTransformer(gast.NodeTransformer): ...@@ -55,19 +55,22 @@ class ListTransformer(gast.NodeTransformer):
def visit_If(self, node): def visit_If(self, node):
self.generic_visit(node) self.generic_visit(node)
if is_control_flow_to_transform(node, self.scope_var_type_dict): if is_control_flow_to_transform(node, self.static_analysis_visitor,
self.scope_var_type_dict):
self._transform_list_append_in_control_flow(node) self._transform_list_append_in_control_flow(node)
return node return node
def visit_While(self, node): def visit_While(self, node):
self.generic_visit(node) self.generic_visit(node)
if is_control_flow_to_transform(node, self.scope_var_type_dict): if is_control_flow_to_transform(node, self.static_analysis_visitor,
self.scope_var_type_dict):
self._transform_list_append_in_control_flow(node) self._transform_list_append_in_control_flow(node)
return node return node
def visit_For(self, node): def visit_For(self, node):
self.generic_visit(node) self.generic_visit(node)
if is_control_flow_to_transform(node, self.scope_var_type_dict): if is_control_flow_to_transform(node, self.static_analysis_visitor,
self.scope_var_type_dict):
self._transform_list_append_in_control_flow(node) self._transform_list_append_in_control_flow(node)
return node return node
......
...@@ -26,6 +26,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code ...@@ -26,6 +26,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node
...@@ -150,8 +151,9 @@ class NameVisitor(gast.NodeVisitor): ...@@ -150,8 +151,9 @@ class NameVisitor(gast.NodeVisitor):
self.visit(root_node) self.visit(root_node)
def is_control_flow_loop(self, node): def is_control_flow_loop(self, node):
# TODO: make a better condition need_transform = is_control_flow_to_transform(
return True node, self.static_analysis_visitor)
return need_transform
def get_loop_var_names(self, node): def get_loop_var_names(self, node):
assert isinstance( assert isinstance(
......
...@@ -15,15 +15,8 @@ ...@@ -15,15 +15,8 @@
from __future__ import print_function from __future__ import print_function
import gast import gast
import astor from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
import copy from paddle.fluid.dygraph.dygraph_to_static.utils import create_api_shape_node
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api, create_api_shape_node
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
......
...@@ -50,8 +50,9 @@ def is_api_in_module(node, module_prefix): ...@@ -50,8 +50,9 @@ def is_api_in_module(node, module_prefix):
# source_file = inspect.getfile(dyfunc) # source_file = inspect.getfile(dyfunc)
# import_statements = ImportVisitor(source_file).transform() # import_statements = ImportVisitor(source_file).transform()
# import_str = "".join(import_statements) # import_str = "".join(import_statements)
import paddle.fluid as fluid
import paddle import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
import paddle.fluid.dygraph as dygraph import paddle.fluid.dygraph as dygraph
return eval("_is_api_in_module_helper({}, '{}')".format(func_str, return eval("_is_api_in_module_helper({}, '{}')".format(func_str,
...@@ -88,29 +89,19 @@ def is_numpy_api(node): ...@@ -88,29 +89,19 @@ def is_numpy_api(node):
return False return False
def is_control_flow_to_transform(node, var_name_to_type): def is_control_flow_to_transform(node,
static_analysis_visitor=None,
var_name_to_type=None):
""" """
Determines whether the node is a Paddle control flow statement which needs to Determines whether the node is a PaddlePaddle control flow statement which needs to
transform into a static graph control flow statement. be transformed into a static graph control flow statement.
""" """
assert isinstance(node, gast.AST), \ assert isinstance(node, gast.AST), \
"The type of input node must be gast.AST, but received %s." % type(node) "The type of input node must be gast.AST, but received %s." % type(node)
visitor = IsControlFlowVisitor(
if isinstance(node, gast.If): node, static_analysis_visitor, node_var_type_map=var_name_to_type)
from .ifelse_transformer import IfConditionVisitor need_to_transform = visitor.transform()
if_visitor = IfConditionVisitor( return need_to_transform
node.test, node_var_type_map=var_name_to_type)
return if_visitor.is_control_flow()
if isinstance(node, gast.For):
# TODO: make a better condition
return True
if isinstance(node, gast.While):
# TODO: make a better condition
return True
return False
def _delete_keywords_from(node): def _delete_keywords_from(node):
...@@ -415,3 +406,216 @@ def ast_to_source_code(ast_node): ...@@ -415,3 +406,216 @@ def ast_to_source_code(ast_node):
ast_node = gast.gast_to_ast(ast_node) ast_node = gast.gast_to_ast(ast_node)
source_code = astor.to_source(ast_node) source_code = astor.to_source(ast_node)
return source_code return source_code
def is_candidate_node(node):
"""
Nodes with specified type will be dependent on tensor.
"""
is_compare_node = isinstance(node, (gast.Compare, gast.BoolOp, gast.UnaryOp,
gast.For, gast.If, gast.While))
# TODO(Aurelius84): `.numpy()` may be an customized function,
# and should consider a more elegant way to solve this problem.
has_numpy_attr = ".numpy()" in ast_to_source_code(node)
return is_compare_node or has_numpy_attr
def compare_with_none(node):
"""
Whether the comparator of `gast.Compare` node is `None`.
"""
if isinstance(node, gast.Compare):
for child in [node.left, node.comparators]:
# node.comparators is a list.
if isinstance(child, list):
child = child[0]
if (isinstance(child, gast.Constant) and child.value is None) or (
isinstance(child, gast.Name) and child.id == 'None'):
return True
return False
class IsControlFlowVisitor(gast.NodeVisitor):
"""
Judge whether the ast_node of control flow from Dygraph code dependent on paddle Tensor.
`ast_node` can be gast.If, gast.For, gast.While, gast.If.test(gast.Compare, gast.BoolOp, gast.UnaryOp).
If returns True,
gast.If.test must meet at least one of the following requirements:
1. involves at least one var whose type is Tensor.
2. the Tensor var calls `.numpy()[]` interface or Tensor.shape is [1].
3. involves Tensor.shape[i] and the shape[i] is unknown in compile time.
gast.While must meet at least one of the requirements 1 to 5:
4. has `break` statement.
5. has `continue` statement.
gast.For must meet at least one of the requirements 4 to 6:
6. calls `range` function in `for` statement and the argument of range is Tensor.
TODO: Support non-range case
The following examples should not be considered as control_flow_if:
1. `if Tensor_var` or `if Tensor_var is None`
2. if Tensor.shape[i] is determined with fixed value (not -1 or None)
Note: pred in ConditionalBlock require variable, which means all vars should be Tensor
or transformed into Tensor, like fill_constant(shape=[1], dtype='int32', value=Tensor.shape[i]).
TODO: 1. need to deal with `tensor.shape[i]` which need to eval the data of shape[i],
because reshape_op may be called before this statement.
"""
def __init__(self,
ast_node,
static_analysis_visitor=None,
node_var_type_map=None):
assert isinstance(
ast_node, gast.AST
), "Type of input node should be gast.AST, but received %s." % type(
ast_node)
self.ast_root = ast_node
if static_analysis_visitor is None:
from .static_analysis import StaticAnalysisVisitor
static_analysis_visitor = StaticAnalysisVisitor(ast_node)
self.static_analysis_visitor = static_analysis_visitor
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
self.node_var_type_map = node_var_type_map
self.is_control_flow_num = 0
self._compare_node_tenor_set = set()
def transform(self):
node = self.ast_root
if is_candidate_node(node):
if isinstance(node, gast.If):
self._visit_If(node)
if isinstance(node, gast.For):
self._visit_For(node)
elif isinstance(node, gast.While):
self._visit_While(node)
else:
self.visit(node)
return self.is_control_flow_num > 0
def _visit_If(self, node):
assert isinstance(node, gast.If)
self.visit(node.test)
return
def _visit_For(self, node):
assert isinstance(node, gast.For)
# TODO
# self.is_control_flow_num += 1
if not isinstance(node.iter, gast.Call):
return
if not isinstance(node.iter.func, gast.Name):
return
if node.iter.func.id != "range":
return
for arg in node.iter.args:
self.visit(arg)
for child_node in gast.walk(node):
if isinstance(child_node, (gast.Continue, gast.Break)):
self._visit_break_continue(child_node)
return
def _visit_While(self, node):
assert isinstance(node, gast.While)
test = node.test
self.generic_visit(test)
for child_node in gast.walk(node):
if isinstance(child_node, (gast.Continue, gast.Break)):
self._visit_break_continue(child_node)
return
def _visit_break_continue(self, node):
assert isinstance(node, (gast.Break, gast.Continue))
wrapper_node = self.node_to_wrapper_map.get(node)
if not wrapper_node:
# Transformed node is not in node_to_wrapper_map
return
while wrapper_node.parent:
parent_node = wrapper_node.parent.node
if isinstance(parent_node, (gast.For, gast.While)):
if parent_node is self.ast_root:
self.is_control_flow_num += 1
return
else:
return
wrapper_node = wrapper_node.parent
return
def visit_BoolOp(self, node):
for i, child in enumerate(node.values):
if is_candidate_node(child):
self.visit(child)
return node
def visit_Compare(self, node):
# Ignores child node with `if x` or `if x is None`
# TODO(Aurelius84): `if tensor` will be supported in dygraph
# and should be considered as is_control_flow.
pre_control_flow_num = self.is_control_flow_num
if not compare_with_none(node):
self.generic_visit(node)
for child in gast.walk(node):
if isinstance(child, gast.Subscript):
self._visit_Subscript(child)
if self.is_control_flow_num > pre_control_flow_num:
self._compare_node_tenor_set.add(node)
return node
def _visit_Subscript(self, node):
self.generic_visit(node)
if hasattr(node, 'value') and isinstance(node.value, gast.Call):
self._visit_Call(node.value)
return node
def _visit_Call(self, node):
assert isinstance(node, gast.Call)
if isinstance(node.func, gast.Attribute):
attr_node = node.func
if attr_node.attr == 'numpy':
self.is_control_flow_num += 1
def visit_Call(self, node):
self._visit_Call(node)
if is_paddle_api(node):
self.is_control_flow_num += 1
return node
def visit_Name(self, node):
if self._is_node_with_tensor(node, node.id):
self.is_control_flow_num += 1
return node
def visit_Constant(self, node):
if self._is_node_with_tensor(node, node.value):
self.is_control_flow_num += 1
return node
def _is_node_with_tensor(self, node, name_id):
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES}
# Look up the node_var_type_map by name_id.
if self.node_var_type_map:
if name_id and isinstance(name_id, six.string_types):
var_type = self.node_var_type_map.get(name_id, None)
if var_type and var_type & tensor_types:
return True
# if not found, look up the node_to_wrapper_map by node.
node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
wrapper_node = node_to_wrapper_map.get(node, None)
if wrapper_node is not None:
if wrapper_node.node_var_type & tensor_types:
return True
return False
def get_compare_nodes_with_tensor(self):
return self._compare_node_tenor_set
...@@ -19,9 +19,9 @@ import textwrap ...@@ -19,9 +19,9 @@ import textwrap
import gast import gast
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import get_name_ids from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import get_name_ids
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfConditionVisitor from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfConditionVisitor
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IsControlFlowVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.utils import IsControlFlowVisitor
class TestGetNameIds(unittest.TestCase): class TestGetNameIds(unittest.TestCase):
......
...@@ -47,6 +47,10 @@ def test_list_in_if(x): ...@@ -47,6 +47,10 @@ def test_list_in_if(x):
def test_list_in_for_loop(x, iter_num): def test_list_in_for_loop(x, iter_num):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
# Use `fill_constant` so that static analysis can analyze the type of iter_num is Tensor
iter_num = fluid.layers.fill_constant(
shape=[1], value=iter_num, dtype="int32"
) # TODO(liym27): Delete it if the type of parameter iter_num can be resolved
a = [] a = []
for i in range(iter_num): for i in range(iter_num):
a.append(x) a.append(x)
...@@ -56,6 +60,10 @@ def test_list_in_for_loop(x, iter_num): ...@@ -56,6 +60,10 @@ def test_list_in_for_loop(x, iter_num):
def test_list_in_for_loop_with_concat(x, iter_num): def test_list_in_for_loop_with_concat(x, iter_num):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
a = [] a = []
# Use `fill_constant` so that static analysis can analyze the type of iter_num is Tensor
iter_num = fluid.layers.fill_constant(
shape=[1], value=iter_num, dtype="int32"
) # TODO(liym27): Delete it if the type of parameter iter_num can be resolved
for i in range(iter_num): for i in range(iter_num):
a.append(x) a.append(x)
a = fluid.layers.concat(a, axis=0) a = fluid.layers.concat(a, axis=0)
......
...@@ -29,6 +29,9 @@ np.random.seed(SEED) ...@@ -29,6 +29,9 @@ np.random.seed(SEED)
def while_loop_dyfunc(x): def while_loop_dyfunc(x):
i = fluid.dygraph.to_variable(x) i = fluid.dygraph.to_variable(x)
# Use `to_variable` so that static analysis can analyze the type of X is Tensor
x = fluid.dygraph.to_variable(
x) # TODO(liym27): Delete it if the type of parameter x can be resolved
while x < 10: while x < 10:
i = i + x i = i + x
x = x + 1 x = x + 1
...@@ -37,6 +40,9 @@ def while_loop_dyfunc(x): ...@@ -37,6 +40,9 @@ def while_loop_dyfunc(x):
def while_loop_dyfun_with_conflict_var(x): def while_loop_dyfun_with_conflict_var(x):
i = fluid.dygraph.to_variable(x) i = fluid.dygraph.to_variable(x)
# Use `to_variable` so that static analysis can analyze the type of X is Tensor
x = fluid.dygraph.to_variable(
x) # TODO(liym27): Delete it if the type of parameter x can be resolved
def relu(y): def relu(y):
# 'y' is not visible outside the scope. # 'y' is not visible outside the scope.
...@@ -56,6 +62,9 @@ def while_loop_dyfunc_with_none(x): ...@@ -56,6 +62,9 @@ def while_loop_dyfunc_with_none(x):
i = fluid.dygraph.to_variable(x)\ i = fluid.dygraph.to_variable(x)\
if x is not None \ if x is not None \
else fluid.dygraph.to_variable(x+1) else fluid.dygraph.to_variable(x+1)
# Use `to_variable` so that static analysis can analyze the type of X is Tensor
x = fluid.dygraph.to_variable(
x) # TODO(liym27): Delete it if the type of parameter x can be resolved
flag = 1 flag = 1
while x < 10: while x < 10:
i = i + x if flag is not None else x + i i = i + x if flag is not None else x + i
...@@ -72,6 +81,10 @@ def for_loop_dyfunc(max_len): ...@@ -72,6 +81,10 @@ def for_loop_dyfunc(max_len):
def while_loop_bool_op(x): def while_loop_bool_op(x):
i = fluid.dygraph.to_variable(x) i = fluid.dygraph.to_variable(x)
# Use `to_variable` so that static analysis can analyze the type of X is Tensor
x = fluid.dygraph.to_variable(
x) # TODO(liym27): Delete it if the type of parameter x can be resolved
while (x >= 0 and x < 10) or x <= -1 or x < -3 or (x < -7 or x < -5): while (x >= 0 and x < 10) or x <= -1 or x < -3 or (x < -7 or x < -5):
i = i + x i = i + x
x = x + 1 x = x + 1
...@@ -102,6 +115,11 @@ def for_loop_class_var(max_len): ...@@ -102,6 +115,11 @@ def for_loop_class_var(max_len):
self.c = 5 self.c = 5
foo = Foo() foo = Foo()
# Use `to_variable` so that static analysis can analyze the type of X is Tensor
# TODO(liym27): Delete it if the type of parameter x can be resolved
max_len = fluid.layers.fill_constant(
shape=[1], value=max_len, dtype="int32")
for i in range(max_len): for i in range(max_len):
foo.b = fluid.layers.zeros(shape=[1], dtype='float32') foo.b = fluid.layers.zeros(shape=[1], dtype='float32')
foo.c = foo.b + foo.a foo.c = foo.b + foo.a
......
...@@ -69,6 +69,11 @@ def test_slice_in_while_loop(x, iter_num): ...@@ -69,6 +69,11 @@ def test_slice_in_while_loop(x, iter_num):
def test_slice_in_for_loop(x, iter_num): def test_slice_in_for_loop(x, iter_num):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
a = [] a = []
# Use `fill_constant` so that static analysis can analyze the type of iter_num is Tensor
iter_num = fluid.layers.fill_constant(
shape=[1], value=iter_num, dtype="int32"
) # TODO(liym27): Delete it if the type of parameter iter_num can be resolved
for i in range(iter_num): for i in range(iter_num):
a.append(x) a.append(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册