diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index b77d5525d4508056c9d6d487e63e500265e1d700..4c39975ec94f95d3299efe58474d9db43654ec22 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -66,7 +66,7 @@ std::vector BlockDescBind::AllOps() const { return res; } -void BlockDescBind::Sync() { +void BlockDescBind::Flush() { if (need_update_) { auto &op_field = *this->desc_->mutable_ops(); op_field.Clear(); @@ -91,5 +91,10 @@ BlockDescBind *BlockDescBind::ParentBlock() const { return prog_->Block(static_cast(this->desc_->parent_idx())); } +BlockDesc *BlockDescBind::Proto() { + Flush(); + return desc_; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 9d453e1d6f42077df3886d9645e1ab59eaf1aa1d..cb39eb40d4606e33f461f5f4f81336ae80210572 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -65,9 +65,9 @@ class BlockDescBind { std::vector AllOps() const; - void Sync(); + void Flush(); - BlockDesc *RawPtr() { return desc_; } + BlockDesc *Proto(); private: ProgramDescBind *prog_; // not_own diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index a5d515bbca729220ca6df5fa07d02f1b3f025109..ef207dc54ebe6cc72d9f1e428dd2aaed5ad3dbf0 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -32,7 +32,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, } OpDesc *OpDescBind::Proto() { - Sync(); + Flush(); return &op_desc_; } @@ -101,7 +101,7 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) { } void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { - BlockDesc *desc = block.RawPtr(); + BlockDesc *desc = block.Proto(); this->attrs_[name] = desc; need_update_ = true; } @@ -165,7 +165,7 @@ struct SetAttrDescVisitor : public boost::static_visitor { void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } }; -void OpDescBind::Sync() { +void OpDescBind::Flush() { if (need_update_) { this->op_desc_.mutable_inputs()->Clear(); for (auto &ipt : inputs_) { diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 90155fadeac148bd9cae4ce9066ac4ce8d9df52d..73b5cf846f702fe21277ae139156ec9784aa79b3 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -89,8 +89,6 @@ class OpDescBind { this->need_update_ = true; } - void Sync(); - const VariableNameMap &Inputs() const { return inputs_; } const VariableNameMap &Outputs() const { return outputs_; } @@ -104,6 +102,8 @@ class OpDescBind { void InferShape(const BlockDescBind &block) const; + void Flush(); + private: template static std::vector MapKeys(const MapType &map) { diff --git a/paddle/framework/program_desc.cc b/paddle/framework/program_desc.cc index e89f9a46d587b6378aa3be92306c5680093e1926..fcb7292884275d972377983cb3ba1bcd86fb8348 100644 --- a/paddle/framework/program_desc.cc +++ b/paddle/framework/program_desc.cc @@ -45,7 +45,7 @@ BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) { ProgramDesc *ProgramDescBind::Proto() { for (auto &block : blocks_) { - block->Sync(); + block->Flush(); } return prog_; } diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 89c625b42d4855ced8be7b0a7b1d191f3365f799..ec9b7ee9dd5076de600fb596d52f6a9fac7069a4 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -123,7 +123,18 @@ void BindProgramDesc(py::module &m) { AppendBackward(program_desc, no_grad_vars); }) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference) - .def("num_blocks", &ProgramDescBind::Size); + .def("num_blocks", &ProgramDescBind::Size) + .def("serialize_to_string", + [](ProgramDescBind &program_desc) -> py::bytes { + const ProgramDesc *desc = program_desc.Proto(); + PADDLE_ENFORCE(desc->IsInitialized(), + "ProgramDesc has not been initialized."); + std::string res; + PADDLE_ENFORCE( + desc->SerializeToString(&res), + "Serialize ProgramDesc Error. This could be a bug of Paddle."); + return res; + }); } void BindBlockDesc(py::module &m) { @@ -149,7 +160,17 @@ void BindBlockDesc(py::module &m) { .def("all_vars", &BlockDescBind::AllVars, py::return_value_policy::reference) .def("all_ops", &BlockDescBind::AllOps, - py::return_value_policy::reference); + py::return_value_policy::reference) + .def("serialize_to_string", [](BlockDescBind &block_desc) -> py::bytes { + const BlockDesc *desc = block_desc.Proto(); + PADDLE_ENFORCE(desc->IsInitialized(), + "BlockDesc has not been initialized."); + std::string res; + PADDLE_ENFORCE( + desc->SerializeToString(&res), + "Serialize BlockDesc Error. This could be a bug of Paddle."); + return res; + }); } void BindVarDsec(py::module &m) { @@ -177,7 +198,17 @@ void BindVarDsec(py::module &m) { .def("lod_level", &VarDescBind::GetLodLevel) .def("set_lod_level", &VarDescBind::SetLoDLevel) .def("type", &VarDescBind::GetType) - .def("set_type", &VarDescBind::SetType); + .def("set_type", &VarDescBind::SetType) + .def("serialize_to_string", [](VarDescBind &var_desc) -> py::bytes { + const VarDesc *desc = var_desc.Proto(); + PADDLE_ENFORCE(desc->IsInitialized(), + "VarDesc has not been initialized."); + std::string res; + PADDLE_ENFORCE( + desc->SerializeToString(&res), + "Serialize VarDesc Error. This could be a bug of Paddle."); + return res; + }); py::enum_(var_desc, "VarType", "") .value("LOD_TENSOR", VarDesc::LOD_TENSOR) @@ -213,7 +244,17 @@ void BindOpDesc(py::module &m) { .def("set_block_attr", &OpDescBind::SetBlockAttr) .def("block_attr", &OpDescBind::GetBlockAttr) .def("check_attrs", &OpDescBind::CheckAttrs) - .def("infer_shape", &OpDescBind::InferShape); + .def("infer_shape", &OpDescBind::InferShape) + .def("serialize_to_string", [](OpDescBind &op_desc) -> py::bytes { + const OpDesc *desc = op_desc.Proto(); + PADDLE_ENFORCE(desc->IsInitialized(), + "OpDesc has not been initialized."); + std::string res; + PADDLE_ENFORCE( + desc->SerializeToString(&res), + "Serialize OpDesc Error. This could be a bug of Paddle."); + return res; + }); } } // namespace pybind diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index c57d7239960106be747153faacc03f5ab5174bea..10e5726a85f435b08997083094223ac2a0a15b61 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -73,6 +73,13 @@ class Variable(object): self.block.vars[name] = self self.op = None + def __str__(self): + protostr = self.desc.serialize_to_string() + proto = framework_pb2.VarDesc.FromString(str(protostr)) + return proto.__str__() + + __repr__ = __str__ + @property def name(self): return self.desc.name() @@ -210,6 +217,13 @@ class Operator(object): self.desc.check_attrs() self.desc.infer_shape(self.block.desc) + def __str__(self): + protostr = self.desc.serialize_to_string() + proto = framework_pb2.OpDesc.FromString(str(protostr)) + return proto.__str__() + + __repr__ = __str__ + @property def type(self): return self.desc.type() @@ -252,6 +266,13 @@ class Block(object): self.ops = collections.deque() # operator list self.program = program + def __str__(self): + protostr = self.desc.serialize_to_string() + proto = framework_pb2.BlockDesc.FromString(str(protostr)) + return proto.__str__() + + __repr__ = __str__ + @property def parent_idx(self): return self.desc.parent @@ -296,6 +317,13 @@ class Program(object): self.blocks = [Block(self, 0)] self.current_block_idx = 0 + def __str__(self): + protostr = self.desc.serialize_to_string() + proto = framework_pb2.ProgramDesc.FromString(str(protostr)) + return proto.__str__() + + __repr__ = __str__ + def global_block(self): return self.blocks[0] diff --git a/python/paddle/v2/framework/tests/test_operator_desc.py b/python/paddle/v2/framework/tests/test_operator_desc.py index d7a85d8e4e883efd268c53a0e4977533040a0a14..dfe39c98f7f4fe266d5ec0c4a9ed14ab02e40e3a 100644 --- a/python/paddle/v2/framework/tests/test_operator_desc.py +++ b/python/paddle/v2/framework/tests/test_operator_desc.py @@ -34,6 +34,8 @@ class TestOperator(unittest.TestCase): "Y": mul_y}, outputs={"Out": [mul_out]}, attrs={"x_num_col_dims": 1}) + + self.assertNotEqual(str(mul_op), "") self.assertEqual(mul_op.type, "mul") self.assertEqual(mul_op.input_names, ["X", "Y"]) self.assertEqual(mul_op.input("X"), ["mul.x"]) diff --git a/python/paddle/v2/framework/tests/test_variable.py b/python/paddle/v2/framework/tests/test_variable.py index 695aaaee6c0c1d035349b1d1716c24bab81e607b..6fb934c743a6271c352a74495cc543b62ac2b9d9 100644 --- a/python/paddle/v2/framework/tests/test_variable.py +++ b/python/paddle/v2/framework/tests/test_variable.py @@ -21,6 +21,7 @@ class TestVariable(unittest.TestCase): b = g_program.current_block() w = b.create_var( dtype="float64", shape=[784, 100], lod_level=0, name="fc.w") + self.assertNotEqual(str(w), "") self.assertEqual(core.DataType.FP64, w.data_type) self.assertEqual((784, 100), w.shape) self.assertEqual("fc.w", w.name)