From 5e744096b0f5f049ff25ce2e191516cfcecd9dc7 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Mon, 28 Jun 2021 10:56:04 +0800 Subject: [PATCH] kNCHW is deprecated, should use kLINEAR (#33777) --- paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu | 4 ++-- .../fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h | 2 +- paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu | 2 +- paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu | 2 +- paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu | 4 ++-- paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc | 2 +- paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc | 2 +- paddle/fluid/inference/tensorrt/plugin/trt_plugin.h | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu index 3d84855bcb..62cf059de4 100644 --- a/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu @@ -42,10 +42,10 @@ bool GeluPlugin::supportsFormat(nvinfer1::DataType type, if (with_fp16_) { return ((type == nvinfer1::DataType::kFLOAT || type == nvinfer1::DataType::kHALF) && - (format == nvinfer1::PluginFormat::kNCHW)); + (format == nvinfer1::PluginFormat::kLINEAR)); } else { return ((type == nvinfer1::DataType::kFLOAT) && - (format == nvinfer1::PluginFormat::kNCHW)); + (format == nvinfer1::PluginFormat::kLINEAR)); } } diff --git a/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h index f413505bdf..421c4c7970 100644 --- a/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h @@ -112,7 +112,7 @@ class InstanceNormPlugin : public PluginTensorRT { nvinfer1::PluginFormat format) const override { return ((type == nvinfer1::DataType::kFLOAT || type == nvinfer1::DataType::kHALF) && - (format == nvinfer1::PluginFormat::kNCHW)); + (format == nvinfer1::PluginFormat::kLINEAR)); } }; diff --git a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu index fb8043a9d9..0d3b8ca1b4 100644 --- a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu @@ -174,7 +174,7 @@ bool PoolPluginDynamic::supportsFormatCombination( (in_out && pos < (nb_inputs + nb_outputs)); return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) && - in_out[pos].format == nvinfer1::PluginFormat::kNCHW); + in_out[pos].format == nvinfer1::PluginFormat::kLINEAR); } nvinfer1::DataType PoolPluginDynamic::getOutputDataType( diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu index ad3618bc67..09e39a3b98 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu @@ -129,7 +129,7 @@ bool PReluPluginDynamic::supportsFormatCombination( (in_out && pos < (nb_inputs + nb_outputs)); return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) && - in_out[pos].format == nvinfer1::PluginFormat::kNCHW); + in_out[pos].format == nvinfer1::PluginFormat::kLINEAR); } nvinfer1::DataType PReluPluginDynamic::getOutputDataType( diff --git a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu index 42d9018fd0..e976496ec4 100644 --- a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu @@ -90,10 +90,10 @@ bool SlicePlugin::supportsFormat(nvinfer1::DataType type, if (with_fp16_) { return ((type == nvinfer1::DataType::kFLOAT || type == nvinfer1::DataType::kHALF) && - (format == nvinfer1::PluginFormat::kNCHW)); + (format == nvinfer1::PluginFormat::kLINEAR)); } else { return ((type == nvinfer1::DataType::kFLOAT) && - (format == nvinfer1::PluginFormat::kNCHW)); + (format == nvinfer1::PluginFormat::kLINEAR)); } } diff --git a/paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc index 6636513a55..46f585e655 100644 --- a/paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc +++ b/paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc @@ -33,7 +33,7 @@ TEST(split_op_plugin, test_plugin) { input_dims.push_back(in_dims); sp_plugin.configurePlugin(input_dims.data(), 1, nullptr, 2, input_types.data(), nullptr, nullptr, nullptr, - nvinfer1::PluginFormat::kNCHW, 4); + nvinfer1::PluginFormat::kLINEAR, 4); sp_plugin.initialize(); sp_plugin.getPluginType(); sp_plugin.canBroadcastInputAcrossBatch(0); diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc index 55bc786746..e2f3810cc3 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc @@ -68,7 +68,7 @@ size_t PluginTensorRT::getBaseSerializationSize() { bool PluginTensorRT::supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) const { return ((type == nvinfer1::DataType::kFLOAT) && - (format == nvinfer1::PluginFormat::kNCHW)); + (format == nvinfer1::PluginFormat::kLINEAR)); } void PluginTensorRT::configureWithFormat( diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h index 37be06bba3..9c4add0688 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h @@ -181,7 +181,7 @@ class PluginTensorRTV2Ext : public nvinfer1::IPluginV2Ext { bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) const override { return ((type == nvinfer1::DataType::kFLOAT) && - (format == nvinfer1::PluginFormat::kNCHW)); + (format == nvinfer1::PluginFormat::kLINEAR)); } // Initialize the layer for execution. // This is called when the engine is created. -- GitLab