未验证 提交 099cb75a 编写于 作者: Z zlsh80826 提交者: GitHub

[Paddle-TRT] Fix trt dynamic shape ernie unit test on V100 (#38056)

* add restriction on plugin supportsFormat to eliminate errors from TensorRT8

* ernie-varlen is only supported on architecture >= sm75
上级 89bced5e
...@@ -192,8 +192,11 @@ bool QkvToContextPluginDynamic::supportsFormatCombination( ...@@ -192,8 +192,11 @@ bool QkvToContextPluginDynamic::supportsFormatCombination(
if (pos == 0) { if (pos == 0) {
if (with_fp16_) { if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE #ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kFLOAT || return (
in.type == nvinfer1::DataType::kHALF) && #if IS_TRT_VERSION_LT(8000)
in.type == nvinfer1::DataType::kFLOAT ||
#endif
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (in.format == nvinfer1::TensorFormat::kLINEAR);
#else #else
return (in.type == nvinfer1::DataType::kFLOAT) && return (in.type == nvinfer1::DataType::kFLOAT) &&
......
...@@ -73,8 +73,11 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination( ...@@ -73,8 +73,11 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination(
if (pos == 0) { if (pos == 0) {
if (with_fp16_) { if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE #ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kFLOAT || return (
in.type == nvinfer1::DataType::kHALF) && #if IS_TRT_VERSION_LT(8000)
in.type == nvinfer1::DataType::kFLOAT ||
#endif
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (in.format == nvinfer1::TensorFormat::kLINEAR);
#else #else
return (in.type == nvinfer1::DataType::kFLOAT) && return (in.type == nvinfer1::DataType::kFLOAT) &&
......
...@@ -83,8 +83,11 @@ SlicePlugin *SlicePlugin::clone() const TRT_NOEXCEPT { ...@@ -83,8 +83,11 @@ SlicePlugin *SlicePlugin::clone() const TRT_NOEXCEPT {
bool SlicePlugin::supportsFormat( bool SlicePlugin::supportsFormat(
nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT { nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT {
if (with_fp16_) { if (with_fp16_) {
return ((type == nvinfer1::DataType::kFLOAT || return ((
type == nvinfer1::DataType::kHALF) && #if IS_TRT_VERSION_LT(8000)
type == nvinfer1::DataType::kFLOAT ||
#endif
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kLINEAR)); (format == nvinfer1::PluginFormat::kLINEAR));
} else { } else {
return ((type == nvinfer1::DataType::kFLOAT) && return ((type == nvinfer1::DataType::kFLOAT) &&
...@@ -284,8 +287,11 @@ bool SlicePluginDynamic::supportsFormatCombination( ...@@ -284,8 +287,11 @@ bool SlicePluginDynamic::supportsFormatCombination(
const nvinfer1::PluginTensorDesc &in = in_out[pos]; const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) { if (pos == 0) {
if (with_fp16_) { if (with_fp16_) {
return (in.type == nvinfer1::DataType::kFLOAT || return (
in.type == nvinfer1::DataType::kHALF) && #if IS_TRT_VERSION_LT(8000)
in.type == nvinfer1::DataType::kFLOAT ||
#endif
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (in.format == nvinfer1::TensorFormat::kLINEAR);
} else { } else {
return (in.type == nvinfer1::DataType::kFLOAT) && return (in.type == nvinfer1::DataType::kFLOAT) &&
......
...@@ -280,14 +280,16 @@ void run(paddle_infer::Predictor* predictor, std::vector<float>* out_data) { ...@@ -280,14 +280,16 @@ void run(paddle_infer::Predictor* predictor, std::vector<float>* out_data) {
TEST(AnalysisPredictor, ernie_varlen) { TEST(AnalysisPredictor, ernie_varlen) {
#if IS_TRT_VERSION_GE(7234) #if IS_TRT_VERSION_GE(7234)
auto predictor = InitPredictor(); if (platform::GetGPUComputeCapability(0) >= 75) {
std::vector<float> out_data; auto predictor = InitPredictor();
run(predictor.get(), &out_data); std::vector<float> out_data;
std::vector<float> ref_data{0.59814, 0.219882, 0.181978, run(predictor.get(), &out_data);
0.359796, 0.577414, 0.0627908}; std::vector<float> ref_data{0.59814, 0.219882, 0.181978,
float near_tolerance = 1e-3; 0.359796, 0.577414, 0.0627908};
for (size_t i = 0; i < out_data.size(); i++) { float near_tolerance = 1e-3;
EXPECT_NEAR(ref_data[i], out_data[i], near_tolerance); for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(ref_data[i], out_data[i], near_tolerance);
}
} }
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册