diff --git a/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu index 3d84855bcbddb92c07f8536a40172687adaab7de..62cf059de492a16cff93c0972302292cf1dd2db0 100644 --- a/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu @@ -42,10 +42,10 @@ bool GeluPlugin::supportsFormat(nvinfer1::DataType type, if (with_fp16_) { return ((type == nvinfer1::DataType::kFLOAT || type == nvinfer1::DataType::kHALF) && - (format == nvinfer1::PluginFormat::kNCHW)); + (format == nvinfer1::PluginFormat::kLINEAR)); } else { return ((type == nvinfer1::DataType::kFLOAT) && - (format == nvinfer1::PluginFormat::kNCHW)); + (format == nvinfer1::PluginFormat::kLINEAR)); } } diff --git a/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h index f413505bdf43e9828050a5e3dc6851bd1effcb8d..421c4c7970ec68db9dd2c922439addcadb8059e5 100644 --- a/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h @@ -112,7 +112,7 @@ class InstanceNormPlugin : public PluginTensorRT { nvinfer1::PluginFormat format) const override { return ((type == nvinfer1::DataType::kFLOAT || type == nvinfer1::DataType::kHALF) && - (format == nvinfer1::PluginFormat::kNCHW)); + (format == nvinfer1::PluginFormat::kLINEAR)); } }; diff --git a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu index fb8043a9d90e4b0274087c7943eebefb5e545bc9..0d3b8ca1b4244a2f5fb79ef5cbc160a41c638f9a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu @@ -174,7 +174,7 @@ bool PoolPluginDynamic::supportsFormatCombination( (in_out && pos < (nb_inputs + nb_outputs)); return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) && - in_out[pos].format == nvinfer1::PluginFormat::kNCHW); + in_out[pos].format == nvinfer1::PluginFormat::kLINEAR); } nvinfer1::DataType PoolPluginDynamic::getOutputDataType( diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu index ad3618bc67b0457efbfba52cf6ecfa1c8ee7c398..09e39a3b9876f0951b3064c632ba65701158e3e4 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu @@ -129,7 +129,7 @@ bool PReluPluginDynamic::supportsFormatCombination( (in_out && pos < (nb_inputs + nb_outputs)); return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) && - in_out[pos].format == nvinfer1::PluginFormat::kNCHW); + in_out[pos].format == nvinfer1::PluginFormat::kLINEAR); } nvinfer1::DataType PReluPluginDynamic::getOutputDataType( diff --git a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu index 42d9018fd057952975ffc572828701662e5ba231..e976496ec44ca83bc34b9e7b031d247f67a8082d 100644 --- a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu @@ -90,10 +90,10 @@ bool SlicePlugin::supportsFormat(nvinfer1::DataType type, if (with_fp16_) { return ((type == nvinfer1::DataType::kFLOAT || type == nvinfer1::DataType::kHALF) && - (format == nvinfer1::PluginFormat::kNCHW)); + (format == nvinfer1::PluginFormat::kLINEAR)); } else { return ((type == nvinfer1::DataType::kFLOAT) && - (format == nvinfer1::PluginFormat::kNCHW)); + (format == nvinfer1::PluginFormat::kLINEAR)); } } diff --git a/paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc index 6636513a555f9e638e1dfdb54986010c76785e2a..46f585e6557460c850b6419049b4dbf31d592509 100644 --- a/paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc +++ b/paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc @@ -33,7 +33,7 @@ TEST(split_op_plugin, test_plugin) { input_dims.push_back(in_dims); sp_plugin.configurePlugin(input_dims.data(), 1, nullptr, 2, input_types.data(), nullptr, nullptr, nullptr, - nvinfer1::PluginFormat::kNCHW, 4); + nvinfer1::PluginFormat::kLINEAR, 4); sp_plugin.initialize(); sp_plugin.getPluginType(); sp_plugin.canBroadcastInputAcrossBatch(0); diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc index 55bc786746beafcf7b2df98d54e9391e6a59ba24..e2f3810cc34e012a073942deafaf61616d9000e2 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc @@ -68,7 +68,7 @@ size_t PluginTensorRT::getBaseSerializationSize() { bool PluginTensorRT::supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) const { return ((type == nvinfer1::DataType::kFLOAT) && - (format == nvinfer1::PluginFormat::kNCHW)); + (format == nvinfer1::PluginFormat::kLINEAR)); } void PluginTensorRT::configureWithFormat( diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h index 37be06bba3aebd518e8775c7ed8ccd24aa7fa7e1..9c4add0688987d0dcb979c56b3686b224b80ab4d 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h @@ -181,7 +181,7 @@ class PluginTensorRTV2Ext : public nvinfer1::IPluginV2Ext { bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) const override { return ((type == nvinfer1::DataType::kFLOAT) && - (format == nvinfer1::PluginFormat::kNCHW)); + (format == nvinfer1::PluginFormat::kLINEAR)); } // Initialize the layer for execution. // This is called when the engine is created.