partial_program.py 37.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
16

17
import paddle
18
from paddle.fluid import framework, backward, core, program_guard
19
from paddle.fluid.executor import _is_enable_standalone_executor, _is_dy2st_enable_standalone_executor
20
from paddle.fluid.dygraph import layers
21
from paddle.fluid.dygraph.base import switch_to_static_graph
22
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
23
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
24 25
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
26 27
from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.compiler import BuildStrategy
28
from paddle.fluid.framework import _apply_pass
29
from paddle.fluid.contrib.mixed_precision.decorator import AutoMixedPrecisionLists
30 31
from paddle.fluid.contrib.mixed_precision.fp16_utils import rewrite_program, cast_model_to_fp16
from paddle.fluid.dygraph.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
32
from paddle import _legacy_C_ops
33

34 35 36 37 38 39 40 41 42

class NestSequence(object):
    """
    A wrapper class that easily to flatten and restore the nest structure of
    given sequence.
    """

    def __init__(self, raw_input, need_check=False):
        self.__raw_input = raw_input
43
        self.__input_list = self.tolist()
44 45 46 47 48 49 50 51 52 53 54 55 56
        self.__var_ids = self._get_var_ids()
        self._check_non_variable(need_check)

    def tolist(self):
        """
        Flattens the nested sequences into single list.
        """
        return flatten(self.__raw_input)

    def restore(self, value_list):
        """
        Restores the nested sequence from value list.
        """
57
        assert len(self.__input_list) == len(value_list)
58 59 60 61
        return pack_sequence_as(self.__raw_input, value_list)

    def _get_var_ids(self):
        var_ids = []
62
        for idx, var in enumerate(self.__input_list):
63 64
            if isinstance(
                    var, (framework.Variable, core.VarBase, core.eager.Tensor)):
65 66 67 68 69 70 71 72 73 74
                var_ids.append(idx)

        return var_ids

    def _check_non_variable(self, need_check):
        """
        Raises warning if output of traced function contains non-tensor type values.
        """
        if need_check:
            warning_types = set()
75
            for var in self.__input_list:
76 77 78
                if not isinstance(
                        var,
                    (framework.Variable, core.VarBase, core.eager.Tensor)):
79 80
                    warning_types.add(type(var))
            if warning_types:
81
                logging_utils.warn(
82 83 84 85 86 87 88 89 90 91
                    "Output of traced function contains non-tensor type values: {}. "
                    "Currently, We don't support to update them while training and will return "
                    "what we first saw. Please try to return them as tensor.".
                    format(list(warning_types)))

    @property
    def var_ids(self):
        return self.__var_ids

    def __getitem__(self, item):
92
        return self.__input_list[item]
93

94

95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
class LazyInitialized(object):
    """
    Descriptor to implement lazy initialization of property.
    """

    def __init__(self, function):
        self.function = function

    def __get__(self, instance, cls):
        val = self.function(instance)
        setattr(instance, self.function.__name__, val)
        return val


def _change_is_test_status(program, is_test):
    # change all `is_test` attributes
    for block in program.blocks:
        for op in block.ops:
            if op.has_attr('is_test'):
                op._set_attr('is_test', is_test)
    return program


118
class PartialProgramLayer:
119 120 121 122 123
    """
    PartialProgramLayer wraps all the ops from layers decorated by `@declarative`
    and execute them as a static subgraph.

    .. note::
124 125 126
        **1. This is a very low level API. Users should not use this API
             directly. Please use `partial_program_from(concrete_program)`
             to create it.
127 128 129 130 131 132 133 134 135 136 137 138
        **2. LoDTensorArray is not currently supported in the output.

    Args:
        main_program(Program): The main program that contains ops need to be executed.
        inputs(list[Variable]): The input list of the decorated function by `@declarative`.
        outputs(list[Variable]): The output list of the decorated function by `@declarative`.
        parameters(list[VarBase]|None): All trainable parameters included in the program. Default None.

    Returns:
        Layer: A Layer object that run all ops internally in static mode.
    """

139 140 141 142 143
    def __init__(self,
                 main_program,
                 inputs,
                 outputs,
                 parameters=None,
144
                 **kwargs):
145
        super(PartialProgramLayer, self).__init__()
146 147
        self._inputs = NestSequence(inputs)
        self._outputs = NestSequence(outputs, need_check=True)
148
        self._params = parameters if parameters is not None else []
149

150 151 152
        self._build_strategy = kwargs.get('build_strategy', BuildStrategy())
        assert isinstance(self._build_strategy, BuildStrategy)

153
        self._origin_main_program = self._verify_program(main_program)
154 155 156
        self._cuda_graph_vec = self._create_cuda_graph_vec()
        self._cuda_graph_capture_mode = ""
        self._cuda_graph_pool_id = 0
157
        # Set default mode to train
158
        self.training = True
159

160 161 162 163
        custom_white_list, custom_black_list = None, None
        tracer = framework._dygraph_tracer()
        if tracer:
            custom_white_list, custom_black_list = tracer._get_amp_op_list()
164
        # For AMP training
165 166 167
        self._amp_list = AutoMixedPrecisionLists(
            custom_white_list=custom_white_list,
            custom_black_list=custom_black_list)
168

169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
        # program_id -> list(scope)
        self._scope_cache = {}

    def _get_scope(self, program_id=None, use_scope_cache=False):
        if use_scope_cache:
            if program_id not in self._scope_cache:
                scope = core.Scope()
                self._scope_cache[program_id] = [scope]
                return scope
            else:
                for scope in self._scope_cache[program_id]:
                    if scope._can_reuesd:
                        return scope
                scope = core.Scope()
                self._scope_cache[program_id].append(scope)
                return scope
        else:
            return core.Scope()

188 189 190 191 192 193 194 195
    @LazyInitialized
    def __fake_vars(self):
        return _create_fake_var()

    @LazyInitialized
    def _double_grads(self):
        return self._get_double_grads(self._origin_main_program)

196 197 198 199 200 201 202 203 204 205 206
    # whole
    @switch_to_static_graph
    def _create_program(self, is_infer_mode=False):
        if is_infer_mode:
            return self._origin_main_program.clone(for_test=is_infer_mode)
        else:
            train_program = self._append_backward_desc(
                self._origin_main_program)
            # Note: Only set grad type once after initializing train program. So we put it here.
            self._set_grad_type(self._params, train_program)
            return train_program
207

208 209 210 211 212 213 214 215 216 217 218
    @switch_to_static_graph
    def _create_amp_program(self, is_infer_mode=False):
        amp_program = self._origin_main_program.clone(for_test=is_infer_mode)
        with program_guard(amp_program):
            rewrite_program(amp_program, self._amp_list)
        if is_infer_mode:
            return amp_program
        else:
            train_amp_program = self._append_backward_desc(amp_program)
            self._set_grad_type(self._params, train_amp_program)
            return train_amp_program
219

220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
    @switch_to_static_graph
    def _create_pure_fp16_program(self, is_infer_mode=False):
        pure_fp16_program = self._origin_main_program.clone(
            for_test=is_infer_mode)
        with program_guard(pure_fp16_program):
            cast_model_to_fp16(pure_fp16_program,
                               self._amp_list,
                               use_fp16_guard=False)
        if is_infer_mode:
            return pure_fp16_program
        else:
            train_pure_fp16_program = self._append_backward_desc(
                pure_fp16_program)
            self._set_grad_type(self._params, train_pure_fp16_program)
            return train_pure_fp16_program
235

236
    @switch_to_static_graph
237 238 239 240 241
    def _create_forward_backward_train_program(self):
        whole_program = self._create_program()
        forward_end_op_index = self._infer_program.desc.block(0).op_size()
        return self._get_forward_backward_program_form(whole_program,
                                                       forward_end_op_index)
242

243 244 245 246 247 248 249 250 251 252 253 254 255 256
    @switch_to_static_graph
    def _create_forward_backward_train_amp_program(self):
        whole_program = self._create_amp_program()
        forward_end_op_index = self._infer_amp_program.desc.block(0).op_size()
        return self._get_forward_backward_program_form(whole_program,
                                                       forward_end_op_index)

    @switch_to_static_graph
    def _create_forward_backward_train_pure_fp16_program(self):
        whole_program = self._create_pure_fp16_program()
        forward_end_op_index = self._infer_pure_fp16_program.desc.block(
            0).op_size()
        return self._get_forward_backward_program_form(whole_program,
                                                       forward_end_op_index)
257 258

    @LazyInitialized
259 260
    def _train_program(self):
        return self._create_program()
261

262
    @LazyInitialized
263 264
    def _infer_program(self):
        return self._create_program(is_infer_mode=True)
265

266 267 268 269 270 271 272
    @LazyInitialized
    def _train_amp_program(self):
        return self._create_amp_program()

    @LazyInitialized
    def _infer_amp_program(self):
        return self._create_amp_program(is_infer_mode=True)
273 274 275

    @LazyInitialized
    def _train_pure_fp16_program(self):
276
        return self._create_pure_fp16_program()
277

278
    @LazyInitialized
279 280
    def _infer_pure_fp16_program(self):
        return self._create_pure_fp16_program(is_infer_mode=True)
281

282
    @LazyInitialized
283 284 285
    def _train_forward_backward_program(self):
        program = self._create_forward_backward_train_program()
        return program
286 287

    @LazyInitialized
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
    def _train_amp_forward_backward_program(self):
        program = self._create_forward_backward_train_amp_program()
        return program

    @LazyInitialized
    def _train_pure_fp16_forward_backward_program(self):
        program = self._create_forward_backward_train_pure_fp16_program()
        return program

    @property
    def whole_program(self):
        if self.training:
            if _in_amp_guard():
                return self._train_amp_program
            elif _in_pure_fp16_guard():
                return self._train_pure_fp16_program
            else:
                return self._train_program
        else:
            if _in_amp_guard():
                return self._infer_amp_program
            elif _in_pure_fp16_guard():
                return self._infer_pure_fp16_program
            else:
                return self._infer_program

    @property
    def forward_program(self):
        if self.training:
            if _in_amp_guard():
                program = self._train_amp_forward_backward_program
                return program[0]
            elif _in_pure_fp16_guard():
                program = self._train_pure_fp16_forward_backward_program
                return program[0]
            else:
                program = self._train_forward_backward_program
                return program[0]
        else:
            if _in_amp_guard():
                return self._infer_amp_program
            elif _in_pure_fp16_guard():
                return self._infer_pure_fp16_program
            else:
                return self._infer_program

    @property
    def backward_program(self):
        if self.training:
            if _in_amp_guard():
                program = self._train_amp_forward_backward_program
                return program[1]
            elif _in_pure_fp16_guard():
                program = self._train_pure_fp16_forward_backward_program
                return program[1]
            else:
                program = self._train_forward_backward_program
                return program[1]
        else:
            return paddle.static.Program()
348

349 350
    @LazyInitialized
    def _train_program_id(self):
351 352 353 354
        program_id = _hash_with_id(self._train_program, self)
        core._set_cached_executor_build_strategy(program_id,
                                                 self._build_strategy)
        return program_id
355

356 357 358 359
    @LazyInitialized
    def _infer_program_id(self):
        return _hash_with_id(self._infer_program, self)

360 361 362 363 364 365 366
    @LazyInitialized
    def _train_amp_program_id(self):
        program_id = _hash_with_id(self._train_amp_program, self)
        core._set_cached_executor_build_strategy(program_id,
                                                 self._build_strategy)
        return program_id

367 368 369 370
    @LazyInitialized
    def _infer_amp_program_id(self):
        return _hash_with_id(self._infer_amp_program, self)

371 372 373 374 375 376 377
    @LazyInitialized
    def _train_pure_fp16_program_id(self):
        program_id = _hash_with_id(self._train_pure_fp16_program, self)
        core._set_cached_executor_build_strategy(program_id,
                                                 self._build_strategy)
        return program_id

378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398
    @LazyInitialized
    def _infer_pure_fp16_program_id(self):
        return _hash_with_id(self._infer_pure_fp16_program, self)

    @property
    def whole_program_id(self):
        if self.training:
            if _in_amp_guard():
                return self._train_amp_program_id
            elif _in_pure_fp16_guard():
                return self._train_pure_fp16_program_id
            else:
                return self._train_program_id
        else:
            if _in_amp_guard():
                return self._infer_amp_program_id
            elif _in_pure_fp16_guard():
                return self._infer_pure_fp16_program_id
            else:
                return self._infer_program_id

399 400 401 402 403 404 405 406 407 408 409 410
    def _verify_program(self, main_program):
        """
        Verify that the program parameter is initialized, prune some unused params,
        and remove redundant op callstack.
        """
        # 1. Check all params from main program can be found in self._params
        self._check_params_all_inited(main_program)
        # 2. Prune the parameters not used anywhere in the program.
        self._prune_unused_params(main_program)

        return main_program

411 412
    def prepare_gradient_aggregation(self, start_idx, main_program,
                                     target_program):
413 414 415 416 417 418 419
        """
        Why we need add gradient aggregation operation ?
        In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
        def forward(self, in):
            x = 2 * in  # <---- x is a non-leaf node in program.
            y = x + 3
            return x, y
420

421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
        loss = forward(in)[0].sum()
        loss.backward()  # <----- x@grad will be overwrited by elementwise_add_grad Op
        """

        def _need_aggregation(var):
            """
            if exist a op whose inputs is var, then return True
            """
            if not isinstance(var, framework.Variable) or var.type not in [
                    core.VarDesc.VarType.LOD_TENSOR,
                    core.VarDesc.VarType.SELECTED_ROWS
            ]:
                return False
            if var.dtype not in [paddle.float32, paddle.float64]:
                return False
            for op in main_program.block(0).ops:
                for in_arg in op.input_arg_names:
                    if in_arg == var.name:
                        return True
            return False

        def _insert_aggregation_ops_for_var(target_program, var):
            suffix = "@dy2static"
            var_grad_name = var.grad_name
            new_grad_name = var.name + suffix + "@GRAD"
            finded_ops = list(
                filter(
448
                    lambda x: x[0] >= start_idx and any([
449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
                        out_arg == var_grad_name
                        for out_arg in x[1].output_arg_names
                    ]), enumerate(target_program.block(0).ops)))

            # len(finded_ops) may equals zero when stop_gradient works.
            # len(finded_ops) may > 1, because we may have fill_constant op.
            if len(finded_ops) == 0:
                return None
            # step1: create a new var named var.name@GRAD
            target_program.block(0).create_var(name=new_grad_name,
                                               type=var.type,
                                               dtype=var.dtype,
                                               shape=var.shape)
            # step2: rename the var.name@GRAD to var.name@GRAD@dy2static
            for idx, op in finded_ops:
                op._rename_input(var_grad_name, new_grad_name)
                op._rename_output(var_grad_name, new_grad_name)
            # step3: insert sum op to aggregate the gradient.
            #        var.name@GRAD = sum(var.name@dy2static@GRAD, var.name@GRAD)
            target_program.block(0)._insert_op(
                finded_ops[-1][0] + 1,
                type='sum',
                inputs={'X': [var_grad_name, new_grad_name]},
                outputs={"Out": var_grad_name})
            return None

        to_processed_vars = list(
            filter(_need_aggregation, self._outputs.tolist()))
        for _var in to_processed_vars:
            _insert_aggregation_ops_for_var(target_program, _var)

480
    @switch_to_static_graph
481
    def _append_backward_desc(self, main_program):
482 483
        # make sure all status of is_test are False in train mode.
        program = _change_is_test_status(main_program.clone(), is_test=False)
484
        targets = []
485
        for out in self._outputs.tolist():
486 487 488 489 490 491
            if isinstance(out, framework.Variable):
                targets.append(program.global_block().var(out.name))

        if targets and self._params:
            backward.gradients(targets=targets, inputs=[])

492 493 494 495
        start_idx = len(
            main_program.block(0).ops) + 2 * len(self._outputs.tolist())

        self.prepare_gradient_aggregation(start_idx, main_program, program)
496

497 498
        return program

499 500 501 502 503 504 505 506 507 508
    def _prune_unused_params(self, program):
        """
        Prune the parameters not used anywhere in the program.
        The `@declarative` may only decorated a sub function which
        contains some unused parameters created in `__init__`.
        So prune these parameters to avoid unnecessary operations in
        `run_program_op`.
        """
        required_params = []
        for param in self._params:
509
            found_param = False
510
            for block in program.blocks:
511 512 513 514 515 516
                for op in block.ops:
                    if param.name in op.input_arg_names or param.name in op.output_arg_names:
                        required_params.append(param)
                        found_param = True
                        break
                if found_param:
517 518 519 520
                    break

        self._params = required_params

521 522 523 524 525 526
    def _get_double_grads(self, program):
        double_grads = []
        for block in program.blocks:
            for name in block.vars:
                if "@GRAD" in name:
                    var_desc = block.vars[name].desc
J
Jiabin Yang 已提交
527
                    var_base = None
J
Jiabin Yang 已提交
528
                    if not framework._in_eager_mode_:
J
Jiabin Yang 已提交
529 530 531 532 533 534 535 536 537
                        var_base = core.VarBase(var_desc.dtype(),
                                                var_desc.shape(),
                                                var_desc.name(),
                                                var_desc.type(), False)
                    else:
                        var_base = core.eager.Tensor(var_desc.dtype(),
                                                     var_desc.shape(),
                                                     var_desc.name(),
                                                     var_desc.type(), False)
538
                    double_grads.append(var_base)
539
        return self._valid_vars(double_grads)
540

541
    def _get_end_op_index(self):
542 543 544 545 546
        if _in_amp_guard():
            infer_program = self._infer_amp_program
        elif _in_pure_fp16_guard():
            infer_program = self._infer_pure_fp16_program
        else:
547
            infer_program = self.infer_program
548 549
        return infer_program.desc.block(0).op_size()

550 551
    def __call__(self, inputs):
        in_vars, out_vars = self._prepare(inputs)
552

553 554
        self._cast_fp16_if_pure_fp16(in_vars)

555
        attrs = [
556 557 558 559
            'global_block',
            self.program.desc.block(0), 'start_op_index', 0, 'end_op_index',
            self._get_end_op_index(), 'is_test', not self.training,
            'program_id', self.program_id
560 561 562 563 564
        ]
        if self._cuda_graph_capture_mode:
            attrs.extend(
                ('cuda_graph_capture_mode', self._cuda_graph_capture_mode,
                 'cuda_graph_pool_id', self._cuda_graph_pool_id))
565

566 567 568 569 570 571 572
        use_interpretorcore = _is_enable_standalone_executor(
        ) and _is_dy2st_enable_standalone_executor()
        attrs.extend(('use_interpretorcore', use_interpretorcore))
        if use_interpretorcore:
            attrs.extend(
                ('forward_global_block', self.forward_program.desc.block(0),
                 'backward_global_block', self.backward_program.desc.block(0)))
573

574 575 576 577 578 579 580 581 582 583 584 585 586
            _legacy_C_ops.run_program(
                self._valid_vars(in_vars), self._valid_vars(self._params),
                self._valid_vars(out_vars),
                self._create_scope_vec(program_id=self.program_id,
                                       use_scope_cache=True),
                self._double_grads, self._cuda_graph_vec, *attrs)
        else:
            _legacy_C_ops.run_program(self._valid_vars(in_vars),
                                      self._valid_vars(self._params),
                                      self._valid_vars(out_vars),
                                      self._create_scope_vec(),
                                      self._double_grads, self._cuda_graph_vec,
                                      *attrs)
587 588
        restored_nest_out = self._restore_out(out_vars)
        return self._remove_no_value(restored_nest_out)
589

590 591 592 593
    def _cast_fp16_if_pure_fp16(self, in_vars):
        if _in_pure_fp16_guard():
            for i, var in enumerate(in_vars):
                name = var.name
594 595 596
                if (self.program.global_block().has_var(name)
                        and self.program.global_block().var(name).dtype
                        == paddle.float16):
597 598 599
                    in_vars[i] = var.astype('float16')
                    in_vars[i].name = name

600 601
    @property
    def program(self):
602
        return self.whole_program
603

604 605
    @property
    def program_id(self):
606
        return self.whole_program_id
607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624

    @property
    def train_program(self):
        if _in_amp_guard():
            return self._train_amp_program
        elif _in_pure_fp16_guard():
            return self._train_pure_fp16_program
        else:
            return self._train_program

    @property
    def infer_program(self):
        if _in_amp_guard():
            return self._infer_amp_program
        elif _in_pure_fp16_guard():
            return self._infer_pure_fp16_program
        else:
            return self._infer_program
625

626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683
    @switch_to_static_graph
    def _get_forward_backward_program_form(self, whole_program,
                                           forward_end_op_index):
        forward_builded_program = add_build_strategy_for(
            whole_program, 0, forward_end_op_index, self._build_strategy)
        backward_start_op_index = forward_end_op_index + 2 * len(
            self._outputs.var_ids)
        backward_end_op_index = whole_program.desc.block(0).op_size()
        backward_builded_program = add_build_strategy_for(
            whole_program, backward_start_op_index, backward_end_op_index,
            self._build_strategy)
        self._apply_inplace_pass(forward_builded_program,
                                 backward_builded_program)
        return [forward_builded_program, backward_builded_program]

    def _apply_inplace_pass(self, forward_program, backward_program):
        attr_types = {
            "use_cuda": "bool",
            "mem_opt_skip_vars": "list[str]",
            "for_partial_block": "bool"
        }
        empty_startup_program = paddle.static.Program()
        use_cuda = True if core.is_compiled_with_cuda() else False
        # skip data var
        forward_mem_opt_skip_vars = []
        backward_mem_opt_skip_vars = []
        for var_name, var in forward_program.global_block().vars.items():
            if var.is_data:
                forward_mem_opt_skip_vars.append(var_name)
        for var_name, var in backward_program.global_block().vars.items():
            if var.is_data:
                backward_mem_opt_skip_vars.append(var_name)
        for var in self._inputs:
            if isinstance(var, paddle.fluid.framework.Variable):
                forward_mem_opt_skip_vars.append(var.desc.name())
                backward_mem_opt_skip_vars.append(var.desc.name())
        for var in self._outputs:
            if isinstance(var, paddle.fluid.framework.Variable):
                forward_mem_opt_skip_vars.append(var.desc.name())
                backward_mem_opt_skip_vars.append(var.desc.name())
        for var_name in core.parse_safe_eager_deletion_skip_vars(
                backward_program.desc):
            forward_mem_opt_skip_vars.append(var_name)
        attrs = {
            "use_cuda": use_cuda,
            "mem_opt_skip_vars": forward_mem_opt_skip_vars,
            "for_partial_block": True
        }
        _apply_pass(forward_program, empty_startup_program,
                    "buffer_shared_inplace_pass", attrs, attr_types)
        attrs = {
            "use_cuda": use_cuda,
            "mem_opt_skip_vars": backward_mem_opt_skip_vars,
            "for_partial_block": True
        }
        _apply_pass(backward_program, empty_startup_program,
                    "buffer_shared_inplace_pass", attrs, attr_types)

684 685 686 687 688
    def _prepare(self, inputs):
        """
        Prepare inputs, outputs, attrs.
        """
        assert isinstance(inputs, (tuple, list))
689 690
        # Flatten inputs with nested structure into single list.
        flatten_inputs = flatten(inputs)
691 692
        # Convert variable into VarBase and feed in training data.
        input_vars = []
693
        expected_place = framework._current_expected_place()
694
        for i, value in enumerate(flatten_inputs):
695
            if isinstance(value, np.ndarray):
J
Jiabin Yang 已提交
696
                var = None
J
Jiabin Yang 已提交
697
                if not framework._in_eager_mode_:
698 699 700 701 702
                    var = core.VarBase(value=value,
                                       name=self._inputs[i].desc.name(),
                                       persistable=False,
                                       place=expected_place,
                                       zero_copy=True)
J
Jiabin Yang 已提交
703
                else:
704 705 706 707 708
                    var = core.eager.Tensor(value=value,
                                            name=self._inputs[i].desc.name(),
                                            persistable=False,
                                            place=expected_place,
                                            zero_copy=True)
J
Jiabin Yang 已提交
709
            elif isinstance(value, (core.VarBase, core.eager.Tensor)):
710 711 712 713 714 715 716
                # NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
                # into CUDAPlace when it's as input of multi Ops. so we move it in advance
                # to avoid this problem.
                if value.stop_gradient and not value.place._equals(
                        expected_place):
                    var = value._copy_to(expected_place, False)
                    var.stop_gradient = True
717 718
                else:
                    var = value
719
                var.name = self._inputs[i].desc.name()
720 721 722
            else:
                continue
            input_vars.append(var)
723

724 725 726
        # mapping from name(string) -> VarBase
        out_varbase_map = {}

727 728
        def create_out(var_id):
            var = self._outputs[var_id]
729
            assert isinstance(var, framework.Variable)
730
            var_desc = var.desc
J
Jiabin Yang 已提交
731
            varbase = None
732 733 734 735

            if var_desc.name() in out_varbase_map:
                return out_varbase_map[var_desc.name()]

J
Jiabin Yang 已提交
736
            if not framework._in_eager_mode_:
737
                var_base = core.VarBase(var_desc.dtype(), var_desc.shape(),
J
Jiabin Yang 已提交
738 739
                                        var_desc.name(), var_desc.type(), False)
            else:
740 741 742
                var_base = core.eager.Tensor(var_desc.dtype(), var_desc.shape(),
                                             var_desc.name(), var_desc.type(),
                                             False)
743
            var_base.stop_gradient = var.stop_gradient
744
            out_varbase_map[var_desc.name()] = var_base
745 746 747 748 749 750
            return var_base

        # Create VarBase to receive output data.
        out_vars = list(map(create_out, self._outputs.var_ids))

        return input_vars, out_vars
751

752
    def _create_scope_vec(self, program_id=None, use_scope_cache=False):
753
        # Hold forward variables
J
Jiabin Yang 已提交
754
        tmp_scope_vec = None
755 756
        inner_scope = self._get_scope(program_id=program_id,
                                      use_scope_cache=use_scope_cache)
J
Jiabin Yang 已提交
757
        if not framework._in_eager_mode_:
J
Jiabin Yang 已提交
758 759 760 761
            tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
                                         "program_out_scope",
                                         core.VarDesc.VarType.STEP_SCOPES, True)
            tmp_scope_vec.value().set_scope(inner_scope)
762 763
        else:
            tmp_scope_vec = [inner_scope]
764
        return tmp_scope_vec
765

766 767 768 769 770 771
    def _create_cuda_graph_vec(self):
        var = core.VarBase(core.VarDesc.VarType.FP32, [], "cuda_graph",
                           core.VarDesc.VarType.RAW, True)
        var.stop_gradient = True
        return var

772 773 774 775 776 777 778 779 780
    def _restore_out(self, out_vars):
        """
        Restores same nested outputs by only replacing the Variable with VarBase.
        """

        flatten_outputs = self._outputs.tolist()
        for i, idx in enumerate(self._outputs.var_ids):
            flatten_outputs[idx] = out_vars[i]
        outs = self._outputs.restore(flatten_outputs)
781
        if outs is not None and len(outs) == 1:
782 783 784 785
            outs = outs[0]

        return outs

786 787 788 789
    @switch_to_static_graph
    def _clone_for_test(self, main_program):
        return main_program.clone(for_test=True)

790
    def _is_no_value(self, var):
J
Jiabin Yang 已提交
791 792
        if isinstance(var,
                      (core.VarBase, core.eager.Tensor)) and var.shape == [1]:
793 794
            # NOTE: .numpy() will insert MemcpySync operation, it hits performance.
            if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
795 796 797 798 799 800 801
                return True
        return False

    def _remove_no_value(self, out_vars):
        """
        Removes invalid value for various-length return statement
        """
J
Jiabin Yang 已提交
802
        if isinstance(out_vars, (core.VarBase, core.eager.Tensor)):
803 804 805 806 807
            if self._is_no_value(out_vars):
                return None
            return out_vars
        elif isinstance(out_vars, (tuple, list)):
            if isinstance(out_vars, tuple):
808 809
                res = tuple(var for var in out_vars
                            if not self._is_no_value(var))
810 811 812 813 814 815 816 817 818 819 820 821 822 823 824
            else:
                # isinstance(out_vars, list)
                res = [var for var in out_vars if not self._is_no_value(var)]

            has_removed = (len(out_vars) > len(res))
            # len(out_vars) > len(res) means we have removed var. This is
            # preventing out_vars is empty or just one element at the beginning
            if len(res) == 0 and has_removed:
                return None
            elif len(res) == 1 and has_removed:
                return res[0]
            return res

        return out_vars

825
    def _set_grad_type(self, params, train_program):
826 827 828 829 830 831 832 833
        # NOTE: if user set sparse gradient mode, the param's gradient
        # will be SelectedRows, not LoDTensor. But tracer will just
        # set param grad VarBase by forward VarBase(LoDTensor)
        # If we don't change grad_var type here, RunProgramOp need
        # transform SelectedRows to LoDTensor forcibly, it may not
        # be user wanted result.
        for param in params:
            grad_name = param.name + core.grad_var_suffix()
834
            grad_var = train_program.desc.block(0).find_var(grad_name.encode())
835 836 837 838 839
            # NOTE: cannot find var desc maybe no problem, such as in batch_norm
            if grad_var is None:
                continue
            param._set_grad_type(grad_var.type())

840 841 842 843 844 845 846 847 848 849 850 851 852
    def _remove_op_call_stack(self, main_program):
        """
        Remove op's python call stack with redundant low-level error messages related to
        transforamtions to avoid confusing users.
        """
        assert isinstance(main_program, framework.Program)
        for block in main_program.blocks:
            for op in block.ops:
                if op.has_attr("op_callstack"):
                    op._remove_attr("op_callstack")

        return main_program

853 854 855 856 857 858 859 860 861 862 863 864
    def _check_params_all_inited(self, main_program):
        """
        Check all params from main program are already initialized, see details as follows:
            1. all parameters in self._params should be type `framework.ParamBase` which are created in dygraph.
            2. all parameters from transformed program can be found in self._params.
               Because they share same data with ParamBase of original dygraph.
        """
        if not isinstance(self._params, (list, tuple)):
            raise TypeError(
                "Type of self._params in PartialProgramLayer should be list or tuple, but received %s."
                % type(self._params))

865 866 867
        param_and_buffer_names_set = set()
        for i, var in enumerate(self._params):
            # self._params constains parameters and buffers with persistable=True.
J
Jiabin Yang 已提交
868
            if not isinstance(var, (core.VarBase, core.eager.Tensor)):
869
                raise TypeError(
870 871
                    'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.'
                    .format(i, type(var)))
872
            param_and_buffer_names_set.add(var.name)
873 874

        for block in main_program.blocks:
875
            for name, var in block.vars.items():
876
                if isinstance(var, framework.Parameter):
877
                    if name not in param_and_buffer_names_set:
878
                        raise ValueError(
879 880 881 882 883 884
                            "\n\tWe don't support to define layer with parameters in the function decorated by `@to_static`."
                            "\n\tBut we found parameter(%s) was created in the decorated function."
                            "\n"
                            "\n\tRevise suggestion: "
                            "\n\t\t1. Please ensure all your sublayers are inheritted from nn.Layer."
                            "\n\t\t2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List"
885 886
                            % name)

887 888 889 890 891 892 893 894
    def _valid_vars(self, vars):
        """
        Note: run_program_op.InferShape requires `X`/'Out' not be null.
        But it's common in dy2static, fake varBase is created to handle the
        problem.
        """
        return vars if vars else self.__fake_vars

895

896
def _create_fake_var():
897
    """
898
    Create a fake_var (force on CPU) to handle empty input or output
899
    """
J
Jiabin Yang 已提交
900
    if not framework._in_eager_mode_:
J
Jiabin Yang 已提交
901 902 903 904 905
        return [
            core.VarBase(core.VarDesc.VarType.FP32, [], "Fake_var",
                         core.VarDesc.VarType.RAW, False)
        ]
    else:
906 907 908 909
        return [
            core.eager.Tensor(core.VarDesc.VarType.FP32, [], "Fake_var",
                              core.VarDesc.VarType.RAW, False)
        ]
910 911 912 913 914 915 916


def partial_program_from(concrete_program):
    inputs = concrete_program.inputs
    if inputs and isinstance(inputs[0], layers.Layer):
        inputs = inputs[1:]

917 918 919 920
    return PartialProgramLayer(concrete_program.main_program, inputs,
                               concrete_program.outputs,
                               concrete_program.parameters,
                               **concrete_program.kwargs)
921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940


@switch_to_static_graph
def add_build_strategy_for(program,
                           start_op_index,
                           end_op_index,
                           build_strategy=None):
    if (start_op_index < end_op_index):
        compiled_program = paddle.static.CompiledProgram(
            core.Graph(program.desc, start_op_index, end_op_index),
            build_strategy=build_strategy)
        compiled_program._compile(core.Scope(),
                                  framework._current_expected_place())
        ir_graph = framework.IrGraph(compiled_program._graph)
        builded_program = ir_graph.to_program()
        if hasattr(compiled_program._program, 'lr_sheduler'):
            builded_program.lr_sheduler = compiled_program._program.lr_sheduler
    else:
        builded_program = program
    return builded_program