layer_object_helper.py 6.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
姜永久 已提交
16
from ..framework import Parameter, in_dygraph_mode, _global_flags
17 18
from ..param_attr import ParamAttr
from .. import core
19

20
from ..layer_helper_base import LayerHelperBase
21
from ..dygraph_utils import _append_activation_in_dygraph
22 23 24 25


class LayerObjectHelper(LayerHelperBase):
    def __init__(self, name):
26
        super().__init__(name, layer_type=name)
27

28 29 30 31 32 33 34 35
    def append_op(
        self,
        type=None,
        inputs=None,
        outputs=None,
        attrs=None,
        stop_gradient=None,
    ):
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
        """append an operator for this layer object.

           Args:
               type: operator type
               inputs: input variable of the operator
               dtype: data type of this parameter
               is_bias: if this is a bias parameter
               default_initializer: set the default initializer for this parameter

        Returns created parameter Variable.
        """
        return self.main_program.current_block().append_op(
            type=type,
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
52 53
            stop_gradient=stop_gradient,
        )
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68

    def _multiple_input(self, inputs_in):
        inputs = inputs_in
        ret = []
        if isinstance(inputs, (list, tuple)):
            for inp in inputs:
                ret.append(self.to_variable(inp))
        else:
            ret.append(self.to_variable(inputs))
        return ret

    # TODO: make it public when we need it
    def _input(self, inputs_in):
        inputs = self._multiple_input(inputs_in)
        if len(inputs) != 1:
69
            raise "{0} layer only takes one input in".format(self.layer_type)
70 71 72 73 74 75 76 77
        return inputs[0]

    def _multiple_param_attr(self, length, param_attr_in=None):
        param_attr = param_attr_in
        if isinstance(param_attr, ParamAttr):
            param_attr = [param_attr]

        if len(param_attr) != 1 and len(param_attr) != length:
78 79 80
            raise ValueError(
                "parameter number mismatch in {}".format(self.name)
            )
81 82
        elif len(param_attr) == 1 and length != 1:
            tmp = [None] * length
83
            for i in range(length):
84 85 86 87 88 89 90 91 92 93 94 95 96
                tmp[i] = copy.deepcopy(param_attr[0])
            param_attr = tmp
        return param_attr

    def iter_inputs_and_params(self, inputs_in, param_attr_in=None):
        """Access all inputs and params one by one

           Args:
               inputs_in: inputs to be iter
               param_attr_in: param_attr to be iter

        Returns input, param_attr
        """
97 98
        param_attr_in = ParamAttr._to_attr(param_attr_in)
        if isinstance(param_attr_in, bool):
99 100 101
            raise ValueError(
                'Param_attr should not be False in {}'.format(self.name)
            )
102 103 104 105 106 107 108 109 110 111 112 113 114 115
        inputs = inputs_in if (inputs_in is not None) else []
        inputs = self._multiple_input(inputs)
        param_attrs = self._multiple_param_attr(len(inputs), param_attr_in)
        for ipt, param_attr in zip(inputs, param_attrs):
            yield ipt, param_attr

    def input_dtype(self, inputs_in):
        """Get input data type

           Args:
               inputs_in: inputs wanted know the data type

        Returns dtype of the input
        """
116
        inputs_in = inputs_in if (inputs_in is not None) else []
117 118 119 120 121 122
        inputs = self._multiple_input(inputs_in)
        dtype = None
        for each in inputs:
            if dtype is None:
                dtype = each.dtype
            elif dtype != each.dtype:
123 124 125 126
                raise ValueError(
                    "Data Type mismatch: %d to %d in %s"
                    % (dtype, each.dtype, self.name)
                )
127 128 129 130 131 132 133 134 135 136 137 138
        return dtype

    def get_parameter(self, name):
        """Get parameter specifically

           Args:
               name: parameter's name

        Returns target parameter
        """
        param = self.main_program.global_block().var(name)
        if not isinstance(param, Parameter):
139 140 141
            raise ValueError(
                "no Parameter name %s found in %s" % (name, self.name)
            )
142 143 144
        return param

    # TODO: this should not be called anymore after all activation func move to Layers
145
    def append_activation(self, input_var, act=None, use_cudnn=None):
146 147 148 149 150 151 152 153 154 155 156 157 158
        """Append activation

            Args:
                input_var: the input variable. The len(input_var.shape) is
                larger or equal than 2.
                act: activation type
                use_cudnn: if use cudnn

        Return the Variable of after append activation
        """
        act = act
        if act is None:
            return input_var
159
        if isinstance(act, str):
160 161
            act = {'type': act}
        else:
162
            raise TypeError(
163 164
                str(act) + " should be unicode or str in %s ", self.name
            )
165 166 167

        if (use_cudnn is not None) and use_cudnn:
            act['use_cudnn'] = use_cudnn
168
        use_mkldnn = _global_flags()["FLAGS_use_mkldnn"]
169 170
        if (use_mkldnn is not None) and use_mkldnn:
            act['use_mkldnn'] = use_mkldnn
171
        act_type = act.pop('type')
姜永久 已提交
172
        if in_dygraph_mode():
173 174 175
            res = _append_activation_in_dygraph(
                input_var, act_type, use_cudnn, use_mkldnn
            )
176 177 178
            return res
        else:
            tmp = self.create_variable_for_type_inference(dtype=input_var.dtype)
179 180 181 182 183 184
            self.append_op(
                type=act_type,
                inputs={"X": [input_var]},
                outputs={"Out": [tmp]},
                attrs=act,
            )
185
            return tmp
186 187 188 189 190 191 192 193 194 195 196 197

    def is_instance(self, param, cls):
        """Check if the input parameter is instance of input class

            Args:
                param: parameter to be check
                cls: class of the parameter

        Return result of the check (True or False)
        """
        param = param
        if not isinstance(param, cls):
198 199
            raise TypeError(
                "The input {0} parameter of method {1} must be {2}, in layer {3}",
200 201 202 203 204
                param,
                self.layer_type,
                cls.__name__,
                self.name,
            )