未验证 提交 38dad3b9 编写于 作者: W Wang Bojun 提交者: GitHub

[TRT] Add sm version check for TensorRT flash attention and cross attention pass/plugin (#50830)

* add sm version check

* use GetGPUComputeCapability
上级 12075f2a
...@@ -553,6 +553,13 @@ void TrtCrossMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { ...@@ -553,6 +553,13 @@ void TrtCrossMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
"8.5.2.2. Stop this pass"; "8.5.2.2. Stop this pass";
return; return;
} }
int sm = platform::GetGPUComputeCapability(platform::GetCurrentDeviceId());
if (sm < 80) {
VLOG(3) << "Flash attention oss plugin only available for nvidia gpu with "
"sm >= 80, but got sm = "
<< sm << " . Stop this pass";
return;
}
#else #else
// if no tensorrt, early stop // if no tensorrt, early stop
return; return;
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#ifdef PADDLE_WITH_TENSORRT #ifdef PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#endif #endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -545,10 +546,18 @@ void TrtFlashMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { ...@@ -545,10 +546,18 @@ void TrtFlashMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
"8.5.2.2. Stop this pass"; "8.5.2.2. Stop this pass";
return; return;
} }
int sm = platform::GetGPUComputeCapability(platform::GetCurrentDeviceId());
if (sm < 80) {
VLOG(3) << "Flash attention oss plugin only available for nvidia gpu with "
"sm >= 80, but got sm = "
<< sm << " . Stop this pass";
return;
}
#else #else
// if no tensorrt, early stop // if no tensorrt, early stop
return; return;
#endif #endif
bool with_dynamic_shape = Get<bool>("with_dynamic_shape"); bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
VLOG(3) << "Flash attention oss plugin need trt " VLOG(3) << "Flash attention oss plugin need trt "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册