未验证 提交 5e744096 编写于 作者: Z zlsh80826 提交者: GitHub

kNCHW is deprecated, should use kLINEAR (#33777)

上级 d91352c0
......@@ -42,10 +42,10 @@ bool GeluPlugin::supportsFormat(nvinfer1::DataType type,
if (with_fp16_) {
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kNCHW));
(format == nvinfer1::PluginFormat::kLINEAR));
} else {
return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kNCHW));
(format == nvinfer1::PluginFormat::kLINEAR));
}
}
......
......@@ -112,7 +112,7 @@ class InstanceNormPlugin : public PluginTensorRT {
nvinfer1::PluginFormat format) const override {
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kNCHW));
(format == nvinfer1::PluginFormat::kLINEAR));
}
};
......
......@@ -174,7 +174,7 @@ bool PoolPluginDynamic::supportsFormatCombination(
(in_out && pos < (nb_inputs + nb_outputs));
return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
in_out[pos].format == nvinfer1::PluginFormat::kNCHW);
in_out[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
nvinfer1::DataType PoolPluginDynamic::getOutputDataType(
......
......@@ -129,7 +129,7 @@ bool PReluPluginDynamic::supportsFormatCombination(
(in_out && pos < (nb_inputs + nb_outputs));
return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
in_out[pos].format == nvinfer1::PluginFormat::kNCHW);
in_out[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
nvinfer1::DataType PReluPluginDynamic::getOutputDataType(
......
......@@ -90,10 +90,10 @@ bool SlicePlugin::supportsFormat(nvinfer1::DataType type,
if (with_fp16_) {
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kNCHW));
(format == nvinfer1::PluginFormat::kLINEAR));
} else {
return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kNCHW));
(format == nvinfer1::PluginFormat::kLINEAR));
}
}
......
......@@ -33,7 +33,7 @@ TEST(split_op_plugin, test_plugin) {
input_dims.push_back(in_dims);
sp_plugin.configurePlugin(input_dims.data(), 1, nullptr, 2,
input_types.data(), nullptr, nullptr, nullptr,
nvinfer1::PluginFormat::kNCHW, 4);
nvinfer1::PluginFormat::kLINEAR, 4);
sp_plugin.initialize();
sp_plugin.getPluginType();
sp_plugin.canBroadcastInputAcrossBatch(0);
......
......@@ -68,7 +68,7 @@ size_t PluginTensorRT::getBaseSerializationSize() {
bool PluginTensorRT::supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const {
return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kNCHW));
(format == nvinfer1::PluginFormat::kLINEAR));
}
void PluginTensorRT::configureWithFormat(
......
......@@ -181,7 +181,7 @@ class PluginTensorRTV2Ext : public nvinfer1::IPluginV2Ext {
bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const override {
return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kNCHW));
(format == nvinfer1::PluginFormat::kLINEAR));
}
// Initialize the layer for execution.
// This is called when the engine is created.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册