diff --git a/cmake/pten.cmake b/cmake/pten.cmake index 9a3552efce8e12d21f1c79b70ac649c43af2c085..5645ac6cfa3039afdad0514abade5c9ea9b35408 100644 --- a/cmake/pten.cmake +++ b/cmake/pten.cmake @@ -58,15 +58,21 @@ endfunction() function(kernel_declare TARGET_LIST) foreach(kernel_path ${TARGET_LIST}) file(READ ${kernel_path} kernel_impl) - # TODO(chenweihang): rename PD_REGISTER_KERNEL to PD_REGISTER_KERNEL - # NOTE(chenweihang): now we don't recommend to use digit in kernel name - string(REGEX MATCH "(PD_REGISTER_KERNEL|PD_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}") + string(REGEX MATCH "(PD_REGISTER_KERNEL|PD_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*,[ \t\r\n\/]*[a-z0-9_]*" first_registry "${kernel_impl}") if (NOT first_registry STREQUAL "") + # some gpu kernel only can run on cuda, not support rocm, so we add this branch + if (WITH_ROCM) + string(FIND "${first_registry}" "cuda_only" pos) + if(pos GREATER 1) + continue() + endif() + endif() # parse the first kernel name string(REPLACE "PD_REGISTER_KERNEL(" "" kernel_name "${first_registry}") string(REPLACE "PD_REGISTER_GENERAL_KERNEL(" "" kernel_name "${kernel_name}") string(REPLACE "," "" kernel_name "${kernel_name}") string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}") + string(REGEX REPLACE "//cuda_only" "" kernel_name "${kernel_name}") # append kernel declare into declarations.h # TODO(chenweihang): default declare ALL_LAYOUT for each kernel if (${kernel_path} MATCHES "./cpu\/")