topology.py 5.1 KB
Newer Older
Q
qiaolongfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Q
qiaolongfei 已提交
15 16 17
import collections

from paddle.proto.ModelConfig_pb2 import ModelConfig
X
xuwei06 已提交
18
import paddle.trainer_config_helpers as conf_helps
Q
qiaolongfei 已提交
19
import layer as v2_layer
X
xuwei06 已提交
20
import config_base
21
import cPickle
T
update  
typhoonzero 已提交
22
from paddle.trainer import config_parser as cp
Q
qiaolongfei 已提交
23 24 25 26 27 28 29 30 31 32

__all__ = ['Topology']


class Topology(object):
    """
    Topology is used to store the information about all layers
    and network configs.
    """

33 34 35 36 37 38 39 40 41
    def __init__(self, layers, extra_layers=None):
        def __check__(layers):
            if not isinstance(layers, collections.Sequence):
                layers = [layers]
            for layer in layers:
                __check_layer_type__(layer)
            return layers

        layers = __check__(layers)
Q
qiaolongfei 已提交
42
        self.layers = layers
43 44 45 46
        if extra_layers is not None:
            extra_layers = __check__(extra_layers)

        self.__model_config__ = v2_layer.parse_network(
D
dangqingqing 已提交
47
            layers, extra_layers=extra_layers)
D
bug fix  
dangqingqing 已提交
48 49 50 51

        if extra_layers is not None:
            self.layers.extend(extra_layers)

Q
qiaolongfei 已提交
52
        assert isinstance(self.__model_config__, ModelConfig)
Q
qiaolongfei 已提交
53

T
update  
typhoonzero 已提交
54
    def update_from_default(self):
T
typhoonzero 已提交
55 56
        # HACK(typhoonzero): update ParameterConfig(proto) in case of
        # optimizers are defined after layers, or between layers.
T
update  
typhoonzero 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70
        # Must be called from trainer.__init__()
        for parameter in self.__model_config__.parameters:
            if parameter.momentum == 0.0 and cp.g_default_momentum:
                parameter.momentum = cp.g_default_momentum
            if parameter.decay_rate == 0.0 and cp.g_default_decay_rate:
                parameter.decay_rate = cp.g_default_decay_rate
            if parameter.initial_mean == 0.0:
                parameter.initial_mean = cp.g_default_initial_mean
            if parameter.initial_std == 0.01:
                parameter.initial_std = cp.g_default_initial_std
            if parameter.initial_strategy == 0:
                parameter.initial_strategy = cp.g_default_initial_strategy
            if parameter.initial_smart == False:
                parameter.initial_smart = cp.g_default_initial_smart
T
typhoonzero 已提交
71 72 73 74 75 76 77 78
            if parameter.num_batches_regularization == 1 and \
                cp.g_default_num_batches_regularization:
                parameter.num_batches_regularization = \
                    cp.g_default_num_batches_regularization
            if parameter.gradient_clipping_threshold == 0.0 and \
                cp.g_default_gradient_clipping_threshold:
                parameter.gradient_clipping_threshold = \
                    cp.g_default_gradient_clipping_threshold
T
update  
typhoonzero 已提交
79 80 81 82
            if parameter.device == -1 and cp.g_default_device:
                parameter.device = cp.g_default_device
            # FIXME(typhoonzero): ignored: update_hooks, g_default_compact_func

Q
qiaolongfei 已提交
83 84 85 86 87
    def use_sparse_updater(self):
        """
        check if any parameter require to use sparse_update
        :return:
        """
88
        use_sparse = False
Q
qiaolongfei 已提交
89 90
        for parameter in self.__model_config__.parameters:
            if parameter.sparse_update or parameter.sparse_remote_update:
91 92 93
                use_sparse = True
                break
        return use_sparse
Q
qiaolongfei 已提交
94

Q
qiaolongfei 已提交
95
    def proto(self):
Q
qiaolongfei 已提交
96 97 98
        return self.__model_config__

    def get_layer(self, name):
Q
qiaolongfei 已提交
99 100 101 102 103
        """
        get v2.Layer Class instance by layer name
        :param name:
        :return:
        """
X
xuwei06 已提交
104
        return v2_layer.get_layer(name)
Q
qiaolongfei 已提交
105

Q
qiaolongfei 已提交
106
    def data_layers(self):
Q
qiaolongfei 已提交
107 108 109 110
        """
        get all data layer
        :return:
        """
X
xuwei06 已提交
111 112 113 114 115
        data_layers = {}
        for layer in self.proto().layers:
            l = v2_layer.get_layer(layer.name)
            if l and l.layer_type == conf_helps.LayerType.DATA:
                data_layers[layer.name] = l
Q
qiaolongfei 已提交
116 117
        return data_layers

Q
qiaolongfei 已提交
118 119
    def data_type(self):
        """
Q
qiaolongfei 已提交
120 121
        get data_type from proto, such as:
        [('image', dense_vector(768)), ('label', integer_value(10))]
Q
qiaolongfei 已提交
122
        """
D
dangqingqing 已提交
123
        data_layers = self.data_layers()
124

X
xuwei06 已提交
125
        return [(nm, data_layers[nm].data_type)
C
caoying03 已提交
126
                for nm in self.proto().input_layer_names]
Q
qiaolongfei 已提交
127

Q
qiaolongfei 已提交
128 129 130 131 132 133
    def get_layer_proto(self, name):
        for layer in self.__model_config__.layers:
            if layer.name == name:
                return layer
        return None

134 135 136 137 138 139 140 141
    def serialize_for_inference(self, stream):
        protobin = self.proto().SerializeToString()
        data_type = self.data_type()
        cPickle.dump({
            'protobin': protobin,
            'data_type': data_type
        }, stream, cPickle.HIGHEST_PROTOCOL)

Q
qiaolongfei 已提交
142

Q
qiaolongfei 已提交
143
def __check_layer_type__(layer):
X
xuwei06 已提交
144 145
    if not isinstance(layer, config_base.Layer):
        raise ValueError('layer should have type paddle.v2.config_base.Layer')