未验证 提交 099cb75a 编写于 作者: Z zlsh80826 提交者: GitHub

[Paddle-TRT] Fix trt dynamic shape ernie unit test on V100 (#38056)

* add restriction on plugin supportsFormat to eliminate errors from TensorRT8

* ernie-varlen is only supported on architecture >= sm75
上级 89bced5e
......@@ -192,8 +192,11 @@ bool QkvToContextPluginDynamic::supportsFormatCombination(
if (pos == 0) {
if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
return (
#if IS_TRT_VERSION_LT(8000)
in.type == nvinfer1::DataType::kFLOAT ||
#endif
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
......
......@@ -73,8 +73,11 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination(
if (pos == 0) {
if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
return (
#if IS_TRT_VERSION_LT(8000)
in.type == nvinfer1::DataType::kFLOAT ||
#endif
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
......
......@@ -83,8 +83,11 @@ SlicePlugin *SlicePlugin::clone() const TRT_NOEXCEPT {
bool SlicePlugin::supportsFormat(
nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT {
if (with_fp16_) {
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
return ((
#if IS_TRT_VERSION_LT(8000)
type == nvinfer1::DataType::kFLOAT ||
#endif
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kLINEAR));
} else {
return ((type == nvinfer1::DataType::kFLOAT) &&
......@@ -284,8 +287,11 @@ bool SlicePluginDynamic::supportsFormatCombination(
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
if (with_fp16_) {
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
return (
#if IS_TRT_VERSION_LT(8000)
in.type == nvinfer1::DataType::kFLOAT ||
#endif
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
......
......@@ -280,14 +280,16 @@ void run(paddle_infer::Predictor* predictor, std::vector<float>* out_data) {
TEST(AnalysisPredictor, ernie_varlen) {
#if IS_TRT_VERSION_GE(7234)
auto predictor = InitPredictor();
std::vector<float> out_data;
run(predictor.get(), &out_data);
std::vector<float> ref_data{0.59814, 0.219882, 0.181978,
0.359796, 0.577414, 0.0627908};
float near_tolerance = 1e-3;
for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(ref_data[i], out_data[i], near_tolerance);
if (platform::GetGPUComputeCapability(0) >= 75) {
auto predictor = InitPredictor();
std::vector<float> out_data;
run(predictor.get(), &out_data);
std::vector<float> ref_data{0.59814, 0.219882, 0.181978,
0.359796, 0.577414, 0.0627908};
float near_tolerance = 1e-3;
for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(ref_data[i], out_data[i], near_tolerance);
}
}
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册