From 4e23af724733b21fec6253554f909f515bf89f96 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Fri, 31 Mar 2023 14:35:49 +0800 Subject: [PATCH] [Paddle-TRT] fix skiplayernorm, add trt_version check (#52342) * fix skiplayernorm, add trt_version check --- .../ir/trt_skip_layernorm_fuse_pass.cc | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc index 18ea8850dc5..6186dd72688 100644 --- a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -18,6 +18,9 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/op_version_registry.h" +#ifdef PADDLE_WITH_TENSORRT +#include "paddle/fluid/inference/tensorrt/helper.h" +#endif namespace paddle { namespace framework { @@ -105,6 +108,20 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); FusePassBase::Init("skip_layernorm_fuse", graph); + +#ifdef PADDLE_WITH_TENSORRT + auto trt_version = paddle::inference::tensorrt::GetTrtRuntimeVersion(); + if (std::get<0>(trt_version) * 1000 + std::get<1>(trt_version) * 100 + + std::get<2>(trt_version) * 10 < + 7200) { + VLOG(3) << "skip_layernorm oss plugin only available for trt version >= " + "7.2 Stop this pass"; + return; + } +#else + // if no tensorrt, early stop + return; +#endif int found_subgraph_count = 0; GraphPatternDetector gpd; -- GitLab