diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index d6f6e60fe2d3de76ebb668c246cb625e237ac492..ae9c16e0cc7106bcd15fe24bafc0cf3954b1bc7f 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -110,6 +110,7 @@ const Kernel& KernelFactory::SelectKernelOrThrowError( << "] is not registered."; } #endif + auto kernel_iter = iter->second.find(kernel_key); // TODO(chenweihang): polish refind impl here if (kernel_iter == iter->second.end() && @@ -118,6 +119,22 @@ const Kernel& KernelFactory::SelectKernelOrThrowError( kernel_key.backend(), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()); kernel_iter = iter->second.find(any_layout_kernel_key); } + +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (kernel_iter == iter->second.end()) { + // Fallback CPU backend + phi::KernelKey cpu_kernel_key( + phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); + kernel_iter = iter->second.find(cpu_kernel_key); + if (kernel_iter == iter->second.end() && + kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) { + phi::KernelKey any_layout_kernel_key( + phi::Backend::CPU, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()); + kernel_iter = iter->second.find(any_layout_kernel_key); + } + } +#endif + PADDLE_ENFORCE_NE( kernel_iter, iter->second.end(),