未验证 提交 38dad3b9 编写于 作者: W Wang Bojun 提交者: GitHub

[TRT] Add sm version check for TensorRT flash attention and cross attention pass/plugin (#50830)

* add sm version check

* use GetGPUComputeCapability
上级 12075f2a
......@@ -553,6 +553,13 @@ void TrtCrossMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
"8.5.2.2. Stop this pass";
return;
}
int sm = platform::GetGPUComputeCapability(platform::GetCurrentDeviceId());
if (sm < 80) {
VLOG(3) << "Flash attention oss plugin only available for nvidia gpu with "
"sm >= 80, but got sm = "
<< sm << " . Stop this pass";
return;
}
#else
// if no tensorrt, early stop
return;
......
......@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/op_version_registry.h"
#ifdef PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#endif
namespace paddle {
namespace framework {
......@@ -545,10 +546,18 @@ void TrtFlashMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
"8.5.2.2. Stop this pass";
return;
}
int sm = platform::GetGPUComputeCapability(platform::GetCurrentDeviceId());
if (sm < 80) {
VLOG(3) << "Flash attention oss plugin only available for nvidia gpu with "
"sm >= 80, but got sm = "
<< sm << " . Stop this pass";
return;
}
#else
// if no tensorrt, early stop
return;
#endif
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
if (!with_dynamic_shape) {
VLOG(3) << "Flash attention oss plugin need trt "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册