From ffc396055ddc48b34c0d4691c4ad5ca143a8df68 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Mon, 19 Dec 2022 14:47:47 +0800 Subject: [PATCH] [Paddle Inference]restart looup_table_v2 (#49119) * restart looup_table_v2 --- .../ir/trt_embedding_eltwise_layernorm_fuse_pass.cc | 6 ------ paddle/fluid/inference/tensorrt/convert/op_converter.h | 8 ++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 4 ++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc index 23ebbddf57..8bb0c8ce67 100644 --- a/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc @@ -64,8 +64,6 @@ void TrtEmbedding2Eltwise1Pattern::operator()() { create_emb_vars(pattern, lookup_table2_w_repr(), "W", true); std::unordered_set embedding_ops{"lookup_table", "lookup_table_v2"}; - auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed"); - auto* feed2 = pattern->NewNode(feed2_repr())->assert_is_op("feed"); auto* lookup_table1 = pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops); @@ -79,10 +77,8 @@ void TrtEmbedding2Eltwise1Pattern::operator()() { pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) ->assert_is_op_output("elementwise_add"); - feed1->LinksTo({lookup_table1_x}); lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) .LinksTo({lookup_table1_out}); - feed2->LinksTo({lookup_table2_x}); lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w}) .LinksTo({lookup_table2_out}); eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out}) @@ -95,7 +91,6 @@ void TrtEmbedding1Eltwise1Pattern::operator()() { create_emb_vars(pattern, lookup_table1_w_repr(), "W", true); std::unordered_set embedding_ops{"lookup_table", "lookup_table_v2"}; - auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed"); auto* lookup_table1 = pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops); @@ -110,7 +105,6 @@ void TrtEmbedding1Eltwise1Pattern::operator()() { ->assert_is_op_output("elementwise_add"); lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) .LinksTo({lookup_table1_out}); - feed1->LinksTo({lookup_table1_x}); eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in}) .LinksTo({eltwise_add_out}); } diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 74074303b9..453c9b7ff6 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -151,6 +151,14 @@ class OpConverter { platform::errors::Unimplemented("no OpConverter for optype [%s]", op_desc.Type())); } + // lookup_table_v2 == lookup_table + if (op_desc.Type() == "lookup_table_v2") { + it = Registry::Global().Lookup("lookup_table"); + PADDLE_ENFORCE_NOT_NULL( + it, + platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); + } if (!it) { it = Registry::Global().Lookup(op_desc.Type()); } diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index bd224b136c..fea54fea3f 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -2573,7 +2573,7 @@ struct SimpleOpTypeSetTeller : public Teller { "lookup_table", "merge_layernorm", "skip_merge_layernorm", - // "lookup_table_v2", + "lookup_table_v2", "expand_v2"}; std::unordered_set teller_set{ @@ -2719,7 +2719,7 @@ struct SimpleOpTypeSetTeller : public Teller { "merge_layernorm", "skip_merge_layernorm", "lookup_table", - // "lookup_table_v2", + "lookup_table_v2", "expand_v2"}; }; -- GitLab