layer_helper.py 7.1 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

Y
Yu Yang 已提交
17
import copy
18
import six
Y
Yu Yang 已提交
19

20
from .framework import Parameter, dtype_is_floating, in_dygraph_mode, OpProtoHolder, _global_flags
21
from . import unique_name
22
from paddle.fluid.initializer import Constant, Xavier
23
from .param_attr import ParamAttr
24
from . import core
25
from six.moves import zip
26
from .layer_helper_base import LayerHelperBase
27
from .dygraph_utils import _append_activation_in_dygraph
Y
Yu Yang 已提交
28

Y
Yu Yang 已提交
29

30
class LayerHelper(LayerHelperBase):
Y
Yu Yang 已提交
31 32 33
    def __init__(self, layer_type, **kwargs):
        self.kwargs = kwargs
        name = self.kwargs.get('name', None)
L
lujun 已提交
34
        # TODO(panyx0718, minqiyang): dygraph mode
X
Xin Pan 已提交
35
        # can not use both `layer_type` and `name`. Deprecate LayerHelper
L
lujun 已提交
36
        # and write a Helper for dygraph mode.
Y
Yu Yang 已提交
37
        if name is None:
38
            self.kwargs['name'] = unique_name.generate(layer_type)
Y
Yu Yang 已提交
39

40 41
        super(LayerHelper, self).__init__(
            self.kwargs['name'], layer_type=layer_type)
42

Y
Yu Yang 已提交
43
    def append_op(self, *args, **kwargs):
44
        return self.main_program.current_block().append_op(*args, **kwargs)
Y
Yu Yang 已提交
45 46 47

    def multiple_input(self, input_param_name='input'):
        inputs = self.kwargs.get(input_param_name, [])
48 49 50 51
        ret = []
        if isinstance(inputs, list) or isinstance(inputs, tuple):
            for inp in inputs:
                ret.append(self.to_variable(inp))
Y
Yu Yang 已提交
52
        else:
53 54
            ret.append(self.to_variable(inputs))
        return ret
Y
Yu Yang 已提交
55 56 57 58 59 60 61 62 63

    def input(self, input_param_name='input'):
        inputs = self.multiple_input(input_param_name)
        if len(inputs) != 1:
            raise "{0} layer only takes one input".format(self.layer_type)
        return inputs[0]

    @property
    def param_attr(self):
Y
yuyang18 已提交
64
        return ParamAttr._to_attr(self.kwargs.get('param_attr', None))
Y
Yu Yang 已提交
65

Q
QI JUN 已提交
66
    @property
Q
QI JUN 已提交
67
    def bias_attr(self):
Y
yuyang18 已提交
68
        return ParamAttr._to_attr(self.kwargs.get('bias_attr', None))
Y
Yu Yang 已提交
69

70
    #TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of param_attr
Y
Yu Yang 已提交
71 72
    def multiple_param_attr(self, length):
        param_attr = self.param_attr
Y
Yu Yang 已提交
73
        if isinstance(param_attr, ParamAttr):
Y
Yu Yang 已提交
74 75 76 77 78 79
            param_attr = [param_attr]

        if len(param_attr) != 1 and len(param_attr) != length:
            raise ValueError("parameter number mismatch")
        elif len(param_attr) == 1 and length != 1:
            tmp = [None] * length
M
minqiyang 已提交
80
            for i in six.moves.range(length):
Y
Yu Yang 已提交
81 82 83 84 85 86 87
                tmp[i] = copy.deepcopy(param_attr[0])
            param_attr = tmp
        return param_attr

    def iter_inputs_and_params(self, input_param_name='input'):
        inputs = self.multiple_input(input_param_name)
        param_attrs = self.multiple_param_attr(len(inputs))
88
        for ipt, param_attr in zip(inputs, param_attrs):
Y
Yu Yang 已提交
89 90 91 92 93 94 95
            yield ipt, param_attr

    def input_dtype(self, input_param_name='input'):
        inputs = self.multiple_input(input_param_name)
        dtype = None
        for each in inputs:
            if dtype is None:
F
fengjiayi 已提交
96 97
                dtype = each.dtype
            elif dtype != each.dtype:
Q
Qiao Longfei 已提交
98 99
                raise ValueError("Data Type mismatch: %d to %d" %
                                 (dtype, each.dtype))
Y
Yu Yang 已提交
100 101
        return dtype

Q
Qiao Longfei 已提交
102 103 104 105 106 107
    def get_parameter(self, name):
        param = self.main_program.global_block().var(name)
        if not isinstance(param, Parameter):
            raise ValueError("no Parameter name %s found" % name)
        return param

108
    #TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of bias_attr
Y
Yu Yang 已提交
109
    def append_bias_op(self, input_var, dim_start=1, dim_end=None):
110
        """
X
xuwei06 已提交
111
        Append bias operator and return its output. If the user does not set
112
        bias_attr, append_bias_op will return input_var
X
xuwei06 已提交
113

114 115 116 117
        :param input_var: the input variable. The len(input_var.shape) is
        larger or equal than 2.
        :bias_initializer: an instance of a subclass of Initializer used to
        initialize the bias
X
xuwei06 已提交
118 119
        :param dim_start:
        :param dim_end: the shape of the bias will be
X
xuwei06 已提交
120
        input_var.shape[dim_start:dim_end]. The bias is broadcasted to other
X
xuwei06 已提交
121
        dimensions and added to input_var to get the output
122
        """
X
xuwei06 已提交
123
        size = list(input_var.shape[dim_start:dim_end])
Q
QI JUN 已提交
124
        bias_attr = self.bias_attr
Y
Yu Yang 已提交
125 126
        if not bias_attr:
            return input_var
127

Y
Yu Yang 已提交
128
        b = self.create_parameter(
Y
Yu Yang 已提交
129
            attr=bias_attr, shape=size, dtype=input_var.dtype, is_bias=True)
X
Xin Pan 已提交
130
        tmp = self.create_variable_for_type_inference(dtype=input_var.dtype)
Y
Yu Yang 已提交
131 132 133 134
        self.append_op(
            type='elementwise_add',
            inputs={'X': [input_var],
                    'Y': [b]},
X
xuwei06 已提交
135 136
            outputs={'Out': [tmp]},
            attrs={'axis': dim_start})
Y
Yu Yang 已提交
137 138
        return tmp

139
    #TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of act
M
minqiyang 已提交
140
    def append_activation(self, input_var):
Y
Yu Yang 已提交
141 142 143
        act = self.kwargs.get('act', None)
        if act is None:
            return input_var
144
        if isinstance(act, six.string_types):
Y
Yu Yang 已提交
145
            act = {'type': act}
M
minqiyang 已提交
146 147
        else:
            raise TypeError(str(act) + " should be unicode or str")
148

149
        use_cudnn = None
K
Kexin Zhao 已提交
150
        if 'use_cudnn' in self.kwargs and self.kwargs.get('use_cudnn'):
151 152
            use_cudnn = self.kwargs.get('use_cudnn')
            act['use_cudnn'] = use_cudnn
153
        use_mkldnn = self.kwargs.get(
154
            'use_mkldnn', _global_flags().get("FLAGS_use_mkldnn", False))
155 156
        if use_mkldnn:
            act['use_mkldnn'] = use_mkldnn
Y
Yu Yang 已提交
157
        act_type = act.pop('type')
158 159 160 161 162 163 164 165 166 167 168 169
        if in_dygraph_mode():
            res = _append_activation_in_dygraph(input_var, act_type, use_cudnn,
                                                use_mkldnn)
            return res
        else:
            tmp = self.create_variable_for_type_inference(dtype=input_var.dtype)
            self.append_op(
                type=act_type,
                inputs={"X": [input_var]},
                outputs={"Out": [tmp]},
                attrs=act)
            return tmp
170

171
    #TODO (jiabin): should we remove this since it has never be used
172 173
    def _get_default_initializer(self, dtype):
        if dtype is None or dtype_is_floating(dtype) is True:
174
            return Xavier()
175 176
        else:
            # For integer and boolean types, initialize with all zeros
177
            return Constant()
Y
Yang Yu 已提交
178

179
    #TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of kwargs
Y
Yang Yu 已提交
180 181 182 183 184
    def is_instance(self, param_name, cls):
        param = self.kwargs.get(param_name, None)
        if not isinstance(param, cls):
            raise TypeError("The input {0} parameter of method {1} must be {2}",
                            param_name, self.layer_type, cls.__name__)