dump_v2_config.py 2.7 KB
Newer Older
Y
ying 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections

from paddle.trainer_config_helpers.layers import LayerOutput
from paddle.v2.layer import parse_network
from paddle.proto import TrainerConfig_pb2

__all__ = ["dump_v2_config"]


def dump_v2_config(topology, save_path, binary=False):
    """ Dump the network topology to a specified file.

    This function is only used to dump network defined by using PaddlePaddle V2
    APIs. This function will NOT dump configurations related to PaddlePaddle
    optimizer.

    :param topology: The output layers (can be more than one layers given in a
                     Python List or Tuple) of the entire network. Using the
                     specified layers (if more than one layer is given) as root,
                     traversing back to the data layer(s), all the layers
                     connected to the specified output layers will be dumped.
                     Layers not connceted to the specified will not be dumped.
    :type topology: LayerOutput|List|Tuple
    :param save_path: The path to save the dumped network topology.
    :type save_path: str
    :param binary: Whether to dump the serialized network topology or not.
                   The default value is false. NOTE that, if you call this
                   function to generate network topology for PaddlePaddle C-API,
                   a serialized version of network topology is required. When
                   using PaddlePaddle C-API, this flag MUST be set to True.
    :type binary: bool
    """

    if isinstance(topology, LayerOutput):
        topology = [topology]
    elif isinstance(topology, collections.Sequence):
        for out_layer in topology:
            assert isinstance(out_layer, LayerOutput), (
                "The type of each element in the parameter topology "
                "should be LayerOutput.")
    else:
        raise RuntimeError("Error input type for parameter topology.")

    model_str = parse_network(topology)
    with open(save_path, "w") as fout:
        if binary:
            fout.write(model_str.SerializeToString())
        else:
            fout.write(str(model_str))