提交 429b2b0d 编写于 作者: W wanghuancoder 提交者: GitHub

Revert "refine stride flag (#57005)"

This reverts commit 8c75039c.
上级 bc601f58
...@@ -14,10 +14,6 @@ ...@@ -14,10 +14,6 @@
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#include <regex>
#include <string>
#include <unordered_set>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/utils/flags.h" #include "paddle/utils/flags.h"
...@@ -37,10 +33,6 @@ PHI_DEFINE_EXPORTED_bool(use_stride_kernel, ...@@ -37,10 +33,6 @@ PHI_DEFINE_EXPORTED_bool(use_stride_kernel,
true, true,
"Whether to use strdie kernel if op support stride."); "Whether to use strdie kernel if op support stride.");
PHI_DEFINE_EXPORTED_string(stride_kernel_blacklist,
"",
"It controls the strided kernel subset do not use.");
PD_DECLARE_int32(low_precision_op_list); PD_DECLARE_int32(low_precision_op_list);
PD_DECLARE_bool(enable_api_kernel_fallback); PD_DECLARE_bool(enable_api_kernel_fallback);
PD_DECLARE_bool(run_kp_kernel); PD_DECLARE_bool(run_kp_kernel);
...@@ -234,26 +226,14 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -234,26 +226,14 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name)); phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
if (FLAGS_use_stride_kernel && use_strided_kernel) { if (FLAGS_use_stride_kernel && use_strided_kernel) {
std::regex reg(","); auto stride_kernel_iter = iter->second.find(
std::unordered_set<std::string> elems{ {const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN
std::sregex_token_iterator(FLAGS_stride_kernel_blacklist.begin(), ? paddle::experimental::Backend::GPU
FLAGS_stride_kernel_blacklist.end(), : const_kernel_key.backend(),
reg, phi::DataLayout::STRIDED,
-1), const_kernel_key.dtype()});
std::sregex_token_iterator()}; if (stride_kernel_iter != iter->second.end()) {
elems.erase(""); return {stride_kernel_iter->second, false, true};
if (!elems.count(kernel_name)) {
auto stride_kernel_iter = iter->second.find(
{const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN
? paddle::experimental::Backend::GPU
: const_kernel_key.backend(),
phi::DataLayout::STRIDED,
const_kernel_key.dtype()});
if (stride_kernel_iter != iter->second.end()) {
VLOG(1) << "use strided kernel, kernel_name = " << kernel_name;
return {stride_kernel_iter->second, false, true};
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册