jit.py 47.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

17 18
import os
import pickle
19
import warnings
20
import functools
21
from collections import OrderedDict
22 23

import six
24
import paddle
25
from paddle.fluid import core
26 27
from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy
from paddle.fluid.data_feeder import check_type
28
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
29
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
30
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import set_code_level, set_verbosity
31
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, StaticFunction, unwrap_decorators
32
from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME, TranslatedLayer
33 34
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.executor import Executor, scope_guard
35 36 37
from paddle.fluid.framework import Block, ParamBase, Program, Variable
from paddle.fluid.framework import _current_expected_place, _dygraph_guard, _dygraph_tracer
from paddle.fluid.framework import dygraph_only, in_dygraph_mode
38
from paddle.fluid.wrapped_decorator import wrap_decorator
39

40 41
__all__ = [
    'TracedLayer', 'declarative', 'dygraph_to_static_func', 'set_code_level',
C
Chen Weihang 已提交
42
    'set_verbosity', 'save', 'load'
43
]
44 45 46 47 48 49 50 51 52 53 54 55


def create_program_from_desc(program_desc):
    program = Program()
    program.desc = program_desc
    program.blocks = [Block(program, 0)]
    program._sync_with_cpp()
    return program


def _extract_vars(inputs, result_list):
    if isinstance(inputs, Variable):
56
        result_list.append(inputs)
57
    elif isinstance(inputs, (list, tuple)):
58 59
        for var in inputs:
            _extract_vars(var, result_list)
60 61 62 63
    else:
        raise TypeError(
            "The type of 'each element of inputs' in fluid.dygraph.jit.TracedLayer.trace must be fluid.Variable, but received {}.".
            format(type(inputs)))
64 65 66 67 68 69 70 71


def extract_vars(inputs):
    result_list = []
    _extract_vars(inputs, result_list)
    return result_list


72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
def _dygraph_to_static_func_(dygraph_func):
    """
    Converts imperative dygraph APIs into declarative function APIs. Decorator
    @dygraph_to_static_func only converts imperative dygraph APIs into
    declarative net-building APIs, which means it doesn't return immediate
    digital result as imperative mode. Users should handle Program and Executor
    by themselves.

    Note:
    This decorator is NOT our recommended way to transform imperative function
    to declarative function. We will remove this decorator after we finalize
    cleaning up code.

    Args:
        dygraph_func (callable): callable imperative function.

    Returns:
        Callable: converting imperative dygraph APIs into declarative
        net-building APIs.

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          import numpy as np
          from paddle.fluid.dygraph.jit import dygraph_to_static_func

          @dygraph_to_static_func
          def func(x):
              if fluid.layers.mean(x) < 0:
                  x_v = x - 1
              else:
                  x_v = x + 1

               return x_v

          x = fluid.layers.fill_constant(shape=[3, 3], value=0, dtype='float64')

          x_v = func(x)
          exe = fluid.Executor(fluid.CPUPlace())
          out = exe.run(fetch_list=[x_v])
          print(out[0])
          # [[1. 1. 1.]
          #  [1. 1. 1.]
          #  [1. 1. 1.]]

    """

    # TODO: remove this decorator after we finalize training API
121 122
    def __impl__(*args, **kwargs):
        program_translator = ProgramTranslator()
123
        if in_dygraph_mode() or not program_translator.enable_to_static:
124
            logging_utils.warn(
125
                "The decorator 'dygraph_to_static_func' doesn't work in "
126
                "dygraph mode or set ProgramTranslator.enable to False. "
127 128 129 130
                "We will just return dygraph output.")
            return dygraph_func(*args, **kwargs)
        static_func = program_translator.get_func(dygraph_func)
        return static_func(*args, **kwargs)
131 132 133 134

    return __impl__


135
dygraph_to_static_func = wrap_decorator(_dygraph_to_static_func_)
136

137

138 139 140 141 142 143
def copy_decorator_attrs(original_func, decorated_obj):
    """
    Copies some necessary attributes from original function into decorated function.

    Args:
        original_func(callable): the original decorated function.
144
        decorated_obj(StaticFunction): the target decorated StaticFunction object.
145 146 147 148 149 150 151 152 153 154 155 156 157 158
    """
    decorator_name = "declarative"

    decorated_obj.__name__ = original_func.__name__
    decorated_obj._decorator_name = decorator_name
    decorated_obj.__wrapped__ = original_func
    decorated_obj.__doc__ = original_func.__doc__
    if hasattr(original_func, "__module__"):
        decorated_obj.__module__ = original_func.__module__

    return decorated_obj


def declarative(function=None, input_spec=None):
159 160 161
    """
    Converts imperative dygraph APIs into declarative function APIs. Decorator
    @declarative handles the Program and Executor of static mode and returns
162 163 164 165
    the result as dygraph Tensor(s). Users could use the returned dygraph
    Tensor(s) to do imperative training, inference, or other operations. If the
    decorated function calls other imperative function, the called one will be
    converted into declarative function as well.
166

167
    Args:
168 169 170
        function (callable): callable imperative function.
        input_spec(list[InputSpec]): list of InputSpec to specific the shape/dtype/name
            information of each input Tensor.
171

172
    Returns:
173
        Tensor(s): containing the numerical result.
174

175 176
    Examples:
        .. code-block:: python
177

178 179 180
          import paddle.fluid as fluid
          import numpy as np
          from paddle.fluid.dygraph.jit import declarative
181

182
          fluid.enable_dygraph()
183

184 185 186 187 188 189 190 191
          @declarative
          def func(x):
              x = fluid.dygraph.to_variable(x)
              if fluid.layers.mean(x) < 0:
                  x_v = x - 1
              else:
                  x_v = x + 1
              return x_v
192

193 194 195
          x = np.ones([1, 2])
          x_v = func(x)
          print(x_v.numpy()) # [[2. 2.]]
196

197
    """
198

199 200
    def decorated(python_func):
        """
201
        Decorates a python function into a StaticFunction object.
202 203 204
        """
        # Step 1. unwrap the function if it is already decorated.
        _, python_func = unwrap_decorators(python_func)
205

206 207 208
        # Step 2. copy some attributes from original python function.
        static_layer = copy_decorator_attrs(
            original_func=python_func,
209
            decorated_obj=StaticFunction(
210 211 212
                function=python_func, input_spec=input_spec))

        return static_layer
213

214 215
    # for usage: `declarative(foo, ...)`
    if function is not None:
216
        if isinstance(function, Layer):
217
            if isinstance(function.forward, StaticFunction):
218
                class_name = function.__class__.__name__
219
                logging_utils.warn(
220 221 222 223 224 225
                    "`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one.".
                    format(class_name))
            function.forward = decorated(function.forward)
            return function
        else:
            return decorated(function)
226

227 228
    # for usage: `@declarative`
    return decorated
229 230


C
Chen Weihang 已提交
231
class _SaveLoadConfig(object):
232 233 234 235 236
    def __init__(self):
        self._output_spec = None
        self._model_filename = None
        self._params_filename = None
        self._separate_params = False
237 238
        # used for `paddle.load`
        self._keep_name_table = False
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256

        # NOTE: Users rarely use following configs, so these configs are not open to users,
        # reducing user learning costs, but we retain the configuration capabilities

        # If True, programs are modified to only support direct inference deployment. 
        # Otherwise,more information will be stored for flexible optimization and re-training. 
        # Currently, only True is supported
        self._export_for_deployment = True

        # If True, It will save inference program only, and do not save params of Program
        self._program_only = False

    @property
    def output_spec(self):
        return self._output_spec

    @output_spec.setter
    def output_spec(self, spec):
C
Chen Weihang 已提交
257 258
        if spec is None:
            return
259 260
        if not isinstance(spec, list):
            raise TypeError(
C
Chen Weihang 已提交
261
                "The config `output_spec` should be 'list', but received input type is %s."
262 263 264 265
                % type(input))
            for var in spec:
                if not isinstance(var, core.VarBase):
                    raise TypeError(
C
Chen Weihang 已提交
266
                        "The element in config `output_spec` list should be 'Variable', but received element's type is %s."
267 268 269 270 271 272 273 274 275
                        % type(var))
        self._output_spec = spec

    @property
    def model_filename(self):
        return self._model_filename

    @model_filename.setter
    def model_filename(self, filename):
C
Chen Weihang 已提交
276 277
        if filename is None:
            return
278 279
        if not isinstance(filename, six.string_types):
            raise TypeError(
C
Chen Weihang 已提交
280
                "The config `model_filename` should be str, but received input's type is %s."
281 282
                % type(filename))
        if len(filename) == 0:
C
Chen Weihang 已提交
283
            raise ValueError("The config `model_filename` is empty string.")
284 285 286 287 288 289 290 291
        self._model_filename = filename

    @property
    def params_filename(self):
        return self._params_filename

    @params_filename.setter
    def params_filename(self, filename):
C
Chen Weihang 已提交
292 293
        if filename is None:
            return
294 295
        if not isinstance(filename, six.string_types):
            raise TypeError(
C
Chen Weihang 已提交
296
                "The config `params_filename` should be str, but received input's type is %s."
297 298
                % type(filename))
        if len(filename) == 0:
C
Chen Weihang 已提交
299
            raise ValueError("The config `params_filename` is empty string.")
300 301 302 303 304 305 306 307 308 309 310 311 312 313
        self._params_filename = filename

    # NOTE: [why not use params_filename=None control params saved separately]
    # The new save interface does not recommend parameters to be saved separately. 
    # Here, the concept should be separated as clearly as possible. 
    # Setting params_filename=None only means that the saved file name is set 
    # and without any other meaning. New separate_params control for file saved
    # separately can makes the concept clearer.
    @property
    def separate_params(self):
        return self._separate_params

    @separate_params.setter
    def separate_params(self, value):
C
Chen Weihang 已提交
314 315
        if value is None:
            return None
316 317
        if not isinstance(value, bool):
            raise TypeError(
C
Chen Weihang 已提交
318
                "The config `separate_params` should be bool value, but received input's type is %s."
319 320 321
                % type(value))
        self._separate_params = value

322 323 324 325 326 327
    @property
    def keep_name_table(self):
        return self._keep_name_table

    @keep_name_table.setter
    def keep_name_table(self, value):
C
Chen Weihang 已提交
328 329
        if value is None:
            return
330 331
        if not isinstance(value, bool):
            raise TypeError(
C
Chen Weihang 已提交
332
                "The config `keep_name_table` should be bool value, but received input's type is %s."
333 334 335
                % type(value))
        self._keep_name_table = value

336

C
Chen Weihang 已提交
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
def _parse_save_configs(configs):
    supported_configs = [
        'output_spec', 'model_filename', 'params_filename', 'separate_params'
    ]

    # input check
    for key in configs:
        if key not in supported_configs:
            raise ValueError(
                "The additional config (%s) of `paddle.jit.save` is not supported."
                % (key))

    # construct inner config
    inner_config = _SaveLoadConfig()
    inner_config.output_spec = configs.get('output_spec', None)
    inner_config.model_filename = configs.get('model_filename', None)
    inner_config.params_filename = configs.get('params_filename', None)
    inner_config.separate_params = configs.get('separate_params', None)

    return inner_config


def _parse_load_config(configs):
    supported_configs = ['model_filename', 'params_filename', 'separate_params']

    # input check
    for key in configs:
        if key not in supported_configs:
            raise ValueError(
                "The additional config (%s) of `paddle.jit.load` is not supported."
                % (key))

    # construct inner config
    inner_config = _SaveLoadConfig()
    inner_config.model_filename = configs.get('model_filename', None)
    inner_config.params_filename = configs.get('params_filename', None)
    inner_config.separate_params = configs.get('separate_params', None)

    return inner_config


378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
def _get_input_var_names(inputs, input_spec):
    name_none_error = "The %s's name is None. " \
        "When using jit.save, please set InputSepc's name in " \
        "to_static(input_spec=[]) and jit.save(input_spec=[]) " \
        "and make sure they are consistent."
    name_no_exists_error = "The tensor `%s` does not exists. " \
        "Please make sure the name of InputSpec or example Tensor " \
        "in input_spec is the same as the name of InputSpec in " \
        "`to_static` decorated on the Layer.forward method."
    result_list = []
    input_var_names = [var.name for var in inputs if isinstance(var, Variable)]
    if input_spec is None:
        # no prune
        result_list = input_var_names
    elif input_spec is not None and len(input_spec) == len(input_var_names):
        # no prune
        result_list = input_var_names
        # if input spec name not in input_var_names, only raise warning 
        for spec in input_spec:
            if spec.name is None:
                warnings.warn(name_none_error % spec)
            elif spec.name not in input_var_names:
                warnings.warn(name_no_exists_error % spec.name)
            else:
                # do nothing
                pass
    else:
        # prune
        for spec in input_spec:
            if spec.name is None:
                # name is None, the input_spec only can be InputSpec
                raise ValueError(name_none_error % spec)
            elif spec.name not in input_var_names:
                # the input_spec can be `InputSpec` or `VarBase`
                raise ValueError(name_no_exists_error % spec.name)
            else:
                result_list.append(spec.name)

    return result_list


def _get_output_vars(outputs, output_spec):
    name_no_exists_error = "The tensor `%s` does not exists. " \
        "Please make sure the name of example Tensor " \
        "in configs.output_spec is the output tensor of " \
        "Layer.forward method."
    result_list = []
    output_vars_dict = OrderedDict()
    for var in outputs:
        if isinstance(var, Variable):
            output_vars_dict[var.name] = var
    if output_spec is None:
        result_list = output_vars_dict.values()
    elif output_spec is not None and len(output_spec) == len(output_vars_dict):
        result_list = output_vars_dict.values()
        for var in output_spec:
            if var.name not in output_vars_dict:
                warnings.warn(name_no_exists_error % var.name)
    else:
        for var in output_spec:
            if var.name not in output_vars_dict:
                raise ValueError(name_no_exists_error % var.name)
            else:
                result_list.append(output_vars_dict[var.name])
    return result_list


445
@switch_to_static_graph
C
Chen Weihang 已提交
446
def save(layer, model_path, input_spec=None, **configs):
447 448 449 450 451 452 453 454 455 456
    """
    Saves input declarative Layer as :ref:`api_imperative_TranslatedLayer` 
    format model, which can be used for inference or fine-tuning after loading.

    It will save the translated program and all related persistable 
    variables of input declarative Layer to given ``model_path``.
    
    The default saved translated program file name is ``__model__``,
    and the default saved persistable variables file name is ``__variables__``,
    and it also saved some additional variable description information to file 
457
    ``__variables.info__``, these additional information is used in fine-tuning.
458 459 460 461 462 463 464 465 466

    The saved model can be loaded by follow APIs:
      - :ref:`api_imperative_jit_load`
      - :ref:`api_fluid_io_load_inference_model` (need pass ``params_filename='__variables__'``)
      - Other C++ inference APIs

    Args:
        layer (Layer): the Layer to be saved. The Layer should be decorated by `@declarative`.
        model_path (str): the directory to save the model.
C
Chen Weihang 已提交
467
        input_spec (list[InputSpec|Tensor], optional): Describes the input of the saved model. 
468 469 470
            It is the example inputs that will be passed to saved TranslatedLayer's forward
            function. If None, all input variables of the original Layer's forward function
            would be the inputs of the saved model. Default None.
C
Chen Weihang 已提交
471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
        configs (dict, optional): other save configuration options for compatibility. We do not 
            recommend using these configurations, if not necessary, DO NOT use them. Default None.
            The following options are currently supported:
            (1) output_spec (list[Tensor]): Selects the output targets of the saved model.
            By default, all return variables of original Layer's forward function are kept as the 
            output of the saved model. If the provided ``output_spec`` list is not all output variables, 
            the saved model will be pruned according to the given ``output_spec`` list. 
            (2) model_filename (string): The name of file to save the translated program of target Layer.
            Default filename is :code:`__model__` . 
            (3) params_filename (string): The name of file to save all persistable variables in target Layer. 
            Default file name is :code:`__variables__` .
            (4) separate_params (bool): Configure whether to save the Layer parameters as separete files.
            If True, each parameter will be saved to a file separately, the file name is the parameter name,
            and the params_filename configuration will not take effect. Default False.

486 487 488 489 490 491 492
    Returns:
        None

    Examples:
        .. code-block:: python

            import numpy as np
493 494 495
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
496

497 498 499
            BATCH_SIZE = 16
            BATCH_NUM = 4
            EPOCH_NUM = 4
500

501 502 503 504 505 506 507
            IMAGE_SIZE = 784
            CLASS_NUM = 10

            # define a random dataset
            class RandomDataset(paddle.io.Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples
508

509 510 511 512
                def __getitem__(self, idx):
                    image = np.random.random([IMAGE_SIZE]).astype('float32')
                    label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
                    return image, label
513

514 515
                def __len__(self):
                    return self.num_samples
516

517 518
            class LinearNet(nn.Layer):
                def __init__(self):
519
                    super(LinearNet, self).__init__()
520
                    self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
521

522
                @paddle.jit.to_static
523 524 525
                def forward(self, x):
                    return self._linear(x)

526 527 528 529 530 531 532 533 534 535 536
            def train(layer, loader, loss_fn, opt):
                for epoch_id in range(EPOCH_NUM):
                    for batch_id, (image, label) in enumerate(loader()):
                        out = layer(image)
                        loss = loss_fn(out, label)
                        loss.backward()
                        opt.step()
                        opt.clear_grad()
                        print("Epoch {} batch {}: loss = {}".format(
                            epoch_id, batch_id, np.mean(loss.numpy())))

537
            # enable dygraph mode
538 539
            place = paddle.CPUPlace()
            paddle.disable_static(place) 
540

541
            # 1. train & save model.
542

543 544 545 546
            # create network
            layer = LinearNet()
            loss_fn = nn.CrossEntropyLoss()
            adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())
547

548 549 550 551 552 553 554 555
            # create data loader
            dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
            loader = paddle.io.DataLoader(dataset,
                places=place,
                batch_size=BATCH_SIZE,
                shuffle=True,
                drop_last=True,
                num_workers=2)
556

557 558
            # train
            train(layer, loader, loss_fn, adam)
559

560
            # save
561
            model_path = "linear.example.model"
562
            paddle.jit.save(layer, model_path)
563 564 565 566
    """

    # 1. input check
    prog_translator = ProgramTranslator()
567
    if not prog_translator.enable_to_static:
568
        raise RuntimeError(
569
            "The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
570 571 572
        )
    if not isinstance(layer, Layer):
        raise TypeError(
573
            "The input layer of paddle.jit.save should be 'Layer', but received layer type is %s."
574 575
            % type(layer))

C
Chen Weihang 已提交
576
    configs = _parse_save_configs(configs)
577

578 579
    # avoid change user given input_spec
    inner_input_spec = None
580 581 582 583 584
    if input_spec is not None:
        if not isinstance(input_spec, list):
            raise TypeError(
                "The input input_spec should be 'list', but received input_spec's type is %s."
                % type(input_spec))
585
        inner_input_spec = []
586
        for var in input_spec:
587 588 589 590 591 592
            if isinstance(var, paddle.static.InputSpec):
                inner_input_spec.append(var)
            elif isinstance(var, (core.VarBase, Variable)):
                inner_input_spec.append(
                    paddle.static.InputSpec.from_tensor(var))
            else:
593
                raise TypeError(
594
                    "The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s."
595 596
                    % type(var))

597 598
    # 2. get program from Layer
    # TODO(chenweihang): add support for other method, not only forward
599
    if isinstance(layer.forward, StaticFunction):
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629
        concrete_program = layer.forward.concrete_program
    else:
        # transform in jit.save, if input_spec is incomplete, declarative will throw error
        static_forward = declarative(layer.forward, input_spec=inner_input_spec)
        concrete_program = static_forward.concrete_program
        # the input_spec has been used in declarative, which is equal to 
        # @declarative with input_spec and jit.save without input_spec,
        # avoid needless warning
        inner_input_spec = None

    # 3. build input & output of save_infernece_model
    # NOTE(chenweihang): [ Get input variables name ]
    # There are two cases, whether to prune the inputs or not
    # - not prune inputs (recommend):
    #   - the len(input_spec) == len((concrete_program.inputs) - 1
    #   - here can use concrete_program.inputs directly
    # - prune inputs:
    #   - the input_spec length < len((concrete_program.inputs) - 1
    #   - the input_spec's name should be in concrete_program.inputs
    input_var_names = _get_input_var_names(concrete_program.inputs,
                                           inner_input_spec)

    # NOTE(chenweihang): [ Get output variables ]
    # the rule is like [ Get input variables name ]. For output var, 
    # we only support VarBase spec, and actually, we only need the 
    # var name of output, and we don't recommended to use output_spec
    output_vars = _get_output_vars(concrete_program.outputs,
                                   configs.output_spec)

    # NOTE(chenweihang): we maintain the mapping of variable name to
630 631 632 633
    # structured name, the buffer variable (non-persistable)
    # saved to inference program may not need by dygraph Layer, 
    # we only record the state_dict variable's structured name
    state_names_dict = dict()
634
    for structured_name, var in six.iteritems(layer.state_dict()):
635 636
        state_names_dict[var.name] = structured_name

637
    # 4. share parameters from Layer to scope & record var info
638 639
    scope = core.Scope()
    extra_var_info = dict()
640
    for param_or_buffer in concrete_program.parameters:
641 642 643 644 645 646
        # share to scope
        param_or_buffer_tensor = scope.var(param_or_buffer.name).get_tensor()
        src_tensor = param_or_buffer.value().get_tensor()
        param_or_buffer_tensor._share_data_with(src_tensor)
        # record var info
        extra_info_dict = dict()
647 648 649
        if param_or_buffer.name in state_names_dict:
            extra_info_dict['structured_name'] = state_names_dict[
                param_or_buffer.name]
650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674
        extra_info_dict['stop_gradient'] = param_or_buffer.stop_gradient
        if isinstance(param_or_buffer, ParamBase):
            extra_info_dict['trainable'] = param_or_buffer.trainable
        extra_var_info[param_or_buffer.name] = extra_info_dict

    # 5. save inference model
    from paddle.fluid.io import save_inference_model

    # VARIABLE_FILENAME keep nameing style consistent with '__model__'
    if configs.params_filename is None:
        configs.params_filename = VARIABLE_FILENAME

    with scope_guard(scope):
        save_inference_model(
            dirname=model_path,
            feeded_var_names=input_var_names,
            target_vars=output_vars,
            executor=Executor(_current_expected_place()),
            main_program=concrete_program.main_program.clone(),
            model_filename=configs.model_filename,
            params_filename=None
            if configs.separate_params else configs.params_filename,
            export_for_deployment=configs._export_for_deployment,
            program_only=configs._program_only)

675
        # NOTE(chenweihang): [ Save extra variable info ]
676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695
        # save_inference_model will lose some important variable information, including:
        #   - Variable name and correspondence (when saved variables as one file)
        #   - Variable.stop_gradient information
        #   - Which persistent variable are parameter and which are not
        #   - Parameter.trainable information
        #
        # The lost information cannot be recovered when it is loaded again, 
        # so if we want to perform fine-tune after loading, we may need to 
        # configure redundant information to proceed.
        #
        # Due to compatibility issues, we cannot change the original storage structure, 
        # but we can save these information in `jit.save` without changing the original 
        # storage to improve user experience. So we save extra information into
        # file `__variables.info__`
        extra_var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME)
        with open(extra_var_info_path, 'wb') as f:
            pickle.dump(extra_var_info, f, protocol=2)


@dygraph_only
C
Chen Weihang 已提交
696
def load(model_path, **configs):
697 698 699 700 701 702 703 704 705 706
    """
    :api_attr: imperative

    Load model saved by :ref:`api_imperative_jit_save` or :ref:`api_fluid_io_save_inference_model`
    as :ref:`api_imperative_TranslatedLayer`, then performing inference or fine-tune training.

    .. note::
        For some historical reasons, if you load model saved by :ref:`api_fluid_io_save_inference_model`,
        there will be the following limitations when using it in fine-tuning:
        1. Imperative mode do not support LoDTensor. All original model's feed targets or parametars that depend on LoD are temporarily unavailable.
707
        2. All saved model's feed targets need to be passed into TranslatedLayer's forward function.
708 709 710 711 712
        3. The variable's ``stop_gradient`` information is lost and can not be recovered.
        4. The parameter's ``trainable`` information is lost and can not be recovered.

    Args:
        model_path (str): The directory path where the model is saved.
C
Chen Weihang 已提交
713 714 715 716 717 718 719 720 721 722 723
        configs (dict, optional): other save configuration options for compatibility. We do not 
            recommend using these configurations, if not necessary, DO NOT use them. Default None.
            The following options are currently supported:
            (1) model_filename (string): The filename to load the translated program of target Layer.
            Default filename is :code:`__model__` . 
            (2) params_filename (string): The filename to load all persistable variables in target Layer. 
            Default file name is :code:`__variables__` .
            (3) separate_params (bool): Configure whether to load the Layer parameters from separete files.
            If True, each parameter will be loaded from a file separately, the file name is the parameter name,
            and the params_filename configuration will not take effect. Default False.

724 725 726 727 728 729 730 731 732 733

    Returns:
        TranslatedLayer: A Layer object can run saved translated model.

    Examples:
        1. Load model saved by :ref:`api_imperative_jit_save` then performing inference and fine-tune training.

        .. code-block:: python

            import numpy as np
734 735 736
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
737

738 739 740
            BATCH_SIZE = 16
            BATCH_NUM = 4
            EPOCH_NUM = 4
741

742 743
            IMAGE_SIZE = 784
            CLASS_NUM = 10
744

745 746 747 748
            # define a random dataset
            class RandomDataset(paddle.io.Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples
749

750 751 752 753
                def __getitem__(self, idx):
                    image = np.random.random([IMAGE_SIZE]).astype('float32')
                    label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
                    return image, label
754

755 756 757 758 759
                def __len__(self):
                    return self.num_samples

            class LinearNet(nn.Layer):
                def __init__(self):
760
                    super(LinearNet, self).__init__()
761
                    self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
762

763
                @paddle.jit.to_static
764 765 766
                def forward(self, x):
                    return self._linear(x)

767 768 769 770 771 772 773 774 775 776 777
            def train(layer, loader, loss_fn, opt):
                for epoch_id in range(EPOCH_NUM):
                    for batch_id, (image, label) in enumerate(loader()):
                        out = layer(image)
                        loss = loss_fn(out, label)
                        loss.backward()
                        opt.step()
                        opt.clear_grad()
                        print("Epoch {} batch {}: loss = {}".format(
                            epoch_id, batch_id, np.mean(loss.numpy())))

778
            # enable dygraph mode
779 780
            place = paddle.CPUPlace()
            paddle.disable_static(place) 
781 782

            # 1. train & save model.
783

784
            # create network
785 786 787 788
            layer = LinearNet()
            loss_fn = nn.CrossEntropyLoss()
            adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())

789
            # create data loader
790 791 792 793 794 795 796
            dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
            loader = paddle.io.DataLoader(dataset,
                places=place,
                batch_size=BATCH_SIZE,
                shuffle=True,
                drop_last=True,
                num_workers=2)
797

798 799
            # train
            train(layer, loader, loss_fn, adam)
800

801 802 803
            # save
            model_path = "linear.example.model"
            paddle.jit.save(layer, model_path)
804

805
            # 2. load model
806

807 808
            # load
            loaded_layer = paddle.jit.load(model_path)
809 810

            # inference
811 812 813
            loaded_layer.eval()
            x = paddle.randn([1, IMAGE_SIZE], 'float32')
            pred = loaded_layer(x)
814 815

            # fine-tune
816 817 818
            loaded_layer.train()
            adam = opt.Adam(learning_rate=0.001, parameters=loaded_layer.parameters())
            train(loaded_layer, loader, loss_fn, adam)
819 820 821 822 823 824 825


        2. Load model saved by :ref:`api_fluid_io_save_inference_model` then performing and fine-tune training.

        .. code-block:: python

            import numpy as np
826
            import paddle
827
            import paddle.fluid as fluid
828 829
            import paddle.nn as nn
            import paddle.optimizer as opt
830

831 832 833
            BATCH_SIZE = 16
            BATCH_NUM = 4
            EPOCH_NUM = 4
834

835 836 837 838 839 840 841
            IMAGE_SIZE = 784
            CLASS_NUM = 10

            # define a random dataset
            class RandomDataset(paddle.io.Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples
842

843 844 845 846
                def __getitem__(self, idx):
                    image = np.random.random([IMAGE_SIZE]).astype('float32')
                    label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
                    return image, label
847

848 849
                def __len__(self):
                    return self.num_samples
850

851
            image = fluid.data(name='image', shape=[None, 784], dtype='float32')
852
            label = fluid.data(name='label', shape=[None, 1], dtype='int64')
853
            pred = fluid.layers.fc(input=image, size=10, act='softmax')
854 855 856 857 858 859 860 861 862 863
            loss = fluid.layers.cross_entropy(input=pred, label=label)
            avg_loss = fluid.layers.mean(loss)

            optimizer = fluid.optimizer.SGD(learning_rate=0.001)
            optimizer.minimize(avg_loss)

            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            exe.run(fluid.default_startup_program())

864 865 866 867 868 869 870 871 872
            # create data loader
            dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
            loader = paddle.io.DataLoader(dataset,
                feed_list=[image, label],
                places=place,
                batch_size=BATCH_SIZE, 
                shuffle=True,
                drop_last=True,
                num_workers=2)
873 874 875 876 877 878 879 880 881 882

            # 1. train and save inference model
            for data in loader():
                exe.run(
                    fluid.default_main_program(),
                    feed=data, 
                    fetch_list=[avg_loss])

            model_path = "fc.example.model"
            fluid.io.save_inference_model(
883 884 885
                model_path, ["image"], [pred], exe)

            # 2. load model
886 887

            # enable dygraph mode
888 889 890 891
            paddle.disable_static(place)

            # load
            fc = paddle.jit.load(model_path)
892

893 894 895
            # inference
            fc.eval()
            x = paddle.randn([1, IMAGE_SIZE], 'float32')
896 897
            pred = fc(x)

898
            # fine-tune
899
            fc.train()
900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916
            loss_fn = nn.CrossEntropyLoss()
            adam = opt.Adam(learning_rate=0.001, parameters=fc.parameters())
            loader = paddle.io.DataLoader(dataset,
                places=place,
                batch_size=BATCH_SIZE,
                shuffle=True,
                drop_last=True,
                num_workers=2)
            for epoch_id in range(EPOCH_NUM):
                for batch_id, (image, label) in enumerate(loader()):
                    out = fc(image)
                    loss = loss_fn(out, label)
                    loss.backward()
                    adam.step()
                    adam.clear_grad()
                    print("Epoch {} batch {}: loss = {}".format(
                        epoch_id, batch_id, np.mean(loss.numpy())))
917
    """
C
Chen Weihang 已提交
918
    config = _parse_load_config(configs)
919
    return TranslatedLayer._construct(model_path, config)
920 921


922
@dygraph_only
Z
Zeng Jinle 已提交
923 924 925 926 927
def _trace(layer,
           inputs,
           feed_prefix='feed_',
           fetch_prefix='fetch_',
           tmp_prefix='t_'):
928
    assert isinstance(layer, Layer)
929 930 931 932 933 934 935 936 937

    if not isinstance(inputs, (list, tuple)):
        inputs = [inputs]

    tracer = _dygraph_tracer()._get_program_desc_tracer()

    var_list = extract_vars(inputs)

    with program_desc_tracing_guard(True):
938
        original_outputs = layer(*inputs)
939 940 941 942
        if not isinstance(original_outputs, (list, tuple)):
            outputs = [original_outputs]
        else:
            outputs = original_outputs
943
        out_vars = [var for var in outputs]
944

945
        program_desc, feed_names, fetch_names, parameters = tracer.create_program_desc(
Z
Zeng Jinle 已提交
946
            var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix)
947 948 949 950 951
        tracer.reset()

    with _dygraph_guard(None):
        program = create_program_from_desc(program_desc)

952
    return original_outputs, program, feed_names, fetch_names, parameters
953 954 955 956


class TracedLayer(object):
    """
957 958
    :api_attr: imperative
    
959 960 961 962 963
    TracedLayer is used to convert a forward dygraph model to a static
    graph model. This is mainly used to save the dygraph model for online
    inference using C++. Besides, users can also do inference in Python
    using the converted static graph model, which usually has better
    performance than the original dygraph model.
964 965 966 967

    TracedLayer would run the static graph model using :code:`Executor`
    and :code:`CompiledProgram` . The static graph model would share
    parameters with the dygraph model.
968 969

    All TracedLayer objects should not be created by constructor and should
970 971 972 973 974 975 976 977 978 979 980
    be created by static method :code:`TracedLayer.trace(layer, inputs)` .

    The TracedLayer can only be used to convert the data-independent dygraph
    model into the static graph model, which means the dygraph model should
    be independent with the tensor data and shape.
    """

    def __init__(self, program, parameters, feed_names, fetch_names):
        self._program = program
        self._feed_names = feed_names
        self._fetch_names = fetch_names
981
        self._params = parameters
982 983 984 985 986

        self._place = _current_expected_place()

        self._scope = core.Scope()
        for p in parameters:
987
            src_tensor = p.value().get_tensor()
988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010
            dst_tensor = self._scope.var(p.name).get_tensor()
            dst_tensor._share_data_with(src_tensor)

        self._exe = Executor(self._place)
        self._compiled_program = None
        self._build_strategy = None
        self._exec_strategy = None

    @property
    def program(self):
        return self._program

    def _switch(self, is_test=True):
        for block_id in range(self._program.num_blocks):
            block = self._program.block(block_id)
            for op in block.ops:
                if op.has_attr("is_test"):
                    op._set_attr("is_test", is_test)

    @staticmethod
    @dygraph_only
    def trace(layer, inputs):
        """
1011
        This method is the only allowed method to create TracedLayer object.
1012 1013 1014 1015
        It would call the :code:`layer(*inputs)` method to run the dygraph
        model and convert it into a static graph model.

        Args:
1016
            layer (dygraph.Layer): the layer object to be traced.
1017 1018
            inputs (list(Tensor)|tuple(Tensor)|Tensor): the input tensors of
                the layer object.
1019 1020

        Returns:
1021
            tuple: A tuple of 2 items, whose the first item is the output of
1022 1023
                :code:`layer(*inputs)` , and the second item is the created
                TracedLayer object.
1024

1025
        Examples:
1026 1027 1028
            .. code-block:: python:

                import paddle.fluid as fluid
1029
                from paddle.fluid.dygraph import Linear, to_variable, TracedLayer
1030 1031 1032
                import numpy as np

                class ExampleLayer(fluid.dygraph.Layer):
1033 1034 1035
                    def __init__(self):
                        super(ExampleLayer, self).__init__()
                        self._fc = Linear(3, 10)
1036 1037 1038 1039 1040

                    def forward(self, input):
                        return self._fc(input)

                with fluid.dygraph.guard():
1041
                    layer = ExampleLayer()
1042 1043 1044
                    in_np = np.random.random([2, 3]).astype('float32')
                    in_var = to_variable(in_np)
                    out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var])
1045 1046 1047 1048 1049 1050 1051 1052 1053

                    # run the static graph model using Executor inside
                    out_static_graph = static_layer([in_var])

                    print(len(out_static_graph)) # 1
                    print(out_static_graph[0].shape) # (2, 10)

                    # save the static graph model for inference
                    static_layer.save_inference_model(dirname='./saved_infer_model')
1054
        """
1055 1056 1057 1058
        assert isinstance(
            layer, Layer
        ), "The type of 'layer' in fluid.dygraph.jit.TracedLayer.trace must be fluid.dygraph.Layer, but received {}.".format(
            type(layer))
1059 1060
        outs, prog, feed, fetch, parameters = _trace(layer, inputs)
        traced = TracedLayer(prog, parameters, feed, fetch)
1061 1062 1063 1064 1065 1066 1067
        return outs, traced

    def set_strategy(self, build_strategy=None, exec_strategy=None):
        """
        Set the strategies when running static graph model.

        Args:
1068
            build_strategy (BuildStrategy, optional): build strategy of
1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079
                :code:`CompiledProgram` inside TracedLayer. Default None.
            exec_strategy (ExecutionStrategy, optional): execution strategy of
                :code:`CompiledProgram` inside TracedLayer. Default None.

        Returns:
            None

        Examples:
            .. code-block:: python:

                import paddle.fluid as fluid
1080
                from paddle.fluid.dygraph import Linear, to_variable, TracedLayer
1081 1082 1083
                import numpy as np

                class ExampleLayer(fluid.dygraph.Layer):
1084 1085 1086
                    def __init__(self):
                        super(ExampleLayer, self).__init__()
                        self._fc = Linear(3, 10)
1087 1088 1089 1090 1091

                    def forward(self, input):
                        return self._fc(input)

                with fluid.dygraph.guard():
1092
                    layer = ExampleLayer()
1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107
                    in_np = np.random.random([2, 3]).astype('float32')
                    in_var = to_variable(in_np)

                    out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var])

                    build_strategy = fluid.BuildStrategy()
                    build_strategy.enable_inplace = True

                    exec_strategy = fluid.ExecutionStrategy()
                    exec_strategy.num_threads = 2

                    static_layer.set_strategy(build_strategy=build_strategy, exec_strategy=exec_strategy)
                    out_static_graph = static_layer([in_var])
        """
        assert self._compiled_program is None, "Cannot set strategy after run"
1108 1109 1110 1111 1112 1113 1114 1115
        assert isinstance(
            build_strategy, (type(None), BuildStrategy)
        ), "The type of 'build_strategy' in fluid.dygraph.jit.TracedLayer.set_strategy must be fluid.BuildStrategy, but received {}.".format(
            type(build_strategy))
        assert isinstance(
            exec_strategy, (type(None), ExecutionStrategy)
        ), "The type of 'exec_strategy' in fluid.dygraph.jit.TracedLayer.set_strategy must be fluid.ExecutionStrategy, but received {}.".format(
            type(exec_strategy))
1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133
        self._build_strategy = build_strategy
        self._exec_strategy = exec_strategy

    @switch_to_static_graph
    def _compile(self):
        self._compiled_program = CompiledProgram(
            self._program).with_data_parallel(
                build_strategy=self._build_strategy,
                exec_strategy=self._exec_strategy,
                places=self._place)

    def _build_feed(self, inputs):
        assert isinstance(inputs, (list, tuple)), \
            "Inputs should be a list or tuple of variables"
        assert len(inputs) == len(self._feed_names)
        feed_dict = {}
        if in_dygraph_mode():
            for x, name in zip(inputs, self._feed_names):
1134
                feed_dict[name] = x.value().get_tensor()
1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156
        else:
            for x, name in zip(inputs, self._feed_names):
                feed_dict[name] = x

        return feed_dict

    @switch_to_static_graph
    def _run(self, feed):
        return self._exe.run(self._compiled_program,
                             feed=feed,
                             fetch_list=self._fetch_names)

    def __call__(self, inputs):
        with scope_guard(self._scope):
            if self._compiled_program is None:
                self._compile()

            return self._run(self._build_feed(inputs))

    @switch_to_static_graph
    def save_inference_model(self, dirname, feed=None, fetch=None):
        """
1157 1158
        Save the TracedLayer to a model for inference. The saved
        inference model can be loaded by C++ inference APIs.
1159 1160

        Args:
1161
            dirname (str): the directory to save the inference model.
1162
            feed (list[int], optional): the input variable indices of the saved
1163
                inference model. If None, all input variables of the
1164 1165 1166 1167 1168 1169 1170 1171
                TracedLayer object would be the inputs of the saved inference
                model. Default None.
            fetch (list[int], optional): the output variable indices of the
                saved inference model. If None, all output variables of the
                TracedLayer object would be the outputs of the saved inference
                model. Default None.

        Returns:
1172
            None
1173 1174 1175 1176 1177

        Examples:
            .. code-block:: python:

                import paddle.fluid as fluid
1178
                from paddle.fluid.dygraph import Linear, to_variable, TracedLayer
1179 1180 1181
                import numpy as np

                class ExampleLayer(fluid.dygraph.Layer):
1182 1183 1184
                    def __init__(self):
                        super(ExampleLayer, self).__init__()
                        self._fc = Linear(3, 10)
1185 1186 1187 1188

                    def forward(self, input):
                        return self._fc(input)

1189 1190 1191
                save_dirname = './saved_infer_model'
                in_np = np.random.random([2, 3]).astype('float32')

1192
                with fluid.dygraph.guard():
1193
                    layer = ExampleLayer()
1194 1195
                    in_var = to_variable(in_np)
                    out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var])
1196
                    static_layer.save_inference_model(save_dirname, feed=[0], fetch=[0])
1197 1198

                place = fluid.CPUPlace()
1199 1200
                exe = fluid.Executor(place)
                program, feed_vars, fetch_vars = fluid.io.load_inference_model(save_dirname,
1201
                                                    exe)
1202 1203 1204

                fetch, = exe.run(program, feed={feed_vars[0]: in_np}, fetch_list=fetch_vars)
                print(fetch.shape) # (2, 10)
1205
        """
1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220
        check_type(dirname, "dirname", str,
                   "fluid.dygraph.jit.TracedLayer.save_inference_model")
        check_type(feed, "feed", (type(None), list),
                   "fluid.dygraph.jit.TracedLayer.save_inference_model")
        if isinstance(feed, list):
            for f in feed:
                check_type(f, "each element of feed", int,
                           "fluid.dygraph.jit.TracedLayer.save_inference_model")
        check_type(fetch, "fetch", (type(None), list),
                   "fluid.dygraph.jit.TracedLayer.save_inference_model")
        if isinstance(fetch, list):
            for f in fetch:
                check_type(f, "each element of fetch", int,
                           "fluid.dygraph.jit.TracedLayer.save_inference_model")

1221
        from paddle.fluid.io import save_inference_model
1222 1223 1224 1225 1226

        def get_feed_fetch(all_vars, partial_vars):
            if partial_vars is None:
                return all_vars

1227
            return [all_vars[idx] for idx in partial_vars]
1228 1229 1230 1231 1232 1233 1234 1235 1236 1237

        with scope_guard(self._scope):
            feeded_var_names = get_feed_fetch(self._feed_names, feed)
            target_var_names = get_feed_fetch(self._fetch_names, fetch)
            target_vars = []
            for name in target_var_names:
                target_var = self._program.global_block().vars.get(name, None)
                assert target_var is not None, "{} cannot be found".format(name)
                target_vars.append(target_var)

1238
            save_inference_model(
1239 1240 1241 1242 1243
                dirname=dirname,
                feeded_var_names=feeded_var_names,
                target_vars=target_vars,
                executor=self._exe,
                main_program=self._program.clone())