未验证 提交 099cb75a 编写于 作者: Z zlsh80826 提交者: GitHub

[Paddle-TRT] Fix trt dynamic shape ernie unit test on V100 (#38056)

* add restriction on plugin supportsFormat to eliminate errors from TensorRT8

* ernie-varlen is only supported on architecture >= sm75
上级 89bced5e
...@@ -192,7 +192,10 @@ bool QkvToContextPluginDynamic::supportsFormatCombination( ...@@ -192,7 +192,10 @@ bool QkvToContextPluginDynamic::supportsFormatCombination(
if (pos == 0) { if (pos == 0) {
if (with_fp16_) { if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE #ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kFLOAT || return (
#if IS_TRT_VERSION_LT(8000)
in.type == nvinfer1::DataType::kFLOAT ||
#endif
in.type == nvinfer1::DataType::kHALF) && in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (in.format == nvinfer1::TensorFormat::kLINEAR);
#else #else
......
...@@ -73,7 +73,10 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination( ...@@ -73,7 +73,10 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination(
if (pos == 0) { if (pos == 0) {
if (with_fp16_) { if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE #ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kFLOAT || return (
#if IS_TRT_VERSION_LT(8000)
in.type == nvinfer1::DataType::kFLOAT ||
#endif
in.type == nvinfer1::DataType::kHALF) && in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (in.format == nvinfer1::TensorFormat::kLINEAR);
#else #else
......
...@@ -83,7 +83,10 @@ SlicePlugin *SlicePlugin::clone() const TRT_NOEXCEPT { ...@@ -83,7 +83,10 @@ SlicePlugin *SlicePlugin::clone() const TRT_NOEXCEPT {
bool SlicePlugin::supportsFormat( bool SlicePlugin::supportsFormat(
nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT { nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT {
if (with_fp16_) { if (with_fp16_) {
return ((type == nvinfer1::DataType::kFLOAT || return ((
#if IS_TRT_VERSION_LT(8000)
type == nvinfer1::DataType::kFLOAT ||
#endif
type == nvinfer1::DataType::kHALF) && type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kLINEAR)); (format == nvinfer1::PluginFormat::kLINEAR));
} else { } else {
...@@ -284,7 +287,10 @@ bool SlicePluginDynamic::supportsFormatCombination( ...@@ -284,7 +287,10 @@ bool SlicePluginDynamic::supportsFormatCombination(
const nvinfer1::PluginTensorDesc &in = in_out[pos]; const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) { if (pos == 0) {
if (with_fp16_) { if (with_fp16_) {
return (in.type == nvinfer1::DataType::kFLOAT || return (
#if IS_TRT_VERSION_LT(8000)
in.type == nvinfer1::DataType::kFLOAT ||
#endif
in.type == nvinfer1::DataType::kHALF) && in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (in.format == nvinfer1::TensorFormat::kLINEAR);
} else { } else {
......
...@@ -280,6 +280,7 @@ void run(paddle_infer::Predictor* predictor, std::vector<float>* out_data) { ...@@ -280,6 +280,7 @@ void run(paddle_infer::Predictor* predictor, std::vector<float>* out_data) {
TEST(AnalysisPredictor, ernie_varlen) { TEST(AnalysisPredictor, ernie_varlen) {
#if IS_TRT_VERSION_GE(7234) #if IS_TRT_VERSION_GE(7234)
if (platform::GetGPUComputeCapability(0) >= 75) {
auto predictor = InitPredictor(); auto predictor = InitPredictor();
std::vector<float> out_data; std::vector<float> out_data;
run(predictor.get(), &out_data); run(predictor.get(), &out_data);
...@@ -289,6 +290,7 @@ TEST(AnalysisPredictor, ernie_varlen) { ...@@ -289,6 +290,7 @@ TEST(AnalysisPredictor, ernie_varlen) {
for (size_t i = 0; i < out_data.size(); i++) { for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(ref_data[i], out_data[i], near_tolerance); EXPECT_NEAR(ref_data[i], out_data[i], near_tolerance);
} }
}
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册