未验证 提交 6a6c7493 编写于 作者: Y YuanRisheng 提交者: GitHub

[BugFix]Fix concat bugs when call onednn kernel (#46518) (#46845)

* fix concat bug

* fix ci bugs

* fix ci bugs
上级 d091d1b0
...@@ -78,7 +78,7 @@ function(kernel_declare TARGET_LIST) ...@@ -78,7 +78,7 @@ function(kernel_declare TARGET_LIST)
string( string(
REGEX REGEX
MATCH MATCH
"(PD_REGISTER_KERNEL|PD_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*,[ \t\r\n\/]*[a-z0-9_]*" "(PD_REGISTER_KERNEL|PD_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z]*,[ \\\t\r\n]*[A-Z_]*"
first_registry first_registry
"${kernel_impl}") "${kernel_impl}")
if(NOT first_registry STREQUAL "") if(NOT first_registry STREQUAL "")
...@@ -89,38 +89,23 @@ function(kernel_declare TARGET_LIST) ...@@ -89,38 +89,23 @@ function(kernel_declare TARGET_LIST)
continue() continue()
endif() endif()
endif() endif()
# parse the first kernel name # parse the registerd kernel message
string(REPLACE "PD_REGISTER_KERNEL(" "" kernel_name "${first_registry}") string(REPLACE "PD_REGISTER_KERNEL(" "" kernel_msg "${first_registry}")
string(REPLACE "PD_REGISTER_GENERAL_KERNEL(" "" kernel_name string(REPLACE "PD_REGISTER_GENERAL_KERNEL(" "" kernel_msg
"${kernel_name}") "${kernel_msg}")
string(REPLACE "," "" kernel_name "${kernel_name}") string(REPLACE "," ";" kernel_msg "${kernel_msg}")
string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}") string(REGEX REPLACE "[ \\\t\r\n]+" "" kernel_msg "${kernel_msg}")
string(REGEX REPLACE "//cuda_only" "" kernel_name "${kernel_name}") string(REGEX REPLACE "//cuda_only" "" kernel_msg "${kernel_msg}")
list(GET kernel_msg 0 kernel_name)
list(GET kernel_msg 1 kernel_backend)
list(GET kernel_msg 2 kernel_layout)
# append kernel declare into declarations.h # append kernel declare into declarations.h
# TODO(chenweihang): default declare ALL_LAYOUT for each kernel file(
if(${kernel_path} MATCHES "./cpu\/") APPEND ${kernel_declare_file}
file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, ${kernel_backend}, ${kernel_layout});\n"
"PD_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n") )
elseif(${kernel_path} MATCHES "./gpu\/")
file(APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, GPU, ALL_LAYOUT);\n")
elseif(${kernel_path} MATCHES "./xpu\/")
file(APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, XPU, ALL_LAYOUT);\n")
elseif(${kernel_path} MATCHES "./gpudnn\/")
file(APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, GPUDNN, ALL_LAYOUT);\n")
elseif(${kernel_path} MATCHES "./kps\/")
file(APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, KPS, ALL_LAYOUT);\n")
elseif(${kernel_path} MATCHES "./onednn\/")
file(APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, OneDNN, ALL_LAYOUT);\n")
else()
# deal with device independent kernel, now we use CPU temporaary
file(APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n")
endif()
endif() endif()
endforeach() endforeach()
endfunction() endfunction()
......
...@@ -156,7 +156,7 @@ void ConcatKernel(const Context& dev_ctx, ...@@ -156,7 +156,7 @@ void ConcatKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(concat, PD_REGISTER_KERNEL(concat,
OneDNN, OneDNN,
ALL_LAYOUT, ONEDNN,
phi::ConcatKernel, phi::ConcatKernel,
float, float,
phi::dtype::bfloat16, phi::dtype::bfloat16,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册