未验证 提交 5e744096 编写于 作者: Z zlsh80826 提交者: GitHub

kNCHW is deprecated, should use kLINEAR (#33777)

上级 d91352c0
...@@ -42,10 +42,10 @@ bool GeluPlugin::supportsFormat(nvinfer1::DataType type, ...@@ -42,10 +42,10 @@ bool GeluPlugin::supportsFormat(nvinfer1::DataType type,
if (with_fp16_) { if (with_fp16_) {
return ((type == nvinfer1::DataType::kFLOAT || return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) && type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kNCHW)); (format == nvinfer1::PluginFormat::kLINEAR));
} else { } else {
return ((type == nvinfer1::DataType::kFLOAT) && return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kNCHW)); (format == nvinfer1::PluginFormat::kLINEAR));
} }
} }
......
...@@ -112,7 +112,7 @@ class InstanceNormPlugin : public PluginTensorRT { ...@@ -112,7 +112,7 @@ class InstanceNormPlugin : public PluginTensorRT {
nvinfer1::PluginFormat format) const override { nvinfer1::PluginFormat format) const override {
return ((type == nvinfer1::DataType::kFLOAT || return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) && type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kNCHW)); (format == nvinfer1::PluginFormat::kLINEAR));
} }
}; };
......
...@@ -174,7 +174,7 @@ bool PoolPluginDynamic::supportsFormatCombination( ...@@ -174,7 +174,7 @@ bool PoolPluginDynamic::supportsFormatCombination(
(in_out && pos < (nb_inputs + nb_outputs)); (in_out && pos < (nb_inputs + nb_outputs));
return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) && return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
in_out[pos].format == nvinfer1::PluginFormat::kNCHW); in_out[pos].format == nvinfer1::PluginFormat::kLINEAR);
} }
nvinfer1::DataType PoolPluginDynamic::getOutputDataType( nvinfer1::DataType PoolPluginDynamic::getOutputDataType(
......
...@@ -129,7 +129,7 @@ bool PReluPluginDynamic::supportsFormatCombination( ...@@ -129,7 +129,7 @@ bool PReluPluginDynamic::supportsFormatCombination(
(in_out && pos < (nb_inputs + nb_outputs)); (in_out && pos < (nb_inputs + nb_outputs));
return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) && return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
in_out[pos].format == nvinfer1::PluginFormat::kNCHW); in_out[pos].format == nvinfer1::PluginFormat::kLINEAR);
} }
nvinfer1::DataType PReluPluginDynamic::getOutputDataType( nvinfer1::DataType PReluPluginDynamic::getOutputDataType(
......
...@@ -90,10 +90,10 @@ bool SlicePlugin::supportsFormat(nvinfer1::DataType type, ...@@ -90,10 +90,10 @@ bool SlicePlugin::supportsFormat(nvinfer1::DataType type,
if (with_fp16_) { if (with_fp16_) {
return ((type == nvinfer1::DataType::kFLOAT || return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) && type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kNCHW)); (format == nvinfer1::PluginFormat::kLINEAR));
} else { } else {
return ((type == nvinfer1::DataType::kFLOAT) && return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kNCHW)); (format == nvinfer1::PluginFormat::kLINEAR));
} }
} }
......
...@@ -33,7 +33,7 @@ TEST(split_op_plugin, test_plugin) { ...@@ -33,7 +33,7 @@ TEST(split_op_plugin, test_plugin) {
input_dims.push_back(in_dims); input_dims.push_back(in_dims);
sp_plugin.configurePlugin(input_dims.data(), 1, nullptr, 2, sp_plugin.configurePlugin(input_dims.data(), 1, nullptr, 2,
input_types.data(), nullptr, nullptr, nullptr, input_types.data(), nullptr, nullptr, nullptr,
nvinfer1::PluginFormat::kNCHW, 4); nvinfer1::PluginFormat::kLINEAR, 4);
sp_plugin.initialize(); sp_plugin.initialize();
sp_plugin.getPluginType(); sp_plugin.getPluginType();
sp_plugin.canBroadcastInputAcrossBatch(0); sp_plugin.canBroadcastInputAcrossBatch(0);
......
...@@ -68,7 +68,7 @@ size_t PluginTensorRT::getBaseSerializationSize() { ...@@ -68,7 +68,7 @@ size_t PluginTensorRT::getBaseSerializationSize() {
bool PluginTensorRT::supportsFormat(nvinfer1::DataType type, bool PluginTensorRT::supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const { nvinfer1::PluginFormat format) const {
return ((type == nvinfer1::DataType::kFLOAT) && return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kNCHW)); (format == nvinfer1::PluginFormat::kLINEAR));
} }
void PluginTensorRT::configureWithFormat( void PluginTensorRT::configureWithFormat(
......
...@@ -181,7 +181,7 @@ class PluginTensorRTV2Ext : public nvinfer1::IPluginV2Ext { ...@@ -181,7 +181,7 @@ class PluginTensorRTV2Ext : public nvinfer1::IPluginV2Ext {
bool supportsFormat(nvinfer1::DataType type, bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const override { nvinfer1::PluginFormat format) const override {
return ((type == nvinfer1::DataType::kFLOAT) && return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kNCHW)); (format == nvinfer1::PluginFormat::kLINEAR));
} }
// Initialize the layer for execution. // Initialize the layer for execution.
// This is called when the engine is created. // This is called when the engine is created.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册