未验证 提交 f2c96bc2 编写于 作者: S sneaxiy 提交者: GitHub

Fix generate_kernels.py in CUDA 12.0 (#52232)

* fix generate_kernels.py in CUDA 12.0

* fix attrs bug
上级 2e9fd5e4
...@@ -171,7 +171,7 @@ function(select_nvcc_arch_flags out_variable out_arch_bin) ...@@ -171,7 +171,7 @@ function(select_nvcc_arch_flags out_variable out_arch_bin)
else() else()
if(${CMAKE_CUDA_COMPILER_VERSION} LESS 11.1) # CUDA 11.0 if(${CMAKE_CUDA_COMPILER_VERSION} LESS 11.1) # CUDA 11.0
set(cuda_arch_bin "80") set(cuda_arch_bin "80")
elseif(${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0) # CUDA 11.1+ else()
set(cuda_arch_bin "80 86") set(cuda_arch_bin "80 86")
endif() endif()
endif() endif()
......
...@@ -129,7 +129,16 @@ if(WITH_CUTLASS) ...@@ -129,7 +129,16 @@ if(WITH_CUTLASS)
COMMAND COMMAND
${PYTHON_EXECUTABLE} ${PYTHON_EXECUTABLE}
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_kernels.py ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_kernels.py
--cuda_arch "${NVCC_ARCH_BIN}") --cuda_arch "${NVCC_ARCH_BIN}"
RESULT_VARIABLE memory_efficient_attention_gen_res)
if(NOT memory_efficient_attention_gen_res EQUAL 0)
message(
FATAL_ERROR
"The memory efficient attention kernel generation errors with NVCC_ARCH_BIN=${NVCC_ARCH_BIN}"
)
endif()
file(GLOB cutlass_cu "fusion/cutlass/conv2d/generated/*.cu" file(GLOB cutlass_cu "fusion/cutlass/conv2d/generated/*.cu"
"fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu" "fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu"
"fusion/cutlass/memory_efficient_attention/autogen/impl/*.cu") "fusion/cutlass/memory_efficient_attention/autogen/impl/*.cu")
......
...@@ -44,7 +44,7 @@ def find_arch_range(min_arch, max_arch): ...@@ -44,7 +44,7 @@ def find_arch_range(min_arch, max_arch):
assert min_arch <= max_arch assert min_arch <= max_arch
n = len(DEFAULT_ARCH) n = len(DEFAULT_ARCH)
start_idx = 0 start_idx = n - 1
for i in range(n - 1): for i in range(n - 1):
if DEFAULT_ARCH[i] <= min_arch and min_arch < DEFAULT_ARCH[i + 1]: if DEFAULT_ARCH[i] <= min_arch and min_arch < DEFAULT_ARCH[i + 1]:
start_idx = i start_idx = i
...@@ -54,6 +54,7 @@ def find_arch_range(min_arch, max_arch): ...@@ -54,6 +54,7 @@ def find_arch_range(min_arch, max_arch):
for i in range(n - 1): for i in range(n - 1):
if DEFAULT_ARCH[i] <= max_arch and max_arch < DEFAULT_ARCH[i + 1]: if DEFAULT_ARCH[i] <= max_arch and max_arch < DEFAULT_ARCH[i + 1]:
end_idx = i + 1 end_idx = i + 1
return DEFAULT_ARCH[start_idx:end_idx] return DEFAULT_ARCH[start_idx:end_idx]
......
...@@ -134,7 +134,7 @@ def memory_efficient_attention( ...@@ -134,7 +134,7 @@ def memory_efficient_attention(
"causal_diagonal": causal_diagonal, "causal_diagonal": causal_diagonal,
"seqlen_k": seqlen_k, "seqlen_k": seqlen_k,
}, },
args={ attrs={
"max_seqlen_q": max_seqlen_q, "max_seqlen_q": max_seqlen_q,
"max_seqlen_k": max_seqlen_k, "max_seqlen_k": max_seqlen_k,
"causal": causal, "causal": causal,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册