diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 126c2ce1c7d0efe640b8095b4861c3a96cff42c2..de6db60730b2434379711deb52721ed826246af7 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -50,12 +50,12 @@ public: VarDesc *Proto() { return &desc_; } - void SetShape(const vector &dims) { + void SetShape(const std::vector &dims) { VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims()); } void SetDataType(int type_id) { - desc_.mutable_lod_tensor()->set_data_type(const_cast(type_id)); + desc_.mutable_lod_tensor()->set_data_type(static_cast(type_id)); } std::vector Shape() { @@ -86,7 +86,8 @@ public: int32_t Parent() const { return desc_->parent_idx(); } - VarDescBind *NewVar(const std::string &name) { + VarDescBind *NewVar(py::bytes name_bytes) { + std::string name = name_bytes; need_update_ = true; auto it = vars_.find(name); PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name); @@ -224,16 +225,15 @@ void BindBlockDesc(py::module &m) { &BlockDescBind::AppendOp, py::return_value_policy::reference) .def("new_var", - [](BlockDesc &self) { return self.add_vars(); }, + &BlockDescBind::NewVar, py::return_value_policy::reference); } void BindVarDsec(py::module &m) { - py::class_(m, "VarDesc", "") - .def(py::init<>()) - .def("set_shape", VarDescBind::SetShape) - .def("set_data_type", VarDescBind::SetDataType) - .def("shape", VarDescBind::Shape); + py::class_(m, "VarDesc", "") + .def("set_shape", &VarDescBind::SetShape) + .def("set_data_type", &VarDescBind::SetDataType) + .def("shape", &VarDescBind::Shape); } void BindOpDesc(py::module &m) { diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index fbe1f7152bb12dd7ddc69894a9ce2a892784d8f0..f1074f6bb5054edd2a1a9f0b1846aaff70917965 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -33,7 +33,7 @@ class TestVarDesc(unittest.TestCase): def test_shape(self): program_desc = core.ProgramDesc.instance() block = program_desc.root_block() - var = block.new_var() + var = block.new_var('my_var') src_shape = [3, 2, 10, 8] var.set_shape(src_shape) res_shape = var.shape()