From 2961a4f07dd4ea7e078c468b17fc2f9aead19fd1 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Fri, 24 Apr 2020 13:36:43 +0800 Subject: [PATCH] [Dy2Stat] Optimize loop cond (#24049) * Simplify code for gast.If in is_control_flow_to_transform. * Move IsControlFlowVisitor to file utils. * Don't use convert_call for build-in func in CallTransformer. * Optimize api is_control_flow_to_transform. * Polish the document of IsControlFlowVisitor. --- .../dygraph_to_static/call_transformer.py | 13 + .../dygraph_to_static/ifelse_transformer.py | 144 +---------- .../dygraph_to_static/list_transformer.py | 9 +- .../dygraph_to_static/loop_transformer.py | 6 +- .../tensor_shape_transformer.py | 11 +- .../fluid/dygraph/dygraph_to_static/utils.py | 244 ++++++++++++++++-- .../dygraph_to_static/test_ifelse_basic.py | 2 +- .../unittests/dygraph_to_static/test_list.py | 8 + .../unittests/dygraph_to_static/test_loop.py | 18 ++ .../unittests/dygraph_to_static/test_slice.py | 5 + 10 files changed, 285 insertions(+), 175 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py index 9c128a83c47..b872ab723e3 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py @@ -32,6 +32,15 @@ class CallTransformer(gast.NodeTransformer): self.wrapper_root = wrapper_root self.root = wrapper_root.node + def _is_builtin_call(self, node): + assert isinstance(node, gast.Call) + func_str = ast_to_source_code(node.func).strip() + try: + from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin + return eval("is_builtin({})".format(func_str)) + except Exception: + return False + def transform(self): self.visit(self.root) @@ -39,6 +48,10 @@ class CallTransformer(gast.NodeTransformer): self.generic_visit(node) if is_paddle_api(node): return node + + if self._is_builtin_call(node): + return node + func_str = ast_to_source_code(node.func).strip() new_func_str = "fluid.dygraph.dygraph_to_static.convert_call({})".format( func_str) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index b57d74f6470..db69e3763f9 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -25,12 +25,15 @@ import gast import six from paddle.fluid import unique_name +from paddle.fluid.dygraph.dygraph_to_static.utils import compare_with_none +from paddle.fluid.dygraph.dygraph_to_static.utils import is_candidate_node from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node +from paddle.fluid.dygraph.dygraph_to_static.utils import IsControlFlowVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor -from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper TRUE_FUNC_PREFIX = 'true_fn' FALSE_FUNC_PREFIX = 'false_fn' @@ -142,145 +145,6 @@ class IfElseTransformer(gast.NodeTransformer): return self.new_func_nodes -def is_candidate_node(node): - """ - Nodes with specified type will be dependent on tensor. - """ - is_compare_node = isinstance(node, - (gast.Compare, gast.BoolOp, gast.UnaryOp)) - # TODO(Aurelius84): `.numpy()` may be an customized function, - # and should consider a more elegant way to solve this problem. - has_numpy_attr = ".numpy()" in ast_to_source_code(node) - return is_compare_node or has_numpy_attr - - -def compare_with_none(node): - """ - Whether the comparator of `gast.Compare` node is `None`. - """ - if isinstance(node, gast.Compare): - for child in [node.left, node.comparators]: - # node.comparators is a list. - if isinstance(child, list): - child = child[0] - if (isinstance(child, gast.Constant) and child.value is None) or ( - isinstance(child, gast.Name) and child.id == 'None'): - return True - return False - - -class IsControlFlowVisitor(gast.NodeVisitor): - """ - Judge whether the node.test from Dygraph code dependent on paddle Tensor. - If does, it should satisfy: - 1. must involve at least one var whose type is Tensor. - 2. the Tensor var should call `.numpy()[]` interface or Tensor.shape is [1]. - 3. involve Tensor.shape[i] and the shape[i] is unknown in compile time. - The following examples should not be considered as control_flow_if: - 1. `if Tensor_var` or `if Tensor_var is None` - 2. if Tensor.shape[i] is determined with fixed value (not -1 or None) - - Note: pred in ConditionalBlock require variable, which means all vars should be Tensor - or transformed into Tensor, like fill_constant(shape=[1], dtype='int32', value=Tensor.shape[i]). - - TODO: 1. need to deal with `tensor.shape[i]` which need to eval the data of shape[i], - because reshape_op may be called before this statement. - """ - - def __init__(self, - ast_node, - static_analysis_visitor=None, - node_var_type_map=None): - assert isinstance( - ast_node, gast.AST - ), "Type of input node should be gast.AST, but received %s." % type( - ast_node) - self.ast_root = ast_node - if static_analysis_visitor is None: - static_analysis_visitor = StaticAnalysisVisitor(ast_node) - self.static_analysis_visitor = static_analysis_visitor - self.node_var_type_map = node_var_type_map - - self.is_control_flow_num = 0 - self._compare_node_tenor_set = set() - - def transform(self): - node = self.ast_root - if is_candidate_node(node): - self.visit(node) - return self.is_control_flow_num > 0 - - def visit_BoolOp(self, node): - for i, child in enumerate(node.values): - if is_candidate_node(child): - self.visit(child) - return node - - def visit_Compare(self, node): - # Ignores child node with `if x` or `if x is None` - # TODO(Aurelius84): `if tensor` will be supported in dygraph - # and should be considered as is_control_flow. - pre_control_flow_num = self.is_control_flow_num - if not compare_with_none(node): - self.generic_visit(node) - for child in gast.walk(node): - if isinstance(child, gast.Subscript): - self._visit_Subscript(child) - if self.is_control_flow_num > pre_control_flow_num: - self._compare_node_tenor_set.add(node) - return node - - def _visit_Subscript(self, node): - self.generic_visit(node) - if hasattr(node, 'value') and isinstance(node.value, gast.Call): - self._visit_Call(node.value) - return node - - def _visit_Call(self, node): - assert isinstance(node, gast.Call) - if isinstance(node.func, gast.Attribute): - attr_node = node.func - if attr_node.attr == 'numpy': - self.is_control_flow_num += 1 - - def visit_Call(self, node): - self._visit_Call(node) - if is_paddle_api(node): - self.is_control_flow_num += 1 - return node - - def visit_Name(self, node): - if self._is_node_with_tensor(node, node.id): - self.is_control_flow_num += 1 - return node - - def visit_Constant(self, node): - if self._is_node_with_tensor(node, node.value): - self.is_control_flow_num += 1 - return node - - def _is_node_with_tensor(self, node, name_id): - tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES} - # Look up the node_var_type_map by name_id. - if self.node_var_type_map: - if name_id and isinstance(name_id, six.string_types): - var_type = self.node_var_type_map.get(name_id, None) - if var_type and var_type & tensor_types: - return True - # if not found, look up the node_to_wrapper_map by node. - node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( - ) - wrapper_node = node_to_wrapper_map.get(node, None) - if wrapper_node is not None: - if wrapper_node.node_var_type & tensor_types: - return True - - return False - - def get_compare_nodes_with_tensor(self): - return self._compare_node_tenor_set - - class NodeTestTransformer(gast.NodeTransformer): def __init__(self, ast_node, compare_nodes_with_tensor=None): if compare_nodes_with_tensor is None: diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py index 5579229ccd5..4780325bae3 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py @@ -55,19 +55,22 @@ class ListTransformer(gast.NodeTransformer): def visit_If(self, node): self.generic_visit(node) - if is_control_flow_to_transform(node, self.scope_var_type_dict): + if is_control_flow_to_transform(node, self.static_analysis_visitor, + self.scope_var_type_dict): self._transform_list_append_in_control_flow(node) return node def visit_While(self, node): self.generic_visit(node) - if is_control_flow_to_transform(node, self.scope_var_type_dict): + if is_control_flow_to_transform(node, self.static_analysis_visitor, + self.scope_var_type_dict): self._transform_list_append_in_control_flow(node) return node def visit_For(self, node): self.generic_visit(node) - if is_control_flow_to_transform(node, self.scope_var_type_dict): + if is_control_flow_to_transform(node, self.static_analysis_visitor, + self.scope_var_type_dict): self._transform_list_append_in_control_flow(node) return node diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 82400ab01b8..900e48269d6 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -26,6 +26,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name +from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node @@ -150,8 +151,9 @@ class NameVisitor(gast.NodeVisitor): self.visit(root_node) def is_control_flow_loop(self, node): - # TODO: make a better condition - return True + need_transform = is_control_flow_to_transform( + node, self.static_analysis_visitor) + return need_transform def get_loop_var_names(self, node): assert isinstance( diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index 52392bc1e0a..a21a5af4552 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -15,15 +15,8 @@ from __future__ import print_function import gast -import astor -import copy -from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor -from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform -from paddle.fluid import unique_name -from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func -from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable -from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func -from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api, create_api_shape_node +from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api +from paddle.fluid.dygraph.dygraph_to_static.utils import create_api_shape_node from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 5a9146420df..d410df1589b 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -50,8 +50,9 @@ def is_api_in_module(node, module_prefix): # source_file = inspect.getfile(dyfunc) # import_statements = ImportVisitor(source_file).transform() # import_str = "".join(import_statements) - import paddle.fluid as fluid import paddle + import paddle.fluid as fluid + import paddle.fluid.layers as layers from paddle.fluid.dygraph import to_variable import paddle.fluid.dygraph as dygraph return eval("_is_api_in_module_helper({}, '{}')".format(func_str, @@ -88,29 +89,19 @@ def is_numpy_api(node): return False -def is_control_flow_to_transform(node, var_name_to_type): +def is_control_flow_to_transform(node, + static_analysis_visitor=None, + var_name_to_type=None): """ - Determines whether the node is a Paddle control flow statement which needs to - transform into a static graph control flow statement. + Determines whether the node is a PaddlePaddle control flow statement which needs to + be transformed into a static graph control flow statement. """ assert isinstance(node, gast.AST), \ "The type of input node must be gast.AST, but received %s." % type(node) - - if isinstance(node, gast.If): - from .ifelse_transformer import IfConditionVisitor - if_visitor = IfConditionVisitor( - node.test, node_var_type_map=var_name_to_type) - return if_visitor.is_control_flow() - - if isinstance(node, gast.For): - # TODO: make a better condition - return True - - if isinstance(node, gast.While): - # TODO: make a better condition - return True - - return False + visitor = IsControlFlowVisitor( + node, static_analysis_visitor, node_var_type_map=var_name_to_type) + need_to_transform = visitor.transform() + return need_to_transform def _delete_keywords_from(node): @@ -415,3 +406,216 @@ def ast_to_source_code(ast_node): ast_node = gast.gast_to_ast(ast_node) source_code = astor.to_source(ast_node) return source_code + + +def is_candidate_node(node): + """ + Nodes with specified type will be dependent on tensor. + """ + is_compare_node = isinstance(node, (gast.Compare, gast.BoolOp, gast.UnaryOp, + gast.For, gast.If, gast.While)) + # TODO(Aurelius84): `.numpy()` may be an customized function, + # and should consider a more elegant way to solve this problem. + has_numpy_attr = ".numpy()" in ast_to_source_code(node) + return is_compare_node or has_numpy_attr + + +def compare_with_none(node): + """ + Whether the comparator of `gast.Compare` node is `None`. + """ + if isinstance(node, gast.Compare): + for child in [node.left, node.comparators]: + # node.comparators is a list. + if isinstance(child, list): + child = child[0] + if (isinstance(child, gast.Constant) and child.value is None) or ( + isinstance(child, gast.Name) and child.id == 'None'): + return True + return False + + +class IsControlFlowVisitor(gast.NodeVisitor): + """ + Judge whether the ast_node of control flow from Dygraph code dependent on paddle Tensor. + `ast_node` can be gast.If, gast.For, gast.While, gast.If.test(gast.Compare, gast.BoolOp, gast.UnaryOp). + + If returns True, + gast.If.test must meet at least one of the following requirements: + 1. involves at least one var whose type is Tensor. + 2. the Tensor var calls `.numpy()[]` interface or Tensor.shape is [1]. + 3. involves Tensor.shape[i] and the shape[i] is unknown in compile time. + gast.While must meet at least one of the requirements 1 to 5: + 4. has `break` statement. + 5. has `continue` statement. + gast.For must meet at least one of the requirements 4 to 6: + 6. calls `range` function in `for` statement and the argument of range is Tensor. + TODO: Support non-range case + + The following examples should not be considered as control_flow_if: + 1. `if Tensor_var` or `if Tensor_var is None` + 2. if Tensor.shape[i] is determined with fixed value (not -1 or None) + + Note: pred in ConditionalBlock require variable, which means all vars should be Tensor + or transformed into Tensor, like fill_constant(shape=[1], dtype='int32', value=Tensor.shape[i]). + + TODO: 1. need to deal with `tensor.shape[i]` which need to eval the data of shape[i], + because reshape_op may be called before this statement. + """ + + def __init__(self, + ast_node, + static_analysis_visitor=None, + node_var_type_map=None): + assert isinstance( + ast_node, gast.AST + ), "Type of input node should be gast.AST, but received %s." % type( + ast_node) + self.ast_root = ast_node + if static_analysis_visitor is None: + from .static_analysis import StaticAnalysisVisitor + static_analysis_visitor = StaticAnalysisVisitor(ast_node) + self.static_analysis_visitor = static_analysis_visitor + self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( + ) + self.node_var_type_map = node_var_type_map + + self.is_control_flow_num = 0 + self._compare_node_tenor_set = set() + + def transform(self): + node = self.ast_root + if is_candidate_node(node): + if isinstance(node, gast.If): + self._visit_If(node) + if isinstance(node, gast.For): + self._visit_For(node) + elif isinstance(node, gast.While): + self._visit_While(node) + else: + self.visit(node) + return self.is_control_flow_num > 0 + + def _visit_If(self, node): + assert isinstance(node, gast.If) + self.visit(node.test) + return + + def _visit_For(self, node): + assert isinstance(node, gast.For) + # TODO + # self.is_control_flow_num += 1 + if not isinstance(node.iter, gast.Call): + return + if not isinstance(node.iter.func, gast.Name): + return + if node.iter.func.id != "range": + return + for arg in node.iter.args: + self.visit(arg) + + for child_node in gast.walk(node): + if isinstance(child_node, (gast.Continue, gast.Break)): + self._visit_break_continue(child_node) + return + + def _visit_While(self, node): + assert isinstance(node, gast.While) + test = node.test + self.generic_visit(test) + for child_node in gast.walk(node): + if isinstance(child_node, (gast.Continue, gast.Break)): + self._visit_break_continue(child_node) + return + + def _visit_break_continue(self, node): + assert isinstance(node, (gast.Break, gast.Continue)) + wrapper_node = self.node_to_wrapper_map.get(node) + if not wrapper_node: + # Transformed node is not in node_to_wrapper_map + return + + while wrapper_node.parent: + parent_node = wrapper_node.parent.node + if isinstance(parent_node, (gast.For, gast.While)): + if parent_node is self.ast_root: + self.is_control_flow_num += 1 + return + else: + return + + wrapper_node = wrapper_node.parent + + return + + def visit_BoolOp(self, node): + for i, child in enumerate(node.values): + if is_candidate_node(child): + self.visit(child) + return node + + def visit_Compare(self, node): + # Ignores child node with `if x` or `if x is None` + # TODO(Aurelius84): `if tensor` will be supported in dygraph + # and should be considered as is_control_flow. + pre_control_flow_num = self.is_control_flow_num + if not compare_with_none(node): + self.generic_visit(node) + for child in gast.walk(node): + if isinstance(child, gast.Subscript): + self._visit_Subscript(child) + if self.is_control_flow_num > pre_control_flow_num: + self._compare_node_tenor_set.add(node) + return node + + def _visit_Subscript(self, node): + self.generic_visit(node) + if hasattr(node, 'value') and isinstance(node.value, gast.Call): + self._visit_Call(node.value) + return node + + def _visit_Call(self, node): + assert isinstance(node, gast.Call) + if isinstance(node.func, gast.Attribute): + attr_node = node.func + if attr_node.attr == 'numpy': + self.is_control_flow_num += 1 + + def visit_Call(self, node): + self._visit_Call(node) + if is_paddle_api(node): + self.is_control_flow_num += 1 + return node + + def visit_Name(self, node): + if self._is_node_with_tensor(node, node.id): + self.is_control_flow_num += 1 + return node + + def visit_Constant(self, node): + if self._is_node_with_tensor(node, node.value): + self.is_control_flow_num += 1 + return node + + def _is_node_with_tensor(self, node, name_id): + from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType + + tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES} + # Look up the node_var_type_map by name_id. + if self.node_var_type_map: + if name_id and isinstance(name_id, six.string_types): + var_type = self.node_var_type_map.get(name_id, None) + if var_type and var_type & tensor_types: + return True + # if not found, look up the node_to_wrapper_map by node. + node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( + ) + wrapper_node = node_to_wrapper_map.get(node, None) + if wrapper_node is not None: + if wrapper_node.node_var_type & tensor_types: + return True + + return False + + def get_compare_nodes_with_tensor(self): + return self._compare_node_tenor_set diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py index e40fa00ddfe..91026096952 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py @@ -19,9 +19,9 @@ import textwrap import gast from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import get_name_ids from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfConditionVisitor -from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IsControlFlowVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType +from paddle.fluid.dygraph.dygraph_to_static.utils import IsControlFlowVisitor class TestGetNameIds(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py index 588d900f316..3e65492524d 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py @@ -47,6 +47,10 @@ def test_list_in_if(x): def test_list_in_for_loop(x, iter_num): x = fluid.dygraph.to_variable(x) + # Use `fill_constant` so that static analysis can analyze the type of iter_num is Tensor + iter_num = fluid.layers.fill_constant( + shape=[1], value=iter_num, dtype="int32" + ) # TODO(liym27): Delete it if the type of parameter iter_num can be resolved a = [] for i in range(iter_num): a.append(x) @@ -56,6 +60,10 @@ def test_list_in_for_loop(x, iter_num): def test_list_in_for_loop_with_concat(x, iter_num): x = fluid.dygraph.to_variable(x) a = [] + # Use `fill_constant` so that static analysis can analyze the type of iter_num is Tensor + iter_num = fluid.layers.fill_constant( + shape=[1], value=iter_num, dtype="int32" + ) # TODO(liym27): Delete it if the type of parameter iter_num can be resolved for i in range(iter_num): a.append(x) a = fluid.layers.concat(a, axis=0) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index b64fa34500f..2ebd6eff939 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -29,6 +29,9 @@ np.random.seed(SEED) def while_loop_dyfunc(x): i = fluid.dygraph.to_variable(x) + # Use `to_variable` so that static analysis can analyze the type of X is Tensor + x = fluid.dygraph.to_variable( + x) # TODO(liym27): Delete it if the type of parameter x can be resolved while x < 10: i = i + x x = x + 1 @@ -37,6 +40,9 @@ def while_loop_dyfunc(x): def while_loop_dyfun_with_conflict_var(x): i = fluid.dygraph.to_variable(x) + # Use `to_variable` so that static analysis can analyze the type of X is Tensor + x = fluid.dygraph.to_variable( + x) # TODO(liym27): Delete it if the type of parameter x can be resolved def relu(y): # 'y' is not visible outside the scope. @@ -56,6 +62,9 @@ def while_loop_dyfunc_with_none(x): i = fluid.dygraph.to_variable(x)\ if x is not None \ else fluid.dygraph.to_variable(x+1) + # Use `to_variable` so that static analysis can analyze the type of X is Tensor + x = fluid.dygraph.to_variable( + x) # TODO(liym27): Delete it if the type of parameter x can be resolved flag = 1 while x < 10: i = i + x if flag is not None else x + i @@ -72,6 +81,10 @@ def for_loop_dyfunc(max_len): def while_loop_bool_op(x): i = fluid.dygraph.to_variable(x) + + # Use `to_variable` so that static analysis can analyze the type of X is Tensor + x = fluid.dygraph.to_variable( + x) # TODO(liym27): Delete it if the type of parameter x can be resolved while (x >= 0 and x < 10) or x <= -1 or x < -3 or (x < -7 or x < -5): i = i + x x = x + 1 @@ -102,6 +115,11 @@ def for_loop_class_var(max_len): self.c = 5 foo = Foo() + + # Use `to_variable` so that static analysis can analyze the type of X is Tensor + # TODO(liym27): Delete it if the type of parameter x can be resolved + max_len = fluid.layers.fill_constant( + shape=[1], value=max_len, dtype="int32") for i in range(max_len): foo.b = fluid.layers.zeros(shape=[1], dtype='float32') foo.c = foo.b + foo.a diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py index 27d49404ce1..3a450d4554f 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py @@ -69,6 +69,11 @@ def test_slice_in_while_loop(x, iter_num): def test_slice_in_for_loop(x, iter_num): x = fluid.dygraph.to_variable(x) a = [] + # Use `fill_constant` so that static analysis can analyze the type of iter_num is Tensor + iter_num = fluid.layers.fill_constant( + shape=[1], value=iter_num, dtype="int32" + ) # TODO(liym27): Delete it if the type of parameter iter_num can be resolved + for i in range(iter_num): a.append(x) -- GitLab