diff --git a/cmake/pten.cmake b/cmake/pten.cmake index be265940f3f1d9af11a9da475b4e7ef303afe84e..acc30aa22996a7ccb8fbe0b04ecd5b8365887bc1 100644 --- a/cmake/pten.cmake +++ b/cmake/pten.cmake @@ -93,7 +93,7 @@ function(kernel_library TARGET) set(all_srcs) set(kernel_deps) - set(oneValueArgs "") + set(oneValueArgs SUB_DIR) set(multiValueArgs SRCS DEPS) cmake_parse_arguments(kernel_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -135,8 +135,17 @@ function(kernel_library TARGET) foreach(src ${all_srcs}) file(READ ${src} target_content) string(REGEX MATCHALL "#include \"paddle\/pten\/kernels\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content}) + if ("${kernel_library_SUB_DIR}" STREQUAL "") + string(REGEX MATCHALL "#include \"paddle\/pten\/kernels\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content}) + else() + string(REGEX MATCHALL "#include \"paddle\/pten\/kernels\/${kernel_library_SUB_DIR}\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content}) + endif() foreach(include_kernel ${include_kernels}) + if ("${kernel_library_SUB_DIR}" STREQUAL "") string(REGEX REPLACE "#include \"paddle\/pten\/kernels\/" "" kernel_name ${include_kernel}) + else() + string(REGEX REPLACE "#include \"paddle\/pten\/kernels\/${kernel_library_SUB_DIR}\/" "" kernel_name ${include_kernel}) + endif() string(REGEX REPLACE ".h\"" "" kernel_name ${kernel_name}) list(APPEND kernel_deps ${kernel_name}) endforeach() @@ -250,7 +259,7 @@ endfunction() function(register_kernels) set(options "") - set(oneValueArgs "") + set(oneValueArgs SUB_DIR) set(multiValueArgs EXCLUDES DEPS) cmake_parse_arguments(register_kernels "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -263,9 +272,9 @@ function(register_kernels) list(FIND register_kernels_EXCLUDES ${target} _index) if (${_index} EQUAL -1) if (${register_kernels_DEPS_len} GREATER 0) - kernel_library(${target} DEPS ${register_kernels_DEPS}) + kernel_library(${target} DEPS ${register_kernels_DEPS} SUB_DIR ${register_kernels_SUB_DIR}) else() - kernel_library(${target}) + kernel_library(${target} SUB_DIR ${register_kernels_SUB_DIR}) endif() endif() endforeach()