提交 6915c924 编写于 作者: F fengjiayi

Fix bug

上级 4fb106af
...@@ -130,16 +130,15 @@ public: ...@@ -130,16 +130,15 @@ public:
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims()); VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
} }
void SetDataType(int type_id) { void SetDataType(framework::DataType data_type) {
desc_.mutable_lod_tensor()->set_data_type( desc_.mutable_lod_tensor()->set_data_type(data_type);
static_cast<enum DataType>(type_id));
} }
std::vector<int64_t> Shape() { std::vector<int64_t> Shape() {
return RepeatedToVector(desc_.lod_tensor().dims()); return RepeatedToVector(desc_.lod_tensor().dims());
} }
int DataType() { return desc_.lod_tensor().data_type(); } framework::DataType DataType() { return desc_.lod_tensor().data_type(); }
private: private:
VarDesc desc_; VarDesc desc_;
...@@ -502,14 +501,21 @@ void BindBlockDesc(py::module &m) { ...@@ -502,14 +501,21 @@ void BindBlockDesc(py::module &m) {
} }
void BindVarDsec(py::module &m) { void BindVarDsec(py::module &m) {
py::enum_<framework::DataType>(m, "DataType", "")
.value("BOOL", DataType::BOOL)
.value("INT16", DataType::INT16)
.value("INT32", DataType::INT32)
.value("INT64", DataType::INT64)
.value("FP16", DataType::FP16)
.value("FP32", DataType::FP32)
.value("FP64", DataType::FP64);
py::class_<VarDescBind>(m, "VarDesc", "") py::class_<VarDescBind>(m, "VarDesc", "")
.def("name", &VarDescBind::Name, py::return_value_policy::reference) .def("name", &VarDescBind::Name, py::return_value_policy::reference)
.def("set_shape", &VarDescBind::SetShape) .def("set_shape", &VarDescBind::SetShape)
.def("set_data_type", &VarDescBind::SetDataType) .def("set_data_type", &VarDescBind::SetDataType)
.def("shape", &VarDescBind::Shape, py::return_value_policy::reference) .def("shape", &VarDescBind::Shape, py::return_value_policy::reference)
.def("data_type", .def("data_type", &VarDescBind::DataType);
&VarDescBind::DataType,
py::return_value_policy::reference);
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) {
......
...@@ -97,8 +97,8 @@ class TestVarDesc(unittest.TestCase): ...@@ -97,8 +97,8 @@ class TestVarDesc(unittest.TestCase):
program_desc = core.ProgramDesc.__create_program_desc__() program_desc = core.ProgramDesc.__create_program_desc__()
block = program_desc.block(0) block = program_desc.block(0)
var = block.new_var('my_var') var = block.new_var('my_var')
var.set_data_type(2) var.set_data_type(core.DataType.INT32)
self.assertEqual(2, var.data_type) self.assertEqual(core.DataType.INT32, var.data_type())
class TestBlockDesc(unittest.TestCase): class TestBlockDesc(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册