提交 09b53c08 编写于 作者: L Luo Tao

add remove_var from c++ end

上级 95710456
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <deque> #include <deque>
#include <memory> #include <memory>
#include <set> #include <set>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -96,6 +97,8 @@ class BlockDesc { ...@@ -96,6 +97,8 @@ class BlockDesc {
*/ */
void RemoveOp(size_t s, size_t e); void RemoveOp(size_t s, size_t e);
void RemoveVar(const std::string &name) { vars_.erase(name); }
std::vector<OpDesc *> AllOps() const; std::vector<OpDesc *> AllOps() const;
size_t OpSize() const { return ops_.size(); } size_t OpSize() const { return ops_.size(); }
......
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#include "paddle/fluid/pybind/protobuf.h" #include "paddle/fluid/pybind/protobuf.h"
#include <deque> #include <deque>
#include <iostream> #include <iostream>
#include <string>
#include <tuple>
#include "paddle/fluid/framework/backward.h" #include "paddle/fluid/framework/backward.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
...@@ -98,7 +100,7 @@ namespace pybind { ...@@ -98,7 +100,7 @@ namespace pybind {
using namespace paddle::framework; // NOLINT using namespace paddle::framework; // NOLINT
template <typename T> template <typename T>
static py::bytes SerializeMessage(T &self) { static py::bytes SerializeMessage(T &self) { // NOLINT
// Check IsInitialized in Python // Check IsInitialized in Python
std::string retv; std::string retv;
PADDLE_ENFORCE(self.Proto()->SerializePartialToString(&retv), PADDLE_ENFORCE(self.Proto()->SerializePartialToString(&retv),
...@@ -107,7 +109,7 @@ static py::bytes SerializeMessage(T &self) { ...@@ -107,7 +109,7 @@ static py::bytes SerializeMessage(T &self) {
} }
// Bind Methods // Bind Methods
void BindProgramDesc(py::module &m) { void BindProgramDesc(py::module &m) { // NOLINT
py::class_<ProgramDesc>(m, "ProgramDesc", "") py::class_<ProgramDesc>(m, "ProgramDesc", "")
.def(py::init<>()) .def(py::init<>())
.def("__init__", .def("__init__",
...@@ -151,7 +153,7 @@ void BindProgramDesc(py::module &m) { ...@@ -151,7 +153,7 @@ void BindProgramDesc(py::module &m) {
}); });
} }
void BindBlockDesc(py::module &m) { void BindBlockDesc(py::module &m) { // NOLINT
py::class_<BlockDesc>(m, "BlockDesc", "") py::class_<BlockDesc>(m, "BlockDesc", "")
.def_property_readonly("id", &BlockDesc::ID) .def_property_readonly("id", &BlockDesc::ID)
.def_property_readonly("parent", &BlockDesc::Parent) .def_property_readonly("parent", &BlockDesc::Parent)
...@@ -200,13 +202,19 @@ void BindBlockDesc(py::module &m) { ...@@ -200,13 +202,19 @@ void BindBlockDesc(py::module &m) {
return self.FindVarRecursive(name); return self.FindVarRecursive(name);
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("remove_var",
[](BlockDesc &self, py::bytes byte_name) {
std::string name = byte_name;
return self.RemoveVar(name);
},
py::return_value_policy::reference)
.def("all_vars", &BlockDesc::AllVars, py::return_value_policy::reference) .def("all_vars", &BlockDesc::AllVars, py::return_value_policy::reference)
.def("op_size", &BlockDesc::OpSize) .def("op_size", &BlockDesc::OpSize)
.def("op", &BlockDesc::Op, py::return_value_policy::reference) .def("op", &BlockDesc::Op, py::return_value_policy::reference)
.def("serialize_to_string", SerializeMessage<BlockDesc>); .def("serialize_to_string", SerializeMessage<BlockDesc>);
} }
void BindVarDsec(py::module &m) { void BindVarDsec(py::module &m) { // NOLINT
py::class_<VarDesc> var_desc(m, "VarDesc", ""); py::class_<VarDesc> var_desc(m, "VarDesc", "");
var_desc var_desc
.def("name", .def("name",
...@@ -257,7 +265,7 @@ void BindVarDsec(py::module &m) { ...@@ -257,7 +265,7 @@ void BindVarDsec(py::module &m) {
.value("RAW", proto::VarType::RAW); .value("RAW", proto::VarType::RAW);
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) { // NOLINT
py::enum_<proto::AttrType>(m, "AttrType", "") py::enum_<proto::AttrType>(m, "AttrType", "")
.value("INT", proto::AttrType::INT) .value("INT", proto::AttrType::INT)
.value("INTS", proto::AttrType::INTS) .value("INTS", proto::AttrType::INTS)
......
...@@ -19,9 +19,9 @@ from paddle.fluid.framework import Program ...@@ -19,9 +19,9 @@ from paddle.fluid.framework import Program
class TestOpDesc(unittest.TestCase): class TestOpDesc(unittest.TestCase):
def test_op_desc(self): def test_op_desc(self):
prog = core.ProgramDesc() program_desc = core.ProgramDesc()
self.assertIsNotNone(prog) self.assertIsNotNone(program_desc)
block = prog.block(0) block = program_desc.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
op = block.append_op() op = block.append_op()
self.assertIsNotNone(op) self.assertIsNotNone(op)
...@@ -67,7 +67,7 @@ class TestOpDesc(unittest.TestCase): ...@@ -67,7 +67,7 @@ class TestOpDesc(unittest.TestCase):
self.assertEqual(8, len(op.attr_names())) self.assertEqual(8, len(op.attr_names()))
op.set_block_attr("block_attr", prog.block(0)) op.set_block_attr("block_attr", program_desc.block(0))
self.assertEqual(0, op.block_attr("block_attr")) self.assertEqual(0, op.block_attr("block_attr"))
mul_op = block.append_op() mul_op = block.append_op()
...@@ -88,20 +88,20 @@ class TestProgramDesc(unittest.TestCase): ...@@ -88,20 +88,20 @@ class TestProgramDesc(unittest.TestCase):
del program_desc del program_desc
def test_append_block(self): def test_append_block(self):
prog_desc = core.ProgramDesc() program_desc = core.ProgramDesc()
self.assertIsNotNone(prog_desc) self.assertIsNotNone(program_desc)
block_root = prog_desc.block(0) block_root = program_desc.block(0)
self.assertIsNotNone(block_root) self.assertIsNotNone(block_root)
self.assertEqual(block_root.id, 0) self.assertEqual(block_root.id, 0)
block1 = prog_desc.append_block(block_root) block1 = program_desc.append_block(block_root)
block2 = prog_desc.append_block(block1) block2 = program_desc.append_block(block1)
self.assertIsNotNone(block1) self.assertIsNotNone(block1)
self.assertEqual(block1.id, block2.parent) self.assertEqual(block1.id, block2.parent)
self.assertEqual(block_root.id, block1.parent) self.assertEqual(block_root.id, block1.parent)
block3 = prog_desc.append_block(block_root) block3 = program_desc.append_block(block_root)
self.assertEqual(block3.parent, block_root.id) self.assertEqual(block3.parent, block_root.id)
self.assertEqual(prog_desc.block(1).id, 1) self.assertEqual(program_desc.block(1).id, 1)
self.assertEqual(4, prog_desc.num_blocks()) self.assertEqual(4, program_desc.num_blocks())
class TestVarDesc(unittest.TestCase): class TestVarDesc(unittest.TestCase):
...@@ -162,9 +162,9 @@ class TestVarDesc(unittest.TestCase): ...@@ -162,9 +162,9 @@ class TestVarDesc(unittest.TestCase):
class TestBlockDesc(unittest.TestCase): class TestBlockDesc(unittest.TestCase):
def test_add_var(self): def test_add_var(self):
prog = core.ProgramDesc() program_desc = core.ProgramDesc()
self.assertIsNotNone(prog) self.assertIsNotNone(program_desc)
block = prog.block(0) block = program_desc.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
var1 = block.var("var1") var1 = block.var("var1")
var2 = block.var("var2") var2 = block.var("var2")
...@@ -175,9 +175,9 @@ class TestBlockDesc(unittest.TestCase): ...@@ -175,9 +175,9 @@ class TestBlockDesc(unittest.TestCase):
self.assertEqual(var2_re, var2) self.assertEqual(var2_re, var2)
def test_add_op(self): def test_add_op(self):
prog = core.ProgramDesc() program_desc = core.ProgramDesc()
self.assertIsNotNone(prog) self.assertIsNotNone(program_desc)
block = prog.block(0) block = program_desc.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
op1 = block.append_op() op1 = block.append_op()
op2 = block.append_op() op2 = block.append_op()
...@@ -189,9 +189,9 @@ class TestBlockDesc(unittest.TestCase): ...@@ -189,9 +189,9 @@ class TestBlockDesc(unittest.TestCase):
def test_remove_op(self): def test_remove_op(self):
program = Program() program = Program()
prog = program.desc program_desc = program.desc
self.assertIsNotNone(prog) self.assertIsNotNone(program_desc)
block = prog.block(0) block = program_desc.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
op0 = block.append_op() op0 = block.append_op()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册