提交 6915c924 编写于 作者: F fengjiayi

Fix bug

上级 4fb106af
......@@ -130,16 +130,15 @@ public:
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
}
void SetDataType(int type_id) {
desc_.mutable_lod_tensor()->set_data_type(
static_cast<enum DataType>(type_id));
void SetDataType(framework::DataType data_type) {
desc_.mutable_lod_tensor()->set_data_type(data_type);
}
std::vector<int64_t> Shape() {
return RepeatedToVector(desc_.lod_tensor().dims());
}
int DataType() { return desc_.lod_tensor().data_type(); }
framework::DataType DataType() { return desc_.lod_tensor().data_type(); }
private:
VarDesc desc_;
......@@ -502,14 +501,21 @@ void BindBlockDesc(py::module &m) {
}
void BindVarDsec(py::module &m) {
py::enum_<framework::DataType>(m, "DataType", "")
.value("BOOL", DataType::BOOL)
.value("INT16", DataType::INT16)
.value("INT32", DataType::INT32)
.value("INT64", DataType::INT64)
.value("FP16", DataType::FP16)
.value("FP32", DataType::FP32)
.value("FP64", DataType::FP64);
py::class_<VarDescBind>(m, "VarDesc", "")
.def("name", &VarDescBind::Name, py::return_value_policy::reference)
.def("set_shape", &VarDescBind::SetShape)
.def("set_data_type", &VarDescBind::SetDataType)
.def("shape", &VarDescBind::Shape, py::return_value_policy::reference)
.def("data_type",
&VarDescBind::DataType,
py::return_value_policy::reference);
.def("data_type", &VarDescBind::DataType);
}
void BindOpDesc(py::module &m) {
......
......@@ -97,8 +97,8 @@ class TestVarDesc(unittest.TestCase):
program_desc = core.ProgramDesc.__create_program_desc__()
block = program_desc.block(0)
var = block.new_var('my_var')
var.set_data_type(2)
self.assertEqual(2, var.data_type)
var.set_data_type(core.DataType.INT32)
self.assertEqual(core.DataType.INT32, var.data_type())
class TestBlockDesc(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册