loop_transformer.py 23.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import copy
import gast

from collections import defaultdict
from paddle.fluid import unique_name
22
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
23
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
24
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
25
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
26
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
27
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
28
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
29 30 31 32 33 34 35
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node

__all__ = ['LoopTransformer', 'NameVisitor']

WHILE_CONDITION_PREFIX = 'while_condition'
WHILE_BODY_PREFIX = 'while_body'

36 37
FOR_CONDITION_PREFIX = 'for_loop_condition'
FOR_BODY_PREFIX = 'for_loop_body'
38
GENERATE_VARIABLE_PREFIX = 'generate_variable'
39

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

def create_while_node(condition_name, body_name, loop_var_names):
    while_args = []
    while_args.append(
        gast.Name(
            id=condition_name,
            ctx=gast.Param(),
            annotation=None,
            type_comment=None))
    while_args.append(
        gast.Name(
            id=body_name, ctx=gast.Param(), annotation=None, type_comment=None))
    assign_targets = [
        gast.Name(
            id=var_name, ctx=gast.Param(), annotation=None, type_comment=None)
        for var_name in loop_var_names
    ]
    while_args.append(gast.List(elts=assign_targets, ctx=gast.Param()))

59 60 61
    while_func_id = gast.parse(
        'fluid.dygraph.dygraph_to_static.convert_operators.convert_while_loop'
    ).body[0].value
62 63 64 65 66 67 68 69 70 71 72 73 74 75
    while_node = gast.Call(func=while_func_id, args=while_args, keywords=[])
    assign_node = gast.Assign(
        targets=[gast.Tuple(
            elts=assign_targets, ctx=gast.Store())],
        value=while_node)
    return assign_node


class NameVisitor(gast.NodeVisitor):
    '''
    Analysis name liveness for loop transformer
    '''

    def __init__(self, root_node):
76
        # Set of gast.Name or gast.Attribute for variables
77
        self.current_seen_vars = set()
78

79 80 81
        # List of gast.While/gast.For nodes
        self.current_loop = []

82 83 84 85 86
        # List of nodes that have scope of variables.
        self.nodes_with_scope = []

        self.blacklist_names = {"False", "True", "None"}

87 88
        # Mapping from gast.While/gast.For to variable nodes
        self.before_loop_body_vars = defaultdict(set)
89 90
        # NOTE: Use ordered list as dict value
        self.in_loop_vars = defaultdict(list)
91

92 93 94 95 96 97
        # Mapping from gast.While/gast.For to variable nodes which is condition
        # of loop or being modified during the loop
        self.write_in_loop = defaultdict(set)
        self.condition_vars = defaultdict(set)
        self.in_condition = False

98 99 100 101
        self.static_analysis_visitor = StaticAnalysisVisitor(root_node)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )

102 103 104
        self.visit(root_node)

    def get_loop_var_names(self, node):
105 106
        assert isinstance(
            node, (gast.While, gast.For)), "Input node is not gast loop node"
107 108
        loop_var_names = set()
        create_var_names = set()
H
Huihuang Zheng 已提交
109
        read_context = {type(gast.Load()), type(gast.AugLoad())}
110

111 112 113 114 115 116 117 118 119
        in_loop_vars_list = self.in_loop_vars[node]

        # get dict `var_name_to_ctxs`
        var_name_to_ctxs = defaultdict(list)
        for var_node in in_loop_vars_list:
            var_name_to_ctxs[self._var_node_to_name(var_node)].append(
                var_node.ctx)

        in_loop_vars = set(in_loop_vars_list)
120
        in_loop_vars = self._remove_unnecessary_vars(in_loop_vars, node)
121
        in_loop_name_strs = self._var_nodes_to_names(in_loop_vars)
122

123
        before_loop_body_vars = self.before_loop_body_vars[node]
124
        before_loop_body_vars = self._remove_unnecessary_vars(
125
            before_loop_body_vars, node)
126
        before_loop_name_strs = self._var_nodes_to_names(before_loop_body_vars)
127

128
        after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars
129
        after_loop_vars = self._remove_unnecessary_vars(after_loop_vars, node)
130 131
        after_loop_name_strs = self._var_nodes_to_names(after_loop_vars,
                                                        read_context)
132 133
        condition_vars = self.condition_vars[node]
        condition_names = self._var_nodes_to_names(condition_vars)
134

135 136 137 138 139 140 141
        write_vars = self.write_in_loop[node]
        write_names = self._var_nodes_to_names(write_vars)

        name_to_type = {}
        for var in in_loop_vars:
            wrapper = self.node_to_wrapper_map[var]
            name_to_type[self._var_node_to_name(var)] = wrapper.node_var_type
142 143
        for name in in_loop_name_strs:
            if name in before_loop_name_strs:
144 145 146 147 148 149 150 151 152
                # If a variable is used in loop and created before loop

                # If this var is a basic variable and read-only and not
                # condition var, it may not be loop_var else it should
                # be in loop_var as input
                if (not name in condition_names) and (
                        not name in write_names
                ) and self._node_var_type_is_basic(name_to_type[name]):
                    continue
153
                loop_var_names.add(name)
154

155 156 157
            elif name in after_loop_name_strs:
                # If a variable is created in the while loop and read after
                # loop, it should be in loop_var and we should create it
158 159 160 161

                # because name in after_loop_name must be initialized in loop
                # So it is write-only, we don't have to filter read-only basic
                # vars out
162 163
                loop_var_names.add(name)
                create_var_names.add(name)
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
            else:
                # If a variable is used and created in loop, but used before created,
                # it should be in loop_var and we should create it.

                # For example, `var_a` should be in loop_var and we should create it.
                #
                #   res = 0
                #   for i, x in enumerate(x_array):
                #       if i > 2:
                #           x = func1(var_a)
                #       var_a = func2(x)
                #

                if isinstance(var_name_to_ctxs[name][0], gast.Load):
                    loop_var_names.add(name)
                    create_var_names.add(name)
180

181 182 183
        return loop_var_names, create_var_names

    def visit_Name(self, node):
184 185 186
        if self._is_call_func_name_node(node):
            self.generic_visit(node)
            return
187
        if node.id in self.blacklist_names:
188 189
            self.generic_visit(node)
            return
190

191
        self.current_seen_vars.add(node)
192 193 194
        write_context = {
            type(gast.Store()), type(gast.AugStore()), type(gast.Del())
        }
195
        for loop_node in self.current_loop:
196
            self.in_loop_vars[loop_node].append(node)
197 198
            if type(node.ctx) in write_context:
                self.write_in_loop[loop_node].add(node)
199 200
        if self.in_condition:
            self.condition_vars[loop_node].add(node)
201 202
        self.generic_visit(node)

203 204 205 206 207 208 209 210 211 212 213 214 215
    def visit_FunctionDef(self, node):
        self.nodes_with_scope.append(node)
        self.blacklist_names.add(node.name)
        # The variables in the function are not visible to the outside scope.
        before_func_seen_vars = copy.copy(self.current_seen_vars)

        self.generic_visit(node)
        self.nodes_with_scope.pop()
        # After exiting the scope of the node, variables in this scope
        # should be removed from self.current_seen_vars.
        if self.nodes_with_scope:
            self.current_seen_vars = before_func_seen_vars

216 217 218 219 220 221 222 223 224 225
    def visit(self, node):
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        ret = visitor(node)
        return ret

    def visit_Attribute(self, node):
        if self._is_call_func_name_node(node):
            return
        attr_full_name = get_attribute_full_name(node)
226 227 228 229 230 231 232 233 234 235
        # Class variables are not allowed to appear in the arguments list
        # of defined function under class methods in Python.
        """
        def class_func(self):
            def while_loop_body(self.x, y) # `self.x` is illegal.
        """
        # TODO: If do change the variable with `self.var`, need a better
        # way to deal with this case.
        if attr_full_name.startswith("self."):
            return
236
        self.current_seen_vars.add(node)
237

238
        for loop_node in self.current_loop:
239
            self.in_loop_vars[loop_node].append(node)
240

241 242 243
        # sub-nodes are visited during get_attribute_full_name and we shouldn't
        # visit again

244 245
    def visit_For(self, node):
        self.current_loop.append(node)
246
        self.in_condition = True
247
        self.visit(node.target)
248 249
        self.visit(node.iter)
        self.in_condition = False
250
        self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
251 252 253 254 255
        self.generic_visit(node)
        self.current_loop.pop()

    def visit_While(self, node):
        self.current_loop.append(node)
256
        self.in_condition = True
257
        self.visit(node.test)
258
        self.in_condition = False
259
        self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
260 261 262
        self.generic_visit(node)
        self.current_loop.pop()

263 264 265 266
    def _var_nodes_to_names(self, node_set, ctx_filter_set=None):
        ret = set()
        for node in node_set:
            if ctx_filter_set is None or type(node.ctx) in ctx_filter_set:
267
                ret.add(self._var_node_to_name(node))
268 269
        return ret

270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
    def _var_node_to_name(self, node):
        if isinstance(node, gast.Name):
            return node.id
        elif isinstance(node, gast.Attribute):
            return get_attribute_full_name(node)

    def _node_var_type_is_basic(self, node_var_type):
        basic_types = {
            NodeVarType.BOOLEAN, NodeVarType.INT, NodeVarType.FLOAT,
            NodeVarType.STRING
        }
        for t in node_var_type:
            if t in basic_types:
                return True
        return False

286
    def _is_call_func_name_node(self, node):
287
        parent_node = self._get_parent_node(node)
288 289
        if isinstance(parent_node, gast.Call) and parent_node.func == node:
            return True
290 291
        return False

292 293 294 295 296 297 298
    def _get_parent_node(self, node):
        wrapper_node = self.node_to_wrapper_map.get(node)
        if wrapper_node:
            parent_node = wrapper_node.parent.node
            return parent_node
        return None

299
    def _remove_unnecessary_vars(self, loop_vars, loop_node):
300
        """
301 302 303 304
        Remove unnecessary vars from before_loop_vars, after_loop_vars or in_loop_vars about loop_node.
            1. Remove target vars of gast.For from before_loop_vars or after_loop_vars.
            2. Remove vars only in gast.comprehension.
        :param loop_vars: before_loop_vars, after_loop_vars or in_loop_vars of loop_node.
305 306 307
        :param loop_node: Current loop node.
        """

308 309 310 311
        vars_of_list_generator = set()
        target_vars_of_for_node = set()

        for name_node in loop_vars:
312 313 314 315 316
            if not isinstance(name_node, gast.Name):
                continue

            parent_node = self._get_parent_node(name_node)

317 318 319 320
            # NOTE: gast.For.target or gast.comprehension.target can be gast.Tuple.
            #  For examples:
            #   1) `for i, j in enumerate(x)` has two target vars: i and j
            #   2) `[x for x,y in array]` has two target vars: x and y
321 322 323
            if isinstance(parent_node, gast.Tuple):
                parent_node = self._get_parent_node(parent_node)

324 325 326 327 328 329 330
            # 1. Get vars only in gast.comprehension.
            # For examples:
            #  1) [x for x,y in array] -> x, x, y
            #  2) [f(x) for x in array] -> x
            #  3) [func(x, y) for x in array] -> x, x
            if isinstance(parent_node, gast.comprehension):
                # 1.1 target vars in list/set comprehensions
331 332 333 334 335 336
                target_node = parent_node.target
                if isinstance(target_node, gast.Tuple):
                    target_vars = target_node.elts
                else:
                    target_vars = [target_node]

337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
                vars_of_list_generator = vars_of_list_generator | set(
                    target_vars)

                # 1.2 vars from target vars used in elt_node
                target_var_names = {var.id for var in target_vars}
                listcomp_node = self._get_parent_node(parent_node)
                elt_node = listcomp_node.elt
                if isinstance(elt_node, gast.Name):
                    if elt_node.id in target_var_names:
                        vars_of_list_generator.add(elt_node)
                for child_node in gast.walk(elt_node):
                    if isinstance(child_node, gast.Name):
                        if child_node.id in target_var_names:
                            vars_of_list_generator.add(child_node)

            # 2. Get target vars or vars from target vars used in for-loop.
            elif isinstance(parent_node,
                            gast.For) and parent_node is not loop_node:
                # 2.1 target vars in gast.For node.
                target_node = parent_node.target
                if isinstance(target_node, gast.Tuple):
                    target_vars = target_node.elts
                else:
                    target_vars = [target_node]
361

362 363
                target_vars_of_for_node = target_vars_of_for_node | set(
                    target_vars)
364

365 366 367
        # 2.2 vars from target vars used in for-loop
        target_vars_name_strs = {var.id for var in target_vars_of_for_node}
        for var in loop_vars:
368 369
            if not isinstance(var, gast.Name):
                continue
370
            if var.id in target_vars_name_strs and var not in self.condition_vars[
371
                    loop_node]:
372
                target_vars_of_for_node.add(var)
373

374 375
        removed_vars = target_vars_of_for_node | vars_of_list_generator
        return loop_vars - removed_vars
376

377 378 379 380 381 382 383 384 385

class LoopTransformer(gast.NodeTransformer):
    """
    This class transforms python while/for statement into Static Graph Ast
    """

    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
386
        ), "Input non-AstNodeWrapper node for the initialization of LoopTransformer."
387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node
        self.name_visitor = NameVisitor(self.root)

    def transform(self):
        self.visit(self.root)

    def visit(self, node):
        self.generic_visit(node)
        # All parent nodes that may contain gast.While/gast.For
        if hasattr(node, 'body'):
            self.replace_stmt_list(node.body)
        if hasattr(node, 'orelse'):
            self.replace_stmt_list(node.orelse)
        return node

    def replace_stmt_list(self, body_list):
        if not isinstance(body_list, list):
            return

        i = 0
        while i < len(body_list):
            if isinstance(body_list[i], gast.While):
                new_stmts = self.get_while_stmt_nodes(body_list[i])
                body_list[i:i + 1] = new_stmts
                i += len(new_stmts)
            elif isinstance(body_list[i], gast.For):
414 415 416
                new_stmts = self.get_for_stmt_nodes(body_list[i])
                body_list[i:i + 1] = new_stmts
                i += len(new_stmts)
417 418 419
            else:
                i += 1

420 421 422
    def get_for_stmt_nodes(self, node):
        # TODO: consider for - else in python

423 424
        # 1. get key statements for different cases
        # NOTE 1: three key statements:
425 426 427 428
        #   1). init_stmts: list[node], prepare nodes of for loop, may not only one
        #   2). cond_stmt: node, condition node to judge whether continue loop
        #   3). body_stmts: list[node], updated loop body, sometimes we should change
        #       the original statement in body, not just append new statement
429 430 431 432 433 434
        #
        # NOTE 2: The following `for` statements will be transformed to `while` statements:
        #   1). for x in range(*)
        #   2). for x in iter_var
        #   3). for i, x in enumerate(*)

435
        current_for_node_parser = ForNodeVisitor(node)
436 437
        stmts_tuple = current_for_node_parser.parse()
        if stmts_tuple is None:
438
            return [node]
439
        init_stmts, cond_stmt, body_stmts = stmts_tuple
440

441
        # 2. get original loop vars
442 443
        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
444 445 446 447 448 449 450 451 452 453 454 455
        # NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
        # we need append new loop var & remove useless loop var
        #   1. for x in var -> x is no need
        #   2. for i, x in enumerate(var) -> x is no need
        if current_for_node_parser.is_for_iter(
        ) or current_for_node_parser.is_for_enumerate_iter():
            iter_var_name = current_for_node_parser.iter_var_name
            iter_idx_name = current_for_node_parser.iter_idx_name
            loop_var_names.add(iter_idx_name)
            if iter_var_name not in create_var_names:
                loop_var_names.remove(iter_var_name)

456
        # 3. prepare result statement list
457 458 459 460 461 462 463 464 465
        new_stmts = []
        # Python can create variable in loop and use it out of loop, E.g.
        #
        # for x in range(10):
        #     y += x
        # print(x) # x = 10
        #
        # We need to create static variable for those variables
        for name in create_var_names:
466 467
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))
468

469
        # 4. append init statements
470
        new_stmts.extend(init_stmts)
471

472
        # 5. create & append condition function node
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492
        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_CONDITION_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=[gast.Return(value=cond_stmt)],
            decorator_list=[],
            returns=None,
            type_comment=None)
493 494 495 496 497
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
498 499
        new_stmts.append(condition_func_node)

500
        # 6. create & append loop body function node
501 502
        # append return values for loop body
        body_stmts.append(
503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520
            gast.Return(value=generate_name_node(
                loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_BODY_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
521
            body=body_stmts,
522 523 524
            decorator_list=[],
            returns=None,
            type_comment=None)
525 526 527 528 529
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
530 531
        new_stmts.append(body_func_node)

532
        # 7. create & append while loop node
533 534 535 536 537 538
        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name, loop_var_names)
        new_stmts.append(while_loop_node)

        return new_stmts

539 540 541 542 543 544 545 546 547 548 549 550 551 552
    def get_while_stmt_nodes(self, node):
        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []

        # Python can create variable in loop and use it out of loop, E.g.
        #
        # while x < 10:
        #     x += 1
        #     y = x
        # z = y
        #
        # We need to create static variable for those variables
        for name in create_var_names:
553 554
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))
555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571

        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_CONDITION_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
572
            body=[gast.Return(value=node.test)],
573 574 575
            decorator_list=[],
            returns=None,
            type_comment=None)
576

577 578 579 580 581
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(
            gast.Return(value=generate_name_node(
                loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_BODY_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
608 609 610 611 612
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
613 614 615 616 617 618
        new_stmts.append(body_func_node)

        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name, loop_var_names)
        new_stmts.append(while_loop_node)
        return new_stmts