From 6915c924a4922b2c92c7584e6e15a6c3ee45d945 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 25 Sep 2017 17:22:29 -0700 Subject: [PATCH] Fix bug --- paddle/pybind/protobuf.cc | 20 ++++++++++++------- .../v2/framework/tests/test_protobuf_descs.py | 4 ++-- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 5d5782a6f8..3388b5cfdc 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -130,16 +130,15 @@ public: VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims()); } - void SetDataType(int type_id) { - desc_.mutable_lod_tensor()->set_data_type( - static_cast(type_id)); + void SetDataType(framework::DataType data_type) { + desc_.mutable_lod_tensor()->set_data_type(data_type); } std::vector Shape() { return RepeatedToVector(desc_.lod_tensor().dims()); } - int DataType() { return desc_.lod_tensor().data_type(); } + framework::DataType DataType() { return desc_.lod_tensor().data_type(); } private: VarDesc desc_; @@ -502,14 +501,21 @@ void BindBlockDesc(py::module &m) { } void BindVarDsec(py::module &m) { + py::enum_(m, "DataType", "") + .value("BOOL", DataType::BOOL) + .value("INT16", DataType::INT16) + .value("INT32", DataType::INT32) + .value("INT64", DataType::INT64) + .value("FP16", DataType::FP16) + .value("FP32", DataType::FP32) + .value("FP64", DataType::FP64); + py::class_(m, "VarDesc", "") .def("name", &VarDescBind::Name, py::return_value_policy::reference) .def("set_shape", &VarDescBind::SetShape) .def("set_data_type", &VarDescBind::SetDataType) .def("shape", &VarDescBind::Shape, py::return_value_policy::reference) - .def("data_type", - &VarDescBind::DataType, - py::return_value_policy::reference); + .def("data_type", &VarDescBind::DataType); } void BindOpDesc(py::module &m) { diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index 8bb50cbbd1..13d819abf4 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -97,8 +97,8 @@ class TestVarDesc(unittest.TestCase): program_desc = core.ProgramDesc.__create_program_desc__() block = program_desc.block(0) var = block.new_var('my_var') - var.set_data_type(2) - self.assertEqual(2, var.data_type) + var.set_data_type(core.DataType.INT32) + self.assertEqual(core.DataType.INT32, var.data_type()) class TestBlockDesc(unittest.TestCase): -- GitLab