diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc index 18ea8850dc5bfb15afa5584a2d6241ba2da8e9ed..6186dd72688e74d6b365f484347b98929eb04e70 100644 --- a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -18,6 +18,9 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/op_version_registry.h" +#ifdef PADDLE_WITH_TENSORRT +#include "paddle/fluid/inference/tensorrt/helper.h" +#endif namespace paddle { namespace framework { @@ -105,6 +108,20 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); FusePassBase::Init("skip_layernorm_fuse", graph); + +#ifdef PADDLE_WITH_TENSORRT + auto trt_version = paddle::inference::tensorrt::GetTrtRuntimeVersion(); + if (std::get<0>(trt_version) * 1000 + std::get<1>(trt_version) * 100 + + std::get<2>(trt_version) * 10 < + 7200) { + VLOG(3) << "skip_layernorm oss plugin only available for trt version >= " + "7.2 Stop this pass"; + return; + } +#else + // if no tensorrt, early stop + return; +#endif int found_subgraph_count = 0; GraphPatternDetector gpd;