diff --git a/paddle/fluid/framework/prune.cc b/paddle/fluid/framework/prune.cc index f5e20e041e9e44437907f384c7cadae007a9f510..cc0240017de9fa1b74c4460471e1b320294e85ad 100644 --- a/paddle/fluid/framework/prune.cc +++ b/paddle/fluid/framework/prune.cc @@ -210,6 +210,23 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, should_run.push_back(true); } else { should_run.push_back(false); + // If the output of an op modifies feed vars, the op should not clip. + // For example, in the transformer structure, the third parameter returned + // by beam_search op is generally assigned to a feed var. Cutting the + // assign op will cause an error. + if (parent_block_id != -1) { + bool flag = false; + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + if (feed_var_names.count(argu)) { + flag = true; + } + } + } + if (flag) { + should_run.back() = true; + } + } } } diff --git a/paddle/fluid/framework/prune_test.cc b/paddle/fluid/framework/prune_test.cc index eb5c241a8372a460483c70e38f962168b1cdbbc0..12fa0c61f8121d475a0cf2aa78e4bb995a01b132 100644 --- a/paddle/fluid/framework/prune_test.cc +++ b/paddle/fluid/framework/prune_test.cc @@ -185,3 +185,34 @@ TEST(Prune, recurrrent_op) { EXPECT_EQ(pruned.blocks(0).ops_size(), 2); EXPECT_EQ(pruned.blocks(1).ops_size(), 1); } + +// If the output of an op modifies feed vars, the op should not clip. +TEST(Prune, recurrrent_op_2) { + f::ProgramDesc program; + f::BlockDesc *block = program.MutableBlock(0); + f::BlockDesc *sub_block = program.AppendBlock(*block); + AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, + f::AttributeMap{}, block); + + std::vector state_var_name(1, "y"); + AddOp("recurrent", {{"input", {"b", "c"}}}, {{"output", {"b1, c1"}}}, + {{"ex_states", state_var_name}, + {"states", state_var_name}, + {"sub_block", sub_block}}, + block); + + EXPECT_TRUE(sub_block != nullptr); + AddOp("rnn_memory_helper", {{"input", {"x"}}}, {{"output", {"a"}}}, + f::AttributeMap{}, sub_block); + + f::proto::ProgramDesc *pdesc = program.Proto(); + pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true); + + f::proto::ProgramDesc pruned; + std::set feed_var_names = {"x", "a"}; + + f::Prune(*pdesc, feed_var_names, &pruned); + EXPECT_EQ(pruned.blocks_size(), 2); + EXPECT_EQ(pruned.blocks(0).ops_size(), 2); + EXPECT_EQ(pruned.blocks(1).ops_size(), 1); +} diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py index f3606dae19e57a6c9845e2782e4552099b29e652..b5526383d6c510703402526e19d88d6c3b28bd8b 100644 --- a/python/paddle/fluid/layers/rnn.py +++ b/python/paddle/fluid/layers/rnn.py @@ -3040,7 +3040,8 @@ def beam_search_decode(ids, scores, beam_size, end_id, name=None): 'beam_search_encode') helper = LayerHelper('beam_search_decode', **locals()) sentence_ids = helper.create_variable_for_type_inference(dtype=ids.dtype) - sentence_scores = helper.create_variable_for_type_inference(dtype=ids.dtype) + sentence_scores = helper.create_variable_for_type_inference( + dtype=scores.dtype) helper.append_op( type="beam_search_decode",