From 08b44e67b3ca8b0f40ef8150bd5ddae635108022 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Sat, 6 May 2023 20:24:33 +0800 Subject: [PATCH] [inference][trt] add lookup_table op trt converter, use trt gather layer (#53554) * add lookup_table op trt converter * update --- .../inference/tensorrt/convert/CMakeLists.txt | 1 + .../convert/fused_lookup_tables_op.cc | 4 +- .../tensorrt/convert/lookup_table_op.cc | 47 +++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/lookup_table_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 13c0137c7d8..d09c344ca16 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -100,6 +100,7 @@ list( skip_merge_layernorm_op.cc generic_and_custom_plugin_creater.cc fused_lookup_tables_op.cc + lookup_table_op.cc elementwiseadd_transpose_op.cc skip_groupnorm_act_op.cc preln_groupnorm_act_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/fused_lookup_tables_op.cc b/paddle/fluid/inference/tensorrt/convert/fused_lookup_tables_op.cc index 27afced4125..aaeea2d7258 100644 --- a/paddle/fluid/inference/tensorrt/convert/fused_lookup_tables_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fused_lookup_tables_op.cc @@ -110,4 +110,6 @@ class FusedLookupTablesOpConverter : public OpConverter { } // namespace inference } // namespace paddle -REGISTER_TRT_OP_CONVERTER(lookup_table, FusedLookupTablesOpConverter); +// NOTE(liuyuanle): We will remove the implementation here later. Ref to +// tensorrt/convert/lookup_table_op.cc. +// REGISTER_TRT_OP_CONVERTER(lookup_table, FusedLookupTablesOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/lookup_table_op.cc b/paddle/fluid/inference/tensorrt/convert/lookup_table_op.cc new file mode 100644 index 00000000000..e517a00ec94 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/lookup_table_op.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class LookupTableOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + framework::OpDesc op_desc(op, nullptr); + VLOG(3) + << "convert lookup_table(lookup_table_v2) op to TensorRT IGatherLayer"; + + auto ids_name = op_desc.Input("Ids").front(); + auto w_name = op_desc.Input("W").front(); + auto out_name = op_desc.Output("Out").front(); + + auto* ids_tensor = engine_->GetITensor(ids_name); + auto* w_tensor = engine_->GetITensor(w_name); + + auto* gather_layer = + TRT_ENGINE_ADD_LAYER(engine_, Gather, *w_tensor, *ids_tensor, 0); + RreplenishLayerAndOutput(gather_layer, "gather", {out_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(lookup_table, LookupTableOpConverter); -- GitLab