layer_helper.py 7.4 KB
Newer Older
Y
Yu Yang 已提交
1 2 3
import copy
import itertools

Y
Yu Yang 已提交
4
from framework import Variable, default_main_program, default_startup_program, unique_name, dtype_is_floating
5
from paddle.v2.fluid.initializer import Constant, Xavier
Y
Yu Yang 已提交
6

Y
Yu Yang 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20

class LayerHelper(object):
    def __init__(self, layer_type, **kwargs):
        self.kwargs = kwargs
        self.layer_type = layer_type
        name = self.kwargs.get('name', None)
        if name is None:
            self.kwargs['name'] = unique_name(self.layer_type)

    @property
    def name(self):
        return self.kwargs['name']

    @property
21 22
    def main_program(self):
        prog = self.kwargs.get('main_program', None)
Y
Yu Yang 已提交
23
        if prog is None:
Y
Yu Yang 已提交
24
            return default_main_program()
Y
Yu Yang 已提交
25 26 27
        else:
            return prog

Q
QI JUN 已提交
28
    @property
29 30
    def startup_program(self):
        prog = self.kwargs.get('startup_program', None)
Q
QI JUN 已提交
31
        if prog is None:
Y
Yu Yang 已提交
32
            return default_startup_program()
Q
QI JUN 已提交
33 34 35
        else:
            return prog

Y
Yu Yang 已提交
36
    def append_op(self, *args, **kwargs):
37
        return self.main_program.current_block().append_op(*args, **kwargs)
Y
Yu Yang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

    def multiple_input(self, input_param_name='input'):
        inputs = self.kwargs.get(input_param_name, [])
        type_error = TypeError(
            "Input of {0} layer should be Variable or sequence of Variable".
            format(self.layer_type))
        if isinstance(inputs, Variable):
            inputs = [inputs]
        elif not isinstance(inputs, list) and not isinstance(inputs, tuple):
            raise type_error
        else:
            for each in inputs:
                if not isinstance(each, Variable):
                    raise type_error
        return inputs

    def input(self, input_param_name='input'):
        inputs = self.multiple_input(input_param_name)
        if len(inputs) != 1:
            raise "{0} layer only takes one input".format(self.layer_type)
        return inputs[0]

    @property
    def param_attr(self):
62
        default = {'name': None}
Y
Yu Yang 已提交
63
        actual = self.kwargs.get('param_attr', None)
Y
Yu Yang 已提交
64 65 66 67 68 69
        if actual is None:
            actual = default
        for default_field in default.keys():
            if default_field not in actual:
                actual[default_field] = default[default_field]
        return actual
Y
Yu Yang 已提交
70

Q
QI JUN 已提交
71
    @property
Q
QI JUN 已提交
72
    def bias_attr(self):
73
        default = {'name': None}
74
        bias_attr = self.kwargs.get('bias_attr', None)
Q
QI JUN 已提交
75
        if bias_attr is None:
Y
Yu Yang 已提交
76 77 78 79 80 81
            bias_attr = default

        if isinstance(bias_attr, dict):
            for default_field in default.keys():
                if default_field not in bias_attr:
                    bias_attr[default_field] = default[default_field]
Y
Yu Yang 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
        return bias_attr

    def multiple_param_attr(self, length):
        param_attr = self.param_attr
        if isinstance(param_attr, dict):
            param_attr = [param_attr]

        if len(param_attr) != 1 and len(param_attr) != length:
            raise ValueError("parameter number mismatch")
        elif len(param_attr) == 1 and length != 1:
            tmp = [None] * length
            for i in xrange(length):
                tmp[i] = copy.deepcopy(param_attr[0])
            param_attr = tmp
        return param_attr

    def iter_inputs_and_params(self, input_param_name='input'):
        inputs = self.multiple_input(input_param_name)
        param_attrs = self.multiple_param_attr(len(inputs))
        for ipt, param_attr in itertools.izip(inputs, param_attrs):
            yield ipt, param_attr

    def input_dtype(self, input_param_name='input'):
        inputs = self.multiple_input(input_param_name)
        dtype = None
        for each in inputs:
            if dtype is None:
F
fengjiayi 已提交
109 110
                dtype = each.dtype
            elif dtype != each.dtype:
Y
Yu Yang 已提交
111 112 113
                raise ValueError("Data Type mismatch")
        return dtype

114 115
    def create_parameter(self, attr, shape, dtype, suffix='w',
                         initializer=None):
116 117
        # Deepcopy the attr so that parameters can be shared in program
        attr_copy = copy.deepcopy(attr)
118 119
        if initializer is not None:
            attr_copy['initializer'] = initializer
120 121
        else:
            attr_copy['initializer'] = self._get_default_initializer(dtype)
122 123
        if attr_copy['name'] is None:
            attr_copy['name'] = unique_name(".".join([self.name, suffix]))
124
        self.startup_program.global_block().create_parameter(
125
            dtype=dtype, shape=shape, **attr_copy)
126
        return self.main_program.global_block().create_parameter(
Q
Qiao Longfei 已提交
127 128 129 130
            name=attr_copy['name'],
            dtype=dtype,
            shape=shape,
            trainable=attr_copy.get('trainable', True))
Y
Yu Yang 已提交
131 132

    def create_tmp_variable(self, dtype):
133
        return self.main_program.current_block().create_var(
Q
QI JUN 已提交
134 135 136
            name=unique_name(".".join([self.name, 'tmp'])),
            dtype=dtype,
            persistable=False)
Y
Yu Yang 已提交
137

Y
Yu Yang 已提交
138
    def create_variable(self, *args, **kwargs):
139
        return self.main_program.current_block().create_var(*args, **kwargs)
Y
Yu Yang 已提交
140

Q
Qiao Longfei 已提交
141
    def create_global_variable(self, persistable=False, *args, **kwargs):
142
        return self.main_program.global_block().create_var(
Q
Qiao Longfei 已提交
143 144 145 146
            *args, persistable=persistable, **kwargs)

    def set_variable_initializer(self, var, initializer):
        assert isinstance(var, Variable)
147
        self.startup_program.global_block().create_var(
Q
Qiao Longfei 已提交
148 149
            name=var.name,
            type=var.type,
F
fengjiayi 已提交
150
            dtype=var.dtype,
Q
Qiao Longfei 已提交
151 152 153
            shape=var.shape,
            persistable=True,
            initializer=initializer)
Y
Yu Yang 已提交
154

155 156 157 158 159
    def append_bias_op(self,
                       input_var,
                       bias_initializer,
                       dim_start=1,
                       dim_end=None):
160
        """
X
xuwei06 已提交
161
        Append bias operator and return its output. If the user does not set
162
        bias_attr, append_bias_op will return input_var
X
xuwei06 已提交
163

164 165 166 167
        :param input_var: the input variable. The len(input_var.shape) is
        larger or equal than 2.
        :bias_initializer: an instance of a subclass of Initializer used to
        initialize the bias
X
xuwei06 已提交
168 169
        :param dim_start:
        :param dim_end: the shape of the bias will be
X
xuwei06 已提交
170
        input_var.shape[dim_start:dim_end]. The bias is broadcasted to other
X
xuwei06 已提交
171
        dimensions and added to input_var to get the output
172
        """
X
xuwei06 已提交
173
        size = list(input_var.shape[dim_start:dim_end])
Q
QI JUN 已提交
174
        bias_attr = self.bias_attr
Y
Yu Yang 已提交
175 176
        if not bias_attr:
            return input_var
177

Y
Yu Yang 已提交
178
        b = self.create_parameter(
179 180
            attr=bias_attr,
            shape=size,
F
fengjiayi 已提交
181
            dtype=input_var.dtype,
182 183
            suffix='b',
            initializer=bias_initializer)
F
fengjiayi 已提交
184
        tmp = self.create_tmp_variable(dtype=input_var.dtype)
Y
Yu Yang 已提交
185 186 187 188
        self.append_op(
            type='elementwise_add',
            inputs={'X': [input_var],
                    'Y': [b]},
X
xuwei06 已提交
189 190
            outputs={'Out': [tmp]},
            attrs={'axis': dim_start})
Y
Yu Yang 已提交
191 192 193 194 195 196 197 198
        return tmp

    def append_activation(self, input_var):
        act = self.kwargs.get('act', None)
        if act is None:
            return input_var
        if isinstance(act, basestring):
            act = {'type': act}
F
fengjiayi 已提交
199
        tmp = self.create_tmp_variable(dtype=input_var.dtype)
Y
Yu Yang 已提交
200 201 202 203 204 205 206
        act_type = act.pop('type')
        self.append_op(
            type=act_type,
            inputs={"X": [input_var]},
            outputs={"Y": [tmp]},
            attrs=act)
        return tmp
207 208 209

    def _get_default_initializer(self, dtype):
        if dtype is None or dtype_is_floating(dtype) is True:
210
            return Xavier()
211 212
        else:
            # For integer and boolean types, initialize with all zeros
213
            return Constant()