提交 5d9ce046 编写于 作者: F fengjiayi 提交者: GitHub

Debug string for Python ProtoBuf (#4800)

* Add debug string for Python ProtoBuf

and Rename `Sync` to `Flush`

* Add check of ProtoBuf initialization
上级 2c46666e
...@@ -66,7 +66,7 @@ std::vector<OpDescBind *> BlockDescBind::AllOps() const { ...@@ -66,7 +66,7 @@ std::vector<OpDescBind *> BlockDescBind::AllOps() const {
return res; return res;
} }
void BlockDescBind::Sync() { void BlockDescBind::Flush() {
if (need_update_) { if (need_update_) {
auto &op_field = *this->desc_->mutable_ops(); auto &op_field = *this->desc_->mutable_ops();
op_field.Clear(); op_field.Clear();
...@@ -91,5 +91,10 @@ BlockDescBind *BlockDescBind::ParentBlock() const { ...@@ -91,5 +91,10 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx())); return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
} }
BlockDesc *BlockDescBind::Proto() {
Flush();
return desc_;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -65,9 +65,9 @@ class BlockDescBind { ...@@ -65,9 +65,9 @@ class BlockDescBind {
std::vector<OpDescBind *> AllOps() const; std::vector<OpDescBind *> AllOps() const;
void Sync(); void Flush();
BlockDesc *RawPtr() { return desc_; } BlockDesc *Proto();
private: private:
ProgramDescBind *prog_; // not_own ProgramDescBind *prog_; // not_own
......
...@@ -32,7 +32,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, ...@@ -32,7 +32,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
} }
OpDesc *OpDescBind::Proto() { OpDesc *OpDescBind::Proto() {
Sync(); Flush();
return &op_desc_; return &op_desc_;
} }
...@@ -101,7 +101,7 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) { ...@@ -101,7 +101,7 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
} }
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
BlockDesc *desc = block.RawPtr(); BlockDesc *desc = block.Proto();
this->attrs_[name] = desc; this->attrs_[name] = desc;
need_update_ = true; need_update_ = true;
} }
...@@ -165,7 +165,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> { ...@@ -165,7 +165,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
}; };
void OpDescBind::Sync() { void OpDescBind::Flush() {
if (need_update_) { if (need_update_) {
this->op_desc_.mutable_inputs()->Clear(); this->op_desc_.mutable_inputs()->Clear();
for (auto &ipt : inputs_) { for (auto &ipt : inputs_) {
......
...@@ -89,8 +89,6 @@ class OpDescBind { ...@@ -89,8 +89,6 @@ class OpDescBind {
this->need_update_ = true; this->need_update_ = true;
} }
void Sync();
const VariableNameMap &Inputs() const { return inputs_; } const VariableNameMap &Inputs() const { return inputs_; }
const VariableNameMap &Outputs() const { return outputs_; } const VariableNameMap &Outputs() const { return outputs_; }
...@@ -104,6 +102,8 @@ class OpDescBind { ...@@ -104,6 +102,8 @@ class OpDescBind {
void InferShape(const BlockDescBind &block) const; void InferShape(const BlockDescBind &block) const;
void Flush();
private: private:
template <typename MapType> template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) { static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
...@@ -45,7 +45,7 @@ BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) { ...@@ -45,7 +45,7 @@ BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
ProgramDesc *ProgramDescBind::Proto() { ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) { for (auto &block : blocks_) {
block->Sync(); block->Flush();
} }
return prog_; return prog_;
} }
......
...@@ -123,7 +123,18 @@ void BindProgramDesc(py::module &m) { ...@@ -123,7 +123,18 @@ void BindProgramDesc(py::module &m) {
AppendBackward(program_desc, no_grad_vars); AppendBackward(program_desc, no_grad_vars);
}) })
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
.def("num_blocks", &ProgramDescBind::Size); .def("num_blocks", &ProgramDescBind::Size)
.def("serialize_to_string",
[](ProgramDescBind &program_desc) -> py::bytes {
const ProgramDesc *desc = program_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"ProgramDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize ProgramDesc Error. This could be a bug of Paddle.");
return res;
});
} }
void BindBlockDesc(py::module &m) { void BindBlockDesc(py::module &m) {
...@@ -149,7 +160,17 @@ void BindBlockDesc(py::module &m) { ...@@ -149,7 +160,17 @@ void BindBlockDesc(py::module &m) {
.def("all_vars", &BlockDescBind::AllVars, .def("all_vars", &BlockDescBind::AllVars,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("all_ops", &BlockDescBind::AllOps, .def("all_ops", &BlockDescBind::AllOps,
py::return_value_policy::reference); py::return_value_policy::reference)
.def("serialize_to_string", [](BlockDescBind &block_desc) -> py::bytes {
const BlockDesc *desc = block_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"BlockDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize BlockDesc Error. This could be a bug of Paddle.");
return res;
});
} }
void BindVarDsec(py::module &m) { void BindVarDsec(py::module &m) {
...@@ -177,7 +198,17 @@ void BindVarDsec(py::module &m) { ...@@ -177,7 +198,17 @@ void BindVarDsec(py::module &m) {
.def("lod_level", &VarDescBind::GetLodLevel) .def("lod_level", &VarDescBind::GetLodLevel)
.def("set_lod_level", &VarDescBind::SetLoDLevel) .def("set_lod_level", &VarDescBind::SetLoDLevel)
.def("type", &VarDescBind::GetType) .def("type", &VarDescBind::GetType)
.def("set_type", &VarDescBind::SetType); .def("set_type", &VarDescBind::SetType)
.def("serialize_to_string", [](VarDescBind &var_desc) -> py::bytes {
const VarDesc *desc = var_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"VarDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize VarDesc Error. This could be a bug of Paddle.");
return res;
});
py::enum_<VarDesc::VarType>(var_desc, "VarType", "") py::enum_<VarDesc::VarType>(var_desc, "VarType", "")
.value("LOD_TENSOR", VarDesc::LOD_TENSOR) .value("LOD_TENSOR", VarDesc::LOD_TENSOR)
...@@ -213,7 +244,17 @@ void BindOpDesc(py::module &m) { ...@@ -213,7 +244,17 @@ void BindOpDesc(py::module &m) {
.def("set_block_attr", &OpDescBind::SetBlockAttr) .def("set_block_attr", &OpDescBind::SetBlockAttr)
.def("block_attr", &OpDescBind::GetBlockAttr) .def("block_attr", &OpDescBind::GetBlockAttr)
.def("check_attrs", &OpDescBind::CheckAttrs) .def("check_attrs", &OpDescBind::CheckAttrs)
.def("infer_shape", &OpDescBind::InferShape); .def("infer_shape", &OpDescBind::InferShape)
.def("serialize_to_string", [](OpDescBind &op_desc) -> py::bytes {
const OpDesc *desc = op_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"OpDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize OpDesc Error. This could be a bug of Paddle.");
return res;
});
} }
} // namespace pybind } // namespace pybind
......
...@@ -73,6 +73,13 @@ class Variable(object): ...@@ -73,6 +73,13 @@ class Variable(object):
self.block.vars[name] = self self.block.vars[name] = self
self.op = None self.op = None
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.VarDesc.FromString(str(protostr))
return proto.__str__()
__repr__ = __str__
@property @property
def name(self): def name(self):
return self.desc.name() return self.desc.name()
...@@ -210,6 +217,13 @@ class Operator(object): ...@@ -210,6 +217,13 @@ class Operator(object):
self.desc.check_attrs() self.desc.check_attrs()
self.desc.infer_shape(self.block.desc) self.desc.infer_shape(self.block.desc)
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.OpDesc.FromString(str(protostr))
return proto.__str__()
__repr__ = __str__
@property @property
def type(self): def type(self):
return self.desc.type() return self.desc.type()
...@@ -252,6 +266,13 @@ class Block(object): ...@@ -252,6 +266,13 @@ class Block(object):
self.ops = collections.deque() # operator list self.ops = collections.deque() # operator list
self.program = program self.program = program
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.BlockDesc.FromString(str(protostr))
return proto.__str__()
__repr__ = __str__
@property @property
def parent_idx(self): def parent_idx(self):
return self.desc.parent return self.desc.parent
...@@ -296,6 +317,13 @@ class Program(object): ...@@ -296,6 +317,13 @@ class Program(object):
self.blocks = [Block(self, 0)] self.blocks = [Block(self, 0)]
self.current_block_idx = 0 self.current_block_idx = 0
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.ProgramDesc.FromString(str(protostr))
return proto.__str__()
__repr__ = __str__
def global_block(self): def global_block(self):
return self.blocks[0] return self.blocks[0]
......
...@@ -34,6 +34,8 @@ class TestOperator(unittest.TestCase): ...@@ -34,6 +34,8 @@ class TestOperator(unittest.TestCase):
"Y": mul_y}, "Y": mul_y},
outputs={"Out": [mul_out]}, outputs={"Out": [mul_out]},
attrs={"x_num_col_dims": 1}) attrs={"x_num_col_dims": 1})
self.assertNotEqual(str(mul_op), "")
self.assertEqual(mul_op.type, "mul") self.assertEqual(mul_op.type, "mul")
self.assertEqual(mul_op.input_names, ["X", "Y"]) self.assertEqual(mul_op.input_names, ["X", "Y"])
self.assertEqual(mul_op.input("X"), ["mul.x"]) self.assertEqual(mul_op.input("X"), ["mul.x"])
......
...@@ -21,6 +21,7 @@ class TestVariable(unittest.TestCase): ...@@ -21,6 +21,7 @@ class TestVariable(unittest.TestCase):
b = g_program.current_block() b = g_program.current_block()
w = b.create_var( w = b.create_var(
dtype="float64", shape=[784, 100], lod_level=0, name="fc.w") dtype="float64", shape=[784, 100], lod_level=0, name="fc.w")
self.assertNotEqual(str(w), "")
self.assertEqual(core.DataType.FP64, w.data_type) self.assertEqual(core.DataType.FP64, w.data_type)
self.assertEqual((784, 100), w.shape) self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name) self.assertEqual("fc.w", w.name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册