From 80f5e25e4e7dc7163364b621d106e340a7bfa9ad Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 22 Dec 2022 20:01:56 +0800 Subject: [PATCH] [Paddle Inference]fix gather_nd (#49266) * fix reshape, gather_nd --- paddle/fluid/inference/tensorrt/convert/stack_op.cc | 5 ++++- paddle/fluid/inference/tensorrt/op_teller.cc | 11 +---------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/stack_op.cc b/paddle/fluid/inference/tensorrt/convert/stack_op.cc index c60d2578ec..70e634c387 100644 --- a/paddle/fluid/inference/tensorrt/convert/stack_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/stack_op.cc @@ -42,6 +42,7 @@ class StackOpConverter : public OpConverter { auto input = op_desc.Input("X"); int input_num = input.size(); std::vector inputs; + auto output_name = op_desc.Output("Y").front(); for (int i = 0; i < input_num; ++i) { inputs.push_back(engine_->GetITensor(input[i])); @@ -76,13 +77,15 @@ class StackOpConverter : public OpConverter { auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *inputs[i]); reshape_layer->setInput(1, *after_shape_tensor); inputs[i] = reshape_layer->getOutput(0); + reshape_layer->setName(("stack: reshape: (Output( " + std::to_string(i) + + " )" + output_name + ")") + .c_str()); } auto* layer = TRT_ENGINE_ADD_LAYER( engine_, Concatenation, inputs.data(), inputs.size()); layer->setAxis(axis); - auto output_name = op_desc.Output("Y").front(); RreplenishLayerAndOutput(layer, "stack", {output_name}, test_mode); } }; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index c469c4fbf3..ad9b6156d9 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -609,18 +609,9 @@ struct SimpleOpTypeSetTeller : public Teller { "the pass."; return false; } - +#if IS_TRT_VERSION_LT(8200) auto index_var_name = desc.Input("Index")[0]; auto* index_var_desc = block->FindVar(index_var_name); - - // The index input must be int32 datatype. - if (index_var_desc->GetDataType() != - paddle::framework::proto::VarType_Type::VarType_Type_INT32) { - VLOG(3) << "gather_nd op Index input data type must be int32"; - return false; - } - -#if IS_TRT_VERSION_LT(8200) auto x_var_name = desc.Input("X")[0]; auto* x_var_desc = block->FindVar(x_var_name); const auto index_shape = index_var_desc->GetShape(); -- GitLab