未验证 提交 ba2f8f0c 编写于 作者: W Wilber 提交者: GitHub

fix embedding_eltwise_layernorm_fuse_pass. test=develop (#24592)

上级 1ad6317b
...@@ -36,8 +36,9 @@ static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name, ...@@ -36,8 +36,9 @@ static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name, static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name,
const std::string& arg) { const std::string& arg) {
PDNode* node = pattern->NewNode(name) PDNode* node = pattern->NewNode(name)
->assert_is_op_output("lookup_table") ->assert_is_only_output_of_op("lookup_table")
->assert_is_op_input("elementwise_add", arg); ->assert_is_op_input("elementwise_add", arg)
->AsIntermediate();
return node; return node;
} }
void Embedding2Eltwise1Pattern::operator()() { void Embedding2Eltwise1Pattern::operator()() {
...@@ -94,7 +95,8 @@ void SkipLayerNorm::operator()() { ...@@ -94,7 +95,8 @@ void SkipLayerNorm::operator()() {
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->assert_is_op_input("layer_norm", "X"); ->assert_is_op_input("layer_norm", "X")
->AsIntermediate();
auto* layer_norm = auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr()) auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册