From 76a6b88d233aae2c7f6804f3c6aeb19e903dd54a Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 24 Feb 2022 09:52:33 +0800 Subject: [PATCH] [PHi] Skip kernel declare for cuda only kernel on rocm (#39869) * skip kernel declare for cuda only kernel on rocm * fix error --- cmake/pten.cmake | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/cmake/pten.cmake b/cmake/pten.cmake index 9a3552efce8..5645ac6cfa3 100644 --- a/cmake/pten.cmake +++ b/cmake/pten.cmake @@ -58,15 +58,21 @@ endfunction() function(kernel_declare TARGET_LIST) foreach(kernel_path ${TARGET_LIST}) file(READ ${kernel_path} kernel_impl) - # TODO(chenweihang): rename PD_REGISTER_KERNEL to PD_REGISTER_KERNEL - # NOTE(chenweihang): now we don't recommend to use digit in kernel name - string(REGEX MATCH "(PD_REGISTER_KERNEL|PD_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}") + string(REGEX MATCH "(PD_REGISTER_KERNEL|PD_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*,[ \t\r\n\/]*[a-z0-9_]*" first_registry "${kernel_impl}") if (NOT first_registry STREQUAL "") + # some gpu kernel only can run on cuda, not support rocm, so we add this branch + if (WITH_ROCM) + string(FIND "${first_registry}" "cuda_only" pos) + if(pos GREATER 1) + continue() + endif() + endif() # parse the first kernel name string(REPLACE "PD_REGISTER_KERNEL(" "" kernel_name "${first_registry}") string(REPLACE "PD_REGISTER_GENERAL_KERNEL(" "" kernel_name "${kernel_name}") string(REPLACE "," "" kernel_name "${kernel_name}") string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}") + string(REGEX REPLACE "//cuda_only" "" kernel_name "${kernel_name}") # append kernel declare into declarations.h # TODO(chenweihang): default declare ALL_LAYOUT for each kernel if (${kernel_path} MATCHES "./cpu\/") -- GitLab