io.py 51.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import six
import pickle
import numpy as np

22
import paddle
23 24 25 26
from paddle import compat as cpt
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid import backward
27
from paddle.fluid import unique_name
28 29 30
from paddle.fluid.dygraph import layers
from paddle.fluid.layers import nn
from paddle.fluid.dygraph.base import switch_to_static_graph
W
WeiXin 已提交
31
from paddle.fluid.framework import in_dygraph_mode
32 33 34

__all__ = ['TranslatedLayer']

35 36 37 38
INFER_MODEL_SUFFIX = ".pdmodel"
INFER_PARAMS_SUFFIX = ".pdiparams"
INFER_PARAMS_INFO_SUFFIX = ".pdiparams.info"

39 40 41
LOADED_VAR_SUFFIX = "load"
PARAMETER_NAME_PREFIX = "param"
BUFFER_NAME_PREFIX = "buffer"
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117


def _load_program_desc(model_file_path):
    # 1. parse program desc
    with open(model_file_path, "rb") as f:
        program_desc_str = f.read()

    program_desc = core.ProgramDesc(program_desc_str)
    if not core._is_program_version_supported(program_desc._version()):
        raise ValueError("Unsupported program version: %d\n" %
                         program_desc._version())

    return program_desc


def _is_persistable(var_desc):
    if var_desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
            var_desc.type() == core.VarDesc.VarType.FETCH_LIST or \
            var_desc.type() == core.VarDesc.VarType.READER or \
            var_desc.type() == core.VarDesc.VarType.RAW:
        return False
    return var_desc.persistable()


def _is_parameter(persistable_var_desc, program_desc):
    # 1. firstly, param should be input of op
    input_ops = []  # op can be repeated
    for block_idx in six.moves.range(program_desc.num_blocks()):
        block = program_desc.block(block_idx)
        for op_idx in six.moves.range(block.op_size()):
            op = block.op(op_idx)
            # NOTE: parameter is the input of a certain op
            if persistable_var_desc.name() in op.input_arg_names():
                input_ops.append(op)
    # 2. secondly, param should not be output of op or be same op's output
    for block_idx in six.moves.range(program_desc.num_blocks()):
        block = program_desc.block(block_idx)
        for op_idx in six.moves.range(block.op_size()):
            op = block.op(op_idx)
            if persistable_var_desc.name() in op.output_arg_names():
                # such as batch_norm_op
                if op in input_ops:
                    continue
                else:
                    return False
    return True


def _get_persistable_vars(program_desc):
    persistable_vars = []
    for i in six.moves.range(program_desc.num_blocks()):
        block = program_desc.block(i)
        persistable_vars.extend(list(filter(_is_persistable, block.all_vars())))
    return persistable_vars


def _get_persistable_var_names(program_desc):
    """
    Get all persistable variable names in ProgramDesc.
    """
    var_names = []
    persistable_vars = _get_persistable_vars(program_desc)
    for var in persistable_vars:
        var_names.append(var.name())
    return var_names


def _get_all_var_names(program_desc):
    all_var_names = set()
    for i in six.moves.range(program_desc.num_blocks()):
        block = program_desc.block(i)
        for var in block.all_vars():
            all_var_names.add(var.name())
    return all_var_names


118
@switch_to_static_graph
119 120 121
def _append_loaded_suffix(name):
    """
    Append loaded suffix to the given variable name
122
    e.g. x ==> x.load_0, x.load_0 ==> x.load_0.load_0
123
    """
124
    suffix = LOADED_VAR_SUFFIX
125
    name = cpt.to_text(name)
126 127
    new_name = unique_name.generate_with_ignorable_key('.'.join((name, suffix)))
    return new_name
128 129


130 131 132
@switch_to_static_graph
def _generate_unique_var_name(prefix):
    return unique_name.generate_with_ignorable_key(prefix)
133 134 135


def _append_loaded_suffix_to_var(program_desc):
136
    suffix_varname_dict = dict()
137 138 139 140
    persistable_vars = _get_persistable_vars(program_desc)
    for var_desc in persistable_vars:
        old_name = var_desc.name()
        new_name = _append_loaded_suffix(var_desc.name())
141
        suffix_varname_dict[new_name] = old_name
142 143 144
        var_desc.set_name(new_name)
        for block_idx in six.moves.range(program_desc.num_blocks()):
            block = program_desc.block(block_idx)
C
Chen Weihang 已提交
145
            block._rename_var(cpt.to_bytes(old_name), cpt.to_bytes(new_name))
146 147 148 149
            for op_idx in six.moves.range(block.op_size()):
                op = block.op(op_idx)
                op._rename_input(old_name, new_name)
                op._rename_output(old_name, new_name)
150
    return suffix_varname_dict
151 152


153 154 155 156 157 158 159 160 161 162 163 164 165 166
@switch_to_static_graph
def _generate_unique_var_name_sync_with_main_program(prefix):
    return unique_name.generate(prefix)


def _get_loaded_var_new_old(program_desc, all_new_old_dict_all):
    new_old_dict = dict()
    persistable_vars = _get_persistable_vars(program_desc)
    for var_desc in persistable_vars:
        name_new = var_desc.name()
        new_old_dict[name_new] = all_new_old_dict_all[name_new]
    return new_old_dict


W
WeiXin 已提交
167
def _rename_var_program_desc(program_desc, include=None, exclude=None):
168 169
    """
    Change the name of the loaded variables.Use 'unique_name.generate' to avoid duplication
W
WeiXin 已提交
170 171 172 173 174 175 176 177
    e.g. linear_0.tmp_3 ==> linear_0.tmp_1, x ==> x_0.
    If 'include' is not `None`,variables that are not in include are not renamed.
    If 'exclude' is not `None`,variables that are in exclude will are not renamed.

    Args:
        program_desc(ProgramDesc):the variables in it will be modified.
        include(List):list of names of variables.
        exclude(List):list of names of variables.
178 179 180 181 182 183 184 185 186 187 188 189
    """
    dict_rename_var_old_new = dict()
    dict_rename_var_new_old = dict()
    old_names = []
    for b_idx in six.moves.range(program_desc.num_blocks()):
        cur_block = program_desc.block(b_idx)
        for var in cur_block.all_vars():
            old_names.append(var.name())
    for b_idx in six.moves.range(program_desc.num_blocks()):
        cur_block = program_desc.block(b_idx)
        for var_idx, var in enumerate(cur_block.all_vars()):
            name_old = var.name()
W
WeiXin 已提交
190 191 192
            should_rename = (include is None or name_old in include) and (
                exclude is None or name_old not in exclude)
            if should_rename:
193 194 195 196
                temp_name = name_old.split('_')
                if len(temp_name) > 1 and temp_name[-1].isnumeric():
                    temp_name = "_".join(temp_name[:-1])
                else:
W
WeiXin 已提交
197 198 199 200 201 202 203 204 205
                    temp_name = name_old
                while True:
                    name_new = _generate_unique_var_name_sync_with_main_program(
                        temp_name)
                    if name_new not in old_names[:var_idx] + old_names[var_idx +
                                                                       1:]:
                        break
            else:
                name_new = name_old
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
            if name_old != name_new:
                cur_block._rename_var(
                    cpt.to_bytes(name_old), cpt.to_bytes(name_new))
            dict_rename_var_old_new[name_old] = name_new
            dict_rename_var_new_old[name_new] = name_old

    for b_idx in six.moves.range(program_desc.num_blocks()):
        cur_block = program_desc.block(b_idx)
        for op_idx in six.moves.range(cur_block.op_size()):
            op = cur_block.op(op_idx)
            for input_arg_name in op.input_arg_names():
                if input_arg_name in dict_rename_var_old_new:
                    if input_arg_name != dict_rename_var_old_new[
                            input_arg_name]:
                        op._rename_input(
                            input_arg_name,
                            dict_rename_var_old_new[input_arg_name])
            for output_arg_name in op.output_arg_names():
                if output_arg_name in dict_rename_var_old_new:
                    if output_arg_name != dict_rename_var_old_new[
                            output_arg_name]:
                        op._rename_output(
                            output_arg_name,
                            dict_rename_var_old_new[output_arg_name])
    program_desc.flush()
    return dict_rename_var_new_old, dict_rename_var_old_new


234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
@switch_to_static_graph
def _build_program_by_desc(program_desc):
    prog = framework.Program()
    prog.desc = program_desc
    prog.blocks = [
        framework.Block(prog, i)
        for i in six.moves.range(prog.desc.num_blocks())
    ]
    prog._sync_with_cpp()
    return prog


def _change_is_test_status(program_desc, is_test):
    # change all `is_test` attributes
    for i in six.moves.range(program_desc.num_blocks()):
        block = program_desc.block(i)
        for j in six.moves.range(block.op_size()):
            op = block.op(j)
            if op.has_attr('is_test'):
                op._set_attr('is_test', is_test)


class _ProgramHolder(object):
    """
    Holds the execution information of a Program.

    _ProgramHolder is the execution unit of TranslatedLayer, 
    if TranslatedLayer contains multiple _ProgramHolder, 
    it can execute multiple methods

    _ProgramHolder is an internal concept.
    """

    def __init__(self, program_desc):
        super(_ProgramHolder, self).__init__()

        # input, output, persistable var info
271
        self._input_descs = []
272
        self._output_descs = []
273
        self._persistable_names = []
274 275 276 277

        # execution scope
        self._inner_scope = core.Scope()

278 279 280
        # append suffix var name dict
        self._suffix_varname_dict = None

281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
        # forward program
        self._infer_program_desc = self._preprocess(program_desc)
        # forward + backward program
        self._train_program_desc = self._append_backward_desc(
            self._infer_program_desc)

    @property
    def infer_program(self):
        return self._infer_program_desc

    @property
    def train_program(self):
        return self._train_program_desc

    @property
296 297
    def input_descs(self):
        return self._input_descs
298 299

    @property
300
    def output_descs(self):
301 302 303 304 305 306 307 308 309 310 311
        return self._output_descs

    @property
    def persistable_names(self):
        return self._persistable_names

    @property
    def scope(self):
        return self._inner_scope

    def _preprocess(self, program_desc):
W
WeiXin 已提交
312 313 314 315
        # rename persistable variables of 'program_desc'
        list_persistable_var = _get_persistable_var_names(program_desc)
        rename_new_old_dict, _ = _rename_var_program_desc(program_desc,
                                                          list_persistable_var)
316 317 318 319 320 321 322 323 324 325
        # 1. Prune original program
        # remove feed, fetch and scale-1 op, remove op_callstack attr
        ops_to_remove = []
        root_block = program_desc.block(0)
        for i in six.moves.range(root_block.op_size()):
            op = root_block.op(i)
            if op.type() == 'feed':
                ops_to_remove.append(i)
                feed_var_name = cpt.to_bytes(op.input('X')[0])
                root_block._remove_var(feed_var_name)
326 327
                self._input_descs.append(
                    root_block.find_var(cpt.to_bytes(op.output('Out')[0])))
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
            elif op.type() == 'scale' and op.output('Out')[0].startswith(
                    'save_infer_model/scale_'):
                ops_to_remove.append(i)
                out_var_name = cpt.to_bytes(op.output('Out')[0])
                root_block._remove_var(out_var_name)
                self._output_descs.append(
                    root_block.find_var(cpt.to_bytes(op.input('X')[0])))
            elif op.type() == 'fetch':
                ops_to_remove.append(i)
                fetch_var_name = cpt.to_bytes(op.output('Out')[0])
                root_block._remove_var(fetch_var_name)
                # NOTE: some old pre-train models have no extra scale_op
                if not op.input('X')[0].startswith('save_infer_model/scale_'):
                    self._output_descs.append(
                        root_block.find_var(cpt.to_bytes(op.input('X')[0])))
            else:
                if op.has_attr("op_callstack"):
                    op.remove_attr("op_callstack")

        for op_idx in reversed(ops_to_remove):
            root_block._remove_op(op_idx, op_idx + 1)

        # 2. Input processing, reverse feed vars
351
        self._input_descs.reverse()
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370

        # 3. Output processing, add scale for outputs
        tmp_program = _build_program_by_desc(program_desc)
        # NOTE: [why need append scale for outputs]
        # When dealing with some more complex pre-training models, there 
        # will be situations where the pre-training model has multiple 
        # fetch outputs. In the scenario of multiple fetch outputs, 
        # there is a special case where multiple outputs of the model 
        # may be on the same branch. According to the user's subsequent 
        # use, multiple outputs may be associated with multiple branches.
        # These subsequent operations are added in TranslatedLayer is 
        # agnostic during initialization, which results in subsequent 
        # gradient accumulation operations that are required on the 
        # output node in the middle of the branch will not be performed, 
        # resulting in error, details see pull request:
        # [https://github.com/PaddlePaddle/Paddle/pull/24627]
        self._append_scale_to_output(tmp_program)

        # 4. Persistable vars processing
371
        # - append loaded suffix to persistable vars
372 373 374 375 376 377 378 379
        # NOTE: [why need to append suffix to persistable vars]
        # Dygraph and static graph mode use the same naming mechanism. 
        # If users want to load the model fine-tune, it is possible 
        # to add the existing Layer in the loaded model to enhance 
        # the network. For example, the original saved model has linear, 
        # and later after loading, a new linear is added. At this time, 
        # there will be a problem of duplicate names, so here is unified 
        # to add the LOADED suffix to the parameters of the model loaded
380 381 382
        self._suffix_varname_dict = _get_loaded_var_new_old(program_desc,
                                                            rename_new_old_dict)

383 384 385 386 387 388 389 390 391 392 393 394 395
        # - get persistable var
        self._persistable_names = _get_persistable_var_names(program_desc)

        return program_desc

    @switch_to_static_graph
    def _append_scale_to_output(self, program):
        # 1. append scale & save var
        scale_output_vars = []
        with framework.program_guard(program):
            for i, out in enumerate(self._output_descs):
                var = program.global_block().var(out.name())
                var = nn.scale(
396
                    var, 1., name="translated_layer/scale_{}".format(i))
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
                scale_output_vars.append(var)
        # 2. update output names & descs
        for i, var in enumerate(scale_output_vars):
            self._output_descs[i] = var.desc

    @switch_to_static_graph
    def _append_backward_desc(self, infer_program_desc):
        program_desc_copy = core.ProgramDesc(infer_program_desc)

        # 1. set all `is_test` attributes to False
        _change_is_test_status(program_desc_copy, False)

        # 2. prepare program and related var
        # NOTE: To reuse backward interfaces, build Program firstly.
        # Originally, there is no need to build a program, but need to almost
        # rewrite a series of methods for append_backward for program_desc. 
        # Therefore, in order to reuse the method of backward.py, build the program here.
        program = _build_program_by_desc(program_desc_copy)

416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432
        # 3. Add the outputs which is only used for training and not saved in
        # inference program.
        for block_idx in six.moves.range(program.num_blocks):
            block = program.block(block_idx)
            for op in block.ops:
                if op.type == "batch_norm":
                    if "ReserveSpace" not in op.output_names or len(
                            op.output("ReserveSpace")) == 0:
                        reserve_space = block.create_var(
                            name=unique_name.generate_with_ignorable_key(
                                ".".join(["reserve_space", 'tmp'])),
                            dtype=block.var(op.input("X")[0]).dtype,
                            type=core.VarDesc.VarType.LOD_TENSOR,
                            persistable=False,
                            stop_gradient=True)
                        op.desc.set_output("ReserveSpace", [reserve_space.name])

433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477
        targets = []
        for out in self._output_descs:
            targets.append(program.global_block().var(out.name()))

        # 3. append backward
        backward.gradients(targets=targets, inputs=[])
        return program.desc


# [ TranslatedLayer : Run program in imperative mode ]
# 
# DESIGN IDEA: using an special operator `RunProgram`, execute program inside operator.
#
# Op's Inputs:
#   - the input variable of the user feed
#   - the necessary parameters of the network
# Op's Outputs:
#   - the output variable of fetch
# 
# This op receives a complete program desc, internally creates scope
# and executor, executes this program. Key points:
#
# 1. Data Sharing: 
#   The varBase of the dynamic graph is not in the scope, so before the op
#   executes the program internally, create persistent variables with the
#   same name as feed, parameters, and fetch in the scope, and share the
#   LoDTensor of the op input.
# 
# 2. Forward and Backward Separation:
#   Because the dynamic graph op performs the forward and backward separately,
#   in the forward op RunProgram, we only execute the forward part of whole program,
#   and in the backward op RunProgramGrad, we execute the backward part of program.
#   We can not separate the program into forward and backward part, which will 
#   make some control flow execution logic wrong.


# NOTE: [compatible] deal with model saved by save_inference_model,
# which need get var info from program desc
def _load_persistable_vars_by_program(model_path,
                                      program_holder,
                                      params_filename=None):
    # make sure the path has been checked
    persistable_vars = _get_persistable_vars(program_holder.infer_program)
    load_var_dict = {}
    for each_var in persistable_vars:
478
        orig_each_name = program_holder._suffix_varname_dict[each_var.name()]
479 480 481 482 483 484 485 486 487 488 489 490
        if _is_parameter(each_var, program_holder.infer_program):
            # create output varbase
            new_var = framework.ParamBase(
                shape=each_var.shape(),
                dtype=each_var.dtype(),
                name=each_var.name(),
                type=each_var.type(),
                persistable=True)
        else:
            new_var = framework._varbase_creator(
                type=each_var.type(),
                name=each_var.name(),
491
                shape=each_var.shape(),
492 493 494 495 496 497 498 499 500 501 502 503 504
                dtype=each_var.dtype(),
                persistable=True)
        if params_filename is None:
            framework._dygraph_tracer().trace_op(
                type='load',
                inputs={},
                outputs={'Out': new_var},
                attrs={'file_path': os.path.join(model_path, orig_each_name)})
        new_var.stop_gradient = False
        load_var_dict[each_var.name()] = new_var

    if params_filename is not None:
        load_var_list = []
505 506 507 508 509 510
        dict_name_old_new = {
            v: k
            for k, v in program_holder._suffix_varname_dict.items()
        }
        for name in sorted(dict_name_old_new.keys()):
            load_var_list.append(load_var_dict[dict_name_old_new[name]])
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537

        framework._dygraph_tracer().trace_op(
            type='load_combine',
            inputs={},
            outputs={'Out': load_var_list},
            attrs={'file_path': os.path.join(model_path, params_filename)})

        for each_var in persistable_vars:
            if not _is_parameter(each_var, program_holder.infer_program):
                continue
            param = load_var_dict[each_var.name()]
            param.stop_gradient = False

    # NOTE: [Recovery stop gradient information based on the program]
    # After loading the model, the stop_gradient information 
    # of the original variable is lost, but if a parameter does not
    # have a corresponding @GRAD variable in the backward program,
    # it can be said that it is also stop_gradient
    all_var_names = _get_all_var_names(program_holder.train_program)
    for var_name in load_var_dict:
        grad_var_name = var_name + core.grad_var_suffix()
        if grad_var_name not in all_var_names:
            load_var_dict[var_name].stop_gradient = True

    return load_var_dict


538 539
def _load_persistable_vars(model_path, var_info_path, program_holder,
                           params_filename):
540 541
    # 1. load extra var info
    with open(var_info_path, 'rb') as f:
542
        extra_var_info = pickle.load(f)
543 544 545 546

    # 2. construct var dict
    load_var_dict = dict()
    load_var_list = []
547 548 549 550
    inv_suffix_varname_dict = {
        value: key
        for key, value in program_holder._suffix_varname_dict.items()
    }
551 552 553 554 555 556 557 558 559 560

    # NOTE(chenweihang): we need load persistable vars based the program,
    # because the program may be pruned when `save_inference_model`, some
    # var in `extra_var_info` may have been pruned 
    for name in sorted(inv_suffix_varname_dict):
        if name not in extra_var_info:
            raise RuntimeError(
                "The model to be loaded is not complete."
                "The variable `%s` of program cannot be found in loaded model.",
                name)
561 562
        # get suffix var name, see [why need to append suffix to persistable vars]
        new_name = inv_suffix_varname_dict[name]
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579
        # create output varbase
        if extra_var_info[name].get('trainable', None) is not None:
            # use default shape and dtype
            new_var = framework.ParamBase(
                shape=[1],  # only to pass check, this shape is not meaningful
                dtype=core.VarDesc.VarType.FP32,
                name=new_name,
                persistable=True)
        else:
            new_var = framework._varbase_creator(
                name=new_name, persistable=True)

        new_var.stop_gradient = extra_var_info[name]['stop_gradient']
        load_var_dict[new_name] = new_var
        load_var_list.append(new_var)

    # 3. load all vars
580 581 582 583 584 585 586 587 588 589 590
    assert params_filename is not None, "params_filename should not be None."
    var_file_path = os.path.join(model_path, params_filename)
    if not os.path.exists(var_file_path):
        if len(extra_var_info) != 0:
            raise ValueError("The model to be loaded is incomplete.")
    else:
        framework._dygraph_tracer().trace_op(
            type='load_combine',
            inputs={},
            outputs={'Out': load_var_list},
            attrs={'file_path': var_file_path})
591 592 593 594

    return load_var_dict


595 596 597 598 599 600 601 602 603
# NOTE(chenweihang): to adapt paddle.load to get state_dict
def _remove_varname_suffix(var_dict, program_holder):
    no_suffix_var_dict = dict()
    for var_name in var_dict:
        no_suffix_name = program_holder._suffix_varname_dict[var_name]
        no_suffix_var_dict[no_suffix_name] = var_dict[var_name]
    return no_suffix_var_dict


604 605 606 607 608 609 610 611
def _construct_program_holders(model_path, model_filename=None):
    # make sure the path has been checked
    program_holder_dict = dict()

    if model_filename is not None:
        # [compatible] if assign model_filename, only can load one program as Layer.forward
        model_filename = os.path.basename(model_filename)
        model_file_path = os.path.join(model_path, model_filename)
612 613 614 615 616 617 618 619
        model_name = model_filename[:-len(INFER_MODEL_SUFFIX)]
        #Load every file that meets the requirements in the directory model_path.
        for filename in os.listdir(model_path):
            if model_filename == filename:
                func_name = 'forward'
                model_file_path = os.path.join(model_path, model_filename)
            elif filename.endswith(INFER_MODEL_SUFFIX) and filename.startswith(
                    model_name):
620 621 622 623 624 625 626
                parsing_names = filename[len(model_name):-len(
                    INFER_MODEL_SUFFIX) + 1].split('.')
                if len(parsing_names) == 3 and len(parsing_names[1]) > 0:
                    func_name = parsing_names[1]
                    model_file_path = os.path.join(model_path, filename)
                else:
                    continue
627 628 629 630
            else:
                continue
            program_holder_dict[func_name] = _ProgramHolder(
                _load_program_desc(model_file_path))
631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648
    else:
        for _, _, file_names in os.walk(model_path):
            for name in file_names:
                if 'model' in name:
                    model_file_path = os.path.join(model_path, name)
                    method_name = name.strip('_')
                    if method_name == 'model':
                        method_name = 'forward'
                    else:
                        method_name.replace('model', '')
                    program_holder_dict[method_name] = _ProgramHolder(
                        _load_program_desc(model_file_path))

    return program_holder_dict


def _construct_params_and_buffers(model_path,
                                  programs,
649 650
                                  params_filename=None,
                                  append_suffix=True):
651 652
    var_info_filename = str(params_filename) + ".info"
    var_info_path = os.path.join(model_path, var_info_filename)
653
    params_path = os.path.join(model_path, str(params_filename))
654

655 656
    if os.path.exists(var_info_path):
        var_dict = _load_persistable_vars(model_path, var_info_path,
657
                                          programs['forward'], params_filename)
658 659 660
        model_name = params_filename[:-len(INFER_PARAMS_SUFFIX)]
        #Load every file that meets the requirements in the directory model_path.
        for file_name in os.listdir(model_path):
661 662 663 664 665 666 667 668
            if file_name.startswith(model_name) and file_name.endswith(
                    INFER_PARAMS_SUFFIX):
                parsing_names = file_name[len(model_name):-len(
                    INFER_PARAMS_SUFFIX) + 1].split('.')
                if len(parsing_names) == 3 and len(parsing_names[1]) > 0:
                    func_name = parsing_names[1]
                else:
                    continue
669 670 671 672 673 674
            else:
                continue
            var_info_path = os.path.join(model_path, var_info_filename)
            var_dict.update(
                _load_persistable_vars(model_path, var_info_path, programs[
                    func_name], file_name))
675 676 677
    elif params_filename is not None and not os.path.exists(params_path):
        # When saving XX, there is only '*.pdmodel'
        return dict()
678 679 680
    else:
        var_dict = _load_persistable_vars_by_program(
            model_path, programs['forward'], params_filename)
681 682 683 684

    if not append_suffix:
        var_dict = _remove_varname_suffix(var_dict, programs['forward'])

685 686 687
    return var_dict


W
WeiXin 已提交
688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008
def _run_dygraph(instance, input, program_holder):

    # 1. prepare inputs, outputs, attrs
    input_vars = []
    for i, value in enumerate(input):
        if not isinstance(value, (np.ndarray, core.VarBase)):
            raise TypeError(
                "The type of input in TranslatedLayer must be numpy array or Variable(VarBase), but received %s."
                % type(value))
        # NOTE: In order to unify the API, firstly convert the input to VarBase
        if isinstance(value, np.ndarray):
            var = core.VarBase(
                value=value,
                name=program_holder.input_descs[i].name(),
                persistable=False,
                place=framework._current_expected_place(),
                zero_copy=True)
        else:
            var = value
            # NOTE: we changed var name here, 
            # but it may be an important name set by user
            var.name = program_holder.input_descs[i].name()
        input_vars.append(var)
    if instance._input_args_names is None:
        instance._input_args_names = [
            ins.name() for ins in program_holder.input_descs
        ]

    persistable_vars = []
    for var_name in program_holder.persistable_names:
        dy_var_name = instance._persistable_var_name_dict[var_name]
        if dy_var_name in instance._parameters:
            persistable_vars.append(instance._parameters[dy_var_name])
        elif dy_var_name in instance._buffers:
            persistable_vars.append(instance._buffers[dy_var_name])
        else:
            raise ValueError(
                "The persistable variable %s does not exist in current TranslatedLayer."
                % var_name)

    output_vars = []
    for var_desc in program_holder.output_descs:
        var = core.VarBase(var_desc.dtype(),
                           var_desc.shape(),
                           var_desc.name(), var_desc.type(), False)
        output_vars.append(var)

    # hold forward variables
    tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
                                 "program_out_scope",
                                 core.VarDesc.VarType.STEP_SCOPES, True)
    tmp_scope_vec.value().set_scope(program_holder.scope)

    # 2. run program by op
    trace_program = program_holder.infer_program if instance._is_test else program_holder.train_program
    end_op_index = program_holder.infer_program.block(0).op_size()
    framework._dygraph_tracer().trace_op(
        type='run_program',
        inputs={'X': input_vars,
                'Params': persistable_vars},
        outputs={'Out': output_vars,
                 'OutScope': tmp_scope_vec},
        attrs={
            'global_block': trace_program.block(0),
            'start_op_index': 0,
            'end_op_index': end_op_index,
            'is_test': instance._is_test
        })
    # NOTE: [ why need set param's gradient type here ]
    # if user set sparse gradient mode, the param's gradient
    # will be SelectedRows, not LoDTensor. But tracer will just
    # set param grad VarBase by forward VarBase(LoDTensor)
    # If we don't change grad_var type here, RunProgramOp need
    # transform SelectedRows to LoDTensor forcibly, it may not
    # be user wanted result.
    for persistable_var in persistable_vars:
        grad_var_name = var.name + core.grad_var_suffix()
        grad_var = trace_program.block(0).find_var(cpt.to_bytes(grad_var_name))
        # NOTE: cannot find var desc maybe not problem, 
        # such as in batch_norm
        if grad_var is None:
            continue
        persistable_var._set_grad_type(grad_var.type())

    # 3. prepare output, keep same form with inputs
    outs = output_vars
    if len(output_vars) == 1:
        outs = output_vars[0]
    return outs


def _run_static_graph(input, program_holder, trace_program):
    main_program = framework.default_main_program()
    param_var_names = _get_persistable_var_names(trace_program)
    _, dict_rename_var_old_new = _rename_var_program_desc(
        trace_program, exclude=param_var_names)
    trace_program.flush()
    output_names = [var.name() for var in program_holder.output_descs]
    # append blocks from 'trace_program'
    _append_block(main_program, trace_program, program_holder, input,
                  dict_rename_var_old_new)
    main_program._sync_with_cpp()
    outs = _get_output_from_program(main_program, program_holder,
                                    dict_rename_var_old_new)
    if len(outs) == 1:
        outs = outs[0]
    return outs


def _collect_current_and_parent_var(program, block_idx):
    '''
    Get variables in current block and its parent block.
    
    Args:
        program(Program): The program containing the current block.
        block_idx(int): index of current block.

    Returns:
        List: list of variables.
    '''
    vars = []
    if block_idx < 0:
        return vars
    for var in program.block(block_idx).vars:
        vars.append(var)
    parent_idx = program.block(block_idx).parent_idx
    if parent_idx > -1:
        vars += _collect_current_and_parent_var(program, parent_idx)
    return vars


def _append_block(dest_program,
                  src_program_desc,
                  program_holder,
                  input_variables,
                  dict_rename_var_old_new=None):
    '''
    Append Variables and Operators in 'src_program_desc' to dest_program.
    
    Args:
        dest_program(Program): Variables and Operators are appended to it.
        src_program_desc(ProgramDesc): Variables in it will be appended to 'dest_program'.
        program_holder(_ProgramHolder): program_holder of TranslatedLayer
        input_variables(list): list of input variables
        dict_rename_var_old_new(None|dict): When using '_rename_var_program_desc', 
        use it to map the name of the variable before it was modified and the new name.
    '''

    origin_block_idx = dest_program.current_block_idx
    param_var_names = _collect_current_and_parent_var(dest_program,
                                                      origin_block_idx)
    append_var_from_block_desc_static(
        dest_program.block(origin_block_idx),
        src_program_desc.block(0),
        exclude=param_var_names)

    name_inp_desc = [inp.name() for inp in program_holder.input_descs]
    input_names = [inp.name for inp in input_variables]
    if len(name_inp_desc) != len(input_names):
        raise ValueError(
            "The number of input is invalid, expected {}, but received {}.".
            format(len(name_inp_desc), len(input_names)))
    for i, out_name in enumerate(name_inp_desc):
        if dict_rename_var_old_new:
            out_name = dict_rename_var_old_new[out_name]
        dest_program.block(origin_block_idx).append_op(
            type='assign',
            inputs={'X': [input_names[i]]},
            outputs={'Out': [out_name]})

    append_ops = append_op_from_block_desc_static(
        dest_program.block(origin_block_idx), src_program_desc.block(0))
    dest_program._sync_with_cpp()

    offset_block_idx = dest_program.num_blocks - 1

    if src_program_desc.num_blocks() > 1:
        for src_block_idx in range(1, src_program_desc.num_blocks()):
            src_block = src_program_desc.block(src_block_idx)
            src_parent_idx = src_block.parent
            if src_parent_idx > 0:
                parent_idx = offset_block_idx + parent_idx
            else:
                parent_idx = origin_block_idx
            dest_block = dest_program._create_block(parent_idx=parent_idx)
            append_var_from_block_desc_static(
                dest_block, src_block, exclude=param_var_names)
            append_ops += append_op_from_block_desc_static(dest_block,
                                                           src_block)

    dest_program._sync_with_cpp()
    for op in append_ops:
        if op.has_attr('sub_block'):
            sub = op.attr('sub_block')
            if isinstance(sub, framework.core.BlockDesc):
                origin_id = sub.id
            if isinstance(sub, framework.Block):
                origin_id = sub.idx
            op._set_attr('sub_block',
                         dest_program.block(offset_block_idx + origin_id))
    dest_program._sync_with_cpp()
    dest_program.current_block_idx = origin_block_idx


def _get_output_from_program(program,
                             program_holder,
                             dict_rename_var_old_new=None):
    """
    Get output name of 'program' according to program_holder
    """
    outs = list()
    for var in program_holder.output_descs:
        for idx in range(program.num_blocks):
            vars = program.block(idx).vars
            var_name = var.name()
            if dict_rename_var_old_new:
                var_name = dict_rename_var_old_new[var_name]
            if var_name in vars:
                out = vars[var_name]
                if out not in outs:
                    outs.append(out)
    return outs


def append_op_from_block_desc_static(block, src_block_desc):
    """
    Append Operators of 'src_block_desc' to current block.

    Args:
        block(Block): append OP of  'src_block_desc' to it.
        src_block_desc(BlockDesc): append var of  'src_block_desc'

    Returns:
        List: list of the OP that are append to current block.
    """
    ops = []
    for i in range(src_block_desc.op_size()):
        ops.append(append_op_from_desc_static(block, src_block_desc.op(i)))
    return ops


def append_op_from_desc_static(block, op_desc):
    """
    Append Operators to 'block' according to 'op_desc'.

    Args:
        block(Block): append OP of  'src_block_desc' to it.
        op_desc(OpDesc): create OP according to it.

    Returns:
        Operator: OP appended to 'block'.
    """
    op_type = op_desc.type()
    op_append = block.desc.append_op()
    op_append.copy_from(op_desc)
    op = framework.Operator(
        block=block,
        desc=op_append,
        type=op_type,
        inputs=None,
        outputs=None,
        attrs=None)
    block.ops.append(op)
    return op


def append_var_from_block_desc_static(block,
                                      src_block_desc,
                                      include=None,
                                      exclude=None):
    """
    Append Variables of 'src_block_desc' to current block.
    If 'include' is not `None`,variables that are not in include are not append.
    If 'exclude' is not `None`,variables that are in exclude will are not append.

    Args:
        block(Block): append Variables of  'src_block_desc' to it.
        src_block_desc(BlockDesc): append var of  'src_block_desc'
        include(List):list of names of variables
        exclude(List):list of names of variables

    Returns:
        List: list of the variables that are append to current block.
    """
    vars_append = []
    for var_desc in src_block_desc.all_vars():
        var_desc_name = var_desc.name()
        should_append = (include is None or var_desc_name in include) and (
            exclude is None or var_desc_name not in exclude)
        if not block.has_var(var_desc_name) and should_append:
            var_type = var_desc.type()
            if var_type in [
                    core.VarDesc.VarType.SELECTED_ROWS,
                    core.VarDesc.VarType.LOD_TENSOR,
                    core.VarDesc.VarType.LOD_TENSOR_ARRAY
            ]:
                data_type = var_desc.dtype()
                var_shape = var_desc.shape()
            else:
                data_type = None
                var_shape = None
            if var_type in [
                    core.VarDesc.VarType.LOD_TENSOR,
                    core.VarDesc.VarType.LOD_TENSOR_ARRAY
            ]:
                lod_level = var_desc.lod_level()
            else:
                lod_level = None

            vars_append.append(
                block.create_var(
                    name=var_desc.name(),
                    dtype=data_type,
                    type=var_type,
                    shape=var_shape,
                    lod_level=lod_level,
                    persistable=var_desc.persistable(),
                    set_need_check_feed=var_desc.need_check_feed()))
    return vars_append


1009 1010
class TranslatedLayer(layers.Layer):
    """
1011 1012 1013
    TranslatedLayer is a ``paddle.nn.Layer`` for holding the model 
    loaded by :ref:`api_paddle_jit_load` . It can be used like a 
    general Layer object in eval or train mode.
1014 1015
    
    .. note:
1016
        The TranslatedLayer objects should not be created by constructor, it only can be loaded and constructed by :ref:`api_paddle_jit_load` .
1017 1018 1019 1020 1021

    Examples:
        .. code-block:: python

            import numpy as np
1022 1023 1024
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
1025

1026 1027 1028
            BATCH_SIZE = 16
            BATCH_NUM = 4
            EPOCH_NUM = 4
1029

1030 1031 1032 1033 1034 1035 1036
            IMAGE_SIZE = 784
            CLASS_NUM = 10

            # define a random dataset
            class RandomDataset(paddle.io.Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples
1037

1038 1039 1040 1041
                def __getitem__(self, idx):
                    image = np.random.random([IMAGE_SIZE]).astype('float32')
                    label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
                    return image, label
1042

1043 1044
                def __len__(self):
                    return self.num_samples
1045

1046 1047
            class LinearNet(nn.Layer):
                def __init__(self):
1048
                    super(LinearNet, self).__init__()
1049
                    self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
1050

1051
                @paddle.jit.to_static
1052 1053 1054
                def forward(self, x):
                    return self._linear(x)

1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065
            def train(layer, loader, loss_fn, opt):
                for epoch_id in range(EPOCH_NUM):
                    for batch_id, (image, label) in enumerate(loader()):
                        out = layer(image)
                        loss = loss_fn(out, label)
                        loss.backward()
                        opt.step()
                        opt.clear_grad()
                        print("Epoch {} batch {}: loss = {}".format(
                            epoch_id, batch_id, np.mean(loss.numpy())))

1066 1067
            # 1. train & save model.

1068 1069 1070 1071
            # create network
            layer = LinearNet()
            loss_fn = nn.CrossEntropyLoss()
            adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())
1072

1073 1074 1075 1076 1077 1078 1079
            # create data loader
            dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
            loader = paddle.io.DataLoader(dataset,
                batch_size=BATCH_SIZE,
                shuffle=True,
                drop_last=True,
                num_workers=2)
1080

1081 1082
            # train
            train(layer, loader, loss_fn, adam)
1083

1084
            # save
1085
            model_path = "linear.example.model"
1086
            paddle.jit.save(layer, model_path)
1087 1088

            # 2. load model as TranslatedLayer
1089 1090 1091 1092

            # load
            translated_layer = paddle.jit.load(model_path)

1093 1094
            # inference
            translated_layer.eval()
1095
            x = paddle.randn([1, IMAGE_SIZE], 'float32')
1096
            pred = translated_layer(x)
1097

1098 1099
            # fine-tune
            translated_layer.train()
1100 1101
            adam = opt.Adam(learning_rate=0.001, parameters=translated_layer.parameters())
            train(translated_layer, loader, loss_fn, adam)
1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113

    """

    def __init__(self, programs, persistable_vars):
        super(TranslatedLayer, self).__init__()

        if not isinstance(programs, dict):
            raise TypeError(
                "TranslatedLayer need to use _ProgramHolder's dict for initialization."
            )
        if not isinstance(persistable_vars, dict):
            raise TypeError(
1114
                "TranslatedLayer need to use persistable variable dict for initialization."
1115 1116 1117 1118
            )

        self._program_holder_dict = programs

1119 1120 1121 1122 1123 1124 1125 1126
        # NOTE(chenweihang): [ why not use var name directly? ]
        # When add parameter or buffer to Layer by follow apis,
        # the variable name can't contain `.`, beccause which may cause
        # AttributeError when access the newly added parameter or buffer
        # in the form of `self.**.**``, but the ParamBase or BarBase
        # name contains `.` originally, such as `linear_0.w_0`, so here
        # need to generate new var name for each var
        self._persistable_var_name_dict = dict()
1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141
        # the TranslatedLayer object holded var names count started from 0
        with unique_name.guard():
            for name, var in persistable_vars.items():
                if isinstance(var, framework.ParamBase):
                    dy_name = _generate_unique_var_name(PARAMETER_NAME_PREFIX)
                    self._persistable_var_name_dict[name] = dy_name
                    self.add_parameter(dy_name, var)
                elif isinstance(var, core.VarBase):
                    dy_name = _generate_unique_var_name(BUFFER_NAME_PREFIX)
                    self._persistable_var_name_dict[name] = dy_name
                    self.register_buffer(dy_name, var)
                else:
                    raise TypeError(
                        "Adding persistent variable which  to layer is not supported now"
                    )
1142 1143

        self._is_test = True
W
WeiXin 已提交
1144
        self._input_args_names = None
1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161

    @staticmethod
    @framework.dygraph_only
    def _construct(model_path, configs=None):
        # 0. dir and filename check
        model_path = os.path.normpath(model_path)
        if not os.path.isdir(model_path):
            raise ValueError("There is no directory named '%s'" % model_path)
        model_filename = None
        params_filename = None
        if configs is not None:
            model_filename = configs.model_filename
            params_filename = configs.params_filename

        # 1. load program desc & construct _ProgramHolder
        programs = _construct_program_holders(model_path, model_filename)

1162
        # 2. load layer parameters & buffers
1163 1164
        persistable_vars = _construct_params_and_buffers(model_path, programs,
                                                         params_filename)
1165 1166 1167 1168 1169 1170

        # 3. construct TranslatedLayer object
        translated_layer = TranslatedLayer(programs, persistable_vars)

        # 4. create TranslatedLayer's execution method
        for method_name, program_holder in programs.items():
1171 1172 1173 1174
            if translated_layer._input_args_names is None:
                translated_layer._input_args_names = [
                    ins.name() for ins in program_holder.input_descs
                ]
1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185
            setattr(TranslatedLayer, method_name,
                    TranslatedLayer._execution_method_creator(method_name,
                                                              program_holder))

        # 5. set TranslatedLayer's default mode to eval
        translated_layer.eval()

        return translated_layer

    @staticmethod
    def _execution_method_creator(method_name, program_holder):
W
WeiXin 已提交
1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202
        def __i_m_p_l__(self, *input):
            program_holder = self._program_holder_dict[__i_m_p_l__.__name__]
            # When using jit.save, it runs in static graph mode.
            # Run in dynamic graph mode when the model is inferring.
            if in_dygraph_mode():
                return _run_dygraph(self, input, program_holder)
            else:
                # NOTE(weixin): [ why not use 'program_holder.infer_program' directly? ]
                # When use '_run_static_graph(input, program_holder, program_holder.infer_program)',
                # because '_run_static_graph' modifies 'ProgramDesc', 'OpDesc.op_size()' will return a very large wrong number.
                # A Segmentation fault error may occur if used 'p=ProgramDesc(program_holder.infer_program)'.
                p = framework.Program._construct_from_desc(
                    core.ProgramDesc(program_holder.infer_program))
                return _run_static_graph(input, program_holder, p.desc)

        __i_m_p_l__.__name__ = method_name
        return __i_m_p_l__
1203 1204 1205 1206 1207 1208

    def train(self):
        self._is_test = False

    def eval(self):
        self._is_test = True
1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295

    def program(self, method_name='forward'):
        """
        Gets translated program of specified method.

        Args:
            - method_name (string): mehtod name corresponding to the program
                to be obtained. Default: 'forward'.
        
        Returns:
            Program

        Examples:
            .. code-block:: python
            
                import numpy as np
                import paddle
                import paddle.nn as nn
                import paddle.optimizer as opt

                BATCH_SIZE = 16
                BATCH_NUM = 4
                EPOCH_NUM = 4

                IMAGE_SIZE = 784
                CLASS_NUM = 10

                # define a random dataset
                class RandomDataset(paddle.io.Dataset):
                    def __init__(self, num_samples):
                        self.num_samples = num_samples

                    def __getitem__(self, idx):
                        image = np.random.random([IMAGE_SIZE]).astype('float32')
                        label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
                        return image, label

                    def __len__(self):
                        return self.num_samples

                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)

                    @paddle.jit.to_static
                    def forward(self, x):
                        return self._linear(x)

                def train(layer, loader, loss_fn, opt):
                    for epoch_id in range(EPOCH_NUM):
                        for batch_id, (image, label) in enumerate(loader()):
                            out = layer(image)
                            loss = loss_fn(out, label)
                            loss.backward()
                            opt.step()
                            opt.clear_grad()
                            print("Epoch {} batch {}: loss = {}".format(
                                epoch_id, batch_id, np.mean(loss.numpy())))

                # create network
                layer = LinearNet()
                loss_fn = nn.CrossEntropyLoss()
                adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())

                # create data loader
                dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
                loader = paddle.io.DataLoader(dataset,
                    batch_size=BATCH_SIZE,
                    shuffle=True,
                    drop_last=True,
                    num_workers=2)

                # train
                train(layer, loader, loss_fn, adam)

                # save
                model_path = "linear.example.model"
                paddle.jit.save(layer, model_path)

                # load
                translated_layer = paddle.jit.load(model_path)

                # get program
                program = translated_layer.program()
        """
        # 1. get program holder
1296
        program_holder = self._get_program_holder(method_name)
1297 1298 1299 1300 1301 1302 1303

        # 2. get inference program desc
        program_desc = program_holder.infer_program

        # 3. construct program
        program = _build_program_by_desc(program_desc)
        return program
1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344

    def _get_program_holder(self, method_name='forward'):
        program_holder = self._program_holder_dict.get(method_name, None)
        if program_holder is None:
            raise ValueError(
                "The method `%s` does not exist in loaded TranslatedLayer." %
                method_name)
        return program_holder

    def _input_spec(self, method_name='forward'):
        # 1. get program holder
        program_holder = self._get_program_holder(method_name)

        # 2. build input spec by input desc
        input_spec = []
        for var_desc in program_holder.input_descs:
            spec = paddle.static.InputSpec(
                shape=var_desc.shape(),
                dtype=var_desc.dtype(),
                name=var_desc.name())
            input_spec.append(spec)

        return input_spec

    def _output_spec(self, method_name='forward'):
        # 1. get program holder
        program_holder = self._get_program_holder(method_name)

        # 2. build output spec by output desc
        output_spec = []
        for var_desc in program_holder.output_descs:
            # NOTE(chenweihang): InputSpec describes a tensor, not just input. 
            # Maybe the name is not good enough. Here we use InputSpec to 
            # construct the description of Output tensor
            spec = paddle.static.InputSpec(
                shape=var_desc.shape(),
                dtype=var_desc.dtype(),
                name=var_desc.name())
            output_spec.append(spec)

        return output_spec