未验证 提交 8c75039c 编写于 作者: W wanghuancoder 提交者: GitHub

refine stride flag (#57005)

* refine stride flag
上级 30c9dad1
......@@ -14,6 +14,10 @@
#include "paddle/phi/core/kernel_factory.h"
#include <regex>
#include <string>
#include <unordered_set>
#include "glog/logging.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/utils/flags.h"
......@@ -33,6 +37,10 @@ PHI_DEFINE_EXPORTED_bool(use_stride_kernel,
true,
"Whether to use strdie kernel if op support stride.");
PHI_DEFINE_EXPORTED_string(stride_kernel_blacklist,
"",
"It controls the strided kernel subset do not use.");
PD_DECLARE_int32(low_precision_op_list);
PD_DECLARE_bool(enable_api_kernel_fallback);
PD_DECLARE_bool(run_kp_kernel);
......@@ -226,14 +234,26 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
if (FLAGS_use_stride_kernel && use_strided_kernel) {
auto stride_kernel_iter = iter->second.find(
{const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN
? paddle::experimental::Backend::GPU
: const_kernel_key.backend(),
phi::DataLayout::STRIDED,
const_kernel_key.dtype()});
if (stride_kernel_iter != iter->second.end()) {
return {stride_kernel_iter->second, false, true};
std::regex reg(",");
std::unordered_set<std::string> elems{
std::sregex_token_iterator(FLAGS_stride_kernel_blacklist.begin(),
FLAGS_stride_kernel_blacklist.end(),
reg,
-1),
std::sregex_token_iterator()};
elems.erase("");
if (!elems.count(kernel_name)) {
auto stride_kernel_iter = iter->second.find(
{const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN
? paddle::experimental::Backend::GPU
: const_kernel_key.backend(),
phi::DataLayout::STRIDED,
const_kernel_key.dtype()});
if (stride_kernel_iter != iter->second.end()) {
VLOG(1) << "use strided kernel, kernel_name = " << kernel_name;
return {stride_kernel_iter->second, false, true};
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册