diff --git a/paddle/fluid/framework/prune.cc b/paddle/fluid/framework/prune.cc index 0c27d41f33ffb3d76a3e024da21ff74514f386d5..c58cb8ad2ace9927d85a22cb400e2b91af331cbd 100644 --- a/paddle/fluid/framework/prune.cc +++ b/paddle/fluid/framework/prune.cc @@ -35,6 +35,10 @@ namespace framework { const char kFeedOpType[] = "feed"; const char kFetchOpType[] = "fetch"; +const char kRecurrent[] = "recurrent"; +const char kStates[] = "states"; +const char kExStates[] = "ex_states"; + bool HasDependentInputVar( const proto::OpDesc& op_desc, const std::unordered_set& dependent_vars) { @@ -173,6 +177,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, auto* op = op_field->Add(); *op = input.blocks(block_id).ops(i); if (HasSubBlock(*op)) { + VLOG(2) << "Pruning op which has sub block: " << op->type(); // create sub_block_dependent_vars here to help prune the sub block std::unordered_set sub_block_dependent_vars; for (auto& var : op->inputs()) { @@ -189,6 +194,20 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, } } } + + // Recurrent op's states are also dependent vars + if (op->type() == kRecurrent) { + auto& attributes = op->attrs(); + for (auto& attr : attributes) { + if (attr.name() == kStates || attr.name() == kExStates) { + for (auto& argu : attr.strings()) { + if (feed_var_names.count(argu) == 0) { + sub_block_dependent_vars.insert(argu); + } + } + } + } + } // GetSubBlockIndex(*op) is the idx of the sub_block in the input desc // output_block_id is the idx of the current block in the output desc prune_impl(input, output, GetSubBlockIndex(*op), output_block_id, diff --git a/paddle/fluid/framework/prune_test.cc b/paddle/fluid/framework/prune_test.cc index 210e61a4dec7ee01e3b1d4db2a7d4934156d9b04..eb5c241a8372a460483c70e38f962168b1cdbbc0 100644 --- a/paddle/fluid/framework/prune_test.cc +++ b/paddle/fluid/framework/prune_test.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/operator.h" @@ -61,12 +62,12 @@ TEST(Prune, one_operator) { f::proto::ProgramDesc pruned; std::set feed_var_names = {}; f::Prune(*pdesc, feed_var_names, &pruned); - PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0); + EXPECT_EQ(pruned.blocks(0).ops_size(), 0); feed_var_names.insert("a"); pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true); f::Prune(*pdesc, feed_var_names, &pruned); - PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1); + EXPECT_EQ(pruned.blocks(0).ops_size(), 1); } TEST(Prune, forward) { @@ -88,7 +89,7 @@ TEST(Prune, forward) { f::proto::ProgramDesc pruned; pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true); f::Prune(*pdesc, feed_var_names, &pruned); - PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1); + EXPECT_EQ(pruned.blocks(0).ops_size(), i + 1); } } @@ -111,7 +112,7 @@ TEST(Prune, multi_input_op) { f::proto::ProgramDesc pruned; std::set feed_var_names = {"a0", "a1", "a2"}; f::Prune(*pdesc, feed_var_names, &pruned); - PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4); + EXPECT_EQ(pruned.blocks(0).ops_size(), 4); } TEST(Prune, multi_output_op) { @@ -131,7 +132,7 @@ TEST(Prune, multi_output_op) { f::proto::ProgramDesc pruned; std::set feed_var_names = {"a"}; f::Prune(*pdesc, feed_var_names, &pruned); - PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2); + EXPECT_EQ(pruned.blocks(0).ops_size(), 2); } TEST(Prune, multi_target) { @@ -152,5 +153,35 @@ TEST(Prune, multi_target) { f::proto::ProgramDesc pruned; std::set feed_var_names = {"a"}; f::Prune(*pdesc, feed_var_names, &pruned); - PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3); + EXPECT_EQ(pruned.blocks(0).ops_size(), 3); +} + +TEST(Prune, recurrrent_op) { + f::ProgramDesc program; + f::BlockDesc *block = program.MutableBlock(0); + f::BlockDesc *sub_block = program.AppendBlock(*block); + AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, + f::AttributeMap{}, block); + + std::vector state_var_name(1, "y"); + AddOp("recurrent", {{"input", {"b", "c"}}}, {{"output", {"b1, c1"}}}, + {{"ex_states", state_var_name}, + {"states", state_var_name}, + {"sub_block", sub_block}}, + block); + + EXPECT_TRUE(sub_block != nullptr); + AddOp("rnn_memory_helper", {{"input", {"x"}}}, {{"output", {"y"}}}, + f::AttributeMap{}, sub_block); + + f::proto::ProgramDesc *pdesc = program.Proto(); + pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true); + + f::proto::ProgramDesc pruned; + std::set feed_var_names = {"a"}; + + f::Prune(*pdesc, feed_var_names, &pruned); + EXPECT_EQ(pruned.blocks_size(), 2); + EXPECT_EQ(pruned.blocks(0).ops_size(), 2); + EXPECT_EQ(pruned.blocks(1).ops_size(), 1); }