test_closure_analysis.py 8.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import inspect
16 17
import unittest

18 19
from numpy import append

20
import paddle
21
from paddle.jit.dy2static.utils import FunctionNameLivenessAnalysis
22
from paddle.utils import gast
23 24

global_a = []
25 26 27


class JudgeVisitor(gast.NodeVisitor):
28
    def __init__(self, ans, mod):
29
        self.ans = ans
30
        self.mod = mod
31 32 33 34

    def visit_FunctionDef(self, node):
        scope = node.pd_scope
        expected = self.ans.get(node.name, set())
35 36
        exp_mod = self.mod.get(node.name, set())
        assert scope.existed_vars() == expected, "Not Equals."
37 38 39 40 41
        assert (
            scope.modified_vars() == exp_mod
        ), "Not Equals in function:{} . expect {} , but get {}".format(
            node.name, exp_mod, scope.modified_vars()
        )
42 43 44
        self.generic_visit(node)


45 46 47 48 49 50 51
class JudgePushPopVisitor(gast.NodeVisitor):
    def __init__(self, push_pop_vars):
        self.pp_var = push_pop_vars

    def visit_FunctionDef(self, node):
        scope = node.pd_scope
        expected = self.pp_var.get(node.name, set())
52 53 54 55 56
        assert (
            scope.push_pop_vars == expected
        ), "Not Equals in function:{} . expect {} , but get {}".format(
            node.name, expected, scope.push_pop_vars
        )
57 58 59
        self.generic_visit(node)


60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
def test_normal_0(x):
    def func():
        if True:
            i = 1

    func()
    return i


def test_normal_argument(x):
    x = 1

    def func():
        if True:
            print(x)
            i = 1

    func()
    return x


def test_global(x):
    global t
    t = 10

    def func():
        if True:
            print(x)
            i = 1

    func()
    return x


def test_nonlocal(x, *args, **kargs):
    i = 10

    def func(*args, **kargs):
        nonlocal i
        k = 10
        if True:
            print(x)
            i = 1

    func(*args, **kargs)
    return x


108
def test_push_pop_1(x, *args, **kargs):
109
    """push_pop_vars in main_function is : `l`, `k`"""
110 111 112 113 114 115 116 117 118
    l = []
    k = []
    for i in range(10):
        l.append(i)
        k.pop(i)
    return l


def test_push_pop_2(x, *args, **kargs):
119
    """push_pop_vars in main_function is : `k`"""
120 121 122 123 124 125 126 127 128 129 130 131
    l = []
    k = []

    def func():
        l.append(0)

    for i in range(10):
        k.append(i)
    return l, k


def test_push_pop_3(x, *args, **kargs):
132 133 134 135
    """push_pop_vars in main_function is : `k`
    NOTE: One may expect `k` and `l` because l
          is nonlocal. Name bind analysis is
          not implemented yet.
136 137 138 139 140 141 142 143 144 145 146 147 148 149
    """
    l = []
    k = []

    def func():
        nonlocal l
        l.append(0)

    for i in range(10):
        k.append(i)
    return l, k


def test_push_pop_4(x, *args, **kargs):
150
    """push_pop_vars in main_function is : `k`"""
151 152 153 154 155 156 157 158 159 160 161
    l = []
    k = []
    for i in range(10):
        for j in range(10):
            if True:
                l.append(j)
            else:
                k.pop()
    return l, k


162 163
class TestClosureAnalysis(unittest.TestCase):
    def setUp(self):
164
        self.judge_type = "var and w_vars"
165 166 167 168
        self.init_dygraph_func()

    def init_dygraph_func(self):
        self.all_dygraph_funcs = [
169 170 171 172
            test_nonlocal,
            test_global,
            test_normal_0,
            test_normal_argument,
173 174
        ]
        self.answer = [
175
            {'func': set('k'), 'test_nonlocal': set('i')},
176 177 178 179 180 181 182 183 184 185 186
            {
                'func': set({'i'}),
            },
            {
                'func': set('i'),
            },
            {
                'func': set('i'),
            },
        ]

187
        self.modified_var = [
188 189
            {'func': set('ki'), 'test_nonlocal': set('i')},
            {'func': set({'i'}), 'test_global': set({"t"})},
190 191 192
            {
                'func': set('i'),
            },
193
            {'func': set('i'), 'test_normal_argument': set('x')},
194 195
        ]

196
    def test_main(self):
197
        if self.judge_type == 'push_pop_vars':
198 199 200
            for push_pop_vars, func in zip(
                self.push_pop_vars, self.all_dygraph_funcs
            ):
201 202 203 204 205
                test_func = inspect.getsource(func)
                gast_root = gast.parse(test_func)
                name_visitor = FunctionNameLivenessAnalysis(gast_root)
                JudgePushPopVisitor(push_pop_vars).visit(gast_root)
        else:
206 207 208
            for mod, ans, func in zip(
                self.modified_var, self.answer, self.all_dygraph_funcs
            ):
209 210 211 212
                test_func = inspect.getsource(func)
                gast_root = gast.parse(test_func)
                name_visitor = FunctionNameLivenessAnalysis(gast_root)
                JudgeVisitor(ans, mod).visit(gast_root)
213 214 215 216 217 218 219 220 221 222 223 224 225


def TestClosureAnalysis_Attribute_func():
    # in this function, only self is a Name, self.current is a Attribute. self is read and self.current.function is store()
    i = 0
    self.current.function = 12


class TestClosureAnalysis_Attribute(TestClosureAnalysis):
    def init_dygraph_func(self):

        self.all_dygraph_funcs = [TestClosureAnalysis_Attribute_func]
        self.answer = [{"TestClosureAnalysis_Attribute_func": set({'i'})}]
226 227 228 229 230 231 232
        self.modified_var = [
            {
                "TestClosureAnalysis_Attribute_func": set(
                    {'i', 'self.current.function'}
                )
            }
        ]
233 234


235 236 237 238
class TestClosureAnalysis_PushPop(TestClosureAnalysis):
    def init_dygraph_func(self):
        self.judge_type = "push_pop_vars"
        self.all_dygraph_funcs = [
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
            test_push_pop_1,
            test_push_pop_2,
            test_push_pop_3,
            test_push_pop_4,
        ]
        self.push_pop_vars = [
            {
                "test_push_pop_1": set({'l', 'k'}),
            },
            {
                "test_push_pop_2": set({'k'}),
                "func": set("l"),
            },
            {
                "test_push_pop_3": set({'k'}),
                "func": set("l"),
            },
            {
                "test_push_pop_4": set({'k', 'l'}),
            },
259 260 261
        ]


262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
class TestPushPopTrans(unittest.TestCase):
    def test(self):
        def vlist_of_dict(x):
            ma = {'a': []}
            for i in range(3):
                ma['a'].append(1)
            return ma

        x = paddle.to_tensor([3])
        print(paddle.jit.to_static(vlist_of_dict).code)
        print(paddle.jit.to_static(vlist_of_dict)(x))

    def test2(self):
        import numpy as np

        def vlist_of_dict(x):
            a = np.array([1, 2, 3])
            for i in range(3):
                np.append(a, 4)
            return a

        x = paddle.to_tensor([3])
        print(paddle.jit.to_static(vlist_of_dict).code)
        print(paddle.jit.to_static(vlist_of_dict)(x))

    def test3(self):
        import numpy as np

        def vlist_of_dict(x):
            a = np.array([1, 2, 3])
            if True:
                pass
            return a

        x = paddle.to_tensor([3])
        print(paddle.jit.to_static(vlist_of_dict).code)
        print(paddle.jit.to_static(vlist_of_dict)(x))

    def test4(self):
301 302
        import numpy as np

303 304 305 306 307 308 309 310 311 312 313
        def vlist_of_dict(x):
            a = np.array([1, 2, 3])
            for i in range(3):
                append(a, 4)
            return a

        x = paddle.to_tensor([3])
        print(paddle.jit.to_static(vlist_of_dict).code)
        print(paddle.jit.to_static(vlist_of_dict)(x))

    def test5(self):
314 315
        import numpy as np

316 317 318 319 320 321 322 323 324 325 326
        def vlist_of_dict(x):
            a = np.array([1, 2, 3])
            for i in range(3):
                global_a.append(4)
            return a

        x = paddle.to_tensor([3])
        print(paddle.jit.to_static(vlist_of_dict).code)
        print(paddle.jit.to_static(vlist_of_dict)(x))


327 328
if __name__ == '__main__':
    unittest.main()