jit.py 49.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

17 18
import os
import pickle
19
import warnings
20
import functools
21
from collections import OrderedDict
22 23

import six
24
import paddle
25
from paddle.fluid import core
26 27
from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy
from paddle.fluid.data_feeder import check_type
28
from paddle.fluid.layers.utils import flatten
29
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
30
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
31
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import set_code_level, set_verbosity
32
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, StaticFunction, unwrap_decorators
33
from paddle.fluid.dygraph.io import TranslatedLayer, INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
34 35
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.executor import Executor, scope_guard
36 37 38
from paddle.fluid.framework import Block, ParamBase, Program, Variable
from paddle.fluid.framework import _current_expected_place, _dygraph_guard, _dygraph_tracer
from paddle.fluid.framework import dygraph_only, in_dygraph_mode
39
from paddle.fluid.wrapped_decorator import wrap_decorator
40

41 42
__all__ = [
    'TracedLayer', 'declarative', 'dygraph_to_static_func', 'set_code_level',
43
    'set_verbosity', 'save', 'load'
44
]
45 46 47 48 49 50 51 52 53 54 55 56


def create_program_from_desc(program_desc):
    program = Program()
    program.desc = program_desc
    program.blocks = [Block(program, 0)]
    program._sync_with_cpp()
    return program


def _extract_vars(inputs, result_list):
    if isinstance(inputs, Variable):
57
        result_list.append(inputs)
58
    elif isinstance(inputs, (list, tuple)):
59 60
        for var in inputs:
            _extract_vars(var, result_list)
61 62 63 64
    else:
        raise TypeError(
            "The type of 'each element of inputs' in fluid.dygraph.jit.TracedLayer.trace must be fluid.Variable, but received {}.".
            format(type(inputs)))
65 66 67 68 69 70 71 72


def extract_vars(inputs):
    result_list = []
    _extract_vars(inputs, result_list)
    return result_list


73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
def _dygraph_to_static_func_(dygraph_func):
    """
    Converts imperative dygraph APIs into declarative function APIs. Decorator
    @dygraph_to_static_func only converts imperative dygraph APIs into
    declarative net-building APIs, which means it doesn't return immediate
    digital result as imperative mode. Users should handle Program and Executor
    by themselves.

    Note:
    This decorator is NOT our recommended way to transform imperative function
    to declarative function. We will remove this decorator after we finalize
    cleaning up code.

    Args:
        dygraph_func (callable): callable imperative function.

    Returns:
        Callable: converting imperative dygraph APIs into declarative
        net-building APIs.

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          import numpy as np
          from paddle.fluid.dygraph.jit import dygraph_to_static_func

          @dygraph_to_static_func
          def func(x):
              if fluid.layers.mean(x) < 0:
                  x_v = x - 1
              else:
                  x_v = x + 1

               return x_v

          x = fluid.layers.fill_constant(shape=[3, 3], value=0, dtype='float64')

          x_v = func(x)
          exe = fluid.Executor(fluid.CPUPlace())
          out = exe.run(fetch_list=[x_v])
          print(out[0])
          # [[1. 1. 1.]
          #  [1. 1. 1.]
          #  [1. 1. 1.]]

    """

    # TODO: remove this decorator after we finalize training API
122 123
    def __impl__(*args, **kwargs):
        program_translator = ProgramTranslator()
124
        if in_dygraph_mode() or not program_translator.enable_to_static:
125
            logging_utils.warn(
126
                "The decorator 'dygraph_to_static_func' doesn't work in "
127
                "dygraph mode or set ProgramTranslator.enable to False. "
128 129 130 131
                "We will just return dygraph output.")
            return dygraph_func(*args, **kwargs)
        static_func = program_translator.get_func(dygraph_func)
        return static_func(*args, **kwargs)
132 133 134 135

    return __impl__


136
dygraph_to_static_func = wrap_decorator(_dygraph_to_static_func_)
137

138

139 140 141 142 143 144
def copy_decorator_attrs(original_func, decorated_obj):
    """
    Copies some necessary attributes from original function into decorated function.

    Args:
        original_func(callable): the original decorated function.
145
        decorated_obj(StaticFunction): the target decorated StaticFunction object.
146 147 148 149 150 151 152 153 154 155 156 157 158 159
    """
    decorator_name = "declarative"

    decorated_obj.__name__ = original_func.__name__
    decorated_obj._decorator_name = decorator_name
    decorated_obj.__wrapped__ = original_func
    decorated_obj.__doc__ = original_func.__doc__
    if hasattr(original_func, "__module__"):
        decorated_obj.__module__ = original_func.__module__

    return decorated_obj


def declarative(function=None, input_spec=None):
160 161 162
    """
    Converts imperative dygraph APIs into declarative function APIs. Decorator
    @declarative handles the Program and Executor of static mode and returns
163 164 165 166
    the result as dygraph Tensor(s). Users could use the returned dygraph
    Tensor(s) to do imperative training, inference, or other operations. If the
    decorated function calls other imperative function, the called one will be
    converted into declarative function as well.
167

168
    Args:
169 170 171
        function (callable): callable imperative function.
        input_spec(list[InputSpec]): list of InputSpec to specific the shape/dtype/name
            information of each input Tensor.
172

173
    Returns:
174
        Tensor(s): containing the numerical result.
175

176 177
    Examples:
        .. code-block:: python
178

179 180 181 182 183 184 185 186 187 188 189 190 191 192
            import paddle
            from paddle.jit import to_static

            @to_static
            def func(x):
                if paddle.mean(x) < 0:
                    x_v = x - 1
                else:
                    x_v = x + 1
                return x_v

            x = paddle.ones([1, 2], dtype='float32')
            x_v = func(x)
            print(x_v) # [[2. 2.]]
193

194
    """
195

196 197
    def decorated(python_func):
        """
198
        Decorates a python function into a StaticFunction object.
199 200 201
        """
        # Step 1. unwrap the function if it is already decorated.
        _, python_func = unwrap_decorators(python_func)
202

203 204 205
        # Step 2. copy some attributes from original python function.
        static_layer = copy_decorator_attrs(
            original_func=python_func,
206
            decorated_obj=StaticFunction(
207 208 209
                function=python_func, input_spec=input_spec))

        return static_layer
210

211 212
    # for usage: `declarative(foo, ...)`
    if function is not None:
213
        if isinstance(function, Layer):
214
            if isinstance(function.forward, StaticFunction):
215
                class_name = function.__class__.__name__
216
                logging_utils.warn(
217 218 219 220 221 222
                    "`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one.".
                    format(class_name))
            function.forward = decorated(function.forward)
            return function
        else:
            return decorated(function)
223

224 225
    # for usage: `@declarative`
    return decorated
226 227


228
class _SaveLoadConfig(object):
229 230 231 232 233
    def __init__(self):
        self._output_spec = None
        self._model_filename = None
        self._params_filename = None
        self._separate_params = False
234 235
        # used for `paddle.load`
        self._keep_name_table = False
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253

        # NOTE: Users rarely use following configs, so these configs are not open to users,
        # reducing user learning costs, but we retain the configuration capabilities

        # If True, programs are modified to only support direct inference deployment. 
        # Otherwise,more information will be stored for flexible optimization and re-training. 
        # Currently, only True is supported
        self._export_for_deployment = True

        # If True, It will save inference program only, and do not save params of Program
        self._program_only = False

    @property
    def output_spec(self):
        return self._output_spec

    @output_spec.setter
    def output_spec(self, spec):
254 255
        if spec is None:
            return
256 257
        if not isinstance(spec, list):
            raise TypeError(
258
                "The config `output_spec` should be 'list', but received input type is %s."
259 260 261 262
                % type(input))
            for var in spec:
                if not isinstance(var, core.VarBase):
                    raise TypeError(
263
                        "The element in config `output_spec` list should be 'Variable', but received element's type is %s."
264 265 266 267 268 269 270 271 272
                        % type(var))
        self._output_spec = spec

    @property
    def model_filename(self):
        return self._model_filename

    @model_filename.setter
    def model_filename(self, filename):
273 274
        if filename is None:
            return
275 276
        if not isinstance(filename, six.string_types):
            raise TypeError(
277
                "The config `model_filename` should be str, but received input's type is %s."
278 279
                % type(filename))
        if len(filename) == 0:
280
            raise ValueError("The config `model_filename` is empty string.")
281 282 283 284 285 286 287 288
        self._model_filename = filename

    @property
    def params_filename(self):
        return self._params_filename

    @params_filename.setter
    def params_filename(self, filename):
289 290
        if filename is None:
            return
291 292
        if not isinstance(filename, six.string_types):
            raise TypeError(
293
                "The config `params_filename` should be str, but received input's type is %s."
294 295
                % type(filename))
        if len(filename) == 0:
296
            raise ValueError("The config `params_filename` is empty string.")
297 298
        self._params_filename = filename

299 300 301 302 303 304
    @property
    def keep_name_table(self):
        return self._keep_name_table

    @keep_name_table.setter
    def keep_name_table(self, value):
305 306
        if value is None:
            return
307 308
        if not isinstance(value, bool):
            raise TypeError(
309
                "The config `keep_name_table` should be bool value, but received input's type is %s."
310 311 312
                % type(value))
        self._keep_name_table = value

313

314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
def _parse_save_configs(configs):
    supported_configs = ['output_spec']

    # input check
    for key in configs:
        if key not in supported_configs:
            raise ValueError(
                "The additional config (%s) of `paddle.jit.save` is not supported."
                % (key))

    # construct inner config
    inner_config = _SaveLoadConfig()
    inner_config.output_spec = configs.get('output_spec', None)

    return inner_config


def _parse_load_config(configs):
    supported_configs = ['model_filename', 'params_filename']

    # input check
    for key in configs:
        if key not in supported_configs:
            raise ValueError(
                "The additional config (%s) of `paddle.jit.load` is not supported."
                % (key))

    # construct inner config
    inner_config = _SaveLoadConfig()
    inner_config.model_filename = configs.get('model_filename', None)
    inner_config.params_filename = configs.get('params_filename', None)

    return inner_config


349 350 351 352 353 354 355 356 357 358
def _get_input_var_names(inputs, input_spec):
    name_none_error = "The %s's name is None. " \
        "When using jit.save, please set InputSepc's name in " \
        "to_static(input_spec=[]) and jit.save(input_spec=[]) " \
        "and make sure they are consistent."
    name_no_exists_error = "The tensor `%s` does not exists. " \
        "Please make sure the name of InputSpec or example Tensor " \
        "in input_spec is the same as the name of InputSpec in " \
        "`to_static` decorated on the Layer.forward method."
    result_list = []
359 360 361
    input_var_names = [
        var.name for var in flatten(inputs) if isinstance(var, Variable)
    ]
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398
    if input_spec is None:
        # no prune
        result_list = input_var_names
    elif input_spec is not None and len(input_spec) == len(input_var_names):
        # no prune
        result_list = input_var_names
        # if input spec name not in input_var_names, only raise warning 
        for spec in input_spec:
            if spec.name is None:
                warnings.warn(name_none_error % spec)
            elif spec.name not in input_var_names:
                warnings.warn(name_no_exists_error % spec.name)
            else:
                # do nothing
                pass
    else:
        # prune
        for spec in input_spec:
            if spec.name is None:
                # name is None, the input_spec only can be InputSpec
                raise ValueError(name_none_error % spec)
            elif spec.name not in input_var_names:
                # the input_spec can be `InputSpec` or `VarBase`
                raise ValueError(name_no_exists_error % spec.name)
            else:
                result_list.append(spec.name)

    return result_list


def _get_output_vars(outputs, output_spec):
    name_no_exists_error = "The tensor `%s` does not exists. " \
        "Please make sure the name of example Tensor " \
        "in configs.output_spec is the output tensor of " \
        "Layer.forward method."
    result_list = []
    output_vars_dict = OrderedDict()
399
    for var in flatten(outputs):
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
        if isinstance(var, Variable):
            output_vars_dict[var.name] = var
    if output_spec is None:
        result_list = output_vars_dict.values()
    elif output_spec is not None and len(output_spec) == len(output_vars_dict):
        result_list = output_vars_dict.values()
        for var in output_spec:
            if var.name not in output_vars_dict:
                warnings.warn(name_no_exists_error % var.name)
    else:
        for var in output_spec:
            if var.name not in output_vars_dict:
                raise ValueError(name_no_exists_error % var.name)
            else:
                result_list.append(output_vars_dict[var.name])
    return result_list


418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
# NOTE(chenweihang): [ Handling of use cases of API paddle.jit.load ]
# `paddle.jit.load` may be used to load saved results of:
# 1. Expected cases:
#   - paddle.jit.save
#   - paddle.static.save_inference_model
#   - paddle.fluid.io.save_inference_model
# 2. Error cases:
#   - paddle.save: no .pdmodel for prefix
#   - paddle.static.save: no .pdiparams but .pdparams exists
#   - paddle.fluid.io.save_params/save_persistables: no __model__
# TODO(chenweihang): polish error message in above error cases
def _build_load_path_and_config(path, config):
    # NOTE(chenweihang): If both [prefix save format] and [directory save format] exist,
    # raise error, avoid confusing behavior
    prefix_format_path = path + INFER_MODEL_SUFFIX
    prefix_format_exist = os.path.exists(prefix_format_path)
    directory_format_exist = os.path.isdir(path)
    if prefix_format_exist and directory_format_exist:
        raise ValueError(
            "The %s.pdmodel and %s directory exist at the same time, "
            "don't know which one to load, please make sure that the specified target "
            "of ``path`` is unique." % (path, path))
    elif not prefix_format_exist and not directory_format_exist:
        raise ValueError("The ``path`` (%s) to load model not exists." % path)
    else:
        if prefix_format_exist:
            file_prefix = os.path.basename(path)
            model_path = os.path.dirname(path)
            if config.model_filename is not None:
                warnings.warn(
                    "When loading the result saved with the "
                    "specified file prefix, the ``model_filename`` config does "
                    "not take effect.")
            config.model_filename = file_prefix + INFER_MODEL_SUFFIX
            if config.params_filename is not None:
                warnings.warn(
                    "When loading the result saved with the "
                    "specified file prefix, the ``params_filename`` config does "
                    "not take effect.")
            config.params_filename = file_prefix + INFER_PARAMS_SUFFIX
        else:
            # Compatible with the old save_inference_model format
            model_path = path
461

462
    return model_path, config
463 464


465
@switch_to_static_graph
466
def save(layer, path, input_spec=None, **configs):
467
    """
468
    Saves input Layer as ``paddle.jit.TranslatedLayer``
469 470 471
    format model, which can be used for inference or fine-tuning after loading.

    It will save the translated program and all related persistable 
472
    variables of input Layer to given ``path`` .
473
    
474
    ``path`` is the prefix of saved objects, and the saved translated program file 
475
    suffix is ``.pdmodel`` , the saved persistable variables file suffix is ``.pdiparams`` ,
476 477
    and here also saved some additional variable description information to a file,  
    its suffix is ``.pdiparams.info``, these additional information is used in fine-tuning.
478 479

    The saved model can be loaded by follow APIs:
480 481
      - ``paddle.jit.load`` 
      - ``paddle.static.load_inference_model`` 
482 483 484
      - Other C++ inference APIs

    Args:
485
        layer (Layer): The Layer to be saved.
486
        path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``.
487 488 489 490
        input_spec (list[InputSpec|Tensor], optional): Describes the input of the saved model's forward 
            method, which can be described by InputSpec or example Tensor. If None, all input variables of 
            the original Layer's forward method would be the inputs of the saved model. Default None.
        **configs (dict, optional): Other save configuration options for compatibility. We do not 
491 492 493 494
            recommend using these configurations, they may be removed in the future. If not necessary, 
            DO NOT use them. Default None.
            The following options are currently supported:
            (1) output_spec (list[Tensor]): Selects the output targets of the saved model.
495
            By default, all return variables of original Layer's forward method are kept as the 
496 497 498
            output of the saved model. If the provided ``output_spec`` list is not all output variables, 
            the saved model will be pruned according to the given ``output_spec`` list. 

499 500 501 502 503 504 505
    Returns:
        None

    Examples:
        .. code-block:: python

            import numpy as np
506 507 508
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
509

510 511 512
            BATCH_SIZE = 16
            BATCH_NUM = 4
            EPOCH_NUM = 4
513

514 515 516 517 518 519 520
            IMAGE_SIZE = 784
            CLASS_NUM = 10

            # define a random dataset
            class RandomDataset(paddle.io.Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples
521

522 523 524 525
                def __getitem__(self, idx):
                    image = np.random.random([IMAGE_SIZE]).astype('float32')
                    label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
                    return image, label
526

527 528
                def __len__(self):
                    return self.num_samples
529

530 531
            class LinearNet(nn.Layer):
                def __init__(self):
532
                    super(LinearNet, self).__init__()
533
                    self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
534

535
                @paddle.jit.to_static
536 537 538
                def forward(self, x):
                    return self._linear(x)

539 540 541 542 543 544 545 546 547 548 549 550
            def train(layer, loader, loss_fn, opt):
                for epoch_id in range(EPOCH_NUM):
                    for batch_id, (image, label) in enumerate(loader()):
                        out = layer(image)
                        loss = loss_fn(out, label)
                        loss.backward()
                        opt.step()
                        opt.clear_grad()
                        print("Epoch {} batch {}: loss = {}".format(
                            epoch_id, batch_id, np.mean(loss.numpy())))

            # 1. train & save model.
551

552 553 554 555
            # create network
            layer = LinearNet()
            loss_fn = nn.CrossEntropyLoss()
            adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())
556

557 558 559 560 561 562 563
            # create data loader
            dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
            loader = paddle.io.DataLoader(dataset,
                batch_size=BATCH_SIZE,
                shuffle=True,
                drop_last=True,
                num_workers=2)
564

565 566
            # train
            train(layer, loader, loss_fn, adam)
567

568
            # save
569 570
            path = "example_model/linear"
            paddle.jit.save(layer, path)
571 572
    """

573
    # 1. input build & check
574
    prog_translator = ProgramTranslator()
575
    if not prog_translator.enable_to_static:
576
        raise RuntimeError(
577
            "The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
578 579 580
        )
    if not isinstance(layer, Layer):
        raise TypeError(
581
            "The input layer of paddle.jit.save should be 'Layer', but received layer type is %s."
582 583
            % type(layer))

584 585 586 587 588 589 590 591 592 593
    # NOTE(chenweihang): If the input layer be wrapped by DataParallel,
    # the args and kwargs of forward method will can't be parsed by
    # function_spec, so here we save DataParallel._layers instead 
    # DataParallel it self
    # NOTE(chenweihang): using inner_layer, do not change input layer
    if isinstance(layer, paddle.DataParallel):
        inner_layer = layer._layers
    else:
        inner_layer = layer

594 595 596 597 598 599 600 601 602 603 604
    # path check
    file_prefix = os.path.basename(path)
    if file_prefix == "":
        raise ValueError(
            "The input path MUST be format of dirname/file_prefix "
            "[dirname\\file_prefix in Windows system], but received "
            "file_prefix is empty string.")

    dirname = os.path.dirname(path)
    if dirname and not os.path.exists(dirname):
        os.makedirs(dirname)
605

606 607
    # avoid change user given input_spec
    inner_input_spec = None
608
    if input_spec is not None:
609 610
        for attr_func in dir(inner_layer):
            static_func = getattr(inner_layer, attr_func, None)
611 612 613 614 615
            if isinstance(static_func,
                          StaticFunction) and 'forward' != attr_func:
                raise ValueError(
                    "If there are static functions other than 'forward' that need to be saved, the input 'input_spec' should be None, but received the type of 'input_spec' is %s."
                    % type(input_spec))
616 617 618 619
        if not isinstance(input_spec, list):
            raise TypeError(
                "The input input_spec should be 'list', but received input_spec's type is %s."
                % type(input_spec))
620
        inner_input_spec = []
621
        for var in flatten(input_spec):
622 623 624 625 626 627
            if isinstance(var, paddle.static.InputSpec):
                inner_input_spec.append(var)
            elif isinstance(var, (core.VarBase, Variable)):
                inner_input_spec.append(
                    paddle.static.InputSpec.from_tensor(var))
            else:
628
                raise TypeError(
629
                    "The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s."
630 631
                    % type(var))

632 633
    # parse configs
    configs = _parse_save_configs(configs)
634 635
    scope = core.Scope()
    extra_var_info = dict()
636 637
    for attr_func in dir(inner_layer):
        static_func = getattr(inner_layer, attr_func, None)
638 639 640 641 642
        if isinstance(static_func, StaticFunction):
            concrete_program = static_func.concrete_program
        elif 'forward' == attr_func:
            # transform in jit.save, if input_spec is incomplete, declarative will throw error
            static_forward = declarative(
643
                inner_layer.forward, input_spec=inner_input_spec)
644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675
            concrete_program = static_forward.concrete_program
            # the input_spec has been used in declarative, which is equal to 
            # @declarative with input_spec and jit.save without input_spec,
            # avoid needless warning
            inner_input_spec = None
        else:
            continue

        # 3. build input & output of save_infernece_model
        # NOTE(chenweihang): [ Get input variables name ]
        # There are two cases, whether to prune the inputs or not
        # - not prune inputs (recommend):
        #   - the len(input_spec) == len((concrete_program.inputs) - 1
        #   - here can use concrete_program.inputs directly
        # - prune inputs:
        #   - the input_spec length < len((concrete_program.inputs) - 1
        #   - the input_spec's name should be in concrete_program.inputs
        input_var_names = _get_input_var_names(concrete_program.inputs,
                                               inner_input_spec)

        # NOTE(chenweihang): [ Get output variables ]
        # the rule is like [ Get input variables name ]. For output var, 
        # we only support VarBase spec, and actually, we only need the 
        # var name of output, and we don't recommended to use output_spec
        output_vars = _get_output_vars(concrete_program.outputs,
                                       configs.output_spec)

        # NOTE(chenweihang): we maintain the mapping of variable name to
        # structured name, the buffer variable (non-persistable)
        # saved to inference program may not need by dygraph Layer, 
        # we only record the state_dict variable's structured name
        state_names_dict = dict()
676
        for structured_name, var in six.iteritems(inner_layer.state_dict()):
677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737
            state_names_dict[var.name] = structured_name

        # 4. share parameters from Layer to scope & record var info        
        for param_or_buffer in concrete_program.parameters:
            # share to scope
            param_or_buffer_tensor = scope.var(param_or_buffer.name).get_tensor(
            )
            src_tensor = param_or_buffer.value().get_tensor()
            param_or_buffer_tensor._share_data_with(src_tensor)
            # record var info
            if param_or_buffer.name not in extra_var_info:
                extra_info_dict = dict()
                if param_or_buffer.name in state_names_dict:
                    extra_info_dict['structured_name'] = state_names_dict[
                        param_or_buffer.name]
                extra_info_dict['stop_gradient'] = param_or_buffer.stop_gradient
                if isinstance(param_or_buffer, ParamBase):
                    extra_info_dict['trainable'] = param_or_buffer.trainable
                extra_var_info[param_or_buffer.name] = extra_info_dict

        # 5. save inference model
        from paddle.fluid.io import save_inference_model

        # construct new save_inference_model arguments
        model_path = dirname
        # NOTE(chenweihang): because prefix contains model and params filename,
        # so we don't support set model_filename & params_filename 
        if 'forward' == attr_func:
            model_filename = file_prefix + INFER_MODEL_SUFFIX
            params_filename = file_prefix + INFER_PARAMS_SUFFIX
        else:
            model_filename = file_prefix + '.' + attr_func + INFER_MODEL_SUFFIX
            params_filename = file_prefix + '.' + attr_func + INFER_PARAMS_SUFFIX

        with scope_guard(scope):
            save_inference_model(
                dirname=model_path,
                feeded_var_names=input_var_names,
                target_vars=output_vars,
                executor=Executor(_current_expected_place()),
                main_program=concrete_program.main_program.clone(),
                model_filename=model_filename,
                params_filename=params_filename,
                export_for_deployment=configs._export_for_deployment,
                program_only=configs._program_only)

    # NOTE(chenweihang): [ Save extra variable info ]
    # save_inference_model will lose some important variable information, including:
    #   - Variable name and correspondence (when saved variables as one file)
    #   - Variable.stop_gradient information
    #   - Which persistent variable are parameter and which are not
    #   - Parameter.trainable information
    #
    # The lost information cannot be recovered when it is loaded again, 
    # so if we want to perform fine-tune after loading, we may need to 
    # configure redundant information to proceed.
    #
    # Due to compatibility issues, we cannot change the original storage structure, 
    # but we can save these information in `jit.save` without changing the original 
    # storage to improve user experience. So we save extra information into
    # file `***.pdiparams.info`
738
    with scope_guard(scope):
739
        extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX
740 741 742 743 744
        with open(extra_var_info_path, 'wb') as f:
            pickle.dump(extra_var_info, f, protocol=2)


@dygraph_only
745
def load(path, **configs):
746 747 748
    """
    :api_attr: imperative

749 750 751
    Load model saved by ``paddle.jit.save`` or ``paddle.static.save_inference_model`` or 
    paddle 1.x API ``paddle.fluid.io.save_inference_model`` as ``paddle.jit.TranslatedLayer``, 
    then performing inference or fine-tune training.
752 753

    .. note::
754
        If you load model saved by ``paddle.static.save_inference_model`` ,
755 756
        there will be the following limitations when using it in fine-tuning:
        1. Imperative mode do not support LoDTensor. All original model's feed targets or parametars that depend on LoD are temporarily unavailable.
757
        2. All saved model's feed targets need to be passed into TranslatedLayer's forward function.
758 759 760 761
        3. The variable's ``stop_gradient`` information is lost and can not be recovered.
        4. The parameter's ``trainable`` information is lost and can not be recovered.

    Args:
762 763
        path (str): The path prefix to load model. The format is ``dirname/file_prefix`` or ``file_prefix`` .
        **configs (dict, optional): Other load configuration options for compatibility. We do not 
764 765 766
            recommend using these configurations, they may be removed in the future. If not necessary, 
            DO NOT use them. Default None.
            The following options are currently supported:
767
            (1) model_filename (str): The inference model file name of the paddle 1.x 
768
            ``save_inference_model`` save format. Default file name is :code:`__model__` . 
769
            (2) params_filename (str): The persistable variables file name of the paddle 1.x 
770 771 772
            ``save_inference_model`` save format. No default file name, save variables separately 
            by default.

773 774 775 776 777

    Returns:
        TranslatedLayer: A Layer object can run saved translated model.

    Examples:
778
        1. Load model saved by ``paddle.jit.save`` then performing inference and fine-tune training.
779 780 781 782

        .. code-block:: python

            import numpy as np
783 784 785
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
786

787 788 789
            BATCH_SIZE = 16
            BATCH_NUM = 4
            EPOCH_NUM = 4
790

791 792
            IMAGE_SIZE = 784
            CLASS_NUM = 10
793

794 795 796 797
            # define a random dataset
            class RandomDataset(paddle.io.Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples
798

799 800 801 802
                def __getitem__(self, idx):
                    image = np.random.random([IMAGE_SIZE]).astype('float32')
                    label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
                    return image, label
803

804 805 806 807 808
                def __len__(self):
                    return self.num_samples

            class LinearNet(nn.Layer):
                def __init__(self):
809
                    super(LinearNet, self).__init__()
810
                    self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
811

812
                @paddle.jit.to_static
813 814 815
                def forward(self, x):
                    return self._linear(x)

816 817 818 819 820 821 822 823 824 825 826
            def train(layer, loader, loss_fn, opt):
                for epoch_id in range(EPOCH_NUM):
                    for batch_id, (image, label) in enumerate(loader()):
                        out = layer(image)
                        loss = loss_fn(out, label)
                        loss.backward()
                        opt.step()
                        opt.clear_grad()
                        print("Epoch {} batch {}: loss = {}".format(
                            epoch_id, batch_id, np.mean(loss.numpy())))

827
            # 1. train & save model.
828

829
            # create network
830 831 832 833
            layer = LinearNet()
            loss_fn = nn.CrossEntropyLoss()
            adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())

834
            # create data loader
835 836 837 838 839 840
            dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
            loader = paddle.io.DataLoader(dataset,
                batch_size=BATCH_SIZE,
                shuffle=True,
                drop_last=True,
                num_workers=2)
841

842 843
            # train
            train(layer, loader, loss_fn, adam)
844

845
            # save
846 847
            path = "example_model/linear"
            paddle.jit.save(layer, path)
848

849
            # 2. load model
850

851
            # load
852
            loaded_layer = paddle.jit.load(path)
853 854

            # inference
855 856 857
            loaded_layer.eval()
            x = paddle.randn([1, IMAGE_SIZE], 'float32')
            pred = loaded_layer(x)
858 859

            # fine-tune
860 861 862
            loaded_layer.train()
            adam = opt.Adam(learning_rate=0.001, parameters=loaded_layer.parameters())
            train(loaded_layer, loader, loss_fn, adam)
863 864


865
        2. Load model saved by ``paddle.fluid.io.save_inference_model`` then performing and fine-tune training.
866 867 868 869

        .. code-block:: python

            import numpy as np
870
            import paddle
871
            import paddle.static as static
872 873
            import paddle.nn as nn
            import paddle.optimizer as opt
874
            import paddle.nn.functional as F
875

876 877 878
            BATCH_SIZE = 16
            BATCH_NUM = 4
            EPOCH_NUM = 4
879

880 881 882 883 884 885 886
            IMAGE_SIZE = 784
            CLASS_NUM = 10

            # define a random dataset
            class RandomDataset(paddle.io.Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples
887

888 889 890 891
                def __getitem__(self, idx):
                    image = np.random.random([IMAGE_SIZE]).astype('float32')
                    label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
                    return image, label
892

893 894
                def __len__(self):
                    return self.num_samples
895

896 897
            paddle.enable_static()

898 899
            image = static.data(name='image', shape=[None, 784], dtype='float32')
            label = static.data(name='label', shape=[None, 1], dtype='int64')
900
            pred = static.nn.fc(x=image, size=10, activation='softmax')
901 902
            loss = F.cross_entropy(input=pred, label=label)
            avg_loss = paddle.mean(loss)
903

904
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
905 906
            optimizer.minimize(avg_loss)

907 908 909
            place = paddle.CPUPlace()
            exe = static.Executor(place)
            exe.run(static.default_startup_program())
910

911 912 913 914 915 916 917 918 919
            # create data loader
            dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
            loader = paddle.io.DataLoader(dataset,
                feed_list=[image, label],
                places=place,
                batch_size=BATCH_SIZE, 
                shuffle=True,
                drop_last=True,
                num_workers=2)
920 921 922 923

            # 1. train and save inference model
            for data in loader():
                exe.run(
924
                    static.default_main_program(),
925 926 927 928
                    feed=data, 
                    fetch_list=[avg_loss])

            model_path = "fc.example.model"
929
            paddle.fluid.io.save_inference_model(
930 931 932
                model_path, ["image"], [pred], exe)

            # 2. load model
933 934

            # enable dygraph mode
935 936 937 938
            paddle.disable_static(place)

            # load
            fc = paddle.jit.load(model_path)
939

940 941 942
            # inference
            fc.eval()
            x = paddle.randn([1, IMAGE_SIZE], 'float32')
943 944
            pred = fc(x)

945
            # fine-tune
946
            fc.train()
947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963
            loss_fn = nn.CrossEntropyLoss()
            adam = opt.Adam(learning_rate=0.001, parameters=fc.parameters())
            loader = paddle.io.DataLoader(dataset,
                places=place,
                batch_size=BATCH_SIZE,
                shuffle=True,
                drop_last=True,
                num_workers=2)
            for epoch_id in range(EPOCH_NUM):
                for batch_id, (image, label) in enumerate(loader()):
                    out = fc(image)
                    loss = loss_fn(out, label)
                    loss.backward()
                    adam.step()
                    adam.clear_grad()
                    print("Epoch {} batch {}: loss = {}".format(
                        epoch_id, batch_id, np.mean(loss.numpy())))
964
    """
965 966 967 968
    # 1. construct correct config
    config = _parse_load_config(configs)
    model_path, config = _build_load_path_and_config(path, config)

969
    return TranslatedLayer._construct(model_path, config)
970 971


972
@dygraph_only
Z
Zeng Jinle 已提交
973 974 975 976 977
def _trace(layer,
           inputs,
           feed_prefix='feed_',
           fetch_prefix='fetch_',
           tmp_prefix='t_'):
978
    assert isinstance(layer, Layer)
979 980 981 982 983 984 985 986 987

    if not isinstance(inputs, (list, tuple)):
        inputs = [inputs]

    tracer = _dygraph_tracer()._get_program_desc_tracer()

    var_list = extract_vars(inputs)

    with program_desc_tracing_guard(True):
988
        original_outputs = layer(*inputs)
989 990 991 992
        if not isinstance(original_outputs, (list, tuple)):
            outputs = [original_outputs]
        else:
            outputs = original_outputs
993
        out_vars = [var for var in outputs]
994

995
        program_desc, feed_names, fetch_names, parameters = tracer.create_program_desc(
Z
Zeng Jinle 已提交
996
            var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix)
997 998 999 1000 1001
        tracer.reset()

    with _dygraph_guard(None):
        program = create_program_from_desc(program_desc)

1002
    return original_outputs, program, feed_names, fetch_names, parameters
1003 1004 1005 1006


class TracedLayer(object):
    """
1007 1008
    :api_attr: imperative
    
1009 1010 1011 1012 1013
    TracedLayer is used to convert a forward dygraph model to a static
    graph model. This is mainly used to save the dygraph model for online
    inference using C++. Besides, users can also do inference in Python
    using the converted static graph model, which usually has better
    performance than the original dygraph model.
1014 1015 1016 1017

    TracedLayer would run the static graph model using :code:`Executor`
    and :code:`CompiledProgram` . The static graph model would share
    parameters with the dygraph model.
1018 1019

    All TracedLayer objects should not be created by constructor and should
1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030
    be created by static method :code:`TracedLayer.trace(layer, inputs)` .

    The TracedLayer can only be used to convert the data-independent dygraph
    model into the static graph model, which means the dygraph model should
    be independent with the tensor data and shape.
    """

    def __init__(self, program, parameters, feed_names, fetch_names):
        self._program = program
        self._feed_names = feed_names
        self._fetch_names = fetch_names
1031
        self._params = parameters
1032 1033 1034 1035 1036

        self._place = _current_expected_place()

        self._scope = core.Scope()
        for p in parameters:
1037
            src_tensor = p.value().get_tensor()
1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060
            dst_tensor = self._scope.var(p.name).get_tensor()
            dst_tensor._share_data_with(src_tensor)

        self._exe = Executor(self._place)
        self._compiled_program = None
        self._build_strategy = None
        self._exec_strategy = None

    @property
    def program(self):
        return self._program

    def _switch(self, is_test=True):
        for block_id in range(self._program.num_blocks):
            block = self._program.block(block_id)
            for op in block.ops:
                if op.has_attr("is_test"):
                    op._set_attr("is_test", is_test)

    @staticmethod
    @dygraph_only
    def trace(layer, inputs):
        """
1061
        This method is the only allowed method to create TracedLayer object.
1062 1063 1064 1065
        It would call the :code:`layer(*inputs)` method to run the dygraph
        model and convert it into a static graph model.

        Args:
1066
            layer (paddle.nn.Layer): the layer object to be traced.
1067 1068
            inputs (list(Tensor)|tuple(Tensor)|Tensor): the input tensors of
                the layer object.
1069 1070

        Returns:
1071
            tuple: A tuple of 2 items, whose the first item is the output of
1072 1073
                :code:`layer(*inputs)` , and the second item is the created
                TracedLayer object.
1074

1075
        Examples:
1076 1077
            .. code-block:: python:

1078
                import paddle
1079

1080
                class ExampleLayer(paddle.nn.Layer):
1081 1082
                    def __init__(self):
                        super(ExampleLayer, self).__init__()
1083
                        self._fc = paddle.nn.Linear(3, 10)
1084 1085 1086 1087

                    def forward(self, input):
                        return self._fc(input)

1088 1089 1090 1091 1092 1093 1094
                
                layer = ExampleLayer()
                in_var = paddle.uniform(shape=[2, 3], dtype='float32')
                out_dygraph, static_layer = paddle.jit.TracedLayer.trace(layer, inputs=[in_var])

                # run the static graph model using Executor inside
                out_static_graph = static_layer([in_var])
1095

1096 1097
                print(len(out_static_graph)) # 1
                print(out_static_graph[0].shape) # (2, 10)
1098

1099 1100
                # save the static graph model for inference
                static_layer.save_inference_model(dirname='./saved_infer_model')
1101

1102
        """
1103 1104 1105 1106
        assert isinstance(
            layer, Layer
        ), "The type of 'layer' in fluid.dygraph.jit.TracedLayer.trace must be fluid.dygraph.Layer, but received {}.".format(
            type(layer))
1107 1108
        outs, prog, feed, fetch, parameters = _trace(layer, inputs)
        traced = TracedLayer(prog, parameters, feed, fetch)
1109 1110 1111 1112 1113 1114 1115
        return outs, traced

    def set_strategy(self, build_strategy=None, exec_strategy=None):
        """
        Set the strategies when running static graph model.

        Args:
1116
            build_strategy (BuildStrategy, optional): build strategy of
1117 1118 1119 1120 1121 1122 1123 1124 1125 1126
                :code:`CompiledProgram` inside TracedLayer. Default None.
            exec_strategy (ExecutionStrategy, optional): execution strategy of
                :code:`CompiledProgram` inside TracedLayer. Default None.

        Returns:
            None

        Examples:
            .. code-block:: python:

1127
                import paddle
1128

1129
                class ExampleLayer(paddle.nn.Layer):
1130 1131
                    def __init__(self):
                        super(ExampleLayer, self).__init__()
1132
                        self._fc = paddle.nn.Linear(3, 10)
1133 1134 1135 1136

                    def forward(self, input):
                        return self._fc(input)

1137 1138 1139 1140
                layer = ExampleLayer()
                in_var = paddle.uniform(shape=[2, 3], dtype='float32')

                out_dygraph, static_layer = paddle.jit.TracedLayer.trace(layer, inputs=[in_var])
1141

1142 1143
                build_strategy = paddle.static.BuildStrategy()
                build_strategy.enable_inplace = True
1144

1145 1146
                exec_strategy = paddle.static.ExecutionStrategy()
                exec_strategy.num_threads = 2
1147

1148 1149
                static_layer.set_strategy(build_strategy=build_strategy, exec_strategy=exec_strategy)
                out_static_graph = static_layer([in_var])
1150 1151 1152

        """
        assert self._compiled_program is None, "Cannot set strategy after run"
1153 1154 1155 1156 1157 1158 1159 1160
        assert isinstance(
            build_strategy, (type(None), BuildStrategy)
        ), "The type of 'build_strategy' in fluid.dygraph.jit.TracedLayer.set_strategy must be fluid.BuildStrategy, but received {}.".format(
            type(build_strategy))
        assert isinstance(
            exec_strategy, (type(None), ExecutionStrategy)
        ), "The type of 'exec_strategy' in fluid.dygraph.jit.TracedLayer.set_strategy must be fluid.ExecutionStrategy, but received {}.".format(
            type(exec_strategy))
1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178
        self._build_strategy = build_strategy
        self._exec_strategy = exec_strategy

    @switch_to_static_graph
    def _compile(self):
        self._compiled_program = CompiledProgram(
            self._program).with_data_parallel(
                build_strategy=self._build_strategy,
                exec_strategy=self._exec_strategy,
                places=self._place)

    def _build_feed(self, inputs):
        assert isinstance(inputs, (list, tuple)), \
            "Inputs should be a list or tuple of variables"
        assert len(inputs) == len(self._feed_names)
        feed_dict = {}
        if in_dygraph_mode():
            for x, name in zip(inputs, self._feed_names):
1179
                feed_dict[name] = x.value().get_tensor()
1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201
        else:
            for x, name in zip(inputs, self._feed_names):
                feed_dict[name] = x

        return feed_dict

    @switch_to_static_graph
    def _run(self, feed):
        return self._exe.run(self._compiled_program,
                             feed=feed,
                             fetch_list=self._fetch_names)

    def __call__(self, inputs):
        with scope_guard(self._scope):
            if self._compiled_program is None:
                self._compile()

            return self._run(self._build_feed(inputs))

    @switch_to_static_graph
    def save_inference_model(self, dirname, feed=None, fetch=None):
        """
1202 1203
        Save the TracedLayer to a model for inference. The saved
        inference model can be loaded by C++ inference APIs.
1204 1205

        Args:
1206
            dirname (str): the directory to save the inference model.
1207
            feed (list[int], optional): the input variable indices of the saved
1208
                inference model. If None, all input variables of the
1209 1210 1211 1212 1213 1214 1215 1216
                TracedLayer object would be the inputs of the saved inference
                model. Default None.
            fetch (list[int], optional): the output variable indices of the
                saved inference model. If None, all output variables of the
                TracedLayer object would be the outputs of the saved inference
                model. Default None.

        Returns:
1217
            None
1218 1219 1220 1221 1222

        Examples:
            .. code-block:: python:

                import numpy as np
1223
                import paddle
1224

1225
                class ExampleLayer(paddle.nn.Layer):
1226 1227
                    def __init__(self):
                        super(ExampleLayer, self).__init__()
1228
                        self._fc = paddle.nn.Linear(3, 10)
1229 1230 1231 1232

                    def forward(self, input):
                        return self._fc(input)

1233 1234
                save_dirname = './saved_infer_model'
                in_np = np.random.random([2, 3]).astype('float32')
1235 1236
                in_var = paddle.to_tensor(in_np)
                layer = ExampleLayer()
1237

1238 1239
                out_dygraph, static_layer = paddle.jit.TracedLayer.trace(layer, inputs=[in_var])
                static_layer.save_inference_model(save_dirname, feed=[0], fetch=[0])
1240

1241 1242 1243 1244
                paddle.enable_static()
                place = paddle.CPUPlace()
                exe = paddle.static.Executor(place)
                program, feed_vars, fetch_vars = paddle.static.load_inference_model(save_dirname,
1245
                                                    exe)
1246 1247 1248

                fetch, = exe.run(program, feed={feed_vars[0]: in_np}, fetch_list=fetch_vars)
                print(fetch.shape) # (2, 10)
1249
        """
1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264
        check_type(dirname, "dirname", str,
                   "fluid.dygraph.jit.TracedLayer.save_inference_model")
        check_type(feed, "feed", (type(None), list),
                   "fluid.dygraph.jit.TracedLayer.save_inference_model")
        if isinstance(feed, list):
            for f in feed:
                check_type(f, "each element of feed", int,
                           "fluid.dygraph.jit.TracedLayer.save_inference_model")
        check_type(fetch, "fetch", (type(None), list),
                   "fluid.dygraph.jit.TracedLayer.save_inference_model")
        if isinstance(fetch, list):
            for f in fetch:
                check_type(f, "each element of fetch", int,
                           "fluid.dygraph.jit.TracedLayer.save_inference_model")

1265
        from paddle.fluid.io import save_inference_model
1266 1267 1268 1269 1270

        def get_feed_fetch(all_vars, partial_vars):
            if partial_vars is None:
                return all_vars

1271
            return [all_vars[idx] for idx in partial_vars]
1272 1273 1274 1275 1276 1277 1278 1279 1280 1281

        with scope_guard(self._scope):
            feeded_var_names = get_feed_fetch(self._feed_names, feed)
            target_var_names = get_feed_fetch(self._fetch_names, fetch)
            target_vars = []
            for name in target_var_names:
                target_var = self._program.global_block().vars.get(name, None)
                assert target_var is not None, "{} cannot be found".format(name)
                target_vars.append(target_var)

1282
            save_inference_model(
1283 1284 1285 1286 1287
                dirname=dirname,
                feeded_var_names=feeded_var_names,
                target_vars=target_vars,
                executor=self._exe,
                main_program=self._program.clone())