未验证 提交 e5c7ca48 编写于 作者: C Chen Weihang 提交者: GitHub

auto parse kernel deps by include (#38438)

上级 acef85b2
...@@ -18,7 +18,7 @@ function(kernel_declare TARGET_LIST) ...@@ -18,7 +18,7 @@ function(kernel_declare TARGET_LIST)
file(READ ${kernel_path} kernel_impl) file(READ ${kernel_path} kernel_impl)
# TODO(chenweihang): rename PT_REGISTER_CTX_KERNEL to PT_REGISTER_KERNEL # TODO(chenweihang): rename PT_REGISTER_CTX_KERNEL to PT_REGISTER_KERNEL
# NOTE(chenweihang): now we don't recommend to use digit in kernel name # NOTE(chenweihang): now we don't recommend to use digit in kernel name
string(REGEX MATCH "(PT_REGISTER_CTX_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z_]*," first_registry "${kernel_impl}") string(REGEX MATCH "(PT_REGISTER_CTX_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}")
if (NOT first_registry STREQUAL "") if (NOT first_registry STREQUAL "")
# parse the first kernel name # parse the first kernel name
string(REPLACE "PT_REGISTER_CTX_KERNEL(" "" kernel_name "${first_registry}") string(REPLACE "PT_REGISTER_CTX_KERNEL(" "" kernel_name "${first_registry}")
...@@ -49,6 +49,9 @@ function(kernel_library TARGET) ...@@ -49,6 +49,9 @@ function(kernel_library TARGET)
set(gpu_srcs) set(gpu_srcs)
set(xpu_srcs) set(xpu_srcs)
set(npu_srcs) set(npu_srcs)
# parse and save the deps kerenl targets
set(all_srcs)
set(kernel_deps)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS)
...@@ -57,7 +60,6 @@ function(kernel_library TARGET) ...@@ -57,7 +60,6 @@ function(kernel_library TARGET)
list(LENGTH kernel_library_SRCS kernel_library_SRCS_len) list(LENGTH kernel_library_SRCS kernel_library_SRCS_len)
# one kernel only match one impl file in each backend # one kernel only match one impl file in each backend
# TODO(chenweihang): parse compile deps by include headers
if (${kernel_library_SRCS_len} EQUAL 0) if (${kernel_library_SRCS_len} EQUAL 0)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc)
list(APPEND common_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) list(APPEND common_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc)
...@@ -84,6 +86,23 @@ function(kernel_library TARGET) ...@@ -84,6 +86,23 @@ function(kernel_library TARGET)
# TODO(chenweihang): impl compile by source later # TODO(chenweihang): impl compile by source later
endif() endif()
list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.h)
list(APPEND all_srcs ${common_srcs})
list(APPEND all_srcs ${cpu_srcs})
list(APPEND all_srcs ${gpu_srcs})
list(APPEND all_srcs ${xpu_srcs})
foreach(src ${all_srcs})
file(READ ${src} target_content)
string(REGEX MATCHALL "#include \"paddle\/pten\/kernels\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content})
foreach(include_kernel ${include_kernels})
string(REGEX REPLACE "#include \"paddle\/pten\/kernels\/" "" kernel_name ${include_kernel})
string(REGEX REPLACE ".h\"" "" kernel_name ${kernel_name})
list(APPEND kernel_deps ${kernel_name})
endforeach()
endforeach()
list(REMOVE_DUPLICATES kernel_deps)
list(REMOVE_ITEM kernel_deps ${TARGET})
list(LENGTH common_srcs common_srcs_len) list(LENGTH common_srcs common_srcs_len)
list(LENGTH cpu_srcs cpu_srcs_len) list(LENGTH cpu_srcs cpu_srcs_len)
list(LENGTH gpu_srcs gpu_srcs_len) list(LENGTH gpu_srcs gpu_srcs_len)
...@@ -95,11 +114,11 @@ function(kernel_library TARGET) ...@@ -95,11 +114,11 @@ function(kernel_library TARGET)
# we will use this implementation and will not adopt the implementation # we will use this implementation and will not adopt the implementation
# under specific devices # under specific devices
if (WITH_GPU) if (WITH_GPU)
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS}) nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_ROCM) elseif (WITH_ROCM)
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS}) hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
else() else()
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS}) cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif() endif()
else() else()
# If the kernel has a header file declaration, but no corresponding # If the kernel has a header file declaration, but no corresponding
...@@ -110,15 +129,15 @@ function(kernel_library TARGET) ...@@ -110,15 +129,15 @@ function(kernel_library TARGET)
else() else()
if (WITH_GPU) if (WITH_GPU)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS}) nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif() endif()
elseif (WITH_ROCM) elseif (WITH_ROCM)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS}) hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif() endif()
else() else()
if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${npu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${npu_srcs_len} GREATER 0)
cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} ${npu_srcs} DEPS ${kernel_library_DEPS}) cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} ${npu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif() endif()
endif() endif()
endif() endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册