partial_program.py 39.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from copy import deepcopy

17
import numpy as np
18

19
import paddle
20 21
from paddle import _legacy_C_ops
from paddle.fluid import backward, core, framework, program_guard
22
from paddle.fluid.compiler import BuildStrategy
23 24 25 26 27 28
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _apply_pass

from . import logging_utils
from .return_transformer import RETURN_NO_VALUE_MAGIC_NUM
29
from .utils import _out_grad_names, _param_grad_names
30

31 32
__all__ = []

33

34
class NestSequence:
35 36 37 38 39 40 41
    """
    A wrapper class that easily to flatten and restore the nest structure of
    given sequence.
    """

    def __init__(self, raw_input, need_check=False):
        self.__raw_input = raw_input
42
        self.__input_list = self.tolist()
43 44 45 46 47 48 49
        self.__var_ids = self._get_var_ids()
        self._check_non_variable(need_check)

    def tolist(self):
        """
        Flattens the nested sequences into single list.
        """
50
        return paddle.utils.flatten(self.__raw_input)
51 52 53 54 55

    def restore(self, value_list):
        """
        Restores the nested sequence from value list.
        """
56
        assert len(self.__input_list) == len(value_list)
57
        return paddle.utils.pack_sequence_as(self.__raw_input, value_list)
58 59 60

    def _get_var_ids(self):
        var_ids = []
61
        for idx, var in enumerate(self.__input_list):
62
            if isinstance(
63 64
                var, (framework.Variable, core.VarBase, core.eager.Tensor)
            ):
65 66 67 68 69 70 71 72 73 74
                var_ids.append(idx)

        return var_ids

    def _check_non_variable(self, need_check):
        """
        Raises warning if output of traced function contains non-tensor type values.
        """
        if need_check:
            warning_types = set()
75
            for var in self.__input_list:
76
                if not isinstance(
77 78
                    var, (framework.Variable, core.VarBase, core.eager.Tensor)
                ):
79 80
                    warning_types.add(type(var))
            if warning_types:
81
                logging_utils.warn(
82 83
                    "Output of traced function contains non-tensor type values: {}. "
                    "Currently, We don't support to update them while training and will return "
84 85 86 87
                    "what we first saw. Please try to return them as tensor.".format(
                        list(warning_types)
                    )
                )
88 89 90 91 92 93

    @property
    def var_ids(self):
        return self.__var_ids

    def __getitem__(self, item):
94
        return self.__input_list[item]
95

96

97
class LazyInitialized:
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    """
    Descriptor to implement lazy initialization of property.
    """

    def __init__(self, function):
        self.function = function

    def __get__(self, instance, cls):
        val = self.function(instance)
        setattr(instance, self.function.__name__, val)
        return val


def _change_is_test_status(program, is_test):
    # change all `is_test` attributes
    for block in program.blocks:
        for op in block.ops:
            if op.has_attr('is_test'):
                op._set_attr('is_test', is_test)
    return program


120 121 122 123 124
class ProgramInfo:
    """
    A helper class to recoder Program information
    """

125
    def __init__(self):
126 127 128 129 130
        self.op_size = {
            'fp32': -1,
            'amp': -1,
            'fp16': -1,
        }
131 132 133 134 135 136 137 138 139 140 141 142 143 144
        self.programs = {}
        self.mode = "infer"

    def __call__(self, key, prog_creator):
        """
        Recoder infer program and op size.
        """
        assert key in ['fp32', 'amp', 'fp16']
        if key not in self.programs:
            infer_prog = prog_creator(is_infer_mode=True)
            self.programs[key] = infer_prog
            self.op_size[key] = infer_prog.desc.block(0).op_size()

        return self.programs[key], self.op_size[key]
145 146


147
class PartialProgramLayer:
148
    """
H
hjyp 已提交
149
    PartialProgramLayer wraps all the ops from layers decorated by `@to_static`
150 151 152
    and execute them as a static subgraph.

    .. note::
153 154 155
        **1. This is a very low level API. Users should not use this API
             directly. Please use `partial_program_from(concrete_program)`
             to create it.
156 157 158 159
        **2. LoDTensorArray is not currently supported in the output.

    Args:
        main_program(Program): The main program that contains ops need to be executed.
H
hjyp 已提交
160 161
        inputs(list[Variable]): The input list of the decorated function by `@to_static`.
        outputs(list[Variable]): The output list of the decorated function by `@to_static`.
162 163 164
        parameters(list[VarBase]|None): All trainable parameters included in the program. Default None.

    Returns:
165
        Layer: A Layer object that run all ops internally in static graph mode.
166 167
    """

168 169 170
    def __init__(
        self, main_program, inputs, outputs, parameters=None, **kwargs
    ):
171
        super().__init__()
172 173
        self._inputs = NestSequence(inputs)
        self._outputs = NestSequence(outputs, need_check=True)
174
        self._params = parameters if parameters is not None else []
175

176 177 178
        self._build_strategy = kwargs.get('build_strategy', BuildStrategy())
        assert isinstance(self._build_strategy, BuildStrategy)

179
        self._origin_main_program = self._verify_program(main_program)
180 181 182
        self._cuda_graph_vec = self._create_cuda_graph_vec()
        self._cuda_graph_capture_mode = ""
        self._cuda_graph_pool_id = 0
183
        # Set default mode to train
184
        self.training = True
185
        self._infer_info = ProgramInfo()
186

187 188 189 190
        custom_white_list, custom_black_list = None, None
        tracer = framework._dygraph_tracer()
        if tracer:
            custom_white_list, custom_black_list = tracer._get_amp_op_list()
191
        # For AMP training
192
        self._amp_list = paddle.static.amp.fp16_lists.AutoMixedPrecisionLists(
193
            custom_white_list=custom_white_list,
194 195
            custom_black_list=custom_black_list,
        )
196

197 198 199
        # program_id -> list(scope)
        self._scope_cache = {}

200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
    def __call__(self, inputs):
        """
        Execute static graph by Interpreter and Return dynamic Tensors.
        """
        in_vars, out_vars = self._prepare(inputs)
        self._cast_fp16_if_pure_fp16(in_vars)
        attrs = self._prepare_attributes()

        _legacy_C_ops.run_program(
            self._valid_vars(in_vars),
            self._valid_vars(self._params),
            self._valid_vars(out_vars),
            self._create_scope_vec(
                program_id=self.program_id, use_scope_cache=True
            ),
            self._double_grads,
            self._cuda_graph_vec,
            *attrs
        )
        restored_nest_out = self._restore_out(out_vars)
        return self._remove_no_value(restored_nest_out)

222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
    def _get_scope(self, program_id=None, use_scope_cache=False):
        if use_scope_cache:
            if program_id not in self._scope_cache:
                scope = core.Scope()
                self._scope_cache[program_id] = [scope]
                return scope
            else:
                for scope in self._scope_cache[program_id]:
                    if scope._can_reuesd:
                        return scope
                scope = core.Scope()
                self._scope_cache[program_id].append(scope)
                return scope
        else:
            return core.Scope()

238 239 240 241
    @LazyInitialized
    def _double_grads(self):
        return self._get_double_grads(self._origin_main_program)

242 243 244 245 246 247 248
    # whole
    @switch_to_static_graph
    def _create_program(self, is_infer_mode=False):
        if is_infer_mode:
            return self._origin_main_program.clone(for_test=is_infer_mode)
        else:
            train_program = self._append_backward_desc(
249 250
                self._origin_main_program
            )
251 252 253
            # Note: Only set grad type once after initializing train program. So we put it here.
            self._set_grad_type(self._params, train_program)
            return train_program
254

255 256 257 258
    @switch_to_static_graph
    def _create_amp_program(self, is_infer_mode=False):
        amp_program = self._origin_main_program.clone(for_test=is_infer_mode)
        with program_guard(amp_program):
259 260 261
            paddle.static.amp.fp16_utils.rewrite_program(
                amp_program, self._amp_list
            )
262 263 264 265 266 267
        if is_infer_mode:
            return amp_program
        else:
            train_amp_program = self._append_backward_desc(amp_program)
            self._set_grad_type(self._params, train_amp_program)
            return train_amp_program
268

269 270 271
    @switch_to_static_graph
    def _create_pure_fp16_program(self, is_infer_mode=False):
        pure_fp16_program = self._origin_main_program.clone(
272 273
            for_test=is_infer_mode
        )
274
        with program_guard(pure_fp16_program):
275
            paddle.static.amp.fp16_utils.cast_model_to_fp16(
276 277
                pure_fp16_program, self._amp_list, use_fp16_guard=False
            )
J
Jiabin Yang 已提交
278 279 280 281 282

        core.check_and_set_prim_all_enabled()
        from paddle.incubate.autograd.primapi import to_prim

        to_prim(pure_fp16_program.blocks)
283 284 285 286
        if is_infer_mode:
            return pure_fp16_program
        else:
            train_pure_fp16_program = self._append_backward_desc(
287 288
                pure_fp16_program
            )
289 290
            self._set_grad_type(self._params, train_pure_fp16_program)
            return train_pure_fp16_program
291

292
    @switch_to_static_graph
293
    def _create_forward_backward_train_program(self):
294
        whole_program = self._train_program
295
        _, forward_end_op_index = self._infer_info('fp32', self._create_program)
296
        assert forward_end_op_index >= 0
297

298 299 300
        return self._get_forward_backward_program_form(
            whole_program, forward_end_op_index
        )
301

302 303
    @switch_to_static_graph
    def _create_forward_backward_train_amp_program(self):
304
        whole_program = self._train_amp_program
305 306 307
        _, forward_end_op_index = self._infer_info(
            'amp', self._create_amp_program
        )
308
        assert forward_end_op_index >= 0
309

310 311 312
        return self._get_forward_backward_program_form(
            whole_program, forward_end_op_index
        )
313 314 315

    @switch_to_static_graph
    def _create_forward_backward_train_pure_fp16_program(self):
316
        whole_program = self._train_pure_fp16_program
317 318 319
        _, forward_end_op_index = self._infer_info(
            'fp16', self._create_pure_fp16_program
        )
320
        assert forward_end_op_index >= 0
321

322 323 324
        return self._get_forward_backward_program_form(
            whole_program, forward_end_op_index
        )
325 326

    @LazyInitialized
327 328
    def _train_program(self):
        return self._create_program()
329

330
    @LazyInitialized
331
    def _infer_program(self):
332 333
        program, op_size = self._infer_info('fp32', self._create_program)
        return self._build_infer_program(program, op_size)
334

335 336 337 338 339 340
    @LazyInitialized
    def _train_amp_program(self):
        return self._create_amp_program()

    @LazyInitialized
    def _infer_amp_program(self):
341 342
        program, op_size = self._infer_info('amp', self._create_amp_program)
        return self._build_infer_program(program, op_size)
343 344 345

    @LazyInitialized
    def _train_pure_fp16_program(self):
346
        return self._create_pure_fp16_program()
347

348
    @LazyInitialized
349
    def _infer_pure_fp16_program(self):
350 351
        program, op_size = self._infer_info(
            'fp16', self._create_pure_fp16_program
352
        )
353
        return self._build_infer_program(program, op_size)
354

355
    @LazyInitialized
356 357 358
    def _train_forward_backward_program(self):
        program = self._create_forward_backward_train_program()
        return program
359 360

    @LazyInitialized
361 362 363 364
    def _train_amp_forward_backward_program(self):
        program = self._create_forward_backward_train_amp_program()
        return program

365 366 367 368
    @LazyInitialized
    def _empty_backward_program_for_eval(self):
        return paddle.static.Program()

369 370 371 372 373
    @LazyInitialized
    def _train_pure_fp16_forward_backward_program(self):
        program = self._create_forward_backward_train_pure_fp16_program()
        return program

374 375
    @LazyInitialized
    def _train_program_id(self):
376
        program_id = paddle.utils._hash_with_id(self._train_program, self)
377 378 379
        core._set_cached_executor_build_strategy(
            program_id, self._build_strategy
        )
380
        return program_id
381

382 383
    @LazyInitialized
    def _infer_program_id(self):
384
        return paddle.utils._hash_with_id(self._infer_program, self)
385

386 387
    @LazyInitialized
    def _train_amp_program_id(self):
388
        program_id = paddle.utils._hash_with_id(self._train_amp_program, self)
389 390 391
        core._set_cached_executor_build_strategy(
            program_id, self._build_strategy
        )
392 393
        return program_id

394 395
    @LazyInitialized
    def _infer_amp_program_id(self):
396
        return paddle.utils._hash_with_id(self._infer_amp_program, self)
397

398 399
    @LazyInitialized
    def _train_pure_fp16_program_id(self):
400 401 402
        program_id = paddle.utils._hash_with_id(
            self._train_pure_fp16_program, self
        )
403 404 405
        core._set_cached_executor_build_strategy(
            program_id, self._build_strategy
        )
406 407
        return program_id

408 409
    @LazyInitialized
    def _infer_pure_fp16_program_id(self):
410
        return paddle.utils._hash_with_id(self._infer_pure_fp16_program, self)
411

412 413
    @LazyInitialized
    def _param_grad_names(self):
414
        return _param_grad_names(self._train_program.desc, self._params)
415 416 417

    @LazyInitialized
    def _out_grad_names(self):
418 419 420 421 422
        return _out_grad_names(
            self._train_program.desc,
            self._create_program(is_infer_mode=True).desc.block(0).op_size(),
            len(self._outputs.var_ids),
        )
423

424
    @property
425 426 427 428 429 430 431 432 433 434 435 436 437 438
    def program(self):
        """
        Return current train or eval program.
        """
        if self.training:
            return self.train_program
        else:
            return self.infer_program

    @property
    def program_id(self):
        """
        Return current train or eval program hash id.
        """
J
Jiabin Yang 已提交
439 440
        from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard

441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
        if self.training:
            if _in_amp_guard():
                return self._train_amp_program_id
            elif _in_pure_fp16_guard():
                return self._train_pure_fp16_program_id
            else:
                return self._train_program_id
        else:
            if _in_amp_guard():
                return self._infer_amp_program_id
            elif _in_pure_fp16_guard():
                return self._infer_pure_fp16_program_id
            else:
                return self._infer_program_id

456 457
    @property
    def train_program(self):
J
Jiabin Yang 已提交
458 459
        from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard

460 461 462 463 464 465 466 467 468
        if _in_amp_guard():
            return self._train_amp_program
        elif _in_pure_fp16_guard():
            return self._train_pure_fp16_program
        else:
            return self._train_program

    @property
    def infer_program(self):
J
Jiabin Yang 已提交
469 470
        from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard

471 472 473 474 475 476 477 478 479
        if _in_amp_guard():
            return self._infer_amp_program
        elif _in_pure_fp16_guard():
            return self._infer_pure_fp16_program
        else:
            return self._infer_program

    @property
    def forward_program(self):
J
Jiabin Yang 已提交
480 481
        from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard

482 483 484 485 486 487 488 489 490 491 492 493 494
        if self.training:
            if _in_amp_guard():
                progs = self._train_amp_forward_backward_program
            elif _in_pure_fp16_guard():
                progs = self._train_pure_fp16_forward_backward_program
            else:
                progs = self._train_forward_backward_program
            return progs[0]
        else:
            return self.infer_program

    @property
    def backward_program(self):
J
Jiabin Yang 已提交
495 496
        from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard

497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516
        if self.training:
            if _in_amp_guard():
                progs = self._train_amp_forward_backward_program
            elif _in_pure_fp16_guard():
                progs = self._train_pure_fp16_forward_backward_program
            else:
                progs = self._train_forward_backward_program
            return progs[1]
        else:
            """
            Can't just return paddle.static.Program(), because self.backward_program is a property,
            whenever we call this method, a tmp Program() object is created and is gc immediatly
            after executed the following line in PartialProgramLayer.__call__.

            >>> self.backward_program.desc.block(0),

            When we access RunProgramAPI, it's possible to get an invalid backward_program address.
            """
            return self._empty_backward_program_for_eval

517 518 519 520 521 522 523 524 525 526 527 528
    def _verify_program(self, main_program):
        """
        Verify that the program parameter is initialized, prune some unused params,
        and remove redundant op callstack.
        """
        # 1. Check all params from main program can be found in self._params
        self._check_params_all_inited(main_program)
        # 2. Prune the parameters not used anywhere in the program.
        self._prune_unused_params(main_program)

        return main_program

529 530 531
    def prepare_gradient_aggregation(
        self, start_idx, main_program, target_program
    ):
532 533 534 535 536 537 538
        """
        Why we need add gradient aggregation operation ?
        In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
        def forward(self, in):
            x = 2 * in  # <---- x is a non-leaf node in program.
            y = x + 3
            return x, y
539

540 541 542 543 544 545 546 547 548
        loss = forward(in)[0].sum()
        loss.backward()  # <----- x@grad will be overwrited by elementwise_add_grad Op
        """

        def _need_aggregation(var):
            """
            if exist a op whose inputs is var, then return True
            """
            if not isinstance(var, framework.Variable) or var.type not in [
549 550
                core.VarDesc.VarType.LOD_TENSOR,
                core.VarDesc.VarType.SELECTED_ROWS,
551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566
            ]:
                return False
            if var.dtype not in [paddle.float32, paddle.float64]:
                return False
            for op in main_program.block(0).ops:
                for in_arg in op.input_arg_names:
                    if in_arg == var.name:
                        return True
            return False

        def _insert_aggregation_ops_for_var(target_program, var):
            suffix = "@dy2static"
            var_grad_name = var.grad_name
            new_grad_name = var.name + suffix + "@GRAD"
            finded_ops = list(
                filter(
567 568 569 570 571 572 573 574 575 576
                    lambda x: x[0] >= start_idx
                    and any(
                        [
                            out_arg == var_grad_name
                            for out_arg in x[1].output_arg_names
                        ]
                    ),
                    enumerate(target_program.block(0).ops),
                )
            )
577 578 579 580 581 582

            # len(finded_ops) may equals zero when stop_gradient works.
            # len(finded_ops) may > 1, because we may have fill_constant op.
            if len(finded_ops) == 0:
                return None
            # step1: create a new var named var.name@GRAD
583 584 585 586 587 588
            target_program.block(0).create_var(
                name=new_grad_name,
                type=var.type,
                dtype=var.dtype,
                shape=var.shape,
            )
589 590 591 592 593 594 595 596 597 598
            # step2: rename the var.name@GRAD to var.name@GRAD@dy2static
            for idx, op in finded_ops:
                op._rename_input(var_grad_name, new_grad_name)
                op._rename_output(var_grad_name, new_grad_name)
            # step3: insert sum op to aggregate the gradient.
            #        var.name@GRAD = sum(var.name@dy2static@GRAD, var.name@GRAD)
            target_program.block(0)._insert_op(
                finded_ops[-1][0] + 1,
                type='sum',
                inputs={'X': [var_grad_name, new_grad_name]},
599 600
                outputs={"Out": var_grad_name},
            )
601 602 603
            return None

        to_processed_vars = list(
604 605
            filter(_need_aggregation, self._outputs.tolist())
        )
606 607 608
        for _var in to_processed_vars:
            _insert_aggregation_ops_for_var(target_program, _var)

609
    @switch_to_static_graph
610
    def _append_backward_desc(self, main_program):
611 612
        # make sure all status of is_test are False in train mode.
        program = _change_is_test_status(main_program.clone(), is_test=False)
613
        targets = []
614
        for out in self._outputs.tolist():
615 616 617
            if isinstance(out, framework.Variable):
                targets.append(program.global_block().var(out.name))

618
        if targets:
619 620
            # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
            core.check_and_set_prim_all_enabled()
621
            backward.gradients(targets=targets, inputs=[])
622

623
        start_idx = len(main_program.block(0).ops) + len(self._outputs.tolist())
624 625

        self.prepare_gradient_aggregation(start_idx, main_program, program)
626

627 628
        return program

629 630 631
    def _prune_unused_params(self, program):
        """
        Prune the parameters not used anywhere in the program.
H
hjyp 已提交
632
        The `@to_static` may only decorated a sub function which
633 634 635 636 637 638
        contains some unused parameters created in `__init__`.
        So prune these parameters to avoid unnecessary operations in
        `run_program_op`.
        """
        required_params = []
        for param in self._params:
639
            found_param = False
640
            for block in program.blocks:
641
                for op in block.ops:
642 643 644 645
                    if (
                        param.name in op.input_arg_names
                        or param.name in op.output_arg_names
                    ):
646 647 648 649
                        required_params.append(param)
                        found_param = True
                        break
                if found_param:
650 651 652 653
                    break

        self._params = required_params

654 655 656 657 658 659
    def _get_double_grads(self, program):
        double_grads = []
        for block in program.blocks:
            for name in block.vars:
                if "@GRAD" in name:
                    var_desc = block.vars[name].desc
J
Jiabin Yang 已提交
660
                    var_base = None
661
                    if not framework.global_var._in_eager_mode_:
662 663 664 665 666 667 668
                        var_base = core.VarBase(
                            var_desc.dtype(),
                            var_desc.shape(),
                            var_desc.name(),
                            var_desc.type(),
                            False,
                        )
J
Jiabin Yang 已提交
669
                    else:
670 671 672 673 674 675 676
                        var_base = core.eager.Tensor(
                            var_desc.dtype(),
                            var_desc.shape(),
                            var_desc.name(),
                            var_desc.type(),
                            False,
                        )
677
                    double_grads.append(var_base)
678
        return self._valid_vars(double_grads)
679

680
    def _cast_fp16_if_pure_fp16(self, in_vars):
J
Jiabin Yang 已提交
681 682
        from paddle.amp.auto_cast import _in_pure_fp16_guard

683 684 685 686 687 688 689 690 691 692
        if _in_pure_fp16_guard():
            for i, var in enumerate(in_vars):
                name = var.name
                if (
                    self.program.global_block().has_var(name)
                    and self.program.global_block().var(name).dtype
                    == paddle.float16
                ):
                    in_vars[i] = var.astype('float16')
                    in_vars[i].name = name
693

694
    def _prepare_attributes(self):
695
        attrs = [
696 697 698 699
            'forward_global_block',
            self.forward_program.desc.block(0),
            'backward_global_block',
            self.backward_program.desc.block(0),
700 701 702 703
            'is_test',
            not self.training,
            'program_id',
            self.program_id,
704
        ]
705 706 707 708 709 710 711 712 713 714 715 716
        if self.training:
            # NOTE: In the case of higher-order gradient, the names of the parameter grads may be like
            # `grad/grad/grad/linear_0.w_0@GRAD` instead of simply `linear_0.w_0@GRAD`, so we get
            # the correct names of the parameter grads from program. And out grads are similar to above.
            attrs.extend(
                (
                    'param_grad_names',
                    self._param_grad_names,
                    'out_grad_names',
                    self._out_grad_names,
                )
            )
717 718
        if self._cuda_graph_capture_mode:
            attrs.extend(
719 720 721 722 723 724 725
                (
                    'cuda_graph_capture_mode',
                    self._cuda_graph_capture_mode,
                    'cuda_graph_pool_id',
                    self._cuda_graph_pool_id,
                )
            )
726
        return attrs
727

728 729 730 731 732 733 734 735 736 737 738 739
    @switch_to_static_graph
    def _build_infer_program(self, infer_program, forward_end_op_index):
        forward_skip_vars = self._parse_skip_gc_vars(infer_program)
        builded_infer_program = add_build_strategy_for(
            infer_program,
            0,
            forward_end_op_index,
            self._build_strategy,
            forward_skip_vars,
        )
        self._apply_inplace_pass(builded_infer_program, None)
        return builded_infer_program
740

741
    @switch_to_static_graph
742 743 744
    def _get_forward_backward_program_form(
        self, whole_program, forward_end_op_index
    ):
745 746
        # NOTE(dev): We apply build_strategy for backward firstly to
        # avoid skipping more gc variables.
747
        backward_start_op_index = forward_end_op_index + len(
748 749
            self._outputs.var_ids
        )
750
        backward_end_op_index = whole_program.desc.block(0).op_size()
751 752 753 754 755
        # For Backward process in CINN, all param@GRAD shoule be skipped for GC, because
        # they will be shared in scope and used by optimizer.
        backward_skip_vars = (
            self._parse_skip_gc_vars(whole_program) + self._param_grad_names
        )
756
        backward_builded_program = add_build_strategy_for(
757 758 759 760
            whole_program,
            backward_start_op_index,
            backward_end_op_index,
            self._build_strategy,
761 762 763 764 765 766 767 768 769 770 771 772
            backward_skip_vars,
        )

        forward_skip_vars = self._parse_skip_gc_vars(
            whole_program, backward_builded_program
        )
        forward_builded_program = add_build_strategy_for(
            whole_program,
            0,
            forward_end_op_index,
            self._build_strategy,
            forward_skip_vars,
773
        )
774

775 776 777
        self._apply_inplace_pass(
            forward_builded_program, backward_builded_program
        )
778 779 780 781 782 783
        return [forward_builded_program, backward_builded_program]

    def _apply_inplace_pass(self, forward_program, backward_program):
        attr_types = {
            "use_cuda": "bool",
            "mem_opt_skip_vars": "list[str]",
784
            "for_partial_block": "bool",
785 786 787 788
        }
        empty_startup_program = paddle.static.Program()
        use_cuda = True if core.is_compiled_with_cuda() else False
        # skip data var
789 790 791 792
        forward_mem_opt_skip_vars = self._parse_skip_gc_vars(
            forward_program, backward_program
        )
        backward_mem_opt_skip_vars = self._parse_skip_gc_vars(forward_program)
793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818
        if forward_program:
            attrs = {
                "use_cuda": use_cuda,
                "mem_opt_skip_vars": forward_mem_opt_skip_vars,
                "for_partial_block": True,
            }
            _apply_pass(
                forward_program,
                empty_startup_program,
                "buffer_shared_inplace_pass",
                attrs,
                attr_types,
            )
        if backward_program:
            attrs = {
                "use_cuda": use_cuda,
                "mem_opt_skip_vars": backward_mem_opt_skip_vars,
                "for_partial_block": True,
            }
            _apply_pass(
                backward_program,
                empty_startup_program,
                "buffer_shared_inplace_pass",
                attrs,
                attr_types,
            )
819

820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846
    @LazyInitialized
    def _inout_var_names(self):
        """
        Returns Variable Names from self._inputs and self.outputs
        """
        var_names = []
        for var in self._inputs:
            if isinstance(var, paddle.fluid.framework.Variable):
                var_names.append(var.desc.name())
        for var in self._outputs:
            if isinstance(var, paddle.fluid.framework.Variable):
                var_names.append(var.desc.name())
        return var_names

    def _parse_skip_gc_vars(self, program, backward_program=None):
        """
        Parse variables that need to skip GC after execute it.
        If specify backward_program, it will keep the variables used in backward.
        """
        # skip data var, DO NOT ignore this deepcopy
        skip_vars = deepcopy(self._inout_var_names)
        for var_name, var in program.global_block().vars.items():
            if var.is_data:
                skip_vars.append(var_name)

        if backward_program:
            for var_name in core.parse_safe_eager_deletion_skip_vars(
847
                backward_program.desc, True
848 849 850 851
            ):
                skip_vars.append(var_name)
        return skip_vars

852 853 854 855 856
    def _prepare(self, inputs):
        """
        Prepare inputs, outputs, attrs.
        """
        assert isinstance(inputs, (tuple, list))
857
        # Flatten inputs with nested structure into single list.
858
        flatten_inputs = paddle.utils.flatten(inputs)
859 860
        # Convert variable into VarBase and feed in training data.
        input_vars = []
861
        expected_place = framework._current_expected_place()
862
        for i, value in enumerate(flatten_inputs):
863
            if isinstance(value, np.ndarray):
J
Jiabin Yang 已提交
864
                var = None
865
                if not framework.global_var._in_eager_mode_:
866 867 868 869 870 871 872
                    var = core.VarBase(
                        value=value,
                        name=self._inputs[i].desc.name(),
                        persistable=False,
                        place=expected_place,
                        zero_copy=True,
                    )
J
Jiabin Yang 已提交
873
                else:
874 875 876 877 878 879 880
                    var = core.eager.Tensor(
                        value=value,
                        name=self._inputs[i].desc.name(),
                        persistable=False,
                        place=expected_place,
                        zero_copy=True,
                    )
J
Jiabin Yang 已提交
881
            elif isinstance(value, (core.VarBase, core.eager.Tensor)):
882 883 884 885
                # NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
                # into CUDAPlace when it's as input of multi Ops. so we move it in advance
                # to avoid this problem.
                if value.stop_gradient and not value.place._equals(
886 887
                    expected_place
                ):
888 889
                    var = value._copy_to(expected_place, False)
                    var.stop_gradient = True
890 891
                else:
                    var = value
892
                var.name = self._inputs[i].desc.name()
893 894 895
            else:
                continue
            input_vars.append(var)
896

897 898 899
        # mapping from name(string) -> VarBase
        out_varbase_map = {}

900 901
        def create_out(var_id):
            var = self._outputs[var_id]
902
            assert isinstance(var, framework.Variable)
903
            var_desc = var.desc
J
Jiabin Yang 已提交
904
            varbase = None
905 906 907 908

            if var_desc.name() in out_varbase_map:
                return out_varbase_map[var_desc.name()]

909
            if not framework.global_var._in_eager_mode_:
910 911 912 913 914 915 916
                var_base = core.VarBase(
                    var_desc.dtype(),
                    var_desc.shape(),
                    var_desc.name(),
                    var_desc.type(),
                    False,
                )
J
Jiabin Yang 已提交
917
            else:
918 919 920 921 922 923 924
                var_base = core.eager.Tensor(
                    var_desc.dtype(),
                    var_desc.shape(),
                    var_desc.name(),
                    var_desc.type(),
                    False,
                )
925
            var_base.stop_gradient = var.stop_gradient
926
            out_varbase_map[var_desc.name()] = var_base
927 928 929 930 931 932
            return var_base

        # Create VarBase to receive output data.
        out_vars = list(map(create_out, self._outputs.var_ids))

        return input_vars, out_vars
933

934
    def _create_scope_vec(self, program_id=None, use_scope_cache=False):
935
        # Hold forward variables
J
Jiabin Yang 已提交
936
        tmp_scope_vec = None
937 938 939
        inner_scope = self._get_scope(
            program_id=program_id, use_scope_cache=use_scope_cache
        )
940
        if not framework.global_var._in_eager_mode_:
941 942 943 944 945 946 947
            tmp_scope_vec = core.VarBase(
                core.VarDesc.VarType.FP32,
                [],
                "program_out_scope",
                core.VarDesc.VarType.STEP_SCOPES,
                True,
            )
J
Jiabin Yang 已提交
948
            tmp_scope_vec.value().set_scope(inner_scope)
949 950
        else:
            tmp_scope_vec = [inner_scope]
951
        return tmp_scope_vec
952

953
    def _create_cuda_graph_vec(self):
954 955 956 957 958 959 960
        var = core.VarBase(
            core.VarDesc.VarType.FP32,
            [],
            "cuda_graph",
            core.VarDesc.VarType.RAW,
            True,
        )
961 962 963
        var.stop_gradient = True
        return var

964 965 966 967 968 969 970 971 972
    def _restore_out(self, out_vars):
        """
        Restores same nested outputs by only replacing the Variable with VarBase.
        """

        flatten_outputs = self._outputs.tolist()
        for i, idx in enumerate(self._outputs.var_ids):
            flatten_outputs[idx] = out_vars[i]
        outs = self._outputs.restore(flatten_outputs)
973
        if outs is not None and len(outs) == 1:
974 975 976 977
            outs = outs[0]

        return outs

978 979 980 981
    @switch_to_static_graph
    def _clone_for_test(self, main_program):
        return main_program.clone(for_test=True)

982
    def _is_no_value(self, var):
983 984 985
        if isinstance(var, (core.VarBase, core.eager.Tensor)) and var.shape == [
            1
        ]:
986 987
            # NOTE: .numpy() will insert MemcpySync operation, it hits performance.
            if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
988 989 990 991 992 993 994
                return True
        return False

    def _remove_no_value(self, out_vars):
        """
        Removes invalid value for various-length return statement
        """
J
Jiabin Yang 已提交
995
        if isinstance(out_vars, (core.VarBase, core.eager.Tensor)):
996 997 998 999 1000
            if self._is_no_value(out_vars):
                return None
            return out_vars
        elif isinstance(out_vars, (tuple, list)):
            if isinstance(out_vars, tuple):
1001 1002 1003
                res = tuple(
                    var for var in out_vars if not self._is_no_value(var)
                )
1004 1005 1006 1007
            else:
                # isinstance(out_vars, list)
                res = [var for var in out_vars if not self._is_no_value(var)]

1008
            has_removed = len(out_vars) > len(res)
1009 1010 1011 1012 1013 1014 1015 1016 1017 1018
            # len(out_vars) > len(res) means we have removed var. This is
            # preventing out_vars is empty or just one element at the beginning
            if len(res) == 0 and has_removed:
                return None
            elif len(res) == 1 and has_removed:
                return res[0]
            return res

        return out_vars

1019
    def _set_grad_type(self, params, train_program):
1020 1021 1022 1023 1024 1025 1026 1027
        # NOTE: if user set sparse gradient mode, the param's gradient
        # will be SelectedRows, not LoDTensor. But tracer will just
        # set param grad VarBase by forward VarBase(LoDTensor)
        # If we don't change grad_var type here, RunProgramOp need
        # transform SelectedRows to LoDTensor forcibly, it may not
        # be user wanted result.
        for param in params:
            grad_name = param.name + core.grad_var_suffix()
1028
            grad_var = train_program.desc.block(0).find_var(grad_name.encode())
1029 1030 1031 1032 1033
            # NOTE: cannot find var desc maybe no problem, such as in batch_norm
            if grad_var is None:
                continue
            param._set_grad_type(grad_var.type())

1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046
    def _remove_op_call_stack(self, main_program):
        """
        Remove op's python call stack with redundant low-level error messages related to
        transforamtions to avoid confusing users.
        """
        assert isinstance(main_program, framework.Program)
        for block in main_program.blocks:
            for op in block.ops:
                if op.has_attr("op_callstack"):
                    op._remove_attr("op_callstack")

        return main_program

1047 1048 1049 1050 1051 1052 1053 1054 1055 1056
    def _check_params_all_inited(self, main_program):
        """
        Check all params from main program are already initialized, see details as follows:
            1. all parameters in self._params should be type `framework.ParamBase` which are created in dygraph.
            2. all parameters from transformed program can be found in self._params.
               Because they share same data with ParamBase of original dygraph.
        """
        if not isinstance(self._params, (list, tuple)):
            raise TypeError(
                "Type of self._params in PartialProgramLayer should be list or tuple, but received %s."
1057 1058
                % type(self._params)
            )
1059

1060 1061 1062
        param_and_buffer_names_set = set()
        for i, var in enumerate(self._params):
            # self._params constains parameters and buffers with persistable=True.
J
Jiabin Yang 已提交
1063
            if not isinstance(var, (core.VarBase, core.eager.Tensor)):
1064
                raise TypeError(
1065 1066 1067 1068
                    'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.'.format(
                        i, type(var)
                    )
                )
1069
            param_and_buffer_names_set.add(var.name)
1070 1071

        for block in main_program.blocks:
1072
            for name, var in block.vars.items():
1073
                if isinstance(var, framework.Parameter):
1074
                    if name not in param_and_buffer_names_set:
1075
                        raise ValueError(
1076 1077 1078 1079 1080 1081
                            "\n\tWe don't support to define layer with parameters in the function decorated by `@to_static`."
                            "\n\tBut we found parameter(%s) was created in the decorated function."
                            "\n"
                            "\n\tRevise suggestion: "
                            "\n\t\t1. Please ensure all your sublayers are inheritted from nn.Layer."
                            "\n\t\t2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List"
1082 1083
                            % name
                        )
1084

1085
    def _valid_vars(self, vars):
1086
        return vars if vars else None
1087

1088 1089 1090 1091 1092 1093

def partial_program_from(concrete_program):
    inputs = concrete_program.inputs
    if inputs and isinstance(inputs[0], layers.Layer):
        inputs = inputs[1:]

1094 1095 1096 1097 1098 1099 1100
    return PartialProgramLayer(
        concrete_program.main_program,
        inputs,
        concrete_program.outputs,
        concrete_program.parameters,
        **concrete_program.kwargs
    )
1101 1102 1103


@switch_to_static_graph
1104
def add_build_strategy_for(
1105
    program, start_op_index, end_op_index, build_strategy=None, skip_vars=None
1106 1107
):
    if start_op_index < end_op_index:
1108 1109
        compiled_program = paddle.static.CompiledProgram(
            core.Graph(program.desc, start_op_index, end_op_index),
1110 1111
            build_strategy=build_strategy,
        )
1112 1113 1114
        if skip_vars:
            # TODO(Aurelius84): Need to unify name with C++, such as kSkipVarNames.
            compiled_program._graph.set("skip_gc_vars", set(skip_vars))
1115 1116 1117
        compiled_program._compile(
            core.Scope(), framework._current_expected_place()
        )
1118 1119 1120 1121 1122 1123 1124
        ir_graph = framework.IrGraph(compiled_program._graph)
        builded_program = ir_graph.to_program()
        if hasattr(compiled_program._program, 'lr_sheduler'):
            builded_program.lr_sheduler = compiled_program._program.lr_sheduler
    else:
        builded_program = program
    return builded_program