ifelse_transformer.py 15.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from collections import defaultdict

18 19 20 21
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
# It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
22
from paddle.utils import gast
23
from paddle.fluid import unique_name
24

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
from paddle.fluid.dygraph.dygraph_to_static.utils import (
    create_funcDef_node,
    ast_to_source_code,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import (
    FunctionNameLivenessAnalysis,
)
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import (
    AstNodeWrapper,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import (
    create_nonlocal_stmt_nodes,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import (
    create_get_args_node,
    create_set_args_node,
)
42
from .base_transformer import (
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
    BaseTransformer,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import (
    FOR_ITER_INDEX_PREFIX,
    FOR_ITER_TUPLE_PREFIX,
    FOR_ITER_TUPLE_INDEX_PREFIX,
    FOR_ITER_VAR_LEN_PREFIX,
    FOR_ITER_VAR_NAME_PREFIX,
    FOR_ITER_ZIP_TO_LIST_PREFIX,
    FOR_ITER_TARGET_PREFIX,
    FOR_ITER_ITERATOR_PREFIX,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import (
    GetterSetterHelper,
    create_name_str,
)
59

60 61
__all__ = ['IfElseTransformer']

62 63
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
64 65 66
GET_ARGS_FUNC_PREFIX = 'get_args'
SET_ARGS_FUNC_PREFIX = 'set_args'
ARGS_NAME = '__args'
67 68


69
class IfElseTransformer(BaseTransformer):
70 71 72 73 74
    """
    Transform if/else statement of Dygraph into Static Graph.
    """

    def __init__(self, wrapper_root):
75 76 77 78
        assert isinstance(wrapper_root, AstNodeWrapper), (
            "Type of input node should be AstNodeWrapper, but received %s ."
            % type(wrapper_root)
        )
79
        self.root = wrapper_root.node
80
        FunctionNameLivenessAnalysis(
81 82
            self.root
        )  # name analysis of current ast tree.
83 84 85 86 87 88 89 90 91

    def transform(self):
        """
        Main function to transform AST.
        """
        self.visit(self.root)

    def visit_If(self, node):
        self.generic_visit(node)
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
        (
            true_func_node,
            false_func_node,
            get_args_node,
            set_args_node,
            return_name_ids,
            push_pop_ids,
        ) = transform_if_else(node, self.root)

        new_node = create_convert_ifelse_node(
            return_name_ids,
            push_pop_ids,
            node.test,
            true_func_node,
            false_func_node,
            get_args_node,
            set_args_node,
        )

        return [
            get_args_node,
            set_args_node,
            true_func_node,
            false_func_node,
        ] + [new_node]
117 118 119 120 121 122 123

    def visit_Call(self, node):
        # Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]`
        if isinstance(node.func, gast.Attribute):
            attribute = node.func
            if attribute.attr == 'numpy':
                node = attribute.value
124
        self.generic_visit(node)
125 126
        return node

127 128 129 130
    def visit_IfExp(self, node):
        """
        Transformation with `true_fn(x) if Tensor > 0 else false_fn(x)`
        """
131
        self.generic_visit(node)
132

133 134 135
        new_node = create_convert_ifelse_node(
            None, None, node.test, node.body, node.orelse, None, None, True
        )
136 137 138 139 140
        # Note: A blank line will be added separately if transform gast.Expr
        # into source code. Using gast.Expr.value instead to avoid syntax error
        # in python.
        if isinstance(new_node, gast.Expr):
            new_node = new_node.value
141

142 143
        return new_node

144

145
class NameVisitor(gast.NodeVisitor):
146 147 148
    def __init__(self, after_node=None, end_node=None):
        # The start node (exclusive) of the visitor
        self.after_node = after_node
149 150
        # The terminate node of the visitor.
        self.end_node = end_node
151 152 153 154
        # Dict to store the names and ctxs of vars.
        self.name_ids = defaultdict(list)
        # List of current visited nodes
        self.ancestor_nodes = []
155 156
        # True when in range (after_node, end_node).
        self._in_range = after_node is None
157
        self._candidate_ctxs = (gast.Store, gast.Load, gast.Param)
158
        self._def_func_names = set()
159 160 161

    def visit(self, node):
        """Visit a node."""
162 163 164 165 166
        if self.after_node is not None and node == self.after_node:
            self._in_range = True
            return
        if node == self.end_node:
            self._in_range = False
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
            return

        self.ancestor_nodes.append(node)
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        ret = visitor(node)
        self.ancestor_nodes.pop()

        return ret

    def visit_If(self, node):
        """
        For nested `if/else`, the created vars are not always visible for parent node.
        In addition, the vars created in `if.body` are not visible for `if.orelse`.

        Case 1:
            x = 1
            if m > 1:
                res = new_tensor
            res = res + 1   # Error, `res` is not visible here.

        Case 2:
            if x_tensor > 0:
                res = new_tensor
            else:
                res = res + 1   # Error, `res` is not visible here.

        In above two cases, we should consider to manage the scope of vars to parsing
        the arguments and returned vars correctly.
        """
197
        if not self._in_range or not self.end_node:
198
            self.generic_visit(node)
199
            return
200
        else:
201 202 203
            before_if_name_ids = copy.deepcopy(self.name_ids)
            body_name_ids = self._visit_child(node.body)
            # If traversal process stops early in `if.body`, return the currently seen name_ids.
204
            if not self._in_range:
205 206 207 208
                self._update_name_ids(before_if_name_ids)
            else:
                else_name_ids = self._visit_child(node.orelse)
                # If traversal process stops early in `if.orelse`, return the currently seen name_ids.
209
                if not self._in_range:
210 211 212 213
                    self._update_name_ids(before_if_name_ids)
                else:
                    # Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
                    # into name_ids.
214
                    new_name_ids = self._find_new_name_ids(
215 216
                        body_name_ids, else_name_ids
                    )
217 218 219 220
                    for new_name_id in new_name_ids:
                        before_if_name_ids[new_name_id].append(gast.Store())

                    self.name_ids = before_if_name_ids
221 222

    def visit_Attribute(self, node):
223
        if not self._in_range or not self._is_call_func_name_node(node):
224 225 226
            self.generic_visit(node)

    def visit_Name(self, node):
227 228 229
        if not self._in_range:
            self.generic_visit(node)
            return
230
        blacklist = {'True', 'False', 'None'}
231 232
        if node.id in blacklist:
            return
233 234
        if node.id in self._def_func_names:
            return
235 236 237 238 239
        if not self._is_call_func_name_node(node):
            if isinstance(node.ctx, self._candidate_ctxs):
                self.name_ids[node.id].append(node.ctx)

    def visit_Assign(self, node):
240 241 242
        if not self._in_range:
            self.generic_visit(node)
            return
243 244 245 246
        # Visit `value` firstly.
        node._fields = ('value', 'targets')
        self.generic_visit(node)

247
    def visit_FunctionDef(self, node):
248 249 250
        # NOTE: We skip to visit names of get_args and set_args, because they contains
        # nonlocal statement such as 'nonlocal x, self' where 'self' should not be
        # parsed as returned value in contron flow.
251 252 253 254
        if (
            GET_ARGS_FUNC_PREFIX in node.name
            or SET_ARGS_FUNC_PREFIX in node.name
        ):
255 256
            return

257 258 259
        if not self._in_range:
            self.generic_visit(node)
            return
260
        self._def_func_names.add(node.name)
261 262 263 264 265 266 267
        if not self.end_node:
            self.generic_visit(node)
        else:
            before_name_ids = copy.deepcopy(self.name_ids)
            self.name_ids = defaultdict(list)
            self.generic_visit(node)

268
            if not self._in_range:
269 270 271 272
                self._update_name_ids(before_name_ids)
            else:
                self.name_ids = before_name_ids

273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
    def _visit_child(self, node):
        self.name_ids = defaultdict(list)
        if isinstance(node, list):
            for item in node:
                if isinstance(item, gast.AST):
                    self.visit(item)
        elif isinstance(node, gast.AST):
            self.visit(node)

        return copy.deepcopy(self.name_ids)

    def _find_new_name_ids(self, body_name_ids, else_name_ids):
        def is_required_ctx(ctxs, required_ctx):
            for ctx in ctxs:
                if isinstance(ctx, required_ctx):
                    return True
            return False

291
        candidate_name_ids = set(body_name_ids.keys()) & set(
292 293
            else_name_ids.keys()
        )
294 295 296
        store_ctx = gast.Store
        new_name_ids = set()
        for name_id in candidate_name_ids:
297 298 299
            if is_required_ctx(
                body_name_ids[name_id], store_ctx
            ) and is_required_ctx(else_name_ids[name_id], store_ctx):
300 301 302 303 304
                new_name_ids.add(name_id)

        return new_name_ids

    def _is_call_func_name_node(self, node):
305
        white_func_names = set(['append', 'extend'])
306 307 308 309
        if len(self.ancestor_nodes) > 1:
            assert self.ancestor_nodes[-1] == node
            parent_node = self.ancestor_nodes[-2]
            if isinstance(parent_node, gast.Call) and parent_node.func == node:
310
                # e.g: var_list.append(elem), var_list is also a name_id.
311 312 313 314
                should_skip = (
                    isinstance(node, gast.Attribute)
                    and node.attr in white_func_names
                )
315 316
                if not should_skip:
                    return True
317 318
        return False

319
    def _update_name_ids(self, new_name_ids):
320
        for name_id, ctxs in new_name_ids.items():
321 322
            self.name_ids[name_id] = ctxs + self.name_ids[name_id]

323

324 325 326 327
def _valid_nonlocal_names(return_name_ids, nonlocal_names):
    """
    All var in return_name_ids should be in nonlocal_names.
    Moreover, we will always put return_name_ids in front of nonlocal_names.
328

329 330 331 332 333 334 335 336 337 338 339 340
    For Example:

        return_name_ids: [x, y]
        nonlocal_names : [a, y, b, x]

    Return:
        nonlocal_names : [x, y, a, b]
    """
    assert isinstance(return_name_ids, list)
    for name in return_name_ids:
        if name not in nonlocal_names:
            raise ValueError(
341 342 343 344
                "Required returned var '{}' must be in 'nonlocal' statement '', but not found.".format(
                    name
                )
            )
345 346 347 348 349
        nonlocal_names.remove(name)

    return return_name_ids + nonlocal_names


350 351 352 353
def transform_if_else(node, root):
    """
    Transform ast.If into control flow statement of Paddle static graph.
    """
354

355
    # TODO(liym27): Consider variable like `self.a` modified in if/else node.
356
    return_name_ids = sorted(list(node.pd_scope.modified_vars()))
357
    push_pop_ids = sorted(list(node.pd_scope.variadic_length_vars()))
358
    nonlocal_names = list(return_name_ids)
359 360 361 362 363
    nonlocal_names.sort()
    # NOTE: All var in return_name_ids should be in nonlocal_names.
    nonlocal_names = _valid_nonlocal_names(return_name_ids, nonlocal_names)

    # TODO(dev): Need a better way to deal this.
364 365
    # LoopTransformer will create some special vars, which is not visiable by users. so we can sure it's safe to remove them.
    filter_names = [
366 367 368 369 370 371 372 373 374
        ARGS_NAME,
        FOR_ITER_INDEX_PREFIX,
        FOR_ITER_TUPLE_PREFIX,
        FOR_ITER_TARGET_PREFIX,
        FOR_ITER_ITERATOR_PREFIX,
        FOR_ITER_TUPLE_INDEX_PREFIX,
        FOR_ITER_VAR_LEN_PREFIX,
        FOR_ITER_VAR_NAME_PREFIX,
        FOR_ITER_ZIP_TO_LIST_PREFIX,
375 376 377 378
    ]

    def remove_if(x):
        for name in filter_names:
379 380
            if x.startswith(name):
                return False
381 382 383 384
        return True

    nonlocal_names = list(filter(remove_if, nonlocal_names))
    return_name_ids = nonlocal_names
385

386
    nonlocal_stmt_node = create_nonlocal_stmt_nodes(nonlocal_names)
387

388 389 390 391 392 393 394 395 396
    empty_arg_node = gast.arguments(
        args=[],
        posonlyargs=[],
        vararg=None,
        kwonlyargs=[],
        kw_defaults=None,
        kwarg=None,
        defaults=[],
    )
397 398

    true_func_node = create_funcDef_node(
399
        nonlocal_stmt_node + node.body,
400
        name=unique_name.generate(TRUE_FUNC_PREFIX),
401
        input_args=empty_arg_node,
402 403
        return_name_ids=[],
    )
404
    false_func_node = create_funcDef_node(
405
        nonlocal_stmt_node + node.orelse,
406
        name=unique_name.generate(FALSE_FUNC_PREFIX),
407
        input_args=empty_arg_node,
408 409
        return_name_ids=[],
    )
410

411 412 413
    helper = GetterSetterHelper(None, None, nonlocal_names, push_pop_ids)
    get_args_node = create_get_args_node(helper.union())
    set_args_node = create_set_args_node(helper.union())
414

415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
    return (
        true_func_node,
        false_func_node,
        get_args_node,
        set_args_node,
        return_name_ids,
        push_pop_ids,
    )


def create_convert_ifelse_node(
    return_name_ids,
    push_pop_ids,
    pred,
    true_func,
    false_func,
    get_args_func,
    set_args_func,
    is_if_expr=False,
):
435
    """
436
    Create `paddle.jit.dy2static.convert_ifelse(
437
            pred, true_fn, false_fn, get_args, set_args, return_name_ids)`
438
    to replace original `python if/else` statement.
439
    """
440 441 442 443 444 445 446 447
    if is_if_expr:
        true_func_source = "lambda : {}".format(ast_to_source_code(true_func))
        false_func_source = "lambda : {}".format(ast_to_source_code(false_func))
    else:
        true_func_source = true_func.name
        false_func_source = false_func.name

    convert_ifelse_layer = gast.parse(
448
        '_jst.IfElse('
449
        '{pred}, {true_fn}, {false_fn}, {get_args}, {set_args}, {return_name_ids}, push_pop_names={push_pop_ids})'.format(
450 451 452
            pred=ast_to_source_code(pred),
            true_fn=true_func_source,
            false_fn=false_func_source,
453 454 455
            get_args=get_args_func.name
            if not is_if_expr
            else 'lambda: None',  # TODO: better way to deal with this
456
            set_args=set_args_func.name
457 458
            if not is_if_expr
            else 'lambda args: None',
459
            return_name_ids=create_name_str(return_name_ids),
460 461 462
            push_pop_ids=create_name_str(push_pop_ids),
        )
    ).body[0]
463

464
    return convert_ifelse_layer