graph.py 10.7 KB
Newer Older
Y
Yu Yang 已提交
1
import paddle.v2.framework.core as core
F
fengjiayi 已提交
2
import paddle.v2.framework.proto.framework_pb2 as framework_pb2
Y
Yu Yang 已提交
3
import collections
Y
Yu Yang 已提交
4
import numpy as np
Y
Yu Yang 已提交
5
import copy
Y
Yu Yang 已提交
6

Y
Yu Yang 已提交
7
__all__ = ['Block', 'Variable', 'Program', 'Operator']
Y
Yu Yang 已提交
8 9


F
fengjiayi 已提交
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
def get_all_op_protos():
    """
    Get all registered op proto from PaddlePaddle C++ end.
    :return: A list of registered OpProto.
    """
    protostrs = core.get_all_op_protos()
    ret_values = []
    for pbstr in protostrs:
        op_proto = framework_pb2.OpProto.FromString(str(pbstr))
        ret_values.append(op_proto)
    return ret_values


class OpProtoHolder(object):
    @classmethod
    def instance(cls):
        if not hasattr(cls, '_instance'):
            cls._instance = cls()
        return cls._instance

    def __init__(self):
        assert not hasattr(
            self.__class__,
            '_instance'), 'Please use `instance()` to get OpProtoHolder opject!'
        op_protos = get_all_op_protos()
        self.op_proto_map = {}
        for proto in op_protos:
F
fengjiayi 已提交
37
            self.op_proto_map[proto.type] = proto
F
fengjiayi 已提交
38 39 40 41 42 43

    def get_op_proto(self, type):
        assert type in self.op_proto_map, "Operator with type \"%s\" has not been registered." % type
        return self.op_proto_map[type]


Y
Yu Yang 已提交
44
class Variable(object):
Y
Yu Yang 已提交
45 46 47 48 49 50 51
    def __init__(self,
                 block,
                 name=None,
                 shape=None,
                 dtype=None,
                 lod_level=None,
                 **kwargs):
Y
Yu Yang 已提交
52 53 54 55
        self.block = block

        if name is None:
            name = Variable._unique_var_name_()
Y
Yu Yang 已提交
56
        try:
57
            self.desc = self.block.desc.var(name)
Y
Yu Yang 已提交
58 59
            is_new_var = False
        except core.EnforceNotMet:
60
            self.desc = self.block.desc.new_var(name)
Y
Yu Yang 已提交
61
            is_new_var = True
Y
Yu Yang 已提交
62 63

        if shape is not None:
Y
Yu Yang 已提交
64
            if is_new_var:
65
                self.desc.set_shape(shape)
Y
Yu Yang 已提交
66 67 68 69 70 71 72 73
            else:
                old_shape = self.shape
                shape = tuple(shape)
                if shape != old_shape:
                    raise ValueError(
                        "Variable {0} has been created before. the previous "
                        "shape is {1}; the new shape is {2}. They are not "
                        "matched.".format(self.name, old_shape, shape))
Y
Yu Yang 已提交
74
        if dtype is not None:
Y
Yu Yang 已提交
75 76
            if not isinstance(dtype, core.DataType):
                dtype = Variable._convert_np_dtype_to_dtype_(dtype)
Y
Yu Yang 已提交
77
            if is_new_var:
78
                self.desc.set_data_type(dtype)
Y
Yu Yang 已提交
79 80 81 82 83 84 85 86
            else:
                old_dtype = self.data_type()
                if dtype != old_shape:
                    raise ValueError("Variable {0} has been created before. "
                                     "The previous data type is {1}; the new "
                                     "data type is {2}. They are not "
                                     "matched.".format(self.name, old_dtype,
                                                       dtype))
Y
Yu Yang 已提交
87 88

        if lod_level is not None:
Y
Yu Yang 已提交
89
            if is_new_var:
90
                self.desc.set_lod_level(lod_level)
Y
Yu Yang 已提交
91 92 93 94 95 96 97
            else:
                if lod_level != self.lod_level:
                    raise ValueError("Variable {0} has been created before. "
                                     "The previous lod_level is {1}; the new "
                                     "lod_level is {2}. They are not "
                                     "matched".format(self.name, self.lod_level,
                                                      lod_level))
Y
Yu Yang 已提交
98
        self.block.vars[name] = self
Y
Yu Yang 已提交
99
        self.op = None
Y
Yu Yang 已提交
100

Y
Yu Yang 已提交
101 102
    @property
    def name(self):
103
        return self.desc.name()
Y
Yu Yang 已提交
104 105 106 107

    @property
    def shape(self):
        # convert to tuple, make it as same as numpy API.
108
        return tuple(self.desc.shape())
Y
Yu Yang 已提交
109 110 111

    @property
    def data_type(self):
112
        return self.desc.data_type()
Y
Yu Yang 已提交
113 114 115

    @property
    def lod_level(self):
116
        return self.desc.lod_level()
Y
Yu Yang 已提交
117 118 119 120 121 122

    @staticmethod
    def _unique_var_name_():
        uid = core.unique_integer()  # unique during whole process.
        return "_generated_var_%d" % uid

Y
Yu Yang 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    @staticmethod
    def _convert_np_dtype_to_dtype_(np_dtype):
        dtype = np.dtype(np_dtype)
        if dtype == np.float32:
            return core.DataType.FP32
        elif dtype == np.float64:
            return core.DataType.FP64
        elif dtype == np.float16:
            return core.DataType.FP16
        elif dtype == np.int32:
            return core.DataType.INT32
        elif dtype == np.int16:
            return core.DataType.INT16
        elif dtype == np.int64:
            return core.DataType.INT64
        elif dtype == np.bool:
            return core.DataType.BOOL
        else:
            raise ValueError("Not supported numpy dtype " + str(dtype))

Y
Yu Yang 已提交
143

Y
Yu Yang 已提交
144
class Operator(object):
F
Update  
fengjiayi 已提交
145
    def __init__(self, block, desc, type, inputs=None, outputs=None,
Y
Yu Yang 已提交
146 147
                 attrs=None):
        self.block = block
F
Update  
fengjiayi 已提交
148 149 150
        self.desc = desc
        self.proto = OpProtoHolder.instance().get_op_proto(type)
        self.desc.set_type(type)
F
Update  
fengjiayi 已提交
151

Y
Yu Yang 已提交
152
        if inputs is not None:
F
Update  
fengjiayi 已提交
153
            for in_proto in self.proto.inputs:
F
Update  
fengjiayi 已提交
154 155 156 157 158 159 160 161 162 163 164
                in_argus = inputs[in_proto.name]
                if not isinstance(in_argus, list):
                    in_argus = [in_argus]
                if not in_proto.duplicable and len(in_argus) > 1:
                    raise ValueError(
                        "Input %s expects only one input, but %d are given." %
                        (in_proto.name, len(in_argus)))
                in_argu_names = []
                for argu in in_argus:
                    in_argu_names.append(argu.name())
                self.desc.set_input(in_proto.name, in_argu_names)
F
Update  
fengjiayi 已提交
165

Y
Yu Yang 已提交
166
        if outputs is not None:
F
Update  
fengjiayi 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179
            for out_proto in self.proto.outputs:
                out_argus = outputs[out_proto.name]
                if not isinstance(out_argus, list):
                    out_argus = [out_argus]
                if not out_proto.duplicable and len(out_argus) > 1:
                    raise ValueError(
                        "Output %s expects only one output, but %d are given." %
                        (out_proto.name, len(out_argus)))
                out_argu_names = []
                for argu in out_argus:
                    out_argu_names.append(argu.name())
                self.desc.set_output(out_proto.name, out_argu_names)

Y
Yu Yang 已提交
180
        if attrs is not None:
F
Update  
fengjiayi 已提交
181 182 183 184 185 186 187 188
            for attr in self.proto.attrs:
                attr_name = attr.name
                if not attr_name in attrs:
                    continue
                if not isinstance(attrs[attr_name], Block):
                    self.desc.set_attr(attr_name, attrs[attr_name])
                else:
                    self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
Y
Yu Yang 已提交
189

F
fengjiayi 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
    @property
    def type(self):
        return self.desc.type()

    def input(self, name):
        return self.desc.input(name)

    @property
    def input_names(self):
        return self.desc.input_names()

    def output(self, name):
        return self.desc.output(name)

    @property
    def output_names(self):
        return self.desc.output_names()

    def has_attr(self, name):
        return self.desc.has_attr(name)

    def attr_type(self, name):
        return self.desc.attr_type(name)

    @property
    def attr_names(self):
        return self.desc.attr_names()

    def attr(self, name):
        return self.desc.attr(name)

    def block_attr(self, name):
        return self.desc.block_attr(name)
Y
Yu Yang 已提交
223 224


Y
Yu Yang 已提交
225 226
class Block(object):
    def __init__(self, program, idx):
F
fengjiayi 已提交
227
        self.desc = program.desc.block(idx)
Y
Yu Yang 已提交
228
        self.vars = dict()  # var_name --> var
Y
Yu Yang 已提交
229
        self.ops = collections.deque()  # operator list
Y
Yu Yang 已提交
230 231 232 233
        self.program = program

    @property
    def parent_idx(self):
F
fengjiayi 已提交
234
        return self.desc.parent
Y
Yu Yang 已提交
235 236 237

    @property
    def idx(self):
F
fengjiayi 已提交
238
        return self.desc.id
Y
Yu Yang 已提交
239

Y
Yu Yang 已提交
240 241 242
    def create_var(self, *args, **kwargs):
        return Variable(self, *args, **kwargs)

Y
Yu Yang 已提交
243 244 245 246
    def create_parameter(self, *args, **kwargs):
        global_block = self.program.global_block()
        return Parameter(global_block, *args, **kwargs)

Y
Yu Yang 已提交
247
    def append_op(self, *args, **kwargs):
F
fengjiayi 已提交
248
        op_desc = self.desc.append_op()
F
Update  
fengjiayi 已提交
249
        op = Operator(self, op_desc, *args, **kwargs)
Y
Yu Yang 已提交
250 251 252 253
        self.ops.append(op)
        return op

    def prepend_op(self, *args, **kwargs):
F
fengjiayi 已提交
254 255
        op_desc = self.desc.prepend_op()
        op = Operator(self, op_desc, *args, **kwargs)
Y
Yu Yang 已提交
256 257 258
        self.ops.appendleft(op)
        return op

Y
Yu Yang 已提交
259 260

class Program(object):
Y
Yu Yang 已提交
261 262 263 264 265 266 267 268
    @classmethod
    def instance(cls):
        # From https://stackoverflow.com/questions/8212053
        # Making Program as a Singleton class.
        if not hasattr(cls, '_instance'):
            cls._instance = cls()
        return cls._instance

Y
Yu Yang 已提交
269
    def __init__(self):
Y
Yu Yang 已提交
270 271
        assert not hasattr(self.__class__,
                           '_instance'), 'Do not call constructor directly!'
F
fengjiayi 已提交
272
        self.desc = core.ProgramDesc.instance()
Y
Yu Yang 已提交
273 274 275 276 277 278 279 280 281 282 283
        self.blocks = [Block(self, 0)]
        self.current_block_idx = 0

    def global_block(self):
        return self.blocks[0]

    def current_block(self):
        return self.blocks[self.current_block_idx]

    def create_block(self):
        new_block_idx = len(self.blocks)
F
fengjiayi 已提交
284
        self.desc.append_block(self.current_block().desc)
Y
Yu Yang 已提交
285 286 287 288 289 290 291 292
        self.current_block_idx = new_block_idx
        self.blocks.append(Block(self, self.current_block_idx))
        return self.current_block()

    def rollback(self):
        self.current_block_idx = self.current_block().parent_idx


Y
Yu Yang 已提交
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
class Parameter(Variable):
    def __init__(self, block, shape, dtype, **kwargs):
        if shape is None or dtype is None:
            raise ValueError("Parameter must set shape and dtype")
        if len(shape) == 0:
            raise ValueError("Parameter shape cannot be empty")

        for each in shape:
            if each < 0:
                raise ValueError("Parameter shape should not be related with "
                                 "batch-size")

        Variable.__init__(self, block, shape=shape, dtype=dtype, **kwargs)
        self.trainable = kwargs.get('trainable', True)
        self.init_attr = kwargs.get('initialize_attr', {
            'type': 'uniform_random',
            'min': -1.0,
            'max': 1.0
        })

        self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
        self._append_initialize_ops_()

    def _append_initialize_ops_(self):
        attr = copy.deepcopy(self.init_attr)
        op_type = attr.pop('type', None)
        block = self.block
        assert isinstance(block, Block)
        shape = self.shape
        attr['dims'] = shape
        attr['data_type'] = int(self.data_type)
        op = block.prepend_op(
            type=op_type, inputs=None, outputs={'Out': [self]}, attrs=attr)
        self.op = op


Y
Yu Yang 已提交
329
# program is a global instance.
Y
Yu Yang 已提交
330
g_program = Program.instance()