未验证 提交 b6321350 编写于 作者: M ming1753 提交者: GitHub

Modify kernel matching rules of infer prune (#54760)

上级 2401d48d
......@@ -86,59 +86,75 @@ function(kernel_declare TARGET_LIST)
"(PD_REGISTER_KERNEL|PD_REGISTER_KERNEL_FOR_ALL_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
first_registry
"${kernel_impl}")
if(NOT first_registry STREQUAL "")
set(kernel_declare_id "")
while(NOT first_registry STREQUAL "")
string(REPLACE "${first_registry}" "" kernel_impl "${kernel_impl}")
# some gpu kernel can run on cuda, but not support jetson, so we add this branch
if(WITH_NV_JETSON)
string(FIND "${first_registry}" "decode_jpeg" pos)
if(pos GREATER 1)
continue()
set(first_registry "")
endif()
endif()
# fusion group kernel is not supported in windows and mac
if(WIN32 OR APPLE)
string(FIND "${first_registry}" "fusion_group" pos)
if(pos GREATER 1)
continue()
set(first_registry "")
endif()
endif()
# some gpu kernel only can run on cuda, not support rocm, so we add this branch
if(WITH_ROCM)
string(FIND "${first_registry}" "cuda_only" pos)
if(pos GREATER 1)
continue()
set(first_registry "")
endif()
endif()
string(
REGEX
MATCH
"(PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
is_all_backend
"${first_registry}")
# parse the registerd kernel message
string(REPLACE "PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(" "" kernel_msg
"${first_registry}")
string(REPLACE "PD_REGISTER_KERNEL(" "" kernel_msg "${kernel_msg}")
string(REPLACE "PD_REGISTER_KERNEL_FOR_ALL_DTYPE(" "" kernel_msg
"${kernel_msg}")
string(REPLACE "," ";" kernel_msg "${kernel_msg}")
string(REGEX REPLACE "[ \\\t\r\n]+" "" kernel_msg "${kernel_msg}")
string(REGEX REPLACE "//cuda_only" "" kernel_msg "${kernel_msg}")
list(GET kernel_msg 0 kernel_name)
if(NOT is_all_backend STREQUAL "")
list(GET kernel_msg 1 kernel_layout)
set(kernel_backend "CPU")
else()
list(GET kernel_msg 1 kernel_backend)
list(GET kernel_msg 2 kernel_layout)
if(NOT first_registry STREQUAL "")
string(
REGEX
MATCH
"(PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
is_all_backend
"${first_registry}")
# parse the registerd kernel message
string(REPLACE "PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(" ""
kernel_msg "${first_registry}")
string(REPLACE "PD_REGISTER_KERNEL(" "" kernel_msg "${kernel_msg}")
string(REPLACE "PD_REGISTER_KERNEL_FOR_ALL_DTYPE(" "" kernel_msg
"${kernel_msg}")
string(REPLACE "," ";" kernel_msg "${kernel_msg}")
string(REGEX REPLACE "[ \\\t\r\n]+" "" kernel_msg "${kernel_msg}")
string(REGEX REPLACE "//cuda_only" "" kernel_msg "${kernel_msg}")
list(GET kernel_msg 0 kernel_name)
if(NOT is_all_backend STREQUAL "")
list(GET kernel_msg 1 kernel_layout)
set(kernel_backend "CPU")
else()
list(GET kernel_msg 1 kernel_backend)
list(GET kernel_msg 2 kernel_layout)
endif()
set(kernel_declare_id
"${kernel_declare_id}PD_DECLARE_KERNEL(${kernel_name}, ${kernel_backend}, ${kernel_layout});"
)
if("${KERNEL_LIST}" STREQUAL "")
set(first_registry "")
else()
string(
REGEX
MATCH
"(PD_REGISTER_KERNEL|PD_REGISTER_KERNEL_FOR_ALL_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
first_registry
"${kernel_impl}")
endif()
endif()
# append kernel declare into declarations.h
file(
APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, ${kernel_backend}, ${kernel_layout});\n"
)
endwhile()
# append kernel declare into declarations.h
if(NOT kernel_declare_id STREQUAL "")
file(APPEND ${kernel_declare_file} "${kernel_declare_id}\n")
endif()
endforeach()
endfunction()
......@@ -203,18 +219,30 @@ function(prune_declaration_h)
file(APPEND ${kernel_declare_file_prune}
"#include \"paddle/phi/core/kernel_registry.h\"\n")
set(kernel_declare_list_prune)
foreach(kernel_registry IN LISTS kernel_registry_list)
if(NOT ${kernel_registry} EQUAL "")
if(NOT "${kernel_registry}" EQUAL "")
foreach(kernel_name IN LISTS kernel_list)
string(FIND ${kernel_registry} "(${kernel_name})" index1)
string(FIND ${kernel_registry} "(${kernel_name}," index2)
string(FIND "${kernel_registry}" "(${kernel_name})" index1)
string(FIND "${kernel_registry}" "(${kernel_name}," index2)
if((NOT ${index1} EQUAL "-1") OR (NOT ${index2} EQUAL "-1"))
file(APPEND ${kernel_declare_file_prune} "${kernel_registry}\n")
string(
REGEX
MATCH
"PD_DECLARE_KERNEL\\([a-z0-9_]*, [[a-z0-9_]*]?[a-zA-Z_]*, [A-Z_]*\\)"
first_registry
"${kernel_registry}")
list(APPEND kernel_declare_list_prune "${first_registry}")
endif()
endforeach()
endif()
endforeach()
list(REMOVE_DUPLICATES kernel_declare_list_prune)
foreach(kernel_declare_prune IN LISTS kernel_declare_list_prune)
file(APPEND ${kernel_declare_file_prune} "${kernel_declare_prune};\n")
endforeach()
file(WRITE ${kernel_declare_file} "")
file(STRINGS ${kernel_declare_file_prune} kernel_registry_list_tmp)
foreach(kernel_registry IN LISTS kernel_registry_list_tmp)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册