graph.py 7.5 KB
Newer Older
Y
Yu Yang 已提交
1
import paddle.v2.framework.core as core
Y
Yu Yang 已提交
2
import collections
Y
Yu Yang 已提交
3
import numpy as np
Y
Yu Yang 已提交
4
import copy
Y
Yu Yang 已提交
5

Y
Yu Yang 已提交
6
__all__ = ['Block', 'Variable', 'Program', 'Operator']
Y
Yu Yang 已提交
7 8 9


class Variable(object):
Y
Yu Yang 已提交
10 11 12 13 14 15 16
    def __init__(self,
                 block,
                 name=None,
                 shape=None,
                 dtype=None,
                 lod_level=None,
                 **kwargs):
Y
Yu Yang 已提交
17 18 19 20
        self.block = block

        if name is None:
            name = Variable._unique_var_name_()
Y
Yu Yang 已提交
21
        try:
22
            self.desc = self.block.desc.var(name)
Y
Yu Yang 已提交
23 24
            is_new_var = False
        except core.EnforceNotMet:
25
            self.desc = self.block.desc.new_var(name)
Y
Yu Yang 已提交
26
            is_new_var = True
Y
Yu Yang 已提交
27 28

        if shape is not None:
Y
Yu Yang 已提交
29
            if is_new_var:
30
                self.desc.set_shape(shape)
Y
Yu Yang 已提交
31 32 33 34 35 36 37 38
            else:
                old_shape = self.shape
                shape = tuple(shape)
                if shape != old_shape:
                    raise ValueError(
                        "Variable {0} has been created before. the previous "
                        "shape is {1}; the new shape is {2}. They are not "
                        "matched.".format(self.name, old_shape, shape))
Y
Yu Yang 已提交
39
        if dtype is not None:
Y
Yu Yang 已提交
40 41
            if not isinstance(dtype, core.DataType):
                dtype = Variable._convert_np_dtype_to_dtype_(dtype)
Y
Yu Yang 已提交
42
            if is_new_var:
43
                self.desc.set_data_type(dtype)
Y
Yu Yang 已提交
44 45 46 47 48 49 50 51
            else:
                old_dtype = self.data_type()
                if dtype != old_shape:
                    raise ValueError("Variable {0} has been created before. "
                                     "The previous data type is {1}; the new "
                                     "data type is {2}. They are not "
                                     "matched.".format(self.name, old_dtype,
                                                       dtype))
Y
Yu Yang 已提交
52 53

        if lod_level is not None:
Y
Yu Yang 已提交
54
            if is_new_var:
55
                self.desc.set_lod_level(lod_level)
Y
Yu Yang 已提交
56 57 58 59 60 61 62
            else:
                if lod_level != self.lod_level:
                    raise ValueError("Variable {0} has been created before. "
                                     "The previous lod_level is {1}; the new "
                                     "lod_level is {2}. They are not "
                                     "matched".format(self.name, self.lod_level,
                                                      lod_level))
Y
Yu Yang 已提交
63
        self.block.vars[name] = self
Y
Yu Yang 已提交
64
        self.op = None
Y
Yu Yang 已提交
65

Y
Yu Yang 已提交
66 67
    @property
    def name(self):
68
        return self.desc.name()
Y
Yu Yang 已提交
69 70 71 72

    @property
    def shape(self):
        # convert to tuple, make it as same as numpy API.
73
        return tuple(self.desc.shape())
Y
Yu Yang 已提交
74 75 76

    @property
    def data_type(self):
77
        return self.desc.data_type()
Y
Yu Yang 已提交
78 79 80

    @property
    def lod_level(self):
81
        return self.desc.lod_level()
Y
Yu Yang 已提交
82 83 84 85 86 87

    @staticmethod
    def _unique_var_name_():
        uid = core.unique_integer()  # unique during whole process.
        return "_generated_var_%d" % uid

Y
Yu Yang 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    @staticmethod
    def _convert_np_dtype_to_dtype_(np_dtype):
        dtype = np.dtype(np_dtype)
        if dtype == np.float32:
            return core.DataType.FP32
        elif dtype == np.float64:
            return core.DataType.FP64
        elif dtype == np.float16:
            return core.DataType.FP16
        elif dtype == np.int32:
            return core.DataType.INT32
        elif dtype == np.int16:
            return core.DataType.INT16
        elif dtype == np.int64:
            return core.DataType.INT64
        elif dtype == np.bool:
            return core.DataType.BOOL
        else:
            raise ValueError("Not supported numpy dtype " + str(dtype))

Y
Yu Yang 已提交
108

Y
Yu Yang 已提交
109 110 111
class Operator(object):
    def __init__(self,
                 block,
Y
Yu Yang 已提交
112
                 desc,
Y
Yu Yang 已提交
113 114 115 116 117
                 type=None,
                 inputs=None,
                 outputs=None,
                 attrs=None):
        self.block = block
Y
Yu Yang 已提交
118
        self.desc = desc
Y
Yu Yang 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131
        if type is not None:
            # TODO.
            pass
        if inputs is not None:
            # TODO
            pass
        if outputs is not None:
            # TODO
            pass
        if attrs is not None:
            # TODO
            pass

Y
Yu Yang 已提交
132
            # TODO: Getters
Y
Yu Yang 已提交
133 134


Y
Yu Yang 已提交
135 136
class Block(object):
    def __init__(self, program, idx):
Y
Yu Yang 已提交
137
        self.desc = program.desc.block(idx)
Y
Yu Yang 已提交
138
        self.vars = dict()  # var_name --> var
Y
Yu Yang 已提交
139
        self.ops = collections.deque()  # operator list
Y
Yu Yang 已提交
140 141 142 143
        self.program = program

    @property
    def parent_idx(self):
Y
Yu Yang 已提交
144
        return self.desc.parent
Y
Yu Yang 已提交
145 146 147

    @property
    def idx(self):
Y
Yu Yang 已提交
148
        return self.desc.id
Y
Yu Yang 已提交
149

Y
Yu Yang 已提交
150 151 152
    def create_var(self, *args, **kwargs):
        return Variable(self, *args, **kwargs)

Y
Yu Yang 已提交
153 154 155 156
    def create_parameter(self, *args, **kwargs):
        global_block = self.program.global_block()
        return Parameter(global_block, *args, **kwargs)

Y
Yu Yang 已提交
157
    def append_op(self, *args, **kwargs):
Y
Yu Yang 已提交
158 159
        op_desc = self.desc.append_op()
        op = Operator(self, op_desc, *args, **kwargs)
Y
Yu Yang 已提交
160 161 162 163
        self.ops.append(op)
        return op

    def prepend_op(self, *args, **kwargs):
Y
Yu Yang 已提交
164 165
        op_desc = self.desc.prepend_op()
        op = Operator(self, op_desc, *args, **kwargs)
Y
Yu Yang 已提交
166 167 168
        self.ops.appendleft(op)
        return op

Y
Yu Yang 已提交
169 170

class Program(object):
Y
Yu Yang 已提交
171 172 173 174 175 176 177 178
    @classmethod
    def instance(cls):
        # From https://stackoverflow.com/questions/8212053
        # Making Program as a Singleton class.
        if not hasattr(cls, '_instance'):
            cls._instance = cls()
        return cls._instance

Y
Yu Yang 已提交
179
    def __init__(self):
Y
Yu Yang 已提交
180 181
        assert not hasattr(self.__class__,
                           '_instance'), 'Do not call constructor directly!'
Y
Yu Yang 已提交
182
        self.desc = core.ProgramDesc.instance()
Y
Yu Yang 已提交
183 184 185 186 187 188 189 190 191 192 193
        self.blocks = [Block(self, 0)]
        self.current_block_idx = 0

    def global_block(self):
        return self.blocks[0]

    def current_block(self):
        return self.blocks[self.current_block_idx]

    def create_block(self):
        new_block_idx = len(self.blocks)
Y
Yu Yang 已提交
194
        self.desc.append_block(self.current_block().desc)
Y
Yu Yang 已提交
195 196 197 198 199 200 201 202
        self.current_block_idx = new_block_idx
        self.blocks.append(Block(self, self.current_block_idx))
        return self.current_block()

    def rollback(self):
        self.current_block_idx = self.current_block().parent_idx


Y
Yu Yang 已提交
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
class Parameter(Variable):
    def __init__(self, block, shape, dtype, **kwargs):
        if shape is None or dtype is None:
            raise ValueError("Parameter must set shape and dtype")
        if len(shape) == 0:
            raise ValueError("Parameter shape cannot be empty")

        for each in shape:
            if each < 0:
                raise ValueError("Parameter shape should not be related with "
                                 "batch-size")

        Variable.__init__(self, block, shape=shape, dtype=dtype, **kwargs)
        self.trainable = kwargs.get('trainable', True)
        self.init_attr = kwargs.get('initialize_attr', {
            'type': 'uniform_random',
            'min': -1.0,
            'max': 1.0
        })

        self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
        self._append_initialize_ops_()

    def _append_initialize_ops_(self):
        attr = copy.deepcopy(self.init_attr)
        op_type = attr.pop('type', None)
        block = self.block
        assert isinstance(block, Block)
        shape = self.shape
        attr['dims'] = shape
        attr['data_type'] = int(self.data_type)
        op = block.prepend_op(
            type=op_type, inputs=None, outputs={'Out': [self]}, attrs=attr)
        self.op = op


Y
Yu Yang 已提交
239
# program is a global instance.
Y
Yu Yang 已提交
240
g_program = Program.instance()