未验证 提交 69717717 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static ] Add closure analysis for control flow and add some unittest (#43713)

* add closure analysis for control flow and add some unittest

* finetune the design of FunctionScopeVisitor

* fix

* fix python check

* fix code by code review
上级 af97b310
...@@ -108,6 +108,99 @@ def create_while_nodes(condition_name, body_name, loop_var_names): ...@@ -108,6 +108,99 @@ def create_while_nodes(condition_name, body_name, loop_var_names):
return ret return ret
class NameScope:
def __init__(self):
""" we don't analyze the read only variable
because they keep the same in control flow.
"""
self.globals = set()
self.nonlocals = set()
self.args = set()
self.w_vars = set() # all vars been stored,
# may be globals or non-locals
def created_vars(self):
return self.w_vars - self.globals - self.nonlocals - self.args
def write_vars(self):
return self.w_vars
def global_vars(self):
return self.globals
class FunctionNameLivenessAnalysis(gast.NodeVisitor):
""" analyze the liveness of a function.
every variables stored in this scope will be collected,
in addition with global/nonlocal information.
1. global variable is stored in node.var_globals.
2. nonlocal variable is stored in node.var_nonlocals.
3. arguments is stored in node.var_args.
For example:
def func(*args, **kargs):
a = 12
global i,j
nonlocal x,y
print(a)
i = k
for m in range(10):
q = 12
After this visitor we have:
# node is the FunctionDef node with name: "func"
node.pd_scope = NameScope(
globals = ['i', 'j'],
nonlocals = ['x', 'y'],
args = ['args', 'kargs'],
wr_vars = ['a', 'i', 'q', 'm']
)
"""
def __init__(self, root_node):
self.funcdef_stack = []
self.visit(root_node)
def _current_funcdef_scope(self):
return self.funcdef_stack[-1].pd_scope
def visit_Name(self, node):
self.generic_visit(node)
write_context = (gast.Store, gast.AugStore, gast.Del)
if isinstance(node.ctx, write_context):
self._current_funcdef_scope().w_vars.add(node.id)
def visit_FunctionDef(self, node):
setattr(node, 'pd_scope', NameScope())
self.funcdef_stack.append(node)
self._current_funcdef_scope().args |= set(
self._get_argument_names(node))
self.generic_visit(node)
self.funcdef_stack.pop()
def visit_Global(self, node):
self._current_funcdef_scope().globals |= set(node.names)
def visit_Nonlocal(self, node):
self._current_funcdef_scope().nonlocals |= set(node.names)
def _get_argument_names(self, node):
""" get all arguments name in the functiondef node.
this node is local to the function and shouldn't
be created.
"""
assert isinstance(
node, gast.FunctionDef), "Input node is not function define node"
names = [a for a in node.args.args]
names.append(node.args.vararg)
names.append(node.args.kwarg)
names = [i.id for i in names if i is not None]
return names
class NameVisitor(gast.NodeVisitor): class NameVisitor(gast.NodeVisitor):
''' '''
Analysis name liveness for loop transformer Analysis name liveness for loop transformer
...@@ -122,7 +215,6 @@ class NameVisitor(gast.NodeVisitor): ...@@ -122,7 +215,6 @@ class NameVisitor(gast.NodeVisitor):
# List of nodes that have scope of variables. # List of nodes that have scope of variables.
self.nodes_with_scope = [] self.nodes_with_scope = []
self.blacklist_names = {"False", "True", "None"} self.blacklist_names = {"False", "True", "None"}
# Mapping from gast.While/gast.For to variable nodes # Mapping from gast.While/gast.For to variable nodes
...@@ -244,6 +336,7 @@ class NameVisitor(gast.NodeVisitor): ...@@ -244,6 +336,7 @@ class NameVisitor(gast.NodeVisitor):
type(gast.AugStore()), type(gast.AugStore()),
type(gast.Del()) type(gast.Del())
} }
for loop_node in self.current_loop: for loop_node in self.current_loop:
self.in_loop_vars[loop_node].append(node) self.in_loop_vars[loop_node].append(node)
if type(node.ctx) in write_context: if type(node.ctx) in write_context:
...@@ -255,6 +348,7 @@ class NameVisitor(gast.NodeVisitor): ...@@ -255,6 +348,7 @@ class NameVisitor(gast.NodeVisitor):
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
self.nodes_with_scope.append(node) self.nodes_with_scope.append(node)
self.blacklist_names.add(node.name) self.blacklist_names.add(node.name)
# The variables in the function are not visible to the outside scope. # The variables in the function are not visible to the outside scope.
before_func_seen_vars = copy.copy(self.current_seen_vars) before_func_seen_vars = copy.copy(self.current_seen_vars)
...@@ -353,6 +447,9 @@ class NameVisitor(gast.NodeVisitor): ...@@ -353,6 +447,9 @@ class NameVisitor(gast.NodeVisitor):
return True return True
return False return False
def _is_global_or_nonlocal(self, node):
return False
def _is_ancestor_node(self, ancestor_node, node): def _is_ancestor_node(self, ancestor_node, node):
parent_node = self._get_parent_node(node) parent_node = self._get_parent_node(node)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import FunctionNameLivenessAnalysis
from paddle.utils import gast
import inspect
class JudgeVisitor(gast.NodeVisitor):
def __init__(self, ans):
self.ans = ans
def visit_FunctionDef(self, node):
scope = node.pd_scope
expected = self.ans.get(node.name, set())
assert scope.created_vars() == expected, "Not Equals."
self.generic_visit(node)
def test_normal_0(x):
def func():
if True:
i = 1
func()
return i
def test_normal_argument(x):
x = 1
def func():
if True:
print(x)
i = 1
func()
return x
def test_global(x):
global t
t = 10
def func():
if True:
print(x)
i = 1
func()
return x
def test_nonlocal(x, *args, **kargs):
i = 10
def func(*args, **kargs):
nonlocal i
k = 10
if True:
print(x)
i = 1
func(*args, **kargs)
return x
class TestClosureAnalysis(unittest.TestCase):
def setUp(self):
self.init_dygraph_func()
def init_dygraph_func(self):
self.all_dygraph_funcs = [
test_nonlocal, test_global, test_normal_0, test_normal_argument
]
self.answer = [
{
'func': set('k'),
'test_nonlocal': set('i')
},
{
'func': set({'i'}),
},
{
'func': set('i'),
},
{
'func': set('i'),
},
]
def test_main(self):
for ans, func in zip(self.answer, self.all_dygraph_funcs):
test_func = inspect.getsource(func)
gast_root = gast.parse(test_func)
name_visitor = FunctionNameLivenessAnalysis(gast_root)
JudgeVisitor(ans).visit(gast_root)
def TestClosureAnalysis_Attribute_func():
# in this function, only self is a Name, self.current is a Attribute. self is read and self.current.function is store()
i = 0
self.current.function = 12
class TestClosureAnalysis_Attribute(TestClosureAnalysis):
def init_dygraph_func(self):
self.all_dygraph_funcs = [TestClosureAnalysis_Attribute_func]
self.answer = [{"TestClosureAnalysis_Attribute_func": set({'i'})}]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册