提交 569616b3 编写于 作者: Y Yu Yang

Complete Variable for Python API

上级 1e41a675
......@@ -12,23 +12,68 @@ class Variable(object):
if name is None:
name = Variable._unique_var_name_()
self.proto = self.block.proto.new_var(name)
try:
self.proto = self.block.proto.var(name)
is_new_var = False
except core.EnforceNotMet:
self.proto = self.block.proto.new_var(name)
is_new_var = True
if shape is not None:
self.proto.set_shape(shape)
if is_new_var:
self.proto.set_shape(shape)
else:
old_shape = self.shape
shape = tuple(shape)
if shape != old_shape:
raise ValueError(
"Variable {0} has been created before. the previous "
"shape is {1}; the new shape is {2}. They are not "
"matched.".format(self.name, old_shape, shape))
if dtype is not None:
if not isinstance(dtype, core.DataType):
dtype = Variable._convert_np_dtype_to_dtype_(dtype)
self.proto.set_data_type(dtype)
if is_new_var:
self.proto.set_data_type(dtype)
else:
old_dtype = self.data_type()
if dtype != old_shape:
raise ValueError("Variable {0} has been created before. "
"The previous data type is {1}; the new "
"data type is {2}. They are not "
"matched.".format(self.name, old_dtype,
dtype))
if lod_level is not None:
self.proto.set_lod_level(lod_level)
if is_new_var:
self.proto.set_lod_level(lod_level)
else:
if lod_level != self.lod_level:
raise ValueError("Variable {0} has been created before. "
"The previous lod_level is {1}; the new "
"lod_level is {2}. They are not "
"matched".format(self.name, self.lod_level,
lod_level))
self.block.vars[name] = self
self.op = None
# TODO(yuyang18): Get methods
@property
def name(self):
return self.proto.name()
@property
def shape(self):
# convert to tuple, make it as same as numpy API.
return tuple(self.proto.shape())
@property
def data_type(self):
return self.proto.data_type()
@property
def lod_level(self):
return self.proto.lod_level()
@staticmethod
def _unique_var_name_():
......@@ -79,7 +124,7 @@ class Operator(object):
# TODO
pass
# TODO: Getters
# TODO: Getters
class Block(object):
......
import unittest
from paddle.v2.framework.graph import Variable
from paddle.v2.framework.graph import Variable, g_program
import paddle.v2.framework.core as core
import numpy as np
......@@ -17,6 +17,24 @@ class TestVariable(unittest.TestCase):
self.assertEqual(DT.BOOL, convert("bool"))
self.assertRaises(ValueError, lambda: convert("int8"))
def test_var(self):
b = g_program.current_block()
w = b.create_var(
dtype="float64", shape=[784, 100], lod_level=0, name="fc.w")
self.assertEqual(core.DataType.FP64, w.data_type)
self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name)
self.assertEqual(0, w.lod_level)
w = b.create_var(name='fc.w')
self.assertEqual(core.DataType.FP64, w.data_type)
self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name)
self.assertEqual(0, w.lod_level)
self.assertRaises(ValueError,
lambda: b.create_var(name="fc.w", shape=(24, 100)))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册