未验证 提交 ff8b2cb7 编写于 作者: H HongyuJia 提交者: GitHub

[Kernel Selection] Simplify kernel selection process in phi, reduce search number to half (#47771)

* simplify SelectKernelOrThrowError function in phi

* opt kernel_selection process

* polish code, fix backend error
上级 c2e77ba3
...@@ -114,16 +114,13 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -114,16 +114,13 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
kernels_.end(), kernels_.end(),
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name)); phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
KernelKey kernel_key = const_kernel_key; KernelKey kernel_key = KernelKey(const_kernel_key.backend(),
phi::DataLayout::ALL_LAYOUT,
const_kernel_key.dtype());
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (kernel_key.backend() == Backend::GPUDNN) { if (kernel_key.backend() == Backend::GPUDNN) {
auto kernel_iter = iter->second.find( auto kernel_iter = iter->second.find(
{Backend::GPUDNN, kernel_key.layout(), kernel_key.dtype()}); {Backend::GPUDNN, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()});
if (kernel_iter == iter->second.end() &&
kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
kernel_iter = iter->second.find(
{Backend::GPUDNN, DataLayout::ALL_LAYOUT, kernel_key.dtype()});
}
if (kernel_iter != iter->second.end()) { if (kernel_iter != iter->second.end()) {
return {kernel_iter->second, false}; return {kernel_iter->second, false};
} }
...@@ -132,13 +129,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -132,13 +129,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
} }
#endif #endif
auto kernel_iter = iter->second.find(kernel_key); auto kernel_iter = iter->second.find(kernel_key);
// TODO(chenweihang): polish refind impl here
if (kernel_iter == iter->second.end() &&
kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
phi::KernelKey any_layout_kernel_key(
kernel_key.backend(), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
kernel_iter = iter->second.find(any_layout_kernel_key);
}
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
kernel_iter == iter->second.end() && kernel_key.backend() == Backend::CPU, kernel_iter == iter->second.end() && kernel_key.backend() == Backend::CPU,
...@@ -162,12 +152,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -162,12 +152,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
phi::KernelKey cpu_kernel_key( phi::KernelKey cpu_kernel_key(
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
kernel_iter = iter->second.find(cpu_kernel_key); kernel_iter = iter->second.find(cpu_kernel_key);
if (kernel_iter == iter->second.end() &&
kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
phi::KernelKey any_layout_kernel_key(
phi::Backend::CPU, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
kernel_iter = iter->second.find(any_layout_kernel_key);
}
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
kernel_iter, kernel_iter,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册