未验证 提交 ffc39605 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference]restart looup_table_v2 (#49119)

* restart looup_table_v2
上级 922f0868
...@@ -64,8 +64,6 @@ void TrtEmbedding2Eltwise1Pattern::operator()() { ...@@ -64,8 +64,6 @@ void TrtEmbedding2Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table2_w_repr(), "W", true); create_emb_vars(pattern, lookup_table2_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table", std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"}; "lookup_table_v2"};
auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed");
auto* feed2 = pattern->NewNode(feed2_repr())->assert_is_op("feed");
auto* lookup_table1 = auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops); pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
...@@ -79,10 +77,8 @@ void TrtEmbedding2Eltwise1Pattern::operator()() { ...@@ -79,10 +77,8 @@ void TrtEmbedding2Eltwise1Pattern::operator()() {
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add"); ->assert_is_op_output("elementwise_add");
feed1->LinksTo({lookup_table1_x});
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out}); .LinksTo({lookup_table1_out});
feed2->LinksTo({lookup_table2_x});
lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w}) lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w})
.LinksTo({lookup_table2_out}); .LinksTo({lookup_table2_out});
eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out}) eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out})
...@@ -95,7 +91,6 @@ void TrtEmbedding1Eltwise1Pattern::operator()() { ...@@ -95,7 +91,6 @@ void TrtEmbedding1Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true); create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table", std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"}; "lookup_table_v2"};
auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed");
auto* lookup_table1 = auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops); pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
...@@ -110,7 +105,6 @@ void TrtEmbedding1Eltwise1Pattern::operator()() { ...@@ -110,7 +105,6 @@ void TrtEmbedding1Eltwise1Pattern::operator()() {
->assert_is_op_output("elementwise_add"); ->assert_is_op_output("elementwise_add");
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out}); .LinksTo({lookup_table1_out});
feed1->LinksTo({lookup_table1_x});
eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in}) eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in})
.LinksTo({eltwise_add_out}); .LinksTo({eltwise_add_out});
} }
......
...@@ -151,6 +151,14 @@ class OpConverter { ...@@ -151,6 +151,14 @@ class OpConverter {
platform::errors::Unimplemented("no OpConverter for optype [%s]", platform::errors::Unimplemented("no OpConverter for optype [%s]",
op_desc.Type())); op_desc.Type()));
} }
// lookup_table_v2 == lookup_table
if (op_desc.Type() == "lookup_table_v2") {
it = Registry<OpConverter>::Global().Lookup("lookup_table");
PADDLE_ENFORCE_NOT_NULL(
it,
platform::errors::Unimplemented("no OpConverter for optype [%s]",
op_desc.Type()));
}
if (!it) { if (!it) {
it = Registry<OpConverter>::Global().Lookup(op_desc.Type()); it = Registry<OpConverter>::Global().Lookup(op_desc.Type());
} }
......
...@@ -2573,7 +2573,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2573,7 +2573,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"lookup_table", "lookup_table",
"merge_layernorm", "merge_layernorm",
"skip_merge_layernorm", "skip_merge_layernorm",
// "lookup_table_v2", "lookup_table_v2",
"expand_v2"}; "expand_v2"};
std::unordered_set<std::string> teller_set{ std::unordered_set<std::string> teller_set{
...@@ -2719,7 +2719,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2719,7 +2719,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"merge_layernorm", "merge_layernorm",
"skip_merge_layernorm", "skip_merge_layernorm",
"lookup_table", "lookup_table",
// "lookup_table_v2", "lookup_table_v2",
"expand_v2"}; "expand_v2"};
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册