diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py index 115b3aa925cf7f76160ba8c1e20c58314d73bf5a..5e9505c64828a7fe4fe6a541da59811bfd3d26c3 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py @@ -16,7 +16,7 @@ from __future__ import print_function import gast import warnings -from .utils import is_paddle_api, is_dygraph_api, is_numpy_api +from .utils import is_paddle_api, is_dygraph_api, is_numpy_api, index_in_list __all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor'] @@ -260,20 +260,27 @@ class StaticAnalysisVisitor(object): def get_var_env(self): return self.var_env + def _get_constant_node_type(self, node): + assert isinstance(node, gast.Constant), \ + "Type of input node should be gast.Constant, but received %s" % type(node) + # singleton: None, True or False + if node.value is None: + return {NodeVarType.NONE} + if isinstance(node.value, bool): + return {NodeVarType.BOOLEAN} + if isinstance(node.value, int): + return {NodeVarType.INT} + if isinstance(node.value, float): + return {NodeVarType.FLOAT} + if isinstance(node.value, str): + return {NodeVarType.STRING} + + return {NodeVarType.UNKNOWN} + def _get_node_var_type(self, cur_wrapper): node = cur_wrapper.node if isinstance(node, gast.Constant): - # singleton: None, True or False - if node.value is None: - return {NodeVarType.NONE} - if isinstance(node.value, bool): - return {NodeVarType.BOOLEAN} - if isinstance(node.value, int): - return {NodeVarType.INT} - if isinstance(node.value, float): - return {NodeVarType.FLOAT} - if isinstance(node.value, str): - return {NodeVarType.STRING} + return self._get_constant_node_type(node) if isinstance(node, gast.BoolOp): return {NodeVarType.BOOLEAN} @@ -308,8 +315,28 @@ class StaticAnalysisVisitor(object): if isinstance(node, gast.Name): if node.id == "None": return {NodeVarType.NONE} - if node.id == "True" or node.id == "False": + if node.id in {"True", "False"}: return {NodeVarType.BOOLEAN} + # If node is child of functionDef.arguments + parent_node_wrapper = cur_wrapper.parent + if parent_node_wrapper and isinstance(parent_node_wrapper.node, + gast.arguments): + parent_node = parent_node_wrapper.node + var_type = {NodeVarType.UNKNOWN} + if parent_node.defaults: + index = index_in_list(parent_node.args, node) + args_len = len(parent_node.args) + if index != -1 and args_len - index <= len( + parent_node.defaults): + defaults_node = parent_node.defaults[index - args_len] + if isinstance(defaults_node, gast.Constant): + var_type = self._get_constant_node_type( + defaults_node) + + # Add node with identified type into cur_env. + self.var_env.set_var_type(node.id, var_type) + return var_type + return self.var_env.get_var_type(node.id) if isinstance(node, gast.Return): diff --git a/python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py similarity index 88% rename from python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py rename to python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py index 6822397b1fe8135237d1a1db6dfcf443367373bd..e72688d800ba59f63503248f2a5d385da23d6882 100644 --- a/python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import numpy as np import paddle.fluid as fluid import unittest -from paddle.fluid.dygraph.dygraph_to_static import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor +from paddle.fluid.dygraph.dygraph_to_static import NodeVarType, StaticAnalysisVisitor def func_to_test1(a, b): @@ -117,12 +117,34 @@ result_var_type5 = { 'inner_unknown_func': {NodeVarType.UNKNOWN}, } + +def func_to_test6(x, y=1): + i = fluid.dygraph.to_variable(x) + + def add(x, y): + return x + y + + while x < 10: + i = add(i, x) + x = x + y + + return i + + +result_var_type6 = { + 'i': {NodeVarType.INT}, + 'x': {NodeVarType.INT}, + 'y': {NodeVarType.INT}, + 'add': {NodeVarType.INT} +} + test_funcs = [ - func_to_test1, func_to_test2, func_to_test3, func_to_test4, func_to_test5 + func_to_test1, func_to_test2, func_to_test3, func_to_test4, func_to_test5, + func_to_test6 ] result_var_type = [ result_var_type1, result_var_type2, result_var_type3, result_var_type4, - result_var_type5 + result_var_type5, result_var_type6 ] @@ -150,8 +172,8 @@ class TestStaticAnalysis(unittest.TestCase): self._check_wrapper(wrapper_root, node_to_wrapper_map) def test_var_env(self): - for i in range(5): - func = test_funcs[i] + + for i, func in enumerate(test_funcs): var_type = result_var_type[i] test_source_code = inspect.getsource(func) ast_root = gast.parse(test_source_code) @@ -164,6 +186,7 @@ class TestStaticAnalysis(unittest.TestCase): var_env.cur_scope = var_env.cur_scope.sub_scopes[0] scope_var_type = var_env.get_scope_var_type() + print(scope_var_type) self.assertEqual(len(scope_var_type), len(var_type)) for name in scope_var_type: print("Test var name %s" % (name))