未验证 提交 76a6b88d 编写于 作者: C Chen Weihang 提交者: GitHub

[PHi] Skip kernel declare for cuda only kernel on rocm (#39869)

* skip kernel declare for cuda only kernel on rocm

* fix error
上级 2457a7d1
......@@ -58,15 +58,21 @@ endfunction()
function(kernel_declare TARGET_LIST)
foreach(kernel_path ${TARGET_LIST})
file(READ ${kernel_path} kernel_impl)
# TODO(chenweihang): rename PD_REGISTER_KERNEL to PD_REGISTER_KERNEL
# NOTE(chenweihang): now we don't recommend to use digit in kernel name
string(REGEX MATCH "(PD_REGISTER_KERNEL|PD_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}")
string(REGEX MATCH "(PD_REGISTER_KERNEL|PD_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*,[ \t\r\n\/]*[a-z0-9_]*" first_registry "${kernel_impl}")
if (NOT first_registry STREQUAL "")
# some gpu kernel only can run on cuda, not support rocm, so we add this branch
if (WITH_ROCM)
string(FIND "${first_registry}" "cuda_only" pos)
if(pos GREATER 1)
continue()
endif()
endif()
# parse the first kernel name
string(REPLACE "PD_REGISTER_KERNEL(" "" kernel_name "${first_registry}")
string(REPLACE "PD_REGISTER_GENERAL_KERNEL(" "" kernel_name "${kernel_name}")
string(REPLACE "," "" kernel_name "${kernel_name}")
string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}")
string(REGEX REPLACE "//cuda_only" "" kernel_name "${kernel_name}")
# append kernel declare into declarations.h
# TODO(chenweihang): default declare ALL_LAYOUT for each kernel
if (${kernel_path} MATCHES "./cpu\/")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册