提交 aa379ccb 编写于 作者: F fengjiayi 提交者: GitHub

Add functions of restoring ProgramDescBind from ProgramDesc (#5109)

* compelete restoring program_bind from program_desc

* Fix bugs

* fix compile errors

* fix errors and add unit tests

* rename some vars

* Follow comments
上级 b1cbdf03
...@@ -120,6 +120,17 @@ BlockDesc *BlockDescBind::Proto() { ...@@ -120,6 +120,17 @@ BlockDesc *BlockDescBind::Proto() {
Flush(); Flush();
return desc_; return desc_;
} }
BlockDescBind::BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {
for (const VarDesc &var_desc : desc_->vars()) {
vars_[var_desc.name()].reset(new VarDescBind(var_desc));
}
for (const OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDescBind(op_desc, prog));
}
}
BlockDescBind::BlockDescBind(const BlockDescBind &other, BlockDesc *desc, BlockDescBind::BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog) ProgramDescBind *prog)
: prog_(prog), desc_(desc) { : prog_(prog), desc_(desc) {
......
...@@ -36,8 +36,7 @@ class ProgramDescBind; ...@@ -36,8 +36,7 @@ class ProgramDescBind;
class BlockDescBind { class BlockDescBind {
public: public:
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) BlockDescBind(ProgramDescBind *prog, BlockDesc *desc);
: prog_(prog), desc_(desc), need_update_(false) {}
BlockDescBind(const BlockDescBind &other, BlockDesc *desc, BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog); ProgramDescBind *prog);
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include "paddle/framework/block_desc.h" #include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/program_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -24,16 +25,47 @@ namespace framework { ...@@ -24,16 +25,47 @@ namespace framework {
OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs) { const AttributeMap &attrs) {
op_desc_.set_type(type); desc_.set_type(type);
inputs_ = inputs; inputs_ = inputs;
outputs_ = outputs; outputs_ = outputs;
attrs_ = attrs; attrs_ = attrs;
need_update_ = true; need_update_ = true;
} }
OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
: desc_(desc), need_update_(false) {
// restore inputs_
int input_size = desc_.inputs_size();
for (int i = 0; i < input_size; ++i) {
const OpDesc::Var &var = desc_.inputs(i);
std::vector<std::string> &args = inputs_[var.parameter()];
int argu_size = var.arguments_size();
args.reserve(argu_size);
for (int j = 0; j < argu_size; ++j) {
args.push_back(var.arguments(j));
}
}
// restore outputs_
int output_size = desc_.outputs_size();
for (int i = 0; i < output_size; ++i) {
const OpDesc::Var &var = desc_.outputs(i);
std::vector<std::string> &args = outputs_[var.parameter()];
int argu_size = var.arguments_size();
args.reserve(argu_size);
for (int j = 0; j < argu_size; ++j) {
args.push_back(var.arguments(j));
}
}
// restore attrs_
for (const OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name();
attrs_[attr_name] = GetAttrValue(attr, prog->Proto());
}
}
OpDesc *OpDescBind::Proto() { OpDesc *OpDescBind::Proto() {
Flush(); Flush();
return &op_desc_; return &desc_;
} }
const std::vector<std::string> &OpDescBind::Input( const std::vector<std::string> &OpDescBind::Input(
...@@ -167,23 +199,23 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> { ...@@ -167,23 +199,23 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void OpDescBind::Flush() { void OpDescBind::Flush() {
if (need_update_) { if (need_update_) {
this->op_desc_.mutable_inputs()->Clear(); this->desc_.mutable_inputs()->Clear();
for (auto &ipt : inputs_) { for (auto &ipt : inputs_) {
auto *input = op_desc_.add_inputs(); auto *input = desc_.add_inputs();
input->set_parameter(ipt.first); input->set_parameter(ipt.first);
VectorToRepeated(ipt.second, input->mutable_arguments()); VectorToRepeated(ipt.second, input->mutable_arguments());
} }
this->op_desc_.mutable_outputs()->Clear(); this->desc_.mutable_outputs()->Clear();
for (auto &opt : outputs_) { for (auto &opt : outputs_) {
auto *output = op_desc_.add_outputs(); auto *output = desc_.add_outputs();
output->set_parameter(opt.first); output->set_parameter(opt.first);
VectorToRepeated(opt.second, output->mutable_arguments()); VectorToRepeated(opt.second, output->mutable_arguments());
} }
this->op_desc_.mutable_attrs()->Clear(); this->desc_.mutable_attrs()->Clear();
for (auto &attr : attrs_) { for (auto &attr : attrs_) {
auto *attr_desc = op_desc_.add_attrs(); auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first); attr_desc->set_name(attr.first);
attr_desc->set_type( attr_desc->set_type(
static_cast<framework::AttrType>(attr.second.which() - 1)); static_cast<framework::AttrType>(attr.second.which() - 1));
......
...@@ -24,6 +24,7 @@ namespace paddle { ...@@ -24,6 +24,7 @@ namespace paddle {
namespace framework { namespace framework {
class BlockDescBind; class BlockDescBind;
class ProgramDescBind;
class OpDescBind { class OpDescBind {
public: public:
...@@ -32,11 +33,13 @@ class OpDescBind { ...@@ -32,11 +33,13 @@ class OpDescBind {
OpDescBind(const std::string &type, const VariableNameMap &inputs, OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs); const VariableNameMap &outputs, const AttributeMap &attrs);
OpDescBind(const OpDesc &desc, ProgramDescBind *prog);
OpDesc *Proto(); OpDesc *Proto();
std::string Type() const { return op_desc_.type(); } std::string Type() const { return desc_.type(); }
void SetType(const std::string &type) { op_desc_.set_type(type); } void SetType(const std::string &type) { desc_.set_type(type); }
const std::vector<std::string> &Input(const std::string &name) const; const std::vector<std::string> &Input(const std::string &name) const;
...@@ -117,7 +120,7 @@ class OpDescBind { ...@@ -117,7 +120,7 @@ class OpDescBind {
return ret_val; return ret_val;
} }
OpDesc op_desc_; OpDesc desc_;
VariableNameMap inputs_; VariableNameMap inputs_;
VariableNameMap outputs_; VariableNameMap outputs_;
AttributeMap attrs_; AttributeMap attrs_;
......
...@@ -19,9 +19,9 @@ namespace paddle { ...@@ -19,9 +19,9 @@ namespace paddle {
namespace framework { namespace framework {
BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) { BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
auto *b = prog_.add_blocks(); auto *b = desc_.add_blocks();
b->set_parent_idx(parent.ID()); b->set_parent_idx(parent.ID());
b->set_idx(prog_.blocks_size() - 1); b->set_idx(desc_.blocks_size() - 1);
blocks_.emplace_back(new BlockDescBind(this, b)); blocks_.emplace_back(new BlockDescBind(this, b));
return blocks_.back().get(); return blocks_.back().get();
} }
...@@ -30,23 +30,32 @@ ProgramDesc *ProgramDescBind::Proto() { ...@@ -30,23 +30,32 @@ ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) { for (auto &block : blocks_) {
block->Flush(); block->Flush();
} }
return &prog_; return &desc_;
} }
ProgramDescBind::ProgramDescBind() { ProgramDescBind::ProgramDescBind() {
auto *block = prog_.mutable_blocks()->Add(); auto *block = desc_.mutable_blocks()->Add();
block->set_idx(kRootBlockIndex); block->set_idx(kRootBlockIndex);
block->set_parent_idx(kNoneBlockIndex); block->set_parent_idx(kNoneBlockIndex);
blocks_.emplace_back(new BlockDescBind(this, block)); blocks_.emplace_back(new BlockDescBind(this, block));
} }
ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) { ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
prog_ = o.prog_; desc_ = o.desc_;
for (int i = 0; i < prog_.blocks_size(); ++i) { for (int i = 0; i < desc_.blocks_size(); ++i) {
auto *block = prog_.mutable_blocks(i); auto *block = desc_.mutable_blocks(i);
blocks_.emplace_back(new BlockDescBind(*o.blocks_[i], block, this)); blocks_.emplace_back(new BlockDescBind(*o.blocks_[i], block, this));
} }
} }
ProgramDescBind::ProgramDescBind(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string.");
for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block_desc));
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -31,6 +31,8 @@ class ProgramDescBind { ...@@ -31,6 +31,8 @@ class ProgramDescBind {
ProgramDescBind(const ProgramDescBind &o); ProgramDescBind(const ProgramDescBind &o);
explicit ProgramDescBind(const std::string &binary_str);
BlockDescBind *AppendBlock(const BlockDescBind &parent); BlockDescBind *AppendBlock(const BlockDescBind &parent);
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); } BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
...@@ -40,7 +42,7 @@ class ProgramDescBind { ...@@ -40,7 +42,7 @@ class ProgramDescBind {
ProgramDesc *Proto(); ProgramDesc *Proto();
private: private:
ProgramDesc prog_; ProgramDesc desc_;
std::vector<std::unique_ptr<BlockDescBind>> blocks_; std::vector<std::unique_ptr<BlockDescBind>> blocks_;
}; };
......
...@@ -59,7 +59,7 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -59,7 +59,7 @@ TEST(ProgramDesc, copy_ctor) {
}; };
ASSERT_EQ(global_block->LocalVarNames(), global_block_copy->LocalVarNames()); ASSERT_EQ(global_block->LocalVarNames(), global_block_copy->LocalVarNames());
ASSERT_EQ(3, global_block_copy->LocalVarNames().size()); ASSERT_EQ(3UL, global_block_copy->LocalVarNames().size());
assert_same_var("X", x); assert_same_var("X", x);
assert_same_var("Y", y); assert_same_var("Y", y);
assert_same_var("Out", out); assert_same_var("Out", out);
...@@ -79,5 +79,67 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -79,5 +79,67 @@ TEST(ProgramDesc, copy_ctor) {
// Not check block's protostr are same it because the order of vars could be // Not check block's protostr are same it because the order of vars could be
// different and it is correct. // different and it is correct.
} }
TEST(ProgramDescBind, serialize_and_deserialize) {
ProgramDescBind program_origin;
auto* global_block = program_origin.Block(0);
auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(VarDesc_VarType_LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(FP32);
y->SetShape({784, 100});
auto* op = global_block->AppendOp();
op->SetType("mul");
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out");
out->SetType(VarDesc_VarType_LOD_TENSOR);
op->SetOutput("Y", {out->Name()});
std::string binary_str;
program_origin.Proto()->SerializeToString(&binary_str);
ProgramDescBind program_restored(binary_str);
auto* global_block_restored = program_restored.Block(0);
ASSERT_NE(global_block, global_block_restored);
auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) {
ASSERT_TRUE(global_block_restored->HasVar(name));
auto* restored = global_block_restored->Var(name);
ASSERT_NE(restored, var_before);
ASSERT_EQ(restored->Name(), var_before->Name());
ASSERT_EQ(restored->GetType(), var_before->GetType());
ASSERT_EQ(restored->Shape(), var_before->Shape());
ASSERT_EQ(restored->Proto()->SerializeAsString(),
var_before->Proto()->SerializeAsString());
};
ASSERT_EQ(global_block->LocalVarNames(),
global_block_restored->LocalVarNames());
ASSERT_EQ(3UL, global_block_restored->LocalVarNames().size());
assert_same_var("X", x);
assert_same_var("Y", y);
assert_same_var("Out", out);
for (size_t i = 0; i < global_block->OpSize(); ++i) {
auto op_origin = global_block->Op(i);
auto op_restored = global_block->Op(i);
ASSERT_EQ(op_origin->Type(), op_restored->Type());
ASSERT_EQ(op_origin->Inputs(), op_restored->Inputs());
ASSERT_EQ(op_origin->Outputs(), op_restored->Outputs());
ASSERT_EQ(op_restored->Proto()->SerializeAsString(),
op_origin->Proto()->SerializeAsString());
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -59,6 +59,8 @@ class VarDescBind { ...@@ -59,6 +59,8 @@ class VarDescBind {
desc_.set_type(VarDesc::LOD_TENSOR); desc_.set_type(VarDesc::LOD_TENSOR);
} }
explicit VarDescBind(const VarDesc &desc) : desc_(desc) {}
VarDesc *Proto() { return &desc_; } VarDesc *Proto() { return &desc_; }
std::string Name() const { return desc_.name(); } std::string Name() const { return desc_.name(); }
......
...@@ -105,6 +105,11 @@ void BindProgramDesc(py::module &m) { ...@@ -105,6 +105,11 @@ void BindProgramDesc(py::module &m) {
[](ProgramDescBind &self, const ProgramDescBind &other) { [](ProgramDescBind &self, const ProgramDescBind &other) {
new (&self) ProgramDescBind(other); new (&self) ProgramDescBind(other);
}) })
.def("__init__",
[](ProgramDescBind &self, const py::bytes &binary_str) {
std::string str(binary_str);
new (&self) ProgramDescBind(str);
})
.def("append_block", &ProgramDescBind::AppendBlock, .def("append_block", &ProgramDescBind::AppendBlock,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("append_backward", .def("append_backward",
......
...@@ -440,6 +440,13 @@ class Program(object): ...@@ -440,6 +440,13 @@ class Program(object):
p.sync_with_cpp() p.sync_with_cpp()
return p return p
@staticmethod
def parse_from_string(binary_str):
p = Program()
p.desc = core.ProgramDesc(binary_str)
p.sync_with_cpp()
return p
def __repr__(self): def __repr__(self):
return str(self) return str(self)
......
...@@ -52,6 +52,25 @@ class TestProgram(unittest.TestCase): ...@@ -52,6 +52,25 @@ class TestProgram(unittest.TestCase):
print prog print prog
print prog.clone() print prog.clone()
def test_parse_program_from_string(self):
prog = Program()
x = prog.global_block().create_var(
name='X', shape=[1000, 784], dtype='float32')
y = prog.global_block().create_var(
name='Y', shape=[784, 100], dtype='float32')
out = prog.global_block().create_var(name='Out', dtype='float32')
prog.global_block().append_op(
type="mul", inputs={'X': [x],
'Y': [y]}, outputs={'Out': [out]})
binary_str = prog.desc.serialize_to_string()
prog_restored = Program.parse_from_string(binary_str)
print prog
print prog_restored
def test_append_backward(self): def test_append_backward(self):
prog = Program() prog = Program()
block = prog.global_block() block = prog.global_block()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册