api.py 66.2 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
M
Ming-Xu Huang 已提交
2
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved.
3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16 17 18
# Temporary disable isort to avoid circular import
# This can be removed after the circular import is resolved
# isort: skip_file
19 20
from __future__ import annotations

21 22
import os
import pickle
23
import warnings
24
from collections import OrderedDict
25
import inspect
M
Ming-Xu Huang 已提交
26
import threading
H
hjyp 已提交
27
from typing import Any, List
28

29
import paddle
J
Jiabin Yang 已提交
30
from paddle.fluid import core, dygraph
31 32 33 34 35
from paddle.fluid.compiler import (
    BuildStrategy,
    CompiledProgram,
    ExecutionStrategy,
)
36
from paddle.fluid.data_feeder import check_type
37
from paddle.fluid.layers.utils import flatten, pack_sequence_as
38 39 40 41
from paddle.fluid.dygraph.base import (
    program_desc_tracing_guard,
    switch_to_static_graph,
)
42 43
from .dy2static import logging_utils
from .dy2static.convert_call_func import (
44 45
    ConversionOptions,
    CONVERSION_OPTIONS,
H
hjyp 已提交
46
    add_ignore_module,
47
)
48
from .dy2static.program_translator import (
49 50 51 52
    ProgramTranslator,
    StaticFunction,
    unwrap_decorators,
)
53
from paddle.jit.translated_layer import (
54 55 56 57 58 59
    TranslatedLayer,
    INFER_MODEL_SUFFIX,
    INFER_PARAMS_SUFFIX,
    INFER_PARAMS_INFO_SUFFIX,
    INFER_PROPERTY_SUFFIX,
)
60 61
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.executor import Executor, scope_guard
62 63 64 65 66 67 68 69 70 71 72 73 74
from paddle.fluid.framework import (
    Block,
    ParamBase,
    Program,
    Variable,
    Parameter,
    EagerParamBase,
)
from paddle.fluid.framework import (
    _current_expected_place,
    _dygraph_guard,
    _dygraph_tracer,
)
J
Jiabin Yang 已提交
75
from paddle.fluid.framework import dygraph_only, _non_static_mode
76
from paddle.fluid.wrapped_decorator import wrap_decorator
77

78 79 80 81 82 83 84 85 86

def create_program_from_desc(program_desc):
    program = Program()
    program.desc = program_desc
    program.blocks = [Block(program, 0)]
    program._sync_with_cpp()
    return program


87
def _extract_vars(inputs, result_list, err_tag='inputs'):
88
    if isinstance(inputs, Variable):
89
        result_list.append(inputs)
90
    elif isinstance(inputs, (list, tuple)):
91
        for var in inputs:
92
            _extract_vars(var, result_list, err_tag)
93 94
    else:
        raise TypeError(
95
            "The type of 'each element of {}' in paddle.jit.TracedLayer.trace must be fluid.Variable, but received {}.".format(
96 97 98
                err_tag, type(inputs)
            )
        )
99 100


101
def extract_vars(inputs, err_tag='inputs'):
102
    result_list = []
103
    _extract_vars(inputs, result_list, err_tag)
104 105 106
    return result_list


107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
def _dygraph_to_static_func_(dygraph_func):
    """
    Converts imperative dygraph APIs into declarative function APIs. Decorator
    @dygraph_to_static_func only converts imperative dygraph APIs into
    declarative net-building APIs, which means it doesn't return immediate
    digital result as imperative mode. Users should handle Program and Executor
    by themselves.

    Note:
    This decorator is NOT our recommended way to transform imperative function
    to declarative function. We will remove this decorator after we finalize
    cleaning up code.

    Args:
        dygraph_func (callable): callable imperative function.

    Returns:
        Callable: converting imperative dygraph APIs into declarative
        net-building APIs.

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          import numpy as np
132
          from paddle.jit.api import dygraph_to_static_func
133 134 135

          @dygraph_to_static_func
          def func(x):
136
              if paddle.mean(x) < 0:
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
                  x_v = x - 1
              else:
                  x_v = x + 1

               return x_v

          x = fluid.layers.fill_constant(shape=[3, 3], value=0, dtype='float64')

          x_v = func(x)
          exe = fluid.Executor(fluid.CPUPlace())
          out = exe.run(fetch_list=[x_v])
          print(out[0])
          # [[1. 1. 1.]
          #  [1. 1. 1.]
          #  [1. 1. 1.]]

    """

    # TODO: remove this decorator after we finalize training API
156 157
    def __impl__(*args, **kwargs):
        program_translator = ProgramTranslator()
J
Jiabin Yang 已提交
158
        if _non_static_mode() or not program_translator.enable_to_static:
159
            logging_utils.warn(
160
                "The decorator 'dygraph_to_static_func' doesn't work in "
R
Ryan 已提交
161
                "dygraph mode or set 'paddle.jit.enable_to_static' to False. "
162 163
                "We will just return dygraph output."
            )
164 165 166
            return dygraph_func(*args, **kwargs)
        static_func = program_translator.get_func(dygraph_func)
        return static_func(*args, **kwargs)
167 168 169 170

    return __impl__


171
dygraph_to_static_func = wrap_decorator(_dygraph_to_static_func_)
172

173

174 175 176 177 178 179
def copy_decorator_attrs(original_func, decorated_obj):
    """
    Copies some necessary attributes from original function into decorated function.

    Args:
        original_func(callable): the original decorated function.
180
        decorated_obj(StaticFunction): the target decorated StaticFunction object.
181
    """
H
hjyp 已提交
182
    decorator_name = "to_static"
183 184 185 186 187 188 189 190 191 192 193

    decorated_obj.__name__ = original_func.__name__
    decorated_obj._decorator_name = decorator_name
    decorated_obj.__wrapped__ = original_func
    decorated_obj.__doc__ = original_func.__doc__
    if hasattr(original_func, "__module__"):
        decorated_obj.__module__ = original_func.__module__

    return decorated_obj


H
hjyp 已提交
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
def ignore_module(modules: List[Any]):
    """
    Adds modules that ignore transcription.
    Builtin modules that have been ignored are collections, pdb, copy, inspect, re, numpy, logging, six

    Args:
        modules (List[Any]): Ignored modules that you want to add

    Examples:
        .. code-block:: python

            import scipy
            import astor

            import paddle
            from paddle.jit import ignore_module

            modules = [
               scipy,
               astor
            ]

            ignore_module(modules)

    """
    add_ignore_module(modules)


H
hjyp 已提交
222
def to_static(
223 224
    function=None, input_spec=None, build_strategy=None, property=False
):
225 226
    """
    Converts imperative dygraph APIs into declarative function APIs. Decorator
227
    @to_static handles the Program and Executor of static graph mode and returns
228 229 230 231
    the result as dygraph Tensor(s). Users could use the returned dygraph
    Tensor(s) to do imperative training, inference, or other operations. If the
    decorated function calls other imperative function, the called one will be
    converted into declarative function as well.
232

233
    Args:
234
        function (callable): callable imperative function.
235
        input_spec(list[InputSpec]|tuple[InputSpec]): list/tuple of InputSpec to specific the shape/dtype/name
236
            information of each input Tensor.
237 238 239 240 241
        build_strategy(BuildStrategy|None): This argument is used to compile the
            converted program with the specified options, such as operators' fusion
            in the computational graph and memory optimization during the execution
            of the computational graph. For more information about build_strategy,
            please refer to :code:`paddle.static.BuildStrategy`. The default is None.
242
        property(bool, Optional): whether the fucntion is python property. The default is False.
243

244

245
    Returns:
246
        Tensor(s): containing the numerical result.
247

248 249
    Examples:
        .. code-block:: python
250

251 252 253 254 255 256 257 258 259 260 261 262 263 264
            import paddle
            from paddle.jit import to_static

            @to_static
            def func(x):
                if paddle.mean(x) < 0:
                    x_v = x - 1
                else:
                    x_v = x + 1
                return x_v

            x = paddle.ones([1, 2], dtype='float32')
            x_v = func(x)
            print(x_v) # [[2. 2.]]
265

266
    """
267

268 269
    def decorated(python_func):
        """
270
        Decorates a python function into a StaticFunction object.
271 272 273
        """
        # Step 1. unwrap the function if it is already decorated.
        _, python_func = unwrap_decorators(python_func)
274

275
        # Step 2. copy some attributes from original python function.
276 277 278 279 280 281 282 283 284
        static_layer = copy_decorator_attrs(
            original_func=python_func,
            decorated_obj=StaticFunction(
                function=python_func,
                input_spec=input_spec,
                build_strategy=build_strategy,
                property=property,
            ),
        )
285 286

        return static_layer
287

288 289 290
    build_strategy = build_strategy or BuildStrategy()
    if not isinstance(build_strategy, BuildStrategy):
        raise TypeError(
291 292 293 294
            "Required type(build_strategy) shall be `paddle.static.BuildStrategy`, but received {}".format(
                type(build_strategy).__name__
            )
        )
295

H
hjyp 已提交
296
    # for usage: `to_static(foo, ...)`
297
    if function is not None:
298
        if isinstance(function, Layer):
299
            if isinstance(function.forward, StaticFunction):
300
                class_name = function.__class__.__name__
301
                logging_utils.warn(
302 303 304 305
                    "`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one.".format(
                        class_name
                    )
                )
306 307 308 309
            function.forward = decorated(function.forward)
            return function
        else:
            return decorated(function)
310

H
hjyp 已提交
311
    # for usage: `@to_static`
312
    return decorated
313 314


315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
def not_to_static(func=None):
    """
    A Decorator to suppresses the convertion of a function.

    Args:
        func(callable): The function to decorate.

    Returns:
        callable: A function which won't be converted in Dynamic-to-Static.

    Examples:
        .. code-block:: python

            import paddle

            @paddle.jit.not_to_static
            def func_not_to_static(x):
                res = x - 1
                return res

            @paddle.jit.to_static
            def func(x):
                if paddle.mean(x) < 0:
                    out = func_not_to_static(x)
                else:
                    out = x + 1
                return out

            x = paddle.ones([1, 2], dtype='float32')
            out = func(x)
            print(out) # [[2. 2.]]
    """
    if func is None:
        return not_to_static

    options = ConversionOptions(not_convert=True)
    setattr(func, CONVERSION_OPTIONS, options)
    return func


355
class _SaveLoadConfig:
356 357 358 359 360
    def __init__(self):
        self._output_spec = None
        self._model_filename = None
        self._params_filename = None
        self._separate_params = False
361 362
        # used for `paddle.load`
        self._keep_name_table = False
363 364 365 366

        # NOTE: Users rarely use following configs, so these configs are not open to users,
        # reducing user learning costs, but we retain the configuration capabilities

367 368
        # If True, programs are modified to only support direct inference deployment.
        # Otherwise,more information will be stored for flexible optimization and re-training.
369 370 371 372 373
        # Currently, only True is supported
        self._export_for_deployment = True

        # If True, It will save inference program only, and do not save params of Program
        self._program_only = False
374
        self.with_hook = False
375

376 377 378
        # if True, multi `StaticFunction` will share params in one file.
        self.combine_params = False

379 380 381 382 383 384
    @property
    def output_spec(self):
        return self._output_spec

    @output_spec.setter
    def output_spec(self, spec):
385 386
        if spec is None:
            return
387 388
        if not isinstance(spec, list):
            raise TypeError(
389
                "The config `output_spec` should be 'list', but received input type is %s."
390 391
                % type(input)
            )
392 393 394
            for var in spec:
                if not isinstance(var, core.VarBase):
                    raise TypeError(
395
                        "The element in config `output_spec` list should be 'Variable', but received element's type is %s."
396 397
                        % type(var)
                    )
398 399 400 401 402 403 404 405
        self._output_spec = spec

    @property
    def model_filename(self):
        return self._model_filename

    @model_filename.setter
    def model_filename(self, filename):
406 407
        if filename is None:
            return
408
        if not isinstance(filename, str):
409
            raise TypeError(
410
                "The config `model_filename` should be str, but received input's type is %s."
411 412
                % type(filename)
            )
413
        if len(filename) == 0:
414
            raise ValueError("The config `model_filename` is empty string.")
415 416 417 418 419 420 421 422
        self._model_filename = filename

    @property
    def params_filename(self):
        return self._params_filename

    @params_filename.setter
    def params_filename(self, filename):
423 424
        if filename is None:
            return
425
        if not isinstance(filename, str):
426
            raise TypeError(
427
                "The config `params_filename` should be str, but received input's type is %s."
428 429
                % type(filename)
            )
430
        if len(filename) == 0:
431
            raise ValueError("The config `params_filename` is empty string.")
432 433
        self._params_filename = filename

434 435 436 437 438 439
    @property
    def keep_name_table(self):
        return self._keep_name_table

    @keep_name_table.setter
    def keep_name_table(self, value):
440 441
        if value is None:
            return
442 443
        if not isinstance(value, bool):
            raise TypeError(
444
                "The config `keep_name_table` should be bool value, but received input's type is %s."
445 446
                % type(value)
            )
447 448
        self._keep_name_table = value

449

450
def _parse_save_configs(configs):
451
    supported_configs = [
452 453 454 455 456
        'output_spec',
        "with_hook",
        "combine_params",
        "clip_extra",
        "skip_forward",
457
    ]
458 459 460 461 462 463

    # input check
    for key in configs:
        if key not in supported_configs:
            raise ValueError(
                "The additional config (%s) of `paddle.jit.save` is not supported."
464 465
                % (key)
            )
466 467 468 469

    # construct inner config
    inner_config = _SaveLoadConfig()
    inner_config.output_spec = configs.get('output_spec', None)
470
    inner_config.with_hook = configs.get('with_hook', False)
471
    inner_config.combine_params = configs.get("combine_params", False)
472
    inner_config.clip_extra = configs.get("clip_extra", True)
H
Hui Zhang 已提交
473
    inner_config.skip_forward = configs.get("skip_forward", False)
474 475 476 477 478 479 480 481 482 483 484 485

    return inner_config


def _parse_load_config(configs):
    supported_configs = ['model_filename', 'params_filename']

    # input check
    for key in configs:
        if key not in supported_configs:
            raise ValueError(
                "The additional config (%s) of `paddle.jit.load` is not supported."
486 487
                % (key)
            )
488 489 490 491 492 493 494 495 496

    # construct inner config
    inner_config = _SaveLoadConfig()
    inner_config.model_filename = configs.get('model_filename', None)
    inner_config.params_filename = configs.get('params_filename', None)

    return inner_config


497
def _get_input_var_names(inputs, input_spec):
498 499 500 501
    name_none_error = (
        "The %s's name is None. "
        "When using jit.save, please set InputSepc's name in "
        "to_static(input_spec=[]) and jit.save(input_spec=[]) "
502
        "and make sure they are consistent."
503 504 505 506 507
    )
    name_no_exists_error = (
        "The tensor `%s` does not exists. "
        "Please make sure the name of InputSpec or example Tensor "
        "in input_spec is the same as the name of InputSpec in "
508
        "`to_static` decorated on the Layer.forward method."
509
    )
510
    result_list = []
511 512 513
    input_var_names = [
        var.name for var in flatten(inputs) if isinstance(var, Variable)
    ]
514 515
    if input_spec is None:
        # no prune
516 517 518 519
        return input_var_names
    else:
        # fileter out non-tensor type spec infos.
        input_spec = [
520 521
            spec
            for spec in input_spec
522 523 524 525
            if isinstance(spec, paddle.static.InputSpec)
        ]

    if len(input_spec) == len(input_var_names):
526 527
        # no prune
        result_list = input_var_names
528
        # if input spec name not in input_var_names, only raise warning
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551
        for spec in input_spec:
            if spec.name is None:
                warnings.warn(name_none_error % spec)
            elif spec.name not in input_var_names:
                warnings.warn(name_no_exists_error % spec.name)
            else:
                # do nothing
                pass
    else:
        # prune
        for spec in input_spec:
            if spec.name is None:
                # name is None, the input_spec only can be InputSpec
                raise ValueError(name_none_error % spec)
            elif spec.name not in input_var_names:
                # the input_spec can be `InputSpec` or `VarBase`
                raise ValueError(name_no_exists_error % spec.name)
            else:
                result_list.append(spec.name)

    return result_list


552
def _get_output_vars(outputs, output_spec, with_hook=False):
553 554 555 556
    name_no_exists_error = (
        "The tensor `%s` does not exists. "
        "Please make sure the name of example Tensor "
        "in configs.output_spec is the output tensor of "
557
        "Layer.forward method."
558
    )
559 560 561 562
    if output_spec and with_hook:
        raise RuntimeError(
            "Currently not support specify output_spec while founding pre/post hooks in your outermost layer."
        )
563 564
    result_list = []
    output_vars_dict = OrderedDict()
565
    for var in flatten(outputs):
566 567 568
        if isinstance(var, Variable):
            output_vars_dict[var.name] = var
    if output_spec is None:
569
        result_list = list(output_vars_dict.values())
570
    elif output_spec is not None and len(output_spec) == len(output_vars_dict):
571
        result_list = list(output_vars_dict.values())
572 573 574 575 576 577 578 579 580 581 582 583
        for var in output_spec:
            if var.name not in output_vars_dict:
                warnings.warn(name_no_exists_error % var.name)
    else:
        for var in output_spec:
            if var.name not in output_vars_dict:
                raise ValueError(name_no_exists_error % var.name)
            else:
                result_list.append(output_vars_dict[var.name])
    return result_list


584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
# NOTE(chenweihang): [ Handling of use cases of API paddle.jit.load ]
# `paddle.jit.load` may be used to load saved results of:
# 1. Expected cases:
#   - paddle.jit.save
#   - paddle.static.save_inference_model
#   - paddle.fluid.io.save_inference_model
# 2. Error cases:
#   - paddle.save: no .pdmodel for prefix
#   - paddle.static.save: no .pdiparams but .pdparams exists
#   - paddle.fluid.io.save_params/save_persistables: no __model__
# TODO(chenweihang): polish error message in above error cases
def _build_load_path_and_config(path, config):
    # NOTE(chenweihang): If both [prefix save format] and [directory save format] exist,
    # raise error, avoid confusing behavior
    prefix_format_path = path + INFER_MODEL_SUFFIX
    prefix_format_exist = os.path.exists(prefix_format_path)
    directory_format_exist = os.path.isdir(path)
    if prefix_format_exist and directory_format_exist:
        raise ValueError(
            "The %s.pdmodel and %s directory exist at the same time, "
            "don't know which one to load, please make sure that the specified target "
605 606
            "of ``path`` is unique." % (path, path)
        )
607
    elif not prefix_format_exist and not directory_format_exist:
608 609 610 611 612
        raise ValueError(
            "The ``path`` (%s) to load model not exists. "
            "Please make sure that *.pdmodel exists or "
            "don't using ``skip_forward=True`` to jit.save." % path
        )
613 614 615 616 617 618 619 620
    else:
        if prefix_format_exist:
            file_prefix = os.path.basename(path)
            model_path = os.path.dirname(path)
            if config.model_filename is not None:
                warnings.warn(
                    "When loading the result saved with the "
                    "specified file prefix, the ``model_filename`` config does "
621 622
                    "not take effect."
                )
623 624 625 626 627
            config.model_filename = file_prefix + INFER_MODEL_SUFFIX
            if config.params_filename is not None:
                warnings.warn(
                    "When loading the result saved with the "
                    "specified file prefix, the ``params_filename`` config does "
628 629
                    "not take effect."
                )
630 631 632 633
            config.params_filename = file_prefix + INFER_PARAMS_SUFFIX
        else:
            # Compatible with the old save_inference_model format
            model_path = path
634

635
    return model_path, config
636 637


M
Ming-Xu Huang 已提交
638 639 640 641
_save_pre_hooks_lock = threading.Lock()
_save_pre_hooks = []


642
class HookRemoveHelper:
643
    """A HookRemoveHelper that can be used to remove hook."""
M
Ming-Xu Huang 已提交
644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678

    def __init__(self, hook):
        self._hook = hook

    def remove(self):
        _remove_save_pre_hook(self._hook)


def _register_save_pre_hook(hook):
    """
    Register a save pre-hook for `paddle.jit.save`.
    This hook will be executed before `save` function has been invoked.

    hook(layer, input_spec, configs) -> None
    - layer (Layer|function): This argument is corresponding to `layer` in `paddle.jit.save`.
    - input_spec (list or tuple[InputSpec|Tensor|Python built-in variable]): This argument is corresponding to `input_spec` in `paddle.jit.save`.
    - configs (dict): This argument is corresponding to `configs` in `paddle.jit.save`.

    Args:
        hook(function): a function registered as a save pre-hook

    Returns:
        HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()`.

    Examples:
        .. code-block:: python

            import numpy as np
            import paddle

            IMAGE_SIZE = 256
            CLASS_NUM = 10

            class LinearNet(paddle.nn.Layer):
                def __init__(self):
679
                    super().__init__()
M
Ming-Xu Huang 已提交
680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725
                    self._linear = paddle.nn.Linear(IMAGE_SIZE, CLASS_NUM)

                def forward(self, x):
                    return self._linear(x)

            saving_count = 0
            def save_pre_hook(layer, input_spec, configs):
                global saving_count
                saving_count += 1

            remove_handler = paddle.jit.register_save_pre_hook(save_pre_hook)

            layer = LinearNet()
            paddle.jit.save(layer, "/tmp", [paddle.static.InputSpec(shape=[-1, IMAGE_SIZE])])
            # saving_count == 1

            remove_handler.remove()
            paddle.jit.save(layer, "/tmp", [paddle.static.InputSpec(shape=[-1, IMAGE_SIZE])])
            # saving_count == 1
    """
    global _save_pre_hooks_lock
    global _save_pre_hooks
    _save_pre_hooks_lock.acquire()
    if hook not in _save_pre_hooks:
        _save_pre_hooks.append(hook)
    _save_pre_hooks_lock.release()
    return HookRemoveHelper(hook)


def _clear_save_pre_hooks():
    global _save_pre_hooks_lock
    global _save_pre_hooks
    _save_pre_hooks_lock.acquire()
    _save_pre_hooks.clear()
    _save_pre_hooks_lock.release()


def _remove_save_pre_hook(hook):
    global _save_pre_hooks_lock
    global _save_pre_hooks
    _save_pre_hooks_lock.acquire()
    if hook in _save_pre_hooks:
        _save_pre_hooks.remove(hook)
    _save_pre_hooks_lock.release()


726
@wrap_decorator
M
Ming-Xu Huang 已提交
727 728 729 730 731 732 733 734 735 736
def _run_save_pre_hooks(func):
    def wrapper(layer, path, input_spec=None, **configs):
        global _save_pre_hooks
        for hook in _save_pre_hooks:
            hook(layer, input_spec, configs)
        func(layer, path, input_spec, **configs)

    return wrapper


737
def _save_property(filename: str, property_vals: list[tuple[Any, str]]):
738 739 740
    """class property serialization.

    Args:
741 742
        filename (str): *.meta
        property_vals (list[tuple[Any, str]]): class property.
743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770
    """

    def set_property(meta, key, val):
        if isinstance(val, float):
            meta.set_float(key, val)
        elif isinstance(val, int):
            meta.set_int(key, val)
        elif isinstance(val, str):
            meta.set_string(key, val)
        elif isinstance(val, (tuple, list)):
            if isinstance(val[0], float):
                meta.set_floats(key, val)
            elif isinstance(val[0], int):
                meta.set_ints(key, val)
            elif isinstance(val[0], str):
                meta.set_strings(key, val)
        else:
            raise ValueError(f"Note support val type: {type(val)}")
        return

    with open(filename, 'wb') as f:
        meta = paddle.framework.core.Property()
        for item in property_vals:
            val, key = item[0], item[1]
            set_property(meta, key, val)
        f.write(meta.serialize_to_string())


M
Ming-Xu Huang 已提交
771
@_run_save_pre_hooks
772
@switch_to_static_graph
773
def save(layer, path, input_spec=None, **configs):
774
    """
775
    Saves input Layer or function as ``paddle.jit.TranslatedLayer``
776 777
    format model, which can be used for inference or fine-tuning after loading.

778
    It will save the translated program and all related persistable
779
    variables of input Layer to given ``path`` .
780 781

    ``path`` is the prefix of saved objects, and the saved translated program file
782
    suffix is ``.pdmodel`` , the saved persistable variables file suffix is ``.pdiparams`` ,
783
    and here also saved some additional variable description information to a file,
784
    its suffix is ``.pdiparams.info``, these additional information is used in fine-tuning.
785 786

    The saved model can be loaded by follow APIs:
787 788
      - ``paddle.jit.load``
      - ``paddle.static.load_inference_model``
789 790
      - Other C++ inference APIs

791
    .. note::
792
        When using ``paddle.jit.save`` to save a function, parameters will not be saved. If you have to
793 794
        save the parameter, please pass the Layer containing function and parameter to ``paddle.jit.save``.

795
    Args:
796
        layer (Layer|function): The Layer or function to be saved.
797
        path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``.
798 799 800
        input_spec (list or tuple[InputSpec|Tensor|Python built-in variable], optional): Describes the input of the saved model's forward
            method, which can be described by InputSpec or example Tensor. Moreover, we support to specify non-tensor type argument,
            such as int, float, string, or list/dict of them.If None, all input variables of
801
            the original Layer's forward method would be the inputs of the saved model. Default None.
802 803
        **configs (dict, optional): Other save configuration options for compatibility. We do not
            recommend using these configurations, they may be removed in the future. If not necessary,
804 805 806
            DO NOT use them. Default None.
            The following options are currently supported:
            (1) output_spec (list[Tensor]): Selects the output targets of the saved model.
807 808 809
            By default, all return variables of original Layer's forward method are kept as the
            output of the saved model. If the provided ``output_spec`` list is not all output variables,
            the saved model will be pruned according to the given ``output_spec`` list.
810

811 812 813 814 815 816
    Returns:
        None

    Examples:
        .. code-block:: python

817
            # example 1: save layer
818
            import numpy as np
819 820 821
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
822

823 824 825
            BATCH_SIZE = 16
            BATCH_NUM = 4
            EPOCH_NUM = 4
826

827 828 829 830 831 832 833
            IMAGE_SIZE = 784
            CLASS_NUM = 10

            # define a random dataset
            class RandomDataset(paddle.io.Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples
834

835 836 837 838
                def __getitem__(self, idx):
                    image = np.random.random([IMAGE_SIZE]).astype('float32')
                    label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
                    return image, label
839

840 841
                def __len__(self):
                    return self.num_samples
842

843 844
            class LinearNet(nn.Layer):
                def __init__(self):
845
                    super().__init__()
846
                    self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
847

848
                @paddle.jit.to_static
849 850 851
                def forward(self, x):
                    return self._linear(x)

852 853 854 855 856 857 858 859 860 861 862 863
            def train(layer, loader, loss_fn, opt):
                for epoch_id in range(EPOCH_NUM):
                    for batch_id, (image, label) in enumerate(loader()):
                        out = layer(image)
                        loss = loss_fn(out, label)
                        loss.backward()
                        opt.step()
                        opt.clear_grad()
                        print("Epoch {} batch {}: loss = {}".format(
                            epoch_id, batch_id, np.mean(loss.numpy())))

            # 1. train & save model.
864

865 866 867 868
            # create network
            layer = LinearNet()
            loss_fn = nn.CrossEntropyLoss()
            adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())
869

870 871 872 873 874 875 876
            # create data loader
            dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
            loader = paddle.io.DataLoader(dataset,
                batch_size=BATCH_SIZE,
                shuffle=True,
                drop_last=True,
                num_workers=2)
877

878 879
            # train
            train(layer, loader, loss_fn, adam)
880

881
            # save
882 883
            path = "example_model/linear"
            paddle.jit.save(layer, path)
884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903

            # example 2: save function
            import paddle
            from paddle.static import InputSpec


            def save_function():
                @paddle.jit.to_static
                def fun(inputs):
                    return paddle.tanh(inputs)

                path = 'test_jit_save_load_function_1/func'
                inps = paddle.rand([3, 6])
                origin = fun(inps)

                paddle.jit.save(fun, path)
                load_func = paddle.jit.load(path)

                load_result = load_func(inps)
                print((load_result - origin).abs().max() < 1e-10)
904

905
            save_function()
906 907
    """

908
    # 1. input build & check
909
    prog_translator = ProgramTranslator()
910
    if not prog_translator.enable_to_static:
911
        raise RuntimeError(
R
Ryan 已提交
912
            "The paddle.jit.save doesn't work when setting 'paddle.jit.enable_to_static' to False."
913
        )
914

915 916 917 918 919
    if not (
        isinstance(layer, Layer)
        or inspect.isfunction(layer)
        or isinstance(layer, StaticFunction)
    ):
920
        raise TypeError(
921
            "The input of paddle.jit.save should be 'Layer' or 'Function', but received input type is %s."
922 923
            % type(layer)
        )
924 925 926 927
    elif inspect.isfunction(layer) or isinstance(layer, StaticFunction):
        warnings.warn(
            'What you save is a function, and `jit.save` will generate the name of the model file according to `path` you specify. When loading these files with `jit.load`, you get a `TranslatedLayer` whose inference result is the same as the inference result of the function you saved.'
        )
928

929 930
    # NOTE(chenweihang): If the input layer be wrapped by DataParallel,
    # the args and kwargs of forward method will can't be parsed by
931
    # function_spec, so here we save DataParallel._layers instead
932 933 934 935 936 937 938
    # DataParallel it self
    # NOTE(chenweihang): using inner_layer, do not change input layer
    if isinstance(layer, paddle.DataParallel):
        inner_layer = layer._layers
    else:
        inner_layer = layer

939 940 941 942 943 944
    # path check
    file_prefix = os.path.basename(path)
    if file_prefix == "":
        raise ValueError(
            "The input path MUST be format of dirname/file_prefix "
            "[dirname\\file_prefix in Windows system], but received "
945 946
            "file_prefix is empty string."
        )
947 948 949 950

    dirname = os.path.dirname(path)
    if dirname and not os.path.exists(dirname):
        os.makedirs(dirname)
951

952 953
    # avoid change user given input_spec
    inner_input_spec = None
954
    if input_spec is not None:
955 956 957
        if isinstance(layer, Layer):
            for attr_func in dir(inner_layer):
                static_func = getattr(inner_layer, attr_func, None)
958 959 960 961
                if (
                    isinstance(static_func, StaticFunction)
                    and 'forward' != attr_func
                ):
962 963
                    raise ValueError(
                        "If there are static functions other than 'forward' that need to be saved, the input 'input_spec' should be None, but received the type of 'input_spec' is %s."
964 965
                        % type(input_spec)
                    )
966

967
        if not isinstance(input_spec, (list, tuple)):
968 969
            raise TypeError(
                "The input input_spec should be 'list', but received input_spec's type is %s."
970 971
                % type(input_spec)
            )
972
        inner_input_spec = []
973
        for var in flatten(input_spec):
974 975
            if isinstance(var, paddle.static.InputSpec):
                inner_input_spec.append(var)
0
0x45f 已提交
976
            elif isinstance(var, (core.VarBase, core.eager.Tensor, Variable)):
977
                inner_input_spec.append(
978 979
                    paddle.static.InputSpec.from_tensor(var)
                )
980
            else:
981 982
                # NOTE(Aurelius84): Support non-Tensor type in `input_spec`.
                inner_input_spec.append(var)
983

984 985
    # parse configs
    configs = _parse_save_configs(configs)
986
    # whether outermost layer has pre/post hook, if does, we need also save
987
    # these operators in program.
988
    with_hook = configs.with_hook
989 990 991
    combine_params = configs.combine_params
    if combine_params:
        configs._program_only = True
992

993 994
    scope = core.Scope()
    extra_var_info = dict()
995 996
    if isinstance(layer, Layer):
        functions = dir(inner_layer)
997 998
        if inner_layer._forward_pre_hooks or inner_layer._forward_post_hooks:
            with_hook = True
999 1000
    else:
        # layer is function
1001 1002 1003
        functions = [
            layer,
        ]
1004

1005
    combine_vars = {}
1006
    property_vals = []  # (value, key)
H
Hui Zhang 已提交
1007
    concrete_program = None
1008 1009 1010 1011
    for attr_func in functions:
        if isinstance(layer, Layer):
            static_func = getattr(inner_layer, attr_func, None)
            if isinstance(static_func, StaticFunction):
1012 1013 1014 1015
                if static_func.is_property:
                    # property method to be exported
                    immediate_val = static_func()
                    property_vals.append(
1016 1017 1018 1019 1020
                        (
                            immediate_val,
                            layer.__class__.__name__ + '.' + attr_func,
                        )
                    )
1021 1022
                    continue

1023 1024 1025 1026 1027
                concrete_program = (
                    static_func.concrete_program_specify_input_spec(
                        inner_input_spec, with_hook=with_hook
                    )
                )
1028
            elif 'forward' == attr_func:
H
Hui Zhang 已提交
1029 1030 1031 1032
                if configs.skip_forward:
                    # do not jit.save forward function
                    continue

1033
                # transform in jit.save, if input_spec is incomplete, declarative will throw error
1034
                # inner_input_spec is list[InputSpec], it should be packed with same structure
1035 1036
                # as original input_spec here.
                if inner_input_spec:
1037 1038 1039
                    inner_input_spec = pack_sequence_as(
                        input_spec, inner_input_spec
                    )
H
hjyp 已提交
1040
                static_forward = to_static(
1041 1042 1043 1044 1045 1046 1047
                    inner_layer.forward, input_spec=inner_input_spec
                )
                concrete_program = (
                    static_forward.concrete_program_specify_input_spec(
                        with_hook=with_hook
                    )
                )
1048
                # the input_spec has been used in declarative, which is equal to
H
hjyp 已提交
1049
                # @to_static with input_spec and jit.save without input_spec,
1050 1051 1052 1053
                # avoid needless warning
                inner_input_spec = None
            else:
                continue
1054 1055 1056
        else:
            # When layer is a function
            if isinstance(attr_func, StaticFunction):
1057 1058 1059 1060 1061 1062
                if attr_func.is_property:
                    # property method to be exported
                    immediate_val = attr_func()
                    property_vals.append((immediate_val, attr_func))
                    continue

1063 1064 1065 1066 1067
                concrete_program = (
                    attr_func.concrete_program_specify_input_spec(
                        inner_input_spec
                    )
                )
1068 1069
            else:
                if inner_input_spec:
1070 1071 1072
                    inner_input_spec = pack_sequence_as(
                        input_spec, inner_input_spec
                    )
H
hjyp 已提交
1073
                static_function = to_static(
1074 1075
                    attr_func, input_spec=inner_input_spec
                )
1076 1077 1078 1079
                concrete_program = static_function.concrete_program

                if static_function._class_instance is None:
                    warnings.warn(
1080 1081 1082 1083
                        '`jit.save` will only save the `Program`, not the parameters. If you have to save the parameters, please make sure that {} is a member function of `paddle.nn.Layer` and the saved parameters are in `state_dict`'.format(
                            layer
                        )
                    )
1084

1085
        # when save multi `StaticFunction`, all `StaticFunction` share params.
1086 1087
        dygraph_state_dict = None
        if isinstance(inner_layer, Layer):
1088
            dygraph_state_dict = inner_layer.to_static_state_dict()
1089 1090
        elif isinstance(attr_func, StaticFunction):
            if attr_func._class_instance:
1091 1092
                dygraph_state_dict = (
                    attr_func._class_instance.to_static_state_dict()
1093
                )
1094 1095

        if dygraph_state_dict:
1096 1097 1098 1099 1100
            # NOTE(chenweihang): we maintain the mapping of variable name to
            # structured name, the buffer variable (non-persistable)
            # saved to inference program may not need by dygraph Layer,
            # we only record the state_dict variable's structured name
            state_names_dict = dict()
1101
            state_var_dict = dict()
1102
            for structured_name, var in dygraph_state_dict.items():
1103
                state_names_dict[var.name] = structured_name
1104
                state_var_dict[var.name] = var
1105

1106 1107 1108 1109 1110 1111 1112 1113 1114 1115
        # 3. share parameters from Layer to scope & record var info
        with dygraph.guard():
            for param_or_buffer in concrete_program.parameters:
                # share to scope
                if param_or_buffer.type == core.VarDesc.VarType.VOCAB:
                    scr_tensor = param_or_buffer.value().get_map_tensor()
                    tgt_var = scope.var(param_or_buffer.name)
                    tgt_var.set_vocab(scr_tensor)
                else:
                    param_or_buffer_tensor = scope.var(
1116 1117 1118 1119 1120 1121 1122 1123
                        param_or_buffer.name
                    ).get_tensor()
                    # src_tensor = param_or_buffer.value().get_tensor()
                    src_tensor = (
                        state_var_dict[param_or_buffer.name]
                        .value()
                        .get_tensor()
                    )
1124 1125 1126 1127 1128 1129
                    param_or_buffer_tensor._share_data_with(src_tensor)
                # record var info
                if param_or_buffer.name not in extra_var_info:
                    extra_info_dict = dict()
                    if param_or_buffer.name in state_names_dict:
                        extra_info_dict['structured_name'] = state_names_dict[
1130 1131
                            param_or_buffer.name
                        ]
1132
                    extra_info_dict[
1133 1134
                        'stop_gradient'
                    ] = param_or_buffer.stop_gradient
1135 1136 1137
                    if isinstance(param_or_buffer, (ParamBase, EagerParamBase)):
                        extra_info_dict['trainable'] = param_or_buffer.trainable
                    extra_var_info[param_or_buffer.name] = extra_info_dict
1138 1139

        # 4. build input & output of save_infernece_model
1140 1141 1142 1143 1144 1145 1146 1147
        # NOTE(chenweihang): [ Get input variables name ]
        # There are two cases, whether to prune the inputs or not
        # - not prune inputs (recommend):
        #   - the len(input_spec) == len((concrete_program.inputs) - 1
        #   - here can use concrete_program.inputs directly
        # - prune inputs:
        #   - the input_spec length < len((concrete_program.inputs) - 1
        #   - the input_spec's name should be in concrete_program.inputs
1148 1149 1150
        input_var_names = _get_input_var_names(
            concrete_program.inputs, inner_input_spec
        )
1151 1152

        # NOTE(chenweihang): [ Get output variables ]
1153 1154
        # the rule is like [ Get input variables name ]. For output var,
        # we only support VarBase spec, and actually, we only need the
1155
        # var name of output, and we don't recommended to use output_spec
1156 1157
        # print(concrete_program.main_program)
        # print(concrete_program.outputs, configs.output_spec)
1158 1159 1160
        output_vars = _get_output_vars(
            concrete_program.outputs, configs.output_spec, with_hook
        )
1161 1162 1163 1164 1165 1166 1167

        # 5. save inference model
        from paddle.fluid.io import save_inference_model

        # construct new save_inference_model arguments
        model_path = dirname
        # NOTE(chenweihang): because prefix contains model and params filename,
1168
        # so we don't support set model_filename & params_filename
1169
        if 'forward' == attr_func or not isinstance(layer, Layer):
1170 1171 1172 1173
            model_filename = file_prefix + INFER_MODEL_SUFFIX
            params_filename = file_prefix + INFER_PARAMS_SUFFIX
        else:
            model_filename = file_prefix + '.' + attr_func + INFER_MODEL_SUFFIX
1174 1175 1176
            params_filename = (
                file_prefix + '.' + attr_func + INFER_PARAMS_SUFFIX
            )
1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187

        with scope_guard(scope):
            save_inference_model(
                dirname=model_path,
                feeded_var_names=input_var_names,
                target_vars=output_vars,
                executor=Executor(_current_expected_place()),
                main_program=concrete_program.main_program.clone(),
                model_filename=model_filename,
                params_filename=params_filename,
                export_for_deployment=configs._export_for_deployment,
1188
                program_only=configs._program_only,
1189 1190
                clip_extra=configs.clip_extra,
            )
1191

1192 1193 1194
        if combine_params:
            clone_main_program = concrete_program.main_program.clone()
            clone_main_program = clone_main_program._prune_with_input(
1195 1196
                input_var_names, output_vars
            )
1197 1198
            for block in clone_main_program.blocks:
                combine_vars.update(block.vars)
1199 1200 1201

    # save shared params
    if combine_params:
1202 1203 1204 1205 1206 1207
        # sort vars by name
        combine_vars = sorted(combine_vars.items(), key=lambda item: item[0])
        ordered_vars = []
        for name, var in combine_vars:
            ordered_vars.append(var)

1208 1209
        params_filename = file_prefix + INFER_PARAMS_SUFFIX
        with scope_guard(scope):
1210 1211 1212 1213 1214 1215
            paddle.static.save_vars(
                Executor(_current_expected_place()),
                dirname=model_path,
                vars=list(filter(paddle.fluid.io.is_persistable, ordered_vars)),
                filename=params_filename,
            )
1216
        # save property
1217 1218 1219
        property_save_path = os.path.join(
            os.path.normpath(model_path), file_prefix + INFER_PROPERTY_SUFFIX
        )
1220
        _save_property(property_save_path, property_vals)
1221

1222 1223 1224 1225 1226 1227 1228
    # NOTE(chenweihang): [ Save extra variable info ]
    # save_inference_model will lose some important variable information, including:
    #   - Variable name and correspondence (when saved variables as one file)
    #   - Variable.stop_gradient information
    #   - Which persistent variable are parameter and which are not
    #   - Parameter.trainable information
    #
1229 1230
    # The lost information cannot be recovered when it is loaded again,
    # so if we want to perform fine-tune after loading, we may need to
1231 1232
    # configure redundant information to proceed.
    #
1233 1234
    # Due to compatibility issues, we cannot change the original storage structure,
    # but we can save these information in `jit.save` without changing the original
1235 1236
    # storage to improve user experience. So we save extra information into
    # file `***.pdiparams.info`
1237 1238 1239

    # "layer" can only be Layer or function or StaticFunction.
    contain_parameter = False
H
Hui Zhang 已提交
1240 1241 1242
    if concrete_program is not None:
        for var in concrete_program.main_program.list_vars():
            contain_parameter |= isinstance(var, Parameter)
1243 1244

    if (isinstance(layer, Layer) or contain_parameter) and extra_var_info:
1245 1246 1247 1248
        with scope_guard(scope):
            extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX
            with open(extra_var_info_path, 'wb') as f:
                pickle.dump(extra_var_info, f, protocol=2)
1249 1250 1251


@dygraph_only
1252
def load(path, **configs):
1253 1254 1255
    """
    :api_attr: imperative

1256 1257
    Load model saved by ``paddle.jit.save`` or ``paddle.static.save_inference_model`` or
    paddle 1.x API ``paddle.fluid.io.save_inference_model`` as ``paddle.jit.TranslatedLayer``,
1258
    then performing inference or fine-tune training.
1259 1260

    .. note::
1261
        If you load model saved by ``paddle.static.save_inference_model`` ,
1262 1263
        there will be the following limitations when using it in fine-tuning:
        1. Imperative mode do not support LoDTensor. All original model's feed targets or parametars that depend on LoD are temporarily unavailable.
1264
        2. All saved model's feed targets need to be passed into TranslatedLayer's forward function.
1265 1266 1267 1268
        3. The variable's ``stop_gradient`` information is lost and can not be recovered.
        4. The parameter's ``trainable`` information is lost and can not be recovered.

    Args:
1269
        path (str): The path prefix to load model. The format is ``dirname/file_prefix`` or ``file_prefix`` .
1270 1271
        **configs (dict, optional): Other load configuration options for compatibility. We do not
            recommend using these configurations, they may be removed in the future. If not necessary,
1272 1273
            DO NOT use them. Default None.
            The following options are currently supported:
1274 1275 1276 1277
            (1) model_filename (str): The inference model file name of the paddle 1.x
            ``save_inference_model`` save format. Default file name is :code:`__model__` .
            (2) params_filename (str): The persistable variables file name of the paddle 1.x
            ``save_inference_model`` save format. No default file name, save variables separately
1278 1279
            by default.

1280 1281 1282 1283 1284

    Returns:
        TranslatedLayer: A Layer object can run saved translated model.

    Examples:
1285
        1. Load model saved by ``paddle.jit.save`` then performing inference and fine-tune training.
1286 1287 1288 1289

        .. code-block:: python

            import numpy as np
1290 1291 1292
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
1293

1294 1295 1296
            BATCH_SIZE = 16
            BATCH_NUM = 4
            EPOCH_NUM = 4
1297

1298 1299
            IMAGE_SIZE = 784
            CLASS_NUM = 10
1300

1301 1302 1303 1304
            # define a random dataset
            class RandomDataset(paddle.io.Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples
1305

1306 1307 1308 1309
                def __getitem__(self, idx):
                    image = np.random.random([IMAGE_SIZE]).astype('float32')
                    label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
                    return image, label
1310

1311 1312 1313 1314 1315
                def __len__(self):
                    return self.num_samples

            class LinearNet(nn.Layer):
                def __init__(self):
1316
                    super().__init__()
1317
                    self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
1318

1319
                @paddle.jit.to_static
1320 1321 1322
                def forward(self, x):
                    return self._linear(x)

1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333
            def train(layer, loader, loss_fn, opt):
                for epoch_id in range(EPOCH_NUM):
                    for batch_id, (image, label) in enumerate(loader()):
                        out = layer(image)
                        loss = loss_fn(out, label)
                        loss.backward()
                        opt.step()
                        opt.clear_grad()
                        print("Epoch {} batch {}: loss = {}".format(
                            epoch_id, batch_id, np.mean(loss.numpy())))

1334
            # 1. train & save model.
1335

1336
            # create network
1337 1338 1339 1340
            layer = LinearNet()
            loss_fn = nn.CrossEntropyLoss()
            adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())

1341
            # create data loader
1342 1343 1344 1345 1346 1347
            dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
            loader = paddle.io.DataLoader(dataset,
                batch_size=BATCH_SIZE,
                shuffle=True,
                drop_last=True,
                num_workers=2)
1348

1349 1350
            # train
            train(layer, loader, loss_fn, adam)
1351

1352
            # save
1353 1354
            path = "example_model/linear"
            paddle.jit.save(layer, path)
1355

1356
            # 2. load model
1357

1358
            # load
1359
            loaded_layer = paddle.jit.load(path)
1360 1361

            # inference
1362 1363 1364
            loaded_layer.eval()
            x = paddle.randn([1, IMAGE_SIZE], 'float32')
            pred = loaded_layer(x)
1365 1366

            # fine-tune
1367 1368 1369
            loaded_layer.train()
            adam = opt.Adam(learning_rate=0.001, parameters=loaded_layer.parameters())
            train(loaded_layer, loader, loss_fn, adam)
1370 1371


1372
        2. Load model saved by ``paddle.fluid.io.save_inference_model`` then performing and fine-tune training.
1373 1374 1375 1376

        .. code-block:: python

            import numpy as np
1377
            import paddle
1378
            import paddle.static as static
1379 1380
            import paddle.nn as nn
            import paddle.optimizer as opt
1381
            import paddle.nn.functional as F
1382

1383 1384 1385
            BATCH_SIZE = 16
            BATCH_NUM = 4
            EPOCH_NUM = 4
1386

1387 1388 1389 1390 1391 1392 1393
            IMAGE_SIZE = 784
            CLASS_NUM = 10

            # define a random dataset
            class RandomDataset(paddle.io.Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples
1394

1395 1396 1397 1398
                def __getitem__(self, idx):
                    image = np.random.random([IMAGE_SIZE]).astype('float32')
                    label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
                    return image, label
1399

1400 1401
                def __len__(self):
                    return self.num_samples
1402

1403 1404
            paddle.enable_static()

1405 1406
            image = static.data(name='image', shape=[None, 784], dtype='float32')
            label = static.data(name='label', shape=[None, 1], dtype='int64')
1407
            pred = static.nn.fc(x=image, size=10, activation='softmax')
1408 1409
            loss = F.cross_entropy(input=pred, label=label)
            avg_loss = paddle.mean(loss)
1410

1411
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
1412 1413
            optimizer.minimize(avg_loss)

1414 1415 1416
            place = paddle.CPUPlace()
            exe = static.Executor(place)
            exe.run(static.default_startup_program())
1417

1418 1419 1420 1421 1422
            # create data loader
            dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
            loader = paddle.io.DataLoader(dataset,
                feed_list=[image, label],
                places=place,
1423
                batch_size=BATCH_SIZE,
1424 1425
                shuffle=True,
                drop_last=True,
W
WeiXin 已提交
1426
                return_list=False,
1427
                num_workers=2)
1428 1429 1430 1431

            # 1. train and save inference model
            for data in loader():
                exe.run(
1432
                    static.default_main_program(),
1433
                    feed=data,
1434 1435 1436
                    fetch_list=[avg_loss])

            model_path = "fc.example.model"
1437
            paddle.fluid.io.save_inference_model(
1438 1439 1440
                model_path, ["image"], [pred], exe)

            # 2. load model
1441 1442

            # enable dygraph mode
1443 1444 1445 1446
            paddle.disable_static(place)

            # load
            fc = paddle.jit.load(model_path)
1447

1448 1449 1450
            # inference
            fc.eval()
            x = paddle.randn([1, IMAGE_SIZE], 'float32')
1451 1452
            pred = fc(x)

1453
            # fine-tune
1454
            fc.train()
1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471
            loss_fn = nn.CrossEntropyLoss()
            adam = opt.Adam(learning_rate=0.001, parameters=fc.parameters())
            loader = paddle.io.DataLoader(dataset,
                places=place,
                batch_size=BATCH_SIZE,
                shuffle=True,
                drop_last=True,
                num_workers=2)
            for epoch_id in range(EPOCH_NUM):
                for batch_id, (image, label) in enumerate(loader()):
                    out = fc(image)
                    loss = loss_fn(out, label)
                    loss.backward()
                    adam.step()
                    adam.clear_grad()
                    print("Epoch {} batch {}: loss = {}".format(
                        epoch_id, batch_id, np.mean(loss.numpy())))
1472
    """
1473 1474 1475 1476
    # 1. construct correct config
    config = _parse_load_config(configs)
    model_path, config = _build_load_path_and_config(path, config)

1477
    return TranslatedLayer._construct(model_path, config)
1478 1479


1480
@dygraph_only
1481 1482 1483
def _trace(
    layer, inputs, feed_prefix='feed_', fetch_prefix='fetch_', tmp_prefix='t_'
):
1484
    assert isinstance(layer, Layer)
1485 1486 1487 1488 1489 1490 1491 1492 1493

    if not isinstance(inputs, (list, tuple)):
        inputs = [inputs]

    tracer = _dygraph_tracer()._get_program_desc_tracer()

    var_list = extract_vars(inputs)

    with program_desc_tracing_guard(True):
1494
        original_outputs = layer(*inputs)
1495 1496 1497 1498
        if not isinstance(original_outputs, (list, tuple)):
            outputs = [original_outputs]
        else:
            outputs = original_outputs
1499
        out_vars = extract_vars(outputs, err_tag='outputs')
1500

1501 1502 1503 1504 1505 1506 1507 1508
        (
            program_desc,
            feed_names,
            fetch_names,
            parameters,
        ) = tracer.create_program_desc(
            var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix
        )
1509 1510 1511 1512 1513
        tracer.reset()

    with _dygraph_guard(None):
        program = create_program_from_desc(program_desc)

1514
    return original_outputs, program, feed_names, fetch_names, parameters
1515 1516


1517
class TracedLayer:
1518
    """
1519
    :api_attr: imperative
1520

1521 1522 1523 1524 1525
    TracedLayer is used to convert a forward dygraph model to a static
    graph model. This is mainly used to save the dygraph model for online
    inference using C++. Besides, users can also do inference in Python
    using the converted static graph model, which usually has better
    performance than the original dygraph model.
1526 1527 1528 1529

    TracedLayer would run the static graph model using :code:`Executor`
    and :code:`CompiledProgram` . The static graph model would share
    parameters with the dygraph model.
1530 1531

    All TracedLayer objects should not be created by constructor and should
1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542
    be created by static method :code:`TracedLayer.trace(layer, inputs)` .

    The TracedLayer can only be used to convert the data-independent dygraph
    model into the static graph model, which means the dygraph model should
    be independent with the tensor data and shape.
    """

    def __init__(self, program, parameters, feed_names, fetch_names):
        self._program = program
        self._feed_names = feed_names
        self._fetch_names = fetch_names
1543
        self._params = parameters
1544 1545 1546 1547 1548

        self._place = _current_expected_place()

        self._scope = core.Scope()
        for p in parameters:
1549
            src_tensor = p.value().get_tensor()
1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572
            dst_tensor = self._scope.var(p.name).get_tensor()
            dst_tensor._share_data_with(src_tensor)

        self._exe = Executor(self._place)
        self._compiled_program = None
        self._build_strategy = None
        self._exec_strategy = None

    @property
    def program(self):
        return self._program

    def _switch(self, is_test=True):
        for block_id in range(self._program.num_blocks):
            block = self._program.block(block_id)
            for op in block.ops:
                if op.has_attr("is_test"):
                    op._set_attr("is_test", is_test)

    @staticmethod
    @dygraph_only
    def trace(layer, inputs):
        """
1573
        This method is the only allowed method to create TracedLayer object.
1574 1575 1576 1577
        It would call the :code:`layer(*inputs)` method to run the dygraph
        model and convert it into a static graph model.

        Args:
1578
            layer (paddle.nn.Layer): the layer object to be traced.
1579 1580
            inputs (list(Tensor)|tuple(Tensor)|Tensor): the input tensors of
                the layer object.
1581 1582

        Returns:
1583
            tuple: A tuple of 2 items, whose the first item is the output of
1584 1585
                :code:`layer(*inputs)` , and the second item is the created
                TracedLayer object.
1586

1587
        Examples:
1588 1589
            .. code-block:: python:

1590
                import paddle
1591

1592
                class ExampleLayer(paddle.nn.Layer):
1593
                    def __init__(self):
1594
                        super().__init__()
1595
                        self._fc = paddle.nn.Linear(3, 10)
1596 1597 1598 1599

                    def forward(self, input):
                        return self._fc(input)

1600

1601 1602 1603 1604 1605 1606
                layer = ExampleLayer()
                in_var = paddle.uniform(shape=[2, 3], dtype='float32')
                out_dygraph, static_layer = paddle.jit.TracedLayer.trace(layer, inputs=[in_var])

                # run the static graph model using Executor inside
                out_static_graph = static_layer([in_var])
1607

1608 1609
                print(len(out_static_graph)) # 1
                print(out_static_graph[0].shape) # (2, 10)
1610

1611
                # save the static graph model for inference
1612
                static_layer.save_inference_model('./saved_infer_model')
1613

1614
        """
1615 1616
        assert isinstance(
            layer, Layer
1617
        ), "The type of 'layer' in paddle.jit.TracedLayer.trace must be fluid.dygraph.Layer, but received {}.".format(
1618 1619
            type(layer)
        )
1620 1621
        outs, prog, feed, fetch, parameters = _trace(layer, inputs)
        traced = TracedLayer(prog, parameters, feed, fetch)
1622 1623 1624 1625 1626 1627 1628
        return outs, traced

    def set_strategy(self, build_strategy=None, exec_strategy=None):
        """
        Set the strategies when running static graph model.

        Args:
1629
            build_strategy (BuildStrategy, optional): build strategy of
1630 1631 1632 1633 1634 1635 1636 1637 1638 1639
                :code:`CompiledProgram` inside TracedLayer. Default None.
            exec_strategy (ExecutionStrategy, optional): execution strategy of
                :code:`CompiledProgram` inside TracedLayer. Default None.

        Returns:
            None

        Examples:
            .. code-block:: python:

1640
                import paddle
1641

1642
                class ExampleLayer(paddle.nn.Layer):
1643
                    def __init__(self):
1644
                        super().__init__()
1645
                        self._fc = paddle.nn.Linear(3, 10)
1646 1647 1648 1649

                    def forward(self, input):
                        return self._fc(input)

1650 1651 1652 1653
                layer = ExampleLayer()
                in_var = paddle.uniform(shape=[2, 3], dtype='float32')

                out_dygraph, static_layer = paddle.jit.TracedLayer.trace(layer, inputs=[in_var])
1654

1655 1656
                build_strategy = paddle.static.BuildStrategy()
                build_strategy.enable_inplace = True
1657

1658 1659
                exec_strategy = paddle.static.ExecutionStrategy()
                exec_strategy.num_threads = 2
1660

1661 1662
                static_layer.set_strategy(build_strategy=build_strategy, exec_strategy=exec_strategy)
                out_static_graph = static_layer([in_var])
1663 1664 1665

        """
        assert self._compiled_program is None, "Cannot set strategy after run"
1666 1667
        assert isinstance(
            build_strategy, (type(None), BuildStrategy)
1668
        ), "The type of 'build_strategy' in paddle.jit.TracedLayer.set_strategy must be fluid.BuildStrategy, but received {}.".format(
1669 1670
            type(build_strategy)
        )
1671 1672
        assert isinstance(
            exec_strategy, (type(None), ExecutionStrategy)
1673
        ), "The type of 'exec_strategy' in paddle.jit.TracedLayer.set_strategy must be fluid.ExecutionStrategy, but received {}.".format(
1674 1675
            type(exec_strategy)
        )
1676 1677 1678 1679 1680 1681
        self._build_strategy = build_strategy
        self._exec_strategy = exec_strategy

    @switch_to_static_graph
    def _compile(self):
        self._compiled_program = CompiledProgram(
1682 1683 1684 1685 1686 1687
            self._program
        ).with_data_parallel(
            build_strategy=self._build_strategy,
            exec_strategy=self._exec_strategy,
            places=self._place,
        )
1688 1689

    def _build_feed(self, inputs):
1690 1691 1692
        assert isinstance(
            inputs, (list, tuple)
        ), "Inputs should be a list or tuple of variables"
1693 1694
        assert len(inputs) == len(self._feed_names)
        feed_dict = {}
J
Jiabin Yang 已提交
1695
        if _non_static_mode():
1696
            for x, name in zip(inputs, self._feed_names):
1697
                feed_dict[name] = x.value().get_tensor()
1698 1699 1700 1701 1702 1703 1704 1705
        else:
            for x, name in zip(inputs, self._feed_names):
                feed_dict[name] = x

        return feed_dict

    @switch_to_static_graph
    def _run(self, feed):
1706 1707 1708
        return self._exe.run(
            self._compiled_program, feed=feed, fetch_list=self._fetch_names
        )
1709 1710 1711 1712 1713 1714 1715 1716 1717

    def __call__(self, inputs):
        with scope_guard(self._scope):
            if self._compiled_program is None:
                self._compile()

            return self._run(self._build_feed(inputs))

    @switch_to_static_graph
1718
    def save_inference_model(self, path, feed=None, fetch=None, **kwargs):
1719
        """
1720 1721
        Save the TracedLayer to a model for inference. The saved
        inference model can be loaded by C++ inference APIs.
1722

1723 1724 1725
        ``path`` is the prefix of saved objects, and the saved translated program file
        suffix is ``.pdmodel`` , the saved persistable variables file suffix is ``.pdiparams`` .

1726
        Args:
1727
            path(str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``.
1728
            feed (list[int], optional): the input variable indices of the saved
1729
                inference model. If None, all input variables of the
1730 1731 1732 1733 1734 1735
                TracedLayer object would be the inputs of the saved inference
                model. Default None.
            fetch (list[int], optional): the output variable indices of the
                saved inference model. If None, all output variables of the
                TracedLayer object would be the outputs of the saved inference
                model. Default None.
1736
            kwargs: Supported keys including 'clip_extra'.set to True if you want to clip extra information for every operator.
1737 1738

        Returns:
1739
            None
1740 1741 1742 1743 1744

        Examples:
            .. code-block:: python:

                import numpy as np
1745
                import paddle
1746

1747
                class ExampleLayer(paddle.nn.Layer):
1748
                    def __init__(self):
1749
                        super().__init__()
1750
                        self._fc = paddle.nn.Linear(3, 10)
1751 1752 1753 1754

                    def forward(self, input):
                        return self._fc(input)

1755 1756
                save_dirname = './saved_infer_model'
                in_np = np.random.random([2, 3]).astype('float32')
1757 1758
                in_var = paddle.to_tensor(in_np)
                layer = ExampleLayer()
1759

1760 1761
                out_dygraph, static_layer = paddle.jit.TracedLayer.trace(layer, inputs=[in_var])
                static_layer.save_inference_model(save_dirname, feed=[0], fetch=[0])
1762

1763 1764 1765 1766
                paddle.enable_static()
                place = paddle.CPUPlace()
                exe = paddle.static.Executor(place)
                program, feed_vars, fetch_vars = paddle.static.load_inference_model(save_dirname,
1767
                                                    exe)
1768 1769 1770

                fetch, = exe.run(program, feed={feed_vars[0]: in_np}, fetch_list=fetch_vars)
                print(fetch.shape) # (2, 10)
1771
        """
1772 1773 1774 1775
        check_type(
            path,
            "path",
            str,
1776
            "paddle.jit.TracedLayer.save_inference_model",
1777 1778 1779 1780 1781
        )
        check_type(
            feed,
            "feed",
            (type(None), list),
1782
            "paddle.jit.TracedLayer.save_inference_model",
1783
        )
1784 1785
        if isinstance(feed, list):
            for f in feed:
1786
                check_type(
1787 1788 1789
                    f,
                    "each element of feed",
                    int,
1790
                    "paddle.jit.TracedLayer.save_inference_model",
1791 1792 1793 1794 1795
                )
        check_type(
            fetch,
            "fetch",
            (type(None), list),
1796
            "paddle.jit.TracedLayer.save_inference_model",
1797
        )
1798 1799
        if isinstance(fetch, list):
            for f in fetch:
1800
                check_type(
1801 1802 1803
                    f,
                    "each element of fetch",
                    int,
1804
                    "paddle.jit.TracedLayer.save_inference_model",
1805
                )
1806
        clip_extra = kwargs.get('clip_extra', True)
1807 1808 1809 1810 1811 1812
        # path check
        file_prefix = os.path.basename(path)
        if file_prefix == "":
            raise ValueError(
                "The input path MUST be format of dirname/file_prefix "
                "[dirname\\file_prefix in Windows system], but received "
1813 1814
                "file_prefix is empty string."
            )
1815 1816 1817 1818 1819

        dirname = os.path.dirname(path)
        if dirname and not os.path.exists(dirname):
            os.makedirs(dirname)

1820
        from paddle.fluid.io import save_inference_model
1821 1822 1823 1824 1825

        def get_feed_fetch(all_vars, partial_vars):
            if partial_vars is None:
                return all_vars

1826
            return [all_vars[idx] for idx in partial_vars]
1827 1828 1829 1830 1831 1832 1833 1834 1835 1836

        with scope_guard(self._scope):
            feeded_var_names = get_feed_fetch(self._feed_names, feed)
            target_var_names = get_feed_fetch(self._fetch_names, fetch)
            target_vars = []
            for name in target_var_names:
                target_var = self._program.global_block().vars.get(name, None)
                assert target_var is not None, "{} cannot be found".format(name)
                target_vars.append(target_var)

1837 1838 1839
            model_filename = file_prefix + INFER_MODEL_SUFFIX
            params_filename = file_prefix + INFER_PARAMS_SUFFIX

1840 1841 1842 1843 1844 1845 1846 1847 1848 1849
            save_inference_model(
                dirname=dirname,
                feeded_var_names=feeded_var_names,
                target_vars=target_vars,
                executor=self._exe,
                main_program=self._program.clone(),
                model_filename=model_filename,
                params_filename=params_filename,
                clip_extra=clip_extra,
            )