From 697177177ce5c39b53be41788428e3a220782a51 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 24 Jun 2022 17:06:17 +0800 Subject: [PATCH] [ Dy2Static ] Add closure analysis for control flow and add some unittest (#43713) * add closure analysis for control flow and add some unittest * finetune the design of FunctionScopeVisitor * fix * fix python check * fix code by code review --- .../dygraph_to_static/loop_transformer.py | 99 ++++++++++++- .../test_closure_analysis.py | 134 ++++++++++++++++++ 2 files changed, 232 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_closure_analysis.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 045878ed54..fa401fa3e4 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -108,6 +108,99 @@ def create_while_nodes(condition_name, body_name, loop_var_names): return ret +class NameScope: + + def __init__(self): + """ we don't analyze the read only variable + because they keep the same in control flow. + """ + self.globals = set() + self.nonlocals = set() + self.args = set() + self.w_vars = set() # all vars been stored, + # may be globals or non-locals + def created_vars(self): + return self.w_vars - self.globals - self.nonlocals - self.args + + def write_vars(self): + return self.w_vars + + def global_vars(self): + return self.globals + + +class FunctionNameLivenessAnalysis(gast.NodeVisitor): + """ analyze the liveness of a function. + + every variables stored in this scope will be collected, + in addition with global/nonlocal information. + + 1. global variable is stored in node.var_globals. + 2. nonlocal variable is stored in node.var_nonlocals. + 3. arguments is stored in node.var_args. + + For example: + + def func(*args, **kargs): + a = 12 + global i,j + nonlocal x,y + print(a) + i = k + for m in range(10): + q = 12 + + After this visitor we have: + # node is the FunctionDef node with name: "func" + node.pd_scope = NameScope( + globals = ['i', 'j'], + nonlocals = ['x', 'y'], + args = ['args', 'kargs'], + wr_vars = ['a', 'i', 'q', 'm'] + ) + """ + + def __init__(self, root_node): + self.funcdef_stack = [] + self.visit(root_node) + + def _current_funcdef_scope(self): + return self.funcdef_stack[-1].pd_scope + + def visit_Name(self, node): + self.generic_visit(node) + write_context = (gast.Store, gast.AugStore, gast.Del) + if isinstance(node.ctx, write_context): + self._current_funcdef_scope().w_vars.add(node.id) + + def visit_FunctionDef(self, node): + setattr(node, 'pd_scope', NameScope()) + self.funcdef_stack.append(node) + self._current_funcdef_scope().args |= set( + self._get_argument_names(node)) + self.generic_visit(node) + self.funcdef_stack.pop() + + def visit_Global(self, node): + self._current_funcdef_scope().globals |= set(node.names) + + def visit_Nonlocal(self, node): + self._current_funcdef_scope().nonlocals |= set(node.names) + + def _get_argument_names(self, node): + """ get all arguments name in the functiondef node. + this node is local to the function and shouldn't + be created. + """ + assert isinstance( + node, gast.FunctionDef), "Input node is not function define node" + names = [a for a in node.args.args] + names.append(node.args.vararg) + names.append(node.args.kwarg) + names = [i.id for i in names if i is not None] + return names + + class NameVisitor(gast.NodeVisitor): ''' Analysis name liveness for loop transformer @@ -122,7 +215,6 @@ class NameVisitor(gast.NodeVisitor): # List of nodes that have scope of variables. self.nodes_with_scope = [] - self.blacklist_names = {"False", "True", "None"} # Mapping from gast.While/gast.For to variable nodes @@ -244,6 +336,7 @@ class NameVisitor(gast.NodeVisitor): type(gast.AugStore()), type(gast.Del()) } + for loop_node in self.current_loop: self.in_loop_vars[loop_node].append(node) if type(node.ctx) in write_context: @@ -255,6 +348,7 @@ class NameVisitor(gast.NodeVisitor): def visit_FunctionDef(self, node): self.nodes_with_scope.append(node) self.blacklist_names.add(node.name) + # The variables in the function are not visible to the outside scope. before_func_seen_vars = copy.copy(self.current_seen_vars) @@ -353,6 +447,9 @@ class NameVisitor(gast.NodeVisitor): return True return False + def _is_global_or_nonlocal(self, node): + return False + def _is_ancestor_node(self, ancestor_node, node): parent_node = self._get_parent_node(node) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_closure_analysis.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_closure_analysis.py new file mode 100644 index 0000000000..7986fb1cba --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_closure_analysis.py @@ -0,0 +1,134 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle +from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import FunctionNameLivenessAnalysis +from paddle.utils import gast +import inspect + + +class JudgeVisitor(gast.NodeVisitor): + + def __init__(self, ans): + self.ans = ans + + def visit_FunctionDef(self, node): + scope = node.pd_scope + expected = self.ans.get(node.name, set()) + assert scope.created_vars() == expected, "Not Equals." + self.generic_visit(node) + + +def test_normal_0(x): + + def func(): + if True: + i = 1 + + func() + return i + + +def test_normal_argument(x): + x = 1 + + def func(): + if True: + print(x) + i = 1 + + func() + return x + + +def test_global(x): + global t + t = 10 + + def func(): + if True: + print(x) + i = 1 + + func() + return x + + +def test_nonlocal(x, *args, **kargs): + i = 10 + + def func(*args, **kargs): + nonlocal i + k = 10 + if True: + print(x) + i = 1 + + func(*args, **kargs) + return x + + +class TestClosureAnalysis(unittest.TestCase): + + def setUp(self): + self.init_dygraph_func() + + def init_dygraph_func(self): + self.all_dygraph_funcs = [ + test_nonlocal, test_global, test_normal_0, test_normal_argument + ] + self.answer = [ + { + 'func': set('k'), + 'test_nonlocal': set('i') + }, + { + 'func': set({'i'}), + }, + { + 'func': set('i'), + }, + { + 'func': set('i'), + }, + ] + + def test_main(self): + for ans, func in zip(self.answer, self.all_dygraph_funcs): + test_func = inspect.getsource(func) + gast_root = gast.parse(test_func) + name_visitor = FunctionNameLivenessAnalysis(gast_root) + JudgeVisitor(ans).visit(gast_root) + + +def TestClosureAnalysis_Attribute_func(): + # in this function, only self is a Name, self.current is a Attribute. self is read and self.current.function is store() + i = 0 + self.current.function = 12 + + +class TestClosureAnalysis_Attribute(TestClosureAnalysis): + + def init_dygraph_func(self): + + self.all_dygraph_funcs = [TestClosureAnalysis_Attribute_func] + self.answer = [{"TestClosureAnalysis_Attribute_func": set({'i'})}] + + +if __name__ == '__main__': + unittest.main() -- GitLab