diff --git a/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu index e083e9633dc29f3358c37d1c421a83318f87be2b..f335c63fa36614ed8b32123c7ce10ff5828a441e 100644 --- a/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu @@ -288,14 +288,31 @@ bool GenericPlugin::supportsFormatCombination( int nb_inputs, int nb_outputs) TRT_NOEXCEPT { if (op_desc_.Type() == "gather_nd" || op_desc_.Type() == "yolo_box") { - if (pos == 0) return in_out[pos].type == nvinfer1::DataType::kFLOAT; - if (pos == 1) return in_out[pos].type == nvinfer1::DataType::kINT32; + if (pos == 0) + return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + if (pos == 1) + return (in_out[pos].type == nvinfer1::DataType::kINT32) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + if (pos == 2) + return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); } else if (op_desc_.Type() == "scatter_nd_add") { - if (pos == 0) return in_out[pos].type == nvinfer1::DataType::kFLOAT; - if (pos == 1) return in_out[pos].type == nvinfer1::DataType::kINT32; - if (pos == 2) return in_out[pos].type == nvinfer1::DataType::kFLOAT; + if (pos == 0) + return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + if (pos == 1) + return (in_out[pos].type == nvinfer1::DataType::kINT32) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + if (pos == 2) + return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + if (pos == 3) + return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); } else { - return in_out[pos].type == nvinfer1::DataType::kFLOAT; + return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); } }