diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py index 7468279438be24f349d9d1e7bf1daa04651ec746..5b93115b3e9789d8fadcb90d272ac479a5e92337 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/graph.py @@ -1,5 +1,37 @@ import paddle.v2.framework.core as core +__all__ = ['Block', 'Variable', 'Program'] + + +class Variable(object): + def __init__(self, block, name=None, shape=None, dtype=None, + lod_level=None): + self.block = block + + if name is None: + name = Variable._unique_var_name_() + self.proto = self.block.proto.new_var(name) + + if shape is not None: + self.proto.set_shape(shape) + + if dtype is not None: + # TODO(yuyang18): Convert dtype from numpy.dtype + self.proto.set_data_type(dtype) + + if lod_level is not None: + # TODO(yuyang18): set_lod_level is not defined. + self.proto.set_lod_level(lod_level) + + self.block.vars[name] = self + + # TODO(yuyang18): Get methods + + @staticmethod + def _unique_var_name_(): + uid = core.unique_integer() # unique during whole process. + return "_generated_var_%d" % uid + class Block(object): def __init__(self, program, idx): @@ -16,6 +48,9 @@ class Block(object): def idx(self): return self.proto.id + def create_var(self, *args, **kwargs): + return Variable(self, *args, **kwargs) + class Program(object): @classmethod