From 75b16781181984a1db55e501d42ece37cb34d0ee Mon Sep 17 00:00:00 2001 From: weishengying <63448337+weishengying@users.noreply.github.com> Date: Tue, 18 Oct 2022 20:57:49 +0800 Subject: [PATCH] Fix bugs in the General Plugin Mechanism (#47072) --- .../tensorrt/plugin/generic_plugin.cu | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu index e083e9633d..f335c63fa3 100644 --- a/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu @@ -288,14 +288,31 @@ bool GenericPlugin::supportsFormatCombination( int nb_inputs, int nb_outputs) TRT_NOEXCEPT { if (op_desc_.Type() == "gather_nd" || op_desc_.Type() == "yolo_box") { - if (pos == 0) return in_out[pos].type == nvinfer1::DataType::kFLOAT; - if (pos == 1) return in_out[pos].type == nvinfer1::DataType::kINT32; + if (pos == 0) + return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + if (pos == 1) + return (in_out[pos].type == nvinfer1::DataType::kINT32) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + if (pos == 2) + return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); } else if (op_desc_.Type() == "scatter_nd_add") { - if (pos == 0) return in_out[pos].type == nvinfer1::DataType::kFLOAT; - if (pos == 1) return in_out[pos].type == nvinfer1::DataType::kINT32; - if (pos == 2) return in_out[pos].type == nvinfer1::DataType::kFLOAT; + if (pos == 0) + return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + if (pos == 1) + return (in_out[pos].type == nvinfer1::DataType::kINT32) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + if (pos == 2) + return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + if (pos == 3) + return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); } else { - return in_out[pos].type == nvinfer1::DataType::kFLOAT; + return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); } } -- GitLab