From 1e41a675d4111a826ffac45cbd197054d193d72e Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 9 Oct 2017 13:39:38 -0700 Subject: [PATCH] Convert np.dtype to core.DataType --- python/paddle/v2/framework/graph.py | 25 +++++++++++++++++-- .../v2/framework/tests/test_variable.py | 22 ++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) create mode 100644 python/paddle/v2/framework/tests/test_variable.py diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py index 6f2a76a9835..a7a3ca62c7f 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/graph.py @@ -1,5 +1,6 @@ import paddle.v2.framework.core as core import collections +import numpy as np __all__ = ['Block', 'Variable', 'Program', 'Operator'] @@ -17,11 +18,11 @@ class Variable(object): self.proto.set_shape(shape) if dtype is not None: - # TODO(yuyang18): Convert dtype from numpy.dtype + if not isinstance(dtype, core.DataType): + dtype = Variable._convert_np_dtype_to_dtype_(dtype) self.proto.set_data_type(dtype) if lod_level is not None: - # TODO(yuyang18): set_lod_level is not defined. self.proto.set_lod_level(lod_level) self.block.vars[name] = self @@ -34,6 +35,26 @@ class Variable(object): uid = core.unique_integer() # unique during whole process. return "_generated_var_%d" % uid + @staticmethod + def _convert_np_dtype_to_dtype_(np_dtype): + dtype = np.dtype(np_dtype) + if dtype == np.float32: + return core.DataType.FP32 + elif dtype == np.float64: + return core.DataType.FP64 + elif dtype == np.float16: + return core.DataType.FP16 + elif dtype == np.int32: + return core.DataType.INT32 + elif dtype == np.int16: + return core.DataType.INT16 + elif dtype == np.int64: + return core.DataType.INT64 + elif dtype == np.bool: + return core.DataType.BOOL + else: + raise ValueError("Not supported numpy dtype " + str(dtype)) + class Operator(object): def __init__(self, diff --git a/python/paddle/v2/framework/tests/test_variable.py b/python/paddle/v2/framework/tests/test_variable.py new file mode 100644 index 00000000000..dd23eac0cd1 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_variable.py @@ -0,0 +1,22 @@ +import unittest +from paddle.v2.framework.graph import Variable +import paddle.v2.framework.core as core +import numpy as np + + +class TestVariable(unittest.TestCase): + def test_np_dtype_convert(self): + DT = core.DataType + convert = Variable._convert_np_dtype_to_dtype_ + self.assertEqual(DT.FP32, convert(np.float32)) + self.assertEqual(DT.FP16, convert("float16")) + self.assertEqual(DT.FP64, convert("float64")) + self.assertEqual(DT.INT32, convert("int32")) + self.assertEqual(DT.INT16, convert("int16")) + self.assertEqual(DT.INT64, convert("int64")) + self.assertEqual(DT.BOOL, convert("bool")) + self.assertRaises(ValueError, lambda: convert("int8")) + + +if __name__ == '__main__': + unittest.main() -- GitLab