diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 8e0d574177d46a15ac6effa0ed474b45fd4b660e..992460fe8267c0661564cb138603589229293c0d 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -114,16 +114,13 @@ KernelResult KernelFactory::SelectKernelOrThrowError( kernels_.end(), phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name)); - KernelKey kernel_key = const_kernel_key; + KernelKey kernel_key = KernelKey(const_kernel_key.backend(), + phi::DataLayout::ALL_LAYOUT, + const_kernel_key.dtype()); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (kernel_key.backend() == Backend::GPUDNN) { auto kernel_iter = iter->second.find( - {Backend::GPUDNN, kernel_key.layout(), kernel_key.dtype()}); - if (kernel_iter == iter->second.end() && - kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) { - kernel_iter = iter->second.find( - {Backend::GPUDNN, DataLayout::ALL_LAYOUT, kernel_key.dtype()}); - } + {Backend::GPUDNN, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()}); if (kernel_iter != iter->second.end()) { return {kernel_iter->second, false}; } @@ -132,13 +129,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError( } #endif auto kernel_iter = iter->second.find(kernel_key); - // TODO(chenweihang): polish refind impl here - if (kernel_iter == iter->second.end() && - kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) { - phi::KernelKey any_layout_kernel_key( - kernel_key.backend(), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()); - kernel_iter = iter->second.find(any_layout_kernel_key); - } PADDLE_ENFORCE_NE( kernel_iter == iter->second.end() && kernel_key.backend() == Backend::CPU, @@ -162,12 +152,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError( phi::KernelKey cpu_kernel_key( phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); kernel_iter = iter->second.find(cpu_kernel_key); - if (kernel_iter == iter->second.end() && - kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) { - phi::KernelKey any_layout_kernel_key( - phi::Backend::CPU, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()); - kernel_iter = iter->second.find(any_layout_kernel_key); - } PADDLE_ENFORCE_NE( kernel_iter,