未验证 提交 8c75039c 编写于 作者: W wanghuancoder 提交者: GitHub

refine stride flag (#57005)

* refine stride flag
上级 30c9dad1
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#include <regex>
#include <string>
#include <unordered_set>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/utils/flags.h" #include "paddle/utils/flags.h"
...@@ -33,6 +37,10 @@ PHI_DEFINE_EXPORTED_bool(use_stride_kernel, ...@@ -33,6 +37,10 @@ PHI_DEFINE_EXPORTED_bool(use_stride_kernel,
true, true,
"Whether to use strdie kernel if op support stride."); "Whether to use strdie kernel if op support stride.");
PHI_DEFINE_EXPORTED_string(stride_kernel_blacklist,
"",
"It controls the strided kernel subset do not use.");
PD_DECLARE_int32(low_precision_op_list); PD_DECLARE_int32(low_precision_op_list);
PD_DECLARE_bool(enable_api_kernel_fallback); PD_DECLARE_bool(enable_api_kernel_fallback);
PD_DECLARE_bool(run_kp_kernel); PD_DECLARE_bool(run_kp_kernel);
...@@ -226,14 +234,26 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -226,14 +234,26 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name)); phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
if (FLAGS_use_stride_kernel && use_strided_kernel) { if (FLAGS_use_stride_kernel && use_strided_kernel) {
auto stride_kernel_iter = iter->second.find( std::regex reg(",");
{const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN std::unordered_set<std::string> elems{
? paddle::experimental::Backend::GPU std::sregex_token_iterator(FLAGS_stride_kernel_blacklist.begin(),
: const_kernel_key.backend(), FLAGS_stride_kernel_blacklist.end(),
phi::DataLayout::STRIDED, reg,
const_kernel_key.dtype()}); -1),
if (stride_kernel_iter != iter->second.end()) { std::sregex_token_iterator()};
return {stride_kernel_iter->second, false, true}; elems.erase("");
if (!elems.count(kernel_name)) {
auto stride_kernel_iter = iter->second.find(
{const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN
? paddle::experimental::Backend::GPU
: const_kernel_key.backend(),
phi::DataLayout::STRIDED,
const_kernel_key.dtype()});
if (stride_kernel_iter != iter->second.end()) {
VLOG(1) << "use strided kernel, kernel_name = " << kernel_name;
return {stride_kernel_iter->second, false, true};
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册