From ff8b2cb766ef7463de03157f1c1d322a1602cc7e Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Fri, 9 Dec 2022 11:22:12 +0800 Subject: [PATCH] [Kernel Selection] Simplify kernel selection process in phi, reduce search number to half (#47771) * simplify SelectKernelOrThrowError function in phi * opt kernel_selection process * polish code, fix backend error --- paddle/phi/core/kernel_factory.cc | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 8e0d574177..992460fe82 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -114,16 +114,13 @@ KernelResult KernelFactory::SelectKernelOrThrowError( kernels_.end(), phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name)); - KernelKey kernel_key = const_kernel_key; + KernelKey kernel_key = KernelKey(const_kernel_key.backend(), + phi::DataLayout::ALL_LAYOUT, + const_kernel_key.dtype()); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (kernel_key.backend() == Backend::GPUDNN) { auto kernel_iter = iter->second.find( - {Backend::GPUDNN, kernel_key.layout(), kernel_key.dtype()}); - if (kernel_iter == iter->second.end() && - kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) { - kernel_iter = iter->second.find( - {Backend::GPUDNN, DataLayout::ALL_LAYOUT, kernel_key.dtype()}); - } + {Backend::GPUDNN, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()}); if (kernel_iter != iter->second.end()) { return {kernel_iter->second, false}; } @@ -132,13 +129,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError( } #endif auto kernel_iter = iter->second.find(kernel_key); - // TODO(chenweihang): polish refind impl here - if (kernel_iter == iter->second.end() && - kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) { - phi::KernelKey any_layout_kernel_key( - kernel_key.backend(), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()); - kernel_iter = iter->second.find(any_layout_kernel_key); - } PADDLE_ENFORCE_NE( kernel_iter == iter->second.end() && kernel_key.backend() == Backend::CPU, @@ -162,12 +152,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError( phi::KernelKey cpu_kernel_key( phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); kernel_iter = iter->second.find(cpu_kernel_key); - if (kernel_iter == iter->second.end() && - kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) { - phi::KernelKey any_layout_kernel_key( - phi::Backend::CPU, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()); - kernel_iter = iter->second.find(any_layout_kernel_key); - } PADDLE_ENFORCE_NE( kernel_iter, -- GitLab