diff --git a/paddle/fluid/inference/tensorrt/convert/gather_nd_op.cc b/paddle/fluid/inference/tensorrt/convert/gather_nd_op.cc index c6ed96a674d1342b5986f0565d23c0410efa17aa..2b926ccdf122146b2b81eeb8ecd451954736cae2 100644 --- a/paddle/fluid/inference/tensorrt/convert/gather_nd_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/gather_nd_op.cc @@ -24,13 +24,24 @@ class GatherNdOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - VLOG(4) << "convert a paddle gather_nd op to tensorrt gather_nd plugin"; framework::OpDesc op_desc(op, nullptr); + auto input = engine_->GetITensor(op_desc.Input("X")[0]); + auto index = engine_->GetITensor(op_desc.Input("Index")[0]); + auto output_name = op_desc.Output("Out")[0]; + + // AddGatherV2 is supported by the trt version of 8.2. +#if IS_TRT_VERSION_GE(8200) + VLOG(3) << "convert gather_nd op to tensorrt gather_nd layer"; + + auto layer = TRT_ENGINE_ADD_LAYER( + engine_, GatherV2, *input, *index, nvinfer1::GatherMode::kND); + layer->setNbElementWiseDims(0); + RreplenishLayerAndOutput(layer, "gather_nd", {output_name}, test_mode); +#else + VLOG(4) << "convert a paddle gather_nd op to tensorrt gather_nd plugin"; // Declare inputs std::vector inputs; - auto* input = engine_->GetITensor(op_desc.Input("X")[0]); - auto* index = engine_->GetITensor(op_desc.Input("Index")[0]); inputs.emplace_back(input); inputs.emplace_back(index); @@ -41,7 +52,6 @@ class GatherNdOpConverter : public OpConverter { layer = engine_->AddDynamicPlugin(inputs.data(), inputs.size(), plugin); std::string layer_name = "gather_nd (Output: "; - auto output_name = op_desc.Output("Out")[0]; layer->getOutput(0)->setName(output_name.c_str()); engine_->SetITensor(output_name, layer->getOutput(0)); layer_name += output_name; @@ -49,6 +59,7 @@ class GatherNdOpConverter : public OpConverter { engine_->DeclareOutput(output_name); } layer->setName((layer_name + ")").c_str()); +#endif } }; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 24d5209baaa4ffae9368f0146f483736570ace7c..363b3132a1536b64862433fd5edfc51d483612fc 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -566,9 +566,8 @@ struct SimpleOpTypeSetTeller : public Teller { "the pass."; return false; } - auto x_var_name = desc.Input("X")[0]; + auto index_var_name = desc.Input("Index")[0]; - auto* x_var_desc = block->FindVar(x_var_name); auto* index_var_desc = block->FindVar(index_var_name); // The index input must be int32 datatype. @@ -578,6 +577,9 @@ struct SimpleOpTypeSetTeller : public Teller { return false; } +#if IS_TRT_VERSION_LT(8200) + auto x_var_name = desc.Input("X")[0]; + auto* x_var_desc = block->FindVar(x_var_name); const auto index_shape = index_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape(); if (x_shape.size() <= 2) { @@ -591,6 +593,7 @@ struct SimpleOpTypeSetTeller : public Teller { << " ] not equal to x dims size [" << x_shape.size() << "]"; return false; } +#endif } if (op_type == "anchor_generator") { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather_nd.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather_nd.py index 10d3048a26cbe552db622a3d2e241d5de07a332e..8b32e5516b97e6976dd5eeb9f8b651e34809d5c6 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather_nd.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather_nd.py @@ -69,11 +69,11 @@ class TrtConvertGatherNdTest_dim_4_1(TrtLayerAutoScanTest): ) -> (paddle_infer.Config, List[int], float): def generate_dynamic_shape(attrs): self.dynamic_shape.min_input_shape = { - "input_data": [1, 8, 8, 8], + "input_data": [2, 32, 64, 64], "index_data": [1], } self.dynamic_shape.max_input_shape = { - "input_data": [4, 32, 64, 64], + "input_data": [2, 32, 64, 64], "index_data": [1], } self.dynamic_shape.opt_input_shape = { @@ -159,11 +159,11 @@ class TrtConvertGatherNdTest_dim_4_1_2(TrtLayerAutoScanTest): ) -> (paddle_infer.Config, List[int], float): def generate_dynamic_shape(attrs): self.dynamic_shape.min_input_shape = { - "input_data": [1, 8, 8, 8], + "input_data": [2, 32, 64, 64], "index_data": [2], } self.dynamic_shape.max_input_shape = { - "input_data": [4, 32, 64, 64], + "input_data": [2, 32, 64, 64], "index_data": [2], } self.dynamic_shape.opt_input_shape = { @@ -249,11 +249,11 @@ class TrtConvertGatherNdTest_dim_4_2(TrtLayerAutoScanTest): ) -> (paddle_infer.Config, List[int], float): def generate_dynamic_shape(attrs): self.dynamic_shape.min_input_shape = { - "input_data": [1, 8, 8, 8], + "input_data": [2, 32, 64, 64], "index_data": [2, 2], } self.dynamic_shape.max_input_shape = { - "input_data": [4, 32, 64, 64], + "input_data": [2, 32, 64, 64], "index_data": [2, 2], } self.dynamic_shape.opt_input_shape = { @@ -339,11 +339,11 @@ class TrtConvertGatherNdTest_dim_4_3(TrtLayerAutoScanTest): ) -> (paddle_infer.Config, List[int], float): def generate_dynamic_shape(attrs): self.dynamic_shape.min_input_shape = { - "input_data": [1, 8, 8, 8], + "input_data": [2, 32, 64, 64], "index_data": [2, 2, 4], } self.dynamic_shape.max_input_shape = { - "input_data": [4, 32, 64, 64], + "input_data": [2, 32, 64, 64], "index_data": [2, 2, 4], } self.dynamic_shape.opt_input_shape = { @@ -429,15 +429,15 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest): ) -> (paddle_infer.Config, List[int], float): def generate_dynamic_shape(attrs): self.dynamic_shape.min_input_shape = { - "input_data": [1, 4], + "input_data": [2, 32], "index_data": [2, 2], } self.dynamic_shape.max_input_shape = { - "input_data": [4, 64], + "input_data": [2, 32], "index_data": [2, 2], } self.dynamic_shape.opt_input_shape = { - "input_data": [2, 8], + "input_data": [2, 32], "index_data": [2, 2], } @@ -521,15 +521,15 @@ class TrtConvertGatherNdTest_dim_3_3(TrtLayerAutoScanTest): ) -> (paddle_infer.Config, List[int], float): def generate_dynamic_shape(attrs): self.dynamic_shape.min_input_shape = { - "input_data": [1, 4, 4], - "index_data": [1, 1, 1], + "input_data": [16, 32, 256], + "index_data": [2, 2, 2], } self.dynamic_shape.max_input_shape = { - "input_data": [16, 64, 512], - "index_data": [4, 2, 4], + "input_data": [16, 32, 256], + "index_data": [2, 2, 2], } self.dynamic_shape.opt_input_shape = { - "input_data": [2, 8, 64], + "input_data": [16, 32, 256], "index_data": [2, 2, 2], }