diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 13c0137c7d895b1b77a3936e29567dd58e698000..d09c344ca1643a50997c0c69130b8c600f78c60c 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -100,6 +100,7 @@ list( skip_merge_layernorm_op.cc generic_and_custom_plugin_creater.cc fused_lookup_tables_op.cc + lookup_table_op.cc elementwiseadd_transpose_op.cc skip_groupnorm_act_op.cc preln_groupnorm_act_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/fused_lookup_tables_op.cc b/paddle/fluid/inference/tensorrt/convert/fused_lookup_tables_op.cc index 27afced4125f82054fbc2480c7b9686ac0927c3e..aaeea2d725809eac1c26ec51b9c79d1f5d251d26 100644 --- a/paddle/fluid/inference/tensorrt/convert/fused_lookup_tables_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fused_lookup_tables_op.cc @@ -110,4 +110,6 @@ class FusedLookupTablesOpConverter : public OpConverter { } // namespace inference } // namespace paddle -REGISTER_TRT_OP_CONVERTER(lookup_table, FusedLookupTablesOpConverter); +// NOTE(liuyuanle): We will remove the implementation here later. Ref to +// tensorrt/convert/lookup_table_op.cc. +// REGISTER_TRT_OP_CONVERTER(lookup_table, FusedLookupTablesOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/lookup_table_op.cc b/paddle/fluid/inference/tensorrt/convert/lookup_table_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e517a00ec94487d34f4e7c8e32feae71ac9015a7 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/lookup_table_op.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class LookupTableOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + framework::OpDesc op_desc(op, nullptr); + VLOG(3) + << "convert lookup_table(lookup_table_v2) op to TensorRT IGatherLayer"; + + auto ids_name = op_desc.Input("Ids").front(); + auto w_name = op_desc.Input("W").front(); + auto out_name = op_desc.Output("Out").front(); + + auto* ids_tensor = engine_->GetITensor(ids_name); + auto* w_tensor = engine_->GetITensor(w_name); + + auto* gather_layer = + TRT_ENGINE_ADD_LAYER(engine_, Gather, *w_tensor, *ids_tensor, 0); + RreplenishLayerAndOutput(gather_layer, "gather", {out_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(lookup_table, LookupTableOpConverter);