diff --git a/paddle/fluid/framework/prune.cc b/paddle/fluid/framework/prune.cc index 919378c929185b12826c8b427d0e9a86a382bb2b..274b0ca0d903d4e89c7bceb74bc16581f03bb584 100644 --- a/paddle/fluid/framework/prune.cc +++ b/paddle/fluid/framework/prune.cc @@ -210,6 +210,23 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, should_run.push_back(true); } else { should_run.push_back(false); + // If the output of an op modifies feed vars, the op should not clip. + // For example, in the transformer structure, the third parameter returned + // by beam_search op is generally assigned to a feed var. Cutting the + // assign op will cause an error. + if (parent_block_id != -1) { + bool flag = false; + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + if (feed_var_names.count(argu)) { + flag = true; + } + } + } + if (flag) { + should_run.back() = true; + } + } } } diff --git a/paddle/fluid/framework/prune_test.cc b/paddle/fluid/framework/prune_test.cc index eb5c241a8372a460483c70e38f962168b1cdbbc0..12fa0c61f8121d475a0cf2aa78e4bb995a01b132 100644 --- a/paddle/fluid/framework/prune_test.cc +++ b/paddle/fluid/framework/prune_test.cc @@ -185,3 +185,34 @@ TEST(Prune, recurrrent_op) { EXPECT_EQ(pruned.blocks(0).ops_size(), 2); EXPECT_EQ(pruned.blocks(1).ops_size(), 1); } + +// If the output of an op modifies feed vars, the op should not clip. +TEST(Prune, recurrrent_op_2) { + f::ProgramDesc program; + f::BlockDesc *block = program.MutableBlock(0); + f::BlockDesc *sub_block = program.AppendBlock(*block); + AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, + f::AttributeMap{}, block); + + std::vector state_var_name(1, "y"); + AddOp("recurrent", {{"input", {"b", "c"}}}, {{"output", {"b1, c1"}}}, + {{"ex_states", state_var_name}, + {"states", state_var_name}, + {"sub_block", sub_block}}, + block); + + EXPECT_TRUE(sub_block != nullptr); + AddOp("rnn_memory_helper", {{"input", {"x"}}}, {{"output", {"a"}}}, + f::AttributeMap{}, sub_block); + + f::proto::ProgramDesc *pdesc = program.Proto(); + pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true); + + f::proto::ProgramDesc pruned; + std::set feed_var_names = {"x", "a"}; + + f::Prune(*pdesc, feed_var_names, &pruned); + EXPECT_EQ(pruned.blocks_size(), 2); + EXPECT_EQ(pruned.blocks(0).ops_size(), 2); + EXPECT_EQ(pruned.blocks(1).ops_size(), 1); +} diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py index 39d25a9c7ffd47212153a345a98781ed034c7020..4ec0770aaf0a559b9619e74f7c9178b2fd61062f 100644 --- a/python/paddle/fluid/layers/rnn.py +++ b/python/paddle/fluid/layers/rnn.py @@ -3102,7 +3102,8 @@ def beam_search_decode(ids, scores, beam_size, end_id, name=None): 'beam_search_encode') helper = LayerHelper('beam_search_decode', **locals()) sentence_ids = helper.create_variable_for_type_inference(dtype=ids.dtype) - sentence_scores = helper.create_variable_for_type_inference(dtype=ids.dtype) + sentence_scores = helper.create_variable_for_type_inference( + dtype=scores.dtype) helper.append_op( type="beam_search_decode",