partial_program.py 28.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import numpy as np
17
import six
18

19
import paddle
20
from paddle.fluid import framework, backward, core, program_guard
21
from paddle.fluid.dygraph import layers
22
from paddle.fluid.dygraph.base import switch_to_static_graph
23
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
24
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
25 26
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
27 28
from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.compiler import BuildStrategy
29
from paddle.fluid.contrib.mixed_precision.decorator import AutoMixedPrecisionLists
30 31
from paddle.fluid.contrib.mixed_precision.fp16_utils import rewrite_program, cast_model_to_fp16
from paddle.fluid.dygraph.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
32
import paddle.compat as cpt
33
from paddle import _C_ops, _legacy_C_ops
34

35 36 37 38 39 40 41 42 43

class NestSequence(object):
    """
    A wrapper class that easily to flatten and restore the nest structure of
    given sequence.
    """

    def __init__(self, raw_input, need_check=False):
        self.__raw_input = raw_input
44
        self.__input_list = self.tolist()
45 46 47 48 49 50 51 52 53 54 55 56 57
        self.__var_ids = self._get_var_ids()
        self._check_non_variable(need_check)

    def tolist(self):
        """
        Flattens the nested sequences into single list.
        """
        return flatten(self.__raw_input)

    def restore(self, value_list):
        """
        Restores the nested sequence from value list.
        """
58
        assert len(self.__input_list) == len(value_list)
59 60 61 62
        return pack_sequence_as(self.__raw_input, value_list)

    def _get_var_ids(self):
        var_ids = []
63
        for idx, var in enumerate(self.__input_list):
64 65
            if isinstance(
                    var, (framework.Variable, core.VarBase, core.eager.Tensor)):
66 67 68 69 70 71 72 73 74 75
                var_ids.append(idx)

        return var_ids

    def _check_non_variable(self, need_check):
        """
        Raises warning if output of traced function contains non-tensor type values.
        """
        if need_check:
            warning_types = set()
76
            for var in self.__input_list:
77 78 79
                if not isinstance(
                        var,
                    (framework.Variable, core.VarBase, core.eager.Tensor)):
80 81
                    warning_types.add(type(var))
            if warning_types:
82
                logging_utils.warn(
83 84 85 86 87 88 89 90 91 92
                    "Output of traced function contains non-tensor type values: {}. "
                    "Currently, We don't support to update them while training and will return "
                    "what we first saw. Please try to return them as tensor.".
                    format(list(warning_types)))

    @property
    def var_ids(self):
        return self.__var_ids

    def __getitem__(self, item):
93
        return self.__input_list[item]
94

95

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
class LazyInitialized(object):
    """
    Descriptor to implement lazy initialization of property.
    """

    def __init__(self, function):
        self.function = function

    def __get__(self, instance, cls):
        val = self.function(instance)
        setattr(instance, self.function.__name__, val)
        return val


def _change_is_test_status(program, is_test):
    # change all `is_test` attributes
    for block in program.blocks:
        for op in block.ops:
            if op.has_attr('is_test'):
                op._set_attr('is_test', is_test)
    return program


119
class PartialProgramLayer:
120 121 122 123 124
    """
    PartialProgramLayer wraps all the ops from layers decorated by `@declarative`
    and execute them as a static subgraph.

    .. note::
125 126 127
        **1. This is a very low level API. Users should not use this API
             directly. Please use `partial_program_from(concrete_program)`
             to create it.
128 129 130 131 132 133 134 135 136 137 138 139
        **2. LoDTensorArray is not currently supported in the output.

    Args:
        main_program(Program): The main program that contains ops need to be executed.
        inputs(list[Variable]): The input list of the decorated function by `@declarative`.
        outputs(list[Variable]): The output list of the decorated function by `@declarative`.
        parameters(list[VarBase]|None): All trainable parameters included in the program. Default None.

    Returns:
        Layer: A Layer object that run all ops internally in static mode.
    """

140 141 142 143 144
    def __init__(self,
                 main_program,
                 inputs,
                 outputs,
                 parameters=None,
145
                 **kwargs):
146
        super(PartialProgramLayer, self).__init__()
147 148
        self._inputs = NestSequence(inputs)
        self._outputs = NestSequence(outputs, need_check=True)
149
        self._params = parameters if parameters is not None else []
150

151 152 153
        self._build_strategy = kwargs.get('build_strategy', BuildStrategy())
        assert isinstance(self._build_strategy, BuildStrategy)

154
        self._origin_main_program = self._verify_program(main_program)
155 156 157
        self._cuda_graph_vec = self._create_cuda_graph_vec()
        self._cuda_graph_capture_mode = ""
        self._cuda_graph_pool_id = 0
158
        # Set default mode to train
159
        self.training = True
160

161 162 163 164
        custom_white_list, custom_black_list = None, None
        tracer = framework._dygraph_tracer()
        if tracer:
            custom_white_list, custom_black_list = tracer._get_amp_op_list()
165
        # For AMP training
166 167 168
        self._amp_list = AutoMixedPrecisionLists(
            custom_white_list=custom_white_list,
            custom_black_list=custom_black_list)
169

170 171 172 173 174 175 176 177
    @LazyInitialized
    def __fake_vars(self):
        return _create_fake_var()

    @LazyInitialized
    def _double_grads(self):
        return self._get_double_grads(self._origin_main_program)

178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
    @LazyInitialized
    def _infer_program(self):
        """
        Lazy initialized property of infer_program.
        """
        return self._clone_for_test(self._origin_main_program)

    @LazyInitialized
    def _train_program(self):
        """
        Lazy initialized property of train_program.
        """
        train_program = self._append_backward_desc(self._origin_main_program)
        # Note: Only set grad type once after initializing train program. So we
        # put it here.
        self._set_grad_type(self._params, train_program)

        return train_program

197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
    @LazyInitialized
    @switch_to_static_graph
    def _infer_amp_program(self):
        """
        Lazy initialized property of infer_amp_program.
        """
        infer_amp_program = self._origin_main_program.clone()
        with program_guard(infer_amp_program):
            rewrite_program(infer_amp_program, self._amp_list)

        return infer_amp_program

    @LazyInitialized
    def _train_amp_program(self):
        """
        Lazy initialized property of train_amp_program.
        """
214 215 216
        train_amp_program = self._append_backward_desc(self._infer_amp_program)
        self._set_grad_type(self._params, train_amp_program)
        return train_amp_program
217

218 219 220 221 222 223 224 225
    @LazyInitialized
    @switch_to_static_graph
    def _infer_pure_fp16_program(self):
        """
        Lazy initialized property of _infer_pure_fp16_program.
        """
        infer_pure_fp16_program = self._origin_main_program.clone()
        with program_guard(infer_pure_fp16_program):
226 227 228
            cast_model_to_fp16(infer_pure_fp16_program,
                               self._amp_list,
                               use_fp16_guard=False)
229 230 231 232 233 234 235 236

        return infer_pure_fp16_program

    @LazyInitialized
    def _train_pure_fp16_program(self):
        """
        Lazy initialized property of _train_pure_fp16_program.
        """
237 238 239 240
        train_pure_fp16_program = self._append_backward_desc(
            self._infer_pure_fp16_program)
        self._set_grad_type(self._params, train_pure_fp16_program)
        return train_pure_fp16_program
241

242 243 244 245
    @LazyInitialized
    def _infer_program_id(self):
        return _hash_with_id(self._infer_program, self)

246 247 248 249 250 251 252 253
    @LazyInitialized
    def _infer_pure_fp16_program_id(self):
        return _hash_with_id(self._infer_pure_fp16_program, self)

    @LazyInitialized
    def _infer_amp_program_id(self):
        return _hash_with_id(self._infer_amp_program, self)

254 255
    @LazyInitialized
    def _train_program_id(self):
256 257 258 259 260
        program_id = _hash_with_id(self._train_program, self)
        core._set_cached_executor_build_strategy(program_id,
                                                 self._build_strategy)

        return program_id
261

262 263 264 265 266 267 268 269
    @LazyInitialized
    def _train_amp_program_id(self):
        program_id = _hash_with_id(self._train_amp_program, self)
        core._set_cached_executor_build_strategy(program_id,
                                                 self._build_strategy)

        return program_id

270 271 272 273 274 275 276 277
    @LazyInitialized
    def _train_pure_fp16_program_id(self):
        program_id = _hash_with_id(self._train_pure_fp16_program, self)
        core._set_cached_executor_build_strategy(program_id,
                                                 self._build_strategy)

        return program_id

278 279 280 281 282 283 284 285 286 287 288 289
    def _verify_program(self, main_program):
        """
        Verify that the program parameter is initialized, prune some unused params,
        and remove redundant op callstack.
        """
        # 1. Check all params from main program can be found in self._params
        self._check_params_all_inited(main_program)
        # 2. Prune the parameters not used anywhere in the program.
        self._prune_unused_params(main_program)

        return main_program

290 291
    def prepare_gradient_aggregation(self, start_idx, main_program,
                                     target_program):
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
        """
        Why we need add gradient aggregation operation ?
        In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
        def forward(self, in):
            x = 2 * in  # <---- x is a non-leaf node in program.
            y = x + 3
            return x, y
        
        loss = forward(in)[0].sum()
        loss.backward()  # <----- x@grad will be overwrited by elementwise_add_grad Op
        """

        def _need_aggregation(var):
            """
            if exist a op whose inputs is var, then return True
            """
            if not isinstance(var, framework.Variable) or var.type not in [
                    core.VarDesc.VarType.LOD_TENSOR,
                    core.VarDesc.VarType.SELECTED_ROWS
            ]:
                return False
            if var.dtype not in [paddle.float32, paddle.float64]:
                return False
            for op in main_program.block(0).ops:
                for in_arg in op.input_arg_names:
                    if in_arg == var.name:
                        return True
            return False

        def _insert_aggregation_ops_for_var(target_program, var):
            suffix = "@dy2static"
            var_grad_name = var.grad_name
            new_grad_name = var.name + suffix + "@GRAD"
            finded_ops = list(
                filter(
327
                    lambda x: x[0] >= start_idx and any([
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
                        out_arg == var_grad_name
                        for out_arg in x[1].output_arg_names
                    ]), enumerate(target_program.block(0).ops)))

            # len(finded_ops) may equals zero when stop_gradient works.
            # len(finded_ops) may > 1, because we may have fill_constant op.
            if len(finded_ops) == 0:
                return None
            # step1: create a new var named var.name@GRAD
            target_program.block(0).create_var(name=new_grad_name,
                                               type=var.type,
                                               dtype=var.dtype,
                                               shape=var.shape)
            # step2: rename the var.name@GRAD to var.name@GRAD@dy2static
            for idx, op in finded_ops:
                op._rename_input(var_grad_name, new_grad_name)
                op._rename_output(var_grad_name, new_grad_name)
            # step3: insert sum op to aggregate the gradient.
            #        var.name@GRAD = sum(var.name@dy2static@GRAD, var.name@GRAD)
            target_program.block(0)._insert_op(
                finded_ops[-1][0] + 1,
                type='sum',
                inputs={'X': [var_grad_name, new_grad_name]},
                outputs={"Out": var_grad_name})
            return None

        to_processed_vars = list(
            filter(_need_aggregation, self._outputs.tolist()))
        for _var in to_processed_vars:
            _insert_aggregation_ops_for_var(target_program, _var)

359
    @switch_to_static_graph
360
    def _append_backward_desc(self, main_program):
361 362
        # make sure all status of is_test are False in train mode.
        program = _change_is_test_status(main_program.clone(), is_test=False)
363
        targets = []
364
        for out in self._outputs.tolist():
365 366 367 368 369 370
            if isinstance(out, framework.Variable):
                targets.append(program.global_block().var(out.name))

        if targets and self._params:
            backward.gradients(targets=targets, inputs=[])

371 372 373 374
        start_idx = len(
            main_program.block(0).ops) + 2 * len(self._outputs.tolist())

        self.prepare_gradient_aggregation(start_idx, main_program, program)
375

376 377
        return program

378 379 380 381 382 383 384 385 386 387
    def _prune_unused_params(self, program):
        """
        Prune the parameters not used anywhere in the program.
        The `@declarative` may only decorated a sub function which
        contains some unused parameters created in `__init__`.
        So prune these parameters to avoid unnecessary operations in
        `run_program_op`.
        """
        required_params = []
        for param in self._params:
388
            found_param = False
389
            for block in program.blocks:
390 391 392 393 394 395
                for op in block.ops:
                    if param.name in op.input_arg_names or param.name in op.output_arg_names:
                        required_params.append(param)
                        found_param = True
                        break
                if found_param:
396 397 398 399
                    break

        self._params = required_params

400 401 402 403 404 405
    def _get_double_grads(self, program):
        double_grads = []
        for block in program.blocks:
            for name in block.vars:
                if "@GRAD" in name:
                    var_desc = block.vars[name].desc
J
Jiabin Yang 已提交
406
                    var_base = None
J
Jiabin Yang 已提交
407
                    if not framework._in_eager_mode_:
J
Jiabin Yang 已提交
408 409 410 411 412 413 414 415 416
                        var_base = core.VarBase(var_desc.dtype(),
                                                var_desc.shape(),
                                                var_desc.name(),
                                                var_desc.type(), False)
                    else:
                        var_base = core.eager.Tensor(var_desc.dtype(),
                                                     var_desc.shape(),
                                                     var_desc.name(),
                                                     var_desc.type(), False)
417
                    double_grads.append(var_base)
418
        return self._valid_vars(double_grads)
419

420
    def _get_end_op_index(self):
421 422 423 424 425
        if _in_amp_guard():
            infer_program = self._infer_amp_program
        elif _in_pure_fp16_guard():
            infer_program = self._infer_pure_fp16_program
        else:
426
            infer_program = self.infer_program
427 428
        return infer_program.desc.block(0).op_size()

429 430
    def __call__(self, inputs):
        in_vars, out_vars = self._prepare(inputs)
431

432
        attrs = [
433 434 435 436
            'global_block',
            self.program.desc.block(0), 'start_op_index', 0, 'end_op_index',
            self._get_end_op_index(), 'is_test', not self.training,
            'program_id', self.program_id
437 438 439 440 441
        ]
        if self._cuda_graph_capture_mode:
            attrs.extend(
                ('cuda_graph_capture_mode', self._cuda_graph_capture_mode,
                 'cuda_graph_pool_id', self._cuda_graph_pool_id))
442 443 444

        self._cast_fp16_if_pure_fp16(in_vars)

445 446 447 448 449
        _legacy_C_ops.run_program(self._valid_vars(in_vars),
                                  self._valid_vars(self._params),
                                  self._valid_vars(out_vars),
                                  self._create_scope_vec(), self._double_grads,
                                  self._cuda_graph_vec, *attrs)
450 451
        restored_nest_out = self._restore_out(out_vars)
        return self._remove_no_value(restored_nest_out)
452

453 454 455 456
    def _cast_fp16_if_pure_fp16(self, in_vars):
        if _in_pure_fp16_guard():
            for i, var in enumerate(in_vars):
                name = var.name
457 458 459
                if (self.program.global_block().has_var(name)
                        and self.program.global_block().var(name).dtype
                        == paddle.float16):
460 461 462
                    in_vars[i] = var.astype('float16')
                    in_vars[i].name = name

463 464
    @property
    def program(self):
465
        if self.training:
466
            return self.train_program
467
        else:
468
            return self.infer_program
469

470 471
    @property
    def program_id(self):
472
        if self.training:
473 474 475 476 477 478
            if _in_amp_guard():
                return self._train_amp_program_id
            elif _in_pure_fp16_guard():
                return self._train_pure_fp16_program_id
            else:
                return self._train_program_id
479
        else:
480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503
            if _in_amp_guard():
                return self._infer_amp_program_id
            elif _in_pure_fp16_guard():
                return self._infer_pure_fp16_program_id
            else:
                return self._infer_program_id

    @property
    def train_program(self):
        if _in_amp_guard():
            return self._train_amp_program
        elif _in_pure_fp16_guard():
            return self._train_pure_fp16_program
        else:
            return self._train_program

    @property
    def infer_program(self):
        if _in_amp_guard():
            return self._infer_amp_program
        elif _in_pure_fp16_guard():
            return self._infer_pure_fp16_program
        else:
            return self._infer_program
504

505 506 507 508 509
    def _prepare(self, inputs):
        """
        Prepare inputs, outputs, attrs.
        """
        assert isinstance(inputs, (tuple, list))
510 511
        # Flatten inputs with nested structure into single list.
        flatten_inputs = flatten(inputs)
512 513
        # Convert variable into VarBase and feed in training data.
        input_vars = []
514
        expected_place = framework._current_expected_place()
515
        for i, value in enumerate(flatten_inputs):
516
            if isinstance(value, np.ndarray):
J
Jiabin Yang 已提交
517
                var = None
J
Jiabin Yang 已提交
518
                if not framework._in_eager_mode_:
519 520 521 522 523
                    var = core.VarBase(value=value,
                                       name=self._inputs[i].desc.name(),
                                       persistable=False,
                                       place=expected_place,
                                       zero_copy=True)
J
Jiabin Yang 已提交
524
                else:
525 526 527 528 529
                    var = core.eager.Tensor(value=value,
                                            name=self._inputs[i].desc.name(),
                                            persistable=False,
                                            place=expected_place,
                                            zero_copy=True)
J
Jiabin Yang 已提交
530
            elif isinstance(value, (core.VarBase, core.eager.Tensor)):
531 532 533 534 535 536 537
                # NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
                # into CUDAPlace when it's as input of multi Ops. so we move it in advance
                # to avoid this problem.
                if value.stop_gradient and not value.place._equals(
                        expected_place):
                    var = value._copy_to(expected_place, False)
                    var.stop_gradient = True
538 539
                else:
                    var = value
540
                var.name = self._inputs[i].desc.name()
541 542 543
            else:
                continue
            input_vars.append(var)
544

545 546 547
        # mapping from name(string) -> VarBase
        out_varbase_map = {}

548 549
        def create_out(var_id):
            var = self._outputs[var_id]
550
            assert isinstance(var, framework.Variable)
551
            var_desc = var.desc
J
Jiabin Yang 已提交
552
            varbase = None
553 554 555 556

            if var_desc.name() in out_varbase_map:
                return out_varbase_map[var_desc.name()]

J
Jiabin Yang 已提交
557
            if not framework._in_eager_mode_:
558
                var_base = core.VarBase(var_desc.dtype(), var_desc.shape(),
J
Jiabin Yang 已提交
559 560
                                        var_desc.name(), var_desc.type(), False)
            else:
561 562 563
                var_base = core.eager.Tensor(var_desc.dtype(), var_desc.shape(),
                                             var_desc.name(), var_desc.type(),
                                             False)
564
            var_base.stop_gradient = var.stop_gradient
565
            out_varbase_map[var_desc.name()] = var_base
566 567 568 569 570 571
            return var_base

        # Create VarBase to receive output data.
        out_vars = list(map(create_out, self._outputs.var_ids))

        return input_vars, out_vars
572

573
    def _create_scope_vec(self):
574
        # Hold forward variables
J
Jiabin Yang 已提交
575
        tmp_scope_vec = None
576
        inner_scope = core.Scope()
J
Jiabin Yang 已提交
577
        if not framework._in_eager_mode_:
J
Jiabin Yang 已提交
578 579 580 581
            tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
                                         "program_out_scope",
                                         core.VarDesc.VarType.STEP_SCOPES, True)
            tmp_scope_vec.value().set_scope(inner_scope)
582 583
        else:
            tmp_scope_vec = [inner_scope]
584
        return tmp_scope_vec
585

586 587 588 589 590 591
    def _create_cuda_graph_vec(self):
        var = core.VarBase(core.VarDesc.VarType.FP32, [], "cuda_graph",
                           core.VarDesc.VarType.RAW, True)
        var.stop_gradient = True
        return var

592 593 594 595 596 597 598 599 600
    def _restore_out(self, out_vars):
        """
        Restores same nested outputs by only replacing the Variable with VarBase.
        """

        flatten_outputs = self._outputs.tolist()
        for i, idx in enumerate(self._outputs.var_ids):
            flatten_outputs[idx] = out_vars[i]
        outs = self._outputs.restore(flatten_outputs)
601
        if outs is not None and len(outs) == 1:
602 603 604 605
            outs = outs[0]

        return outs

606 607 608 609
    @switch_to_static_graph
    def _clone_for_test(self, main_program):
        return main_program.clone(for_test=True)

610
    def _is_no_value(self, var):
J
Jiabin Yang 已提交
611 612
        if isinstance(var,
                      (core.VarBase, core.eager.Tensor)) and var.shape == [1]:
613 614
            # NOTE: .numpy() will insert MemcpySync operation, it hits performance.
            if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
615 616 617 618 619 620 621
                return True
        return False

    def _remove_no_value(self, out_vars):
        """
        Removes invalid value for various-length return statement
        """
J
Jiabin Yang 已提交
622
        if isinstance(out_vars, (core.VarBase, core.eager.Tensor)):
623 624 625 626 627
            if self._is_no_value(out_vars):
                return None
            return out_vars
        elif isinstance(out_vars, (tuple, list)):
            if isinstance(out_vars, tuple):
628 629
                res = tuple(var for var in out_vars
                            if not self._is_no_value(var))
630 631 632 633 634 635 636 637 638 639 640 641 642 643 644
            else:
                # isinstance(out_vars, list)
                res = [var for var in out_vars if not self._is_no_value(var)]

            has_removed = (len(out_vars) > len(res))
            # len(out_vars) > len(res) means we have removed var. This is
            # preventing out_vars is empty or just one element at the beginning
            if len(res) == 0 and has_removed:
                return None
            elif len(res) == 1 and has_removed:
                return res[0]
            return res

        return out_vars

645
    def _set_grad_type(self, params, train_program):
646 647 648 649 650 651 652 653
        # NOTE: if user set sparse gradient mode, the param's gradient
        # will be SelectedRows, not LoDTensor. But tracer will just
        # set param grad VarBase by forward VarBase(LoDTensor)
        # If we don't change grad_var type here, RunProgramOp need
        # transform SelectedRows to LoDTensor forcibly, it may not
        # be user wanted result.
        for param in params:
            grad_name = param.name + core.grad_var_suffix()
654
            grad_var = train_program.desc.block(0).find_var(
655 656 657 658 659 660
                cpt.to_bytes(grad_name))
            # NOTE: cannot find var desc maybe no problem, such as in batch_norm
            if grad_var is None:
                continue
            param._set_grad_type(grad_var.type())

661 662 663 664 665 666 667 668 669 670 671 672 673
    def _remove_op_call_stack(self, main_program):
        """
        Remove op's python call stack with redundant low-level error messages related to
        transforamtions to avoid confusing users.
        """
        assert isinstance(main_program, framework.Program)
        for block in main_program.blocks:
            for op in block.ops:
                if op.has_attr("op_callstack"):
                    op._remove_attr("op_callstack")

        return main_program

674 675 676 677 678 679 680 681 682 683 684 685
    def _check_params_all_inited(self, main_program):
        """
        Check all params from main program are already initialized, see details as follows:
            1. all parameters in self._params should be type `framework.ParamBase` which are created in dygraph.
            2. all parameters from transformed program can be found in self._params.
               Because they share same data with ParamBase of original dygraph.
        """
        if not isinstance(self._params, (list, tuple)):
            raise TypeError(
                "Type of self._params in PartialProgramLayer should be list or tuple, but received %s."
                % type(self._params))

686 687 688
        param_and_buffer_names_set = set()
        for i, var in enumerate(self._params):
            # self._params constains parameters and buffers with persistable=True.
J
Jiabin Yang 已提交
689
            if not isinstance(var, (core.VarBase, core.eager.Tensor)):
690
                raise TypeError(
691 692
                    'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.'
                    .format(i, type(var)))
693
            param_and_buffer_names_set.add(var.name)
694 695

        for block in main_program.blocks:
696
            for name, var in six.iteritems(block.vars):
697
                if isinstance(var, framework.Parameter):
698
                    if name not in param_and_buffer_names_set:
699
                        raise ValueError(
700 701 702 703 704 705
                            "\n\tWe don't support to define layer with parameters in the function decorated by `@to_static`."
                            "\n\tBut we found parameter(%s) was created in the decorated function."
                            "\n"
                            "\n\tRevise suggestion: "
                            "\n\t\t1. Please ensure all your sublayers are inheritted from nn.Layer."
                            "\n\t\t2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List"
706 707
                            % name)

708 709 710 711 712 713 714 715
    def _valid_vars(self, vars):
        """
        Note: run_program_op.InferShape requires `X`/'Out' not be null.
        But it's common in dy2static, fake varBase is created to handle the
        problem.
        """
        return vars if vars else self.__fake_vars

716

717
def _create_fake_var():
718
    """
719
    Create a fake_var (force on CPU) to handle empty input or output
720
    """
J
Jiabin Yang 已提交
721
    if not framework._in_eager_mode_:
J
Jiabin Yang 已提交
722 723 724 725 726
        return [
            core.VarBase(core.VarDesc.VarType.FP32, [], "Fake_var",
                         core.VarDesc.VarType.RAW, False)
        ]
    else:
727 728 729 730
        return [
            core.eager.Tensor(core.VarDesc.VarType.FP32, [], "Fake_var",
                              core.VarDesc.VarType.RAW, False)
        ]
731 732 733 734 735 736 737


def partial_program_from(concrete_program):
    inputs = concrete_program.inputs
    if inputs and isinstance(inputs[0], layers.Layer):
        inputs = inputs[1:]

738 739 740 741
    return PartialProgramLayer(concrete_program.main_program, inputs,
                               concrete_program.outputs,
                               concrete_program.parameters,
                               **concrete_program.kwargs)