未验证 提交 daf464af 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #12807 from panyx0718/fix

fix program_desc constructor
...@@ -55,11 +55,20 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { ...@@ -55,11 +55,20 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
auto all_ops = blocks_[block_id]->AllOps(); auto all_ops = blocks_[block_id]->AllOps();
for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) { for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) {
auto &op = all_ops[op_id]; auto &op = all_ops[op_id];
for (const std::string &attr_name : op->AttrNames()) { for (const std::string &attr_name : op->AttrNames()) {
if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) { if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) {
int sub_block_id = int sub_block_id =
o.Block(block_id).Op(op_id)->GetBlockAttrId(attr_name); o.Block(block_id).Op(op_id)->GetBlockAttrId(attr_name);
op->SetBlockAttr(attr_name, MutableBlock(sub_block_id)); op->SetBlockAttr(attr_name, MutableBlock(sub_block_id));
} else if (op->GetAttrType(attr_name) == proto::AttrType::BLOCKS) {
std::vector<int> sub_block_ids =
o.Block(block_id).Op(op_id)->GetBlocksAttrIds(attr_name);
std::vector<BlockDesc *> block_descs;
for (int block_id : sub_block_ids) {
block_descs.push_back(MutableBlock(block_id));
}
op->SetBlocksAttr(attr_name, block_descs);
} }
} }
} }
...@@ -68,24 +77,16 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { ...@@ -68,24 +77,16 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
desc_ = desc; desc_ = desc;
for (auto &block_desc : *desc_.mutable_blocks()) { InitFromProto();
blocks_.emplace_back(new BlockDesc(this, &block_desc));
}
for (auto &block : blocks_) {
for (auto *op : block->AllOps()) {
for (const auto &attr : op->Proto()->attrs()) {
if (attr.type() == proto::AttrType::BLOCK) {
size_t blk_idx = attr.block_idx();
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
}
}
}
}
} }
ProgramDesc::ProgramDesc(const std::string &binary_str) { ProgramDesc::ProgramDesc(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str), PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string."); "Fail to parse program_desc from binary string.");
InitFromProto();
}
void ProgramDesc::InitFromProto() {
for (auto &block_desc : *desc_.mutable_blocks()) { for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDesc(this, &block_desc)); blocks_.emplace_back(new BlockDesc(this, &block_desc));
} }
...@@ -95,6 +96,13 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { ...@@ -95,6 +96,13 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
if (attr.type() == proto::AttrType::BLOCK) { if (attr.type() == proto::AttrType::BLOCK) {
size_t blk_idx = attr.block_idx(); size_t blk_idx = attr.block_idx();
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
} else if (attr.type() == proto::AttrType::BLOCKS) {
auto blks_idx = attr.blocks_idx();
std::vector<BlockDesc *> block_descs;
for (int blk_idx : blks_idx) {
block_descs.push_back(this->MutableBlock(blk_idx));
}
op->SetBlocksAttr(attr.name(), block_descs);
} }
} }
} }
......
...@@ -76,6 +76,8 @@ class ProgramDesc { ...@@ -76,6 +76,8 @@ class ProgramDesc {
void SetFetchHolderName(const std::string &fetch_holder_name); void SetFetchHolderName(const std::string &fetch_holder_name);
private: private:
void InitFromProto();
proto::ProgramDesc desc_; proto::ProgramDesc desc_;
std::vector<std::unique_ptr<BlockDesc>> blocks_; std::vector<std::unique_ptr<BlockDesc>> blocks_;
......
...@@ -42,6 +42,19 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -42,6 +42,19 @@ TEST(ProgramDesc, copy_ctor) {
out->SetType(proto::VarType::LOD_TENSOR); out->SetType(proto::VarType::LOD_TENSOR);
op->SetOutput("Y", {out->Name()}); op->SetOutput("Y", {out->Name()});
BlockDesc* new_block = program.AppendBlock(*global_block);
op = new_block->AppendOp();
op->SetType("mul");
op = global_block->AppendOp();
op->SetType("op_with_subblock");
op->SetAttr("sub_block", new_block);
std::vector<BlockDesc*> sub_blocks;
sub_blocks.push_back(program.AppendBlock(*global_block));
sub_blocks.push_back(program.AppendBlock(*global_block));
op->SetAttr("sub_blocks", sub_blocks);
ProgramDesc program_copy(program); ProgramDesc program_copy(program);
auto* global_block_copy = program_copy.MutableBlock(0); auto* global_block_copy = program_copy.MutableBlock(0);
...@@ -64,6 +77,8 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -64,6 +77,8 @@ TEST(ProgramDesc, copy_ctor) {
assert_same_var("Y", y); assert_same_var("Y", y);
assert_same_var("Out", out); assert_same_var("Out", out);
bool found_sub_block = false;
bool found_sub_blocks = false;
for (size_t i = 0; i < global_block->OpSize(); ++i) { for (size_t i = 0; i < global_block->OpSize(); ++i) {
auto op_origin = global_block->Op(i); auto op_origin = global_block->Op(i);
auto op_copy = global_block_copy->Op(i); auto op_copy = global_block_copy->Op(i);
...@@ -74,8 +89,17 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -74,8 +89,17 @@ TEST(ProgramDesc, copy_ctor) {
ASSERT_EQ(op_copy->Proto()->SerializeAsString(), ASSERT_EQ(op_copy->Proto()->SerializeAsString(),
op_origin->Proto()->SerializeAsString()); op_origin->Proto()->SerializeAsString());
}
if (op->Type() == "op_with_subblock") {
ASSERT_EQ(1, op->GetBlockAttrId("sub_block"));
found_sub_block = true;
ASSERT_EQ(2, op->GetBlocksAttrIds("sub_blocks").size());
found_sub_blocks = true;
}
}
ASSERT_TRUE(found_sub_block);
ASSERT_TRUE(found_sub_blocks);
// Not check block's protostr are same it because the order of vars could be // Not check block's protostr are same it because the order of vars could be
// different and it is correct. // different and it is correct.
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册