未验证 提交 c4a972c2 编写于 作者: W Wilber 提交者: GitHub

fix prune for transformer model. (#26422)

上级 50609f0f
......@@ -210,6 +210,23 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
should_run.push_back(true);
} else {
should_run.push_back(false);
// If the output of an op modifies feed vars, the op should not clip.
// For example, in the transformer structure, the third parameter returned
// by beam_search op is generally assigned to a feed var. Cutting the
// assign op will cause an error.
if (parent_block_id != -1) {
bool flag = false;
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
if (feed_var_names.count(argu)) {
flag = true;
}
}
}
if (flag) {
should_run.back() = true;
}
}
}
}
......
......@@ -185,3 +185,34 @@ TEST(Prune, recurrrent_op) {
EXPECT_EQ(pruned.blocks(0).ops_size(), 2);
EXPECT_EQ(pruned.blocks(1).ops_size(), 1);
}
// If the output of an op modifies feed vars, the op should not clip.
TEST(Prune, recurrrent_op_2) {
f::ProgramDesc program;
f::BlockDesc *block = program.MutableBlock(0);
f::BlockDesc *sub_block = program.AppendBlock(*block);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}},
f::AttributeMap{}, block);
std::vector<std::string> state_var_name(1, "y");
AddOp("recurrent", {{"input", {"b", "c"}}}, {{"output", {"b1, c1"}}},
{{"ex_states", state_var_name},
{"states", state_var_name},
{"sub_block", sub_block}},
block);
EXPECT_TRUE(sub_block != nullptr);
AddOp("rnn_memory_helper", {{"input", {"x"}}}, {{"output", {"a"}}},
f::AttributeMap{}, sub_block);
f::proto::ProgramDesc *pdesc = program.Proto();
pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true);
f::proto::ProgramDesc pruned;
std::set<std::string> feed_var_names = {"x", "a"};
f::Prune(*pdesc, feed_var_names, &pruned);
EXPECT_EQ(pruned.blocks_size(), 2);
EXPECT_EQ(pruned.blocks(0).ops_size(), 2);
EXPECT_EQ(pruned.blocks(1).ops_size(), 1);
}
......@@ -3040,7 +3040,8 @@ def beam_search_decode(ids, scores, beam_size, end_id, name=None):
'beam_search_encode')
helper = LayerHelper('beam_search_decode', **locals())
sentence_ids = helper.create_variable_for_type_inference(dtype=ids.dtype)
sentence_scores = helper.create_variable_for_type_inference(dtype=ids.dtype)
sentence_scores = helper.create_variable_for_type_inference(
dtype=scores.dtype)
helper.append_op(
type="beam_search_decode",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册