diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 673e0ab80b75e7584aa72beeafee32089415655b..5d5782a6f8abb100b6f37b5187be8af8c15ce6d5 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -99,7 +99,7 @@ template inline void VectorToRepeated(const std::vector &vec, RepeatedField *repeated_field) { repeated_field->Reserve(vec.size()); - for (auto &elem : vec) { + for (const auto &elem : vec) { *repeated_field->Add() = elem; } } @@ -124,18 +124,23 @@ public: VarDesc *Proto() { return &desc_; } + py::bytes Name() { return desc_.name(); } + void SetShape(const std::vector &dims) { VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims()); } void SetDataType(int type_id) { - desc_.mutable_lod_tensor()->set_data_type(static_cast(type_id)); + desc_.mutable_lod_tensor()->set_data_type( + static_cast(type_id)); } std::vector Shape() { return RepeatedToVector(desc_.lod_tensor().dims()); } + int DataType() { return desc_.lod_tensor().data_type(); } + private: VarDesc desc_; }; @@ -322,6 +327,22 @@ public: return var; } + VarDescBind *Var(py::bytes name_bytes) const { + std::string name = name_bytes; + auto it = vars_.find(name); + PADDLE_ENFORCE( + it != vars_.end(), "Can not find variable %s in current block.", name); + return it->second.get(); + } + + std::vector AllVars() const { + std::vector res; + for (const auto &p : vars_) { + res.push_back(p.second.get()); + } + return res; + } + BlockDescBind *ParentBlock() const; OpDescBind *AppendOp() { @@ -336,6 +357,14 @@ public: return ops_.front().get(); } + std::vector AllOps() const { + std::vector res; + for (const auto &op : ops_) { + res.push_back(op.get()); + } + return res; + } + void Sync() { if (need_update_) { auto &op_field = *this->desc_->mutable_ops(); @@ -461,16 +490,26 @@ void BindBlockDesc(py::module &m) { .def("prepend_op", &BlockDescBind::PrependOp, py::return_value_policy::reference) - .def("new_var", - &BlockDescBind::NewVar, + .def( + "new_var", &BlockDescBind::NewVar, py::return_value_policy::reference) + .def("var", &BlockDescBind::Var, py::return_value_policy::reference) + .def("all_vars", + &BlockDescBind::AllVars, + py::return_value_policy::reference) + .def("all_ops", + &BlockDescBind::AllOps, py::return_value_policy::reference); } void BindVarDsec(py::module &m) { py::class_(m, "VarDesc", "") + .def("name", &VarDescBind::Name, py::return_value_policy::reference) .def("set_shape", &VarDescBind::SetShape) .def("set_data_type", &VarDescBind::SetDataType) - .def("shape", &VarDescBind::Shape); + .def("shape", &VarDescBind::Shape, py::return_value_policy::reference) + .def("data_type", + &VarDescBind::DataType, + py::return_value_policy::reference); } void BindOpDesc(py::module &m) { diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index aa9a0b33ac6ad7e37202a79fb45baddf3d889338..0dde9729a9b94fea0a27375a6bd7ce82c63a9190 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -57,7 +57,7 @@ class TestOpDesc(unittest.TestCase): class TestProgramDesc(unittest.TestCase): def test_instance(self): - program_desc = core.ProgramDesc.instance() + program_desc = core.ProgramDesc.__create_program_desc__() self.assertIsNotNone(program_desc) del program_desc program_desc = core.ProgramDesc.instance() @@ -84,7 +84,7 @@ class TestProgramDesc(unittest.TestCase): class TestVarDesc(unittest.TestCase): def test_shape(self): - program_desc = core.ProgramDesc.instance() + program_desc = core.ProgramDesc.__create_program_desc__() block = program_desc.block(0) var = block.new_var('my_var') src_shape = [3, 2, 10, 8] @@ -92,6 +92,39 @@ class TestVarDesc(unittest.TestCase): res_shape = var.shape() self.assertEqual(src_shape, res_shape) + def test_data_type(self): + program_desc = core.ProgramDesc.__create_program_desc__() + block = program_desc.block(0) + var = block.new_var('my_var') + var.set_data_type(2) + self.assertEqual(2, var.data_type) + + +class TestBlockDesc(unittest.TestCase): + def test_add_var(self): + prog = core.ProgramDesc.__create_program_desc__() + self.assertIsNotNone(prog) + block = prog.block(0) + self.assertIsNotNone(block) + var1 = block.new_var("var1") + var2 = block.new_var("var2") + var3 = block.new_var("var3") + all_vars = block.all_vars() + self.assertEqual(set(all_vars), set([var1, var2, var3])) + var2_re = block.var("var2") + self.assertEqual(var2_re, var2) + + def test_add_op(self): + prog = core.ProgramDesc.__create_program_desc__() + self.assertIsNotNone(prog) + block = prog.block(0) + self.assertIsNotNone(block) + op1 = block.append_op() + op2 = block.append_op() + op0 = block.prepend_op() + all_ops = block.all_ops() + self.assertEqual(all_ops, [op0, op1, op2]) + if __name__ == '__main__': unittest.main()