ifelse_transformer.py 14.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from collections import defaultdict

from paddle.fluid import unique_name
19
from paddle.jit.dy2static.utils import (
20
    FOR_ITER_INDEX_PREFIX,
21 22
    FOR_ITER_ITERATOR_PREFIX,
    FOR_ITER_TARGET_PREFIX,
23
    FOR_ITER_TUPLE_INDEX_PREFIX,
24
    FOR_ITER_TUPLE_PREFIX,
25 26 27
    FOR_ITER_VAR_LEN_PREFIX,
    FOR_ITER_VAR_NAME_PREFIX,
    FOR_ITER_ZIP_TO_LIST_PREFIX,
28
    FunctionNameLivenessAnalysis,
29
    GetterSetterHelper,
30 31 32
    ast_to_source_code,
    create_funcDef_node,
    create_get_args_node,
33
    create_name_str,
34 35
    create_nonlocal_stmt_nodes,
    create_set_args_node,
36
)
37

38 39 40 41 42 43 44
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
# It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
from paddle.utils import gast

from .base_transformer import BaseTransformer
45
from .utils import FALSE_FUNC_PREFIX, TRUE_FUNC_PREFIX
46

47
__all__ = []
48

49 50 51
GET_ARGS_FUNC_PREFIX = 'get_args'
SET_ARGS_FUNC_PREFIX = 'set_args'
ARGS_NAME = '__args'
52 53


54
class IfElseTransformer(BaseTransformer):
55 56 57 58
    """
    Transform if/else statement of Dygraph into Static Graph.
    """

59 60
    def __init__(self, root):
        self.root = root
61
        FunctionNameLivenessAnalysis(
62 63
            self.root
        )  # name analysis of current ast tree.
64 65 66 67 68 69 70 71 72

    def transform(self):
        """
        Main function to transform AST.
        """
        self.visit(self.root)

    def visit_If(self, node):
        self.generic_visit(node)
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
        (
            true_func_node,
            false_func_node,
            get_args_node,
            set_args_node,
            return_name_ids,
            push_pop_ids,
        ) = transform_if_else(node, self.root)

        new_node = create_convert_ifelse_node(
            return_name_ids,
            push_pop_ids,
            node.test,
            true_func_node,
            false_func_node,
            get_args_node,
            set_args_node,
        )

        return [
            get_args_node,
            set_args_node,
            true_func_node,
            false_func_node,
        ] + [new_node]
98 99 100 101 102 103 104

    def visit_Call(self, node):
        # Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]`
        if isinstance(node.func, gast.Attribute):
            attribute = node.func
            if attribute.attr == 'numpy':
                node = attribute.value
105
        self.generic_visit(node)
106 107
        return node

108 109 110 111
    def visit_IfExp(self, node):
        """
        Transformation with `true_fn(x) if Tensor > 0 else false_fn(x)`
        """
112
        self.generic_visit(node)
113

114 115 116
        new_node = create_convert_ifelse_node(
            None, None, node.test, node.body, node.orelse, None, None, True
        )
117 118 119 120 121
        # Note: A blank line will be added separately if transform gast.Expr
        # into source code. Using gast.Expr.value instead to avoid syntax error
        # in python.
        if isinstance(new_node, gast.Expr):
            new_node = new_node.value
122

123 124
        return new_node

125

126
class NameVisitor(gast.NodeVisitor):
127 128 129
    def __init__(self, after_node=None, end_node=None):
        # The start node (exclusive) of the visitor
        self.after_node = after_node
130 131
        # The terminate node of the visitor.
        self.end_node = end_node
132 133 134 135
        # Dict to store the names and ctxs of vars.
        self.name_ids = defaultdict(list)
        # List of current visited nodes
        self.ancestor_nodes = []
136 137
        # True when in range (after_node, end_node).
        self._in_range = after_node is None
138
        self._candidate_ctxs = (gast.Store, gast.Load, gast.Param)
139
        self._def_func_names = set()
140 141 142

    def visit(self, node):
        """Visit a node."""
143 144 145 146 147
        if self.after_node is not None and node == self.after_node:
            self._in_range = True
            return
        if node == self.end_node:
            self._in_range = False
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
            return

        self.ancestor_nodes.append(node)
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        ret = visitor(node)
        self.ancestor_nodes.pop()

        return ret

    def visit_If(self, node):
        """
        For nested `if/else`, the created vars are not always visible for parent node.
        In addition, the vars created in `if.body` are not visible for `if.orelse`.

        Case 1:
            x = 1
            if m > 1:
                res = new_tensor
            res = res + 1   # Error, `res` is not visible here.

        Case 2:
            if x_tensor > 0:
                res = new_tensor
            else:
                res = res + 1   # Error, `res` is not visible here.

        In above two cases, we should consider to manage the scope of vars to parsing
        the arguments and returned vars correctly.
        """
178
        if not self._in_range or not self.end_node:
179
            self.generic_visit(node)
180
            return
181
        else:
182 183 184
            before_if_name_ids = copy.deepcopy(self.name_ids)
            body_name_ids = self._visit_child(node.body)
            # If traversal process stops early in `if.body`, return the currently seen name_ids.
185
            if not self._in_range:
186 187 188 189
                self._update_name_ids(before_if_name_ids)
            else:
                else_name_ids = self._visit_child(node.orelse)
                # If traversal process stops early in `if.orelse`, return the currently seen name_ids.
190
                if not self._in_range:
191 192 193 194
                    self._update_name_ids(before_if_name_ids)
                else:
                    # Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
                    # into name_ids.
195
                    new_name_ids = self._find_new_name_ids(
196 197
                        body_name_ids, else_name_ids
                    )
198 199 200 201
                    for new_name_id in new_name_ids:
                        before_if_name_ids[new_name_id].append(gast.Store())

                    self.name_ids = before_if_name_ids
202 203

    def visit_Attribute(self, node):
204
        if not self._in_range or not self._is_call_func_name_node(node):
205 206 207
            self.generic_visit(node)

    def visit_Name(self, node):
208 209 210
        if not self._in_range:
            self.generic_visit(node)
            return
211
        blacklist = {'True', 'False', 'None'}
212 213
        if node.id in blacklist:
            return
214 215
        if node.id in self._def_func_names:
            return
216 217 218 219 220
        if not self._is_call_func_name_node(node):
            if isinstance(node.ctx, self._candidate_ctxs):
                self.name_ids[node.id].append(node.ctx)

    def visit_Assign(self, node):
221 222 223
        if not self._in_range:
            self.generic_visit(node)
            return
224 225 226 227
        # Visit `value` firstly.
        node._fields = ('value', 'targets')
        self.generic_visit(node)

228
    def visit_FunctionDef(self, node):
229 230 231
        # NOTE: We skip to visit names of get_args and set_args, because they contains
        # nonlocal statement such as 'nonlocal x, self' where 'self' should not be
        # parsed as returned value in contron flow.
232 233 234 235
        if (
            GET_ARGS_FUNC_PREFIX in node.name
            or SET_ARGS_FUNC_PREFIX in node.name
        ):
236 237
            return

238 239 240
        if not self._in_range:
            self.generic_visit(node)
            return
241
        self._def_func_names.add(node.name)
242 243 244 245 246 247 248
        if not self.end_node:
            self.generic_visit(node)
        else:
            before_name_ids = copy.deepcopy(self.name_ids)
            self.name_ids = defaultdict(list)
            self.generic_visit(node)

249
            if not self._in_range:
250 251 252 253
                self._update_name_ids(before_name_ids)
            else:
                self.name_ids = before_name_ids

254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
    def _visit_child(self, node):
        self.name_ids = defaultdict(list)
        if isinstance(node, list):
            for item in node:
                if isinstance(item, gast.AST):
                    self.visit(item)
        elif isinstance(node, gast.AST):
            self.visit(node)

        return copy.deepcopy(self.name_ids)

    def _find_new_name_ids(self, body_name_ids, else_name_ids):
        def is_required_ctx(ctxs, required_ctx):
            for ctx in ctxs:
                if isinstance(ctx, required_ctx):
                    return True
            return False

272
        candidate_name_ids = set(body_name_ids.keys()) & set(
273 274
            else_name_ids.keys()
        )
275 276 277
        store_ctx = gast.Store
        new_name_ids = set()
        for name_id in candidate_name_ids:
278 279 280
            if is_required_ctx(
                body_name_ids[name_id], store_ctx
            ) and is_required_ctx(else_name_ids[name_id], store_ctx):
281 282 283 284 285
                new_name_ids.add(name_id)

        return new_name_ids

    def _is_call_func_name_node(self, node):
286
        white_func_names = {'append', 'extend'}
287 288 289 290
        if len(self.ancestor_nodes) > 1:
            assert self.ancestor_nodes[-1] == node
            parent_node = self.ancestor_nodes[-2]
            if isinstance(parent_node, gast.Call) and parent_node.func == node:
291
                # e.g: var_list.append(elem), var_list is also a name_id.
292 293 294 295
                should_skip = (
                    isinstance(node, gast.Attribute)
                    and node.attr in white_func_names
                )
296 297
                if not should_skip:
                    return True
298 299
        return False

300
    def _update_name_ids(self, new_name_ids):
301
        for name_id, ctxs in new_name_ids.items():
302 303
            self.name_ids[name_id] = ctxs + self.name_ids[name_id]

304

305 306 307 308
def _valid_nonlocal_names(return_name_ids, nonlocal_names):
    """
    All var in return_name_ids should be in nonlocal_names.
    Moreover, we will always put return_name_ids in front of nonlocal_names.
309

310 311 312 313 314 315 316 317 318 319 320 321
    For Example:

        return_name_ids: [x, y]
        nonlocal_names : [a, y, b, x]

    Return:
        nonlocal_names : [x, y, a, b]
    """
    assert isinstance(return_name_ids, list)
    for name in return_name_ids:
        if name not in nonlocal_names:
            raise ValueError(
322 323 324 325
                "Required returned var '{}' must be in 'nonlocal' statement '', but not found.".format(
                    name
                )
            )
326 327 328 329 330
        nonlocal_names.remove(name)

    return return_name_ids + nonlocal_names


331 332 333 334
def transform_if_else(node, root):
    """
    Transform ast.If into control flow statement of Paddle static graph.
    """
335

336
    # TODO(liym27): Consider variable like `self.a` modified in if/else node.
337 338
    return_name_ids = sorted(node.pd_scope.modified_vars())
    push_pop_ids = sorted(node.pd_scope.variadic_length_vars())
339
    nonlocal_names = list(return_name_ids)
340 341 342 343 344
    nonlocal_names.sort()
    # NOTE: All var in return_name_ids should be in nonlocal_names.
    nonlocal_names = _valid_nonlocal_names(return_name_ids, nonlocal_names)

    # TODO(dev): Need a better way to deal this.
345 346
    # LoopTransformer will create some special vars, which is not visiable by users. so we can sure it's safe to remove them.
    filter_names = [
347 348 349 350 351 352 353 354 355
        ARGS_NAME,
        FOR_ITER_INDEX_PREFIX,
        FOR_ITER_TUPLE_PREFIX,
        FOR_ITER_TARGET_PREFIX,
        FOR_ITER_ITERATOR_PREFIX,
        FOR_ITER_TUPLE_INDEX_PREFIX,
        FOR_ITER_VAR_LEN_PREFIX,
        FOR_ITER_VAR_NAME_PREFIX,
        FOR_ITER_ZIP_TO_LIST_PREFIX,
356 357 358 359
    ]

    def remove_if(x):
        for name in filter_names:
360 361
            if x.startswith(name):
                return False
362 363 364 365
        return True

    nonlocal_names = list(filter(remove_if, nonlocal_names))
    return_name_ids = nonlocal_names
366

367
    nonlocal_stmt_node = create_nonlocal_stmt_nodes(nonlocal_names)
368

369 370 371 372 373 374 375 376 377
    empty_arg_node = gast.arguments(
        args=[],
        posonlyargs=[],
        vararg=None,
        kwonlyargs=[],
        kw_defaults=None,
        kwarg=None,
        defaults=[],
    )
378 379

    true_func_node = create_funcDef_node(
380
        nonlocal_stmt_node + node.body,
381
        name=unique_name.generate(TRUE_FUNC_PREFIX),
382
        input_args=empty_arg_node,
383 384
        return_name_ids=[],
    )
385
    false_func_node = create_funcDef_node(
386
        nonlocal_stmt_node + node.orelse,
387
        name=unique_name.generate(FALSE_FUNC_PREFIX),
388
        input_args=empty_arg_node,
389 390
        return_name_ids=[],
    )
391

392 393 394
    helper = GetterSetterHelper(None, None, nonlocal_names, push_pop_ids)
    get_args_node = create_get_args_node(helper.union())
    set_args_node = create_set_args_node(helper.union())
395

396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
    return (
        true_func_node,
        false_func_node,
        get_args_node,
        set_args_node,
        return_name_ids,
        push_pop_ids,
    )


def create_convert_ifelse_node(
    return_name_ids,
    push_pop_ids,
    pred,
    true_func,
    false_func,
    get_args_func,
    set_args_func,
    is_if_expr=False,
):
416
    """
417
    Create `paddle.jit.dy2static.convert_ifelse(
418
            pred, true_fn, false_fn, get_args, set_args, return_name_ids)`
419
    to replace original `python if/else` statement.
420
    """
421
    if is_if_expr:
422 423
        true_func_source = f"lambda : {ast_to_source_code(true_func)}"
        false_func_source = f"lambda : {ast_to_source_code(false_func)}"
424 425 426 427 428
    else:
        true_func_source = true_func.name
        false_func_source = false_func.name

    convert_ifelse_layer = gast.parse(
429
        '_jst.IfElse('
430
        '{pred}, {true_fn}, {false_fn}, {get_args}, {set_args}, {return_name_ids}, push_pop_names={push_pop_ids})'.format(
431 432 433
            pred=ast_to_source_code(pred),
            true_fn=true_func_source,
            false_fn=false_func_source,
434 435 436
            get_args=get_args_func.name
            if not is_if_expr
            else 'lambda: None',  # TODO: better way to deal with this
437
            set_args=set_args_func.name
438 439
            if not is_if_expr
            else 'lambda args: None',
440
            return_name_ids=create_name_str(return_name_ids),
441 442 443
            push_pop_ids=create_name_str(push_pop_ids),
        )
    ).body[0]
444

445
    return convert_ifelse_layer