提交 bddb4060 编写于 作者: Y Yu Yang

Buggy code

上级 027fc62c
...@@ -13,10 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/pybind/protobuf.h" #include "paddle/pybind/protobuf.h"
#include <deque>
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
using namespace paddle::framework; // NOLINT
template <typename T> template <typename T>
inline std::vector<T> RepeatedToVector( inline std::vector<T> RepeatedToVector(
const google::protobuf::RepeatedField<T> &repeated_field) { const google::protobuf::RepeatedField<T> &repeated_field) {
...@@ -36,45 +39,154 @@ inline void VectorToRepeated(const std::vector<T> &vec, ...@@ -36,45 +39,154 @@ inline void VectorToRepeated(const std::vector<T> &vec,
} }
} }
class ProgramDescBind;
class OpDescBind;
class BlockDescBind;
class OpDescBind {
public:
explicit OpDescBind(BlockDescBind *block) : block_(block) {}
operator OpDesc *() { return &op_desc_; }
private:
BlockDescBind *block_;
OpDesc op_desc_;
};
class BlockDescBind {
public:
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {}
~BlockDescBind() {
std::cerr << "dtor " << this << "," << desc_ << std::endl;
}
int32_t id() const {
std::cerr << "desc ptr " << desc_ << std::endl;
return desc_->idx();
}
int32_t Parent() const { return desc_->parent_idx(); }
OpDescBind *AppendOp() {
need_update_ = true;
ops_.emplace_back(this);
return &ops_.back();
}
void Sync() {
if (need_update_) {
auto &op_field = *this->desc_->mutable_ops();
op_field.Clear();
op_field.Reserve(static_cast<int>(ops_.size()));
for (auto &op_desc : ops_) {
op_field.AddAllocated(op_desc);
}
}
}
private:
ProgramDescBind *prog_; // not_own
BlockDesc *desc_; // not_own
bool need_update_;
std::deque<OpDescBind> ops_;
};
using ProgDescMap =
std::unordered_map<ProgramDesc *, std::unique_ptr<ProgramDescBind>>;
static ProgDescMap *g_bind_map = nullptr;
class ProgramDescBind {
public:
static ProgramDescBind &Instance(ProgramDesc *prog) {
if (g_bind_map == nullptr) {
g_bind_map = new ProgDescMap();
}
auto &map = *g_bind_map;
auto &ptr = map[prog];
if (ptr == nullptr) {
ptr.reset(new ProgramDescBind(prog));
}
return *ptr;
}
BlockDescBind *AppendBlock(BlockDescBind *parent) {
auto *b = prog_->add_blocks();
std::cerr << "block ptr " << b << std::endl;
std::cerr << "pass ptr " << parent << std::endl;
b->set_parent_idx(parent->id());
b->set_idx(prog_->blocks_size() - 1);
blocks_.emplace_back(this, b);
return &blocks_.back();
}
BlockDescBind *Root() { return &blocks_.front(); }
BlockDescBind *Block(size_t idx) { return &blocks_[idx]; }
std::string DebugString() { return Proto()->DebugString(); }
size_t Size() const { return blocks_.size(); }
ProgramDesc *Proto() {
for (auto &block : blocks_) {
block.Sync();
}
return prog_;
}
private:
explicit ProgramDescBind(ProgramDesc *prog) : prog_(prog) {
for (auto &block : *prog->mutable_blocks()) {
blocks_.emplace_back(this, &block);
}
}
// Not owned
ProgramDesc *prog_;
std::vector<BlockDescBind> blocks_;
};
void BindProgramDesc(py::module &m) { void BindProgramDesc(py::module &m) {
using namespace paddle::framework; // NOLINT py::class_<ProgramDescBind>(m, "ProgramDesc", "")
py::class_<ProgramDesc>(m, "ProgramDesc", "")
.def_static("instance", .def_static("instance",
[] { return &GetProgramDesc(); }, []() -> ProgramDescBind * {
return &ProgramDescBind::Instance(&GetProgramDesc());
},
py::return_value_policy::reference) py::return_value_policy::reference)
.def_static("__create_program_desc__", .def_static("__create_program_desc__",
[] { []() -> ProgramDescBind * {
// Only used for unit-test // Only used for unit-test
auto *prog_desc = new ProgramDesc; auto *prog_desc = new ProgramDesc;
auto *block = prog_desc->mutable_blocks()->Add(); auto *block = prog_desc->mutable_blocks()->Add();
block->set_idx(0); block->set_idx(0);
block->set_parent_idx(-1); block->set_parent_idx(-1);
return prog_desc; return &ProgramDescBind::Instance(prog_desc);
}) },
py::return_value_policy::reference)
.def("append_block", .def("append_block",
[](ProgramDesc &self, BlockDesc &parent) { &ProgramDescBind::AppendBlock,
auto desc = self.add_blocks();
desc->set_idx(self.mutable_blocks()->size() - 1);
desc->set_parent_idx(parent.idx());
return desc;
},
py::return_value_policy::reference) py::return_value_policy::reference)
.def("root_block", .def("root_block",
[](ProgramDesc &self) { return self.mutable_blocks()->Mutable(0); }, &ProgramDescBind::Root,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("block", .def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
[](ProgramDesc &self, int id) { return self.blocks(id); }, .def("__str__", &ProgramDescBind::DebugString)
py::return_value_policy::reference) .def("num_blocks", &ProgramDescBind::Size);
.def("__str__", [](ProgramDesc &self) { return self.DebugString(); });
} }
void BindBlockDesc(py::module &m) { void BindBlockDesc(py::module &m) {
using namespace paddle::framework; // NOLINT using namespace paddle::framework; // NOLINT
py::class_<BlockDesc>(m, "BlockDesc", "") py::class_<BlockDescBind>(m, "BlockDesc", "")
.def("id", [](BlockDesc &self) { return self.idx(); }) .def_property_readonly("id", &BlockDescBind::id)
.def("parent", [](BlockDesc &self) { return self.parent_idx(); }) .def_property_readonly("parent", &BlockDescBind::Parent)
.def("append_op", .def("append_op",
[](BlockDesc &self) { return self.add_ops(); }, &BlockDescBind::AppendOp,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("new_var", .def("new_var",
[](BlockDesc &self) { return self.add_vars(); }, [](BlockDesc &self) { return self.add_vars(); },
...@@ -82,73 +194,76 @@ void BindBlockDesc(py::module &m) { ...@@ -82,73 +194,76 @@ void BindBlockDesc(py::module &m) {
} }
void BindVarDsec(py::module &m) { void BindVarDsec(py::module &m) {
using namespace paddle::framework; // NOLINT py::class_<VarDesc>(m, "VarDesc", "");
py::class_<VarDesc>(m, "VarDesc", "") // using namespace paddle::framework; // NOLINT
.def(py::init<>()) // py::class_<VarDesc>(m, "VarDesc", "")
.def("set_name", // .def(py::init<>())
[](VarDesc &self, const std::string &name) { self.set_name(name); }) // .def("set_name",
.def("set_shape", // [](VarDesc &self, const std::string &name) { self.set_name(name);
[](VarDesc &self, const std::vector<int64_t> &dims) { // })
VectorToRepeated(dims, self.mutable_lod_tensor()->mutable_dims()); // .def("set_shape",
}) // [](VarDesc &self, const std::vector<int64_t> &dims) {
.def("set_data_type", // VectorToRepeated(dims,
[](VarDesc &self, int type_id) { // self.mutable_lod_tensor()->mutable_dims());
LoDTensorDesc *lod_tensor_desc = self.mutable_lod_tensor(); // })
lod_tensor_desc->set_data_type(static_cast<DataType>(type_id)); // .def("set_data_type",
}) // [](VarDesc &self, int type_id) {
.def("shape", [](VarDesc &self) { // LoDTensorDesc *lod_tensor_desc = self.mutable_lod_tensor();
const LoDTensorDesc &lod_tensor_desc = self.lod_tensor(); // lod_tensor_desc->set_data_type(static_cast<DataType>(type_id));
return RepeatedToVector(lod_tensor_desc.dims()); // })
}); // .def("shape", [](VarDesc &self) {
// const LoDTensorDesc &lod_tensor_desc = self.lod_tensor();
// return RepeatedToVector(lod_tensor_desc.dims());
// });
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) {
using namespace paddle::framework; // NOLINT // auto op_desc_set_var = [](OpDesc::Var *var,
auto op_desc_set_var = [](OpDesc::Var *var, // const std::string &parameter,
const std::string &parameter, // const std::vector<std::string> &arguments) {
const std::vector<std::string> &arguments) { // var->set_parameter(parameter);
var->set_parameter(parameter); // VectorToRepeated(arguments, var->mutable_arguments());
VectorToRepeated(arguments, var->mutable_arguments()); // };
}; //
// auto op_desc_set_attr = [](OpDesc &desc, const std::string &name) {
auto op_desc_set_attr = [](OpDesc &desc, const std::string &name) { // auto attr = desc.add_attrs();
auto attr = desc.add_attrs(); // attr->set_name(name);
attr->set_name(name); // return attr;
return attr; // };
}; py::class_<OpDescBind>(m, "OpDesc", "");
py::class_<OpDesc>(m, "OpDesc", "") // .def("type", [](OpDesc &op) { return op.type(); })
.def("type", [](OpDesc &op) { return op.type(); }) // .def("set_input",
.def("set_input", // [op_desc_set_var](OpDesc &self,
[op_desc_set_var](OpDesc &self, // const std::string &parameter,
const std::string &parameter, // const std::vector<std::string> &arguments) {
const std::vector<std::string> &arguments) { // auto ipt = self.add_inputs();
auto ipt = self.add_inputs(); // op_desc_set_var(ipt, parameter, arguments);
op_desc_set_var(ipt, parameter, arguments); // })
}) // .def("input_names",
.def("input_names", // [](OpDesc &self) {
[](OpDesc &self) { // std::vector<std::string> ret_val;
std::vector<std::string> ret_val; // ret_val.reserve(static_cast<size_t>(self.inputs().size()));
ret_val.reserve(static_cast<size_t>(self.inputs().size())); // std::transform(
std::transform( // self.inputs().begin(),
self.inputs().begin(), // self.inputs().end(),
self.inputs().end(), // std::back_inserter(ret_val),
std::back_inserter(ret_val), // [](const OpDesc::Var &var) { return var.parameter(); });
[](const OpDesc::Var &var) { return var.parameter(); }); // return ret_val;
return ret_val; // })
}) // .def("__str__", [](OpDesc &self) { return self.DebugString(); })
.def("__str__", [](OpDesc &self) { return self.DebugString(); }) // .def("set_output",
.def("set_output", // [op_desc_set_var](OpDesc &self,
[op_desc_set_var](OpDesc &self, // const std::string &parameter,
const std::string &parameter, // const std::vector<std::string> &arguments) {
const std::vector<std::string> &arguments) { // auto opt = self.add_outputs();
auto opt = self.add_outputs(); // op_desc_set_var(opt, parameter, arguments);
op_desc_set_var(opt, parameter, arguments); // })
}) // .def("set_attr",
.def("set_attr", // [op_desc_set_attr](OpDesc &self, const std::string &name, int i)
[op_desc_set_attr](OpDesc &self, const std::string &name, int i) { // {
op_desc_set_attr(self, name)->set_i(i); // op_desc_set_attr(self, name)->set_i(i);
}); // });
} }
} // namespace pybind } // namespace pybind
} // namespace paddle } // namespace paddle
...@@ -9,21 +9,28 @@ class TestProgramDesc(unittest.TestCase): ...@@ -9,21 +9,28 @@ class TestProgramDesc(unittest.TestCase):
del program_desc del program_desc
program_desc = core.ProgramDesc.instance() program_desc = core.ProgramDesc.instance()
self.assertIsNotNone(program_desc) self.assertIsNotNone(program_desc)
self.assertIsNotNone(program_desc.root_block()) self.assertIsNotNone(program_desc.block(0))
del program_desc del program_desc
def test_append_block(self): def test_append_block(self):
prog_desc = core.ProgramDesc.__create_program_desc__() prog_desc = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog_desc) self.assertIsNotNone(prog_desc)
block_root = prog_desc.root_block() block_root = prog_desc.block(0)
self.assertEqual(block_root.id(), 0) self.assertIsNotNone(block_root)
print 'here'
self.assertEqual(block_root.id, 0)
block1 = prog_desc.append_block(block_root) block1 = prog_desc.append_block(block_root)
block2 = prog_desc.append_block(block1) block2 = prog_desc.append_block(block1)
self.assertEqual(block1.id(), block2.parent()) self.assertIsNotNone(block1)
self.assertEqual(block_root.id(), block1.parent()) print 'here'
self.assertEqual(block1.id, block2.parent)
print 'here'
self.assertEqual(block_root.id, block1.parent)
print 'here'
block3 = prog_desc.append_block(block_root) block3 = prog_desc.append_block(block_root)
self.assertEqual(block3.parent(), block_root.id()) self.assertEqual(block3.parent, block_root.id)
self.assertEqual(prog_desc.block(1).id(), 1) self.assertEqual(prog_desc.block(1).id, 1)
self.assertEqual(4, prog_desc.num_blocks())
class TestVarDesc(unittest.TestCase): class TestVarDesc(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册