diff --git a/paddle/fluid/inference/tensorrt/convert/stack_op.cc b/paddle/fluid/inference/tensorrt/convert/stack_op.cc index c60d2578ec0a333fd705a76155b01aa57ff64855..70e634c38779df6465eeed8b11fe7656f3d23336 100644 --- a/paddle/fluid/inference/tensorrt/convert/stack_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/stack_op.cc @@ -42,6 +42,7 @@ class StackOpConverter : public OpConverter { auto input = op_desc.Input("X"); int input_num = input.size(); std::vector inputs; + auto output_name = op_desc.Output("Y").front(); for (int i = 0; i < input_num; ++i) { inputs.push_back(engine_->GetITensor(input[i])); @@ -76,13 +77,15 @@ class StackOpConverter : public OpConverter { auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *inputs[i]); reshape_layer->setInput(1, *after_shape_tensor); inputs[i] = reshape_layer->getOutput(0); + reshape_layer->setName(("stack: reshape: (Output( " + std::to_string(i) + + " )" + output_name + ")") + .c_str()); } auto* layer = TRT_ENGINE_ADD_LAYER( engine_, Concatenation, inputs.data(), inputs.size()); layer->setAxis(axis); - auto output_name = op_desc.Output("Y").front(); RreplenishLayerAndOutput(layer, "stack", {output_name}, test_mode); } }; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index c469c4fbf3c91dbe61a33d88dba9873ebd613afb..ad9b6156d9caf4a6edc31c3ace7eba78bdd1940f 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -609,18 +609,9 @@ struct SimpleOpTypeSetTeller : public Teller { "the pass."; return false; } - +#if IS_TRT_VERSION_LT(8200) auto index_var_name = desc.Input("Index")[0]; auto* index_var_desc = block->FindVar(index_var_name); - - // The index input must be int32 datatype. - if (index_var_desc->GetDataType() != - paddle::framework::proto::VarType_Type::VarType_Type_INT32) { - VLOG(3) << "gather_nd op Index input data type must be int32"; - return false; - } - -#if IS_TRT_VERSION_LT(8200) auto x_var_name = desc.Input("X")[0]; auto* x_var_desc = block->FindVar(x_var_name); const auto index_shape = index_var_desc->GetShape();