diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 5d5782a6f8abb100b6f37b5187be8af8c15ce6d5..3388b5cfdc0640237a3ef586535ccf1812845c4e 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -130,16 +130,15 @@ public: VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims()); } - void SetDataType(int type_id) { - desc_.mutable_lod_tensor()->set_data_type( - static_cast(type_id)); + void SetDataType(framework::DataType data_type) { + desc_.mutable_lod_tensor()->set_data_type(data_type); } std::vector Shape() { return RepeatedToVector(desc_.lod_tensor().dims()); } - int DataType() { return desc_.lod_tensor().data_type(); } + framework::DataType DataType() { return desc_.lod_tensor().data_type(); } private: VarDesc desc_; @@ -502,14 +501,21 @@ void BindBlockDesc(py::module &m) { } void BindVarDsec(py::module &m) { + py::enum_(m, "DataType", "") + .value("BOOL", DataType::BOOL) + .value("INT16", DataType::INT16) + .value("INT32", DataType::INT32) + .value("INT64", DataType::INT64) + .value("FP16", DataType::FP16) + .value("FP32", DataType::FP32) + .value("FP64", DataType::FP64); + py::class_(m, "VarDesc", "") .def("name", &VarDescBind::Name, py::return_value_policy::reference) .def("set_shape", &VarDescBind::SetShape) .def("set_data_type", &VarDescBind::SetDataType) .def("shape", &VarDescBind::Shape, py::return_value_policy::reference) - .def("data_type", - &VarDescBind::DataType, - py::return_value_policy::reference); + .def("data_type", &VarDescBind::DataType); } void BindOpDesc(py::module &m) { diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index 8bb50cbbd14eaf1eecd62a0d6b4e4cf3ee5d0184..13d819abf418f180f9e661d6c8cc03c6c70dd426 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -97,8 +97,8 @@ class TestVarDesc(unittest.TestCase): program_desc = core.ProgramDesc.__create_program_desc__() block = program_desc.block(0) var = block.new_var('my_var') - var.set_data_type(2) - self.assertEqual(2, var.data_type) + var.set_data_type(core.DataType.INT32) + self.assertEqual(core.DataType.INT32, var.data_type()) class TestBlockDesc(unittest.TestCase):