io.py 40.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import six
import pickle
import numpy as np

22
import paddle
23 24 25 26
from paddle import compat as cpt
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid import backward
27
from paddle.fluid import unique_name
28 29 30 31 32 33
from paddle.fluid.dygraph import layers
from paddle.fluid.layers import nn
from paddle.fluid.dygraph.base import switch_to_static_graph

__all__ = ['TranslatedLayer']

34 35 36 37
INFER_MODEL_SUFFIX = ".pdmodel"
INFER_PARAMS_SUFFIX = ".pdiparams"
INFER_PARAMS_INFO_SUFFIX = ".pdiparams.info"

38 39 40
LOADED_VAR_SUFFIX = "load"
PARAMETER_NAME_PREFIX = "param"
BUFFER_NAME_PREFIX = "buffer"
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116


def _load_program_desc(model_file_path):
    # 1. parse program desc
    with open(model_file_path, "rb") as f:
        program_desc_str = f.read()

    program_desc = core.ProgramDesc(program_desc_str)
    if not core._is_program_version_supported(program_desc._version()):
        raise ValueError("Unsupported program version: %d\n" %
                         program_desc._version())

    return program_desc


def _is_persistable(var_desc):
    if var_desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
            var_desc.type() == core.VarDesc.VarType.FETCH_LIST or \
            var_desc.type() == core.VarDesc.VarType.READER or \
            var_desc.type() == core.VarDesc.VarType.RAW:
        return False
    return var_desc.persistable()


def _is_parameter(persistable_var_desc, program_desc):
    # 1. firstly, param should be input of op
    input_ops = []  # op can be repeated
    for block_idx in six.moves.range(program_desc.num_blocks()):
        block = program_desc.block(block_idx)
        for op_idx in six.moves.range(block.op_size()):
            op = block.op(op_idx)
            # NOTE: parameter is the input of a certain op
            if persistable_var_desc.name() in op.input_arg_names():
                input_ops.append(op)
    # 2. secondly, param should not be output of op or be same op's output
    for block_idx in six.moves.range(program_desc.num_blocks()):
        block = program_desc.block(block_idx)
        for op_idx in six.moves.range(block.op_size()):
            op = block.op(op_idx)
            if persistable_var_desc.name() in op.output_arg_names():
                # such as batch_norm_op
                if op in input_ops:
                    continue
                else:
                    return False
    return True


def _get_persistable_vars(program_desc):
    persistable_vars = []
    for i in six.moves.range(program_desc.num_blocks()):
        block = program_desc.block(i)
        persistable_vars.extend(list(filter(_is_persistable, block.all_vars())))
    return persistable_vars


def _get_persistable_var_names(program_desc):
    """
    Get all persistable variable names in ProgramDesc.
    """
    var_names = []
    persistable_vars = _get_persistable_vars(program_desc)
    for var in persistable_vars:
        var_names.append(var.name())
    return var_names


def _get_all_var_names(program_desc):
    all_var_names = set()
    for i in six.moves.range(program_desc.num_blocks()):
        block = program_desc.block(i)
        for var in block.all_vars():
            all_var_names.add(var.name())
    return all_var_names


117
@switch_to_static_graph
118 119 120
def _append_loaded_suffix(name):
    """
    Append loaded suffix to the given variable name
121
    e.g. x ==> x.load_0, x.load_0 ==> x.load_0.load_0
122
    """
123
    suffix = LOADED_VAR_SUFFIX
124
    name = cpt.to_text(name)
125 126
    new_name = unique_name.generate_with_ignorable_key('.'.join((name, suffix)))
    return new_name
127 128


129 130 131
@switch_to_static_graph
def _generate_unique_var_name(prefix):
    return unique_name.generate_with_ignorable_key(prefix)
132 133 134


def _append_loaded_suffix_to_var(program_desc):
135
    suffix_varname_dict = dict()
136 137 138 139
    persistable_vars = _get_persistable_vars(program_desc)
    for var_desc in persistable_vars:
        old_name = var_desc.name()
        new_name = _append_loaded_suffix(var_desc.name())
140
        suffix_varname_dict[new_name] = old_name
141 142 143
        var_desc.set_name(new_name)
        for block_idx in six.moves.range(program_desc.num_blocks()):
            block = program_desc.block(block_idx)
C
Chen Weihang 已提交
144
            block._rename_var(cpt.to_bytes(old_name), cpt.to_bytes(new_name))
145 146 147 148
            for op_idx in six.moves.range(block.op_size()):
                op = block.op(op_idx)
                op._rename_input(old_name, new_name)
                op._rename_output(old_name, new_name)
149
    return suffix_varname_dict
150 151


152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
@switch_to_static_graph
def _generate_unique_var_name_sync_with_main_program(prefix):
    return unique_name.generate(prefix)


def _get_loaded_var_new_old(program_desc, all_new_old_dict_all):
    new_old_dict = dict()
    persistable_vars = _get_persistable_vars(program_desc)
    for var_desc in persistable_vars:
        name_new = var_desc.name()
        new_old_dict[name_new] = all_new_old_dict_all[name_new]
    return new_old_dict


def _rename_var_program_desc(program_desc):
    """
    Change the name of the loaded variables.Use 'unique_name.generate' to avoid duplication
    e.g. x ==> x_0, x_0 ==> x_1
    """
    dict_rename_var_old_new = dict()
    dict_rename_var_new_old = dict()
    old_names = []
    for b_idx in six.moves.range(program_desc.num_blocks()):
        cur_block = program_desc.block(b_idx)
        for var in cur_block.all_vars():
            old_names.append(var.name())
    persistable_vars = _get_persistable_vars(program_desc)
    for b_idx in six.moves.range(program_desc.num_blocks()):
        cur_block = program_desc.block(b_idx)
        for var_idx, var in enumerate(cur_block.all_vars()):
            if var not in persistable_vars:
                continue
            name_old = var.name()
            while True:
                temp_name = name_old.split('_')
                if len(temp_name) > 1 and temp_name[-1].isnumeric():
                    temp_name = "_".join(temp_name[:-1])
                else:
                    temp_name = "_".join(temp_name)

                name_new = _generate_unique_var_name_sync_with_main_program(
                    temp_name)
                if name_new not in old_names[:var_idx] + old_names[var_idx +
                                                                   1:]:
                    break
            if name_old != name_new:
                cur_block._rename_var(
                    cpt.to_bytes(name_old), cpt.to_bytes(name_new))
            dict_rename_var_old_new[name_old] = name_new
            dict_rename_var_new_old[name_new] = name_old

    for b_idx in six.moves.range(program_desc.num_blocks()):
        cur_block = program_desc.block(b_idx)
        for op_idx in six.moves.range(cur_block.op_size()):
            op = cur_block.op(op_idx)
            for input_arg_name in op.input_arg_names():
                if input_arg_name in dict_rename_var_old_new:
                    if input_arg_name != dict_rename_var_old_new[
                            input_arg_name]:
                        op._rename_input(
                            input_arg_name,
                            dict_rename_var_old_new[input_arg_name])
            for output_arg_name in op.output_arg_names():
                if output_arg_name in dict_rename_var_old_new:
                    if output_arg_name != dict_rename_var_old_new[
                            output_arg_name]:
                        op._rename_output(
                            output_arg_name,
                            dict_rename_var_old_new[output_arg_name])
    program_desc.flush()
    return dict_rename_var_new_old, dict_rename_var_old_new


225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
@switch_to_static_graph
def _build_program_by_desc(program_desc):
    prog = framework.Program()
    prog.desc = program_desc
    prog.blocks = [
        framework.Block(prog, i)
        for i in six.moves.range(prog.desc.num_blocks())
    ]
    prog._sync_with_cpp()
    return prog


def _change_is_test_status(program_desc, is_test):
    # change all `is_test` attributes
    for i in six.moves.range(program_desc.num_blocks()):
        block = program_desc.block(i)
        for j in six.moves.range(block.op_size()):
            op = block.op(j)
            if op.has_attr('is_test'):
                op._set_attr('is_test', is_test)


class _ProgramHolder(object):
    """
    Holds the execution information of a Program.

    _ProgramHolder is the execution unit of TranslatedLayer, 
    if TranslatedLayer contains multiple _ProgramHolder, 
    it can execute multiple methods

    _ProgramHolder is an internal concept.
    """

    def __init__(self, program_desc):
        super(_ProgramHolder, self).__init__()

        # input, output, persistable var info
262
        self._input_descs = []
263
        self._output_descs = []
264
        self._persistable_names = []
265 266 267 268

        # execution scope
        self._inner_scope = core.Scope()

269 270 271
        # append suffix var name dict
        self._suffix_varname_dict = None

272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
        # forward program
        self._infer_program_desc = self._preprocess(program_desc)
        # forward + backward program
        self._train_program_desc = self._append_backward_desc(
            self._infer_program_desc)

    @property
    def infer_program(self):
        return self._infer_program_desc

    @property
    def train_program(self):
        return self._train_program_desc

    @property
287 288
    def input_descs(self):
        return self._input_descs
289 290

    @property
291
    def output_descs(self):
292 293 294 295 296 297 298 299 300 301 302
        return self._output_descs

    @property
    def persistable_names(self):
        return self._persistable_names

    @property
    def scope(self):
        return self._inner_scope

    def _preprocess(self, program_desc):
303 304
        # rename variables of 'program_desc'
        rename_new_old_dict, _ = _rename_var_program_desc(program_desc)
305 306 307 308 309 310 311 312 313 314
        # 1. Prune original program
        # remove feed, fetch and scale-1 op, remove op_callstack attr
        ops_to_remove = []
        root_block = program_desc.block(0)
        for i in six.moves.range(root_block.op_size()):
            op = root_block.op(i)
            if op.type() == 'feed':
                ops_to_remove.append(i)
                feed_var_name = cpt.to_bytes(op.input('X')[0])
                root_block._remove_var(feed_var_name)
315 316
                self._input_descs.append(
                    root_block.find_var(cpt.to_bytes(op.output('Out')[0])))
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
            elif op.type() == 'scale' and op.output('Out')[0].startswith(
                    'save_infer_model/scale_'):
                ops_to_remove.append(i)
                out_var_name = cpt.to_bytes(op.output('Out')[0])
                root_block._remove_var(out_var_name)
                self._output_descs.append(
                    root_block.find_var(cpt.to_bytes(op.input('X')[0])))
            elif op.type() == 'fetch':
                ops_to_remove.append(i)
                fetch_var_name = cpt.to_bytes(op.output('Out')[0])
                root_block._remove_var(fetch_var_name)
                # NOTE: some old pre-train models have no extra scale_op
                if not op.input('X')[0].startswith('save_infer_model/scale_'):
                    self._output_descs.append(
                        root_block.find_var(cpt.to_bytes(op.input('X')[0])))
            else:
                if op.has_attr("op_callstack"):
                    op.remove_attr("op_callstack")

        for op_idx in reversed(ops_to_remove):
            root_block._remove_op(op_idx, op_idx + 1)

        # 2. Input processing, reverse feed vars
340
        self._input_descs.reverse()
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359

        # 3. Output processing, add scale for outputs
        tmp_program = _build_program_by_desc(program_desc)
        # NOTE: [why need append scale for outputs]
        # When dealing with some more complex pre-training models, there 
        # will be situations where the pre-training model has multiple 
        # fetch outputs. In the scenario of multiple fetch outputs, 
        # there is a special case where multiple outputs of the model 
        # may be on the same branch. According to the user's subsequent 
        # use, multiple outputs may be associated with multiple branches.
        # These subsequent operations are added in TranslatedLayer is 
        # agnostic during initialization, which results in subsequent 
        # gradient accumulation operations that are required on the 
        # output node in the middle of the branch will not be performed, 
        # resulting in error, details see pull request:
        # [https://github.com/PaddlePaddle/Paddle/pull/24627]
        self._append_scale_to_output(tmp_program)

        # 4. Persistable vars processing
360
        # - append loaded suffix to persistable vars
361 362 363 364 365 366 367 368
        # NOTE: [why need to append suffix to persistable vars]
        # Dygraph and static graph mode use the same naming mechanism. 
        # If users want to load the model fine-tune, it is possible 
        # to add the existing Layer in the loaded model to enhance 
        # the network. For example, the original saved model has linear, 
        # and later after loading, a new linear is added. At this time, 
        # there will be a problem of duplicate names, so here is unified 
        # to add the LOADED suffix to the parameters of the model loaded
369 370 371
        self._suffix_varname_dict = _get_loaded_var_new_old(program_desc,
                                                            rename_new_old_dict)

372 373 374 375 376 377 378 379 380 381 382 383 384
        # - get persistable var
        self._persistable_names = _get_persistable_var_names(program_desc)

        return program_desc

    @switch_to_static_graph
    def _append_scale_to_output(self, program):
        # 1. append scale & save var
        scale_output_vars = []
        with framework.program_guard(program):
            for i, out in enumerate(self._output_descs):
                var = program.global_block().var(out.name())
                var = nn.scale(
385
                    var, 1., name="translated_layer/scale_{}".format(i))
386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
                scale_output_vars.append(var)
        # 2. update output names & descs
        for i, var in enumerate(scale_output_vars):
            self._output_descs[i] = var.desc

    @switch_to_static_graph
    def _append_backward_desc(self, infer_program_desc):
        program_desc_copy = core.ProgramDesc(infer_program_desc)

        # 1. set all `is_test` attributes to False
        _change_is_test_status(program_desc_copy, False)

        # 2. prepare program and related var
        # NOTE: To reuse backward interfaces, build Program firstly.
        # Originally, there is no need to build a program, but need to almost
        # rewrite a series of methods for append_backward for program_desc. 
        # Therefore, in order to reuse the method of backward.py, build the program here.
        program = _build_program_by_desc(program_desc_copy)

        targets = []
        for out in self._output_descs:
            targets.append(program.global_block().var(out.name()))

        # 3. append backward
        backward.gradients(targets=targets, inputs=[])
        return program.desc


# [ TranslatedLayer : Run program in imperative mode ]
# 
# DESIGN IDEA: using an special operator `RunProgram`, execute program inside operator.
#
# Op's Inputs:
#   - the input variable of the user feed
#   - the necessary parameters of the network
# Op's Outputs:
#   - the output variable of fetch
# 
# This op receives a complete program desc, internally creates scope
# and executor, executes this program. Key points:
#
# 1. Data Sharing: 
#   The varBase of the dynamic graph is not in the scope, so before the op
#   executes the program internally, create persistent variables with the
#   same name as feed, parameters, and fetch in the scope, and share the
#   LoDTensor of the op input.
# 
# 2. Forward and Backward Separation:
#   Because the dynamic graph op performs the forward and backward separately,
#   in the forward op RunProgram, we only execute the forward part of whole program,
#   and in the backward op RunProgramGrad, we execute the backward part of program.
#   We can not separate the program into forward and backward part, which will 
#   make some control flow execution logic wrong.


# NOTE: [compatible] deal with model saved by save_inference_model,
# which need get var info from program desc
def _load_persistable_vars_by_program(model_path,
                                      program_holder,
                                      params_filename=None):
    # make sure the path has been checked
    persistable_vars = _get_persistable_vars(program_holder.infer_program)
    load_var_dict = {}
    for each_var in persistable_vars:
450
        orig_each_name = program_holder._suffix_varname_dict[each_var.name()]
451 452 453 454 455 456 457 458 459 460 461 462
        if _is_parameter(each_var, program_holder.infer_program):
            # create output varbase
            new_var = framework.ParamBase(
                shape=each_var.shape(),
                dtype=each_var.dtype(),
                name=each_var.name(),
                type=each_var.type(),
                persistable=True)
        else:
            new_var = framework._varbase_creator(
                type=each_var.type(),
                name=each_var.name(),
463
                shape=each_var.shape(),
464 465 466 467 468 469 470 471 472 473 474 475 476
                dtype=each_var.dtype(),
                persistable=True)
        if params_filename is None:
            framework._dygraph_tracer().trace_op(
                type='load',
                inputs={},
                outputs={'Out': new_var},
                attrs={'file_path': os.path.join(model_path, orig_each_name)})
        new_var.stop_gradient = False
        load_var_dict[each_var.name()] = new_var

    if params_filename is not None:
        load_var_list = []
477 478 479 480 481 482
        dict_name_old_new = {
            v: k
            for k, v in program_holder._suffix_varname_dict.items()
        }
        for name in sorted(dict_name_old_new.keys()):
            load_var_list.append(load_var_dict[dict_name_old_new[name]])
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509

        framework._dygraph_tracer().trace_op(
            type='load_combine',
            inputs={},
            outputs={'Out': load_var_list},
            attrs={'file_path': os.path.join(model_path, params_filename)})

        for each_var in persistable_vars:
            if not _is_parameter(each_var, program_holder.infer_program):
                continue
            param = load_var_dict[each_var.name()]
            param.stop_gradient = False

    # NOTE: [Recovery stop gradient information based on the program]
    # After loading the model, the stop_gradient information 
    # of the original variable is lost, but if a parameter does not
    # have a corresponding @GRAD variable in the backward program,
    # it can be said that it is also stop_gradient
    all_var_names = _get_all_var_names(program_holder.train_program)
    for var_name in load_var_dict:
        grad_var_name = var_name + core.grad_var_suffix()
        if grad_var_name not in all_var_names:
            load_var_dict[var_name].stop_gradient = True

    return load_var_dict


510 511
def _load_persistable_vars(model_path, var_info_path, program_holder,
                           params_filename):
512 513
    # 1. load extra var info
    with open(var_info_path, 'rb') as f:
514
        extra_var_info = pickle.load(f)
515 516 517 518

    # 2. construct var dict
    load_var_dict = dict()
    load_var_list = []
519 520 521 522
    inv_suffix_varname_dict = {
        value: key
        for key, value in program_holder._suffix_varname_dict.items()
    }
523 524 525 526 527 528 529 530 531 532

    # NOTE(chenweihang): we need load persistable vars based the program,
    # because the program may be pruned when `save_inference_model`, some
    # var in `extra_var_info` may have been pruned 
    for name in sorted(inv_suffix_varname_dict):
        if name not in extra_var_info:
            raise RuntimeError(
                "The model to be loaded is not complete."
                "The variable `%s` of program cannot be found in loaded model.",
                name)
533 534
        # get suffix var name, see [why need to append suffix to persistable vars]
        new_name = inv_suffix_varname_dict[name]
535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551
        # create output varbase
        if extra_var_info[name].get('trainable', None) is not None:
            # use default shape and dtype
            new_var = framework.ParamBase(
                shape=[1],  # only to pass check, this shape is not meaningful
                dtype=core.VarDesc.VarType.FP32,
                name=new_name,
                persistable=True)
        else:
            new_var = framework._varbase_creator(
                name=new_name, persistable=True)

        new_var.stop_gradient = extra_var_info[name]['stop_gradient']
        load_var_dict[new_name] = new_var
        load_var_list.append(new_var)

    # 3. load all vars
552 553 554 555 556 557 558 559 560 561 562
    assert params_filename is not None, "params_filename should not be None."
    var_file_path = os.path.join(model_path, params_filename)
    if not os.path.exists(var_file_path):
        if len(extra_var_info) != 0:
            raise ValueError("The model to be loaded is incomplete.")
    else:
        framework._dygraph_tracer().trace_op(
            type='load_combine',
            inputs={},
            outputs={'Out': load_var_list},
            attrs={'file_path': var_file_path})
563 564 565 566

    return load_var_dict


567 568 569 570 571 572 573 574 575
# NOTE(chenweihang): to adapt paddle.load to get state_dict
def _remove_varname_suffix(var_dict, program_holder):
    no_suffix_var_dict = dict()
    for var_name in var_dict:
        no_suffix_name = program_holder._suffix_varname_dict[var_name]
        no_suffix_var_dict[no_suffix_name] = var_dict[var_name]
    return no_suffix_var_dict


576 577 578 579 580 581 582 583
def _construct_program_holders(model_path, model_filename=None):
    # make sure the path has been checked
    program_holder_dict = dict()

    if model_filename is not None:
        # [compatible] if assign model_filename, only can load one program as Layer.forward
        model_filename = os.path.basename(model_filename)
        model_file_path = os.path.join(model_path, model_filename)
584 585 586 587 588 589 590 591 592 593 594 595 596 597 598
        model_name = model_filename[:-len(INFER_MODEL_SUFFIX)]
        #Load every file that meets the requirements in the directory model_path.
        for filename in os.listdir(model_path):
            if model_filename == filename:
                func_name = 'forward'
                model_file_path = os.path.join(model_path, model_filename)
            elif filename.endswith(INFER_MODEL_SUFFIX) and filename.startswith(
                    model_name):
                func_name = filename[len(model_name) + 1:-len(
                    INFER_MODEL_SUFFIX)]
                model_file_path = os.path.join(model_path, filename)
            else:
                continue
            program_holder_dict[func_name] = _ProgramHolder(
                _load_program_desc(model_file_path))
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616
    else:
        for _, _, file_names in os.walk(model_path):
            for name in file_names:
                if 'model' in name:
                    model_file_path = os.path.join(model_path, name)
                    method_name = name.strip('_')
                    if method_name == 'model':
                        method_name = 'forward'
                    else:
                        method_name.replace('model', '')
                    program_holder_dict[method_name] = _ProgramHolder(
                        _load_program_desc(model_file_path))

    return program_holder_dict


def _construct_params_and_buffers(model_path,
                                  programs,
617 618
                                  params_filename=None,
                                  append_suffix=True):
619 620
    var_info_filename = str(params_filename) + ".info"
    var_info_path = os.path.join(model_path, var_info_filename)
621

622 623
    if os.path.exists(var_info_path):
        var_dict = _load_persistable_vars(model_path, var_info_path,
624
                                          programs['forward'], params_filename)
625 626 627 628 629 630 631 632 633 634 635 636 637
        model_name = params_filename[:-len(INFER_PARAMS_SUFFIX)]
        #Load every file that meets the requirements in the directory model_path.
        for file_name in os.listdir(model_path):
            if file_name.endswith(INFER_PARAMS_SUFFIX) and file_name.startswith(
                    model_name) and file_name != params_filename:
                func_name = file_name[len(model_name) + 1:-len(
                    INFER_PARAMS_SUFFIX)]
            else:
                continue
            var_info_path = os.path.join(model_path, var_info_filename)
            var_dict.update(
                _load_persistable_vars(model_path, var_info_path, programs[
                    func_name], file_name))
638 639 640
    else:
        var_dict = _load_persistable_vars_by_program(
            model_path, programs['forward'], params_filename)
641 642 643 644

    if not append_suffix:
        var_dict = _remove_varname_suffix(var_dict, programs['forward'])

645 646 647 648 649
    return var_dict


class TranslatedLayer(layers.Layer):
    """
650 651 652
    TranslatedLayer is a ``paddle.nn.Layer`` for holding the model 
    loaded by :ref:`api_paddle_jit_load` . It can be used like a 
    general Layer object in eval or train mode.
653 654
    
    .. note:
655
        The TranslatedLayer objects should not be created by constructor, it only can be loaded and constructed by :ref:`api_paddle_jit_load` .
656 657 658 659 660

    Examples:
        .. code-block:: python

            import numpy as np
661 662 663
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
664

665 666 667
            BATCH_SIZE = 16
            BATCH_NUM = 4
            EPOCH_NUM = 4
668

669 670 671 672 673 674 675
            IMAGE_SIZE = 784
            CLASS_NUM = 10

            # define a random dataset
            class RandomDataset(paddle.io.Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples
676

677 678 679 680
                def __getitem__(self, idx):
                    image = np.random.random([IMAGE_SIZE]).astype('float32')
                    label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
                    return image, label
681

682 683
                def __len__(self):
                    return self.num_samples
684

685 686
            class LinearNet(nn.Layer):
                def __init__(self):
687
                    super(LinearNet, self).__init__()
688
                    self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
689

690
                @paddle.jit.to_static
691 692 693
                def forward(self, x):
                    return self._linear(x)

694 695 696 697 698 699 700 701 702 703 704
            def train(layer, loader, loss_fn, opt):
                for epoch_id in range(EPOCH_NUM):
                    for batch_id, (image, label) in enumerate(loader()):
                        out = layer(image)
                        loss = loss_fn(out, label)
                        loss.backward()
                        opt.step()
                        opt.clear_grad()
                        print("Epoch {} batch {}: loss = {}".format(
                            epoch_id, batch_id, np.mean(loss.numpy())))

705 706
            # 1. train & save model.

707 708 709 710
            # create network
            layer = LinearNet()
            loss_fn = nn.CrossEntropyLoss()
            adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())
711

712 713 714 715 716 717 718
            # create data loader
            dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
            loader = paddle.io.DataLoader(dataset,
                batch_size=BATCH_SIZE,
                shuffle=True,
                drop_last=True,
                num_workers=2)
719

720 721
            # train
            train(layer, loader, loss_fn, adam)
722

723
            # save
724
            model_path = "linear.example.model"
725
            paddle.jit.save(layer, model_path)
726 727

            # 2. load model as TranslatedLayer
728 729 730 731

            # load
            translated_layer = paddle.jit.load(model_path)

732 733
            # inference
            translated_layer.eval()
734
            x = paddle.randn([1, IMAGE_SIZE], 'float32')
735
            pred = translated_layer(x)
736

737 738
            # fine-tune
            translated_layer.train()
739 740
            adam = opt.Adam(learning_rate=0.001, parameters=translated_layer.parameters())
            train(translated_layer, loader, loss_fn, adam)
741 742 743 744 745 746 747 748 749 750 751 752

    """

    def __init__(self, programs, persistable_vars):
        super(TranslatedLayer, self).__init__()

        if not isinstance(programs, dict):
            raise TypeError(
                "TranslatedLayer need to use _ProgramHolder's dict for initialization."
            )
        if not isinstance(persistable_vars, dict):
            raise TypeError(
753
                "TranslatedLayer need to use persistable variable dict for initialization."
754 755 756 757
            )

        self._program_holder_dict = programs

758 759 760 761 762 763 764 765
        # NOTE(chenweihang): [ why not use var name directly? ]
        # When add parameter or buffer to Layer by follow apis,
        # the variable name can't contain `.`, beccause which may cause
        # AttributeError when access the newly added parameter or buffer
        # in the form of `self.**.**``, but the ParamBase or BarBase
        # name contains `.` originally, such as `linear_0.w_0`, so here
        # need to generate new var name for each var
        self._persistable_var_name_dict = dict()
766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
        # the TranslatedLayer object holded var names count started from 0
        with unique_name.guard():
            for name, var in persistable_vars.items():
                if isinstance(var, framework.ParamBase):
                    dy_name = _generate_unique_var_name(PARAMETER_NAME_PREFIX)
                    self._persistable_var_name_dict[name] = dy_name
                    self.add_parameter(dy_name, var)
                elif isinstance(var, core.VarBase):
                    dy_name = _generate_unique_var_name(BUFFER_NAME_PREFIX)
                    self._persistable_var_name_dict[name] = dy_name
                    self.register_buffer(dy_name, var)
                else:
                    raise TypeError(
                        "Adding persistent variable which  to layer is not supported now"
                    )
781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799

        self._is_test = True

    @staticmethod
    @framework.dygraph_only
    def _construct(model_path, configs=None):
        # 0. dir and filename check
        model_path = os.path.normpath(model_path)
        if not os.path.isdir(model_path):
            raise ValueError("There is no directory named '%s'" % model_path)
        model_filename = None
        params_filename = None
        if configs is not None:
            model_filename = configs.model_filename
            params_filename = configs.params_filename

        # 1. load program desc & construct _ProgramHolder
        programs = _construct_program_holders(model_path, model_filename)

800
        # 2. load layer parameters & buffers
801 802
        persistable_vars = _construct_params_and_buffers(model_path, programs,
                                                         params_filename)
803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831

        # 3. construct TranslatedLayer object
        translated_layer = TranslatedLayer(programs, persistable_vars)

        # 4. create TranslatedLayer's execution method
        for method_name, program_holder in programs.items():
            setattr(TranslatedLayer, method_name,
                    TranslatedLayer._execution_method_creator(method_name,
                                                              program_holder))

        # 5. set TranslatedLayer's default mode to eval
        translated_layer.eval()

        return translated_layer

    @staticmethod
    def _execution_method_creator(method_name, program_holder):
        def __impl__(self, *input):
            # 1. prepare inputs, outputs, attrs
            input_vars = []
            for i, value in enumerate(input):
                if not isinstance(value, (np.ndarray, core.VarBase)):
                    raise TypeError(
                        "The type of input in TranslatedLayer must be numpy array or Variable(VarBase), but received %s."
                        % type(value))
                # NOTE: In order to unify the API, firstly convert the input to VarBase
                if isinstance(value, np.ndarray):
                    var = core.VarBase(
                        value=value,
832
                        name=program_holder.input_descs[i].name(),
833 834 835 836 837 838 839
                        persistable=False,
                        place=framework._current_expected_place(),
                        zero_copy=True)
                else:
                    var = value
                    # NOTE: we changed var name here, 
                    # but it may be an important name set by user
840
                    var.name = program_holder.input_descs[i].name()
841 842 843 844
                input_vars.append(var)

            persistable_vars = []
            for var_name in program_holder.persistable_names:
845 846 847 848 849
                dy_var_name = self._persistable_var_name_dict[var_name]
                if dy_var_name in self._parameters:
                    persistable_vars.append(self._parameters[dy_var_name])
                elif dy_var_name in self._buffers:
                    persistable_vars.append(self._buffers[dy_var_name])
850 851 852 853 854 855
                else:
                    raise ValueError(
                        "The persistable variable %s is not exists in current TranslatedLayer."
                        % var_name)

            output_vars = []
856
            for var_desc in program_holder.output_descs:
857 858 859 860 861 862 863 864 865 866 867
                var = core.VarBase(var_desc.dtype(),
                                   var_desc.shape(),
                                   var_desc.name(), var_desc.type(), False)
                output_vars.append(var)

            # hold forward variables
            tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
                                         "program_out_scope",
                                         core.VarDesc.VarType.STEP_SCOPES, True)
            tmp_scope_vec.value().set_scope(program_holder.scope)

868
            # 2. run program by op
869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888
            trace_program = program_holder.infer_program if self._is_test else program_holder.train_program
            end_op_index = program_holder.infer_program.block(0).op_size()
            framework._dygraph_tracer().trace_op(
                type='run_program',
                inputs={'X': input_vars,
                        'Params': persistable_vars},
                outputs={'Out': output_vars,
                         'OutScope': tmp_scope_vec},
                attrs={
                    'global_block': trace_program.block(0),
                    'start_op_index': 0,
                    'end_op_index': end_op_index,
                    'is_test': self._is_test
                })

            # NOTE: [ why need set param's gradient type here ]
            # if user set sparse gradient mode, the param's gradient
            # will be SelectedRows, not LoDTensor. But tracer will just
            # set param grad VarBase by forward VarBase(LoDTensor)
            # If we don't change grad_var type here, RunProgramOp need
889
            # transform SelectedRows to LoDTensor forcibly, it may not
890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914
            # be user wanted result.
            for persistable_var in persistable_vars:
                grad_var_name = var.name + core.grad_var_suffix()
                grad_var = trace_program.block(0).find_var(
                    cpt.to_bytes(grad_var_name))
                # NOTE: cannot find var desc maybe not problem, 
                # such as in batch_norm
                if grad_var is None:
                    continue
                persistable_var._set_grad_type(grad_var.type())

            # 3. prepare output, keep same form with inputs
            outs = output_vars
            if len(output_vars) == 1:
                outs = output_vars[0]
            return outs

        __impl__.__name__ = method_name
        return __impl__

    def train(self):
        self._is_test = False

    def eval(self):
        self._is_test = True
915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001

    def program(self, method_name='forward'):
        """
        Gets translated program of specified method.

        Args:
            - method_name (string): mehtod name corresponding to the program
                to be obtained. Default: 'forward'.
        
        Returns:
            Program

        Examples:
            .. code-block:: python
            
                import numpy as np
                import paddle
                import paddle.nn as nn
                import paddle.optimizer as opt

                BATCH_SIZE = 16
                BATCH_NUM = 4
                EPOCH_NUM = 4

                IMAGE_SIZE = 784
                CLASS_NUM = 10

                # define a random dataset
                class RandomDataset(paddle.io.Dataset):
                    def __init__(self, num_samples):
                        self.num_samples = num_samples

                    def __getitem__(self, idx):
                        image = np.random.random([IMAGE_SIZE]).astype('float32')
                        label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
                        return image, label

                    def __len__(self):
                        return self.num_samples

                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)

                    @paddle.jit.to_static
                    def forward(self, x):
                        return self._linear(x)

                def train(layer, loader, loss_fn, opt):
                    for epoch_id in range(EPOCH_NUM):
                        for batch_id, (image, label) in enumerate(loader()):
                            out = layer(image)
                            loss = loss_fn(out, label)
                            loss.backward()
                            opt.step()
                            opt.clear_grad()
                            print("Epoch {} batch {}: loss = {}".format(
                                epoch_id, batch_id, np.mean(loss.numpy())))

                # create network
                layer = LinearNet()
                loss_fn = nn.CrossEntropyLoss()
                adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())

                # create data loader
                dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
                loader = paddle.io.DataLoader(dataset,
                    batch_size=BATCH_SIZE,
                    shuffle=True,
                    drop_last=True,
                    num_workers=2)

                # train
                train(layer, loader, loss_fn, adam)

                # save
                model_path = "linear.example.model"
                paddle.jit.save(layer, model_path)

                # load
                translated_layer = paddle.jit.load(model_path)

                # get program
                program = translated_layer.program()
        """
        # 1. get program holder
1002
        program_holder = self._get_program_holder(method_name)
1003 1004 1005 1006 1007 1008 1009

        # 2. get inference program desc
        program_desc = program_holder.infer_program

        # 3. construct program
        program = _build_program_by_desc(program_desc)
        return program
1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050

    def _get_program_holder(self, method_name='forward'):
        program_holder = self._program_holder_dict.get(method_name, None)
        if program_holder is None:
            raise ValueError(
                "The method `%s` does not exist in loaded TranslatedLayer." %
                method_name)
        return program_holder

    def _input_spec(self, method_name='forward'):
        # 1. get program holder
        program_holder = self._get_program_holder(method_name)

        # 2. build input spec by input desc
        input_spec = []
        for var_desc in program_holder.input_descs:
            spec = paddle.static.InputSpec(
                shape=var_desc.shape(),
                dtype=var_desc.dtype(),
                name=var_desc.name())
            input_spec.append(spec)

        return input_spec

    def _output_spec(self, method_name='forward'):
        # 1. get program holder
        program_holder = self._get_program_holder(method_name)

        # 2. build output spec by output desc
        output_spec = []
        for var_desc in program_holder.output_descs:
            # NOTE(chenweihang): InputSpec describes a tensor, not just input. 
            # Maybe the name is not good enough. Here we use InputSpec to 
            # construct the description of Output tensor
            spec = paddle.static.InputSpec(
                shape=var_desc.shape(),
                dtype=var_desc.dtype(),
                name=var_desc.name())
            output_spec.append(spec)

        return output_spec