topology.py 5.1 KB
Newer Older
Q
qiaolongfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Q
qiaolongfei 已提交
15 16 17
import collections

from paddle.proto.ModelConfig_pb2 import ModelConfig
X
xuwei06 已提交
18
import paddle.trainer_config_helpers as conf_helps
Q
qiaolongfei 已提交
19
import layer as v2_layer
X
xuwei06 已提交
20
import config_base
21
import cPickle
T
update  
typhoonzero 已提交
22
from paddle.trainer import config_parser as cp
Q
qiaolongfei 已提交
23 24 25 26 27 28 29 30 31 32

__all__ = ['Topology']


class Topology(object):
    """
    Topology is used to store the information about all layers
    and network configs.
    """

33 34 35 36 37 38 39 40 41
    def __init__(self, layers, extra_layers=None):
        def __check__(layers):
            if not isinstance(layers, collections.Sequence):
                layers = [layers]
            for layer in layers:
                __check_layer_type__(layer)
            return layers

        layers = __check__(layers)
Q
qiaolongfei 已提交
42
        self.layers = layers
43 44 45 46
        if extra_layers is not None:
            extra_layers = __check__(extra_layers)

        self.__model_config__ = v2_layer.parse_network(
D
dangqingqing 已提交
47
            layers, extra_layers=extra_layers)
D
bug fix  
dangqingqing 已提交
48 49 50 51

        if extra_layers is not None:
            self.layers.extend(extra_layers)

Q
qiaolongfei 已提交
52
        assert isinstance(self.__model_config__, ModelConfig)
Q
qiaolongfei 已提交
53

T
update  
typhoonzero 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
    def update_from_default(self):
        # HACK(typhoonzero): update ParameterConfig(proto) in case of optimizers
        # are defined after layers, or between layers.
        # Must be called from trainer.__init__()
        for parameter in self.__model_config__.parameters:
            print "####", parameter.decay_rate, cp.g_default_decay_rate
            if parameter.momentum == 0.0 and cp.g_default_momentum:
                parameter.momentum = cp.g_default_momentum
            if parameter.decay_rate == 0.0 and cp.g_default_decay_rate:
                parameter.decay_rate = cp.g_default_decay_rate
            if parameter.initial_mean == 0.0:
                parameter.initial_mean = cp.g_default_initial_mean
            if parameter.initial_std == 0.01:
                parameter.initial_std = cp.g_default_initial_std
            if parameter.initial_strategy == 0:
                parameter.initial_strategy = cp.g_default_initial_strategy
            if parameter.initial_smart == False:
                parameter.initial_smart = cp.g_default_initial_smart
            if parameter.num_batches_regularization == 1 and cp.g_default_num_batches_regularization:
                parameter.num_batches_regularization = cp.g_default_num_batches_regularization
            if parameter.gradient_clipping_threshold == 0.0 and cp.g_default_gradient_clipping_threshold:
                parameter.gradient_clipping_threshold = cp.g_default_gradient_clipping_threshold
            if parameter.device == -1 and cp.g_default_device:
                parameter.device = cp.g_default_device
            # FIXME(typhoonzero): ignored: update_hooks, g_default_compact_func

Q
qiaolongfei 已提交
80 81 82 83 84
    def use_sparse_updater(self):
        """
        check if any parameter require to use sparse_update
        :return:
        """
85
        use_sparse = False
Q
qiaolongfei 已提交
86 87
        for parameter in self.__model_config__.parameters:
            if parameter.sparse_update or parameter.sparse_remote_update:
88 89 90
                use_sparse = True
                break
        return use_sparse
Q
qiaolongfei 已提交
91

Q
qiaolongfei 已提交
92
    def proto(self):
Q
qiaolongfei 已提交
93 94 95
        return self.__model_config__

    def get_layer(self, name):
Q
qiaolongfei 已提交
96 97 98 99 100
        """
        get v2.Layer Class instance by layer name
        :param name:
        :return:
        """
X
xuwei06 已提交
101
        return v2_layer.get_layer(name)
Q
qiaolongfei 已提交
102

Q
qiaolongfei 已提交
103
    def data_layers(self):
Q
qiaolongfei 已提交
104 105 106 107
        """
        get all data layer
        :return:
        """
X
xuwei06 已提交
108 109 110 111 112
        data_layers = {}
        for layer in self.proto().layers:
            l = v2_layer.get_layer(layer.name)
            if l and l.layer_type == conf_helps.LayerType.DATA:
                data_layers[layer.name] = l
Q
qiaolongfei 已提交
113 114
        return data_layers

Q
qiaolongfei 已提交
115 116
    def data_type(self):
        """
Q
qiaolongfei 已提交
117 118
        get data_type from proto, such as:
        [('image', dense_vector(768)), ('label', integer_value(10))]
Q
qiaolongfei 已提交
119
        """
D
dangqingqing 已提交
120
        data_layers = self.data_layers()
121

X
xuwei06 已提交
122
        return [(nm, data_layers[nm].data_type)
C
caoying03 已提交
123
                for nm in self.proto().input_layer_names]
Q
qiaolongfei 已提交
124

Q
qiaolongfei 已提交
125 126 127 128 129 130
    def get_layer_proto(self, name):
        for layer in self.__model_config__.layers:
            if layer.name == name:
                return layer
        return None

131 132 133 134 135 136 137 138
    def serialize_for_inference(self, stream):
        protobin = self.proto().SerializeToString()
        data_type = self.data_type()
        cPickle.dump({
            'protobin': protobin,
            'data_type': data_type
        }, stream, cPickle.HIGHEST_PROTOCOL)

Q
qiaolongfei 已提交
139

Q
qiaolongfei 已提交
140
def __check_layer_type__(layer):
X
xuwei06 已提交
141 142
    if not isinstance(layer, config_base.Layer):
        raise ValueError('layer should have type paddle.v2.config_base.Layer')