From 38dad3b9733401d93bffa7934e3051164aed41a1 Mon Sep 17 00:00:00 2001 From: Wang Bojun <105858416+wwbitejotunn@users.noreply.github.com> Date: Mon, 27 Feb 2023 15:56:37 +0800 Subject: [PATCH] [TRT] Add sm version check for TensorRT flash attention and cross attention pass/plugin (#50830) * add sm version check * use GetGPUComputeCapability --- .../framework/ir/trt_cross_multihead_matmul_fuse_pass.cc | 7 +++++++ .../framework/ir/trt_flash_multihead_matmul_fuse_pass.cc | 9 +++++++++ 2 files changed, 16 insertions(+) diff --git a/paddle/fluid/framework/ir/trt_cross_multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/trt_cross_multihead_matmul_fuse_pass.cc index ee0075f3674..ae50563dfa4 100644 --- a/paddle/fluid/framework/ir/trt_cross_multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_cross_multihead_matmul_fuse_pass.cc @@ -553,6 +553,13 @@ void TrtCrossMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { "8.5.2.2. Stop this pass"; return; } + int sm = platform::GetGPUComputeCapability(platform::GetCurrentDeviceId()); + if (sm < 80) { + VLOG(3) << "Flash attention oss plugin only available for nvidia gpu with " + "sm >= 80, but got sm = " + << sm << " . Stop this pass"; + return; + } #else // if no tensorrt, early stop return; diff --git a/paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.cc index eb5390d2a99..2cafb36e93a 100644 --- a/paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/framework/op_version_registry.h" #ifdef PADDLE_WITH_TENSORRT #include "paddle/fluid/inference/tensorrt/helper.h" +#include "paddle/phi/backends/gpu/gpu_info.h" #endif namespace paddle { namespace framework { @@ -545,10 +546,18 @@ void TrtFlashMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { "8.5.2.2. Stop this pass"; return; } + int sm = platform::GetGPUComputeCapability(platform::GetCurrentDeviceId()); + if (sm < 80) { + VLOG(3) << "Flash attention oss plugin only available for nvidia gpu with " + "sm >= 80, but got sm = " + << sm << " . Stop this pass"; + return; + } #else // if no tensorrt, early stop return; #endif + bool with_dynamic_shape = Get("with_dynamic_shape"); if (!with_dynamic_shape) { VLOG(3) << "Flash attention oss plugin need trt " -- GitLab