From cdd7b956e7a99401a09d99130a1757be2d979bf3 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Wed, 9 Nov 2022 10:30:10 +0800 Subject: [PATCH] [Paddle Inference]upgrade scale and slice op convert for Paddle-TensorRT (#47746) * upgrade scale and slice op convert for Paddle-TensorRT --- .../inference/tensorrt/convert/scale_op.cc | 254 ++++++---- .../inference/tensorrt/convert/slice_op.cc | 172 +++---- paddle/fluid/inference/tensorrt/op_teller.cc | 60 ++- .../inference/tensorrt/plugin/CMakeLists.txt | 1 - .../tensorrt/plugin/slice_op_plugin.cu | 435 ------------------ .../tensorrt/plugin/slice_op_plugin.h | 187 -------- .../ir/inference/test_trt_convert_clip.py | 7 +- 7 files changed, 294 insertions(+), 822 deletions(-) delete mode 100644 paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu delete mode 100644 paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h diff --git a/paddle/fluid/inference/tensorrt/convert/scale_op.cc b/paddle/fluid/inference/tensorrt/convert/scale_op.cc index a3b2e65ac49..d770c21a9ad 100644 --- a/paddle/fluid/inference/tensorrt/convert/scale_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/scale_op.cc @@ -49,106 +49,178 @@ class ScaleOpConverter : public OpConverter { PADDLE_GET_CONST(bool, op_desc.GetAttr("bias_after_scale")); float bias = PADDLE_GET_CONST(float, op_desc.GetAttr("bias")); float scale = PADDLE_GET_CONST(float, op_desc.GetAttr("scale")); - auto create_weights = [&](float data, std::string type) -> float* { - std::unique_ptr tmp_tensor(new phi::DenseTensor()); - tmp_tensor->Resize({1}); - auto* tmp_data = tmp_tensor->mutable_data(platform::CPUPlace()); - tmp_data[0] = data; - engine_->SetWeights(out_name + "_scale_op_" + type, - std::move(tmp_tensor)); - return tmp_data; - }; - - int dynamic_shape_offset = engine_->with_dynamic_shape() ? 1 : 0; - - float* bias_ptr = create_weights(bias, "bias"); - float* scale_ptr = create_weights(scale, "scale"); - - TensorRTEngine::Weight scale_weights{ - nvinfer1::DataType::kFLOAT, static_cast(scale_ptr), 1}; - TensorRTEngine::Weight shift_weights{ - nvinfer1::DataType::kFLOAT, static_cast(bias_ptr), 1}; - TensorRTEngine::Weight power_weights{ - nvinfer1::DataType::kFLOAT, nullptr, 0}; nvinfer1::ILayer* layer = nullptr; + if (engine_->with_dynamic_shape()) { + nvinfer1::ITensor* bias_tensor = Add1DConstantLayer(bias); + bool is_bias_0 = (bias < 1e-06 && bias > -1e-06); + + std::vector bias_shapes(input->getDimensions().nbDims, 1); + auto* bias_shapes_tensor = Add1DConstantLayer(bias_shapes); + auto* reshape_layer_bias = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *bias_tensor); + reshape_layer_bias->setInput(1, *bias_shapes_tensor); + + bool has_scale_tensor; + nvinfer1::ITensor* scale_tensor; + bool is_scale_1; + + auto scale_inputs = op_desc.Inputs(); + if (scale_inputs.find("ScaleTensor") != scale_inputs.end() && + op_desc.Input("ScaleTensor").size()) { // has EndsTensor input + has_scale_tensor = true; + scale_tensor = engine_->GetITensor(op_desc.Input("ScaleTensor")[0]); + is_scale_1 = false; + } else { + has_scale_tensor = false; + scale_tensor = Add1DConstantLayer(scale); + is_scale_1 = ((scale - 1.0) < 1e-06 && (scale - 1.0) > -1e-06); + } - auto input_dim = input->getDimensions(); - - nvinfer1::IShuffleLayer* expand_layer = nullptr; - nvinfer1::IShuffleLayer* squeeze_layer = nullptr; - - if (input_dim.nbDims < 3 + dynamic_shape_offset) { - nvinfer1::Dims expand_shape; - expand_shape.nbDims = 3 + dynamic_shape_offset; - for (int i = 0; i < 3 + dynamic_shape_offset; i++) { - if (i < input_dim.nbDims) { - expand_shape.d[i] = input_dim.d[i] < 0 ? 0 : input_dim.d[i]; + std::vector scale_shapes(input->getDimensions().nbDims, 1); + auto* scale_shapes_tensor = Add1DConstantLayer(scale_shapes); + auto* reshape_layer_scale = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *scale_tensor); + reshape_layer_scale->setInput(1, *scale_shapes_tensor); + + if (!has_scale_tensor && is_scale_1 && is_bias_0) { + layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *input); + } else { + if (bias_after_scale) { + if (!is_scale_1) { + layer = TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *input, + *reshape_layer_scale->getOutput(0), + nvinfer1::ElementWiseOperation::kPROD); + input = layer->getOutput(0); + } + if (!is_bias_0) { + layer = TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *input, + *reshape_layer_bias->getOutput(0), + nvinfer1::ElementWiseOperation::kSUM); + } } else { - expand_shape.d[i] = 1; + if (!is_bias_0) { + layer = TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *input, + *reshape_layer_bias->getOutput(0), + nvinfer1::ElementWiseOperation::kSUM); + input = layer->getOutput(0); + } + if (!is_scale_1) { + layer = TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *input, + *reshape_layer_scale->getOutput(0), + nvinfer1::ElementWiseOperation::kPROD); + } } } - expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); - expand_layer->setReshapeDimensions(expand_shape); - input = expand_layer->getOutput(0); - expand_layer->getOutput(0)->setName( - ("before_reshape_out: " + out_name).c_str()); - expand_layer->setName( - ("Scale: before_reshape (Output: " + out_name + ")").c_str()); - } - - if (bias_after_scale) { - layer = TRT_ENGINE_ADD_LAYER(engine_, - Scale, - *input, - nvinfer1::ScaleMode::kUNIFORM, - shift_weights.get(), - scale_weights.get(), - power_weights.get()); - layer->getOutput(0)->setName( - ("bias_after_scale_out: " + out_name).c_str()); - layer->setName(("Scale: scale (Output: " + out_name + ")").c_str()); } else { - // add bias - layer = TRT_ENGINE_ADD_LAYER(engine_, - Scale, - *(input), - nvinfer1::ScaleMode::kUNIFORM, - shift_weights.get(), - power_weights.get(), - power_weights.get()); - layer->getOutput(0)->setName( - ("bias_before_scale:bias_out: " + out_name).c_str()); - layer->setName(("Scale: scale_bias (Output: " + out_name + ")").c_str()); - // mul scale - layer = TRT_ENGINE_ADD_LAYER(engine_, - Scale, - *(layer->getOutput(0)), - nvinfer1::ScaleMode::kUNIFORM, - power_weights.get(), - scale_weights.get(), - power_weights.get()); - layer->getOutput(0)->setName( - ("bias_before_scale:scale_out: " + out_name).c_str()); - layer->setName(("Scale: scale_scale (Output: " + out_name + ")").c_str()); - } + auto create_weights = [&](float data, std::string type) -> float* { + std::unique_ptr tmp_tensor(new phi::DenseTensor()); + tmp_tensor->Resize({1}); + auto* tmp_data = tmp_tensor->mutable_data(platform::CPUPlace()); + tmp_data[0] = data; + engine_->SetWeights(out_name + "_scale_op_" + type, + std::move(tmp_tensor)); + return tmp_data; + }; + + float* bias_ptr = create_weights(bias, "bias"); + float* scale_ptr = create_weights(scale, "scale"); + + TensorRTEngine::Weight scale_weights{ + nvinfer1::DataType::kFLOAT, static_cast(scale_ptr), 1}; + TensorRTEngine::Weight shift_weights{ + nvinfer1::DataType::kFLOAT, static_cast(bias_ptr), 1}; + TensorRTEngine::Weight power_weights{ + nvinfer1::DataType::kFLOAT, nullptr, 0}; + + auto input_dim = input->getDimensions(); + + nvinfer1::IShuffleLayer* expand_layer = nullptr; + nvinfer1::IShuffleLayer* squeeze_layer = nullptr; + + if (input_dim.nbDims < 3) { + nvinfer1::Dims expand_shape; + expand_shape.nbDims = 3; + for (int i = 0; i < 3; i++) { + if (i < input_dim.nbDims) { + expand_shape.d[i] = input_dim.d[i] < 0 ? 0 : input_dim.d[i]; + } else { + expand_shape.d[i] = 1; + } + } + expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + expand_layer->setReshapeDimensions(expand_shape); + input = expand_layer->getOutput(0); + expand_layer->getOutput(0)->setName( + ("before_reshape_out: " + out_name).c_str()); + expand_layer->setName( + ("Scale: before_reshape (Output: " + out_name + ")").c_str()); + } - PADDLE_ENFORCE_EQ(layer != nullptr, - true, - platform::errors::Fatal("Create scale layer failed.")); + if (bias_after_scale) { + layer = TRT_ENGINE_ADD_LAYER(engine_, + Scale, + *input, + nvinfer1::ScaleMode::kUNIFORM, + shift_weights.get(), + scale_weights.get(), + power_weights.get()); + layer->getOutput(0)->setName( + ("bias_after_scale_out: " + out_name).c_str()); + layer->setName(("Scale: scale (Output: " + out_name + ")").c_str()); + } else { + // add bias + layer = TRT_ENGINE_ADD_LAYER(engine_, + Scale, + *(input), + nvinfer1::ScaleMode::kUNIFORM, + shift_weights.get(), + power_weights.get(), + power_weights.get()); + layer->getOutput(0)->setName( + ("bias_before_scale:bias_out: " + out_name).c_str()); + layer->setName( + ("Scale: scale_bias (Output: " + out_name + ")").c_str()); + // mul scale + layer = TRT_ENGINE_ADD_LAYER(engine_, + Scale, + *(layer->getOutput(0)), + nvinfer1::ScaleMode::kUNIFORM, + power_weights.get(), + scale_weights.get(), + power_weights.get()); + layer->getOutput(0)->setName( + ("bias_before_scale:scale_out: " + out_name).c_str()); + layer->setName( + ("Scale: scale_scale (Output: " + out_name + ")").c_str()); + } - if (input_dim.nbDims < 3 + dynamic_shape_offset) { - nvinfer1::Dims squeeze_shape; - squeeze_shape.nbDims = input_dim.nbDims; - for (int i = 0; i < squeeze_shape.nbDims; i++) { - squeeze_shape.d[i] = input_dim.d[i] < 0 ? 0 : input_dim.d[i]; + PADDLE_ENFORCE_EQ(layer != nullptr, + true, + platform::errors::Fatal("Create scale layer failed.")); + + if (input_dim.nbDims < 3) { + nvinfer1::Dims squeeze_shape; + squeeze_shape.nbDims = input_dim.nbDims; + for (int i = 0; i < squeeze_shape.nbDims; i++) { + squeeze_shape.d[i] = input_dim.d[i] < 0 ? 0 : input_dim.d[i]; + } + squeeze_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0))); + squeeze_layer->setReshapeDimensions(squeeze_shape); + layer = static_cast(squeeze_layer); + layer->getOutput(0)->setName( + ("after_reshape_out: " + out_name).c_str()); + layer->setName( + ("Scale: Shuffle_reshape (Output: " + out_name + ")").c_str()); } - squeeze_layer = - TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0))); - squeeze_layer->setReshapeDimensions(squeeze_shape); - layer = static_cast(squeeze_layer); - layer->getOutput(0)->setName(("after_reshape_out: " + out_name).c_str()); - layer->setName( - ("Scale: Shuffle_reshape (Output: " + out_name + ")").c_str()); } RreplenishLayerAndOutput(layer, "scale", {out_name}, test_mode); } diff --git a/paddle/fluid/inference/tensorrt/convert/slice_op.cc b/paddle/fluid/inference/tensorrt/convert/slice_op.cc index 3b663ba5a61..0081a7d8069 100644 --- a/paddle/fluid/inference/tensorrt/convert/slice_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/slice_op.cc @@ -10,7 +10,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" -#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h" namespace paddle { namespace inference { @@ -34,7 +33,6 @@ class SliceOpConverter : public OpConverter { out_scale = PADDLE_GET_CONST(float, op_desc.GetAttr("out_threshold")); engine_->SetTensorDynamicRange(input, out_scale); } - std::vector axes = PADDLE_GET_CONST(std::vector, op_desc.GetAttr("axes")); std::vector starts = @@ -43,82 +41,85 @@ class SliceOpConverter : public OpConverter { PADDLE_GET_CONST(std::vector, op_desc.GetAttr("ends")); std::vector decrease_axises = PADDLE_GET_CONST(std::vector, op_desc.GetAttr("decrease_axis")); - auto input_dims = input->getDimensions(); - if (!engine_->with_dynamic_shape()) { - // notice that input shape is [CHW] without batch axis when input has - // static shape - for (size_t i = input_dims.nbDims; i > 0; i--) { - input_dims.d[i] = input_dims.d[i - 1]; - } - input_dims.d[0] = 1; // fake batchsize, not useful here - for (size_t i = 0; i < axes.size(); i++) { - if (starts[i] < 0) { - starts[i] = std::max(starts[i] + input_dims.d[axes[i]], 0); - } - if (ends[i] < 0) { - ends[i] = std::max(ends[i] + input_dims.d[axes[i]], 0); - } - ends[i] = std::min(ends[i], input_dims.d[axes[i]]); - PADDLE_ENFORCE_GT( - ends[i], - starts[i], - platform::errors::InvalidArgument( - "Attr(ends) should be greater than attr(starts) in " - "slice op. But received ends = %d, starts = %d.", - ends[i], - starts[i])); - } - } - nvinfer1::ILayer* layer = nullptr; + if (engine_->with_dynamic_shape()) { -#if IS_TRT_VERSION_GE(6000) - auto nchw_input_dims = input->getDimensions(); + auto* shape_tensor = Shape(input); nvinfer1::Dims trt_start_dims; - trt_start_dims.nbDims = nchw_input_dims.nbDims; - memset(trt_start_dims.d, 0, sizeof(int32_t) * nchw_input_dims.nbDims); + trt_start_dims.nbDims = input_dims.nbDims; + memset(trt_start_dims.d, 0, sizeof(int32_t) * input_dims.nbDims); nvinfer1::Dims trt_size_dims = trt_start_dims; - nvinfer1::Dims trt_end_dims = trt_start_dims; nvinfer1::Dims trt_step_dims = trt_start_dims; for (int i = 0; i < trt_step_dims.nbDims; i++) trt_step_dims.d[i] = 1; - - // input : [N,C,H,W] - bool has_neg_indices = false; - for (size_t i = 0; i < axes.size(); i++) { - int trt_axis = axes[i]; - trt_start_dims.d[trt_axis] = starts[i]; - trt_end_dims.d[trt_axis] = ends[i]; - if (starts[i] < 0 || ends[i] < 0) has_neg_indices = true; + nvinfer1::ITensor* start_tensor = nullptr; + nvinfer1::ITensor* end_tensor = nullptr; + + std::vector starts_tensor; + std::vector ends_tensor; + for (int32_t i = 0; i < input_dims.nbDims; ++i) { + starts_tensor.push_back(Add1DConstantLayer(0)); + ends_tensor.push_back(GetEleTensorOfShape(shape_tensor, i)); } - auto* shape_tensor = Shape(input); - auto* start_tensor = Add1DConstantLayer(trt_start_dims); - if (has_neg_indices) { - start_tensor = FixNegIndices(shape_tensor, start_tensor); - } - - std::vector end_vec_tensor; - for (int i = 0; i < trt_end_dims.nbDims; i++) { - end_vec_tensor.push_back(GetEleTensorOfShape(shape_tensor, i)); + auto slice_inputs = op_desc.Inputs(); + if (slice_inputs.find("StartsTensor") != slice_inputs.end() && + op_desc.Input("StartsTensor").size()) { // has StartsTensor input + for (size_t i = 0; i < axes.size(); ++i) { + starts_tensor[axes[i]] = GetEleTensorOfShape( + engine_->GetITensor(op_desc.Input("StartsTensor")[0]), i); + } + } else { + PADDLE_ENFORCE_EQ(starts.size(), + axes.size(), + platform::errors::InvalidArgument( + "The size of this starts: %d must be " + "equal to the axes: %d.", + starts.size(), + axes.size())); + for (size_t i = 0; i < axes.size(); i++) { // same as starts.size() + if (starts[i] < 0) { + starts_tensor[axes[i]] = + Max(Sum(Add1DConstantLayer(starts[i]), + GetEleTensorOfShape(shape_tensor, axes[i])), + Add1DConstantLayer(0)); + } else { + starts_tensor[axes[i]] = + Min(Add1DConstantLayer(starts[i]), + GetEleTensorOfShape(shape_tensor, axes[i])); + } + } } + start_tensor = Concat(starts_tensor); - for (size_t i = 0; i < axes.size(); i++) { - int trt_axis = axes[i]; - if (ends[i] >= 0) { - end_vec_tensor[trt_axis] = Add1DConstantLayer(ends[i]); - } else { - end_vec_tensor[trt_axis] = - Sum(end_vec_tensor[trt_axis], Add1DConstantLayer(ends[i])); + if (slice_inputs.find("EndsTensor") != slice_inputs.end() && + op_desc.Input("EndsTensor").size()) { // has EndsTensor input + for (size_t i = 0; i < axes.size(); ++i) { + ends_tensor[axes[i]] = GetEleTensorOfShape( + engine_->GetITensor(op_desc.Input("EndsTensor")[0]), i); + } + } else { + PADDLE_ENFORCE_EQ(ends.size(), + axes.size(), + platform::errors::InvalidArgument( + "The size of this ends: %d must be " + "equal to the axes: %d.", + ends.size(), + axes.size())); + for (size_t i = 0; i < axes.size(); i++) { // same as ends.size() + if (ends[i] < 0) { + ends_tensor[axes[i]] = + Max(Sum(Add1DConstantLayer(ends[i]), + GetEleTensorOfShape(shape_tensor, axes[i])), + Add1DConstantLayer(0)); + } else { + ends_tensor[axes[i]] = + Min(Add1DConstantLayer(ends[i]), + GetEleTensorOfShape(shape_tensor, axes[i])); + } } } - -// CI failed in trt 6015 but success in 7134, may be a trt bug -#if IS_TRT_VERSION_GE(7134) - auto* size_tensor = - Sub(Min(Concat(end_vec_tensor), shape_tensor), start_tensor); -#else - auto* size_tensor = Sub(Concat(end_vec_tensor), start_tensor); -#endif + end_tensor = Concat(ends_tensor); + auto* size_tensor = Sub(end_tensor, start_tensor); layer = TRT_ENGINE_ADD_LAYER( engine_, Slice, *input, trt_start_dims, trt_size_dims, trt_step_dims); @@ -139,16 +140,30 @@ class SliceOpConverter : public OpConverter { layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *layer->getOutput(0)); layer->setInput(1, *real_size_tensor); } -#else - bool with_fp16 = - engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); - int decrease_axis = decrease_axises.size() == 0 ? -1 : decrease_axises[0]; - plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic( - starts, ends, axes, decrease_axis, with_fp16); - layer = engine_->AddDynamicPlugin(&input, 1, plugin); -#endif } else { -#if IS_TRT_VERSION_GE(6000) + // notice that input shape is [CHW] without batch axis when input has + // static shape + for (size_t i = input_dims.nbDims; i > 0; i--) { + input_dims.d[i] = input_dims.d[i - 1]; + } + input_dims.d[0] = 1; // fake batchsize, not useful here + for (size_t i = 0; i < axes.size(); i++) { + if (starts[i] < 0) { + starts[i] = std::max(starts[i] + input_dims.d[axes[i]], 0); + } + if (ends[i] < 0) { + ends[i] = std::max(ends[i] + input_dims.d[axes[i]], 0); + } + ends[i] = std::min(ends[i], input_dims.d[axes[i]]); + PADDLE_ENFORCE_GT( + ends[i], + starts[i], + platform::errors::InvalidArgument( + "Attr(ends) should be greater than attr(starts) in " + "slice op. But received ends = %d, starts = %d.", + ends[i], + starts[i])); + } auto chw_input_dims = input->getDimensions(); nvinfer1::Dims trt_start_dims; trt_start_dims.nbDims = chw_input_dims.nbDims; @@ -189,13 +204,6 @@ class SliceOpConverter : public OpConverter { reshape_layer->setReshapeDimensions(real_trt_size_dims); layer = static_cast(reshape_layer); } -#else - bool with_fp16 = - engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); - plugin::SlicePlugin* plugin = - new plugin::SlicePlugin(starts, ends, axes, with_fp16); - layer = engine_->AddPlugin(&input, 1, plugin); -#endif } RreplenishLayerAndOutput(layer, "slice", {output_name}, test_mode); } diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 6741ac8bc26..b4f8b7e929f 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1172,25 +1172,13 @@ struct SimpleOpTypeSetTeller : public Teller { } } } - - if (!desc.HasAttr("axes") || !desc.HasAttr("starts") || - !desc.HasAttr("ends")) { + std::vector axes; + if (!desc.HasAttr("axes")) { VLOG(3) << "The necessary attributes of the slice operator axes " - "or starts or ends are missing."; + " are missing."; return false; } else { - std::vector axes = - PADDLE_GET_CONST(std::vector, desc.GetAttr("axes")); - std::vector starts = - PADDLE_GET_CONST(std::vector, desc.GetAttr("starts")); - std::vector ends = - PADDLE_GET_CONST(std::vector, desc.GetAttr("ends")); - - if (axes.size() != starts.size() || axes.size() != ends.size()) { - VLOG(3) << "The shape of attributes of the slice operator axes " - "or starts or ends are not equal."; - return false; - } + axes = PADDLE_GET_CONST(std::vector, desc.GetAttr("axes")); if (!with_dynamic_shape) { for (size_t i = 0; i < axes.size(); i++) { if (axes[i] == 0) { @@ -1203,14 +1191,42 @@ struct SimpleOpTypeSetTeller : public Teller { } // not support following four inputs for slice in paddle-trt auto slice_inputs = desc.Inputs(); // its size == 5 - if (slice_inputs.find("StartsTensor") != slice_inputs.end()) { - if (desc.Input("StartsTensor").size()) { + if (slice_inputs.find("StartsTensor") != slice_inputs.end() && + desc.Input("StartsTensor").size()) { + VLOG(3) << "The Slice has StartsTensor input."; + } else { + if (!desc.HasAttr("starts")) { + VLOG(3) << "The necessary attributes of the slice operator starts or " + "StartsTensor" + " are missing."; return false; + } else { + std::vector starts = + PADDLE_GET_CONST(std::vector, desc.GetAttr("starts")); + if (axes.size() != starts.size()) { + VLOG(3) << "The shape of attributes of the slice operator axes " + "and starts are not equal."; + return false; + } } } - if (slice_inputs.find("EndsTensor") != slice_inputs.end()) { - if (desc.Input("EndsTensor").size()) { + if (slice_inputs.find("EndsTensor") != slice_inputs.end() && + desc.Input("EndsTensor").size()) { + VLOG(3) << "The Slice has EndsTensor input."; + } else { + if (!desc.HasAttr("ends")) { + VLOG(3) << "The necessary attributes of the slice operator ends or " + "EndsTensor" + " are missing."; return false; + } else { + std::vector ends = + PADDLE_GET_CONST(std::vector, desc.GetAttr("ends")); + if (axes.size() != ends.size()) { + VLOG(3) << "The shape of attributes of the slice operator axes " + "and ends are not equal."; + return false; + } } } if (slice_inputs.find("StartsTensorList") != slice_inputs.end()) { @@ -1833,10 +1849,6 @@ struct SimpleOpTypeSetTeller : public Teller { auto x_var_name = desc.Input("X")[0]; auto* x_var_desc = block->FindVar(x_var_name); const auto x_shape = x_var_desc->GetShape(); - if (x_shape.size() == 1) { - VLOG(3) << "clip op does not support input's dim is 1 in tensorrt."; - return false; - } } if (op_type == "reduce_sum" || op_type == "reduce_mean") { diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index a1544065cfc..9d3d4c55323 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -14,7 +14,6 @@ list( emb_eltwise_layernorm_plugin.cu qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu - slice_op_plugin.cu hard_swish_op_plugin.cu stack_op_plugin.cu anchor_generator_op_plugin.cu diff --git a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu deleted file mode 100644 index 031202fb772..00000000000 --- a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu +++ /dev/null @@ -1,435 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include -#include // NOLINT -#include - -#include "glog/logging.h" -#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h" - -namespace paddle { -namespace inference { -namespace tensorrt { -namespace plugin { - -template -__global__ void SliceKernel( - int num, int dims, const T *input, const int *offsets_info, T *output) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - extern __shared__ int shared_data[]; - - for (int i = threadIdx.x; i < dims * 3; i += blockDim.x) { - shared_data[i] = offsets_info[i]; - } - __syncthreads(); - - if (idx < num) { - int t_idx = idx; - int in_idx = 0; - for (int i = dims - 1; i >= 0; i--) { - // output_shape - auto t = t_idx % shared_data[i * 3 + 1]; - // out offset - auto s = t + shared_data[i * 3]; - // input_seg_offset - in_idx = in_idx + shared_data[i * 3 + 2] * s; - t_idx = t_idx / shared_data[i * 3 + 1]; - } - output[idx] = input[in_idx]; - } -} - -SlicePlugin::SlicePlugin(std::vector starts, - std::vector ends, - std::vector axes, - bool with_fp16) - : starts_(starts), ends_(ends), axes_(axes) { - with_fp16_ = with_fp16; -} - -SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) { - deserializeBase(serial_data, serial_length); - DeserializeValue(&serial_data, &serial_length, &starts_); - DeserializeValue(&serial_data, &serial_length, &ends_); - DeserializeValue(&serial_data, &serial_length, &axes_); - DeserializeValue(&serial_data, &serial_length, &with_fp16_); - DeserializeValue(&serial_data, &serial_length, &offset_info_); -} - -SlicePlugin::~SlicePlugin() { cudaFree(offset_temp_data_); } - -SlicePlugin *SlicePlugin::clone() const TRT_NOEXCEPT { - return new SlicePlugin(starts_, ends_, axes_, with_fp16_); -} - -bool SlicePlugin::supportsFormat( - nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT { - if (with_fp16_) { - return ((type == nvinfer1::DataType::kFLOAT || - type == nvinfer1::DataType::kHALF || - type == nvinfer1::DataType::kINT32) && - (format == nvinfer1::PluginFormat::kLINEAR)); - } else { - return ((type == nvinfer1::DataType::kFLOAT || - type == nvinfer1::DataType::kINT32) && - (format == nvinfer1::PluginFormat::kLINEAR)); - } -} - -nvinfer1::Dims SlicePlugin::getOutputDimensions( - int index, const nvinfer1::Dims *inputs, int nb_input_dims) TRT_NOEXCEPT { - auto in_dims = inputs[0]; - nvinfer1::Dims out_dims = in_dims; - for (size_t i = 0; i < axes_.size(); i++) { - int start = starts_[i]; - int end = ends_[i]; - out_dims.d[axes_[i] - 1] = end - start; - } - return out_dims; -} - -int SlicePlugin::enqueue(int batch_size, - const void *const *inputs, -#if IS_TRT_VERSION_LT(8000) - void **outputs, - void *workspace, - cudaStream_t stream) { -#else - void *const *outputs, - void *workspace, - cudaStream_t stream) TRT_NOEXCEPT { -#endif - auto input_dims = getInputDims(0); - - // notice input dims is [C, H, W], add input batch dim here - auto out_dims = getOutputDimensions(0, &input_dims, 1); - input_dims.nbDims += 1; - out_dims.nbDims += 1; - for (auto i = input_dims.nbDims; i > 0; --i) { - input_dims.d[i] = input_dims.d[i - 1]; - out_dims.d[i] = out_dims.d[i - 1]; - } - input_dims.d[0] = batch_size; - out_dims.d[0] = batch_size; - - auto num_dims = input_dims.nbDims; - size_t out_num = ProductDim(out_dims); - - std::vector seg_offsets; - std::vector offsets; - std::vector extends; - - offsets.resize(num_dims); - extends.resize(num_dims); - seg_offsets.resize(num_dims); - - seg_offsets[num_dims - 1] = 1; - for (int i = num_dims - 2; i >= 0; i--) { - seg_offsets[i] = input_dims.d[i + 1] * seg_offsets[i + 1]; - } - for (size_t i = 0; i < num_dims; ++i) { - offsets[i] = 0; - extends[i] = out_dims.d[i]; - } - for (size_t i = 0; i < axes_.size(); ++i) { - offsets[axes_[i]] = starts_[i]; - } - - std::vector offset_info; - for (size_t i = 0; i < num_dims; ++i) { - offset_info.push_back(offsets[i]); - offset_info.push_back(extends[i]); - offset_info.push_back(seg_offsets[i]); - } - - if (offset_temp_data_ == nullptr) { - cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int)); - } - - cudaMemcpyAsync(offset_temp_data_, - offset_info.data(), - sizeof(int) * 3 * num_dims, - cudaMemcpyHostToDevice, - stream); - - int threads = 256; - int blocks = (out_num + threads - 1) / threads; - auto input_type = getDataType(); - if (input_type == nvinfer1::DataType::kFLOAT) { - VLOG(1) << "TRT Plugin DataType selected. Slice-->fp32"; - const float *input1 = static_cast(inputs[0]); - float *output = static_cast(outputs[0]); - SliceKernel<<>>( - out_num, num_dims, input1, offset_temp_data_, output); - } else if (input_type == nvinfer1::DataType::kHALF) { - VLOG(1) << "TRT Plugin DataType selected. Slice-->fp16"; - const half *input1 = static_cast(inputs[0]); - half *output = static_cast(outputs[0]); - SliceKernel<<>>( - out_num, num_dims, input1, offset_temp_data_, output); - } else if (input_type == nvinfer1::DataType::kINT32) { - VLOG(1) << "TRT Plugin DataType selected. Slice-->int32"; - const int *input1 = static_cast(inputs[0]); - int *output = static_cast(outputs[0]); - SliceKernel<<>>( - out_num, num_dims, input1, offset_temp_data_, output); - } else { - PADDLE_THROW(platform::errors::Fatal( - "The Slice TRT Plugin's input type should be float, half or int.")); - } - return cudaGetLastError() != cudaSuccess; -} - -size_t SlicePlugin::getSerializationSize() const TRT_NOEXCEPT { - return getBaseSerializationSize() + SerializedSize(starts_) + - SerializedSize(ends_) + SerializedSize(axes_) + - SerializedSize(with_fp16_) + SerializedSize(offset_info_); -} - -void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT { - serializeBase(buffer); - SerializeValue(&buffer, starts_); - SerializeValue(&buffer, ends_); - SerializeValue(&buffer, axes_); - SerializeValue(&buffer, with_fp16_); - SerializeValue(&buffer, offset_info_); -} - -// Dynamic Plugin below. -#if IS_TRT_VERSION_GE(6000) -SlicePluginDynamic::SlicePluginDynamic(std::vector starts, - std::vector ends, - std::vector axes, - int decrease_axis, - bool with_fp16) - : starts_(starts), ends_(ends), axes_(axes), decrease_axis_(decrease_axis) { - with_fp16_ = with_fp16; -} - -SlicePluginDynamic::SlicePluginDynamic(void const *serialData, - size_t serialLength) { - DeserializeValue(&serialData, &serialLength, &starts_); - DeserializeValue(&serialData, &serialLength, &ends_); - DeserializeValue(&serialData, &serialLength, &axes_); - DeserializeValue(&serialData, &serialLength, &decrease_axis_); - DeserializeValue(&serialData, &serialLength, &with_fp16_); - DeserializeValue(&serialData, &serialLength, &offset_info_); -} - -void SlicePluginDynamic::destroy() TRT_NOEXCEPT { - cudaFree(offset_temp_data_); - delete this; -} - -int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; } - -size_t SlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT { - size_t size = SerializedSize(starts_) + SerializedSize(ends_) + - SerializedSize(axes_) + SerializedSize(decrease_axis_) + - SerializedSize(with_fp16_) + SerializedSize(offset_info_); - - return size; -} - -void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT { - SerializeValue(&buffer, starts_); - SerializeValue(&buffer, ends_); - SerializeValue(&buffer, axes_); - SerializeValue(&buffer, decrease_axis_); - SerializeValue(&buffer, with_fp16_); - SerializeValue(&buffer, offset_info_); -} - -nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions( - int output_index, - const nvinfer1::DimsExprs *inputs, - int nb_inputs, - nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { - auto in_dims = inputs[0]; - nvinfer1::DimsExprs ret = in_dims; - // start, ends should greater 0 - for (size_t i = 0; i < axes_.size(); i++) { - int start = starts_[i]; - int end = ends_[i]; -#if IS_TRT_VERSION_GE(7200) - ret.d[axes_[i]] = expr_builder.operation( - nvinfer1::DimensionOperation::kSUB, - *expr_builder.operation(nvinfer1::DimensionOperation::kMIN, - *expr_builder.constant(ends_[i]), - *in_dims.d[axes_[i]]), - *expr_builder.constant(start)); -#else - ret.d[axes_[i]] = expr_builder.constant(end - start); -#endif - } - if (decrease_axis_ != -1) { - nvinfer1::DimsExprs res; - res.nbDims = ret.nbDims - 1; - int j = 0; - for (size_t i = 0; i < in_dims.nbDims; i++) { - if (decrease_axis_ == i) continue; - res.d[j++] = expr_builder.operation(nvinfer1::DimensionOperation::kMAX, - *expr_builder.constant(0), - *ret.d[i]); - } - return res; - } - return ret; -} - -bool SlicePluginDynamic::supportsFormatCombination( - int pos, - const nvinfer1::PluginTensorDesc *in_out, - int nb_inputs, - int nb_outputs) TRT_NOEXCEPT { - PADDLE_ENFORCE_NOT_NULL( - in_out, - platform::errors::InvalidArgument( - "The input of swish plugin shoule not be nullptr.")); - - PADDLE_ENFORCE_LT( - pos, - nb_inputs + nb_outputs, - platform::errors::InvalidArgument("The pos(%d) should be less than the " - "num(%d) of the input and the output.", - pos, - nb_inputs + nb_outputs)); - - const nvinfer1::PluginTensorDesc &in = in_out[pos]; - if (pos == 0) { - if (with_fp16_) { - return (in.type == nvinfer1::DataType::kFLOAT || - in.type == nvinfer1::DataType::kHALF || - in.type == nvinfer1::DataType::kINT32) && - (in.format == nvinfer1::TensorFormat::kLINEAR); - } else { - return (in.type == nvinfer1::DataType::kFLOAT || - in.type == nvinfer1::DataType::kINT32) && - (in.format == nvinfer1::TensorFormat::kLINEAR); - } - } - const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; - // output - return in.type == prev.type && in.format == prev.format; -} - -nvinfer1::DataType SlicePluginDynamic::getOutputDataType( - int index, - const nvinfer1::DataType *input_types, - int nb_inputs) const TRT_NOEXCEPT { - PADDLE_ENFORCE_EQ(index, - 0, - platform::errors::InvalidArgument( - "The Slice Plugin only has one input, so the " - "index value should be 0, but get %d.", - index)); - PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT || - input_types[0] == nvinfer1::DataType::kHALF || - input_types[0] == nvinfer1::DataType::kINT32), - true, - platform::errors::InvalidArgument( - "The input type should be half, float or int")); - return input_types[0]; -} - -int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, - const nvinfer1::PluginTensorDesc *output_desc, - const void *const *inputs, - void *const *outputs, - void *workspace, - cudaStream_t stream) TRT_NOEXCEPT { - auto input_dims = input_desc[0].dims; - auto out_dims = output_desc[0].dims; - if (decrease_axis_ != -1) { - out_dims = input_dims; - out_dims.d[decrease_axis_] = 1; - } - auto num_dims = input_dims.nbDims; - size_t out_num = ProductDim(out_dims); - - std::vector seg_offsets; - std::vector offsets; - std::vector extends; - - offsets.resize(num_dims); - extends.resize(num_dims); - seg_offsets.resize(num_dims); - - seg_offsets[num_dims - 1] = 1; - for (int i = num_dims - 2; i >= 0; i--) { - seg_offsets[i] = input_dims.d[i + 1] * seg_offsets[i + 1]; - } - - for (size_t i = 0; i < num_dims; ++i) { - offsets[i] = 0; - extends[i] = out_dims.d[i]; - } - for (size_t i = 0; i < axes_.size(); ++i) { - offsets[axes_[i]] = starts_[i]; - } - - offset_info_.resize(num_dims * 3); - for (size_t i = 0; i < num_dims; ++i) { - offset_info_[i * 3 + 0] = offsets[i]; - offset_info_[i * 3 + 1] = extends[i]; - offset_info_[i * 3 + 2] = seg_offsets[i]; - } - - if (offset_temp_data_ == nullptr) { - cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int)); - } - - cudaMemcpyAsync(offset_temp_data_, - offset_info_.data(), - sizeof(int) * 3 * num_dims, - cudaMemcpyHostToDevice, - stream); - - int threads = 256; - int blocks = (out_num + threads - 1) / threads; - auto input_type = input_desc[0].type; - if (input_type == nvinfer1::DataType::kFLOAT) { - VLOG(1) << "TRT Plugin DataType selected. Slice-->fp32"; - const float *input1 = static_cast(inputs[0]); - float *output = static_cast(outputs[0]); - SliceKernel<<>>( - out_num, num_dims, input1, offset_temp_data_, output); - } else if (input_type == nvinfer1::DataType::kHALF) { - VLOG(1) << "TRT Plugin DataType selected. Slice-->fp16"; - const half *input1 = static_cast(inputs[0]); - half *output = static_cast(outputs[0]); - SliceKernel<<>>( - out_num, num_dims, input1, offset_temp_data_, output); - } else if (input_type == nvinfer1::DataType::kINT32) { - VLOG(1) << "TRT Plugin DataType selected. Slice-->int32"; - const int *input1 = static_cast(inputs[0]); - int *output = static_cast(outputs[0]); - SliceKernel<<>>( - out_num, num_dims, input1, offset_temp_data_, output); - } else { - PADDLE_THROW(platform::errors::Fatal( - "The Slice TRT Plugin's input type should be float, half or int.")); - } - return cudaGetLastError() != cudaSuccess; -} -#endif - -} // namespace plugin -} // namespace tensorrt -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h deleted file mode 100644 index d853f855597..00000000000 --- a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -#include "paddle/fluid/inference/tensorrt/engine.h" -#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" - -namespace paddle { -namespace inference { -namespace tensorrt { -namespace plugin { - -class SlicePlugin : public PluginTensorRT { - public: - explicit SlicePlugin(std::vector starts, - std::vector ends, - std::vector axes, - bool with_fp16); - - // It was used for tensorrt deserialization. - // It should not be called by users. - SlicePlugin(void const* serial_data, size_t serial_length); - ~SlicePlugin(); - SlicePlugin* clone() const TRT_NOEXCEPT override; - - const char* getPluginType() const TRT_NOEXCEPT override { - return "slice_plugin"; - } - int getNbOutputs() const TRT_NOEXCEPT override { return 1; } - int initialize() TRT_NOEXCEPT override { return 0; } - bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) - const TRT_NOEXCEPT override; - nvinfer1::Dims getOutputDimensions(int index, - const nvinfer1::Dims* inputs, - int nb_input_dims) TRT_NOEXCEPT override; -#if IS_TRT_VERSION_LT(8000) - int enqueue(int batch_size, - const void* const* inputs, - void** outputs, -#else - int enqueue(int batch_size, - const void* const* inputs, - void* const* outputs, -#endif - void* workspace, - cudaStream_t stream) TRT_NOEXCEPT override; - - size_t getSerializationSize() const TRT_NOEXCEPT override; - - // TRT will call this func to serialize the configuration of TRT - // It should not be called by users. - void serialize(void* buffer) const TRT_NOEXCEPT override; - - private: - std::vector starts_; - std::vector ends_; - std::vector axes_; - int* offset_temp_data_{nullptr}; - std::vector offset_info_; -}; - -class SlicePluginCreator : public TensorRTPluginCreator { - public: - const char* getPluginName() const TRT_NOEXCEPT override { - return "slice_plugin"; - } - - const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } - - nvinfer1::IPluginV2* deserializePlugin(const char* name, - const void* serial_data, - size_t serial_length) - TRT_NOEXCEPT override { - return new SlicePlugin(serial_data, serial_length); - } -}; -REGISTER_TRT_PLUGIN_V2(SlicePluginCreator); - -#if IS_TRT_VERSION_GE(6000) -class SlicePluginDynamic : public DynamicPluginTensorRT { - public: - explicit SlicePluginDynamic(std::vector starts, - std::vector ends, - std::vector axes, - int decrease_axis, - bool with_fp16); - - nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { - return new SlicePluginDynamic( - starts_, ends_, axes_, decrease_axis_, with_fp16_); - } - - SlicePluginDynamic(void const* serialData, size_t serialLength); - - const char* getPluginType() const TRT_NOEXCEPT override { - return "slice_plugin_dynamic"; - } - int getNbOutputs() const TRT_NOEXCEPT override { return 1; } - int initialize() TRT_NOEXCEPT override; - - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void* buffer) const TRT_NOEXCEPT override; - - nvinfer1::DimsExprs getOutputDimensions(int output_index, - const nvinfer1::DimsExprs* inputs, - int nb_inputs, - nvinfer1::IExprBuilder& expr_builder) - TRT_NOEXCEPT override; - - bool supportsFormatCombination(int pos, - const nvinfer1::PluginTensorDesc* inOut, - int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, - int nbOutputs) TRT_NOEXCEPT override {} - - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const TRT_NOEXCEPT override { - return 0; - } - - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, - void* workspace, - cudaStream_t stream) TRT_NOEXCEPT override; - nvinfer1::DataType getOutputDataType(int index, - const nvinfer1::DataType* inputTypes, - int nbInputs) const - TRT_NOEXCEPT override; - - void destroy() TRT_NOEXCEPT override; - - private: - std::vector starts_; - std::vector ends_; - std::vector axes_; - int decrease_axis_; - int* offset_temp_data_{nullptr}; - std::vector offset_info_; -}; - -class SlicePluginDynamicCreator : public TensorRTPluginCreator { - public: - const char* getPluginName() const TRT_NOEXCEPT override { - return "slice_plugin_dynamic"; - } - - const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } - - nvinfer1::IPluginV2* deserializePlugin(const char* name, - const void* serialData, - size_t serialLength) - TRT_NOEXCEPT override { - return new SlicePluginDynamic(serialData, serialLength); - } -}; -REGISTER_TRT_PLUGIN_V2(SlicePluginDynamicCreator); - -#endif - -} // namespace plugin -} // namespace tensorrt -} // namespace inference -} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_clip.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_clip.py index 18d5adb284b..77334116eb8 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_clip.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_clip.py @@ -120,10 +120,13 @@ class TrtConvertClipTest(TrtLayerAutoScanTest): self.dynamic_shape.opt_input_shape = {} def generate_trt_nodes_num(attrs, dynamic_shape): - if self.input_num == 3 or self.dims == 1: + if self.input_num == 3: return 0, 3 else: - return 1, 2 + if not dynamic_shape and self.dims == 1: + return 0, 3 + else: + return 1, 2 attrs = [ program_config.ops[i].attrs for i in range(len(program_config.ops)) -- GitLab