partial_program.py 39.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from copy import deepcopy

17
import numpy as np
18

19
import paddle
20
from paddle import _legacy_C_ops
21
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
22
from paddle.fluid import backward, core, framework, program_guard
23
from paddle.fluid.compiler import BuildStrategy
24
from paddle.fluid.data_feeder import convert_dtype
25 26
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _apply_pass
27
from paddle.optimizer.lr import LRScheduler
28 29

from . import logging_utils
30 31 32 33 34 35
from .utils import (
    RETURN_NO_VALUE_MAGIC_NUM,
    _out_grad_names,
    _param_grad_names,
    backend_guard,
)
36

37 38
__all__ = []

39

40
class NestSequence:
41 42 43 44 45 46 47
    """
    A wrapper class that easily to flatten and restore the nest structure of
    given sequence.
    """

    def __init__(self, raw_input, need_check=False):
        self.__raw_input = raw_input
48
        self.__input_list = self.tolist()
49 50 51 52 53 54 55
        self.__var_ids = self._get_var_ids()
        self._check_non_variable(need_check)

    def tolist(self):
        """
        Flattens the nested sequences into single list.
        """
56
        return paddle.utils.flatten(self.__raw_input)
57 58 59 60 61

    def restore(self, value_list):
        """
        Restores the nested sequence from value list.
        """
62
        assert len(self.__input_list) == len(value_list)
63
        return paddle.utils.pack_sequence_as(self.__raw_input, value_list)
64 65 66

    def _get_var_ids(self):
        var_ids = []
67
        for idx, var in enumerate(self.__input_list):
W
wanghuancoder 已提交
68
            if isinstance(var, (framework.Variable, core.eager.Tensor)):
69 70 71 72 73 74 75 76 77 78
                var_ids.append(idx)

        return var_ids

    def _check_non_variable(self, need_check):
        """
        Raises warning if output of traced function contains non-tensor type values.
        """
        if need_check:
            warning_types = set()
79
            for var in self.__input_list:
W
wanghuancoder 已提交
80
                if not isinstance(var, (framework.Variable, core.eager.Tensor)):
81 82
                    warning_types.add(type(var))
            if warning_types:
83
                logging_utils.warn(
84 85
                    "Output of traced function contains non-tensor type values: {}. "
                    "Currently, We don't support to update them while training and will return "
86 87 88 89
                    "what we first saw. Please try to return them as tensor.".format(
                        list(warning_types)
                    )
                )
90 91 92 93 94 95

    @property
    def var_ids(self):
        return self.__var_ids

    def __getitem__(self, item):
96
        return self.__input_list[item]
97

98

99
class LazyInitialized:
100 101 102 103 104 105 106 107 108 109 110 111 112
    """
    Descriptor to implement lazy initialization of property.
    """

    def __init__(self, function):
        self.function = function

    def __get__(self, instance, cls):
        val = self.function(instance)
        setattr(instance, self.function.__name__, val)
        return val


113 114 115 116 117
class ProgramInfo:
    """
    A helper class to recoder Program information
    """

118
    def __init__(self):
119 120 121 122 123
        self.op_size = {
            'fp32': -1,
            'amp': -1,
            'fp16': -1,
        }
124 125 126 127 128 129 130 131 132 133 134 135 136 137
        self.programs = {}
        self.mode = "infer"

    def __call__(self, key, prog_creator):
        """
        Recoder infer program and op size.
        """
        assert key in ['fp32', 'amp', 'fp16']
        if key not in self.programs:
            infer_prog = prog_creator(is_infer_mode=True)
            self.programs[key] = infer_prog
            self.op_size[key] = infer_prog.desc.block(0).op_size()

        return self.programs[key], self.op_size[key]
138 139


X
xiongkun 已提交
140
class PartialProgramLayerHook:
141
    def before_append_backward(self, forward_program):
X
xiongkun 已提交
142 143
        ...

144
    def after_append_backward(self, whole_program, backward_start_idx):
X
xiongkun 已提交
145 146
        ...

147
    def after_infer(self, infer_program):
X
xiongkun 已提交
148 149 150
        ...


151
class PartialProgramLayer:
152
    """
H
hjyp 已提交
153
    PartialProgramLayer wraps all the ops from layers decorated by `@to_static`
154 155 156
    and execute them as a static subgraph.

    .. note::
157 158 159
        **1. This is a very low level API. Users should not use this API
             directly. Please use `partial_program_from(concrete_program)`
             to create it.
160 161 162 163
        **2. LoDTensorArray is not currently supported in the output.

    Args:
        main_program(Program): The main program that contains ops need to be executed.
H
hjyp 已提交
164 165
        inputs(list[Variable]): The input list of the decorated function by `@to_static`.
        outputs(list[Variable]): The output list of the decorated function by `@to_static`.
W
wanghuancoder 已提交
166
        parameters(list[Tensor]|None): All trainable parameters included in the program. Default None.
167 168

    Returns:
169
        Layer: A Layer object that run all ops internally in static graph mode.
170 171
    """

172 173 174
    def __init__(
        self, main_program, inputs, outputs, parameters=None, **kwargs
    ):
175
        super().__init__()
176 177
        self._inputs = NestSequence(inputs)
        self._outputs = NestSequence(outputs, need_check=True)
178
        self._params = parameters if parameters is not None else []
179

180 181 182
        self._build_strategy = kwargs.get('build_strategy', BuildStrategy())
        assert isinstance(self._build_strategy, BuildStrategy)

183
        self._origin_main_program = self._verify_program(main_program)
184 185 186
        self._cuda_graph_vec = self._create_cuda_graph_vec()
        self._cuda_graph_capture_mode = ""
        self._cuda_graph_pool_id = 0
187
        # Set default mode to train
188
        self.training = True
189
        self._infer_info = ProgramInfo()
190
        self._forward_end_index_map = {}
191

192
        amp_dtype, custom_white_list, custom_black_list = None, None, None
193 194 195
        tracer = framework._dygraph_tracer()
        if tracer:
            custom_white_list, custom_black_list = tracer._get_amp_op_list()
196 197 198 199 200 201 202 203 204 205
            amp_dtype = tracer._amp_dtype
        if amp_dtype is not None and amp_dtype in ['float16', 'bfloat16']:
            # For AMP training
            self._amp_list = (
                paddle.static.amp.fp16_lists.AutoMixedPrecisionLists(
                    custom_white_list=custom_white_list,
                    custom_black_list=custom_black_list,
                    dtype=amp_dtype,
                )
            )
206

207 208
        # program_id -> list(scope)
        self._scope_cache = {}
X
xiongkun 已提交
209
        self._hooker = None
210
        self._backend = kwargs.get('backend', None)
211

212 213 214 215 216 217 218 219
    def __call__(self, inputs):
        """
        Execute static graph by Interpreter and Return dynamic Tensors.
        """
        in_vars, out_vars = self._prepare(inputs)
        self._cast_fp16_if_pure_fp16(in_vars)
        attrs = self._prepare_attributes()

220 221
        self._sync_lr_value_with_scheduler()

222 223 224 225 226 227 228 229 230 231 232 233 234 235
        _legacy_C_ops.run_program(
            self._valid_vars(in_vars),
            self._valid_vars(self._params),
            self._valid_vars(out_vars),
            self._create_scope_vec(
                program_id=self.program_id, use_scope_cache=True
            ),
            self._double_grads,
            self._cuda_graph_vec,
            *attrs
        )
        restored_nest_out = self._restore_out(out_vars)
        return self._remove_no_value(restored_nest_out)

236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
    def _sync_lr_value_with_scheduler(self):
        """Update lr_var value with calculated by lr_scheduler."""
        main_program = self._origin_main_program
        if hasattr(main_program, 'lr_scheduler') and hasattr(
            main_program, 'lr_var'
        ):
            lr_scheduler = main_program.lr_scheduler
            lr_var = main_program.lr_var

            assert isinstance(lr_scheduler, LRScheduler), "must be LRScheduler"
            lr_scheduler = self._origin_main_program.lr_scheduler
            lr_value = lr_scheduler()
            data = np.array(lr_value).astype(convert_dtype(lr_var.dtype))
            lr_var.set_value(data)

X
xiongkun 已提交
251 252 253
    def set_hooker(self, hooker):
        self._hooker = hooker

254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
    def _get_scope(self, program_id=None, use_scope_cache=False):
        if use_scope_cache:
            if program_id not in self._scope_cache:
                scope = core.Scope()
                self._scope_cache[program_id] = [scope]
                return scope
            else:
                for scope in self._scope_cache[program_id]:
                    if scope._can_reuesd:
                        return scope
                scope = core.Scope()
                self._scope_cache[program_id].append(scope)
                return scope
        else:
            return core.Scope()

270 271
    @LazyInitialized
    def _double_grads(self):
272 273
        # TODO: check the affects.
        return None
274

275 276 277 278
    # whole
    @switch_to_static_graph
    def _create_program(self, is_infer_mode=False):
        if is_infer_mode:
X
xiongkun 已提交
279 280 281 282
            infer_program = self._origin_main_program.clone(
                for_test=is_infer_mode
            )
            if self._hooker:
283
                infer_program = self._hooker.after_infer(infer_program)
X
xiongkun 已提交
284
            return infer_program
285 286
        else:
            train_program = self._append_backward_desc(
287 288
                self._origin_main_program
            )
289 290 291
            # Note: Only set grad type once after initializing train program. So we put it here.
            self._set_grad_type(self._params, train_program)
            return train_program
292

293 294 295 296
    @switch_to_static_graph
    def _create_amp_program(self, is_infer_mode=False):
        amp_program = self._origin_main_program.clone(for_test=is_infer_mode)
        with program_guard(amp_program):
297 298
            paddle.static.amp.fp16_utils.cast_model_to_fp16(
                amp_program, self._amp_list, use_fp16_guard=False, level='O1'
299
            )
300
        if is_infer_mode:
301 302
            if self._hooker:
                amp_program = self._hooker.after_infer(amp_program)
303 304 305 306 307
            return amp_program
        else:
            train_amp_program = self._append_backward_desc(amp_program)
            self._set_grad_type(self._params, train_amp_program)
            return train_amp_program
308

309 310 311
    @switch_to_static_graph
    def _create_pure_fp16_program(self, is_infer_mode=False):
        pure_fp16_program = self._origin_main_program.clone(
312 313
            for_test=is_infer_mode
        )
314
        with program_guard(pure_fp16_program):
315
            paddle.static.amp.fp16_utils.cast_model_to_fp16(
316 317
                pure_fp16_program, self._amp_list, use_fp16_guard=False
            )
J
Jiabin Yang 已提交
318

319
        if is_infer_mode:
320 321
            if self._hooker:
                pure_fp16_program = self._hooker.after_infer(pure_fp16_program)
322 323 324
            return pure_fp16_program
        else:
            train_pure_fp16_program = self._append_backward_desc(
325 326
                pure_fp16_program
            )
327 328
            self._set_grad_type(self._params, train_pure_fp16_program)
            return train_pure_fp16_program
329

330
    @switch_to_static_graph
331
    def _create_forward_backward_train_program(self):
332
        whole_program = self._train_program
X
xiongkun 已提交
333
        forward_end_op_index = self.get_forward_end_op_idx(whole_program)
334
        assert forward_end_op_index >= 0
335

336 337 338
        return self._get_forward_backward_program_form(
            whole_program, forward_end_op_index
        )
339

340 341
    @switch_to_static_graph
    def _create_forward_backward_train_amp_program(self):
342
        whole_program = self._train_amp_program
343
        forward_end_op_index = self.get_forward_end_op_idx(whole_program)
344
        assert forward_end_op_index >= 0
345

346 347 348
        return self._get_forward_backward_program_form(
            whole_program, forward_end_op_index
        )
349 350 351

    @switch_to_static_graph
    def _create_forward_backward_train_pure_fp16_program(self):
352
        whole_program = self._train_pure_fp16_program
353
        forward_end_op_index = self.get_forward_end_op_idx(whole_program)
354
        assert forward_end_op_index >= 0
355

356 357 358
        return self._get_forward_backward_program_form(
            whole_program, forward_end_op_index
        )
359 360

    @LazyInitialized
361 362
    def _train_program(self):
        return self._create_program()
363

364
    @LazyInitialized
365
    def _infer_program(self):
366 367
        program, op_size = self._infer_info('fp32', self._create_program)
        return self._build_infer_program(program, op_size)
368

369 370 371 372 373 374
    @LazyInitialized
    def _train_amp_program(self):
        return self._create_amp_program()

    @LazyInitialized
    def _infer_amp_program(self):
375 376
        program, op_size = self._infer_info('amp', self._create_amp_program)
        return self._build_infer_program(program, op_size)
377 378 379

    @LazyInitialized
    def _train_pure_fp16_program(self):
380
        return self._create_pure_fp16_program()
381

382
    @LazyInitialized
383
    def _infer_pure_fp16_program(self):
384 385
        program, op_size = self._infer_info(
            'fp16', self._create_pure_fp16_program
386
        )
387
        return self._build_infer_program(program, op_size)
388

389
    @LazyInitialized
390 391 392
    def _train_forward_backward_program(self):
        program = self._create_forward_backward_train_program()
        return program
393 394

    @LazyInitialized
395 396 397 398
    def _train_amp_forward_backward_program(self):
        program = self._create_forward_backward_train_amp_program()
        return program

399 400 401 402
    @LazyInitialized
    def _empty_backward_program_for_eval(self):
        return paddle.static.Program()

403 404 405 406 407
    @LazyInitialized
    def _train_pure_fp16_forward_backward_program(self):
        program = self._create_forward_backward_train_pure_fp16_program()
        return program

408 409
    @LazyInitialized
    def _train_program_id(self):
410
        program_id = paddle.utils._hash_with_id(self._train_program, self)
411 412 413
        core._set_cached_executor_build_strategy(
            program_id, self._build_strategy
        )
414
        return program_id
415

416 417
    @LazyInitialized
    def _infer_program_id(self):
418
        return paddle.utils._hash_with_id(self._infer_program, self)
419

420 421
    @LazyInitialized
    def _train_amp_program_id(self):
422
        program_id = paddle.utils._hash_with_id(self._train_amp_program, self)
423 424 425
        core._set_cached_executor_build_strategy(
            program_id, self._build_strategy
        )
426 427
        return program_id

428 429
    @LazyInitialized
    def _infer_amp_program_id(self):
430
        return paddle.utils._hash_with_id(self._infer_amp_program, self)
431

432 433
    @LazyInitialized
    def _train_pure_fp16_program_id(self):
434 435 436
        program_id = paddle.utils._hash_with_id(
            self._train_pure_fp16_program, self
        )
437 438 439
        core._set_cached_executor_build_strategy(
            program_id, self._build_strategy
        )
440 441
        return program_id

442 443
    @LazyInitialized
    def _infer_pure_fp16_program_id(self):
444
        return paddle.utils._hash_with_id(self._infer_pure_fp16_program, self)
445

446 447
    @LazyInitialized
    def _param_grad_names(self):
448
        return _param_grad_names(self._train_program.desc, self._params)
449

X
xiongkun 已提交
450
    def get_forward_end_op_idx(self, program):
451 452 453
        return self._forward_end_index_map[
            paddle.utils._hash_with_id(program, self)
        ]
X
xiongkun 已提交
454

455 456
    @LazyInitialized
    def _out_grad_names(self):
457 458
        return _out_grad_names(
            self._train_program.desc,
X
xiongkun 已提交
459
            self.get_forward_end_op_idx(self._train_program),
460 461
            len(self._outputs.var_ids),
        )
462

463
    @property
464 465 466 467 468 469 470 471 472 473 474 475 476 477
    def program(self):
        """
        Return current train or eval program.
        """
        if self.training:
            return self.train_program
        else:
            return self.infer_program

    @property
    def program_id(self):
        """
        Return current train or eval program hash id.
        """
478 479 480 481 482 483 484 485 486 487 488 489 490 491 492
        if self.training:
            if _in_amp_guard():
                return self._train_amp_program_id
            elif _in_pure_fp16_guard():
                return self._train_pure_fp16_program_id
            else:
                return self._train_program_id
        else:
            if _in_amp_guard():
                return self._infer_amp_program_id
            elif _in_pure_fp16_guard():
                return self._infer_pure_fp16_program_id
            else:
                return self._infer_program_id

493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545
    @property
    def train_program(self):
        if _in_amp_guard():
            return self._train_amp_program
        elif _in_pure_fp16_guard():
            return self._train_pure_fp16_program
        else:
            return self._train_program

    @property
    def infer_program(self):
        if _in_amp_guard():
            return self._infer_amp_program
        elif _in_pure_fp16_guard():
            return self._infer_pure_fp16_program
        else:
            return self._infer_program

    @property
    def forward_program(self):
        if self.training:
            if _in_amp_guard():
                progs = self._train_amp_forward_backward_program
            elif _in_pure_fp16_guard():
                progs = self._train_pure_fp16_forward_backward_program
            else:
                progs = self._train_forward_backward_program
            return progs[0]
        else:
            return self.infer_program

    @property
    def backward_program(self):
        if self.training:
            if _in_amp_guard():
                progs = self._train_amp_forward_backward_program
            elif _in_pure_fp16_guard():
                progs = self._train_pure_fp16_forward_backward_program
            else:
                progs = self._train_forward_backward_program
            return progs[1]
        else:
            """
            Can't just return paddle.static.Program(), because self.backward_program is a property,
            whenever we call this method, a tmp Program() object is created and is gc immediatly
            after executed the following line in PartialProgramLayer.__call__.

            >>> self.backward_program.desc.block(0),

            When we access RunProgramAPI, it's possible to get an invalid backward_program address.
            """
            return self._empty_backward_program_for_eval

546 547 548 549 550 551 552 553 554 555 556 557
    def _verify_program(self, main_program):
        """
        Verify that the program parameter is initialized, prune some unused params,
        and remove redundant op callstack.
        """
        # 1. Check all params from main program can be found in self._params
        self._check_params_all_inited(main_program)
        # 2. Prune the parameters not used anywhere in the program.
        self._prune_unused_params(main_program)

        return main_program

558 559 560
    def prepare_gradient_aggregation(
        self, start_idx, main_program, target_program
    ):
561 562 563 564 565 566 567
        """
        Why we need add gradient aggregation operation ?
        In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
        def forward(self, in):
            x = 2 * in  # <---- x is a non-leaf node in program.
            y = x + 3
            return x, y
568

569 570 571 572 573 574 575 576 577
        loss = forward(in)[0].sum()
        loss.backward()  # <----- x@grad will be overwrited by elementwise_add_grad Op
        """

        def _need_aggregation(var):
            """
            if exist a op whose inputs is var, then return True
            """
            if not isinstance(var, framework.Variable) or var.type not in [
578 579
                core.VarDesc.VarType.LOD_TENSOR,
                core.VarDesc.VarType.SELECTED_ROWS,
580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595
            ]:
                return False
            if var.dtype not in [paddle.float32, paddle.float64]:
                return False
            for op in main_program.block(0).ops:
                for in_arg in op.input_arg_names:
                    if in_arg == var.name:
                        return True
            return False

        def _insert_aggregation_ops_for_var(target_program, var):
            suffix = "@dy2static"
            var_grad_name = var.grad_name
            new_grad_name = var.name + suffix + "@GRAD"
            finded_ops = list(
                filter(
596 597 598 599 600 601 602 603 604 605
                    lambda x: x[0] >= start_idx
                    and any(
                        [
                            out_arg == var_grad_name
                            for out_arg in x[1].output_arg_names
                        ]
                    ),
                    enumerate(target_program.block(0).ops),
                )
            )
606 607 608 609 610 611

            # len(finded_ops) may equals zero when stop_gradient works.
            # len(finded_ops) may > 1, because we may have fill_constant op.
            if len(finded_ops) == 0:
                return None
            # step1: create a new var named var.name@GRAD
612 613 614 615 616 617
            target_program.block(0).create_var(
                name=new_grad_name,
                type=var.type,
                dtype=var.dtype,
                shape=var.shape,
            )
618 619 620 621 622 623 624 625 626 627
            # step2: rename the var.name@GRAD to var.name@GRAD@dy2static
            for idx, op in finded_ops:
                op._rename_input(var_grad_name, new_grad_name)
                op._rename_output(var_grad_name, new_grad_name)
            # step3: insert sum op to aggregate the gradient.
            #        var.name@GRAD = sum(var.name@dy2static@GRAD, var.name@GRAD)
            target_program.block(0)._insert_op(
                finded_ops[-1][0] + 1,
                type='sum',
                inputs={'X': [var_grad_name, new_grad_name]},
628 629
                outputs={"Out": var_grad_name},
            )
630 631 632
            return None

        to_processed_vars = list(
633 634
            filter(_need_aggregation, self._outputs.tolist())
        )
635 636 637
        for _var in to_processed_vars:
            _insert_aggregation_ops_for_var(target_program, _var)

638
    @switch_to_static_graph
639
    def _append_backward_desc(self, main_program):
640
        program = main_program.clone(for_test=False)
X
xiongkun 已提交
641
        if self._hooker:
642
            program = self._hooker.before_append_backward(program)
643
        targets = []
644
        for out in self._outputs.tolist():
645 646 647
            if isinstance(out, framework.Variable):
                targets.append(program.global_block().var(out.name))

X
xiongkun 已提交
648
        start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
649
        if targets:
650
            start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
651 652
            with backend_guard(self._backend):
                backward.gradients(targets=targets, inputs=[])
653

X
xiongkun 已提交
654 655
            if self._hooker:
                program, start_idx = self._hooker.after_append_backward(
656
                    program, start_idx
X
xiongkun 已提交
657
                )
658 659 660
            self.prepare_gradient_aggregation(
                start_idx + 1, main_program, program
            )
661

X
xiongkun 已提交
662
        self._forward_end_index_map[
663
            paddle.utils._hash_with_id(program, self)
X
xiongkun 已提交
664
        ] = start_idx - len(self._outputs.tolist())
665 666
        return program

667 668 669
    def _prune_unused_params(self, program):
        """
        Prune the parameters not used anywhere in the program.
H
hjyp 已提交
670
        The `@to_static` may only decorated a sub function which
671 672 673 674 675 676
        contains some unused parameters created in `__init__`.
        So prune these parameters to avoid unnecessary operations in
        `run_program_op`.
        """
        required_params = []
        for param in self._params:
677
            found_param = False
678
            for block in program.blocks:
679
                for op in block.ops:
680 681 682 683
                    if (
                        param.name in op.input_arg_names
                        or param.name in op.output_arg_names
                    ):
684 685 686 687
                        required_params.append(param)
                        found_param = True
                        break
                if found_param:
688 689 690 691
                    break

        self._params = required_params

692 693 694 695 696 697 698 699 700 701 702
    def _cast_fp16_if_pure_fp16(self, in_vars):
        if _in_pure_fp16_guard():
            for i, var in enumerate(in_vars):
                name = var.name
                if (
                    self.program.global_block().has_var(name)
                    and self.program.global_block().var(name).dtype
                    == paddle.float16
                ):
                    in_vars[i] = var.astype('float16')
                    in_vars[i].name = name
703

704
    def _prepare_attributes(self):
705
        attrs = [
706 707 708 709
            'forward_global_block',
            self.forward_program.desc.block(0),
            'backward_global_block',
            self.backward_program.desc.block(0),
710 711 712 713
            'is_test',
            not self.training,
            'program_id',
            self.program_id,
714
        ]
X
xiongkun 已提交
715

716 717 718 719 720 721 722 723 724 725 726 727
        if self.training:
            # NOTE: In the case of higher-order gradient, the names of the parameter grads may be like
            # `grad/grad/grad/linear_0.w_0@GRAD` instead of simply `linear_0.w_0@GRAD`, so we get
            # the correct names of the parameter grads from program. And out grads are similar to above.
            attrs.extend(
                (
                    'param_grad_names',
                    self._param_grad_names,
                    'out_grad_names',
                    self._out_grad_names,
                )
            )
728 729
        if self._cuda_graph_capture_mode:
            attrs.extend(
730 731 732 733 734 735 736
                (
                    'cuda_graph_capture_mode',
                    self._cuda_graph_capture_mode,
                    'cuda_graph_pool_id',
                    self._cuda_graph_pool_id,
                )
            )
737
        return attrs
738

739 740 741 742 743 744 745 746 747 748 749 750
    @switch_to_static_graph
    def _build_infer_program(self, infer_program, forward_end_op_index):
        forward_skip_vars = self._parse_skip_gc_vars(infer_program)
        builded_infer_program = add_build_strategy_for(
            infer_program,
            0,
            forward_end_op_index,
            self._build_strategy,
            forward_skip_vars,
        )
        self._apply_inplace_pass(builded_infer_program, None)
        return builded_infer_program
751

752
    @switch_to_static_graph
753 754 755
    def _get_forward_backward_program_form(
        self, whole_program, forward_end_op_index
    ):
756 757
        # NOTE(dev): We apply build_strategy for backward firstly to
        # avoid skipping more gc variables.
758
        backward_start_op_index = forward_end_op_index + len(
759 760
            self._outputs.var_ids
        )
761
        backward_end_op_index = whole_program.desc.block(0).op_size()
762 763 764 765 766
        # For Backward process in CINN, all param@GRAD shoule be skipped for GC, because
        # they will be shared in scope and used by optimizer.
        backward_skip_vars = (
            self._parse_skip_gc_vars(whole_program) + self._param_grad_names
        )
767
        backward_builded_program = add_build_strategy_for(
768 769 770 771
            whole_program,
            backward_start_op_index,
            backward_end_op_index,
            self._build_strategy,
772 773 774 775 776 777 778 779 780 781 782 783
            backward_skip_vars,
        )

        forward_skip_vars = self._parse_skip_gc_vars(
            whole_program, backward_builded_program
        )
        forward_builded_program = add_build_strategy_for(
            whole_program,
            0,
            forward_end_op_index,
            self._build_strategy,
            forward_skip_vars,
784
        )
785

786 787 788
        self._apply_inplace_pass(
            forward_builded_program, backward_builded_program
        )
789 790 791 792 793 794
        return [forward_builded_program, backward_builded_program]

    def _apply_inplace_pass(self, forward_program, backward_program):
        attr_types = {
            "use_cuda": "bool",
            "mem_opt_skip_vars": "list[str]",
795
            "for_partial_block": "bool",
796 797 798 799
        }
        empty_startup_program = paddle.static.Program()
        use_cuda = True if core.is_compiled_with_cuda() else False
        # skip data var
800 801 802 803
        forward_mem_opt_skip_vars = self._parse_skip_gc_vars(
            forward_program, backward_program
        )
        backward_mem_opt_skip_vars = self._parse_skip_gc_vars(forward_program)
804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829
        if forward_program:
            attrs = {
                "use_cuda": use_cuda,
                "mem_opt_skip_vars": forward_mem_opt_skip_vars,
                "for_partial_block": True,
            }
            _apply_pass(
                forward_program,
                empty_startup_program,
                "buffer_shared_inplace_pass",
                attrs,
                attr_types,
            )
        if backward_program:
            attrs = {
                "use_cuda": use_cuda,
                "mem_opt_skip_vars": backward_mem_opt_skip_vars,
                "for_partial_block": True,
            }
            _apply_pass(
                backward_program,
                empty_startup_program,
                "buffer_shared_inplace_pass",
                attrs,
                attr_types,
            )
830

831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857
    @LazyInitialized
    def _inout_var_names(self):
        """
        Returns Variable Names from self._inputs and self.outputs
        """
        var_names = []
        for var in self._inputs:
            if isinstance(var, paddle.fluid.framework.Variable):
                var_names.append(var.desc.name())
        for var in self._outputs:
            if isinstance(var, paddle.fluid.framework.Variable):
                var_names.append(var.desc.name())
        return var_names

    def _parse_skip_gc_vars(self, program, backward_program=None):
        """
        Parse variables that need to skip GC after execute it.
        If specify backward_program, it will keep the variables used in backward.
        """
        # skip data var, DO NOT ignore this deepcopy
        skip_vars = deepcopy(self._inout_var_names)
        for var_name, var in program.global_block().vars.items():
            if var.is_data:
                skip_vars.append(var_name)

        if backward_program:
            for var_name in core.parse_safe_eager_deletion_skip_vars(
858
                backward_program.desc, True
859 860 861 862
            ):
                skip_vars.append(var_name)
        return skip_vars

863 864 865 866 867
    def _prepare(self, inputs):
        """
        Prepare inputs, outputs, attrs.
        """
        assert isinstance(inputs, (tuple, list))
868
        # Flatten inputs with nested structure into single list.
869
        flatten_inputs = paddle.utils.flatten(inputs)
W
wanghuancoder 已提交
870
        # Convert variable into Tensor and feed in training data.
871
        input_vars = []
872
        expected_place = framework._current_expected_place()
873
        for i, value in enumerate(flatten_inputs):
874
            if isinstance(value, np.ndarray):
J
Jiabin Yang 已提交
875
                var = None
W
wanghuancoder 已提交
876 877 878 879 880 881 882 883
                var = core.eager.Tensor(
                    value=value,
                    name=self._inputs[i].desc.name(),
                    persistable=False,
                    place=expected_place,
                    zero_copy=True,
                )
            elif isinstance(value, core.eager.Tensor):
884 885 886 887
                # NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
                # into CUDAPlace when it's as input of multi Ops. so we move it in advance
                # to avoid this problem.
                if value.stop_gradient and not value.place._equals(
888 889
                    expected_place
                ):
890 891
                    var = value._copy_to(expected_place, False)
                    var.stop_gradient = True
892 893
                else:
                    var = value
894
                var.name = self._inputs[i].desc.name()
895 896 897
            else:
                continue
            input_vars.append(var)
898

W
wanghuancoder 已提交
899
        # mapping from name(string) -> Tensor
900 901
        out_varbase_map = {}

902 903
        def create_out(var_id):
            var = self._outputs[var_id]
904
            assert isinstance(var, framework.Variable)
905
            var_desc = var.desc
J
Jiabin Yang 已提交
906
            varbase = None
907 908 909 910

            if var_desc.name() in out_varbase_map:
                return out_varbase_map[var_desc.name()]

W
wanghuancoder 已提交
911 912 913 914 915 916 917
            var_base = core.eager.Tensor(
                var_desc.dtype(),
                var_desc.shape(),
                var_desc.name(),
                var_desc.type(),
                False,
            )
918
            var_base.stop_gradient = var.stop_gradient
919
            out_varbase_map[var_desc.name()] = var_base
920 921
            return var_base

W
wanghuancoder 已提交
922
        # Create Tensor to receive output data.
923 924 925
        out_vars = list(map(create_out, self._outputs.var_ids))

        return input_vars, out_vars
926

927
    def _create_scope_vec(self, program_id=None, use_scope_cache=False):
928
        # Hold forward variables
J
Jiabin Yang 已提交
929
        tmp_scope_vec = None
930 931 932
        inner_scope = self._get_scope(
            program_id=program_id, use_scope_cache=use_scope_cache
        )
W
wanghuancoder 已提交
933
        tmp_scope_vec = [inner_scope]
934
        return tmp_scope_vec
935

936
    def _create_cuda_graph_vec(self):
W
wanghuancoder 已提交
937
        var = core.eager.Tensor(
938 939 940 941 942 943
            core.VarDesc.VarType.FP32,
            [],
            "cuda_graph",
            core.VarDesc.VarType.RAW,
            True,
        )
944 945 946
        var.stop_gradient = True
        return var

947 948
    def _restore_out(self, out_vars):
        """
W
wanghuancoder 已提交
949
        Restores same nested outputs by only replacing the Variable with Tensor.
950 951 952 953 954 955
        """

        flatten_outputs = self._outputs.tolist()
        for i, idx in enumerate(self._outputs.var_ids):
            flatten_outputs[idx] = out_vars[i]
        outs = self._outputs.restore(flatten_outputs)
956
        if outs is not None and len(outs) == 1:
957 958 959 960
            outs = outs[0]

        return outs

961 962 963 964
    @switch_to_static_graph
    def _clone_for_test(self, main_program):
        return main_program.clone(for_test=True)

965
    def _is_no_value(self, var):
W
wanghuancoder 已提交
966
        if isinstance(var, core.eager.Tensor) and var.shape == [1]:
967 968
            # NOTE: .numpy() will insert MemcpySync operation, it hits performance.
            if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
969 970 971 972 973 974 975
                return True
        return False

    def _remove_no_value(self, out_vars):
        """
        Removes invalid value for various-length return statement
        """
W
wanghuancoder 已提交
976
        if isinstance(out_vars, core.eager.Tensor):
977 978 979 980 981
            if self._is_no_value(out_vars):
                return None
            return out_vars
        elif isinstance(out_vars, (tuple, list)):
            if isinstance(out_vars, tuple):
982 983 984
                res = tuple(
                    var for var in out_vars if not self._is_no_value(var)
                )
985 986 987 988
            else:
                # isinstance(out_vars, list)
                res = [var for var in out_vars if not self._is_no_value(var)]

989
            has_removed = len(out_vars) > len(res)
990 991 992 993 994 995 996 997 998 999
            # len(out_vars) > len(res) means we have removed var. This is
            # preventing out_vars is empty or just one element at the beginning
            if len(res) == 0 and has_removed:
                return None
            elif len(res) == 1 and has_removed:
                return res[0]
            return res

        return out_vars

1000
    def _set_grad_type(self, params, train_program):
1001 1002
        # NOTE: if user set sparse gradient mode, the param's gradient
        # will be SelectedRows, not LoDTensor. But tracer will just
W
wanghuancoder 已提交
1003
        # set param grad Tensor by forward Tensor(LoDTensor)
1004 1005 1006 1007 1008
        # If we don't change grad_var type here, RunProgramOp need
        # transform SelectedRows to LoDTensor forcibly, it may not
        # be user wanted result.
        for param in params:
            grad_name = param.name + core.grad_var_suffix()
1009
            grad_var = train_program.desc.block(0).find_var(grad_name.encode())
1010 1011 1012 1013 1014
            # NOTE: cannot find var desc maybe no problem, such as in batch_norm
            if grad_var is None:
                continue
            param._set_grad_type(grad_var.type())

1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027
    def _remove_op_call_stack(self, main_program):
        """
        Remove op's python call stack with redundant low-level error messages related to
        transforamtions to avoid confusing users.
        """
        assert isinstance(main_program, framework.Program)
        for block in main_program.blocks:
            for op in block.ops:
                if op.has_attr("op_callstack"):
                    op._remove_attr("op_callstack")

        return main_program

1028 1029 1030
    def _check_params_all_inited(self, main_program):
        """
        Check all params from main program are already initialized, see details as follows:
W
wanghuancoder 已提交
1031
            1. all parameters in self._params should be type `framework.EagerParamBase` which are created in dygraph.
1032
            2. all parameters from transformed program can be found in self._params.
W
wanghuancoder 已提交
1033
               Because they share same data with EagerParamBase of original dygraph.
1034 1035 1036 1037
        """
        if not isinstance(self._params, (list, tuple)):
            raise TypeError(
                "Type of self._params in PartialProgramLayer should be list or tuple, but received %s."
1038 1039
                % type(self._params)
            )
1040

1041 1042 1043
        param_and_buffer_names_set = set()
        for i, var in enumerate(self._params):
            # self._params constains parameters and buffers with persistable=True.
W
wanghuancoder 已提交
1044
            if not isinstance(var, core.eager.Tensor):
1045
                raise TypeError(
1046 1047 1048 1049
                    'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.'.format(
                        i, type(var)
                    )
                )
1050
            param_and_buffer_names_set.add(var.name)
1051 1052

        for block in main_program.blocks:
1053
            for name, var in block.vars.items():
1054
                if isinstance(var, framework.Parameter):
1055
                    if name not in param_and_buffer_names_set:
1056
                        raise ValueError(
1057 1058 1059 1060 1061 1062
                            "\n\tWe don't support to define layer with parameters in the function decorated by `@to_static`."
                            "\n\tBut we found parameter(%s) was created in the decorated function."
                            "\n"
                            "\n\tRevise suggestion: "
                            "\n\t\t1. Please ensure all your sublayers are inheritted from nn.Layer."
                            "\n\t\t2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List"
1063 1064
                            % name
                        )
1065

1066
    def _valid_vars(self, vars):
1067
        return vars if vars else None
1068

1069

1070
def partial_program_from(concrete_program, from_method=False):
1071
    inputs = concrete_program.inputs
1072 1073 1074

    # NOTE(SigureMo): Remove the first arg `self` from method args.
    if inputs and from_method:
1075 1076
        inputs = inputs[1:]

1077 1078 1079 1080 1081 1082 1083
    return PartialProgramLayer(
        concrete_program.main_program,
        inputs,
        concrete_program.outputs,
        concrete_program.parameters,
        **concrete_program.kwargs
    )
1084 1085 1086


@switch_to_static_graph
1087
def add_build_strategy_for(
1088
    program, start_op_index, end_op_index, build_strategy=None, skip_vars=None
1089 1090
):
    if start_op_index < end_op_index:
1091 1092
        compiled_program = paddle.static.CompiledProgram(
            core.Graph(program.desc, start_op_index, end_op_index),
1093 1094
            build_strategy=build_strategy,
        )
1095 1096 1097
        if skip_vars:
            # TODO(Aurelius84): Need to unify name with C++, such as kSkipVarNames.
            compiled_program._graph.set("skip_gc_vars", set(skip_vars))
1098 1099 1100
        compiled_program._compile(
            core.Scope(), framework._current_expected_place()
        )
1101 1102
        ir_graph = framework.IrGraph(compiled_program._graph)
        builded_program = ir_graph.to_program()
1103 1104 1105 1106
        if hasattr(compiled_program._program, 'lr_scheduler'):
            builded_program.lr_scheduler = (
                compiled_program._program.lr_scheduler
            )
1107
    else:
X
xiongkun 已提交
1108
        # can't just create a new program, we need copy the vardesc.
1109
        builded_program = paddle.static.Program()
X
xiongkun 已提交
1110 1111
        for var in program.block(0).vars.values():
            builded_program.block(0)._clone_variable(var, False)
1112
    return builded_program