layer_helper.py 7.4 KB
Newer Older
Y
Yu Yang 已提交
1 2 3
import copy
import itertools

4 5 6
from framework import Variable, g_main_program, \
    g_startup_program, unique_name, dtype_is_floating
from paddle.v2.fluid.initializer import Constant, Xavier
Y
Yu Yang 已提交
7

Y
Yu Yang 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21

class LayerHelper(object):
    def __init__(self, layer_type, **kwargs):
        self.kwargs = kwargs
        self.layer_type = layer_type
        name = self.kwargs.get('name', None)
        if name is None:
            self.kwargs['name'] = unique_name(self.layer_type)

    @property
    def name(self):
        return self.kwargs['name']

    @property
22 23
    def main_program(self):
        prog = self.kwargs.get('main_program', None)
Y
Yu Yang 已提交
24
        if prog is None:
25
            return g_main_program
Y
Yu Yang 已提交
26 27 28
        else:
            return prog

Q
QI JUN 已提交
29
    @property
30 31
    def startup_program(self):
        prog = self.kwargs.get('startup_program', None)
Q
QI JUN 已提交
32
        if prog is None:
33
            return g_startup_program
Q
QI JUN 已提交
34 35 36
        else:
            return prog

Y
Yu Yang 已提交
37
    def append_op(self, *args, **kwargs):
38
        return self.main_program.current_block().append_op(*args, **kwargs)
Y
Yu Yang 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62

    def multiple_input(self, input_param_name='input'):
        inputs = self.kwargs.get(input_param_name, [])
        type_error = TypeError(
            "Input of {0} layer should be Variable or sequence of Variable".
            format(self.layer_type))
        if isinstance(inputs, Variable):
            inputs = [inputs]
        elif not isinstance(inputs, list) and not isinstance(inputs, tuple):
            raise type_error
        else:
            for each in inputs:
                if not isinstance(each, Variable):
                    raise type_error
        return inputs

    def input(self, input_param_name='input'):
        inputs = self.multiple_input(input_param_name)
        if len(inputs) != 1:
            raise "{0} layer only takes one input".format(self.layer_type)
        return inputs[0]

    @property
    def param_attr(self):
63
        default = {'name': None}
Y
Yu Yang 已提交
64
        actual = self.kwargs.get('param_attr', None)
Y
Yu Yang 已提交
65 66 67 68 69 70
        if actual is None:
            actual = default
        for default_field in default.keys():
            if default_field not in actual:
                actual[default_field] = default[default_field]
        return actual
Y
Yu Yang 已提交
71

Q
QI JUN 已提交
72
    @property
Q
QI JUN 已提交
73
    def bias_attr(self):
74
        default = {'name': None}
75
        bias_attr = self.kwargs.get('bias_attr', None)
Q
QI JUN 已提交
76
        if bias_attr is None:
Y
Yu Yang 已提交
77 78 79 80 81 82
            bias_attr = default

        if isinstance(bias_attr, dict):
            for default_field in default.keys():
                if default_field not in bias_attr:
                    bias_attr[default_field] = default[default_field]
Y
Yu Yang 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
        return bias_attr

    def multiple_param_attr(self, length):
        param_attr = self.param_attr
        if isinstance(param_attr, dict):
            param_attr = [param_attr]

        if len(param_attr) != 1 and len(param_attr) != length:
            raise ValueError("parameter number mismatch")
        elif len(param_attr) == 1 and length != 1:
            tmp = [None] * length
            for i in xrange(length):
                tmp[i] = copy.deepcopy(param_attr[0])
            param_attr = tmp
        return param_attr

    def iter_inputs_and_params(self, input_param_name='input'):
        inputs = self.multiple_input(input_param_name)
        param_attrs = self.multiple_param_attr(len(inputs))
        for ipt, param_attr in itertools.izip(inputs, param_attrs):
            yield ipt, param_attr

    def input_dtype(self, input_param_name='input'):
        inputs = self.multiple_input(input_param_name)
        dtype = None
        for each in inputs:
            if dtype is None:
F
fengjiayi 已提交
110 111
                dtype = each.dtype
            elif dtype != each.dtype:
Y
Yu Yang 已提交
112 113 114
                raise ValueError("Data Type mismatch")
        return dtype

115 116
    def create_parameter(self, attr, shape, dtype, suffix='w',
                         initializer=None):
117 118
        # Deepcopy the attr so that parameters can be shared in program
        attr_copy = copy.deepcopy(attr)
119 120
        if initializer is not None:
            attr_copy['initializer'] = initializer
121 122
        else:
            attr_copy['initializer'] = self._get_default_initializer(dtype)
123 124
        if attr_copy['name'] is None:
            attr_copy['name'] = unique_name(".".join([self.name, suffix]))
125
        self.startup_program.global_block().create_parameter(
126
            dtype=dtype, shape=shape, **attr_copy)
127
        return self.main_program.global_block().create_parameter(
Q
Qiao Longfei 已提交
128 129 130 131
            name=attr_copy['name'],
            dtype=dtype,
            shape=shape,
            trainable=attr_copy.get('trainable', True))
Y
Yu Yang 已提交
132 133

    def create_tmp_variable(self, dtype):
134
        return self.main_program.current_block().create_var(
Q
QI JUN 已提交
135 136 137
            name=unique_name(".".join([self.name, 'tmp'])),
            dtype=dtype,
            persistable=False)
Y
Yu Yang 已提交
138

Y
Yu Yang 已提交
139
    def create_variable(self, *args, **kwargs):
140
        return self.main_program.current_block().create_var(*args, **kwargs)
Y
Yu Yang 已提交
141

Q
Qiao Longfei 已提交
142
    def create_global_variable(self, persistable=False, *args, **kwargs):
143
        return self.main_program.global_block().create_var(
Q
Qiao Longfei 已提交
144 145 146 147
            *args, persistable=persistable, **kwargs)

    def set_variable_initializer(self, var, initializer):
        assert isinstance(var, Variable)
148
        self.startup_program.global_block().create_var(
Q
Qiao Longfei 已提交
149 150
            name=var.name,
            type=var.type,
F
fengjiayi 已提交
151
            dtype=var.dtype,
Q
Qiao Longfei 已提交
152 153 154
            shape=var.shape,
            persistable=True,
            initializer=initializer)
Y
Yu Yang 已提交
155

156 157 158 159 160
    def append_bias_op(self,
                       input_var,
                       bias_initializer,
                       dim_start=1,
                       dim_end=None):
161
        """
X
xuwei06 已提交
162
        Append bias operator and return its output. If the user does not set
163
        bias_attr, append_bias_op will return input_var
X
xuwei06 已提交
164

165 166 167 168
        :param input_var: the input variable. The len(input_var.shape) is
        larger or equal than 2.
        :bias_initializer: an instance of a subclass of Initializer used to
        initialize the bias
X
xuwei06 已提交
169 170
        :param dim_start:
        :param dim_end: the shape of the bias will be
X
xuwei06 已提交
171
        input_var.shape[dim_start:dim_end]. The bias is broadcasted to other
X
xuwei06 已提交
172
        dimensions and added to input_var to get the output
173
        """
X
xuwei06 已提交
174
        size = list(input_var.shape[dim_start:dim_end])
Q
QI JUN 已提交
175
        bias_attr = self.bias_attr
Y
Yu Yang 已提交
176 177
        if not bias_attr:
            return input_var
178

Y
Yu Yang 已提交
179
        b = self.create_parameter(
180 181
            attr=bias_attr,
            shape=size,
F
fengjiayi 已提交
182
            dtype=input_var.dtype,
183 184
            suffix='b',
            initializer=bias_initializer)
F
fengjiayi 已提交
185
        tmp = self.create_tmp_variable(dtype=input_var.dtype)
Y
Yu Yang 已提交
186 187 188 189
        self.append_op(
            type='elementwise_add',
            inputs={'X': [input_var],
                    'Y': [b]},
X
xuwei06 已提交
190 191
            outputs={'Out': [tmp]},
            attrs={'axis': dim_start})
Y
Yu Yang 已提交
192 193 194 195 196 197 198 199
        return tmp

    def append_activation(self, input_var):
        act = self.kwargs.get('act', None)
        if act is None:
            return input_var
        if isinstance(act, basestring):
            act = {'type': act}
F
fengjiayi 已提交
200
        tmp = self.create_tmp_variable(dtype=input_var.dtype)
Y
Yu Yang 已提交
201 202 203 204 205 206 207
        act_type = act.pop('type')
        self.append_op(
            type=act_type,
            inputs={"X": [input_var]},
            outputs={"Y": [tmp]},
            attrs=act)
        return tmp
208 209 210

    def _get_default_initializer(self, dtype):
        if dtype is None or dtype_is_floating(dtype) is True:
211
            return Xavier()
212 213
        else:
            # For integer and boolean types, initialize with all zeros
214
            return Constant()