From f24eadd90fd27efe3123f1a854e8b61cb2f22e23 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Fri, 24 Feb 2023 16:45:51 +0800 Subject: [PATCH] [Paddle-TRT] allow plugin fall back to fp16 when int8 (#50554) * allow fall back to fp16 when int8 * refine code * refine code * refine code --- paddle/fluid/inference/tensorrt/engine.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 421842cf563..aa5a7657e28 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -358,7 +358,9 @@ class TensorRTEngine { bool WithFp16() { bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf); bool support_fp16 = infer_builder_->platformHasFastFp16(); - return enable_fp16 && support_fp16; + // below is consistent with setFlag in engine.cc + bool fall_back_fp16 = WithInt8() && !use_dla_; + return (enable_fp16 || fall_back_fp16) && support_fp16; } bool WithInt8() { -- GitLab