未验证 提交 ffc39605 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference]restart looup_table_v2 (#49119)

* restart looup_table_v2
上级 922f0868
......@@ -64,8 +64,6 @@ void TrtEmbedding2Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table2_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed");
auto* feed2 = pattern->NewNode(feed2_repr())->assert_is_op("feed");
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
......@@ -79,10 +77,8 @@ void TrtEmbedding2Eltwise1Pattern::operator()() {
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add");
feed1->LinksTo({lookup_table1_x});
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
feed2->LinksTo({lookup_table2_x});
lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w})
.LinksTo({lookup_table2_out});
eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out})
......@@ -95,7 +91,6 @@ void TrtEmbedding1Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed");
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
......@@ -110,7 +105,6 @@ void TrtEmbedding1Eltwise1Pattern::operator()() {
->assert_is_op_output("elementwise_add");
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
feed1->LinksTo({lookup_table1_x});
eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in})
.LinksTo({eltwise_add_out});
}
......
......@@ -151,6 +151,14 @@ class OpConverter {
platform::errors::Unimplemented("no OpConverter for optype [%s]",
op_desc.Type()));
}
// lookup_table_v2 == lookup_table
if (op_desc.Type() == "lookup_table_v2") {
it = Registry<OpConverter>::Global().Lookup("lookup_table");
PADDLE_ENFORCE_NOT_NULL(
it,
platform::errors::Unimplemented("no OpConverter for optype [%s]",
op_desc.Type()));
}
if (!it) {
it = Registry<OpConverter>::Global().Lookup(op_desc.Type());
}
......
......@@ -2573,7 +2573,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"lookup_table",
"merge_layernorm",
"skip_merge_layernorm",
// "lookup_table_v2",
"lookup_table_v2",
"expand_v2"};
std::unordered_set<std::string> teller_set{
......@@ -2719,7 +2719,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"merge_layernorm",
"skip_merge_layernorm",
"lookup_table",
// "lookup_table_v2",
"lookup_table_v2",
"expand_v2"};
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册