partial_program.py 41.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
16 17
from copy import deepcopy

18
import numpy as np
19

20
import paddle
21
from paddle import _legacy_C_ops
22
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
23
from paddle.fluid import backward, core, framework, program_guard
24
from paddle.fluid.compiler import BuildStrategy
25
from paddle.fluid.data_feeder import check_type, convert_dtype
26 27
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _apply_pass
28
from paddle.fluid.unique_name import guard as UniqueNameGuard
29
from paddle.optimizer.lr import LRScheduler
30 31

from . import logging_utils
32 33 34
from .utils import (
    RETURN_NO_VALUE_MAGIC_NUM,
    backend_guard,
35
    construct_grad_names,
36
    tensor_name_guard,
37
)
38

39 40
__all__ = []

41

42
class NestSequence:
43 44 45 46 47 48 49
    """
    A wrapper class that easily to flatten and restore the nest structure of
    given sequence.
    """

    def __init__(self, raw_input, need_check=False):
        self.__raw_input = raw_input
50
        self.__input_list = self.tolist()
51 52 53 54 55 56 57
        self.__var_ids = self._get_var_ids()
        self._check_non_variable(need_check)

    def tolist(self):
        """
        Flattens the nested sequences into single list.
        """
58
        return paddle.utils.flatten(self.__raw_input)
59 60 61 62 63

    def restore(self, value_list):
        """
        Restores the nested sequence from value list.
        """
64
        assert len(self.__input_list) == len(value_list)
65
        return paddle.utils.pack_sequence_as(self.__raw_input, value_list)
66 67 68

    def _get_var_ids(self):
        var_ids = []
69
        for idx, var in enumerate(self.__input_list):
W
wanghuancoder 已提交
70
            if isinstance(var, (framework.Variable, core.eager.Tensor)):
71 72 73 74 75 76 77 78 79 80
                var_ids.append(idx)

        return var_ids

    def _check_non_variable(self, need_check):
        """
        Raises warning if output of traced function contains non-tensor type values.
        """
        if need_check:
            warning_types = set()
81
            for var in self.__input_list:
W
wanghuancoder 已提交
82
                if not isinstance(var, (framework.Variable, core.eager.Tensor)):
83 84
                    warning_types.add(type(var))
            if warning_types:
85
                logging_utils.warn(
86 87
                    "Output of traced function contains non-tensor type values: {}. "
                    "Currently, We don't support to update them while training and will return "
88 89 90 91
                    "what we first saw. Please try to return them as tensor.".format(
                        list(warning_types)
                    )
                )
92 93 94 95 96 97

    @property
    def var_ids(self):
        return self.__var_ids

    def __getitem__(self, item):
98
        return self.__input_list[item]
99

100

101
class LazyInitialized:
102 103 104 105 106 107 108 109 110 111 112 113 114
    """
    Descriptor to implement lazy initialization of property.
    """

    def __init__(self, function):
        self.function = function

    def __get__(self, instance, cls):
        val = self.function(instance)
        setattr(instance, self.function.__name__, val)
        return val


115 116 117 118 119
class ProgramInfo:
    """
    A helper class to recoder Program information
    """

120
    def __init__(self):
121 122 123 124 125
        self.op_size = {
            'fp32': -1,
            'amp': -1,
            'fp16': -1,
        }
126 127 128 129 130 131 132 133 134 135 136 137 138 139
        self.programs = {}
        self.mode = "infer"

    def __call__(self, key, prog_creator):
        """
        Recoder infer program and op size.
        """
        assert key in ['fp32', 'amp', 'fp16']
        if key not in self.programs:
            infer_prog = prog_creator(is_infer_mode=True)
            self.programs[key] = infer_prog
            self.op_size[key] = infer_prog.desc.block(0).op_size()

        return self.programs[key], self.op_size[key]
140 141


X
xiongkun 已提交
142
class PartialProgramLayerHook:
143
    def before_append_backward(self, forward_program):
X
xiongkun 已提交
144 145
        ...

146
    def after_append_backward(self, whole_program, backward_start_idx):
X
xiongkun 已提交
147 148
        ...

149
    def after_infer(self, infer_program):
X
xiongkun 已提交
150 151 152
        ...


153
class PartialProgramLayer:
154
    """
H
hjyp 已提交
155
    PartialProgramLayer wraps all the ops from layers decorated by `@to_static`
156 157 158
    and execute them as a static subgraph.

    .. note::
159 160 161
        **1. This is a very low level API. Users should not use this API
             directly. Please use `partial_program_from(concrete_program)`
             to create it.
162 163 164 165
        **2. LoDTensorArray is not currently supported in the output.

    Args:
        main_program(Program): The main program that contains ops need to be executed.
H
hjyp 已提交
166 167
        inputs(list[Variable]): The input list of the decorated function by `@to_static`.
        outputs(list[Variable]): The output list of the decorated function by `@to_static`.
W
wanghuancoder 已提交
168
        parameters(list[Tensor]|None): All trainable parameters included in the program. Default None.
169 170

    Returns:
171
        Layer: A Layer object that run all ops internally in static graph mode.
172 173
    """

174
    def __init__(
175 176 177 178 179 180 181
        self,
        main_program,
        inputs,
        outputs,
        name_generator,
        parameters=None,
        **kwargs
182
    ):
183
        super().__init__()
184 185
        self._inputs = NestSequence(inputs)
        self._outputs = NestSequence(outputs, need_check=True)
186
        self._params = parameters if parameters is not None else []
187
        self._name_generator = name_generator
188

189 190 191
        self._build_strategy = kwargs.get('build_strategy', BuildStrategy())
        assert isinstance(self._build_strategy, BuildStrategy)

192
        self._origin_main_program = self._verify_program(main_program)
193 194 195 196
        with paddle.fluid.framework._dygraph_guard(
            paddle.fluid.dygraph.Tracer()
        ):
            self._cuda_graph_vec = self._create_cuda_graph_vec()
197 198
        self._cuda_graph_capture_mode = ""
        self._cuda_graph_pool_id = 0
199
        # Set default mode to train
200
        self.training = True
201
        self._infer_info = ProgramInfo()
202
        self._forward_end_index_map = {}
203

204
        amp_dtype, custom_white_list, custom_black_list = None, None, None
205 206 207
        tracer = framework._dygraph_tracer()
        if tracer:
            custom_white_list, custom_black_list = tracer._get_amp_op_list()
208 209 210 211 212 213 214 215 216 217
            amp_dtype = tracer._amp_dtype
        if amp_dtype is not None and amp_dtype in ['float16', 'bfloat16']:
            # For AMP training
            self._amp_list = (
                paddle.static.amp.fp16_lists.AutoMixedPrecisionLists(
                    custom_white_list=custom_white_list,
                    custom_black_list=custom_black_list,
                    dtype=amp_dtype,
                )
            )
218

219 220
        # program_id -> list(scope)
        self._scope_cache = {}
X
xiongkun 已提交
221
        self._hooker = None
222
        self._backend = kwargs.get('backend', None)
223
        self._grad_var_names = {}
224

225 226 227 228
    def __call__(self, inputs):
        """
        Execute static graph by Interpreter and Return dynamic Tensors.
        """
229
        with UniqueNameGuard(self._name_generator):
230
            in_vars, out_vars, in_var_names = self._prepare(inputs)
231 232 233 234 235 236
            self._cast_fp16_if_pure_fp16(in_vars)
            attrs = self._prepare_attributes()
            attrs.extend(["x_names", in_var_names])

            self._sync_lr_value_with_scheduler()

237 238 239 240 241 242 243 244 245 246 247 248
            with tensor_name_guard(in_vars, in_var_names):
                _legacy_C_ops.run_program(
                    self._valid_vars(in_vars),
                    self._valid_vars(self._params),
                    self._valid_vars(out_vars),
                    self._create_scope_vec(
                        program_id=self.program_id, use_scope_cache=True
                    ),
                    self._double_grads,
                    self._cuda_graph_vec,
                    *attrs
                )
249 250 251 252

            self._update_stop_gradient(out_vars)
            restored_nest_out = self._restore_out(out_vars)
            return self._remove_no_value(restored_nest_out)
253

254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
    def _sync_lr_value_with_scheduler(self):
        """Update lr_var value with calculated by lr_scheduler."""
        main_program = self._origin_main_program
        if hasattr(main_program, 'lr_scheduler') and hasattr(
            main_program, 'lr_var'
        ):
            lr_scheduler = main_program.lr_scheduler
            lr_var = main_program.lr_var

            assert isinstance(lr_scheduler, LRScheduler), "must be LRScheduler"
            lr_scheduler = self._origin_main_program.lr_scheduler
            lr_value = lr_scheduler()
            data = np.array(lr_value).astype(convert_dtype(lr_var.dtype))
            lr_var.set_value(data)

X
xiongkun 已提交
269 270 271
    def set_hooker(self, hooker):
        self._hooker = hooker

272 273 274 275 276 277 278 279
    def _get_scope(self, program_id=None, use_scope_cache=False):
        if use_scope_cache:
            if program_id not in self._scope_cache:
                scope = core.Scope()
                self._scope_cache[program_id] = [scope]
                return scope
            else:
                for scope in self._scope_cache[program_id]:
C
co63oc 已提交
280
                    if scope._can_reused:
281 282 283 284 285 286 287
                        return scope
                scope = core.Scope()
                self._scope_cache[program_id].append(scope)
                return scope
        else:
            return core.Scope()

288 289
    @LazyInitialized
    def _double_grads(self):
290 291
        # TODO: check the affects.
        return None
292

293 294 295 296
    # whole
    @switch_to_static_graph
    def _create_program(self, is_infer_mode=False):
        if is_infer_mode:
X
xiongkun 已提交
297 298 299 300
            infer_program = self._origin_main_program.clone(
                for_test=is_infer_mode
            )
            if self._hooker:
301
                infer_program = self._hooker.after_infer(infer_program)
X
xiongkun 已提交
302
            return infer_program
303 304
        else:
            train_program = self._append_backward_desc(
305 306
                self._origin_main_program
            )
307 308 309
            # Note: Only set grad type once after initializing train program. So we put it here.
            self._set_grad_type(self._params, train_program)
            return train_program
310

311 312 313 314
    @switch_to_static_graph
    def _create_amp_program(self, is_infer_mode=False):
        amp_program = self._origin_main_program.clone(for_test=is_infer_mode)
        with program_guard(amp_program):
315 316
            paddle.static.amp.fp16_utils.cast_model_to_fp16(
                amp_program, self._amp_list, use_fp16_guard=False, level='O1'
317
            )
318
        if is_infer_mode:
319 320
            if self._hooker:
                amp_program = self._hooker.after_infer(amp_program)
321 322 323 324 325
            return amp_program
        else:
            train_amp_program = self._append_backward_desc(amp_program)
            self._set_grad_type(self._params, train_amp_program)
            return train_amp_program
326

327 328 329
    @switch_to_static_graph
    def _create_pure_fp16_program(self, is_infer_mode=False):
        pure_fp16_program = self._origin_main_program.clone(
330 331
            for_test=is_infer_mode
        )
332
        with program_guard(pure_fp16_program):
333
            paddle.static.amp.fp16_utils.cast_model_to_fp16(
334 335
                pure_fp16_program, self._amp_list, use_fp16_guard=False
            )
J
Jiabin Yang 已提交
336

337
        if is_infer_mode:
338 339
            if self._hooker:
                pure_fp16_program = self._hooker.after_infer(pure_fp16_program)
340 341 342
            return pure_fp16_program
        else:
            train_pure_fp16_program = self._append_backward_desc(
343 344
                pure_fp16_program
            )
345 346
            self._set_grad_type(self._params, train_pure_fp16_program)
            return train_pure_fp16_program
347

348
    @switch_to_static_graph
349
    def _create_forward_backward_train_program(self):
350
        whole_program = self._train_program
X
xiongkun 已提交
351
        forward_end_op_index = self.get_forward_end_op_idx(whole_program)
352
        assert forward_end_op_index >= 0
353

354 355 356
        return self._get_forward_backward_program_form(
            whole_program, forward_end_op_index
        )
357

358 359
    @switch_to_static_graph
    def _create_forward_backward_train_amp_program(self):
360
        whole_program = self._train_amp_program
361
        forward_end_op_index = self.get_forward_end_op_idx(whole_program)
362
        assert forward_end_op_index >= 0
363

364 365 366
        return self._get_forward_backward_program_form(
            whole_program, forward_end_op_index
        )
367 368 369

    @switch_to_static_graph
    def _create_forward_backward_train_pure_fp16_program(self):
370
        whole_program = self._train_pure_fp16_program
371
        forward_end_op_index = self.get_forward_end_op_idx(whole_program)
372
        assert forward_end_op_index >= 0
373

374 375 376
        return self._get_forward_backward_program_form(
            whole_program, forward_end_op_index
        )
377 378

    @LazyInitialized
379 380
    def _train_program(self):
        return self._create_program()
381

382
    @LazyInitialized
383
    def _infer_program(self):
384 385
        program, op_size = self._infer_info('fp32', self._create_program)
        return self._build_infer_program(program, op_size)
386

387 388 389 390 391 392
    @LazyInitialized
    def _train_amp_program(self):
        return self._create_amp_program()

    @LazyInitialized
    def _infer_amp_program(self):
393 394
        program, op_size = self._infer_info('amp', self._create_amp_program)
        return self._build_infer_program(program, op_size)
395 396 397

    @LazyInitialized
    def _train_pure_fp16_program(self):
398
        return self._create_pure_fp16_program()
399

400
    @LazyInitialized
401
    def _infer_pure_fp16_program(self):
402 403
        program, op_size = self._infer_info(
            'fp16', self._create_pure_fp16_program
404
        )
405
        return self._build_infer_program(program, op_size)
406

407
    @LazyInitialized
408 409 410
    def _train_forward_backward_program(self):
        program = self._create_forward_backward_train_program()
        return program
411 412

    @LazyInitialized
413 414 415 416
    def _train_amp_forward_backward_program(self):
        program = self._create_forward_backward_train_amp_program()
        return program

417 418 419 420
    @LazyInitialized
    def _empty_backward_program_for_eval(self):
        return paddle.static.Program()

421 422 423 424 425
    @LazyInitialized
    def _train_pure_fp16_forward_backward_program(self):
        program = self._create_forward_backward_train_pure_fp16_program()
        return program

426 427
    @LazyInitialized
    def _train_program_id(self):
428
        program_id = paddle.utils._hash_with_id(self._train_program, self)
429 430 431
        core._set_cached_executor_build_strategy(
            program_id, self._build_strategy
        )
432
        return program_id
433

434 435
    @LazyInitialized
    def _infer_program_id(self):
436
        return paddle.utils._hash_with_id(self._infer_program, self)
437

438 439
    @LazyInitialized
    def _train_amp_program_id(self):
440
        program_id = paddle.utils._hash_with_id(self._train_amp_program, self)
441 442 443
        core._set_cached_executor_build_strategy(
            program_id, self._build_strategy
        )
444 445
        return program_id

446 447
    @LazyInitialized
    def _infer_amp_program_id(self):
448
        return paddle.utils._hash_with_id(self._infer_amp_program, self)
449

450 451
    @LazyInitialized
    def _train_pure_fp16_program_id(self):
452 453 454
        program_id = paddle.utils._hash_with_id(
            self._train_pure_fp16_program, self
        )
455 456 457
        core._set_cached_executor_build_strategy(
            program_id, self._build_strategy
        )
458 459
        return program_id

460 461
    @LazyInitialized
    def _infer_pure_fp16_program_id(self):
462
        return paddle.utils._hash_with_id(self._infer_pure_fp16_program, self)
463

X
xiongkun 已提交
464
    def get_forward_end_op_idx(self, program):
465 466 467
        return self._forward_end_index_map[
            paddle.utils._hash_with_id(program, self)
        ]
X
xiongkun 已提交
468

469
    @property
470 471 472 473 474 475 476 477 478 479 480 481 482 483
    def program(self):
        """
        Return current train or eval program.
        """
        if self.training:
            return self.train_program
        else:
            return self.infer_program

    @property
    def program_id(self):
        """
        Return current train or eval program hash id.
        """
484 485 486 487 488 489 490 491 492 493 494 495 496 497 498
        if self.training:
            if _in_amp_guard():
                return self._train_amp_program_id
            elif _in_pure_fp16_guard():
                return self._train_pure_fp16_program_id
            else:
                return self._train_program_id
        else:
            if _in_amp_guard():
                return self._infer_amp_program_id
            elif _in_pure_fp16_guard():
                return self._infer_pure_fp16_program_id
            else:
                return self._infer_program_id

499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551
    @property
    def train_program(self):
        if _in_amp_guard():
            return self._train_amp_program
        elif _in_pure_fp16_guard():
            return self._train_pure_fp16_program
        else:
            return self._train_program

    @property
    def infer_program(self):
        if _in_amp_guard():
            return self._infer_amp_program
        elif _in_pure_fp16_guard():
            return self._infer_pure_fp16_program
        else:
            return self._infer_program

    @property
    def forward_program(self):
        if self.training:
            if _in_amp_guard():
                progs = self._train_amp_forward_backward_program
            elif _in_pure_fp16_guard():
                progs = self._train_pure_fp16_forward_backward_program
            else:
                progs = self._train_forward_backward_program
            return progs[0]
        else:
            return self.infer_program

    @property
    def backward_program(self):
        if self.training:
            if _in_amp_guard():
                progs = self._train_amp_forward_backward_program
            elif _in_pure_fp16_guard():
                progs = self._train_pure_fp16_forward_backward_program
            else:
                progs = self._train_forward_backward_program
            return progs[1]
        else:
            """
            Can't just return paddle.static.Program(), because self.backward_program is a property,
            whenever we call this method, a tmp Program() object is created and is gc immediatly
            after executed the following line in PartialProgramLayer.__call__.

            >>> self.backward_program.desc.block(0),

            When we access RunProgramAPI, it's possible to get an invalid backward_program address.
            """
            return self._empty_backward_program_for_eval

552 553 554 555 556 557 558 559 560 561 562 563
    def _verify_program(self, main_program):
        """
        Verify that the program parameter is initialized, prune some unused params,
        and remove redundant op callstack.
        """
        # 1. Check all params from main program can be found in self._params
        self._check_params_all_inited(main_program)
        # 2. Prune the parameters not used anywhere in the program.
        self._prune_unused_params(main_program)

        return main_program

564 565 566
    def prepare_gradient_aggregation(
        self, start_idx, main_program, target_program
    ):
567 568 569 570 571 572 573
        """
        Why we need add gradient aggregation operation ?
        In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
        def forward(self, in):
            x = 2 * in  # <---- x is a non-leaf node in program.
            y = x + 3
            return x, y
574

575 576 577 578 579 580 581 582 583
        loss = forward(in)[0].sum()
        loss.backward()  # <----- x@grad will be overwrited by elementwise_add_grad Op
        """

        def _need_aggregation(var):
            """
            if exist a op whose inputs is var, then return True
            """
            if not isinstance(var, framework.Variable) or var.type not in [
584 585
                core.VarDesc.VarType.LOD_TENSOR,
                core.VarDesc.VarType.SELECTED_ROWS,
586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601
            ]:
                return False
            if var.dtype not in [paddle.float32, paddle.float64]:
                return False
            for op in main_program.block(0).ops:
                for in_arg in op.input_arg_names:
                    if in_arg == var.name:
                        return True
            return False

        def _insert_aggregation_ops_for_var(target_program, var):
            suffix = "@dy2static"
            var_grad_name = var.grad_name
            new_grad_name = var.name + suffix + "@GRAD"
            finded_ops = list(
                filter(
602 603
                    lambda x: x[0] >= start_idx
                    and any(
604 605
                        out_arg == var_grad_name
                        for out_arg in x[1].output_arg_names
606 607 608 609
                    ),
                    enumerate(target_program.block(0).ops),
                )
            )
610 611 612 613 614 615

            # len(finded_ops) may equals zero when stop_gradient works.
            # len(finded_ops) may > 1, because we may have fill_constant op.
            if len(finded_ops) == 0:
                return None
            # step1: create a new var named var.name@GRAD
616 617 618 619 620 621
            target_program.block(0).create_var(
                name=new_grad_name,
                type=var.type,
                dtype=var.dtype,
                shape=var.shape,
            )
622 623 624 625 626 627 628 629 630 631
            # step2: rename the var.name@GRAD to var.name@GRAD@dy2static
            for idx, op in finded_ops:
                op._rename_input(var_grad_name, new_grad_name)
                op._rename_output(var_grad_name, new_grad_name)
            # step3: insert sum op to aggregate the gradient.
            #        var.name@GRAD = sum(var.name@dy2static@GRAD, var.name@GRAD)
            target_program.block(0)._insert_op(
                finded_ops[-1][0] + 1,
                type='sum',
                inputs={'X': [var_grad_name, new_grad_name]},
632 633
                outputs={"Out": var_grad_name},
            )
634 635 636
            return None

        to_processed_vars = list(
637 638
            filter(_need_aggregation, self._outputs.tolist())
        )
639 640 641
        for _var in to_processed_vars:
            _insert_aggregation_ops_for_var(target_program, _var)

642
    @switch_to_static_graph
643
    def _append_backward_desc(self, main_program):
644
        program = main_program.clone(for_test=False)
X
xiongkun 已提交
645
        if self._hooker:
646
            program = self._hooker.before_append_backward(program)
647
        targets = []
648
        for out in self._outputs.tolist():
649 650 651
            if isinstance(out, framework.Variable):
                targets.append(program.global_block().var(out.name))

X
xiongkun 已提交
652
        start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
653
        if targets:
654
            start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
655
            with backend_guard(self._backend):
656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682
                check_type(
                    targets,
                    'targets',
                    (framework.Variable, list, tuple),
                    'paddle.static.gradients',
                )
                grad_info_map = backward.calc_gradient_helper(
                    targets=targets, inputs=[]
                )

                x_vars = [
                    program.block(0).var(var.name)
                    for var in self._inputs
                    if isinstance(var, framework.Variable)
                ]
                param_vars = [
                    program.block(0).var(param.name) for param in self._params
                ]
                out_vars = [
                    program.block(0).var(var.name)
                    for var in self._outputs
                    if isinstance(var, framework.Variable)
                ]

                self._grad_var_names = construct_grad_names(
                    grad_info_map, x_vars, param_vars, out_vars
                )
683

X
xiongkun 已提交
684 685
            if self._hooker:
                program, start_idx = self._hooker.after_append_backward(
686
                    program, start_idx
X
xiongkun 已提交
687
                )
688 689 690
            self.prepare_gradient_aggregation(
                start_idx + 1, main_program, program
            )
691

X
xiongkun 已提交
692
        self._forward_end_index_map[
693
            paddle.utils._hash_with_id(program, self)
X
xiongkun 已提交
694
        ] = start_idx - len(self._outputs.tolist())
695 696
        return program

697 698 699
    def _prune_unused_params(self, program):
        """
        Prune the parameters not used anywhere in the program.
H
hjyp 已提交
700
        The `@to_static` may only decorated a sub function which
701 702 703 704 705 706
        contains some unused parameters created in `__init__`.
        So prune these parameters to avoid unnecessary operations in
        `run_program_op`.
        """
        required_params = []
        for param in self._params:
707
            found_param = False
708
            for block in program.blocks:
709
                for op in block.ops:
710 711 712 713
                    if (
                        param.name in op.input_arg_names
                        or param.name in op.output_arg_names
                    ):
714 715 716 717
                        required_params.append(param)
                        found_param = True
                        break
                if found_param:
718 719 720 721
                    break

        self._params = required_params

722 723 724 725 726 727 728 729 730 731 732
    def _cast_fp16_if_pure_fp16(self, in_vars):
        if _in_pure_fp16_guard():
            for i, var in enumerate(in_vars):
                name = var.name
                if (
                    self.program.global_block().has_var(name)
                    and self.program.global_block().var(name).dtype
                    == paddle.float16
                ):
                    in_vars[i] = var.astype('float16')
                    in_vars[i].name = name
733

734
    def _prepare_attributes(self):
735
        attrs = [
736 737 738 739
            'forward_global_block',
            self.forward_program.desc.block(0),
            'backward_global_block',
            self.backward_program.desc.block(0),
740 741 742 743
            'is_test',
            not self.training,
            'program_id',
            self.program_id,
744
        ]
X
xiongkun 已提交
745

746 747 748 749 750 751 752
        if self.training:
            # NOTE: In the case of higher-order gradient, the names of the parameter grads may be like
            # `grad/grad/grad/linear_0.w_0@GRAD` instead of simply `linear_0.w_0@GRAD`, so we get
            # the correct names of the parameter grads from program. And out grads are similar to above.
            attrs.extend(
                (
                    'param_grad_names',
753
                    self._grad_var_names.get('param', []),
754
                    'out_grad_names',
755 756 757
                    self._grad_var_names.get('out', []),
                    'x_grad_names',
                    self._grad_var_names.get('x', []),
758 759
                )
            )
760 761
        if self._cuda_graph_capture_mode:
            attrs.extend(
762 763 764 765 766 767 768
                (
                    'cuda_graph_capture_mode',
                    self._cuda_graph_capture_mode,
                    'cuda_graph_pool_id',
                    self._cuda_graph_pool_id,
                )
            )
769
        return attrs
770

771 772 773 774 775 776 777 778 779 780 781 782
    @switch_to_static_graph
    def _build_infer_program(self, infer_program, forward_end_op_index):
        forward_skip_vars = self._parse_skip_gc_vars(infer_program)
        builded_infer_program = add_build_strategy_for(
            infer_program,
            0,
            forward_end_op_index,
            self._build_strategy,
            forward_skip_vars,
        )
        self._apply_inplace_pass(builded_infer_program, None)
        return builded_infer_program
783

784
    @switch_to_static_graph
785 786 787
    def _get_forward_backward_program_form(
        self, whole_program, forward_end_op_index
    ):
788 789
        # NOTE(dev): We apply build_strategy for backward firstly to
        # avoid skipping more gc variables.
790
        backward_start_op_index = forward_end_op_index + len(
791 792
            self._outputs.var_ids
        )
793
        backward_end_op_index = whole_program.desc.block(0).op_size()
794 795
        # For Backward process in CINN, all param@GRAD shoule be skipped for GC, because
        # they will be shared in scope and used by optimizer.
796 797 798
        backward_skip_vars = self._parse_skip_gc_vars(
            whole_program
        ) + self._grad_var_names.get('param', [])
799
        backward_builded_program = add_build_strategy_for(
800 801 802 803
            whole_program,
            backward_start_op_index,
            backward_end_op_index,
            self._build_strategy,
804 805 806 807 808 809 810 811 812 813 814 815
            backward_skip_vars,
        )

        forward_skip_vars = self._parse_skip_gc_vars(
            whole_program, backward_builded_program
        )
        forward_builded_program = add_build_strategy_for(
            whole_program,
            0,
            forward_end_op_index,
            self._build_strategy,
            forward_skip_vars,
816
        )
817

818 819 820
        self._apply_inplace_pass(
            forward_builded_program, backward_builded_program
        )
821 822 823 824 825 826
        return [forward_builded_program, backward_builded_program]

    def _apply_inplace_pass(self, forward_program, backward_program):
        attr_types = {
            "use_cuda": "bool",
            "mem_opt_skip_vars": "list[str]",
827
            "for_partial_block": "bool",
828 829 830 831
        }
        empty_startup_program = paddle.static.Program()
        use_cuda = True if core.is_compiled_with_cuda() else False
        # skip data var
832 833 834 835
        forward_mem_opt_skip_vars = self._parse_skip_gc_vars(
            forward_program, backward_program
        )
        backward_mem_opt_skip_vars = self._parse_skip_gc_vars(forward_program)
836 837 838 839 840 841
        if forward_program:
            attrs = {
                "use_cuda": use_cuda,
                "mem_opt_skip_vars": forward_mem_opt_skip_vars,
                "for_partial_block": True,
            }
842 843 844 845 846 847 848 849
            if not os.getenv("FLAGS_enable_new_ir_in_executor"):
                _apply_pass(
                    forward_program,
                    empty_startup_program,
                    "buffer_shared_inplace_pass",
                    attrs,
                    attr_types,
                )
850 851 852 853 854 855
        if backward_program:
            attrs = {
                "use_cuda": use_cuda,
                "mem_opt_skip_vars": backward_mem_opt_skip_vars,
                "for_partial_block": True,
            }
856 857 858 859 860 861 862 863
            if not os.getenv("FLAGS_enable_new_ir_in_executor"):
                _apply_pass(
                    backward_program,
                    empty_startup_program,
                    "buffer_shared_inplace_pass",
                    attrs,
                    attr_types,
                )
864

865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891
    @LazyInitialized
    def _inout_var_names(self):
        """
        Returns Variable Names from self._inputs and self.outputs
        """
        var_names = []
        for var in self._inputs:
            if isinstance(var, paddle.fluid.framework.Variable):
                var_names.append(var.desc.name())
        for var in self._outputs:
            if isinstance(var, paddle.fluid.framework.Variable):
                var_names.append(var.desc.name())
        return var_names

    def _parse_skip_gc_vars(self, program, backward_program=None):
        """
        Parse variables that need to skip GC after execute it.
        If specify backward_program, it will keep the variables used in backward.
        """
        # skip data var, DO NOT ignore this deepcopy
        skip_vars = deepcopy(self._inout_var_names)
        for var_name, var in program.global_block().vars.items():
            if var.is_data:
                skip_vars.append(var_name)

        if backward_program:
            for var_name in core.parse_safe_eager_deletion_skip_vars(
892
                backward_program.desc, True
893 894 895 896
            ):
                skip_vars.append(var_name)
        return skip_vars

897 898 899 900 901
    def _prepare(self, inputs):
        """
        Prepare inputs, outputs, attrs.
        """
        assert isinstance(inputs, (tuple, list))
902
        # Flatten inputs with nested structure into single list.
903
        flatten_inputs = paddle.utils.flatten(inputs)
W
wanghuancoder 已提交
904
        # Convert variable into Tensor and feed in training data.
905
        input_vars = []
906
        input_var_names = []
907
        expected_place = framework._current_expected_place()
908
        for i, value in enumerate(flatten_inputs):
909
            if isinstance(value, np.ndarray):
J
Jiabin Yang 已提交
910
                var = None
W
wanghuancoder 已提交
911 912 913 914 915 916 917 918
                var = core.eager.Tensor(
                    value=value,
                    name=self._inputs[i].desc.name(),
                    persistable=False,
                    place=expected_place,
                    zero_copy=True,
                )
            elif isinstance(value, core.eager.Tensor):
919 920 921 922
                # NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
                # into CUDAPlace when it's as input of multi Ops. so we move it in advance
                # to avoid this problem.
                if value.stop_gradient and not value.place._equals(
923 924
                    expected_place
                ):
925 926
                    var = value._copy_to(expected_place, False)
                    var.stop_gradient = True
927 928
                else:
                    var = value
929 930
            else:
                continue
931
            input_var_names.append(self._inputs[i].desc.name())
932
            input_vars.append(var)
933

W
wanghuancoder 已提交
934
        # mapping from name(string) -> Tensor
935
        out_tensor_map = {}
936

937 938
        def create_out(var_id):
            var = self._outputs[var_id]
939
            assert isinstance(var, framework.Variable)
940
            var_desc = var.desc
941

942 943
            if var_desc.name() in out_tensor_map:
                return out_tensor_map[var_desc.name()]
944

945
            out = core.eager.Tensor(
W
wanghuancoder 已提交
946 947 948 949 950 951
                var_desc.dtype(),
                var_desc.shape(),
                var_desc.name(),
                var_desc.type(),
                False,
            )
952 953 954
            out.stop_gradient = var.stop_gradient
            out_tensor_map[var_desc.name()] = out
            return out
955

W
wanghuancoder 已提交
956
        # Create Tensor to receive output data.
957 958
        out_vars = list(map(create_out, self._outputs.var_ids))

959
        return input_vars, out_vars, input_var_names
960

961
    def _create_scope_vec(self, program_id=None, use_scope_cache=False):
962
        # Hold forward variables
J
Jiabin Yang 已提交
963
        tmp_scope_vec = None
964 965 966
        inner_scope = self._get_scope(
            program_id=program_id, use_scope_cache=use_scope_cache
        )
W
wanghuancoder 已提交
967
        tmp_scope_vec = [inner_scope]
968
        return tmp_scope_vec
969

970
    def _create_cuda_graph_vec(self):
W
wanghuancoder 已提交
971
        var = core.eager.Tensor(
972 973 974 975 976 977
            core.VarDesc.VarType.FP32,
            [],
            "cuda_graph",
            core.VarDesc.VarType.RAW,
            True,
        )
978 979 980
        var.stop_gradient = True
        return var

X
xiongkun 已提交
981 982 983 984 985 986 987 988 989 990 991
    def _update_stop_gradient(self, out_vars):
        # Update stop_gradient for all outputs
        def set_stop_gradient(var_id, eager_tensor):
            var = self._outputs[var_id]
            assert isinstance(var, framework.Variable)
            eager_tensor.stop_gradient = var.stop_gradient
            return None

        for idx, var in zip(self._outputs.var_ids, out_vars):
            set_stop_gradient(idx, var)

992 993
    def _restore_out(self, out_vars):
        """
W
wanghuancoder 已提交
994
        Restores same nested outputs by only replacing the Variable with Tensor.
995 996 997 998 999 1000
        """

        flatten_outputs = self._outputs.tolist()
        for i, idx in enumerate(self._outputs.var_ids):
            flatten_outputs[idx] = out_vars[i]
        outs = self._outputs.restore(flatten_outputs)
1001
        if outs is not None and len(outs) == 1:
1002 1003 1004 1005
            outs = outs[0]

        return outs

1006 1007 1008 1009
    @switch_to_static_graph
    def _clone_for_test(self, main_program):
        return main_program.clone(for_test=True)

1010
    def _is_no_value(self, var):
W
wanghuancoder 已提交
1011
        if isinstance(var, core.eager.Tensor) and var.shape == [1]:
1012 1013
            # NOTE: .numpy() will insert MemcpySync operation, it hits performance.
            if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
1014 1015 1016 1017 1018 1019 1020
                return True
        return False

    def _remove_no_value(self, out_vars):
        """
        Removes invalid value for various-length return statement
        """
W
wanghuancoder 已提交
1021
        if isinstance(out_vars, core.eager.Tensor):
1022 1023 1024 1025 1026
            if self._is_no_value(out_vars):
                return None
            return out_vars
        elif isinstance(out_vars, (tuple, list)):
            if isinstance(out_vars, tuple):
1027 1028 1029
                res = tuple(
                    var for var in out_vars if not self._is_no_value(var)
                )
1030 1031 1032 1033
            else:
                # isinstance(out_vars, list)
                res = [var for var in out_vars if not self._is_no_value(var)]

1034
            has_removed = len(out_vars) > len(res)
1035 1036 1037 1038 1039 1040 1041 1042 1043 1044
            # len(out_vars) > len(res) means we have removed var. This is
            # preventing out_vars is empty or just one element at the beginning
            if len(res) == 0 and has_removed:
                return None
            elif len(res) == 1 and has_removed:
                return res[0]
            return res

        return out_vars

1045
    def _set_grad_type(self, params, train_program):
1046 1047
        # NOTE: if user set sparse gradient mode, the param's gradient
        # will be SelectedRows, not LoDTensor. But tracer will just
W
wanghuancoder 已提交
1048
        # set param grad Tensor by forward Tensor(LoDTensor)
1049 1050 1051 1052 1053
        # If we don't change grad_var type here, RunProgramOp need
        # transform SelectedRows to LoDTensor forcibly, it may not
        # be user wanted result.
        for param in params:
            grad_name = param.name + core.grad_var_suffix()
1054
            grad_var = train_program.desc.block(0).find_var(grad_name.encode())
1055 1056 1057 1058 1059
            # NOTE: cannot find var desc maybe no problem, such as in batch_norm
            if grad_var is None:
                continue
            param._set_grad_type(grad_var.type())

1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072
    def _remove_op_call_stack(self, main_program):
        """
        Remove op's python call stack with redundant low-level error messages related to
        transforamtions to avoid confusing users.
        """
        assert isinstance(main_program, framework.Program)
        for block in main_program.blocks:
            for op in block.ops:
                if op.has_attr("op_callstack"):
                    op._remove_attr("op_callstack")

        return main_program

1073 1074 1075
    def _check_params_all_inited(self, main_program):
        """
        Check all params from main program are already initialized, see details as follows:
W
wanghuancoder 已提交
1076
            1. all parameters in self._params should be type `framework.EagerParamBase` which are created in dygraph.
1077
            2. all parameters from transformed program can be found in self._params.
W
wanghuancoder 已提交
1078
               Because they share same data with EagerParamBase of original dygraph.
1079 1080 1081 1082
        """
        if not isinstance(self._params, (list, tuple)):
            raise TypeError(
                "Type of self._params in PartialProgramLayer should be list or tuple, but received %s."
1083 1084
                % type(self._params)
            )
1085

1086 1087 1088
        param_and_buffer_names_set = set()
        for i, var in enumerate(self._params):
            # self._params constains parameters and buffers with persistable=True.
W
wanghuancoder 已提交
1089
            if not isinstance(var, core.eager.Tensor):
1090
                raise TypeError(
1091 1092 1093 1094
                    'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.'.format(
                        i, type(var)
                    )
                )
1095
            param_and_buffer_names_set.add(var.name)
1096 1097

        for block in main_program.blocks:
1098
            for name, var in block.vars.items():
1099
                if isinstance(var, framework.Parameter):
1100
                    if name not in param_and_buffer_names_set:
1101
                        raise ValueError(
1102 1103 1104 1105 1106 1107
                            "\n\tWe don't support to define layer with parameters in the function decorated by `@to_static`."
                            "\n\tBut we found parameter(%s) was created in the decorated function."
                            "\n"
                            "\n\tRevise suggestion: "
                            "\n\t\t1. Please ensure all your sublayers are inheritted from nn.Layer."
                            "\n\t\t2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List"
1108 1109
                            % name
                        )
1110

1111
    def _valid_vars(self, vars):
1112
        return vars if vars else None
1113

1114

1115
def partial_program_from(concrete_program, from_method=False):
1116
    inputs = concrete_program.inputs
1117 1118 1119

    # NOTE(SigureMo): Remove the first arg `self` from method args.
    if inputs and from_method:
1120 1121
        inputs = inputs[1:]

1122 1123 1124 1125
    return PartialProgramLayer(
        concrete_program.main_program,
        inputs,
        concrete_program.outputs,
1126
        concrete_program.name_generator,
1127 1128 1129
        concrete_program.parameters,
        **concrete_program.kwargs
    )
1130 1131 1132


@switch_to_static_graph
1133
def add_build_strategy_for(
1134
    program, start_op_index, end_op_index, build_strategy=None, skip_vars=None
1135 1136
):
    if start_op_index < end_op_index:
1137 1138
        compiled_program = paddle.static.CompiledProgram(
            core.Graph(program.desc, start_op_index, end_op_index),
1139 1140
            build_strategy=build_strategy,
        )
1141 1142 1143
        if skip_vars:
            # TODO(Aurelius84): Need to unify name with C++, such as kSkipVarNames.
            compiled_program._graph.set("skip_gc_vars", set(skip_vars))
1144 1145 1146
        compiled_program._compile(
            core.Scope(), framework._current_expected_place()
        )
1147 1148
        ir_graph = framework.IrGraph(compiled_program._graph)
        builded_program = ir_graph.to_program()
1149 1150 1151 1152
        if hasattr(compiled_program._program, 'lr_scheduler'):
            builded_program.lr_scheduler = (
                compiled_program._program.lr_scheduler
            )
1153
    else:
X
xiongkun 已提交
1154
        # can't just create a new program, we need copy the vardesc.
1155
        builded_program = paddle.static.Program()
X
xiongkun 已提交
1156 1157
        for var in program.block(0).vars.values():
            builded_program.block(0)._clone_variable(var, False)
1158
    return builded_program