未验证 提交 e1171142 编写于 作者: H Huihuang Zheng 提交者: GitHub

Set states of recurrent op as dependent vars in prune (#19865)

* Set states of recurrent op as dependent vars in prune of save inference model

This PR will fix the save/load inference model problem of RNN models.

The reason of the bug is that save_inferenc_model will prune OPs that doesn't contribute to Output. But in recurrent_op, States are not Output, OPs refers States will be pruned. 

This fix adds States of recurrent_op as dependent var so that OPs referring States won't be pruned. 
上级 c5eedcf6
......@@ -35,6 +35,10 @@ namespace framework {
const char kFeedOpType[] = "feed";
const char kFetchOpType[] = "fetch";
const char kRecurrent[] = "recurrent";
const char kStates[] = "states";
const char kExStates[] = "ex_states";
bool HasDependentInputVar(
const proto::OpDesc& op_desc,
const std::unordered_set<std::string>& dependent_vars) {
......@@ -173,6 +177,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
auto* op = op_field->Add();
*op = input.blocks(block_id).ops(i);
if (HasSubBlock(*op)) {
VLOG(2) << "Pruning op which has sub block: " << op->type();
// create sub_block_dependent_vars here to help prune the sub block
std::unordered_set<std::string> sub_block_dependent_vars;
for (auto& var : op->inputs()) {
......@@ -189,6 +194,20 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
}
}
}
// Recurrent op's states are also dependent vars
if (op->type() == kRecurrent) {
auto& attributes = op->attrs();
for (auto& attr : attributes) {
if (attr.name() == kStates || attr.name() == kExStates) {
for (auto& argu : attr.strings()) {
if (feed_var_names.count(argu) == 0) {
sub_block_dependent_vars.insert(argu);
}
}
}
}
}
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
// output_block_id is the idx of the current block in the output desc
prune_impl(input, output, GetSubBlockIndex(*op), output_block_id,
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/operator.h"
......@@ -61,12 +62,12 @@ TEST(Prune, one_operator) {
f::proto::ProgramDesc pruned;
std::set<std::string> feed_var_names = {};
f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);
EXPECT_EQ(pruned.blocks(0).ops_size(), 0);
feed_var_names.insert("a");
pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
EXPECT_EQ(pruned.blocks(0).ops_size(), 1);
}
TEST(Prune, forward) {
......@@ -88,7 +89,7 @@ TEST(Prune, forward) {
f::proto::ProgramDesc pruned;
pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
EXPECT_EQ(pruned.blocks(0).ops_size(), i + 1);
}
}
......@@ -111,7 +112,7 @@ TEST(Prune, multi_input_op) {
f::proto::ProgramDesc pruned;
std::set<std::string> feed_var_names = {"a0", "a1", "a2"};
f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
EXPECT_EQ(pruned.blocks(0).ops_size(), 4);
}
TEST(Prune, multi_output_op) {
......@@ -131,7 +132,7 @@ TEST(Prune, multi_output_op) {
f::proto::ProgramDesc pruned;
std::set<std::string> feed_var_names = {"a"};
f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
EXPECT_EQ(pruned.blocks(0).ops_size(), 2);
}
TEST(Prune, multi_target) {
......@@ -152,5 +153,35 @@ TEST(Prune, multi_target) {
f::proto::ProgramDesc pruned;
std::set<std::string> feed_var_names = {"a"};
f::Prune(*pdesc, feed_var_names, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
EXPECT_EQ(pruned.blocks(0).ops_size(), 3);
}
TEST(Prune, recurrrent_op) {
f::ProgramDesc program;
f::BlockDesc *block = program.MutableBlock(0);
f::BlockDesc *sub_block = program.AppendBlock(*block);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}},
f::AttributeMap{}, block);
std::vector<std::string> state_var_name(1, "y");
AddOp("recurrent", {{"input", {"b", "c"}}}, {{"output", {"b1, c1"}}},
{{"ex_states", state_var_name},
{"states", state_var_name},
{"sub_block", sub_block}},
block);
EXPECT_TRUE(sub_block != nullptr);
AddOp("rnn_memory_helper", {{"input", {"x"}}}, {{"output", {"y"}}},
f::AttributeMap{}, sub_block);
f::proto::ProgramDesc *pdesc = program.Proto();
pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true);
f::proto::ProgramDesc pruned;
std::set<std::string> feed_var_names = {"a"};
f::Prune(*pdesc, feed_var_names, &pruned);
EXPECT_EQ(pruned.blocks_size(), 2);
EXPECT_EQ(pruned.blocks(0).ops_size(), 2);
EXPECT_EQ(pruned.blocks(1).ops_size(), 1);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册