loop_transformer.py 28.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import copy
18
from paddle.utils import gast
19 20 21

from collections import defaultdict
from paddle.fluid import unique_name
22
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
23
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
24
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
25
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
26
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
27
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
28
from paddle.fluid.dygraph.dygraph_to_static.utils import ForLoopTuplePreTransformer
29
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
30
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
31 32 33 34 35 36 37
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node

__all__ = ['LoopTransformer', 'NameVisitor']

WHILE_CONDITION_PREFIX = 'while_condition'
WHILE_BODY_PREFIX = 'while_body'

38 39
FOR_CONDITION_PREFIX = 'for_loop_condition'
FOR_BODY_PREFIX = 'for_loop_body'
40
GENERATE_VARIABLE_PREFIX = 'generate_variable'
41

42
ATTRIBUTE_VARIABLE_PREFIX = '__attribute_variable'
43

44 45 46 47 48 49 50 51 52 53 54 55 56

def create_while_nodes(condition_name, body_name, loop_var_names):
    """
    Returns a list of gast.Node which represents the calling of Paddle
    controlflow while_loop.

    Usually, the list just contain 1 statement such as:

    [a, b, c] = paddle.jit.dy2static.convert_while_loop(
            condition_name, body_name, [a, b, c])

    where a, b, c are in loop_var_names.

H
Huihuang Zheng 已提交
57 58
    However, if loop_var_names contains property such as foo.x, we cannot
    assign the property as output of convert_while_loop because Python
59 60 61 62 63 64 65 66 67 68 69 70
    property is a kind of read-only attribute. To handle the case, we replace
    the attributes which are output of convert_while_loop with generated
    variables, then if we know the attribute is not read-only at runtime, we
    assign the attribute. The created statements are like:

    [a, b, __attribute_variable_1] = paddle.jit.dy2static.convert_while_loop(
            condition_name, body_name, [a, b, foo.x])
    if not isinstance(getattr(type(foo), x, None), property): foo.x = __attribute_variable_1

    The number of above statements is not only 1, that's why the return type is
    a list of gast.Node.
    """
71 72 73 74 75 76 77
    # NOTE(liym27):
    # It's better to parse the source code into an AST node than to customize an AST node
    # including child nodes, because it is easy to mistake the ast node type when customizing the node.
    #
    # For example: loop_var_names = [a, b, foo.x], the type of `a` or `b` is gast.Name,
    # but the type of `foo.x` gast.Attribute.

78 79 80 81 82 83 84 85 86 87 88 89 90 91
    unique_name_to_origin = {}
    # We have to make loop_var_names and assign_loop_var_names with same order
    # set doesn't have order so we convert it to list
    loop_var_names = list(loop_var_names)
    assign_loop_var_names = []
    for name in (loop_var_names):
        if "." in name:
            # name is an attribute variable such as foo.x
            tmp_attr_name = unique_name.generate(ATTRIBUTE_VARIABLE_PREFIX)
            unique_name_to_origin[tmp_attr_name] = name
            assign_loop_var_names.append(tmp_attr_name)
        else:
            assign_loop_var_names.append(name)

92
    while_func_name = "paddle.jit.dy2static.convert_while_loop"
93
    while_node_str = "[{}] = {}({}, {}, [{}])".format(
94 95
        ",".join(assign_loop_var_names), while_func_name, condition_name,
        body_name, ",".join(loop_var_names))
96 97
    while_node = gast.parse(while_node_str).body[0]

98 99 100 101 102 103 104 105 106 107 108
    ret = [while_node]
    for tmp_attr_name in unique_name_to_origin:
        origin_attr_var = unique_name_to_origin[tmp_attr_name]
        dot_pos = origin_attr_var.rindex(".")
        obj_name = origin_attr_var[0:dot_pos]
        attr_name = origin_attr_var[dot_pos + 1:]
        assign_if_not_prop_str = "if not isinstance(getattr(type({}), '{}', None), property): {} = {}".format(
            obj_name, attr_name, origin_attr_var, tmp_attr_name)
        assign_if_not_prop_node = gast.parse(assign_if_not_prop_str).body[0]
        ret.append(assign_if_not_prop_node)
    return ret
109 110 111 112 113 114 115 116


class NameVisitor(gast.NodeVisitor):
    '''
    Analysis name liveness for loop transformer
    '''

    def __init__(self, root_node):
117
        # Set of gast.Name or gast.Attribute for variables
118
        self.current_seen_vars = set()
119

120 121 122
        # List of gast.While/gast.For nodes
        self.current_loop = []

123 124 125 126 127
        # List of nodes that have scope of variables.
        self.nodes_with_scope = []

        self.blacklist_names = {"False", "True", "None"}

128 129
        # Mapping from gast.While/gast.For to variable nodes
        self.before_loop_body_vars = defaultdict(set)
130 131
        # NOTE: Use ordered list as dict value
        self.in_loop_vars = defaultdict(list)
132

133 134 135 136 137 138
        # Mapping from gast.While/gast.For to variable nodes which is condition
        # of loop or being modified during the loop
        self.write_in_loop = defaultdict(set)
        self.condition_vars = defaultdict(set)
        self.in_condition = False

139 140 141
        # Some names are types, we shouldn't record them as loop var names.
        self.type_vars = set()

142 143 144 145
        self.static_analysis_visitor = StaticAnalysisVisitor(root_node)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )

146 147 148
        self.visit(root_node)

    def get_loop_var_names(self, node):
149 150
        assert isinstance(
            node, (gast.While, gast.For)), "Input node is not gast loop node"
151 152
        loop_var_names = set()
        create_var_names = set()
H
Huihuang Zheng 已提交
153
        read_context = {type(gast.Load()), type(gast.AugLoad())}
154

155 156 157 158 159 160 161 162 163
        in_loop_vars_list = self.in_loop_vars[node]

        # get dict `var_name_to_ctxs`
        var_name_to_ctxs = defaultdict(list)
        for var_node in in_loop_vars_list:
            var_name_to_ctxs[self._var_node_to_name(var_node)].append(
                var_node.ctx)

        in_loop_vars = set(in_loop_vars_list)
164
        in_loop_vars = self._remove_unnecessary_vars(in_loop_vars, node)
165
        in_loop_name_strs = self._var_nodes_to_names(in_loop_vars)
166

167
        before_loop_body_vars = self.before_loop_body_vars[node]
168
        before_loop_body_vars = self._remove_unnecessary_vars(
169
            before_loop_body_vars, node)
170
        before_loop_name_strs = self._var_nodes_to_names(before_loop_body_vars)
171

172
        after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars
173
        after_loop_vars = self._remove_unnecessary_vars(after_loop_vars, node)
174 175
        after_loop_name_strs = self._var_nodes_to_names(after_loop_vars,
                                                        read_context)
176 177
        condition_vars = self.condition_vars[node]
        condition_names = self._var_nodes_to_names(condition_vars)
178

179 180 181 182 183 184 185
        write_vars = self.write_in_loop[node]
        write_names = self._var_nodes_to_names(write_vars)

        name_to_type = {}
        for var in in_loop_vars:
            wrapper = self.node_to_wrapper_map[var]
            name_to_type[self._var_node_to_name(var)] = wrapper.node_var_type
186 187
        for name in in_loop_name_strs:
            if name in before_loop_name_strs:
188 189 190 191 192 193 194 195 196
                # If a variable is used in loop and created before loop

                # If this var is a basic variable and read-only and not
                # condition var, it may not be loop_var else it should
                # be in loop_var as input
                if (not name in condition_names) and (
                        not name in write_names
                ) and self._node_var_type_is_basic(name_to_type[name]):
                    continue
197
                loop_var_names.add(name)
198

199 200 201
            elif name in after_loop_name_strs:
                # If a variable is created in the while loop and read after
                # loop, it should be in loop_var and we should create it
202 203 204 205

                # because name in after_loop_name must be initialized in loop
                # So it is write-only, we don't have to filter read-only basic
                # vars out
206 207
                loop_var_names.add(name)
                create_var_names.add(name)
208 209 210 211 212 213 214 215 216 217 218 219 220
            else:
                # If a variable is used and created in loop, but used before created,
                # it should be in loop_var and we should create it.

                # For example, `var_a` should be in loop_var and we should create it.
                #
                #   res = 0
                #   for i, x in enumerate(x_array):
                #       if i > 2:
                #           x = func1(var_a)
                #       var_a = func2(x)
                #

221 222 223 224 225 226 227
                is_created = False
                for ctx in var_name_to_ctxs[name]:
                    if isinstance(ctx, gast.Store):
                        is_created = True

                if isinstance(var_name_to_ctxs[name][0],
                              gast.Load) and is_created:
228 229
                    loop_var_names.add(name)
                    create_var_names.add(name)
230

231 232 233
        return loop_var_names, create_var_names

    def visit_Name(self, node):
234 235 236
        if self._is_call_func_name_node(node):
            self.generic_visit(node)
            return
237
        if node.id in self.blacklist_names:
238 239
            self.generic_visit(node)
            return
240

241
        self.current_seen_vars.add(node)
242 243 244
        write_context = {
            type(gast.Store()), type(gast.AugStore()), type(gast.Del())
        }
245
        for loop_node in self.current_loop:
246
            self.in_loop_vars[loop_node].append(node)
247 248
            if type(node.ctx) in write_context:
                self.write_in_loop[loop_node].add(node)
249 250
        if self.in_condition:
            self.condition_vars[loop_node].add(node)
251 252
        self.generic_visit(node)

253 254 255 256 257 258 259 260 261 262 263 264 265
    def visit_FunctionDef(self, node):
        self.nodes_with_scope.append(node)
        self.blacklist_names.add(node.name)
        # The variables in the function are not visible to the outside scope.
        before_func_seen_vars = copy.copy(self.current_seen_vars)

        self.generic_visit(node)
        self.nodes_with_scope.pop()
        # After exiting the scope of the node, variables in this scope
        # should be removed from self.current_seen_vars.
        if self.nodes_with_scope:
            self.current_seen_vars = before_func_seen_vars

266 267 268 269 270 271 272 273 274 275
    def visit(self, node):
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        ret = visitor(node)
        return ret

    def visit_Attribute(self, node):
        if self._is_call_func_name_node(node):
            return
        attr_full_name = get_attribute_full_name(node)
276 277 278 279 280 281 282 283 284 285
        # Class variables are not allowed to appear in the arguments list
        # of defined function under class methods in Python.
        """
        def class_func(self):
            def while_loop_body(self.x, y) # `self.x` is illegal.
        """
        # TODO: If do change the variable with `self.var`, need a better
        # way to deal with this case.
        if attr_full_name.startswith("self."):
            return
286
        self.current_seen_vars.add(node)
287

288
        for loop_node in self.current_loop:
289
            self.in_loop_vars[loop_node].append(node)
290

291 292 293
        # sub-nodes are visited during get_attribute_full_name and we shouldn't
        # visit again

294 295
    def visit_For(self, node):
        self.current_loop.append(node)
296
        self.in_condition = True
297
        self.visit(node.target)
298 299
        self.visit(node.iter)
        self.in_condition = False
300
        self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
301 302 303 304 305
        self.generic_visit(node)
        self.current_loop.pop()

    def visit_While(self, node):
        self.current_loop.append(node)
306
        self.in_condition = True
307
        self.visit(node.test)
308
        self.in_condition = False
309
        self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
310 311 312
        self.generic_visit(node)
        self.current_loop.pop()

313 314 315 316 317 318 319
    def visit_Call(self, node):
        # Store type var names such as "isinstance(x, some_type_names)" and
        # Remove them later
        if isinstance(node.func, gast.Name) and node.func.id == 'isinstance':
            type_node = node.args[1]
            if isinstance(type_node, gast.Tuple):
                for element in type_node.elts:
320
                    self.type_vars.add(ast_to_source_code(element).strip())
321
            else:
322
                self.type_vars.add(ast_to_source_code(type_node).strip())
323 324
        self.generic_visit(node)

325 326 327 328
    def _var_nodes_to_names(self, node_set, ctx_filter_set=None):
        ret = set()
        for node in node_set:
            if ctx_filter_set is None or type(node.ctx) in ctx_filter_set:
329
                ret.add(self._var_node_to_name(node))
330 331
        return ret

332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
    def _var_node_to_name(self, node):
        if isinstance(node, gast.Name):
            return node.id
        elif isinstance(node, gast.Attribute):
            return get_attribute_full_name(node)

    def _node_var_type_is_basic(self, node_var_type):
        basic_types = {
            NodeVarType.BOOLEAN, NodeVarType.INT, NodeVarType.FLOAT,
            NodeVarType.STRING
        }
        for t in node_var_type:
            if t in basic_types:
                return True
        return False

348
    def _is_call_func_name_node(self, node):
349
        parent_node = self._get_parent_node(node)
350 351
        if isinstance(parent_node, gast.Call) and parent_node.func == node:
            return True
352 353
        return False

354 355 356 357 358 359 360 361 362
    def _is_ancestor_node(self, ancestor_node, node):
        parent_node = self._get_parent_node(node)

        while parent_node is not None:
            if parent_node == ancestor_node:
                return True
            parent_node = self._get_parent_node(parent_node)
        return False

363 364 365
    def _get_parent_node(self, node):
        wrapper_node = self.node_to_wrapper_map.get(node)
        if wrapper_node:
366 367 368
            if wrapper_node.parent:
                parent_node = wrapper_node.parent.node
                return parent_node
369 370
        return None

371
    def _remove_unnecessary_vars(self, loop_vars, loop_node):
372
        """
373 374 375
        Remove unnecessary vars from before_loop_vars, after_loop_vars or in_loop_vars about loop_node.
            1. Remove target vars of gast.For from before_loop_vars or after_loop_vars.
            2. Remove vars only in gast.comprehension.
376
            3. Remove vars that are type names, for example: "isinstance(x, var_type_name)"
377
        :param loop_vars: before_loop_vars, after_loop_vars or in_loop_vars of loop_node.
378 379 380
        :param loop_node: Current loop node.
        """

381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
        def filter_name_nodes_from(root_node, target_var_names):
            """
            Filter children with gast.Name type from node.(inclusivly)
            """
            name_nodes = set()
            if isinstance(root_node, gast.Name):
                if node.id in target_var_names:
                    name_nodes.add(root_node)
            for child_node in gast.walk(root_node):
                if isinstance(child_node, gast.Name):
                    if child_node.id in target_var_names:
                        name_nodes.add(child_node)

            return name_nodes

396 397 398 399
        vars_of_list_generator = set()
        target_vars_of_for_node = set()

        for name_node in loop_vars:
400 401 402 403 404
            if not isinstance(name_node, gast.Name):
                continue

            parent_node = self._get_parent_node(name_node)

405 406 407 408
            # NOTE: gast.For.target or gast.comprehension.target can be gast.Tuple.
            #  For examples:
            #   1) `for i, j in enumerate(x)` has two target vars: i and j
            #   2) `[x for x,y in array]` has two target vars: x and y
409 410 411
            if isinstance(parent_node, gast.Tuple):
                parent_node = self._get_parent_node(parent_node)

412 413 414 415 416 417 418
            # 1. Get vars only in gast.comprehension.
            # For examples:
            #  1) [x for x,y in array] -> x, x, y
            #  2) [f(x) for x in array] -> x
            #  3) [func(x, y) for x in array] -> x, x
            if isinstance(parent_node, gast.comprehension):
                # 1.1 target vars in list/set comprehensions
419 420 421 422 423 424
                target_node = parent_node.target
                if isinstance(target_node, gast.Tuple):
                    target_vars = target_node.elts
                else:
                    target_vars = [target_node]

425 426 427 428 429
                vars_of_list_generator = vars_of_list_generator | set(
                    target_vars)

                # 1.2 vars from target vars used in elt_node
                target_var_names = {var.id for var in target_vars}
430 431 432 433 434 435 436 437 438 439
                comp_node = self._get_parent_node(parent_node)
                elt_nodes = []
                if isinstance(comp_node, gast.ListComp):
                    elt_nodes.append(comp_node.elt)
                elif isinstance(comp_node, gast.DictComp):
                    elt_nodes.extend([comp_node.key, comp_node.value])

                for node in elt_nodes:
                    vars_of_list_generator |= filter_name_nodes_from(
                        node, target_var_names)
440

441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456
            # 2. Get target vars or vars from target vars used in for-loop but the for-loop is
            #   1) not the "loop_node" itself
            #   2) not the ancestor of the "loop_node"
            #
            # For examples:
            #   for k in range(x):   # if it's this "loop_node", i or j both should be target vars.
            #      # do something
            #
            #   for i in range(a):   # if it's this "loop_node", k or j should be in target vars but i should not.
            #     for j in range(a): # if it's this "loop_node", k should be in target_vars but i or j should not.
            #       x = i+j
            elif isinstance(parent_node, gast.For):
                if parent_node is loop_node:
                    continue
                if self._is_ancestor_node(parent_node, loop_node):
                    continue
457 458 459 460 461 462
                # 2.1 target vars in gast.For node.
                target_node = parent_node.target
                if isinstance(target_node, gast.Tuple):
                    target_vars = target_node.elts
                else:
                    target_vars = [target_node]
463

464 465
                target_vars_of_for_node = target_vars_of_for_node | set(
                    target_vars)
466

467 468 469
        # 2.2 vars from target vars used in for-loop
        target_vars_name_strs = {var.id for var in target_vars_of_for_node}
        for var in loop_vars:
470 471
            if not isinstance(var, gast.Name):
                continue
472
            if var.id in target_vars_name_strs and var not in self.condition_vars[
473
                    loop_node]:
474
                target_vars_of_for_node.add(var)
475

476
        removed_vars = target_vars_of_for_node | vars_of_list_generator
477 478 479

        # 3. Remove var type names which are stored in self.type_vars
        for var in loop_vars:
480
            if ast_to_source_code(var).strip() in self.type_vars:
481 482
                removed_vars.add(var)

483
        return loop_vars - removed_vars
484

485 486 487 488 489 490 491 492 493

class LoopTransformer(gast.NodeTransformer):
    """
    This class transforms python while/for statement into Static Graph Ast
    """

    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
494
        ), "Input non-AstNodeWrapper node for the initialization of LoopTransformer."
495 496 497 498
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node

    def transform(self):
499 500
        ForLoopTuplePreTransformer(self.wrapper_root).transform()
        self.name_visitor = NameVisitor(self.root)
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522
        self.visit(self.root)

    def visit(self, node):
        self.generic_visit(node)
        # All parent nodes that may contain gast.While/gast.For
        if hasattr(node, 'body'):
            self.replace_stmt_list(node.body)
        if hasattr(node, 'orelse'):
            self.replace_stmt_list(node.orelse)
        return node

    def replace_stmt_list(self, body_list):
        if not isinstance(body_list, list):
            return

        i = 0
        while i < len(body_list):
            if isinstance(body_list[i], gast.While):
                new_stmts = self.get_while_stmt_nodes(body_list[i])
                body_list[i:i + 1] = new_stmts
                i += len(new_stmts)
            elif isinstance(body_list[i], gast.For):
523 524 525
                new_stmts = self.get_for_stmt_nodes(body_list[i])
                body_list[i:i + 1] = new_stmts
                i += len(new_stmts)
526 527 528
            else:
                i += 1

529 530 531
    def get_for_stmt_nodes(self, node):
        # TODO: consider for - else in python

532 533
        # 1. get key statements for different cases
        # NOTE 1: three key statements:
534 535 536 537
        #   1). init_stmts: list[node], prepare nodes of for loop, may not only one
        #   2). cond_stmt: node, condition node to judge whether continue loop
        #   3). body_stmts: list[node], updated loop body, sometimes we should change
        #       the original statement in body, not just append new statement
538 539 540 541 542 543
        #
        # NOTE 2: The following `for` statements will be transformed to `while` statements:
        #   1). for x in range(*)
        #   2). for x in iter_var
        #   3). for i, x in enumerate(*)

544
        current_for_node_parser = ForNodeVisitor(node)
545 546
        stmts_tuple = current_for_node_parser.parse()
        if stmts_tuple is None:
547
            return [node]
548
        init_stmts, cond_stmt, body_stmts = stmts_tuple
549

550
        # 2. get original loop vars
551 552
        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
553 554 555 556 557 558 559 560 561 562 563 564
        # NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
        # we need append new loop var & remove useless loop var
        #   1. for x in var -> x is no need
        #   2. for i, x in enumerate(var) -> x is no need
        if current_for_node_parser.is_for_iter(
        ) or current_for_node_parser.is_for_enumerate_iter():
            iter_var_name = current_for_node_parser.iter_var_name
            iter_idx_name = current_for_node_parser.iter_idx_name
            loop_var_names.add(iter_idx_name)
            if iter_var_name not in create_var_names:
                loop_var_names.remove(iter_var_name)

565
        # 3. prepare result statement list
566 567 568 569 570 571 572 573 574
        new_stmts = []
        # Python can create variable in loop and use it out of loop, E.g.
        #
        # for x in range(10):
        #     y += x
        # print(x) # x = 10
        #
        # We need to create static variable for those variables
        for name in create_var_names:
575 576
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))
577

578
        # 4. append init statements
579
        new_stmts.extend(init_stmts)
580

581
        # 5. create & append condition function node
582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601
        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_CONDITION_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=[gast.Return(value=cond_stmt)],
            decorator_list=[],
            returns=None,
            type_comment=None)
602 603 604 605 606
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
607 608
        new_stmts.append(condition_func_node)

609
        # 6. create & append loop body function node
610 611
        # append return values for loop body
        body_stmts.append(
612
            gast.Return(value=generate_name_node(
613
                loop_var_names, ctx=gast.Load(), gen_tuple_if_single=True)))
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_BODY_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
630
            body=body_stmts,
631 632 633
            decorator_list=[],
            returns=None,
            type_comment=None)
634 635 636 637 638
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
639 640
        new_stmts.append(body_func_node)

641
        # 7. create & append while loop node
642 643 644
        while_loop_nodes = create_while_nodes(
            condition_func_node.name, body_func_node.name, loop_var_names)
        new_stmts.extend(while_loop_nodes)
645 646 647

        return new_stmts

648 649 650 651 652 653 654 655 656 657 658 659 660 661
    def get_while_stmt_nodes(self, node):
        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []

        # Python can create variable in loop and use it out of loop, E.g.
        #
        # while x < 10:
        #     x += 1
        #     y = x
        # z = y
        #
        # We need to create static variable for those variables
        for name in create_var_names:
662 663
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))
664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680

        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_CONDITION_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
681
            body=[gast.Return(value=node.test)],
682 683 684
            decorator_list=[],
            returns=None,
            type_comment=None)
685

686 687 688 689 690
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(
            gast.Return(value=generate_name_node(
                loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_BODY_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
717 718 719 720 721
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
722 723
        new_stmts.append(body_func_node)

724 725 726
        while_loop_nodes = create_while_nodes(
            condition_func_node.name, body_func_node.name, loop_var_names)
        new_stmts.extend(while_loop_nodes)
727
        return new_stmts