提交 09b53c08 编写于 作者: L Luo Tao

add remove_var from c++ end

上级 95710456
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <deque>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
......@@ -96,6 +97,8 @@ class BlockDesc {
*/
void RemoveOp(size_t s, size_t e);
void RemoveVar(const std::string &name) { vars_.erase(name); }
std::vector<OpDesc *> AllOps() const;
size_t OpSize() const { return ops_.size(); }
......
......@@ -15,6 +15,8 @@ limitations under the License. */
#include "paddle/fluid/pybind/protobuf.h"
#include <deque>
#include <iostream>
#include <string>
#include <tuple>
#include "paddle/fluid/framework/backward.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
......@@ -98,7 +100,7 @@ namespace pybind {
using namespace paddle::framework; // NOLINT
template <typename T>
static py::bytes SerializeMessage(T &self) {
static py::bytes SerializeMessage(T &self) { // NOLINT
// Check IsInitialized in Python
std::string retv;
PADDLE_ENFORCE(self.Proto()->SerializePartialToString(&retv),
......@@ -107,7 +109,7 @@ static py::bytes SerializeMessage(T &self) {
}
// Bind Methods
void BindProgramDesc(py::module &m) {
void BindProgramDesc(py::module &m) { // NOLINT
py::class_<ProgramDesc>(m, "ProgramDesc", "")
.def(py::init<>())
.def("__init__",
......@@ -151,7 +153,7 @@ void BindProgramDesc(py::module &m) {
});
}
void BindBlockDesc(py::module &m) {
void BindBlockDesc(py::module &m) { // NOLINT
py::class_<BlockDesc>(m, "BlockDesc", "")
.def_property_readonly("id", &BlockDesc::ID)
.def_property_readonly("parent", &BlockDesc::Parent)
......@@ -200,13 +202,19 @@ void BindBlockDesc(py::module &m) {
return self.FindVarRecursive(name);
},
py::return_value_policy::reference)
.def("remove_var",
[](BlockDesc &self, py::bytes byte_name) {
std::string name = byte_name;
return self.RemoveVar(name);
},
py::return_value_policy::reference)
.def("all_vars", &BlockDesc::AllVars, py::return_value_policy::reference)
.def("op_size", &BlockDesc::OpSize)
.def("op", &BlockDesc::Op, py::return_value_policy::reference)
.def("serialize_to_string", SerializeMessage<BlockDesc>);
}
void BindVarDsec(py::module &m) {
void BindVarDsec(py::module &m) { // NOLINT
py::class_<VarDesc> var_desc(m, "VarDesc", "");
var_desc
.def("name",
......@@ -257,7 +265,7 @@ void BindVarDsec(py::module &m) {
.value("RAW", proto::VarType::RAW);
}
void BindOpDesc(py::module &m) {
void BindOpDesc(py::module &m) { // NOLINT
py::enum_<proto::AttrType>(m, "AttrType", "")
.value("INT", proto::AttrType::INT)
.value("INTS", proto::AttrType::INTS)
......
......@@ -19,9 +19,9 @@ from paddle.fluid.framework import Program
class TestOpDesc(unittest.TestCase):
def test_op_desc(self):
prog = core.ProgramDesc()
self.assertIsNotNone(prog)
block = prog.block(0)
program_desc = core.ProgramDesc()
self.assertIsNotNone(program_desc)
block = program_desc.block(0)
self.assertIsNotNone(block)
op = block.append_op()
self.assertIsNotNone(op)
......@@ -67,7 +67,7 @@ class TestOpDesc(unittest.TestCase):
self.assertEqual(8, len(op.attr_names()))
op.set_block_attr("block_attr", prog.block(0))
op.set_block_attr("block_attr", program_desc.block(0))
self.assertEqual(0, op.block_attr("block_attr"))
mul_op = block.append_op()
......@@ -88,20 +88,20 @@ class TestProgramDesc(unittest.TestCase):
del program_desc
def test_append_block(self):
prog_desc = core.ProgramDesc()
self.assertIsNotNone(prog_desc)
block_root = prog_desc.block(0)
program_desc = core.ProgramDesc()
self.assertIsNotNone(program_desc)
block_root = program_desc.block(0)
self.assertIsNotNone(block_root)
self.assertEqual(block_root.id, 0)
block1 = prog_desc.append_block(block_root)
block2 = prog_desc.append_block(block1)
block1 = program_desc.append_block(block_root)
block2 = program_desc.append_block(block1)
self.assertIsNotNone(block1)
self.assertEqual(block1.id, block2.parent)
self.assertEqual(block_root.id, block1.parent)
block3 = prog_desc.append_block(block_root)
block3 = program_desc.append_block(block_root)
self.assertEqual(block3.parent, block_root.id)
self.assertEqual(prog_desc.block(1).id, 1)
self.assertEqual(4, prog_desc.num_blocks())
self.assertEqual(program_desc.block(1).id, 1)
self.assertEqual(4, program_desc.num_blocks())
class TestVarDesc(unittest.TestCase):
......@@ -162,9 +162,9 @@ class TestVarDesc(unittest.TestCase):
class TestBlockDesc(unittest.TestCase):
def test_add_var(self):
prog = core.ProgramDesc()
self.assertIsNotNone(prog)
block = prog.block(0)
program_desc = core.ProgramDesc()
self.assertIsNotNone(program_desc)
block = program_desc.block(0)
self.assertIsNotNone(block)
var1 = block.var("var1")
var2 = block.var("var2")
......@@ -175,9 +175,9 @@ class TestBlockDesc(unittest.TestCase):
self.assertEqual(var2_re, var2)
def test_add_op(self):
prog = core.ProgramDesc()
self.assertIsNotNone(prog)
block = prog.block(0)
program_desc = core.ProgramDesc()
self.assertIsNotNone(program_desc)
block = program_desc.block(0)
self.assertIsNotNone(block)
op1 = block.append_op()
op2 = block.append_op()
......@@ -189,9 +189,9 @@ class TestBlockDesc(unittest.TestCase):
def test_remove_op(self):
program = Program()
prog = program.desc
self.assertIsNotNone(prog)
block = prog.block(0)
program_desc = program.desc
self.assertIsNotNone(program_desc)
block = program_desc.block(0)
self.assertIsNotNone(block)
op0 = block.append_op()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册