From 569616b329db71bfc4739021d55e0a74179732e2 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 9 Oct 2017 14:04:36 -0700 Subject: [PATCH] Complete Variable for Python API --- python/paddle/v2/framework/graph.py | 59 ++++++++++++++++--- .../v2/framework/tests/test_variable.py | 20 ++++++- 2 files changed, 71 insertions(+), 8 deletions(-) diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py index a7a3ca62c7f..a66e7a9d731 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/graph.py @@ -12,23 +12,68 @@ class Variable(object): if name is None: name = Variable._unique_var_name_() - self.proto = self.block.proto.new_var(name) + try: + self.proto = self.block.proto.var(name) + is_new_var = False + except core.EnforceNotMet: + self.proto = self.block.proto.new_var(name) + is_new_var = True if shape is not None: - self.proto.set_shape(shape) - + if is_new_var: + self.proto.set_shape(shape) + else: + old_shape = self.shape + shape = tuple(shape) + if shape != old_shape: + raise ValueError( + "Variable {0} has been created before. the previous " + "shape is {1}; the new shape is {2}. They are not " + "matched.".format(self.name, old_shape, shape)) if dtype is not None: if not isinstance(dtype, core.DataType): dtype = Variable._convert_np_dtype_to_dtype_(dtype) - self.proto.set_data_type(dtype) + if is_new_var: + self.proto.set_data_type(dtype) + else: + old_dtype = self.data_type() + if dtype != old_shape: + raise ValueError("Variable {0} has been created before. " + "The previous data type is {1}; the new " + "data type is {2}. They are not " + "matched.".format(self.name, old_dtype, + dtype)) if lod_level is not None: - self.proto.set_lod_level(lod_level) + if is_new_var: + self.proto.set_lod_level(lod_level) + else: + if lod_level != self.lod_level: + raise ValueError("Variable {0} has been created before. " + "The previous lod_level is {1}; the new " + "lod_level is {2}. They are not " + "matched".format(self.name, self.lod_level, + lod_level)) self.block.vars[name] = self self.op = None - # TODO(yuyang18): Get methods + @property + def name(self): + return self.proto.name() + + @property + def shape(self): + # convert to tuple, make it as same as numpy API. + return tuple(self.proto.shape()) + + @property + def data_type(self): + return self.proto.data_type() + + @property + def lod_level(self): + return self.proto.lod_level() @staticmethod def _unique_var_name_(): @@ -79,7 +124,7 @@ class Operator(object): # TODO pass - # TODO: Getters + # TODO: Getters class Block(object): diff --git a/python/paddle/v2/framework/tests/test_variable.py b/python/paddle/v2/framework/tests/test_variable.py index dd23eac0cd1..8ea1083ff65 100644 --- a/python/paddle/v2/framework/tests/test_variable.py +++ b/python/paddle/v2/framework/tests/test_variable.py @@ -1,5 +1,5 @@ import unittest -from paddle.v2.framework.graph import Variable +from paddle.v2.framework.graph import Variable, g_program import paddle.v2.framework.core as core import numpy as np @@ -17,6 +17,24 @@ class TestVariable(unittest.TestCase): self.assertEqual(DT.BOOL, convert("bool")) self.assertRaises(ValueError, lambda: convert("int8")) + def test_var(self): + b = g_program.current_block() + w = b.create_var( + dtype="float64", shape=[784, 100], lod_level=0, name="fc.w") + self.assertEqual(core.DataType.FP64, w.data_type) + self.assertEqual((784, 100), w.shape) + self.assertEqual("fc.w", w.name) + self.assertEqual(0, w.lod_level) + + w = b.create_var(name='fc.w') + self.assertEqual(core.DataType.FP64, w.data_type) + self.assertEqual((784, 100), w.shape) + self.assertEqual("fc.w", w.name) + self.assertEqual(0, w.lod_level) + + self.assertRaises(ValueError, + lambda: b.create_var(name="fc.w", shape=(24, 100))) + if __name__ == '__main__': unittest.main() -- GitLab