diff --git a/paddle/fluid/framework/block_desc.h b/paddle/fluid/framework/block_desc.h index 468423e0e8e7b8c9ebc14b7568c9c3bd21645ea7..873969b2a884f6d9e133fe87bf72725c36ce8b98 100644 --- a/paddle/fluid/framework/block_desc.h +++ b/paddle/fluid/framework/block_desc.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include #include @@ -96,6 +97,8 @@ class BlockDesc { */ void RemoveOp(size_t s, size_t e); + void RemoveVar(const std::string &name) { vars_.erase(name); } + std::vector AllOps() const; size_t OpSize() const { return ops_.size(); } diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 45a64f43846e79c27295e52c59dca6bdfaa120a3..985984983a2239f6961bf519bae27fcbb9e7d6d3 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -15,6 +15,8 @@ limitations under the License. */ #include "paddle/fluid/pybind/protobuf.h" #include #include +#include +#include #include "paddle/fluid/framework/backward.h" #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_desc.h" @@ -98,7 +100,7 @@ namespace pybind { using namespace paddle::framework; // NOLINT template -static py::bytes SerializeMessage(T &self) { +static py::bytes SerializeMessage(T &self) { // NOLINT // Check IsInitialized in Python std::string retv; PADDLE_ENFORCE(self.Proto()->SerializePartialToString(&retv), @@ -107,7 +109,7 @@ static py::bytes SerializeMessage(T &self) { } // Bind Methods -void BindProgramDesc(py::module &m) { +void BindProgramDesc(py::module &m) { // NOLINT py::class_(m, "ProgramDesc", "") .def(py::init<>()) .def("__init__", @@ -151,7 +153,7 @@ void BindProgramDesc(py::module &m) { }); } -void BindBlockDesc(py::module &m) { +void BindBlockDesc(py::module &m) { // NOLINT py::class_(m, "BlockDesc", "") .def_property_readonly("id", &BlockDesc::ID) .def_property_readonly("parent", &BlockDesc::Parent) @@ -200,13 +202,19 @@ void BindBlockDesc(py::module &m) { return self.FindVarRecursive(name); }, py::return_value_policy::reference) + .def("remove_var", + [](BlockDesc &self, py::bytes byte_name) { + std::string name = byte_name; + return self.RemoveVar(name); + }, + py::return_value_policy::reference) .def("all_vars", &BlockDesc::AllVars, py::return_value_policy::reference) .def("op_size", &BlockDesc::OpSize) .def("op", &BlockDesc::Op, py::return_value_policy::reference) .def("serialize_to_string", SerializeMessage); } -void BindVarDsec(py::module &m) { +void BindVarDsec(py::module &m) { // NOLINT py::class_ var_desc(m, "VarDesc", ""); var_desc .def("name", @@ -257,7 +265,7 @@ void BindVarDsec(py::module &m) { .value("RAW", proto::VarType::RAW); } -void BindOpDesc(py::module &m) { +void BindOpDesc(py::module &m) { // NOLINT py::enum_(m, "AttrType", "") .value("INT", proto::AttrType::INT) .value("INTS", proto::AttrType::INTS) diff --git a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py index e4cf4a8bce8a53c0348130716dc18c61ac9a5913..f98a8bbc68a4315df3ae761f2e52b8f11cb620c6 100644 --- a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py +++ b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py @@ -19,9 +19,9 @@ from paddle.fluid.framework import Program class TestOpDesc(unittest.TestCase): def test_op_desc(self): - prog = core.ProgramDesc() - self.assertIsNotNone(prog) - block = prog.block(0) + program_desc = core.ProgramDesc() + self.assertIsNotNone(program_desc) + block = program_desc.block(0) self.assertIsNotNone(block) op = block.append_op() self.assertIsNotNone(op) @@ -67,7 +67,7 @@ class TestOpDesc(unittest.TestCase): self.assertEqual(8, len(op.attr_names())) - op.set_block_attr("block_attr", prog.block(0)) + op.set_block_attr("block_attr", program_desc.block(0)) self.assertEqual(0, op.block_attr("block_attr")) mul_op = block.append_op() @@ -88,20 +88,20 @@ class TestProgramDesc(unittest.TestCase): del program_desc def test_append_block(self): - prog_desc = core.ProgramDesc() - self.assertIsNotNone(prog_desc) - block_root = prog_desc.block(0) + program_desc = core.ProgramDesc() + self.assertIsNotNone(program_desc) + block_root = program_desc.block(0) self.assertIsNotNone(block_root) self.assertEqual(block_root.id, 0) - block1 = prog_desc.append_block(block_root) - block2 = prog_desc.append_block(block1) + block1 = program_desc.append_block(block_root) + block2 = program_desc.append_block(block1) self.assertIsNotNone(block1) self.assertEqual(block1.id, block2.parent) self.assertEqual(block_root.id, block1.parent) - block3 = prog_desc.append_block(block_root) + block3 = program_desc.append_block(block_root) self.assertEqual(block3.parent, block_root.id) - self.assertEqual(prog_desc.block(1).id, 1) - self.assertEqual(4, prog_desc.num_blocks()) + self.assertEqual(program_desc.block(1).id, 1) + self.assertEqual(4, program_desc.num_blocks()) class TestVarDesc(unittest.TestCase): @@ -162,9 +162,9 @@ class TestVarDesc(unittest.TestCase): class TestBlockDesc(unittest.TestCase): def test_add_var(self): - prog = core.ProgramDesc() - self.assertIsNotNone(prog) - block = prog.block(0) + program_desc = core.ProgramDesc() + self.assertIsNotNone(program_desc) + block = program_desc.block(0) self.assertIsNotNone(block) var1 = block.var("var1") var2 = block.var("var2") @@ -175,9 +175,9 @@ class TestBlockDesc(unittest.TestCase): self.assertEqual(var2_re, var2) def test_add_op(self): - prog = core.ProgramDesc() - self.assertIsNotNone(prog) - block = prog.block(0) + program_desc = core.ProgramDesc() + self.assertIsNotNone(program_desc) + block = program_desc.block(0) self.assertIsNotNone(block) op1 = block.append_op() op2 = block.append_op() @@ -189,9 +189,9 @@ class TestBlockDesc(unittest.TestCase): def test_remove_op(self): program = Program() - prog = program.desc - self.assertIsNotNone(prog) - block = prog.block(0) + program_desc = program.desc + self.assertIsNotNone(program_desc) + block = program_desc.block(0) self.assertIsNotNone(block) op0 = block.append_op()