op.py 12.0 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

M
minqiyang 已提交
15
import numpy as np
16

17 18
import paddle.fluid.core as core
import paddle.fluid.proto.framework_pb2 as framework_pb2
Y
Yu Yang 已提交
19 20 21


def get_all_op_protos():
22
    """
23
    Get all registered op proto from PaddlePaddle C++ end.
24
    :return: A list of registered OpProto.
25
    """
Y
Yu Yang 已提交
26 27 28
    protostrs = core.get_all_op_protos()
    ret_values = []
    for pbstr in protostrs:
29
        op_proto = framework_pb2.OpProto.FromString(bytes(pbstr))
Y
Yu Yang 已提交
30 31
        ret_values.append(op_proto)
    return ret_values
32 33


Y
Yu Yang 已提交
34
def is_str(s):
35
    return isinstance(s, str)
Y
Yu Yang 已提交
36 37


38 39
class OpDescCreationMethod(object):
    """
40 41
    Convert the user's input(only keyword arguments are supported) to OpDesc
    based on the OpProto.
Y
Yan Chunwei 已提交
42

43 44 45 46 47
    :param op_proto: The OpProto object.
    :type op_proto: op_proto_pb2.OpProto
    """

    def __init__(self, op_proto):
Y
Yu Yang 已提交
48
        if not isinstance(op_proto, framework_pb2.OpProto):
49 50
            raise TypeError(
                "Type of op_proto should be OpProto in PaddlePaddle.")
51
        self.__op_proto__ = op_proto
52
        self.__extra_attrs__ = core.get_op_extra_attrs(op_proto.type)
53 54 55

    def __call__(self, *args, **kwargs):
        """
56
        Convert user's input to OpDesc. Only keyword arguments are supported.
57
        :return: The OpDesc based on user input.
58 59 60
        :rtype: op_desc_pb2.OpDesc
        """
        if len(args) != 0:
61
            raise ValueError("Only keyword arguments are supported.")
Y
Yu Yang 已提交
62 63 64 65 66 67 68
        op_desc = framework_pb2.OpDesc()
        for input_parameter in self.__op_proto__.inputs:
            input_arguments = kwargs.get(input_parameter.name, [])
            if is_str(input_arguments):
                input_arguments = [input_arguments]

            if not input_parameter.duplicable and len(input_arguments) > 1:
69 70 71
                raise ValueError(
                    "Input %s expects only one input, but %d are given." %
                    (input_parameter.name, len(input_arguments)))
Y
Yu Yang 已提交
72 73 74 75 76 77 78 79 80 81 82 83

            ipt = op_desc.inputs.add()
            ipt.parameter = input_parameter.name
            ipt.arguments.extend(input_arguments)

        for output_parameter in self.__op_proto__.outputs:
            output_arguments = kwargs.get(output_parameter.name, [])
            if is_str(output_arguments):
                output_arguments = [output_arguments]

            if not output_parameter.duplicable and len(output_arguments) > 1:
                raise ValueError(
84
                    "Output %s expects only one output, but %d are given." %
Y
Yu Yang 已提交
85 86 87 88 89
                    (output_parameter.name, len(output_arguments)))

            out = op_desc.outputs.add()
            out.parameter = output_parameter.name
            out.arguments.extend(output_arguments)
90 91 92 93 94 95 96 97 98 99 100 101 102

        # Types
        op_desc.type = self.__op_proto__.type

        # Attrs
        for attr in self.__op_proto__.attrs:
            if attr.generated:
                continue
            user_defined_attr = kwargs.get(attr.name, None)
            if user_defined_attr is not None:
                new_attr = op_desc.attrs.add()
                new_attr.name = attr.name
                new_attr.type = attr.type
M
minqiyang 已提交
103 104
                if isinstance(user_defined_attr, np.ndarray):
                    user_defined_attr = user_defined_attr.tolist()
Y
Yu Yang 已提交
105
                if attr.type == framework_pb2.INT:
106
                    new_attr.i = user_defined_attr
Y
Yu Yang 已提交
107
                elif attr.type == framework_pb2.FLOAT:
108
                    new_attr.f = user_defined_attr
J
JiabinYang 已提交
109 110
                elif attr.type == framework_pb2.LONG:
                    new_attr.l = user_defined_attr
Y
Yu Yang 已提交
111
                elif attr.type == framework_pb2.STRING:
112
                    new_attr.s = user_defined_attr
113
                elif attr.type == framework_pb2.BOOLEAN:
D
dangqingqing 已提交
114
                    new_attr.b = user_defined_attr
Y
Yu Yang 已提交
115
                elif attr.type == framework_pb2.INTS:
116
                    new_attr.ints.extend(user_defined_attr)
Y
Yu Yang 已提交
117
                elif attr.type == framework_pb2.FLOATS:
118
                    new_attr.floats.extend(user_defined_attr)
Y
Yu Yang 已提交
119
                elif attr.type == framework_pb2.STRINGS:
120
                    new_attr.strings.extend(user_defined_attr)
121
                elif attr.type == framework_pb2.BOOLEANS:
D
dangqingqing 已提交
122
                    new_attr.bools.extend(user_defined_attr)
S
seiriosPlus 已提交
123 124
                elif attr.type == framework_pb2.LONGS:
                    new_attr.longs.extend(user_defined_attr)
125 126
                elif attr.type == framework_pb2.FLOAT64:
                    new_attr.float64 = user_defined_attr
127
                else:
128
                    raise NotImplementedError(
129 130
                        "A not supported attribute type: %s." %
                        (str(attr.type)))
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
        for attr_name, defalut_val in self.__extra_attrs__.items():
            user_defined_attr = kwargs.get(attr_name, None)
            if user_defined_attr is not None:
                attr_type = int(
                    core.get_attrtibute_type(op_desc.type, attr_name))
                new_attr = op_desc.attrs.add()
                new_attr.name = attr_name
                new_attr.type = attr_type
                if isinstance(user_defined_attr, np.ndarray):
                    user_defined_attr = user_defined_attr.tolist()
                if attr_type == framework_pb2.INT:
                    new_attr.i = user_defined_attr
                elif attr_type == framework_pb2.FLOAT:
                    new_attr.f = user_defined_attr
                elif attr_type == framework_pb2.LONG:
                    new_attr.l = user_defined_attr
                elif attr_type == framework_pb2.STRING:
                    new_attr.s = user_defined_attr
                elif attr_type == framework_pb2.BOOLEAN:
                    new_attr.b = user_defined_attr
                elif attr_type == framework_pb2.INTS:
                    new_attr.ints.extend(user_defined_attr)
                elif attr_type == framework_pb2.FLOATS:
                    new_attr.floats.extend(user_defined_attr)
                elif attr_type == framework_pb2.STRINGS:
                    new_attr.strings.extend(user_defined_attr)
                elif attr_type == framework_pb2.BOOLEANS:
                    new_attr.bools.extend(user_defined_attr)
                elif attr_type == framework_pb2.LONGS:
                    new_attr.longs.extend(user_defined_attr)
                else:
                    raise NotImplementedError(
                        "A not supported attribute type: %s." %
                        (str(attr_type)))
165 166 167 168 169 170

        return op_desc

    @staticmethod
    def any_is_true(generator):
        """
171 172
        Reduce a boolean array to a single boolean parameter. If any element in
        the array is True, this function will return True, otherwise False.
173 174 175 176 177 178 179
        """
        for flag in generator:
            if flag:
                return True
        return False


Y
Yu Yang 已提交
180
class OpInfo(object):
181

182
    def __init__(self, name, method, inputs, outputs, attrs, extra_attrs):
Y
Yu Yang 已提交
183 184 185 186 187
        self.name = name
        self.method = method
        self.inputs = inputs
        self.outputs = outputs
        self.attrs = attrs
188
        self.extra_attrs = extra_attrs
Y
Yu Yang 已提交
189 190


191 192
def create_op_creation_method(op_proto):
    """
193
    Generate op creation method for an OpProto.
194 195 196 197 198 199 200
    """
    method = OpDescCreationMethod(op_proto)

    def __impl__(*args, **kwargs):
        opdesc = method(*args, **kwargs)
        return core.Operator.create(opdesc.SerializeToString())

201 202
    extra_attrs_map = core.get_op_extra_attrs(op_proto.type)

203 204 205 206 207 208
    return OpInfo(method=__impl__,
                  name=op_proto.type,
                  inputs=[(var.name, var.duplicable)
                          for var in op_proto.inputs],
                  outputs=[(var.name, var.duplicable)
                           for var in op_proto.outputs],
209 210
                  attrs=[attr.name for attr in op_proto.attrs],
                  extra_attrs=[item for item in extra_attrs_map.keys()])
211 212 213


class OperatorFactory(object):
214

215 216 217 218 219
    def __init__(self):
        self.op_methods = dict()

        for op_proto in get_all_op_protos():
            method = create_op_creation_method(op_proto)
Y
Yu Yang 已提交
220
            self.op_methods[method.name] = method
Y
Yu Yang 已提交
221

222
    def __call__(self, *args, **kwargs):
223
        if "type" in kwargs:
224
            if len(args) != 0:
225
                raise ValueError(
226 227 228
                    "Except the argument \"type\","
                    "all of the other arguments should be keyword arguments.")
            t = kwargs.pop("type")
229 230
        else:
            if len(args) != 1:
231
                raise ValueError(
232 233
                    "Except the argument \"type\","
                    "all of the other arguments should be keyword arguments.")
234
            t = args[0]
235

Y
Yu Yang 已提交
236
        return self.get_op_info(t).method(**kwargs)
237

Y
Yu Yang 已提交
238
    def types(self):
239
        return list(self.op_methods.keys())
Y
Yu Yang 已提交
240

Y
Yu Yang 已提交
241
    def get_op_info(self, t):
242
        if t not in self.op_methods:
243
            raise ValueError("The operator: %s is not registered." % t)
244
        return self.op_methods.get(t)
245

246
    def get_op_input_names(self, type):
247
        return [x[0] for x in self.get_op_info(type).inputs]
248 249

    def get_op_inputs(self, type):
Y
Yu Yang 已提交
250
        return self.get_op_info(type).inputs
251

252
    def get_op_output_names(self, type):
253
        return [x[0] for x in self.get_op_info(type).outputs]
254 255

    def get_op_outputs(self, type):
Y
Yu Yang 已提交
256
        return self.get_op_info(type).outputs
257

258
    def get_op_attr_names(self, type):
Y
Yu Yang 已提交
259
        return self.get_op_info(type).attrs
260

261 262 263
    def get_op_extra_attr_names(self, type):
        return self.get_op_info(type).extra_attrs

264

Y
Yan Chunwei 已提交
265 266
class __RecurrentOp__(object):
    __proto__ = None
267
    type = "recurrent"
Y
Yan Chunwei 已提交
268 269 270 271 272 273 274 275 276

    def __init__(self):
        # cache recurrent_op's proto
        if self.__proto__ is None:
            for op_proto in get_all_op_protos():
                if op_proto.type == self.type:
                    self.__proto__ = op_proto

    def __call__(self, *args, **kwargs):
277 278
        if self.type not in args and "type" not in kwargs:
            kwargs["type"] = self.type
Y
Yan Chunwei 已提交
279 280 281 282 283 284 285
        # create proto
        create_method = OpDescCreationMethod(self.__proto__)
        proto = create_method(*args, **kwargs)
        # create rnnop
        return core.RecurrentOp.create(proto.SerializeToString())


286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
class __DynamicRecurrentOp__(object):
    __proto__ = None
    type = "dynamic_recurrent"

    def __init__(self):
        # cache recurrent_op's proto
        if self.__proto__ is None:
            for op_proto in get_all_op_protos():
                if op_proto.type == self.type:
                    self.__proto__ = op_proto

    def __call__(self, *args, **kwargs):
        if self.type not in args and "type" not in kwargs:
            kwargs["type"] = self.type
        # create proto
        create_method = OpDescCreationMethod(self.__proto__)
        proto = create_method(*args, **kwargs)
        # create rnnop
        return core.DynamicRecurrentOp.create(proto.SerializeToString())


Z
cond op  
zchen0211 已提交
307 308
class __CondOp__(object):
    __proto__ = None
Z
zchen0211 已提交
309
    type = "cond"
Z
cond op  
zchen0211 已提交
310 311 312 313 314 315 316 317 318

    def __init__(self):
        # cache recurrent_op's proto
        if self.__proto__ is None:
            for op_proto in get_all_op_protos():
                if op_proto.type == self.type:
                    self.__proto__ = op_proto

    def __call__(self, *args, **kwargs):
Z
zchen0211 已提交
319 320
        if self.type not in args and "type" not in kwargs:
            kwargs["type"] = self.type
Z
cond op  
zchen0211 已提交
321 322 323 324 325 326 327
        # create proto
        create_method = OpDescCreationMethod(self.__proto__)
        proto = create_method(*args, **kwargs)
        # create condop
        return core.CondOp.create(proto.SerializeToString())


328
Operator = OperatorFactory()  # The default global factory
Y
Yan Chunwei 已提交
329
RecurrentOp = __RecurrentOp__()
330
DynamicRecurrentOp = __DynamicRecurrentOp__()
Z
cond op  
zchen0211 已提交
331
CondOp = __CondOp__()