未验证 提交 099cb75a 编写于 作者: Z zlsh80826 提交者: GitHub

[Paddle-TRT] Fix trt dynamic shape ernie unit test on V100 (#38056)

* add restriction on plugin supportsFormat to eliminate errors from TensorRT8

* ernie-varlen is only supported on architecture >= sm75
上级 89bced5e
......@@ -192,7 +192,10 @@ bool QkvToContextPluginDynamic::supportsFormatCombination(
if (pos == 0) {
if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kFLOAT ||
return (
#if IS_TRT_VERSION_LT(8000)
in.type == nvinfer1::DataType::kFLOAT ||
#endif
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
......
......@@ -73,7 +73,10 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination(
if (pos == 0) {
if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kFLOAT ||
return (
#if IS_TRT_VERSION_LT(8000)
in.type == nvinfer1::DataType::kFLOAT ||
#endif
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
......
......@@ -83,7 +83,10 @@ SlicePlugin *SlicePlugin::clone() const TRT_NOEXCEPT {
bool SlicePlugin::supportsFormat(
nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT {
if (with_fp16_) {
return ((type == nvinfer1::DataType::kFLOAT ||
return ((
#if IS_TRT_VERSION_LT(8000)
type == nvinfer1::DataType::kFLOAT ||
#endif
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kLINEAR));
} else {
......@@ -284,7 +287,10 @@ bool SlicePluginDynamic::supportsFormatCombination(
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
if (with_fp16_) {
return (in.type == nvinfer1::DataType::kFLOAT ||
return (
#if IS_TRT_VERSION_LT(8000)
in.type == nvinfer1::DataType::kFLOAT ||
#endif
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
} else {
......
......@@ -280,6 +280,7 @@ void run(paddle_infer::Predictor* predictor, std::vector<float>* out_data) {
TEST(AnalysisPredictor, ernie_varlen) {
#if IS_TRT_VERSION_GE(7234)
if (platform::GetGPUComputeCapability(0) >= 75) {
auto predictor = InitPredictor();
std::vector<float> out_data;
run(predictor.get(), &out_data);
......@@ -289,6 +290,7 @@ TEST(AnalysisPredictor, ernie_varlen) {
for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(ref_data[i], out_data[i], near_tolerance);
}
}
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册