未验证 提交 8c75039c 编写于 作者: W wanghuancoder 提交者: GitHub

refine stride flag (#57005)

* refine stride flag
上级 30c9dad1
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#include <regex>
#include <string>
#include <unordered_set>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/utils/flags.h" #include "paddle/utils/flags.h"
...@@ -33,6 +37,10 @@ PHI_DEFINE_EXPORTED_bool(use_stride_kernel, ...@@ -33,6 +37,10 @@ PHI_DEFINE_EXPORTED_bool(use_stride_kernel,
true, true,
"Whether to use strdie kernel if op support stride."); "Whether to use strdie kernel if op support stride.");
PHI_DEFINE_EXPORTED_string(stride_kernel_blacklist,
"",
"It controls the strided kernel subset do not use.");
PD_DECLARE_int32(low_precision_op_list); PD_DECLARE_int32(low_precision_op_list);
PD_DECLARE_bool(enable_api_kernel_fallback); PD_DECLARE_bool(enable_api_kernel_fallback);
PD_DECLARE_bool(run_kp_kernel); PD_DECLARE_bool(run_kp_kernel);
...@@ -226,6 +234,16 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -226,6 +234,16 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name)); phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
if (FLAGS_use_stride_kernel && use_strided_kernel) { if (FLAGS_use_stride_kernel && use_strided_kernel) {
std::regex reg(",");
std::unordered_set<std::string> elems{
std::sregex_token_iterator(FLAGS_stride_kernel_blacklist.begin(),
FLAGS_stride_kernel_blacklist.end(),
reg,
-1),
std::sregex_token_iterator()};
elems.erase("");
if (!elems.count(kernel_name)) {
auto stride_kernel_iter = iter->second.find( auto stride_kernel_iter = iter->second.find(
{const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN {const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN
? paddle::experimental::Backend::GPU ? paddle::experimental::Backend::GPU
...@@ -233,9 +251,11 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -233,9 +251,11 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
phi::DataLayout::STRIDED, phi::DataLayout::STRIDED,
const_kernel_key.dtype()}); const_kernel_key.dtype()});
if (stride_kernel_iter != iter->second.end()) { if (stride_kernel_iter != iter->second.end()) {
VLOG(1) << "use strided kernel, kernel_name = " << kernel_name;
return {stride_kernel_iter->second, false, true}; return {stride_kernel_iter->second, false, true};
} }
} }
}
KernelKey kernel_key = KernelKey(const_kernel_key.backend(), KernelKey kernel_key = KernelKey(const_kernel_key.backend(),
phi::DataLayout::ALL_LAYOUT, phi::DataLayout::ALL_LAYOUT,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册